diff --git a/include/triton/codegen/analysis/layout.h b/include/triton/codegen/analysis/layout.h index 56fb1e4b9..28dfad18d 100644 --- a/include/triton/codegen/analysis/layout.h +++ b/include/triton/codegen/analysis/layout.h @@ -224,6 +224,7 @@ struct scanline_layout: public distributed_layout { int nts(size_t k) { return nts_.at(k); } int contig_per_thread(size_t k) { return nts_.at(k); } + int per_thread(size_t k) { return nts(k) * shape_[k] / shape_per_cta(k);} public: // micro tile size. The size of a tile held by a thread block. std::vector mts_; diff --git a/include/triton/codegen/selection/generator.h b/include/triton/codegen/selection/generator.h index 293aa8908..e3191efb1 100644 --- a/include/triton/codegen/selection/generator.h +++ b/include/triton/codegen/selection/generator.h @@ -24,6 +24,7 @@ namespace llvm{ class IRBuilder; class ArrayType; class Function; + class StructType; } namespace triton{ @@ -114,6 +115,8 @@ private: private: Type *cvt(ir::type *ty); llvm::Attribute cvt(ir::attribute attr); + llvm::StructType* packed_type(ir::value* i); + void forward_declare(ir::function* fn); public: generator(analysis::axes *a_axes, @@ -125,6 +128,8 @@ public: unsigned num_warps); void visit_value(ir::value* v); + void visit_call_inst(ir::call_inst*); + void visit_launch_inst(ir::launch_inst *); void visit_phi_node(ir::phi_node*); void visit_binary_operator(ir::binary_operator*); void visit_getelementptr_inst(ir::getelementptr_inst*); @@ -148,6 +153,8 @@ public: void visit_unmasked_store_inst(ir::unmasked_store_inst*); void visit_masked_store_inst(ir::masked_store_inst*); void visit_cat_inst(ir::cat_inst*); + void visit_extract_value_inst(ir::extract_value_inst *); + void visit_insert_value_inst(ir::insert_value_inst *); void visit_reshape_inst(ir::reshape_inst*); void visit_splat_inst(ir::splat_inst*); void visit_broadcast_inst(ir::broadcast_inst*); @@ -242,6 +249,7 @@ private: /// triton bb -> llvm bb std::map bbs_; std::map> ords_; + std::map fns_; // helper for creating llvm values adder add; diff --git a/include/triton/codegen/transform/inline.h b/include/triton/codegen/transform/inline.h new file mode 100644 index 000000000..c79079b61 --- /dev/null +++ b/include/triton/codegen/transform/inline.h @@ -0,0 +1,31 @@ +#pragma once + +#include + +namespace triton { + +namespace ir { + class module; + class function; + class call_inst; + class builder; +} + +namespace codegen{ +namespace transform{ + +struct fncmp { + bool operator()(ir::function* x, ir::function* y) const; +}; + +class inliner { +public: + inliner() {} + void do_inline(ir::function* fn, ir::call_inst* callsite, ir::builder& builder, std::list& callsites); + void run(ir::module &mod); +}; + + +} +} +} diff --git a/include/triton/codegen/transform/peephole.h b/include/triton/codegen/transform/peephole.h index 0e1ed222e..5b84a813b 100644 --- a/include/triton/codegen/transform/peephole.h +++ b/include/triton/codegen/transform/peephole.h @@ -30,6 +30,9 @@ private: bool rewrite_dot_hmma(ir::dot_inst *dot, ir::builder& builder, bool trans_a, bool trans_b, ir::value *A, ir::value *B, ir::value *D); bool rewrite_dot(ir::instruction *value, ir::builder& builder); bool rewrite_mult(ir::instruction *value, ir::builder& builder); + bool rewrite_insert_extract(ir::instruction *value, ir::builder& builder); + + bool rewrite_unit_red(ir::instruction *value, ir::builder& builder); bool rewrite_gep_ptr_min_off_plus_off(ir::instruction *value, ir::builder& builder); bool rewrite_select_masked_load(ir::instruction *value, ir::builder& builder); diff --git a/include/triton/driver/dispatch.h b/include/triton/driver/dispatch.h index 5503bacaf..2384b4cba 100755 --- a/include/triton/driver/dispatch.h +++ b/include/triton/driver/dispatch.h @@ -89,6 +89,7 @@ public: static CUresult cuDeviceGetAttribute(int *pi, CUdevice_attribute attrib, CUdevice dev); static CUresult cuDeviceGetCount(int *count); // link management + static CUresult cuLinkAddFile_v2(CUlinkState state, CUjitInputType type, const char *path, unsigned int numOptions, CUjit_option *options, void **optionValues); static CUresult cuLinkAddData_v2(CUlinkState state, CUjitInputType type, void* data, size_t size, const char* name, unsigned int numOptions, CUjit_option* options, void** optionValues); static CUresult cuLinkCreate_v2(unsigned int numOptions, CUjit_option* options, void** optionValues, CUlinkState* stateOut); static CUresult cuLinkComplete(CUlinkState state, void** cubinOut, size_t* sizeOut); @@ -214,6 +215,7 @@ private: static void* cuDeviceGetAttribute_; static void* cuDeviceGetCount_; // link management + static void* cuLinkAddFile_v2_; static void* cuLinkAddData_v2_; static void* cuLinkCreate_v2_; static void* cuLinkDestroy_; diff --git a/include/triton/external/CUDA/cuda.h b/include/triton/external/CUDA/cuda.h old mode 100755 new mode 100644 index f7bf9fc12..2f32c80fa --- a/include/triton/external/CUDA/cuda.h +++ b/include/triton/external/CUDA/cuda.h @@ -224,7 +224,7 @@ typedef uint64_t cuuint64_t; /** * CUDA API version number */ -#define CUDA_VERSION 11050 +#define CUDA_VERSION 11040 #ifdef __cplusplus extern "C" { @@ -496,33 +496,7 @@ typedef enum CUarray_format_enum { CU_AD_FORMAT_SIGNED_INT32 = 0x0a, /**< Signed 32-bit integers */ CU_AD_FORMAT_HALF = 0x10, /**< 16-bit floating point */ CU_AD_FORMAT_FLOAT = 0x20, /**< 32-bit floating point */ - CU_AD_FORMAT_NV12 = 0xb0, /**< 8-bit YUV planar format, with 4:2:0 sampling */ - CU_AD_FORMAT_UNORM_INT8X1 = 0xc0, /**< 1 channel unsigned 8-bit normalized integer */ - CU_AD_FORMAT_UNORM_INT8X2 = 0xc1, /**< 2 channel unsigned 8-bit normalized integer */ - CU_AD_FORMAT_UNORM_INT8X4 = 0xc2, /**< 4 channel unsigned 8-bit normalized integer */ - CU_AD_FORMAT_UNORM_INT16X1 = 0xc3, /**< 1 channel unsigned 16-bit normalized integer */ - CU_AD_FORMAT_UNORM_INT16X2 = 0xc4, /**< 2 channel unsigned 16-bit normalized integer */ - CU_AD_FORMAT_UNORM_INT16X4 = 0xc5, /**< 4 channel unsigned 16-bit normalized integer */ - CU_AD_FORMAT_SNORM_INT8X1 = 0xc6, /**< 1 channel signed 8-bit normalized integer */ - CU_AD_FORMAT_SNORM_INT8X2 = 0xc7, /**< 2 channel signed 8-bit normalized integer */ - CU_AD_FORMAT_SNORM_INT8X4 = 0xc8, /**< 4 channel signed 8-bit normalized integer */ - CU_AD_FORMAT_SNORM_INT16X1 = 0xc9, /**< 1 channel signed 16-bit normalized integer */ - CU_AD_FORMAT_SNORM_INT16X2 = 0xca, /**< 2 channel signed 16-bit normalized integer */ - CU_AD_FORMAT_SNORM_INT16X4 = 0xcb, /**< 4 channel signed 16-bit normalized integer */ - CU_AD_FORMAT_BC1_UNORM = 0x91, /**< 4 channel unsigned normalized block-compressed (BC1 compression) format */ - CU_AD_FORMAT_BC1_UNORM_SRGB = 0x92, /**< 4 channel unsigned normalized block-compressed (BC1 compression) format with sRGB encoding*/ - CU_AD_FORMAT_BC2_UNORM = 0x93, /**< 4 channel unsigned normalized block-compressed (BC2 compression) format */ - CU_AD_FORMAT_BC2_UNORM_SRGB = 0x94, /**< 4 channel unsigned normalized block-compressed (BC2 compression) format with sRGB encoding*/ - CU_AD_FORMAT_BC3_UNORM = 0x95, /**< 4 channel unsigned normalized block-compressed (BC3 compression) format */ - CU_AD_FORMAT_BC3_UNORM_SRGB = 0x96, /**< 4 channel unsigned normalized block-compressed (BC3 compression) format with sRGB encoding*/ - CU_AD_FORMAT_BC4_UNORM = 0x97, /**< 1 channel unsigned normalized block-compressed (BC4 compression) format */ - CU_AD_FORMAT_BC4_SNORM = 0x98, /**< 1 channel signed normalized block-compressed (BC4 compression) format */ - CU_AD_FORMAT_BC5_UNORM = 0x99, /**< 2 channel unsigned normalized block-compressed (BC5 compression) format */ - CU_AD_FORMAT_BC5_SNORM = 0x9a, /**< 2 channel signed normalized block-compressed (BC5 compression) format */ - CU_AD_FORMAT_BC6H_UF16 = 0x9b, /**< 3 channel unsigned half-float block-compressed (BC6H compression) format */ - CU_AD_FORMAT_BC6H_SF16 = 0x9c, /**< 3 channel signed half-float block-compressed (BC6H compression) format */ - CU_AD_FORMAT_BC7_UNORM = 0x9d, /**< 4 channel unsigned normalized block-compressed (BC7 compression) format */ - CU_AD_FORMAT_BC7_UNORM_SRGB = 0x9e /**< 4 channel unsigned normalized block-compressed (BC7 compression) format with sRGB encoding */ + CU_AD_FORMAT_NV12 = 0xb0 } CUarray_format; /** @@ -657,7 +631,7 @@ typedef enum CUdevice_attribute_enum { CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED = 102, /**< Device supports virtual memory management APIs like ::cuMemAddressReserve, ::cuMemCreate, ::cuMemMap and related APIs */ CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR_SUPPORTED = 103, /**< Device supports exporting memory to a posix file descriptor with ::cuMemExportToShareableHandle, if requested via ::cuMemCreate */ CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_WIN32_HANDLE_SUPPORTED = 104, /**< Device supports exporting memory to a Win32 NT handle with ::cuMemExportToShareableHandle, if requested via ::cuMemCreate */ - CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_WIN32_KMT_HANDLE_SUPPORTED = 105, /**< Device supports exporting memory to a Win32 KMT handle with ::cuMemExportToShareableHandle, if requested via ::cuMemCreate */ + CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_WIN32_KMT_HANDLE_SUPPORTED = 105, /**< Device supports exporting memory to a Win32 KMT handle with ::cuMemExportToShareableHandle, if requested ::cuMemCreate */ CU_DEVICE_ATTRIBUTE_MAX_BLOCKS_PER_MULTIPROCESSOR = 106, /**< Maximum number of blocks per multiprocessor */ CU_DEVICE_ATTRIBUTE_GENERIC_COMPRESSION_SUPPORTED = 107, /**< Device supports compression of memory */ CU_DEVICE_ATTRIBUTE_MAX_PERSISTING_L2_CACHE_SIZE = 108, /**< Maximum L2 persisting lines capacity setting in bytes. */ @@ -665,7 +639,7 @@ typedef enum CUdevice_attribute_enum { CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_WITH_CUDA_VMM_SUPPORTED = 110, /**< Device supports specifying the GPUDirect RDMA flag with ::cuMemCreate */ CU_DEVICE_ATTRIBUTE_RESERVED_SHARED_MEMORY_PER_BLOCK = 111, /**< Shared memory reserved by CUDA driver per block in bytes */ CU_DEVICE_ATTRIBUTE_SPARSE_CUDA_ARRAY_SUPPORTED = 112, /**< Device supports sparse CUDA arrays and sparse CUDA mipmapped arrays */ - CU_DEVICE_ATTRIBUTE_READ_ONLY_HOST_REGISTER_SUPPORTED = 113, /**< Device supports using the ::cuMemHostRegister flag ::CU_MEMHOSTERGISTER_READ_ONLY to register memory that must be mapped as read-only to the GPU */ + CU_DEVICE_ATTRIBUTE_READ_ONLY_HOST_REGISTER_SUPPORTED = 113, /**< Device supports using the ::cuMemHostRegister flag CU_MEMHOSTERGISTER_READ_ONLY to register memory that must be mapped as read-only to the GPU */ CU_DEVICE_ATTRIBUTE_TIMELINE_SEMAPHORE_INTEROP_SUPPORTED = 114, /**< External timeline semaphore interop is supported on the device */ CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED = 115, /**< Device supports using the ::cuMemAllocAsync and ::cuMemPool family of APIs */ CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_SUPPORTED = 116, /**< Device supports GPUDirect RDMA APIs, like nvidia_p2p_get_pages (see https://docs.nvidia.com/cuda/gpudirect-rdma for more information) */ @@ -1650,8 +1624,7 @@ typedef enum cudaError_enum { CUDA_ERROR_UNSUPPORTED_EXEC_AFFINITY = 224, /** - * This indicates that the device kernel source is invalid. This includes - * compilation/linker errors encountered in device code or user error. + * This indicates that the device kernel source is invalid. */ CUDA_ERROR_INVALID_SOURCE = 300, @@ -2068,9 +2041,9 @@ typedef size_t (CUDA_CB *CUoccupancyB2DSize)(int blockSize); * On Windows the flag is a no-op. * On Linux that memory is marked as non cache-coherent for the GPU and * is expected to be physically contiguous. It may return - * ::CUDA_ERROR_NOT_PERMITTED if run as an unprivileged user, - * ::CUDA_ERROR_NOT_SUPPORTED on older Linux kernel versions. - * On all other platforms, it is not supported and ::CUDA_ERROR_NOT_SUPPORTED + * CUDA_ERROR_NOT_PERMITTED if run as an unprivileged user, + * CUDA_ERROR_NOT_SUPPORTED on older Linux kernel versions. + * On all other platforms, it is not supported and CUDA_ERROR_NOT_SUPPORTED * is returned. * Flag for ::cuMemHostRegister() */ @@ -2079,12 +2052,12 @@ typedef size_t (CUDA_CB *CUoccupancyB2DSize)(int blockSize); /** * If set, the passed memory pointer is treated as pointing to memory that is * considered read-only by the device. On platforms without -* ::CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS_USES_HOST_PAGE_TABLES, this flag is +* CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS_USES_HOST_PAGE_TABLES, this flag is * required in order to register memory mapped to the CPU as read-only. Support * for the use of this flag can be queried from the device attribute -* ::CU_DEVICE_ATTRIBUTE_READ_ONLY_HOST_REGISTER_SUPPORTED. Using this flag with +* CU_DEVICE_ATTRIBUTE_READ_ONLY_HOST_REGISTER_SUPPORTED. Using this flag with * a current context associated with a device that does not have this attribute -* set will cause ::cuMemHostRegister to error with ::CUDA_ERROR_NOT_SUPPORTED. +* set will cause ::cuMemHostRegister to error with CUDA_ERROR_NOT_SUPPORTED. */ #define CU_MEMHOSTREGISTER_READ_ONLY 0x08 @@ -3735,117 +3708,117 @@ CUresult CUDAAPI cuDeviceGetTexture1DLinearMaxWidth(size_t *maxWidthInElements, * \p dev. The supported attributes are: * - ::CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_BLOCK: Maximum number of threads per * block; - * - ::CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X: Maximum x-dimension of a block - * - ::CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Y: Maximum y-dimension of a block - * - ::CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Z: Maximum z-dimension of a block - * - ::CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_X: Maximum x-dimension of a grid - * - ::CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Y: Maximum y-dimension of a grid - * - ::CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Z: Maximum z-dimension of a grid + * - ::CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X: Maximum x-dimension of a block; + * - ::CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Y: Maximum y-dimension of a block; + * - ::CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Z: Maximum z-dimension of a block; + * - ::CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_X: Maximum x-dimension of a grid; + * - ::CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Y: Maximum y-dimension of a grid; + * - ::CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Z: Maximum z-dimension of a grid; * - ::CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK: Maximum amount of - * shared memory available to a thread block in bytes + * shared memory available to a thread block in bytes; * - ::CU_DEVICE_ATTRIBUTE_TOTAL_CONSTANT_MEMORY: Memory available on device for - * __constant__ variables in a CUDA C kernel in bytes - * - ::CU_DEVICE_ATTRIBUTE_WARP_SIZE: Warp size in threads + * __constant__ variables in a CUDA C kernel in bytes; + * - ::CU_DEVICE_ATTRIBUTE_WARP_SIZE: Warp size in threads; * - ::CU_DEVICE_ATTRIBUTE_MAX_PITCH: Maximum pitch in bytes allowed by the * memory copy functions that involve memory regions allocated through - * ::cuMemAllocPitch() + * ::cuMemAllocPitch(); * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_WIDTH: Maximum 1D - * texture width + * texture width; * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_LINEAR_WIDTH: Maximum width - * for a 1D texture bound to linear memory + * for a 1D texture bound to linear memory; * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_MIPMAPPED_WIDTH: Maximum - * mipmapped 1D texture width + * mipmapped 1D texture width; * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_WIDTH: Maximum 2D - * texture width + * texture width; * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_HEIGHT: Maximum 2D - * texture height + * texture height; * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LINEAR_WIDTH: Maximum width - * for a 2D texture bound to linear memory + * for a 2D texture bound to linear memory; * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LINEAR_HEIGHT: Maximum height - * for a 2D texture bound to linear memory + * for a 2D texture bound to linear memory; * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LINEAR_PITCH: Maximum pitch - * in bytes for a 2D texture bound to linear memory + * in bytes for a 2D texture bound to linear memory; * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_MIPMAPPED_WIDTH: Maximum - * mipmapped 2D texture width + * mipmapped 2D texture width; * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_MIPMAPPED_HEIGHT: Maximum - * mipmapped 2D texture height + * mipmapped 2D texture height; * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_WIDTH: Maximum 3D - * texture width + * texture width; * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_HEIGHT: Maximum 3D - * texture height + * texture height; * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_DEPTH: Maximum 3D - * texture depth + * texture depth; * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_WIDTH_ALTERNATE: * Alternate maximum 3D texture width, 0 if no alternate - * maximum 3D texture size is supported + * maximum 3D texture size is supported; * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_HEIGHT_ALTERNATE: * Alternate maximum 3D texture height, 0 if no alternate - * maximum 3D texture size is supported + * maximum 3D texture size is supported; * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_DEPTH_ALTERNATE: * Alternate maximum 3D texture depth, 0 if no alternate - * maximum 3D texture size is supported + * maximum 3D texture size is supported; * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURECUBEMAP_WIDTH: - * Maximum cubemap texture width or height + * Maximum cubemap texture width or height; * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_LAYERED_WIDTH: - * Maximum 1D layered texture width + * Maximum 1D layered texture width; * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_LAYERED_LAYERS: - * Maximum layers in a 1D layered texture + * Maximum layers in a 1D layered texture; * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LAYERED_WIDTH: - * Maximum 2D layered texture width + * Maximum 2D layered texture width; * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LAYERED_HEIGHT: - * Maximum 2D layered texture height + * Maximum 2D layered texture height; * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LAYERED_LAYERS: - * Maximum layers in a 2D layered texture + * Maximum layers in a 2D layered texture; * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURECUBEMAP_LAYERED_WIDTH: - * Maximum cubemap layered texture width or height + * Maximum cubemap layered texture width or height; * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURECUBEMAP_LAYERED_LAYERS: - * Maximum layers in a cubemap layered texture + * Maximum layers in a cubemap layered texture; * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE1D_WIDTH: - * Maximum 1D surface width + * Maximum 1D surface width; * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_WIDTH: - * Maximum 2D surface width + * Maximum 2D surface width; * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_HEIGHT: - * Maximum 2D surface height + * Maximum 2D surface height; * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE3D_WIDTH: - * Maximum 3D surface width + * Maximum 3D surface width; * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE3D_HEIGHT: - * Maximum 3D surface height + * Maximum 3D surface height; * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE3D_DEPTH: - * Maximum 3D surface depth + * Maximum 3D surface depth; * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE1D_LAYERED_WIDTH: - * Maximum 1D layered surface width + * Maximum 1D layered surface width; * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE1D_LAYERED_LAYERS: - * Maximum layers in a 1D layered surface + * Maximum layers in a 1D layered surface; * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_LAYERED_WIDTH: - * Maximum 2D layered surface width + * Maximum 2D layered surface width; * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_LAYERED_HEIGHT: - * Maximum 2D layered surface height + * Maximum 2D layered surface height; * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_LAYERED_LAYERS: - * Maximum layers in a 2D layered surface + * Maximum layers in a 2D layered surface; * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACECUBEMAP_WIDTH: - * Maximum cubemap surface width + * Maximum cubemap surface width; * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACECUBEMAP_LAYERED_WIDTH: - * Maximum cubemap layered surface width + * Maximum cubemap layered surface width; * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACECUBEMAP_LAYERED_LAYERS: - * Maximum layers in a cubemap layered surface + * Maximum layers in a cubemap layered surface; * - ::CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_BLOCK: Maximum number of 32-bit - * registers available to a thread block - * - ::CU_DEVICE_ATTRIBUTE_CLOCK_RATE: The typical clock frequency in kilohertz + * registers available to a thread block; + * - ::CU_DEVICE_ATTRIBUTE_CLOCK_RATE: The typical clock frequency in kilohertz; * - ::CU_DEVICE_ATTRIBUTE_TEXTURE_ALIGNMENT: Alignment requirement; texture * base addresses aligned to ::textureAlign bytes do not need an offset - * applied to texture fetches + * applied to texture fetches; * - ::CU_DEVICE_ATTRIBUTE_TEXTURE_PITCH_ALIGNMENT: Pitch alignment requirement - * for 2D texture references bound to pitched memory + * for 2D texture references bound to pitched memory; * - ::CU_DEVICE_ATTRIBUTE_GPU_OVERLAP: 1 if the device can concurrently copy - * memory between host and device while executing a kernel, or 0 if not + * memory between host and device while executing a kernel, or 0 if not; * - ::CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT: Number of multiprocessors on - * the device + * the device; * - ::CU_DEVICE_ATTRIBUTE_KERNEL_EXEC_TIMEOUT: 1 if there is a run time limit - * for kernels executed on the device, or 0 if not + * for kernels executed on the device, or 0 if not; * - ::CU_DEVICE_ATTRIBUTE_INTEGRATED: 1 if the device is integrated with the - * memory subsystem, or 0 if not + * memory subsystem, or 0 if not; * - ::CU_DEVICE_ATTRIBUTE_CAN_MAP_HOST_MEMORY: 1 if the device can map host - * memory into the CUDA address space, or 0 if not + * memory into the CUDA address space, or 0 if not; * - ::CU_DEVICE_ATTRIBUTE_COMPUTE_MODE: Compute mode that device is currently * in. Available modes are as follows: * - ::CU_COMPUTEMODE_DEFAULT: Default mode - Device is not restricted and @@ -3858,33 +3831,33 @@ CUresult CUDAAPI cuDeviceGetTexture1DLinearMaxWidth(size_t *maxWidthInElements, * executing multiple kernels within the same context simultaneously, or 0 if * not. It is not guaranteed that multiple kernels will be resident * on the device concurrently so this feature should not be relied upon for - * correctness. + * correctness; * - ::CU_DEVICE_ATTRIBUTE_ECC_ENABLED: 1 if error correction is enabled on the - * device, 0 if error correction is disabled or not supported by the device - * - ::CU_DEVICE_ATTRIBUTE_PCI_BUS_ID: PCI bus identifier of the device + * device, 0 if error correction is disabled or not supported by the device; + * - ::CU_DEVICE_ATTRIBUTE_PCI_BUS_ID: PCI bus identifier of the device; * - ::CU_DEVICE_ATTRIBUTE_PCI_DEVICE_ID: PCI device (also known as slot) identifier - * of the device + * of the device; * - ::CU_DEVICE_ATTRIBUTE_PCI_DOMAIN_ID: PCI domain identifier of the device * - ::CU_DEVICE_ATTRIBUTE_TCC_DRIVER: 1 if the device is using a TCC driver. TCC - * is only available on Tesla hardware running Windows Vista or later - * - ::CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE: Peak memory clock frequency in kilohertz - * - ::CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH: Global memory bus width in bits - * - ::CU_DEVICE_ATTRIBUTE_L2_CACHE_SIZE: Size of L2 cache in bytes. 0 if the device doesn't have L2 cache - * - ::CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_MULTIPROCESSOR: Maximum resident threads per multiprocessor + * is only available on Tesla hardware running Windows Vista or later; + * - ::CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE: Peak memory clock frequency in kilohertz; + * - ::CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH: Global memory bus width in bits; + * - ::CU_DEVICE_ATTRIBUTE_L2_CACHE_SIZE: Size of L2 cache in bytes. 0 if the device doesn't have L2 cache; + * - ::CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_MULTIPROCESSOR: Maximum resident threads per multiprocessor; * - ::CU_DEVICE_ATTRIBUTE_UNIFIED_ADDRESSING: 1 if the device shares a unified address space with - * the host, or 0 if not - * - ::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR: Major compute capability version number - * - ::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR: Minor compute capability version number + * the host, or 0 if not; + * - ::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR: Major compute capability version number; + * - ::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR: Minor compute capability version number; * - ::CU_DEVICE_ATTRIBUTE_GLOBAL_L1_CACHE_SUPPORTED: 1 if device supports caching globals - * in L1 cache, 0 if caching globals in L1 cache is not supported by the device + * in L1 cache, 0 if caching globals in L1 cache is not supported by the device; * - ::CU_DEVICE_ATTRIBUTE_LOCAL_L1_CACHE_SUPPORTED: 1 if device supports caching locals - * in L1 cache, 0 if caching locals in L1 cache is not supported by the device + * in L1 cache, 0 if caching locals in L1 cache is not supported by the device; * - ::CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR: Maximum amount of * shared memory available to a multiprocessor in bytes; this amount is shared - * by all thread blocks simultaneously resident on a multiprocessor + * by all thread blocks simultaneously resident on a multiprocessor; * - ::CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_MULTIPROCESSOR: Maximum number of 32-bit * registers available to a multiprocessor; this number is shared by all thread - * blocks simultaneously resident on a multiprocessor + * blocks simultaneously resident on a multiprocessor; * - ::CU_DEVICE_ATTRIBUTE_MANAGED_MEMORY: 1 if device supports allocating managed memory * on this system, 0 if allocating managed memory is not supported by the device on this system. * - ::CU_DEVICE_ATTRIBUTE_MULTI_GPU_BOARD: 1 if device is on a multi-GPU board, 0 if not. @@ -3910,20 +3883,14 @@ CUresult CUDAAPI cuDeviceGetTexture1DLinearMaxWidth(size_t *maxWidthInElements, * - ::CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED: Device supports virtual memory management APIs like ::cuMemAddressReserve, ::cuMemCreate, ::cuMemMap and related APIs * - ::CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR_SUPPORTED: Device supports exporting memory to a posix file descriptor with ::cuMemExportToShareableHandle, if requested via ::cuMemCreate * - ::CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_WIN32_HANDLE_SUPPORTED: Device supports exporting memory to a Win32 NT handle with ::cuMemExportToShareableHandle, if requested via ::cuMemCreate - * - ::CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_WIN32_KMT_HANDLE_SUPPORTED: Device supports exporting memory to a Win32 KMT handle with ::cuMemExportToShareableHandle, if requested via ::cuMemCreate - * - ::CU_DEVICE_ATTRIBUTE_MAX_BLOCKS_PER_MULTIPROCESSOR: Maximum number of thread blocks that can reside on a multiprocessor + * - ::CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_WIN32_KMT_HANDLE_SUPPORTED: Device supports exporting memory to a Win32 KMT handle with ::cuMemExportToShareableHandle, if requested ::cuMemCreate + * - ::CU_DEVICE_ATTRIBUTE_MAX_PERSISTING_L2_CACHE_SIZE: Maximum L2 persisting lines capacity setting in bytes. + * - ::CU_DEVICE_ATTRIBUTE_MAX_ACCESS_POLICY_WINDOW_SIZE: Maximum value of CUaccessPolicyWindow::num_bytes. + * - ::CU_DEVICE_ATTRIBUTE_MAX_BLOCKS_PER_MULTIPROCESSOR: Maximum number of thread blocks that can reside on a multiprocessor. * - ::CU_DEVICE_ATTRIBUTE_GENERIC_COMPRESSION_SUPPORTED: Device supports compressible memory allocation via ::cuMemCreate - * - ::CU_DEVICE_ATTRIBUTE_MAX_PERSISTING_L2_CACHE_SIZE: Maximum L2 persisting lines capacity setting in bytes - * - ::CU_DEVICE_ATTRIBUTE_MAX_ACCESS_POLICY_WINDOW_SIZE: Maximum value of CUaccessPolicyWindow::num_bytes - * - ::CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_WITH_CUDA_VMM_SUPPORTED: Device supports specifying the GPUDirect RDMA flag with ::cuMemCreate. - * - ::CU_DEVICE_ATTRIBUTE_RESERVED_SHARED_MEMORY_PER_BLOCK: Amount of shared memory per block reserved by CUDA driver in bytes - * - ::CU_DEVICE_ATTRIBUTE_SPARSE_CUDA_ARRAY_SUPPORTED: Device supports sparse CUDA arrays and sparse CUDA mipmapped arrays. - * - ::CU_DEVICE_ATTRIBUTE_READ_ONLY_HOST_REGISTER_SUPPORTED: Device supports using the ::cuMemHostRegister flag ::CU_MEMHOSTERGISTER_READ_ONLY to register memory that must be mapped as read-only to the GPU + * - ::CU_DEVICE_ATTRIBUTE_RESERVED_SHARED_MEMORY_PER_BLOCK: Amount of shared memory per block reserved by CUDA driver in bytes. + * - ::CU_DEVICE_ATTRIBUTE_READ_ONLY_HOST_REGISTER_SUPPORTED: Device supports using the ::cuMemHostRegister flag CU_MEMHOSTERGISTER_READ_ONLY to register memory that must be mapped as read-only to the GPU * - ::CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED: Device supports using the ::cuMemAllocAsync and ::cuMemPool family of APIs - * - ::CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_SUPPORTED: Device supports GPUDirect RDMA APIs, like nvidia_p2p_get_pages (see https://docs.nvidia.com/cuda/gpudirect-rdma for more information) - * - ::CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_FLUSH_WRITES_OPTIONS: The returned attribute shall be interpreted as a bitmask, where the individual bits are described by the ::CUflushGPUDirectRDMAWritesOptions enum - * - ::CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_WRITES_ORDERING: GPUDirect RDMA writes to the device do not need to be flushed for consumers within the scope indicated by the returned attribute. See ::CUGPUDirectRDMAWritesOrdering for the numerical values returned here. - * - ::CU_DEVICE_ATTRIBUTE_MEMPOOL_SUPPORTED_HANDLE_TYPES: Bitmask of handle types supported with mempool based IPC * * \param pi - Returned device attribute value * \param attrib - Device attribute to query @@ -4690,13 +4657,6 @@ CUresult CUDAAPI cuCtxCreate_v3(CUcontext *pctx, CUexecAffinityParam *paramsArra * It is the responsibility of the calling function to ensure that no API * call issues using \p ctx while ::cuCtxDestroy() is executing. * - * Destroys and cleans up all resources associated with the context. - * It is the caller's responsibility to ensure that the context or its resources - * are not accessed or passed in subsequent API calls and doing so will result in undefined behavior. - * These resources include CUDA types such as ::CUmodule, ::CUfunction, ::CUstream, ::CUevent, - * ::CUarray, ::CUmipmappedArray, ::CUtexObject, ::CUsurfObject, ::CUtexref, ::CUsurfref, - * ::CUgraphicsResource, ::CUlinkState, ::CUexternalMemory and ::CUexternalSemaphore. - * * If \p ctx is current to the calling thread then \p ctx will also be * popped from the current thread's context stack (as though ::cuCtxPopCurrent() * were called). If \p ctx is current to other threads, then \p ctx will @@ -5672,7 +5632,6 @@ CUresult CUDAAPI cuModuleLoadFatBinary(CUmodule *module, const void *fatCubin); * ::CUDA_ERROR_INVALID_CONTEXT, * ::CUDA_ERROR_INVALID_VALUE * \notefnerr - * \note_destroy_ub * * \sa ::cuModuleGetFunction, * ::cuModuleGetGlobal, @@ -5993,9 +5952,8 @@ cuLinkDestroy(CUlinkState state); /** * \brief Gets free and total memory * - * Returns in \p *total the total amount of memory available to the the current context. - * Returns in \p *free the amount of memory on the device that is free according to the OS. - * CUDA is not guaranteed to be able to allocate all of the memory that the OS reports as free. + * Returns in \p *free and \p *total respectively, the free and total amount of + * memory available for allocation by the CUDA context, in bytes. * * \param free - Returned free memory in bytes * \param total - Returned total memory in bytes @@ -6839,10 +6797,10 @@ CUresult CUDAAPI cuIpcCloseMemHandle(CUdeviceptr dptr); * * - ::CU_MEMHOSTREGISTER_READ_ONLY: The pointer is treated as pointing to memory * that is considered read-only by the device. On platforms without - * ::CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS_USES_HOST_PAGE_TABLES, this flag is + * CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS_USES_HOST_PAGE_TABLES, this flag is * required in order to register memory mapped to the CPU as read-only. Support * for the use of this flag can be queried from the device attribute - * ::CU_DEVICE_ATTRIBUTE_READ_ONLY_HOST_REGISTER_SUPPORTED. Using this flag with + * CU_DEVICE_ATTRIBUTE_READ_ONLY_HOST_REGISTER_SUPPORTED. Using this flag with * a current context associated with a device that does not have this attribute * set will cause ::cuMemHostRegister to error with CUDA_ERROR_NOT_SUPPORTED. * @@ -8987,7 +8945,7 @@ CUresult CUDAAPI cuMemsetD2D32Async(CUdeviceptr dstDevice, size_t dstPitch, unsi * float16's: * \code CUDA_ARRAY_DESCRIPTOR desc; - desc.Format = CU_AD_FORMAT_HALF; + desc.FormatFlags = CU_AD_FORMAT_HALF; desc.NumChannels = 4; desc.Width = width; desc.Height = height; @@ -8997,7 +8955,7 @@ CUresult CUDAAPI cuMemsetD2D32Async(CUdeviceptr dstDevice, size_t dstPitch, unsi * of which is two 8-bit unsigned chars: * \code CUDA_ARRAY_DESCRIPTOR arrayDesc; - desc.Format = CU_AD_FORMAT_UNSIGNED_INT8; + desc.FormatFlags = CU_AD_FORMAT_UNSIGNED_INT8; desc.NumChannels = 2; desc.Width = width; desc.Height = height; @@ -9323,7 +9281,7 @@ CUresult CUDAAPI cuArrayDestroy(CUarray hArray); * 4x16-bit float16's: * \code CUDA_ARRAY3D_DESCRIPTOR desc; - desc.Format = CU_AD_FORMAT_HALF; + desc.FormatFlags = CU_AD_FORMAT_HALF; desc.NumChannels = 4; desc.Width = width; desc.Height = height; @@ -15180,7 +15138,7 @@ CUresult CUDAAPI cuGraphExternalSemaphoresWaitNodeSetParams(CUgraphNode hNode, c * \param nodeParams - Parameters for the node * * When ::cuGraphAddMemAllocNode creates an allocation node, it returns the address of the allocation in - * \p nodeParams.dptr. The allocation's address remains fixed across instantiations and launches. + * \param nodeParams.dptr. The allocation's address remains fixed across instantiations and launches. * * If the allocation is freed in the same graph, by creating a free node using ::cuGraphAddMemFreeNode, * the allocation can be accessed by nodes ordered after the allocation node but before the free node. @@ -15356,9 +15314,7 @@ CUresult CUDAAPI cuGraphMemFreeNodeGetParams(CUgraphNode hNode, CUdeviceptr *dpt * * \sa * ::cuGraphAddMemAllocNode, - * ::cuGraphAddMemFreeNode, - * ::cuDeviceSetGraphMemAttribute, - * ::cuDeviceGetGraphMemAttribute + * ::cuGraphAddMemFreeNode */ CUresult CUDAAPI cuDeviceGraphMemTrim(CUdevice device); @@ -15384,7 +15340,6 @@ CUresult CUDAAPI cuDeviceGraphMemTrim(CUdevice device); * ::CUDA_ERROR_INVALID_DEVICE * * \sa - * ::cuDeviceSetGraphMemAttribute, * ::cuGraphAddMemAllocNode, * ::cuGraphAddMemFreeNode */ @@ -15409,7 +15364,6 @@ CUresult CUDAAPI cuDeviceGetGraphMemAttribute(CUdevice device, CUgraphMem_attrib * ::CUDA_ERROR_INVALID_DEVICE * * \sa - * ::cuDeviceGetGraphMemAttribute, * ::cuGraphAddMemAllocNode, * ::cuGraphAddMemFreeNode */ diff --git a/include/triton/ir/basic_block.h b/include/triton/ir/basic_block.h index 840145246..26d406baf 100644 --- a/include/triton/ir/basic_block.h +++ b/include/triton/ir/basic_block.h @@ -1,4 +1,4 @@ -#pragma once +#pragma once #ifndef _TRITON_IR_BASIC_BLOCK_H_ #define _TRITON_IR_BASIC_BLOCK_H_ @@ -27,7 +27,7 @@ public: private: // constructors - basic_block(context &ctx, const std::string &name, function *parent); + basic_block(context &ctx, const std::string &name, function *parent, basic_block *next); public: // accessors @@ -35,6 +35,7 @@ public: context& get_context() { return ctx_; } // get iterator to first instruction that is not a phi + void replace_phi_uses_with(basic_block* before, basic_block* after); iterator get_first_non_phi(); // get instruction list @@ -60,13 +61,16 @@ public: inline const instruction &back() const { return *inst_list_.back(); } inline instruction &back() { return *inst_list_.back(); } + void append_instruction(ir::instruction* i); + // split + basic_block* split_before(ir::instruction* loc, const std::string& name); + // predecessors - const std::vector& get_predecessors() const { return preds_; } - const std::vector& get_successors() const { return succs_; } - void add_predecessor(basic_block* pred); + std::vector get_predecessors() const; + std::vector get_successors() const; // factory functions - static basic_block* create(context &ctx, const std::string &name, function *parent); + static basic_block* create(context &ctx, const std::string &name, function *parent, basic_block *next = nullptr); void print(std::ostream &os); diff --git a/include/triton/ir/builder.h b/include/triton/ir/builder.h index 45a7d5111..ff8447124 100644 --- a/include/triton/ir/builder.h +++ b/include/triton/ir/builder.h @@ -22,6 +22,7 @@ class phi_node; /* Builder */ class builder{ +public: typedef basic_block::iterator iterator; public: @@ -75,6 +76,7 @@ public: value* create_br(basic_block *dest); value* create_cond_br(value *cond, basic_block* if_dest, basic_block* else_dest); value* create_ret_void(); + value* create_ret(value *ret); // Cast instructions value *create_cast(cast_op_t op, value *v, type *dst_ty); value* create_ptr_to_int(value *src, type *dst_ty); @@ -86,6 +88,9 @@ public: value* create_fp_trunc(value *src, type *dst_ty); value* create_int_cast(value *src, type *dst_ty, bool is_signed); value *create_downcast(value *arg); + // Call instruction + value* create_call(function* fn, const std::vector& args); + value* create_launch(function* fn, const std::vector& args, const std::vector& grid, value* num_warps); // Phi instruction phi_node* create_phi(type *ty, unsigned num_reserved); // Binary instructions @@ -142,6 +147,9 @@ public: value *create_store(value *ptr, value *val); value *create_masked_load(value *arg, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile); value *create_masked_store(value *ptr, value *val, value *mask); + // Struct instructions + value *create_insert_value(value* val, value *elt, size_t idx); + value *create_extract_value(value* val, size_t idx); // Block instruction value *create_splat(value *arg, const type::block_shapes_t &shapes); value *create_reshape(value *arg, const type::block_shapes_t &shapes); diff --git a/include/triton/ir/context_impl.h b/include/triton/ir/context_impl.h index 081ea249d..619ae4c87 100644 --- a/include/triton/ir/context_impl.h +++ b/include/triton/ir/context_impl.h @@ -31,7 +31,8 @@ public: std::map, std::unique_ptr> ptr_tys; // Block types std::map, std::unique_ptr> block_tys; - + // Struct types + std::map struct_tys; // Int constants std::map, std::unique_ptr> int_constants_; // Float constants diff --git a/include/triton/ir/enums.h b/include/triton/ir/enums.h index 2d4c09d79..3fa008606 100644 --- a/include/triton/ir/enums.h +++ b/include/triton/ir/enums.h @@ -95,6 +95,9 @@ enum value_id_t: unsigned { INSTRUCTIONS * ------------ */ INST_BEGIN, + // call + INST_CALL, + INST_LAUNCH, // phi INST_PHI, // arithmetic @@ -129,6 +132,9 @@ enum value_id_t: unsigned { INST_MASKED_LOAD_ASYNC, INST_UNMASKED_STORE, INST_MASKED_STORE, + // struct + INST_EXTRACT_VALUE, + INST_INSERT_VALUE, // retile INST_RESHAPE, INST_SPLAT, diff --git a/include/triton/ir/function.h b/include/triton/ir/function.h index 9e1bc981a..4e76e60a4 100644 --- a/include/triton/ir/function.h +++ b/include/triton/ir/function.h @@ -24,7 +24,7 @@ public: static argument* create(type *ty, const std::string &name, function *parent = nullptr, unsigned arg_no = 0); function* get_parent() const; - unsigned get_arg_no() const; + unsigned get_arg_no() const; void accept(visitor *v); @@ -121,6 +121,8 @@ public: const attr_map_t &attrs() { return attrs_; } bool has_attr(unsigned arg_id) const { return attrs_.find(arg_id) != attrs_.end(); } std::set get_attributes(const argument* arg) { return attrs_[arg->get_arg_no() + 1]; } + void set_is_kernel(bool new_val) { is_kernel_ = new_val; } + bool get_is_kernel() { return is_kernel_; } void print(std::ostream &os); @@ -134,6 +136,7 @@ private: args_t args_; blocks_t blocks_; attr_map_t attrs_; + bool is_kernel_; }; } diff --git a/include/triton/ir/instructions.h b/include/triton/ir/instructions.h index e9e0f0f11..c2d427ae8 100644 --- a/include/triton/ir/instructions.h +++ b/include/triton/ir/instructions.h @@ -81,6 +81,51 @@ private: value_id_t id_; }; +//===----------------------------------------------------------------------===// +// call_inst classes +//===----------------------------------------------------------------------===// + +class call_inst: public instruction { +private: + std::string repr_impl() const; + call_inst(ir::function* fn, const std::vector& values, const std::string& name, instruction* next); + +public: + static call_inst* create(ir::function* fn, const std::vector& values, const std::string &name = "", instruction *next = nullptr); + ir::function* get_fn() { return fn_; } + + _TRITON_DEFINE_CLONE(call_inst) + _TRITON_DEFINE_ACCEPT(call_inst) + +private: + ir::function* fn_; +}; + +class launch_inst: public instruction { +private: + std::string repr_impl() const { return "launch"; } + launch_inst(ir::function* fn, const std::vector& values, const std::vector& grid, ir::value* num_warps, + const std::string &name = "", instruction *next = nullptr); + +public: + static launch_inst* create(ir::function* fn, const std::vector& values, const std::vector& grid, ir::value* num_warps, + const std::string& name = "", instruction* next = nullptr); + + ir::function* get_fn(); + std::vector get_values(); + std::vector get_grid(); + ir::value* get_num_warps(); + + + _TRITON_DEFINE_CLONE(launch_inst) + _TRITON_DEFINE_ACCEPT(launch_inst) + +private: + unsigned val_begin; + unsigned val_end; + unsigned grid_begin; + unsigned grid_end; +}; //===----------------------------------------------------------------------===// // phi_node classes @@ -546,6 +591,44 @@ public: _TRITON_DEFINE_ACCEPT(masked_store_inst) }; +//===----------------------------------------------------------------------===// +// struct classes +//===----------------------------------------------------------------------===// + +// insert_value + +class insert_value_inst: public instruction { +private: + std::string repr_impl() const { return "insertvalue"; } + insert_value_inst(value *val, value *elt, size_t idx, const std::string &name, instruction *next); + +public: + static insert_value_inst* create(value *val, value* elt, size_t idx, const std::string &name = "", instruction *next = nullptr); + size_t get_idx() { return idx_; } + _TRITON_DEFINE_CLONE(insert_value_inst) + _TRITON_DEFINE_ACCEPT(insert_value_inst) + +private: + size_t idx_; +}; + +// extract_value + +class extract_value_inst: public instruction { +private: + std::string repr_impl() const { return "extractvalue"; } + extract_value_inst(value *val, size_t idx, const std::string &name, instruction *next); + +public: + static extract_value_inst* create(value *val, size_t idx, const std::string &name = "", instruction *next = nullptr); + size_t get_idx() { return idx_; } + _TRITON_DEFINE_CLONE(extract_value_inst) + _TRITON_DEFINE_ACCEPT(extract_value_inst) + +private: + size_t idx_; +}; + //===----------------------------------------------------------------------===// // retile_inst classes //===----------------------------------------------------------------------===// diff --git a/include/triton/ir/module.h b/include/triton/ir/module.h index 30881fd49..f8f033eb7 100644 --- a/include/triton/ir/module.h +++ b/include/triton/ir/module.h @@ -34,79 +34,97 @@ class constant; class global_value; class alloc_const; -/* Module */ - -class module { +class value_constructor { typedef std::pair val_key_t; - friend class function; typedef std::pair md_pair_t; -public: - typedef std::map symbols_map_t; - typedef std::vector functions_list_t; - struct current_iteration_info_t{ - lang::iteration_statement *statement; - basic_block *block; - }; - private: phi_node *make_phi(type *ty, unsigned num_values, basic_block *block); value *try_remove_trivial_phis(ir::phi_node *&phi); value *add_phi_operands(const std::string& name, phi_node *&phi); value *get_value_recursive(const std::string& name, basic_block *block); + +public: + value_constructor(builder &builder); + + void set_value(const std::string& name, basic_block* block, value *x); + void set_value(const std::string& name, value* x); + const std::map& get_values() { return values_; } + void set_values(const std::map& values) { values_ = values; } + value *get_value(const std::string& name, basic_block* block); + value *get_value(const std::string& name); + void set_type(const std::string& name, ir::type* ty) { types_[name] = ty; } + // Seal block -- no more predecessors will be added + void seal_block(basic_block *block); + // Metadata + void add_metadata(const std::string &name, md_pair_t x) { metadatas_[name] = x; } + +private: + ir::builder& builder_; + std::map values_; + std::map types_; + std::set sealed_blocks_; + std::map> incomplete_phis_; + std::map current_phi_; + std::map metadatas_; +}; + +/* Module */ + +class module { + typedef std::pair val_key_t; + friend class function; + +public: + typedef std::map symbols_map_t; + typedef std::vector functions_list_t; + +private: void push_function(function *fn) { functions_.push_back(fn); } public: module(const std::string &name, builder& builder); builder& get_builder(); // Setters - void set_value(const std::string& name, basic_block* block, value *x); - void set_value(const std::string& name, value* x); - void set_const(const std::string& name); void set_continue_fn(std::function fn); // Getters - const std::map& get_values() { return values_; } - const std::map& get_types() { return types_; } - void set_values(const std::map& values) { values_ = values; } - void set_types(const std::map& types) { types_ = types; } - - value *get_value(const std::string& name, basic_block* block); - value *get_value(const std::string& name); - void set_type(const std::string& name, ir::type* ty) { types_[name] = ty; } const std::string& get_name(); std::function get_continue_fn(); - // Seal block -- no more predecessors will be added - void seal_block(basic_block *block); // Functions const functions_list_t &get_function_list() const { return functions_; } functions_list_t &get_function_list() { return functions_; } + function *get_function(const std::string& name) { + if(symbols_.find(name) == symbols_.end()) + throw std::runtime_error("function " + name + " is not declared"); + return (function*)symbols_.at(name); + } function *get_or_insert_function(const std::string &name, function_type *ty); + bool has_function(const std::string& name){ + return symbols_.find(name) != symbols_.end(); + } + void remove_function(ir::function* fn){ + functions_.erase(std::remove(functions_.begin(), functions_.end(), fn), functions_.end()); + } + + void reset_ret_ty(const std::string& name, type* ty); + // Const allocation void add_alloc(ir::alloc_const* x) { allocs_.push_back(x); } const std::vector& allocs() { return allocs_; } // Register global void register_global(const std::string& name, ir::value *x) { globals_[name] = x; } const std::map& globals() const { return globals_; } - // Metadata - void add_metadata(const std::string &name, md_pair_t x) { metadatas_[name] = x; } - + // void print(std::ostream &os); private: std::string name_; builder& builder_; - std::map values_; - std::map types_; - std::set const_; - std::set sealed_blocks_; - std::map> incomplete_phis_; functions_list_t functions_; symbols_map_t symbols_; std::function continue_fn_; - std::map current_phi_; std::vector allocs_; std::map globals_; - std::map metadatas_; }; } diff --git a/include/triton/ir/type.h b/include/triton/ir/type.h index 47c9b5f85..d7919b4c8 100644 --- a/include/triton/ir/type.h +++ b/include/triton/ir/type.h @@ -1,4 +1,4 @@ -#pragma once +#pragma once #ifndef _TRITON_IR_TYPE_H_ #define _TRITON_IR_TYPE_H_ @@ -73,6 +73,8 @@ public: type *get_tile_element_ty() const; unsigned get_pointer_address_space() const; type *get_pointer_element_ty() const; + unsigned get_struct_numel() const { return contained_tys_.size(); } + type *get_struct_type(unsigned int i) const { return contained_tys_[i]; } // primitive predicates bool is_void_ty() const { return id_ == VoidTyID; } @@ -91,6 +93,7 @@ public: bool is_bool_ty() const { return is_integer_ty(1); } bool is_pointer_ty() const { return id_ == PointerTyID; } bool is_block_ty() const { return id_ == BlockTyID; } + bool is_struct_ty() const { return id_ == StructTyID; } // Composite predicates bool is_int_or_tileint_ty(); @@ -138,10 +141,10 @@ public: switch(id_) { case VoidTyID: return "void"; case FP8TyID: return "fp8"; + case BF16TyID: return "bf16"; case FP16TyID: return "f16"; case FP32TyID: return "f32"; case FP64TyID: return "f64"; - case BF16TyID: return "bf16"; case LabelTyID: return "label"; case MetadataTyID: return "md"; case TokenTyID: return "tok"; @@ -194,6 +197,16 @@ public: type* get_type_at_index(value *idx) const; }; +class struct_type: public composite_type { +public: + struct_type(const contained_tys_vec_t& tys, bool is_packed); + unsigned get_num_types() const { return contained_tys_.size(); } + static struct_type* get(const contained_tys_vec_t& tys, bool is_packed); + +private: + bool is_packed_; +}; + class block_type: public composite_type { private: block_type(type *ty, const block_shapes_t &shapes); @@ -242,6 +255,7 @@ public: ty_iterator params_end() { return contained_tys_.end(); } type* get_param_ty(unsigned i) const { return contained_tys_.at(1 + i); } type* get_return_ty() const { return contained_tys_.at(0); } + void reset_ret_ty(type* ty) { contained_tys_[0] = ty;} // factory methods static function_type* get(type *ret_ty, const std::vector& param_tys); }; diff --git a/include/triton/ir/value.h b/include/triton/ir/value.h index 7a132d5e2..fde09121a 100644 --- a/include/triton/ir/value.h +++ b/include/triton/ir/value.h @@ -21,7 +21,7 @@ class visitor; class value { public: - typedef std::set users_t; + typedef std::vector users_t; public: // constructor @@ -30,7 +30,7 @@ public: // uses void add_use(user* arg); users_t::iterator erase_use(user* arg); - const std::set &get_users() { return users_; } + const std::vector &get_users() { return users_; } void replace_all_uses_with(value *target); // name void set_name(const std::string &name); diff --git a/include/triton/ir/visitor.h b/include/triton/ir/visitor.h index 25ce578e3..774f2e172 100644 --- a/include/triton/ir/visitor.h +++ b/include/triton/ir/visitor.h @@ -11,6 +11,9 @@ class value; class instruction; +class call_inst; +class launch_inst; + class phi_node; class binary_operator; class getelementptr_inst; @@ -42,6 +45,9 @@ class masked_load_inst; class unmasked_store_inst; class masked_store_inst; +class extract_value_inst; +class insert_value_inst; + class retile_inst; class reshape_inst; class splat_inst; @@ -105,6 +111,8 @@ public: virtual ~visitor() {} virtual void visit_value(ir::value*); + virtual void visit_call_inst(ir::call_inst*) = 0; + virtual void visit_launch_inst(ir::launch_inst*) = 0; virtual void visit_basic_block(basic_block*) = 0; virtual void visit_argument(argument*) = 0; @@ -132,6 +140,9 @@ public: virtual void visit_sin_inst(sin_inst*) = 0; virtual void visit_log_inst(log_inst*) = 0; + virtual void visit_extract_value_inst(extract_value_inst*) = 0; + virtual void visit_insert_value_inst(insert_value_inst*) = 0; + virtual void visit_reshape_inst(reshape_inst*) = 0; virtual void visit_splat_inst(splat_inst*) = 0; virtual void visit_cat_inst(cat_inst*) = 0; diff --git a/lib/codegen/analysis/layout.cc b/lib/codegen/analysis/layout.cc index 5d30a2f45..cec512fec 100644 --- a/lib/codegen/analysis/layout.cc +++ b/lib/codegen/analysis/layout.cc @@ -608,6 +608,8 @@ void layouts::run(ir::module &mod) { // create temporaries size_t id = values_.size(); ir::for_each_instruction(mod, [this, &id](ir::instruction* i) { +// std::cout << "layout: " << std::endl; +// i->print(std::cout); if(auto *red = dynamic_cast(i)) { id++; ir::value *arg = red->get_operand(0); diff --git a/lib/codegen/pass.cc b/lib/codegen/pass.cc index 8921d6c84..e2cd6d228 100644 --- a/lib/codegen/pass.cc +++ b/lib/codegen/pass.cc @@ -13,6 +13,7 @@ #include "triton/codegen/transform/peephole.h" #include "triton/codegen/transform/pipeline.h" #include "triton/codegen/transform/prefetch.h" +#include "triton/codegen/transform/inline.h" #include "triton/ir/function.h" #include "triton/ir/module.h" #include "triton/ir/print.h" @@ -33,6 +34,7 @@ std::unique_ptr add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC bool cts_use_async = target->as_nvidia() && target->as_nvidia()->sm() >= 80; // create passes codegen::analysis::align align; + codegen::transform::inliner inliner; codegen::analysis::axes axes; codegen::transform::cts cts(cts_use_async); codegen::transform::pipeline pipeline(cts_use_async, num_stages); @@ -48,6 +50,7 @@ std::unique_ptr add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC codegen::transform::membar barriers(&liveness, &layouts, &allocation, &prefetch_s, target); codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target, num_warps); // run passes + inliner.run(ir); dce.run(ir); peephole.run(ir); dce.run(ir); diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index b36f51d92..0e6ae4539 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -13,6 +13,7 @@ #include "triton/ir/module.h" #include "triton/ir/function.h" #include "triton/ir/type.h" +#include "triton/ir/utils.h" #include "llvm/IR/Module.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/IntrinsicsNVPTX.h" @@ -139,6 +140,14 @@ Value* geper::operator()(Value *ptr, Value* off, const std::string& name){ * \brief Convert Triton-IR Type to LLVM-IR Type */ Type *generator::cvt(ir::type *ty) { + // struct + if(ty->is_struct_ty()){ + std::vector tys; + for(size_t i = 0; i < ty->get_struct_numel(); i++) + tys.push_back(cvt(ty->get_struct_type(i))); + return StructType::get(builder_->getContext(), tys, true); + } + // function if(auto* tt = dynamic_cast(ty)){ Type *ret_ty = cvt(tt->get_return_ty()); @@ -266,7 +275,8 @@ void generator::visit_value(ir::value* v) { builder_->SetInsertPoint(&*current->getFirstNonPHI()); // visit user if(auto *usr = dynamic_cast(v)){ - usr->accept(this); + if(!dynamic_cast(usr)) + usr->accept(this); } // revert insert point if(phi && !current->empty() && current->getFirstNonPHI()) @@ -282,6 +292,81 @@ void generator::visit_phi_node(ir::phi_node* x) { vals_[x][idx] = phi(ty, x->get_num_operands()); } +/** + * \brief Code Generation for `call` + */ +void generator::visit_call_inst(ir::call_inst* call) { + throw std::runtime_error("call not supported! Triton should be inlining everything."); +} + +void generator::visit_launch_inst(ir::launch_inst *launch) { + ir::function* fn = (ir::function*)launch->get_operand(0); + // forward-declare cudaGetParameterBufferV2 + std::vector get_param_arg_tys = {PointerType::get(builder_->getInt8Ty(), 0), + ArrayType::get(builder_->getInt32Ty(), 3), + ArrayType::get(builder_->getInt32Ty(), 3), + builder_->getInt32Ty()}; + FunctionType* get_param_ty = FunctionType::get(PointerType::get(builder_->getInt8Ty(), 0), get_param_arg_tys, false); + Function* get_param_buffer = Function::Create(get_param_ty, Function::ExternalLinkage, "cudaGetParameterBufferV2", mod_); + AllocaInst* grid = builder_->CreateAlloca(get_param_arg_tys[1]); + AllocaInst* block = builder_->CreateAlloca(get_param_arg_tys[2]); + ConstantInt* _0 = builder_->getInt32(0); + ConstantInt* _1 = builder_->getInt32(1); + ConstantInt* _2 = builder_->getInt32(2); + // create basic block + BasicBlock* launch_done_bb = BasicBlock::Create(builder_->getContext(), "launch_done", builder_->GetInsertBlock()->getParent()); + BasicBlock* launch_bb = BasicBlock::Create(builder_->getContext(), "launch", launch_done_bb->getParent(), launch_done_bb); + Value *tid = tgt_->get_local_id(mod_, *builder_, 0); + Value *is_first_thread = builder_->CreateICmpEQ(tid, i32(0)); + builder_->CreateCondBr(is_first_thread, launch_bb, launch_done_bb); + builder_->SetInsertPoint(launch_bb); + + // + builder_->CreateStore(vals_[launch->get_grid()[0]][{}], builder_->CreateGEP(grid, {_0, _0})); + builder_->CreateStore(vals_[launch->get_grid()[1]][{}], builder_->CreateGEP(grid, {_0, _1})); + builder_->CreateStore(vals_[launch->get_grid()[2]][{}], builder_->CreateGEP(grid, {_0, _2})); + Value* num_warps = mul(builder_->getInt32(32), vals_[launch->get_num_warps()][{}]); + builder_->CreateStore(num_warps, builder_->CreateGEP(block, {_0, _0})); + builder_->CreateStore(builder_->getInt32(1), builder_->CreateGEP(block, {_0, _1})); + builder_->CreateStore(builder_->getInt32(1), builder_->CreateGEP(block, {_0, _2})); + Function* called_fn = fns_[fn]; + Value* callee = ConstantExpr::getCast(Instruction::BitCast, called_fn, get_param_arg_tys[0]); + Value* arg_ptr = builder_->CreateCall(get_param_buffer, {callee, builder_->CreateLoad(grid), builder_->CreateLoad(block), builder_->getInt32(0)}); + // forwrd-declare cudaLaunchDeviceV2 + std::vector launch_device_arg_tys = {get_param_ty->getReturnType(), builder_->getInt64Ty()}; + FunctionType* launch_device_ty = FunctionType::get(builder_->getInt32Ty(), launch_device_arg_tys, false); + Function* launch_device = Function::Create(launch_device_ty, Function::ExternalLinkage, "cudaLaunchDeviceV2", mod_); + // TODO: add branch + Value* do_not_launch = builder_->CreateICmpEQ(builder_->CreatePtrToInt(arg_ptr, builder_->getInt64Ty()), + builder_->getInt64(0)); + BasicBlock* launch2_bb = BasicBlock::Create(builder_->getContext(), "launch2", launch_done_bb->getParent(), launch_done_bb); + builder_->CreateCondBr(do_not_launch, launch_done_bb, launch2_bb); + builder_->SetInsertPoint(launch2_bb); + + unsigned addr_space = arg_ptr->getType()->getPointerAddressSpace(); + unsigned off = 0; + unsigned last_size = 0; + for(ir::value* arg: launch->get_values()){ + Value* curr_arg = vals_[arg][{}]; + Type* curr_arg_ty = curr_arg->getType(); + // handle struct alignment + off += last_size; + unsigned size = curr_arg_ty->isPointerTy() ? 8 : curr_arg_ty->getPrimitiveSizeInBits() / 8; + off = (off + size - 1) / size * size; + // get pointer to current arg + Value* curr_arg_ptr = builder_->CreateGEP(arg_ptr, builder_->getInt32(off)); + curr_arg_ptr = builder_->CreateBitCast(curr_arg_ptr, curr_arg_ty->getPointerTo(addr_space)); + // store arg + builder_->CreateStore(curr_arg, curr_arg_ptr); + last_size = size; + } + builder_->CreateCall(launch_device, {arg_ptr, builder_->getInt64(0)}); + builder_->CreateBr(launch_done_bb); + // done + builder_->SetInsertPoint(launch_done_bb); + +} + /** * \brief Code Generation for `binary_operator` */ @@ -311,6 +396,7 @@ void generator::visit_binary_operator(ir::binary_operator*x) { default: throw std::runtime_error("unreachable switch"); } }; +// x->print(std::cout); for(indices_t idx: idxs_.at(x)){ Value *lhs = vals_[x->get_operand(0)][idx]; Value *rhs = vals_[x->get_operand(1)][idx]; @@ -852,6 +938,31 @@ void generator::visit_masked_store_inst(ir::masked_store_inst* x) { visit_store_inst(x); } +// -- + +void generator::visit_extract_value_inst(ir::extract_value_inst *x) { + auto idxs = idxs_.at(x); + ir::value* agg = x->get_operand(0); + unsigned insert_idx = x->get_idx(); + for(size_t i = 0; i < idxs.size(); i++){ + auto idx = idxs[i]; + vals_[x][idx] = builder_->CreateExtractValue(vals_[agg][idx], {insert_idx}); + } +} + + +void generator::visit_insert_value_inst(ir::insert_value_inst *x){ + auto idxs = idxs_.at(x); + ir::value* agg = x->get_operand(0); + ir::value* val = x->get_operand(1); + unsigned insert_idx = x->get_idx(); + for(size_t i = 0; i < idxs.size(); i++){ + auto idx = idxs[i]; + vals_[x][idx] = builder_->CreateInsertValue(vals_[agg][idx], vals_[val][idx],{insert_idx}); + } +} + +// -- /** * \brief Code Generation for `cat` */ @@ -2686,7 +2797,8 @@ void generator::visit_make_range(ir::make_range* x) { } void generator::visit_undef_value(ir::undef_value *x) { - Type* ty = cvt(x->get_type()->get_scalar_ty()); + ir::type* sca_ty = x->get_type()->get_scalar_ty(); + Type* ty = cvt(sca_ty); for(indices_t idx: idxs_.at(x)) vals_[x][idx] = llvm::UndefValue::get(ty); } @@ -2713,8 +2825,7 @@ void generator::visit_alloc_const(ir::alloc_const *alloc) { } -void generator::visit_function(ir::function* fn) { - LLVMContext &ctx = builder_->getContext(); +void generator::forward_declare(ir::function* fn){ FunctionType *fn_ty = (FunctionType*)cvt(fn->get_fn_type()); if(!tgt_->is_gpu()){ Type *fn_ret_ty = fn_ty->getReturnType(); @@ -2727,6 +2838,18 @@ void generator::visit_function(ir::function* fn) { fn_ty = FunctionType::get(fn_ret_ty, fn_args_ty, false); } Function *ret = Function::Create(fn_ty, Function::ExternalLinkage, fn->get_name(), mod_); + fns_[fn] = ret; +} + +void generator::visit_function(ir::function* fn) { + idxs_.clear(); + vals_.clear(); + seen_.clear(); + LLVMContext &ctx = builder_->getContext(); + + Function* ret = fns_[fn]; + + // set attributes for(auto attr_pair: fn->attrs()){ unsigned id = attr_pair.first; @@ -2751,7 +2874,8 @@ void generator::visit_function(ir::function* fn) { for(unsigned i = 0; i < fn->args().size(); i++) vals_[fn->args()[i]][{}] = &*(ret->arg_begin() + i); // create blocks - for(ir::basic_block *block: fn->blocks()) { + auto blocks = ir::cfg::reverse_post_order(fn); + for(ir::basic_block *block: blocks) { BasicBlock *dst_block = BasicBlock::Create(ctx, block->get_name(), ret); bbs_[block] = dst_block; } @@ -2761,7 +2885,7 @@ void generator::visit_function(ir::function* fn) { visit_layout(x.second); } // generate LLVM-IR code - for(ir::basic_block *block: fn->blocks()) + for(ir::basic_block *block: blocks) visit_basic_block(block); // finalize finalize_function(fn); @@ -2982,10 +3106,12 @@ void generator::visit_layout_shared(analysis::shared_layout* layout) { } void generator::visit_basic_block(ir::basic_block * block) { + BasicBlock *parent = bbs_[block]; builder_->SetInsertPoint(parent); - for(ir::instruction *i: block->get_inst_list()) + for(ir::instruction *i: block->get_inst_list()){ visit_value(i); + } // Update ir bb -> llvm bb mapping bbs_[block] = builder_->GetInsertBlock(); } @@ -3168,6 +3294,12 @@ void generator::finalize_phi_node(ir::phi_node *x) { } } +StructType* generator::packed_type(ir::value* i){ + Type* dtype = cvt(i->get_type()->get_tile_element_ty()); + auto* layout = dynamic_cast(layouts_->get(i)); + assert(layout); +} + void generator::visit(ir::module &src, llvm::Module &dst) { mod_ = &dst; ctx_ = &dst.getContext(); @@ -3184,7 +3316,16 @@ void generator::visit(ir::module &src, llvm::Module &dst) { nullptr, "__shared_ptr", nullptr, GlobalVariable::NotThreadLocal, 3); shmem_ = bit_cast(sh_mem_array, ptr_ty); } + // instantiate device functions +// for(ir::function *fn: src.get_function_list()) +// for(ir::basic_block *bb: fn->blocks()) +// for(ir::instruction *i: bb->get_inst_list()) +// if(auto *call = dynamic_cast(i)){ +// std::cout << "call??" << std::endl; +// } // visit functions + for(ir::function *fn: src.get_function_list()) + forward_declare(fn); for(ir::function *fn: src.get_function_list()) visit_function(fn); } diff --git a/lib/codegen/transform/dce.cc b/lib/codegen/transform/dce.cc index c555290f8..7416ff6e8 100644 --- a/lib/codegen/transform/dce.cc +++ b/lib/codegen/transform/dce.cc @@ -3,6 +3,7 @@ #include "triton/ir/basic_block.h" #include "triton/ir/module.h" #include "triton/ir/utils.h" +#include namespace triton { namespace codegen{ @@ -28,6 +29,8 @@ void dce::run(ir::module &mod) { case ir::INST_ATOMIC_CAS: case ir::INST_ATOMIC_RMW: case ir::INST_ATOMIC_EXCH: + case ir::INST_CALL: + case ir::INST_LAUNCH: case ir::INST_BARRIER: { work_list.push_back(i); marked.insert(i); @@ -65,6 +68,7 @@ void dce::run(ir::module &mod) { } } + // delete for(ir::instruction* i: to_delete) i->erase_from_parent(); diff --git a/lib/codegen/transform/inline.cc b/lib/codegen/transform/inline.cc new file mode 100644 index 000000000..fa22e5354 --- /dev/null +++ b/lib/codegen/transform/inline.cc @@ -0,0 +1,127 @@ +#include +#include "triton/codegen/transform/inline.h" +#include "triton/ir/module.h" +#include "triton/ir/function.h" +#include "triton/ir/utils.h" + +namespace triton{ +namespace codegen{ +namespace transform{ + + +bool fncmp::operator()(ir::function* x, ir::function* y) const { + auto fn_list = x->get_parent()->get_function_list(); + return std::find(fn_list.begin(), fn_list.end(), x) < std::find(fn_list.begin(), fn_list.end(), y); +}; + +void inliner::do_inline(ir::function* fn, ir::call_inst* callsite, ir::builder& builder, + std::list& callsites){ + ir::basic_block* parent_block = callsite->get_parent(); + ir::function* parent_fn = parent_block->get_parent(); + // the parent block is split into block A and block B: + // - block A (`new_blocks[0]`) is the entry block of the inlined function + // - block B (`exit`) resumes execution of the parent function + ir::basic_block* entry = parent_block->split_before(callsite, fn->get_name()); + ir::basic_block* exit = entry->get_successors()[0]; + std::vector new_blocks = {entry}; + for(size_t i = 1; i < fn->blocks().size(); i++){ + ir::basic_block* block = fn->blocks()[i]; + ir::context& ctx = block->get_context(); + const std::string& name = block->get_parent()->get_name() + "_" + block->get_name(); + new_blocks.push_back(ir::basic_block::create(ctx, name, parent_fn)); + } + // a phi node holds the return values of the inlined function + if(exit->get_inst_list().empty()) + builder.set_insert_point(exit); + else + builder.set_insert_point(exit->get_first_non_phi()); + ir::phi_node* exit_val = builder.create_phi(fn->get_fn_type()->get_return_ty(), 0); + callsite->replace_all_uses_with(exit_val); + callsite->erase_from_parent(); + // get arguments `fn` is called with + std::vector tgt_args(callsite->op_begin(), callsite->op_end()); + std::vector src_args(fn->args().begin(), fn->args().end()); + // Actually generate the instructions: + // - Remove the branch created by basic_block::split_before + // - Clone all instructions + // - Replace `ret` with incoming nodes to `exit_val` and branches to `exit` + ir::instruction* terminator = new_blocks[0]->get_inst_list().back(); +// new_blocks[0]->get_inst_list().back()->erase_from_parent(); + terminator->erase_from_parent(); + std::map inst_map; + std::map arg_map; + for(size_t k = 0; k < fn->args().size(); k++) + arg_map[fn->args()[k]] = callsite->ops()[k]; + std::vector rpo = ir::cfg::reverse_post_order(fn); + for(size_t i = 0; i < new_blocks.size(); i++){ + ir::basic_block* old_block = fn->blocks()[i]; + ir::basic_block* new_block = new_blocks[i]; + builder.set_insert_point(new_block); + for(ir::instruction* old_inst: old_block->get_inst_list()){ + // clone instruction + ir::instruction* new_inst = old_inst->clone(); + // replace basic block + for(size_t k = 0; k < new_blocks.size(); k++) + new_inst->replace_uses_of_with(fn->blocks()[k], new_blocks[k]); + // replace values + for(size_t k = 0; k < new_inst->get_num_operands(); k++){ + ir::value* op = new_inst->get_operand(k); + if(auto arg_op = dynamic_cast(op)) + new_inst->set_operand(k, arg_map.at(arg_op)); + if(auto inst_op = dynamic_cast(op)) + if(inst_map.find(inst_op) != inst_map.end()) + new_inst->set_operand(k, inst_map.at(inst_op)); + } + // `ret` instruction is a special case: + // instead of returning we need to branch to after the function call + if(ir::return_inst* ret = dynamic_cast(new_inst)){ + if(ir::value* ret_val = ret->get_return_value()) + exit_val->add_incoming(ret_val, new_block); + new_inst = ir::branch_inst::create(exit); + } + inst_map[old_inst] = new_inst; + builder.insert(new_inst); + } + } + if(exit_val->get_num_incoming() == 1) + exit_val->replace_all_uses_with(exit_val->get_incoming_value(0)); + // done -- make sure insert point is properly set to exit block + builder.set_insert_point(exit); +} + +void inliner::run(ir::module &mod) { + + // gather all call sites + while(true){ + std::map counts; + for(ir::function* fn: mod.get_function_list()) + counts[fn] = 0; + + std::list callsites; + for(ir::function* fn: mod.get_function_list()){ + for(ir::basic_block* block: fn->blocks()) + for(ir::instruction* instr: block->get_inst_list()) + if(ir::call_inst* call = dynamic_cast(instr)){ + callsites.push_back(call); + counts[call->get_fn()] += 1; + } + } + + for(auto& count: counts){ + if(!count.first->get_is_kernel() && count.second == 0) + count.first->get_parent()->remove_function(count.first); + } + + if(callsites.empty()) + break; + + for(ir::call_inst* call: callsites) + do_inline(call->get_fn(), call, mod.get_builder(), callsites); + } + + +} + +} +} +} diff --git a/lib/codegen/transform/peephole.cc b/lib/codegen/transform/peephole.cc index 0961efc9c..c25a252a8 100644 --- a/lib/codegen/transform/peephole.cc +++ b/lib/codegen/transform/peephole.cc @@ -150,32 +150,53 @@ bool peephole::rewrite_unit_red(ir::instruction *value, ir::builder& builder){ } bool peephole::rewrite_mult(ir::instruction *value, ir::builder& builder) { - auto binop = dynamic_cast(value); - if(binop && binop->get_op() == ir::binary_op_t::Mul) { - ir::value *lhs = binop->get_operand(0); - ir::value *rhs = binop->get_operand(1); - ir::constant_int *_1_lhs = nullptr; - if(ir::splat_inst *splat = dynamic_cast(lhs)){ - auto *cst = dynamic_cast(splat->get_operand(0)); - if(cst && cst->get_value() == 1) - _1_lhs = cst; - } - ir::constant_int *_1_rhs = nullptr; - if(ir::splat_inst *splat = dynamic_cast(rhs)){ - auto *cst = dynamic_cast(splat->get_operand(0)); - if(cst && cst->get_value() == 1) - _1_rhs = cst; - } - if(_1_lhs){ - binop->replace_all_uses_with(rhs); - return true; - } - else if(_1_rhs){ - binop->replace_all_uses_with(lhs); - return true; - } + auto binop = dynamic_cast(value); + if(binop && binop->get_op() == ir::binary_op_t::Mul) { + ir::value *lhs = binop->get_operand(0); + ir::value *rhs = binop->get_operand(1); + ir::constant_int *_1_lhs = nullptr; + if(ir::splat_inst *splat = dynamic_cast(lhs)){ + auto *cst = dynamic_cast(splat->get_operand(0)); + if(cst && cst->get_value() == 1) + _1_lhs = cst; } + ir::constant_int *_1_rhs = nullptr; + if(ir::splat_inst *splat = dynamic_cast(rhs)){ + auto *cst = dynamic_cast(splat->get_operand(0)); + if(cst && cst->get_value() == 1) + _1_rhs = cst; + } + if(_1_lhs){ + binop->replace_all_uses_with(rhs); + return true; + } + else if(_1_rhs){ + binop->replace_all_uses_with(lhs); + return true; + } + } + return false; +} + +bool peephole::rewrite_insert_extract(ir::instruction *value, ir::builder& builder){ + auto extracted = dynamic_cast(value); + if(!extracted) return false; + size_t extract_idx = extracted->get_idx(); + ir::value* agg = extracted->get_operand(0); + auto insert = dynamic_cast(agg); + while(insert){ + agg = insert->get_operand(0); + ir::value* inserted = insert->get_operand(1); + size_t insert_idx = insert->get_idx(); + insert = dynamic_cast(agg); + if(extract_idx == insert_idx){ + extracted->replace_all_uses_with(inserted); + return true; + } + insert = dynamic_cast(agg); + } + return false; } @@ -291,6 +312,7 @@ void peephole::run(ir::module &mod) { was_modified = was_modified || rewrite_mult(i, builder); // was_modified = was_modified || rewrite_cts_cfs(i, builder); // was_modified = was_modified || rewrite_trans_phi(i, builder); + was_modified = was_modified || rewrite_insert_extract(i, builder); was_modified = was_modified || rewrite_unit_red(i, builder); was_modified = was_modified || rewrite_gep_ptr_min_off_plus_off(i, builder); // TODO: DOESN'T WORK FOR VECTORIZED MASKED LOAD diff --git a/lib/codegen/transform/pipeline.cc b/lib/codegen/transform/pipeline.cc index c85ba43a1..0c5c0b292 100644 --- a/lib/codegen/transform/pipeline.cc +++ b/lib/codegen/transform/pipeline.cc @@ -134,6 +134,7 @@ void pipeline::run(ir::module &mod) { ir::builder &builder = mod.get_builder(); const int num_stages = num_stages_; std::vector>> preheader_loads; // Used to reorder loads + for(auto info: to_pipeline){ ir::load_inst* load = info.load; ir::phi_node* ptr = info.ptr; diff --git a/lib/driver/dispatch.cc b/lib/driver/dispatch.cc index 9e2aca432..de6f1901b 100755 --- a/lib/driver/dispatch.cc +++ b/lib/driver/dispatch.cc @@ -138,6 +138,7 @@ CUDA_DEFINE3(CUresult, cuDeviceGetAttribute, int *, CUdevice_attribute, CUdevice CUDA_DEFINE1(CUresult, cuDeviceGetCount, int*) // link management +CUDA_DEFINE6(CUresult, cuLinkAddFile_v2, CUlinkState, CUjitInputType, const char *, unsigned int , CUjit_option *, void **); CUDA_DEFINE8(CUresult, cuLinkAddData_v2, CUlinkState, CUjitInputType, void*, size_t, const char*, unsigned int, CUjit_option*, void**); CUDA_DEFINE4(CUresult, cuLinkCreate_v2, unsigned int, CUjit_option*, void**, CUlinkState*); CUDA_DEFINE1(CUresult, cuLinkDestroy, CUlinkState); diff --git a/lib/driver/error.cc b/lib/driver/error.cc index f723351c2..fda2b7f33 100755 --- a/lib/driver/error.cc +++ b/lib/driver/error.cc @@ -90,7 +90,7 @@ void check(CUresult err) case CUDA_ERROR_NOT_PERMITTED : throw not_permitted(); case CUDA_ERROR_NOT_SUPPORTED : throw not_supported(); case CUDA_ERROR_UNKNOWN : throw unknown(); - default : throw unknown(); + default : throw std::runtime_error("unimplemented code: " + std::to_string(err)); } } diff --git a/lib/driver/llvm.cc b/lib/driver/llvm.cc index 463f45712..92a6b75de 100644 --- a/lib/driver/llvm.cc +++ b/lib/driver/llvm.cc @@ -174,6 +174,7 @@ std::string llir_to_ptx(llvm::Module* module, int cc, int version){ init_llvm(); // verify and store llvm llvm::legacy::PassManager pm; +// pm.add(llvm::createPrintModulePass(llvm::outs())); pm.add(llvm::createVerifierPass()); pm.run(*module); // module->print(llvm::outs(), nullptr); @@ -213,6 +214,7 @@ std::string llir_to_ptx(llvm::Module* module, int cc, int version){ return result; } + std::string ptx_to_cubin(const std::string& ptx, const std::string& ptxas, int cc) { // compile ptx with ptxas char _fsrc[L_tmpnam]; diff --git a/lib/ir/basic_block.cc b/lib/ir/basic_block.cc index 0654156a3..93caef2c3 100644 --- a/lib/ir/basic_block.cc +++ b/lib/ir/basic_block.cc @@ -1,3 +1,5 @@ +#include +#include #include "triton/ir/basic_block.h" #include "triton/ir/instructions.h" #include "triton/ir/type.h" @@ -9,23 +11,68 @@ namespace ir { class phi_node; -basic_block::basic_block(context &ctx, const std::string &name, function *parent): +basic_block::basic_block(context &ctx, const std::string &name, function *parent, basic_block* next): value(type::get_label_ty(ctx), name), ctx_(ctx), parent_(parent) { if(parent_) - parent_->insert_block(this); + parent_->insert_block(this, next); } -basic_block* basic_block::create(context &ctx, const std::string &name, function *parent){ - return new basic_block(ctx, name, parent); +basic_block* basic_block::create(context &ctx, const std::string &name, function *parent, basic_block* next){ + return new basic_block(ctx, name, parent, next); } -void basic_block::add_predecessor(basic_block *pred) { - preds_.push_back(pred); - if(pred) - pred->succs_.push_back(this); +void basic_block::replace_phi_uses_with(basic_block* before, basic_block* after) { + for(ir::instruction* i: inst_list_){ + auto* curr_phi = dynamic_cast(i); + if(!curr_phi) + break; + curr_phi->replace_uses_of_with(before, after); + } } +void basic_block::append_instruction(ir::instruction* i){ + i->set_parent(this); + inst_list_.push_back(i); +} +basic_block* basic_block::split_before(ir::instruction* loc, const std::string& name) { + basic_block* ret = basic_block::create(ctx_, name, parent_, this); + ret->set_name(get_name()); + set_name("after_" + name); + + // splice instruction list + auto loc_it = std::find(inst_list_.begin(), inst_list_.end(), loc); + ret->get_inst_list().splice(ret->get_inst_list().begin(), inst_list_, inst_list_.begin(), loc_it); + for(ir::instruction* i: ret->get_inst_list()) + i->set_parent(ret); + // the predecessors of `this` becomes the predecessors of `ret` + for(ir::basic_block* pred: get_predecessors()){ + auto* term = dynamic_cast(pred->get_inst_list().back()); + assert(term); + term->replace_uses_of_with(this, ret); + replace_phi_uses_with(pred, ret); + } + ir::branch_inst* br = branch_inst::create(this); + ret->append_instruction(br); + return ret; +} + +std::vector basic_block::get_predecessors() const { + std::vector ret; + for(ir::user* u: users_) + if(auto term = dynamic_cast(u)) + ret.push_back(term->get_parent()); + return ret; +} + +std::vector basic_block::get_successors() const { + std::vector ret; + for(ir::instruction* i: inst_list_) + for(ir::value* v: i->ops()) + if(auto block = dynamic_cast(v)) + ret.push_back(block); + return ret; +} basic_block::iterator basic_block::get_first_non_phi(){ auto it = begin(); diff --git a/lib/ir/builder.cc b/lib/ir/builder.cc index fff73e665..58174aa7a 100644 --- a/lib/ir/builder.cc +++ b/lib/ir/builder.cc @@ -117,13 +117,10 @@ type *builder::get_double_ty() //===----------------------------------------------------------------------===// value* builder::create_br(basic_block *dest){ - dest->add_predecessor(block_); return insert(branch_inst::create(dest)); } value* builder::create_cond_br(value *cond, basic_block *if_dest, basic_block *else_dest){ - if_dest->add_predecessor(block_); - else_dest->add_predecessor(block_); return insert(branch_inst::create(cond, if_dest, else_dest)); } @@ -131,6 +128,10 @@ value *builder::create_ret_void() { return insert(return_inst::create(ctx_)); } +value *builder::create_ret(value* val) { + return insert(return_inst::create(ctx_, val)); +} + //===----------------------------------------------------------------------===// // cast instructions //===----------------------------------------------------------------------===// @@ -163,6 +164,19 @@ phi_node* builder::create_phi(type *ty, unsigned num_reserved){ return insert(phi_node::create(ty, num_reserved)); } +//===----------------------------------------------------------------------===// +// call instructions +//===----------------------------------------------------------------------===// + +value *builder::create_call(function* fn, const std::vector& args){ + return insert(call_inst::create(fn, args)); +} + +value* builder::create_launch(function* fn, const std::vector& args, const std::vector& grid, value* num_warps){ + return insert(launch_inst::create(fn, args, grid, num_warps)); + +} + //===----------------------------------------------------------------------===// // binary float instructions //===----------------------------------------------------------------------===// @@ -307,6 +321,19 @@ value *builder::create_masked_store(value *ptr, value *val, value *mask){ return insert(masked_store_inst::create(ptr, val, mask)); } +//===----------------------------------------------------------------------===// +// struct instructions +//===----------------------------------------------------------------------===// + + +// Struct instructions +value *builder::create_insert_value(value* val, value *elt, size_t idx){ + return insert(insert_value_inst::create(val, elt, idx)); +} + +value *builder::create_extract_value(value* val, size_t idx) { + return insert(extract_value_inst::create(val, idx)); +} //===----------------------------------------------------------------------===// // block instructions //===----------------------------------------------------------------------===// diff --git a/lib/ir/dispatch.cc b/lib/ir/dispatch.cc index aabbc4385..fc0252dbf 100644 --- a/lib/ir/dispatch.cc +++ b/lib/ir/dispatch.cc @@ -865,6 +865,9 @@ ir::value *dispatch::clock(ir::builder *builder) { } +//===----------------------------------------------------------------------===// +// Control FLow +//===----------------------------------------------------------------------===// // ir::value *dispatch::multiple_of(ir::value *x, int value, ir::builder *){ diff --git a/lib/ir/function.cc b/lib/ir/function.cc index 84d52df72..4f3cd5ac6 100644 --- a/lib/ir/function.cc +++ b/lib/ir/function.cc @@ -33,8 +33,10 @@ void argument::accept(visitor *v) { /* function */ function::function(function_type *ty, linkage_types_t linkage, const std::string &name, module *parent) - : global_object(ty, 0, linkage, name), parent_(parent), fn_ty_(ty) { + : global_object(ty, 0, linkage, name), parent_(parent), fn_ty_(ty), is_kernel_(false) { unsigned num_params = fn_ty_->get_num_params(); + if(parent) + parent->push_function(this); // skip if no parameter if(num_params == 0) return; @@ -44,8 +46,6 @@ function::function(function_type *ty, linkage_types_t linkage, type *param_ty = fn_ty_->get_param_ty(i); args_[i] = argument::create(param_ty, "", this, i); } - if(parent) - parent->push_function(this); } /* basic block */ diff --git a/lib/ir/instructions.cc b/lib/ir/instructions.cc index d1f81f136..1bcbfa9ff 100644 --- a/lib/ir/instructions.cc +++ b/lib/ir/instructions.cc @@ -5,6 +5,7 @@ #include "triton/ir/instructions.h" #include "triton/ir/constant.h" #include "triton/ir/type.h" +#include "triton/ir/function.h" namespace triton{ namespace ir{ @@ -79,6 +80,70 @@ phi_node* phi_node::create(type *ty, unsigned num_reserved, const std::string &n return new phi_node(ty, num_reserved, name, next); } +//===----------------------------------------------------------------------===// +// call_inst classes +//===----------------------------------------------------------------------===// + +std::string call_inst::repr_impl() const { return "call " + fn_->get_name(); } + +call_inst::call_inst(ir::function* fn, const std::vector& values, const std::string& name, instruction* next) + : instruction(fn->get_fn_type()->get_return_ty(), INST_CALL, values.size(), name, next), fn_(fn){ + for(size_t i = 0; i < values.size(); i++) + set_operand(i, values.at(i)); +} + +call_inst* call_inst::create(ir::function* fn, const std::vector& values, const std::string &name, instruction *next) { + return new call_inst(fn, values, name, next); +} + + +// launch + +launch_inst::launch_inst(ir::function* fn, const std::vector& values, const std::vector& grid, ir::value* num_warps, const std::string& name, instruction* next) + : instruction(fn->get_fn_type()->get_return_ty(), INST_LAUNCH, 1 + values.size() + grid.size() + 1, name, next){ + int k = 0; + if(grid.size() != 3) + throw std::runtime_error("grid must have 3 elements"); + set_operand(k++, fn); + val_begin = k; + for(ir::value* v: values) + set_operand(k++, v); + val_end = k; + grid_begin = k; + for(ir::value* g: grid) + set_operand(k++, g); + grid_end = k; + set_operand(k++, num_warps); +} + + +ir::function* launch_inst::get_fn() { + return (ir::function*)get_operand(0); +} + +std::vector launch_inst::get_values() { + std::vector ret; + for(int i = val_begin; i < val_end; i++) + ret.push_back(get_operand(i)); + return ret; +} + +std::vector launch_inst::get_grid() { + std::vector ret; + for(int i = grid_begin; i < grid_end; i++) + ret.push_back(get_operand(i)); + return ret; +} + +ir::value* launch_inst::get_num_warps() { + return get_operand(grid_end); +} + + +launch_inst* launch_inst::create(ir::function *fn, const std::vector &values, const std::vector &grid, ir::value *num_warps, const std::string &name, instruction *next) { + return new launch_inst(fn, values, grid, num_warps, name, next); +} + //===----------------------------------------------------------------------===// // binary_operator classes @@ -324,7 +389,7 @@ cast_inst *cast_inst::create_integer_cast(value *arg, type *ty, bool is_signed, // return_inst return_inst::return_inst(context &ctx, value *ret_val, instruction *next) - : terminator_inst(type::get_void_ty(ctx), INST_RETURN, ret_val!=nullptr, "", next){ + : terminator_inst(ret_val?ret_val->get_type():type::get_void_ty(ctx), INST_RETURN, ret_val!=nullptr, "", next){ if(ret_val) set_operand(0, ret_val); } @@ -521,6 +586,36 @@ masked_store_inst::masked_store_inst(value *ptr, value *val, value *mask, masked_store_inst* masked_store_inst::create(value *ptr, value *val, value *mask, const std::string &name, instruction *next) { return new masked_store_inst(ptr, val, mask, name, next); } + +//===----------------------------------------------------------------------===// +// struct classes +//===----------------------------------------------------------------------===// + +// insert value + +insert_value_inst::insert_value_inst(value *val, value *elt, size_t idx, const std::string& name, instruction *next) + : instruction(val->get_type(), INST_INSERT_VALUE, 2, name, next), idx_(idx) { + set_operand(0, val); + set_operand(1, elt); +} + +insert_value_inst* insert_value_inst::create(value *val, value *elt, size_t idx, const std::string& name, instruction *next){ + return new insert_value_inst(val, elt, idx, name, next); +} + + +// extract value + +extract_value_inst::extract_value_inst(value *val, size_t idx, const std::string& name, instruction *next) + : instruction(val->get_type()->get_struct_type(idx), INST_EXTRACT_VALUE, 1, name, next), idx_(idx) { + set_operand(0, val); +} + +extract_value_inst* extract_value_inst::create(value *val, size_t idx, const std::string& name, instruction *next){ + return new extract_value_inst(val, idx, name, next); +} + + //===----------------------------------------------------------------------===// // retile_inst classes //===----------------------------------------------------------------------===// @@ -575,6 +670,9 @@ instruction* downcast_inst::create(value *arg, const std::string &name, instruct return new downcast_inst(arg->get_type()->get_scalar_ty(), INST_DOWNCAST, arg, name, next); } + + + //===----------------------------------------------------------------------===// // matmul_inst classes //===----------------------------------------------------------------------===// diff --git a/lib/ir/module.cc b/lib/ir/module.cc index 33b39de3a..7df196c8f 100644 --- a/lib/ir/module.cc +++ b/lib/ir/module.cc @@ -9,17 +9,12 @@ namespace triton{ namespace ir{ -/* Module */ -module::module(const std::string &name, builder &builder) - : name_(name), builder_(builder) { +/* */ +value_constructor::value_constructor(ir::builder& builder): builder_(builder){ sealed_blocks_.insert(nullptr); } -ir::builder& module::get_builder() { - return builder_; -} - -void module::set_value(const std::string& name, ir::basic_block *block, ir::value *value){ +void value_constructor::set_value(const std::string& name, ir::basic_block *block, ir::value *value){ values_[val_key_t{name, block}] = value; auto it = metadatas_.find(name); if(auto *x = dynamic_cast(value)) @@ -29,23 +24,11 @@ void module::set_value(const std::string& name, ir::basic_block *block, ir::valu // value->set_name(name); } -void module::set_value(const std::string& name, ir::value *value){ +void value_constructor::set_value(const std::string& name, ir::value *value){ return set_value(name, builder_.get_insert_block(), value); } -void module::set_const(const std::string& name){ - const_.insert(name); -} - -void module::set_continue_fn(std::function fn) { - continue_fn_ = fn; -} - -std::function module::get_continue_fn() { - return continue_fn_; -} - -ir::phi_node* module::make_phi(ir::type *ty, unsigned num_values, ir::basic_block *block){ +ir::phi_node* value_constructor::make_phi(ir::type *ty, unsigned num_values, ir::basic_block *block){ basic_block::iterator insert = block->get_first_non_phi(); if(insert != block->end()){ builder_.set_insert_point(insert); @@ -56,7 +39,7 @@ ir::phi_node* module::make_phi(ir::type *ty, unsigned num_values, ir::basic_bloc return res; } -ir::value *module::try_remove_trivial_phis(ir::phi_node *&phi){ +ir::value *value_constructor::try_remove_trivial_phis(ir::phi_node *&phi){ // find non-self references std::set non_self_ref; std::copy_if(phi->ops().begin(), phi->ops().end(), std::inserter(non_self_ref, non_self_ref.begin()), @@ -69,7 +52,7 @@ ir::value *module::try_remove_trivial_phis(ir::phi_node *&phi){ assert(same != nullptr); phi->replace_all_uses_with(same); phi->erase_from_parent(); - std::set users = phi->get_users(); + std::vector users = phi->get_users(); for(ir::user* u: users) if(auto *uphi = dynamic_cast(u)) if(uphi != phi) @@ -78,7 +61,7 @@ ir::value *module::try_remove_trivial_phis(ir::phi_node *&phi){ } -ir::value *module::add_phi_operands(const std::string& name, ir::phi_node *&phi){ +ir::value *value_constructor::add_phi_operands(const std::string& name, ir::phi_node *&phi){ // already initialized if(phi->get_num_operands()) return phi; @@ -90,12 +73,11 @@ ir::value *module::add_phi_operands(const std::string& name, ir::phi_node *&phi) return phi; } -ir::value *module::get_value_recursive(const std::string& name, ir::basic_block *block) { +ir::value *value_constructor::get_value_recursive(const std::string& name, ir::basic_block *block) { ir::value *result; - bool is_const = const_.find(name) != const_.end(); - auto &preds = block->get_predecessors(); + auto preds = block->get_predecessors(); ir::type *ty = types_.at(name); - if(block && !is_const && sealed_blocks_.find(block) == sealed_blocks_.end()){ + if(block && sealed_blocks_.find(block) == sealed_blocks_.end()){ incomplete_phis_[block][name] = make_phi(ty, 1, block); result = (ir::value*)incomplete_phis_[block][name]; } @@ -117,10 +99,12 @@ ir::value *module::get_value_recursive(const std::string& name, ir::basic_block return result; } -ir::value *module::get_value(const std::string& name, ir::basic_block *block) { +ir::value *value_constructor::get_value(const std::string& name, ir::basic_block *block) { ir::basic_block* save_block = builder_.get_insert_block(); ir::basic_block::iterator save_pt = builder_.get_insert_point(); val_key_t key(name, block); +// std::cout << values_.size() << std::endl; +// std::cout << name << " " << block << " " << values_.begin()->first.first << " " << values_.begin()->first.second << std::endl; if(values_.find(key) != values_.end()){ return values_.at(key); } @@ -131,15 +115,11 @@ ir::value *module::get_value(const std::string& name, ir::basic_block *block) { return result; } -ir::value *module::get_value(const std::string& name) { +ir::value *value_constructor::get_value(const std::string& name) { return get_value(name, builder_.get_insert_block()); } -const std::string& module::get_name() { - return name_; -} - -void module::seal_block(ir::basic_block *block){ +void value_constructor::seal_block(ir::basic_block *block){ for(auto &x: incomplete_phis_[block]){ add_phi_operands(x.first, x.second); if(get_value(x.first) == x.second) @@ -149,11 +129,40 @@ void module::seal_block(ir::basic_block *block){ incomplete_phis_[block].clear(); } + + +/* Module */ + +module::module(const std::string &name, builder &builder) + : name_(name), builder_(builder) { +} + +void module::reset_ret_ty(const std::string& name, type* ty) { + get_function(name)->get_fn_type()->reset_ret_ty(ty); +} + +ir::builder& module::get_builder() { + return builder_; +} + +void module::set_continue_fn(std::function fn) { + continue_fn_ = fn; +} + +std::function module::get_continue_fn() { + return continue_fn_; +} + +const std::string& module::get_name() { + return name_; +} + /* functions */ function *module::get_or_insert_function(const std::string &name, function_type *ty) { function *&fn = (function*&)symbols_[name]; - if(fn == nullptr) - return fn = function::create(ty, global_value::external, name, this); + if(fn == nullptr){ + fn = function::create(ty, global_value::external, name, this); + } return fn; } diff --git a/lib/ir/type.cc b/lib/ir/type.cc index 7e4e4e5d7..735fad965 100644 --- a/lib/ir/type.cc +++ b/lib/ir/type.cc @@ -188,7 +188,26 @@ bool composite_type::index_valid(value *idx) const{ } //===----------------------------------------------------------------------===// -// tile_type class +// struct_type class +//===----------------------------------------------------------------------===// + +struct_type::struct_type(const contained_tys_vec_t& tys, bool is_packed) + : composite_type(tys[0]->get_context(), StructTyID), is_packed_(is_packed) { + contained_tys_ = tys; +} + +struct_type* struct_type::get(const contained_tys_vec_t& tys, bool is_packed) { + assert(tys.size()); + context_impl* impl = tys[0]->get_context().p_impl.get(); + struct_type *& entry = impl->struct_tys[tys]; + if(!entry) + entry = new struct_type(tys, is_packed); + return entry; +} + + +//===----------------------------------------------------------------------===// +// block_type class //===----------------------------------------------------------------------===// block_type::block_type(type *ty, const block_shapes_t &shapes) diff --git a/lib/ir/value.cc b/lib/ir/value.cc index b970e07d7..251d64479 100644 --- a/lib/ir/value.cc +++ b/lib/ir/value.cc @@ -1,5 +1,6 @@ #include #include +#include #include "triton/ir/value.h" #include "triton/ir/instructions.h" @@ -17,11 +18,11 @@ value::value(type *ty, const std::string &name): ty_(ty){ } void value::add_use(user *arg) { - users_.insert(arg); + users_.push_back(arg); } value::users_t::iterator value::erase_use(user *arg){ - auto it = users_.find(arg); + auto it = std::find(users_.begin(), users_.end(), arg); if(it == users_.end()) return it; return users_.erase(it); diff --git a/python/setup.py b/python/setup.py index 6a04a4e42..9179baa5b 100644 --- a/python/setup.py +++ b/python/setup.py @@ -79,7 +79,7 @@ class CMakeBuild(build_ext): def build_extension(self, ext): llvm_include_dir, llvm_library_dir = get_llvm() - self.debug = True + # self.debug = True extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path))) # create build directories build_suffix = 'debug' if self.debug else 'release' diff --git a/python/src/triton.cc b/python/src/triton.cc index 22017ebf5..b97044421 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -659,6 +659,8 @@ void init_triton_ir(py::module &&m) { py::class_(m, "type") .def("is_ptr", &ir::type::is_pointer_ty) .def("is_int", static_cast(&ir::type::is_integer_ty)) + .def("get_int_width", &ir::type::get_integer_bitwidth) + .def("is_floating", &ir::type::is_floating_point_ty) .def("is_block", &ir::type::is_block_ty) .def("make_ptr", &ir::pointer_type::get, ret::reference) @@ -695,6 +697,7 @@ void init_triton_ir(py::module &&m) { .def("is_uint16", [](ir::type *self) { return self->is_integer_ty(16, ir::signedness::UNSIGNED); }) .def("is_uint32", [](ir::type *self) { return self->is_integer_ty(32, ir::signedness::UNSIGNED); }) .def("is_uint64", [](ir::type *self) { return self->is_integer_ty(64, ir::signedness::UNSIGNED); }) + .def("is_struct", &ir::type::is_struct_ty) .def("repr", &ir::type::repr) .def_property_readonly("fp_mantissa_width", &ir::type::get_fp_mantissa_width) @@ -704,23 +707,37 @@ void init_triton_ir(py::module &&m) { py::class_(m, "pointer_type") .def_property_readonly("element", &ir::pointer_type::get_element_ty, ret::reference); - py::class_(m, "function_type"); + py::class_(m, "function_type") + .def_property_readonly("ret_ty", &ir::function_type::get_return_ty) + .def_property_readonly("arg_tys", [](ir::function_type* self){ + return std::vector(self->params_begin(), self->params_end()); + }); + py::class_(m, "integer_type"); + py::class_(m, "block_type") .def_property_readonly("shape", &ir::block_type::get_shapes) .def_property_readonly("numel", &ir::type::get_tile_num_elements); + + py::class_(m, "struct_type") + .def("get", &ir::struct_type::get, ret::reference) + .def_property_readonly("num_types", &ir::struct_type::get_num_types); + + py::class_(m, "value_constructor") + .def(py::init()) + .def("seal_block", &ir::value_constructor::seal_block) + .def("set_value", (void (ir::value_constructor::*)(const std::string &, ir::value *)) & ir::value_constructor::set_value) + .def("set_type", &ir::value_constructor::set_type) + .def("get_value", (ir::value * (ir::value_constructor::*)(const std::string &)) & ir::value_constructor::get_value, ret::reference) + .def("get_values", &ir::value_constructor::get_values, ret::reference) + .def("set_values", &ir::value_constructor::set_values); py::class_(m, "module") .def(py::init()) + .def("has_function", &ir::module::has_function) + .def("get_function", &ir::module::get_function, ret::reference) .def("get_or_insert_function", &ir::module::get_or_insert_function, ret::reference) - .def("seal_block", &ir::module::seal_block) - .def("set_value", (void (ir::module::*)(const std::string &, ir::value *)) & ir::module::set_value) - .def("set_type", &ir::module::set_type) - .def("get_value", (ir::value * (ir::module::*)(const std::string &)) & ir::module::get_value, ret::reference) - .def("get_values", &ir::module::get_values, ret::reference) - .def("set_values", &ir::module::set_values) - .def("get_types", &ir::module::get_types, ret::reference) - .def("set_types", &ir::module::set_types) + .def("reset_ret_ty", &ir::module::reset_ret_ty) .def_property_readonly("builder", &ir::module::get_builder, ret::reference); using eattr = ir::attribute_kind_t; @@ -734,29 +751,45 @@ void init_triton_ir(py::module &&m) { .value("not_implemented", eattr::not_implemented); py::class_(m, "attribute") - .def(py::init()); + .def(py::init()) + .def_property_readonly("value", &ir::attribute::get_value); py::class_(m, "function") .def_property_readonly("args", &ir::function::args) .def_property_readonly("attrs", &ir::function::attrs) - .def("add_attr", &ir::function::add_attr); + .def("set_is_kernel", &ir::function::set_is_kernel) + .def("add_attr", &ir::function::add_attr) + .def("has_attr", &ir::function::has_attr) + .def("get_attrs", &ir::function::get_attributes); - py::class_(m, "argument"); + py::class_(m, "argument") + .def_property_readonly("parent", &ir::argument::get_parent, ret::reference) + .def_property_readonly("arg_no", &ir::argument::get_arg_no); py::class_(m, "basic_block") - .def("create", &ir::basic_block::create, ret::reference) + .def("create", &ir::basic_block::create, ret::reference, py::arg(), py::arg(), py::arg() = nullptr) .def_property_readonly("parent", &ir::basic_block::get_parent, ret::reference); + py::class_(m, "bb_iterator"); + py::class_(m, "builder", py::dynamic_attr()) .def(py::init()) // getters .def_property_readonly("context", &ir::builder::get_context, ret::reference) // control flow + .def("call", &ir::builder::create_call, ret::reference) + .def("launch", &ir::builder::create_launch, ret::reference) .def("br", &ir::builder::create_br, ret::reference) .def("cond_br", &ir::builder::create_cond_br, ret::reference) .def("ret_void", &ir::builder::create_ret_void, ret::reference) + .def("ret", &ir::builder::create_ret, ret::reference) + .def("get_insert_point", &ir::builder::get_insert_point) + .def("set_insert_point", (void (ir::builder::*)(ir::builder::iterator))&ir::builder::set_insert_point) .def("get_insert_block", &ir::builder::get_insert_block, ret::reference) .def("set_insert_block", (void (ir::builder::*)(ir::basic_block *)) & ir::builder::set_insert_point) + // struct + .def("insert_value", &ir::builder::create_insert_value, ret::reference) + .def("extract_value", &ir::builder::create_extract_value, ret::reference) // constants .def("get_int1", &ir::builder::get_int1, ret::reference) .def("get_int32", &ir::builder::get_int32, ret::reference) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index a49b47585..08e5b721f 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -585,7 +585,6 @@ def test_f8_f16_roundtrip(): f8_output_tensor = torch.empty_like(f16, dtype=torch.int8) f8_output = triton.reinterpret(f8_output_tensor, tl.float8) - print(f16.dtype, f8_output.dtype) copy_kernel[grid](f16, f8_output, n_elements, BLOCK_SIZE=1024) assert torch.all(f8_tensor == f8_output_tensor) @@ -1009,8 +1008,8 @@ def test_value_specialization(value: int, value_type: str, device='cuda') -> Non # Parse out the type of the 'VALUE' parameter from the Triton IR. triton_ir = pgm.asm['ttir'] - ir_value_match = re.match(r'\s*def void kernel\((\w+) VALUE ', triton_ir) - ir_value_type = None if ir_value_match is None else ir_value_match.group(1) + ir_value_match = re.match(r'\s*def void (\w+)\((\w+) VALUE ', triton_ir) + ir_value_type = None if ir_value_match is None else ir_value_match.group(2) assert ir_value_type == value_type @@ -1031,3 +1030,28 @@ def test_value_specialization_overflow(value: int, overflow: bool, device='cuda' kernel[(1, )](value, x) else: kernel[(1, )](value, x) +# ------------------------- +# test dynamic parallelism +# ------------------------- + + +@triton.jit +def mult(x, alpha): + tl.store(x + tl.program_id(0), alpha) + + +@triton.jit +def stub(X, alpha, grid_0, grid_1, grid_2): + tl.launch(mult, [X, alpha], [grid_0, grid_1, grid_2]) + + +def test_dyn_par(cond=True, device='cuda'): + n_pids = 10 + # pids = torch.arange(n_pids, device=device) + # alpha = 2.0 + # x_ref = pids * alpha + x_tri = torch.full((10,), fill_value=-1., device=device) + # cond = torch.tensor([cond], device=device) + stub[(1,)](x_tri, 3.14, n_pids, 1, 1) + print(x_tri) + # triton.testing.assert_almost_equal(x_ref, x_tri) diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index cb705aaa6..e6102366a 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -21,6 +21,41 @@ import triton._C.libtriton.triton as _triton from .tools.disasm import extract +def mangle_ty(type): + if type.is_ptr(): + return 'P' + mangle_ty(type.element) + if type.is_int(): + return 'i' + str(type.get_int_width()) + if type.is_fp8(): + return 'fp8' + if type.is_fp16(): + return 'fp16' + if type.is_bf16(): + return 'bf16' + if type.is_fp32(): + return 'fp32' + if type.is_fp64(): + return 'fp64' + if type.is_void(): + return 'V' + if type.is_block(): + elt = mangle_ty(type.scalar) + shape = '_'.join(map(str, type.shape)) + return f'{elt}S{shape}S' + assert False, "Unsupport type" + + +def mangle_fn(name, arg_tys, constants): + # doesn't mangle ret type, which must be a function of arg tys + mangled_arg_names = '_'.join([mangle_ty(ty) for ty in arg_tys]) + key = lambda x: x.__name__ if isinstance(x, JITFunction) else repr(x) + mangled_constants = '_'.join([f'{i}c{key(constants[i])}' for i in sorted(constants)]) + mangled_constants = mangled_constants.replace('.', '_d_') + mangled_constants = mangled_constants.replace("'", '_sq_') + ret = f'{name}__{mangled_arg_names}__{mangled_constants}' + return ret + + class CodeGenerator(ast.NodeVisitor): def get_value(self, name): # search node.id in local scope @@ -36,7 +71,7 @@ class CodeGenerator(ast.NodeVisitor): else: raise ValueError(f'{name} is not defined') if isinstance(ret, triton.language.block): - handle = self.module.get_value(name) + handle = self.value_constructor.get_value(name) return triton.language.block(handle) return ret @@ -44,8 +79,8 @@ class CodeGenerator(ast.NodeVisitor): if isinstance(value, _triton.ir.value): value = triton.language.block(value) if isinstance(value, triton.language.block): - self.module.set_value(name, value.handle) - self.module.set_type(name, value.handle.type) + self.value_constructor.set_value(name, value.handle) + self.value_constructor.set_type(name, value.handle.type) self.lscope[name] = value def is_triton_object(self, value): @@ -58,16 +93,17 @@ class CodeGenerator(ast.NodeVisitor): break return stmts and isinstance(stmt, ast.Return) - def __init__(self, context, prototype, gscope, attributes, constants, kwargs): + def __init__(self, context, prototype, gscope, attributes, constants, module=None, is_kernel=False): self.builder = _triton.ir.builder(context) - self.module = _triton.ir.module('', self.builder) + self.value_constructor = _triton.ir.value_constructor(self.builder) + self.module = _triton.ir.module('', self.builder) if module is None else module self.prototype = prototype self.gscope = gscope self.lscope = dict() self.attributes = attributes self.constants = constants - self.kwargs = kwargs self.last_node = None + self.is_kernel = is_kernel self.builtins = { 'range': range, 'min': triton.language.minimum, @@ -92,9 +128,17 @@ class CodeGenerator(ast.NodeVisitor): ret = self.visit(node.value) if ret is None: return self.builder.ret_void() - return ret + if isinstance(ret, _triton.ir.value): + ret = self.builder.ret(ret) + return ret + if isinstance(ret, triton.language.block): + ret = ret.handle + if isinstance(ret, triton.language.constexpr): + ret = triton.language.core._to_ir(ret, self.builder) + # TODO: should return tl.block + return self.builder.ret(ret) - def visit_FunctionDef(self, node, inline=False, arg_values=None): + def visit_FunctionDef(self, node): arg_names, kwarg_names = self.visit(node.args) # initialize defaults for i, default_value in enumerate(node.args.defaults): @@ -107,45 +151,44 @@ class CodeGenerator(ast.NodeVisitor): else: init_node = ast.AnnAssign(target=st_target, value=default_value, annotation=annotation) self.visit(init_node) - # store keyword arguments in local scope - self.lscope[kwarg_names] = self.kwargs # initialize function - if inline: - pass - else: - fn = self.module.get_or_insert_function(node.name, self.prototype) - arg_values = [] - idx = 0 - for i, arg_name in enumerate(arg_names): - if i in self.constants: - cst = self.constants[i] - if not isinstance(cst, triton.language.constexpr): - cst = triton.language.constexpr(self.constants[i]) - arg_values.append(cst) - else: - if i in self.attributes: - is_ptr = fn.args[idx].type.is_ptr() - attr = 'aligned' if is_ptr else 'multiple_of' - attr = getattr(_triton.ir.attribute_kind, attr) - attr = _triton.ir.attribute(attr, self.attributes[i]) - fn.add_attr(idx + 1, attr) - fn.args[idx].name = arg_name - arg_values.append(fn.args[idx]) - idx += 1 + fn_name = mangle_fn(node.name, self.prototype.arg_tys, self.constants) + fn = self.module.get_or_insert_function(fn_name, self.prototype) + fn.set_is_kernel(self.is_kernel) + arg_values = [] + idx = 0 + for i, arg_name in enumerate(arg_names): + if i in self.constants: + cst = self.constants[i] + if not isinstance(cst, triton.language.constexpr): + cst = triton.language.constexpr(self.constants[i]) + arg_values.append(cst) + else: + if i in self.attributes: + is_ptr = fn.args[idx].type.is_ptr() + attr = 'aligned' if is_ptr else 'multiple_of' + attr = getattr(_triton.ir.attribute_kind, attr) + attr = _triton.ir.attribute(attr, self.attributes[i]) + fn.add_attr(idx + 1, attr) + fn.args[idx].name = arg_name + arg_values.append(fn.args[idx]) + idx += 1 + insert_pt = self.builder.get_insert_block() + entry = _triton.ir.basic_block.create(self.builder.context, "entry", fn) + self.builder.set_insert_block(entry) + self.value_constructor.seal_block(entry) for arg_name, arg_value in zip(arg_names, arg_values): self.set_value(arg_name, arg_value) - if inline: - self.visit_compound_statement(node.body) - return self.last_ret - else: - entry = _triton.ir.basic_block.create(self.builder.context, "entry", fn) - self.module.seal_block(entry) - self.builder.set_insert_block(entry) - # visit function body - self.visit_compound_statement(node.body) - # finalize function + # visit function body + has_ret = self.visit_compound_statement(node.body) + # finalize + if not has_ret: self.builder.ret_void() + else: + self.module.reset_ret_ty(fn_name, self.last_ret.type) + # self.module.reset_ret_type(node.name) + self.builder.set_insert_block(insert_pt) def visit_arguments(self, node): arg_names = [] @@ -186,6 +229,12 @@ class CodeGenerator(ast.NodeVisitor): names = [names] if not isinstance(values, tuple): values = [values] + if isinstance(values[0], _triton.ir.value): + struct = values[0] + ty = struct.type + if ty.is_struct(): + values = [self.builder.extract_value(struct, i) for i in range(ty.num_types)] + assert len(values) == len(names) for name, value in zip(names, values): # by default, constexpr are assigned into python variable if isinstance(value, triton.language.constexpr): @@ -215,6 +264,17 @@ class CodeGenerator(ast.NodeVisitor): def visit_Tuple(self, node): args = [self.visit(x) for x in node.elts] + mode = type(args[0]) + # tuple of values -- create a struct + if len(args) > 1 and mode == triton.language.block\ + and all([type(arg) == mode for arg in args]): + args = [arg.handle for arg in args] + tys = [arg.type for arg in args] + struct_ty = _triton.ir.struct_type.get(tys, True) + ret = _triton.ir.undef.get(struct_ty) + for i, arg in enumerate(args): + ret = self.builder.insert_value(ret, arg, i) + return ret return tuple(args) def visit_BinOp(self, node): @@ -254,9 +314,9 @@ class CodeGenerator(ast.NodeVisitor): then_bb = _triton.ir.basic_block.create(self.builder.context, "then", current_bb.parent) else_bb = _triton.ir.basic_block.create(self.builder.context, "else", current_bb.parent) if node.orelse else None endif_bb = _triton.ir.basic_block.create(self.builder.context, "endif", current_bb.parent) - self.module.seal_block(then_bb) + self.value_constructor.seal_block(then_bb) if else_bb: - self.module.seal_block(else_bb) + self.value_constructor.seal_block(else_bb) self.builder.cond_br(cond.handle, then_bb, else_bb) else: self.builder.cond_br(cond.handle, then_bb, endif_bb) @@ -271,7 +331,7 @@ class CodeGenerator(ast.NodeVisitor): # TODO: last statement is a terminator? if not is_terminator: self.builder.br(endif_bb) - self.module.seal_block(endif_bb) + self.value_constructor.seal_block(endif_bb) self.builder.set_insert_block(endif_bb) else: if isinstance(cond, triton.language.constexpr): @@ -350,9 +410,9 @@ class CodeGenerator(ast.NodeVisitor): self.visit_compound_statement(node.body) continue_fn() stop_bb = self.builder.get_insert_block() - self.module.seal_block(stop_bb) - self.module.seal_block(loop_bb) - self.module.seal_block(next_bb) + self.value_constructor.seal_block(stop_bb) + self.value_constructor.seal_block(loop_bb) + self.value_constructor.seal_block(next_bb) self.builder.set_insert_block(next_bb) for stmt in node.orelse: @@ -421,9 +481,9 @@ class CodeGenerator(ast.NodeVisitor): # TODO: handle case where body breaks control flow continue_fn() stop_bb = self.builder.get_insert_block() - self.module.seal_block(stop_bb) - self.module.seal_block(loop_bb) - self.module.seal_block(next_bb) + self.value_constructor.seal_block(stop_bb) + self.value_constructor.seal_block(loop_bb) + self.value_constructor.seal_block(next_bb) self.builder.set_insert_block(next_bb) for stmt in node.orelse: @@ -449,15 +509,62 @@ class CodeGenerator(ast.NodeVisitor): for keyword in node.keywords: kws.update(self.visit(keyword)) args = [self.visit(arg) for arg in node.args] + if isinstance(fn, JITFunction): - return fn(*args, generator=self, **kws) + from inspect import getcallargs + args = getcallargs(fn.fn, *args, **kws) + args = [args[name] for name in fn.arg_names] + args = [arg if isinstance(arg, triton.language.block) + else triton.language.constexpr(arg) for arg in args] + # generate function def + attributes = dict() + constexprs = [i for i, arg in enumerate(args) if isinstance(arg, triton.language.constexpr)] + constants = {i: args[i] for i in constexprs} + # generate call + args = [None if i in constexprs else arg for i, arg in enumerate(args)] + arg_vals = [arg.handle for arg in args if arg is not None] + arg_types = [arg.type for arg in arg_vals] + fn_name = mangle_fn(fn.__name__, arg_types, constants) + # generate function def if necessary + if not self.module.has_function(fn_name): + ret_type = _triton.ir.type.get_void(self.builder.context) + prototype = _triton.ir.type.make_function(ret_type, arg_types) + gscope = sys.modules[fn.fn.__module__].__dict__ + generator = CodeGenerator(self.builder.context, prototype, gscope, attributes, constants, module=self.module) + generator.visit(fn.parse()) + symbol = self.module.get_function(fn_name) + ret = self.builder.call(symbol, arg_vals) + if not ret.type.is_void() and not ret.type.is_struct(): + ret = triton.language.block(ret) + return ret + # built-in function if hasattr(fn, '__self__') and self.is_triton_object(fn.__self__) or \ sys.modules[fn.__module__] is triton.language.core: - return fn(*args, _builder=self.builder, **kws) + ret = fn(*args, _builder=self.builder, **kws) if fn in self.builtins.values(): args = [arg.value if isinstance(arg, triton.language.constexpr) else arg for arg in args] - return fn(*args, **kws) + ret = fn(*args, **kws) + # special case: dynamic parallelism + # in this case the core primitive returns a proxy + # if isinstance(ret, triton.language.core.LaunchProxy): + # ret_type = _triton.ir.type.get_void(self.builder.context) + # arg_tys = [x.type for x in ret.args] + # prototype = _triton.ir.type.make_function(ret_type, arg_tys) + # gscope = sys.modules[ret.fn.fn.__module__].__dict__ + # constants = ret.constants + # fn_name = mangle_fn(ret.fn.__name__, arg_tys, ret.constants) + # # TODO: clean-up attributes handling in function + # if not self.module.has_function(fn_name): + # attributes = {i: list(arg.parent.get_attrs(arg))[0].value for i, arg in enumerate(ret.args) \ + # if isinstance(arg, _triton.ir.argument) and arg.parent.has_attr(i + 1) } + # generator = CodeGenerator(self.builder.context, prototype, gscope, attributes, constants, module=self.module, is_kernel=True) + # generator.visit(ret.fn.parse()) + # symbol = self.module.get_function(fn_name) + # # TODO: should ret.args not include any constants ? + # ret = self.builder.launch(symbol, ret.args, ret.grid, ret.num_warps) + return ret + # return fn(*args, **kws) def visit_Constant(self, node): return triton.language.constexpr(node.value) @@ -669,6 +776,7 @@ class Kernel: def add_to_cache(self, key, wargs, device_idx, num_warps, num_stages): tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')] + # attributes attributes = dict() for i, arg in enumerate(wargs): @@ -881,7 +989,7 @@ class JITFunction: cache_hook = None - def __init__(self, fn, version=None, do_not_specialize=None): + def __init__(self, fn, version=None, inline=True, do_not_specialize=None): # information of wrapped function self.fn = fn self.module = fn.__module__ @@ -890,6 +998,7 @@ class JITFunction: self.arg_defaults = [v.default for v in signature.parameters.values()] self.version = version + self.inline = inline self.src = textwrap.dedent(inspect.getsource(fn)) self.src = self.src[self.src.find("def"):] self.do_not_specialize = [] if do_not_specialize is None else do_not_specialize @@ -904,6 +1013,8 @@ class JITFunction: # annotations self.annotations = {self.arg_names.index(name): ty for name, ty in fn.__annotations__.items()} self.__annotations__ = fn.__annotations__ + # constexprs + self.constexprs = [self.arg_names.index(ann) for ann in self.__annotations__.keys()] # forward docs self.__doc__ = fn.__doc__ self.__name__ = fn.__name__ @@ -930,31 +1041,8 @@ class JITFunction: assert isinstance(tree.body[0], ast.FunctionDef) return tree - def __call__(self, *args, generator: CodeGenerator, **kwargs): - try: - from inspect import getcallargs - arg_values = getcallargs(self.fn, *args, **kwargs) - arg_values = [arg_values[name] for name in self.arg_names] - arg_values = [arg if isinstance(arg, triton.language.block) - else triton.language.constexpr(arg) for arg in arg_values] - - gscope = generator.gscope.copy() - lscope = generator.lscope.copy() - values = generator.module.get_values().copy() - types = generator.module.get_types().copy() - generator.gscope = sys.modules[self.fn.__module__].__dict__ - generator.lscope = dict() - ret = generator.visit_FunctionDef(self.parse().body[0], inline=True, arg_values=arg_values) - generator.gscope = gscope - generator.lscope = lscope - generator.module.set_values(values) - generator.module.set_types(types) - return ret - except Exception as e: - node = generator.last_node - if node is None or isinstance(e, (NotImplementedError, CompilationError)): - raise e - raise CompilationError(self.src, node) from e + def __call__(self, *args, **kwargs): + raise RuntimeError("Cannot call @triton.jit'd outside of the scope of a kernel.") # - when `.src` attribute is set, cache path needs # to be reinitialized @@ -1039,7 +1127,7 @@ class JITFunction: # generate Triton-IR # export symbols visible from self into code-generator object gscope = self.__globals__ - generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, kwargs=dict()) + generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, is_kernel=True) try: generator.visit(self.parse()) except Exception as e: @@ -1199,9 +1287,21 @@ def jit(*args, **kwargs): return JITFunction(fn, **kwargs) return decorator +###### + +# class ForwardDeclaration: + +# def __init__(self, name, ret_ty, arg_tys) -> None: +# self.name = name +# self.ret_ty = ret_ty +# self.arg_tys = arg_tys + +# def forward_declare(name, ret_ty, arg_tys): +# return ForwardDeclaration(name, ret_ty, arg_tys) ###### + def cdiv(x, y): return (x + y - 1) // y diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 0312d8146..cad4edfe4 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -888,7 +888,7 @@ def sigmoid(x): @triton.jit @_add_math_1arg_docstr("softmax") -def softmax(x, ieee_rounding=False): +def softmax(x, ieee_rounding: constexpr = False): z = x - triton.language.max(x, 0) num = triton.language.exp(z) den = triton.language.sum(num, 0) @@ -942,3 +942,26 @@ def swizzle2d(i, j, size_i, size_j, size_g): @triton.jit def zeros_like(input): return zeros(input.shape, input.dtype) +# ----------------------- +# Dynamic Parallelism +# ----------------------- + + +class LaunchProxy: + + def __init__(self, fn, args, constants, grid, num_warps) -> None: + self.args = args + self.grid = grid + self.constants = constants + self.num_warps = num_warps + self.fn = fn + + +@builtin +def launch(fn, args, grid, num_warps=None, _builder=None): + constants = {i: x for i, x in enumerate(args) if isinstance(x, constexpr)} + args = [_to_ir(x, builder=_builder) for x in args if not isinstance(x, constexpr)] + grid = [_to_ir(x, builder=_builder) for x in grid] + if num_warps is None: + num_warps = _to_ir(4, builder=_builder) + return LaunchProxy(fn, args, constants, grid, num_warps)