[LANG] Added support for device functions (#484)
This commit is contained in:
@@ -224,6 +224,7 @@ struct scanline_layout: public distributed_layout {
|
|||||||
int nts(size_t k) { return nts_.at(k); }
|
int nts(size_t k) { return nts_.at(k); }
|
||||||
int contig_per_thread(size_t k) { return nts_.at(k); }
|
int contig_per_thread(size_t k) { return nts_.at(k); }
|
||||||
|
|
||||||
|
int per_thread(size_t k) { return nts(k) * shape_[k] / shape_per_cta(k);}
|
||||||
public:
|
public:
|
||||||
// micro tile size. The size of a tile held by a thread block.
|
// micro tile size. The size of a tile held by a thread block.
|
||||||
std::vector<int> mts_;
|
std::vector<int> mts_;
|
||||||
|
@@ -24,6 +24,7 @@ namespace llvm{
|
|||||||
class IRBuilder;
|
class IRBuilder;
|
||||||
class ArrayType;
|
class ArrayType;
|
||||||
class Function;
|
class Function;
|
||||||
|
class StructType;
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace triton{
|
namespace triton{
|
||||||
@@ -114,6 +115,8 @@ private:
|
|||||||
private:
|
private:
|
||||||
Type *cvt(ir::type *ty);
|
Type *cvt(ir::type *ty);
|
||||||
llvm::Attribute cvt(ir::attribute attr);
|
llvm::Attribute cvt(ir::attribute attr);
|
||||||
|
llvm::StructType* packed_type(ir::value* i);
|
||||||
|
void forward_declare(ir::function* fn);
|
||||||
|
|
||||||
public:
|
public:
|
||||||
generator(analysis::axes *a_axes,
|
generator(analysis::axes *a_axes,
|
||||||
@@ -125,6 +128,8 @@ public:
|
|||||||
unsigned num_warps);
|
unsigned num_warps);
|
||||||
|
|
||||||
void visit_value(ir::value* v);
|
void visit_value(ir::value* v);
|
||||||
|
void visit_call_inst(ir::call_inst*);
|
||||||
|
void visit_launch_inst(ir::launch_inst *);
|
||||||
void visit_phi_node(ir::phi_node*);
|
void visit_phi_node(ir::phi_node*);
|
||||||
void visit_binary_operator(ir::binary_operator*);
|
void visit_binary_operator(ir::binary_operator*);
|
||||||
void visit_getelementptr_inst(ir::getelementptr_inst*);
|
void visit_getelementptr_inst(ir::getelementptr_inst*);
|
||||||
@@ -148,6 +153,8 @@ public:
|
|||||||
void visit_unmasked_store_inst(ir::unmasked_store_inst*);
|
void visit_unmasked_store_inst(ir::unmasked_store_inst*);
|
||||||
void visit_masked_store_inst(ir::masked_store_inst*);
|
void visit_masked_store_inst(ir::masked_store_inst*);
|
||||||
void visit_cat_inst(ir::cat_inst*);
|
void visit_cat_inst(ir::cat_inst*);
|
||||||
|
void visit_extract_value_inst(ir::extract_value_inst *);
|
||||||
|
void visit_insert_value_inst(ir::insert_value_inst *);
|
||||||
void visit_reshape_inst(ir::reshape_inst*);
|
void visit_reshape_inst(ir::reshape_inst*);
|
||||||
void visit_splat_inst(ir::splat_inst*);
|
void visit_splat_inst(ir::splat_inst*);
|
||||||
void visit_broadcast_inst(ir::broadcast_inst*);
|
void visit_broadcast_inst(ir::broadcast_inst*);
|
||||||
@@ -242,6 +249,7 @@ private:
|
|||||||
/// triton bb -> llvm bb
|
/// triton bb -> llvm bb
|
||||||
std::map<ir::value*, BasicBlock *> bbs_;
|
std::map<ir::value*, BasicBlock *> bbs_;
|
||||||
std::map<ir::value*, std::vector<int>> ords_;
|
std::map<ir::value*, std::vector<int>> ords_;
|
||||||
|
std::map<ir::value*, Function*> fns_;
|
||||||
|
|
||||||
// helper for creating llvm values
|
// helper for creating llvm values
|
||||||
adder add;
|
adder add;
|
||||||
|
31
include/triton/codegen/transform/inline.h
Normal file
31
include/triton/codegen/transform/inline.h
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <list>
|
||||||
|
|
||||||
|
namespace triton {
|
||||||
|
|
||||||
|
namespace ir {
|
||||||
|
class module;
|
||||||
|
class function;
|
||||||
|
class call_inst;
|
||||||
|
class builder;
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace codegen{
|
||||||
|
namespace transform{
|
||||||
|
|
||||||
|
struct fncmp {
|
||||||
|
bool operator()(ir::function* x, ir::function* y) const;
|
||||||
|
};
|
||||||
|
|
||||||
|
class inliner {
|
||||||
|
public:
|
||||||
|
inliner() {}
|
||||||
|
void do_inline(ir::function* fn, ir::call_inst* callsite, ir::builder& builder, std::list<ir::call_inst*>& callsites);
|
||||||
|
void run(ir::module &mod);
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@@ -30,6 +30,9 @@ private:
|
|||||||
bool rewrite_dot_hmma(ir::dot_inst *dot, ir::builder& builder, bool trans_a, bool trans_b, ir::value *A, ir::value *B, ir::value *D);
|
bool rewrite_dot_hmma(ir::dot_inst *dot, ir::builder& builder, bool trans_a, bool trans_b, ir::value *A, ir::value *B, ir::value *D);
|
||||||
bool rewrite_dot(ir::instruction *value, ir::builder& builder);
|
bool rewrite_dot(ir::instruction *value, ir::builder& builder);
|
||||||
bool rewrite_mult(ir::instruction *value, ir::builder& builder);
|
bool rewrite_mult(ir::instruction *value, ir::builder& builder);
|
||||||
|
bool rewrite_insert_extract(ir::instruction *value, ir::builder& builder);
|
||||||
|
|
||||||
|
|
||||||
bool rewrite_unit_red(ir::instruction *value, ir::builder& builder);
|
bool rewrite_unit_red(ir::instruction *value, ir::builder& builder);
|
||||||
bool rewrite_gep_ptr_min_off_plus_off(ir::instruction *value, ir::builder& builder);
|
bool rewrite_gep_ptr_min_off_plus_off(ir::instruction *value, ir::builder& builder);
|
||||||
bool rewrite_select_masked_load(ir::instruction *value, ir::builder& builder);
|
bool rewrite_select_masked_load(ir::instruction *value, ir::builder& builder);
|
||||||
|
@@ -89,6 +89,7 @@ public:
|
|||||||
static CUresult cuDeviceGetAttribute(int *pi, CUdevice_attribute attrib, CUdevice dev);
|
static CUresult cuDeviceGetAttribute(int *pi, CUdevice_attribute attrib, CUdevice dev);
|
||||||
static CUresult cuDeviceGetCount(int *count);
|
static CUresult cuDeviceGetCount(int *count);
|
||||||
// link management
|
// link management
|
||||||
|
static CUresult cuLinkAddFile_v2(CUlinkState state, CUjitInputType type, const char *path, unsigned int numOptions, CUjit_option *options, void **optionValues);
|
||||||
static CUresult cuLinkAddData_v2(CUlinkState state, CUjitInputType type, void* data, size_t size, const char* name, unsigned int numOptions, CUjit_option* options, void** optionValues);
|
static CUresult cuLinkAddData_v2(CUlinkState state, CUjitInputType type, void* data, size_t size, const char* name, unsigned int numOptions, CUjit_option* options, void** optionValues);
|
||||||
static CUresult cuLinkCreate_v2(unsigned int numOptions, CUjit_option* options, void** optionValues, CUlinkState* stateOut);
|
static CUresult cuLinkCreate_v2(unsigned int numOptions, CUjit_option* options, void** optionValues, CUlinkState* stateOut);
|
||||||
static CUresult cuLinkComplete(CUlinkState state, void** cubinOut, size_t* sizeOut);
|
static CUresult cuLinkComplete(CUlinkState state, void** cubinOut, size_t* sizeOut);
|
||||||
@@ -214,6 +215,7 @@ private:
|
|||||||
static void* cuDeviceGetAttribute_;
|
static void* cuDeviceGetAttribute_;
|
||||||
static void* cuDeviceGetCount_;
|
static void* cuDeviceGetCount_;
|
||||||
// link management
|
// link management
|
||||||
|
static void* cuLinkAddFile_v2_;
|
||||||
static void* cuLinkAddData_v2_;
|
static void* cuLinkAddData_v2_;
|
||||||
static void* cuLinkCreate_v2_;
|
static void* cuLinkCreate_v2_;
|
||||||
static void* cuLinkDestroy_;
|
static void* cuLinkDestroy_;
|
||||||
|
244
include/triton/external/CUDA/cuda.h
vendored
Executable file → Normal file
244
include/triton/external/CUDA/cuda.h
vendored
Executable file → Normal file
@@ -224,7 +224,7 @@ typedef uint64_t cuuint64_t;
|
|||||||
/**
|
/**
|
||||||
* CUDA API version number
|
* CUDA API version number
|
||||||
*/
|
*/
|
||||||
#define CUDA_VERSION 11050
|
#define CUDA_VERSION 11040
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
extern "C" {
|
extern "C" {
|
||||||
@@ -496,33 +496,7 @@ typedef enum CUarray_format_enum {
|
|||||||
CU_AD_FORMAT_SIGNED_INT32 = 0x0a, /**< Signed 32-bit integers */
|
CU_AD_FORMAT_SIGNED_INT32 = 0x0a, /**< Signed 32-bit integers */
|
||||||
CU_AD_FORMAT_HALF = 0x10, /**< 16-bit floating point */
|
CU_AD_FORMAT_HALF = 0x10, /**< 16-bit floating point */
|
||||||
CU_AD_FORMAT_FLOAT = 0x20, /**< 32-bit floating point */
|
CU_AD_FORMAT_FLOAT = 0x20, /**< 32-bit floating point */
|
||||||
CU_AD_FORMAT_NV12 = 0xb0, /**< 8-bit YUV planar format, with 4:2:0 sampling */
|
CU_AD_FORMAT_NV12 = 0xb0
|
||||||
CU_AD_FORMAT_UNORM_INT8X1 = 0xc0, /**< 1 channel unsigned 8-bit normalized integer */
|
|
||||||
CU_AD_FORMAT_UNORM_INT8X2 = 0xc1, /**< 2 channel unsigned 8-bit normalized integer */
|
|
||||||
CU_AD_FORMAT_UNORM_INT8X4 = 0xc2, /**< 4 channel unsigned 8-bit normalized integer */
|
|
||||||
CU_AD_FORMAT_UNORM_INT16X1 = 0xc3, /**< 1 channel unsigned 16-bit normalized integer */
|
|
||||||
CU_AD_FORMAT_UNORM_INT16X2 = 0xc4, /**< 2 channel unsigned 16-bit normalized integer */
|
|
||||||
CU_AD_FORMAT_UNORM_INT16X4 = 0xc5, /**< 4 channel unsigned 16-bit normalized integer */
|
|
||||||
CU_AD_FORMAT_SNORM_INT8X1 = 0xc6, /**< 1 channel signed 8-bit normalized integer */
|
|
||||||
CU_AD_FORMAT_SNORM_INT8X2 = 0xc7, /**< 2 channel signed 8-bit normalized integer */
|
|
||||||
CU_AD_FORMAT_SNORM_INT8X4 = 0xc8, /**< 4 channel signed 8-bit normalized integer */
|
|
||||||
CU_AD_FORMAT_SNORM_INT16X1 = 0xc9, /**< 1 channel signed 16-bit normalized integer */
|
|
||||||
CU_AD_FORMAT_SNORM_INT16X2 = 0xca, /**< 2 channel signed 16-bit normalized integer */
|
|
||||||
CU_AD_FORMAT_SNORM_INT16X4 = 0xcb, /**< 4 channel signed 16-bit normalized integer */
|
|
||||||
CU_AD_FORMAT_BC1_UNORM = 0x91, /**< 4 channel unsigned normalized block-compressed (BC1 compression) format */
|
|
||||||
CU_AD_FORMAT_BC1_UNORM_SRGB = 0x92, /**< 4 channel unsigned normalized block-compressed (BC1 compression) format with sRGB encoding*/
|
|
||||||
CU_AD_FORMAT_BC2_UNORM = 0x93, /**< 4 channel unsigned normalized block-compressed (BC2 compression) format */
|
|
||||||
CU_AD_FORMAT_BC2_UNORM_SRGB = 0x94, /**< 4 channel unsigned normalized block-compressed (BC2 compression) format with sRGB encoding*/
|
|
||||||
CU_AD_FORMAT_BC3_UNORM = 0x95, /**< 4 channel unsigned normalized block-compressed (BC3 compression) format */
|
|
||||||
CU_AD_FORMAT_BC3_UNORM_SRGB = 0x96, /**< 4 channel unsigned normalized block-compressed (BC3 compression) format with sRGB encoding*/
|
|
||||||
CU_AD_FORMAT_BC4_UNORM = 0x97, /**< 1 channel unsigned normalized block-compressed (BC4 compression) format */
|
|
||||||
CU_AD_FORMAT_BC4_SNORM = 0x98, /**< 1 channel signed normalized block-compressed (BC4 compression) format */
|
|
||||||
CU_AD_FORMAT_BC5_UNORM = 0x99, /**< 2 channel unsigned normalized block-compressed (BC5 compression) format */
|
|
||||||
CU_AD_FORMAT_BC5_SNORM = 0x9a, /**< 2 channel signed normalized block-compressed (BC5 compression) format */
|
|
||||||
CU_AD_FORMAT_BC6H_UF16 = 0x9b, /**< 3 channel unsigned half-float block-compressed (BC6H compression) format */
|
|
||||||
CU_AD_FORMAT_BC6H_SF16 = 0x9c, /**< 3 channel signed half-float block-compressed (BC6H compression) format */
|
|
||||||
CU_AD_FORMAT_BC7_UNORM = 0x9d, /**< 4 channel unsigned normalized block-compressed (BC7 compression) format */
|
|
||||||
CU_AD_FORMAT_BC7_UNORM_SRGB = 0x9e /**< 4 channel unsigned normalized block-compressed (BC7 compression) format with sRGB encoding */
|
|
||||||
} CUarray_format;
|
} CUarray_format;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -657,7 +631,7 @@ typedef enum CUdevice_attribute_enum {
|
|||||||
CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED = 102, /**< Device supports virtual memory management APIs like ::cuMemAddressReserve, ::cuMemCreate, ::cuMemMap and related APIs */
|
CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED = 102, /**< Device supports virtual memory management APIs like ::cuMemAddressReserve, ::cuMemCreate, ::cuMemMap and related APIs */
|
||||||
CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR_SUPPORTED = 103, /**< Device supports exporting memory to a posix file descriptor with ::cuMemExportToShareableHandle, if requested via ::cuMemCreate */
|
CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR_SUPPORTED = 103, /**< Device supports exporting memory to a posix file descriptor with ::cuMemExportToShareableHandle, if requested via ::cuMemCreate */
|
||||||
CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_WIN32_HANDLE_SUPPORTED = 104, /**< Device supports exporting memory to a Win32 NT handle with ::cuMemExportToShareableHandle, if requested via ::cuMemCreate */
|
CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_WIN32_HANDLE_SUPPORTED = 104, /**< Device supports exporting memory to a Win32 NT handle with ::cuMemExportToShareableHandle, if requested via ::cuMemCreate */
|
||||||
CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_WIN32_KMT_HANDLE_SUPPORTED = 105, /**< Device supports exporting memory to a Win32 KMT handle with ::cuMemExportToShareableHandle, if requested via ::cuMemCreate */
|
CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_WIN32_KMT_HANDLE_SUPPORTED = 105, /**< Device supports exporting memory to a Win32 KMT handle with ::cuMemExportToShareableHandle, if requested ::cuMemCreate */
|
||||||
CU_DEVICE_ATTRIBUTE_MAX_BLOCKS_PER_MULTIPROCESSOR = 106, /**< Maximum number of blocks per multiprocessor */
|
CU_DEVICE_ATTRIBUTE_MAX_BLOCKS_PER_MULTIPROCESSOR = 106, /**< Maximum number of blocks per multiprocessor */
|
||||||
CU_DEVICE_ATTRIBUTE_GENERIC_COMPRESSION_SUPPORTED = 107, /**< Device supports compression of memory */
|
CU_DEVICE_ATTRIBUTE_GENERIC_COMPRESSION_SUPPORTED = 107, /**< Device supports compression of memory */
|
||||||
CU_DEVICE_ATTRIBUTE_MAX_PERSISTING_L2_CACHE_SIZE = 108, /**< Maximum L2 persisting lines capacity setting in bytes. */
|
CU_DEVICE_ATTRIBUTE_MAX_PERSISTING_L2_CACHE_SIZE = 108, /**< Maximum L2 persisting lines capacity setting in bytes. */
|
||||||
@@ -665,7 +639,7 @@ typedef enum CUdevice_attribute_enum {
|
|||||||
CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_WITH_CUDA_VMM_SUPPORTED = 110, /**< Device supports specifying the GPUDirect RDMA flag with ::cuMemCreate */
|
CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_WITH_CUDA_VMM_SUPPORTED = 110, /**< Device supports specifying the GPUDirect RDMA flag with ::cuMemCreate */
|
||||||
CU_DEVICE_ATTRIBUTE_RESERVED_SHARED_MEMORY_PER_BLOCK = 111, /**< Shared memory reserved by CUDA driver per block in bytes */
|
CU_DEVICE_ATTRIBUTE_RESERVED_SHARED_MEMORY_PER_BLOCK = 111, /**< Shared memory reserved by CUDA driver per block in bytes */
|
||||||
CU_DEVICE_ATTRIBUTE_SPARSE_CUDA_ARRAY_SUPPORTED = 112, /**< Device supports sparse CUDA arrays and sparse CUDA mipmapped arrays */
|
CU_DEVICE_ATTRIBUTE_SPARSE_CUDA_ARRAY_SUPPORTED = 112, /**< Device supports sparse CUDA arrays and sparse CUDA mipmapped arrays */
|
||||||
CU_DEVICE_ATTRIBUTE_READ_ONLY_HOST_REGISTER_SUPPORTED = 113, /**< Device supports using the ::cuMemHostRegister flag ::CU_MEMHOSTERGISTER_READ_ONLY to register memory that must be mapped as read-only to the GPU */
|
CU_DEVICE_ATTRIBUTE_READ_ONLY_HOST_REGISTER_SUPPORTED = 113, /**< Device supports using the ::cuMemHostRegister flag CU_MEMHOSTERGISTER_READ_ONLY to register memory that must be mapped as read-only to the GPU */
|
||||||
CU_DEVICE_ATTRIBUTE_TIMELINE_SEMAPHORE_INTEROP_SUPPORTED = 114, /**< External timeline semaphore interop is supported on the device */
|
CU_DEVICE_ATTRIBUTE_TIMELINE_SEMAPHORE_INTEROP_SUPPORTED = 114, /**< External timeline semaphore interop is supported on the device */
|
||||||
CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED = 115, /**< Device supports using the ::cuMemAllocAsync and ::cuMemPool family of APIs */
|
CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED = 115, /**< Device supports using the ::cuMemAllocAsync and ::cuMemPool family of APIs */
|
||||||
CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_SUPPORTED = 116, /**< Device supports GPUDirect RDMA APIs, like nvidia_p2p_get_pages (see https://docs.nvidia.com/cuda/gpudirect-rdma for more information) */
|
CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_SUPPORTED = 116, /**< Device supports GPUDirect RDMA APIs, like nvidia_p2p_get_pages (see https://docs.nvidia.com/cuda/gpudirect-rdma for more information) */
|
||||||
@@ -1650,8 +1624,7 @@ typedef enum cudaError_enum {
|
|||||||
CUDA_ERROR_UNSUPPORTED_EXEC_AFFINITY = 224,
|
CUDA_ERROR_UNSUPPORTED_EXEC_AFFINITY = 224,
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This indicates that the device kernel source is invalid. This includes
|
* This indicates that the device kernel source is invalid.
|
||||||
* compilation/linker errors encountered in device code or user error.
|
|
||||||
*/
|
*/
|
||||||
CUDA_ERROR_INVALID_SOURCE = 300,
|
CUDA_ERROR_INVALID_SOURCE = 300,
|
||||||
|
|
||||||
@@ -2068,9 +2041,9 @@ typedef size_t (CUDA_CB *CUoccupancyB2DSize)(int blockSize);
|
|||||||
* On Windows the flag is a no-op.
|
* On Windows the flag is a no-op.
|
||||||
* On Linux that memory is marked as non cache-coherent for the GPU and
|
* On Linux that memory is marked as non cache-coherent for the GPU and
|
||||||
* is expected to be physically contiguous. It may return
|
* is expected to be physically contiguous. It may return
|
||||||
* ::CUDA_ERROR_NOT_PERMITTED if run as an unprivileged user,
|
* CUDA_ERROR_NOT_PERMITTED if run as an unprivileged user,
|
||||||
* ::CUDA_ERROR_NOT_SUPPORTED on older Linux kernel versions.
|
* CUDA_ERROR_NOT_SUPPORTED on older Linux kernel versions.
|
||||||
* On all other platforms, it is not supported and ::CUDA_ERROR_NOT_SUPPORTED
|
* On all other platforms, it is not supported and CUDA_ERROR_NOT_SUPPORTED
|
||||||
* is returned.
|
* is returned.
|
||||||
* Flag for ::cuMemHostRegister()
|
* Flag for ::cuMemHostRegister()
|
||||||
*/
|
*/
|
||||||
@@ -2079,12 +2052,12 @@ typedef size_t (CUDA_CB *CUoccupancyB2DSize)(int blockSize);
|
|||||||
/**
|
/**
|
||||||
* If set, the passed memory pointer is treated as pointing to memory that is
|
* If set, the passed memory pointer is treated as pointing to memory that is
|
||||||
* considered read-only by the device. On platforms without
|
* considered read-only by the device. On platforms without
|
||||||
* ::CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS_USES_HOST_PAGE_TABLES, this flag is
|
* CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS_USES_HOST_PAGE_TABLES, this flag is
|
||||||
* required in order to register memory mapped to the CPU as read-only. Support
|
* required in order to register memory mapped to the CPU as read-only. Support
|
||||||
* for the use of this flag can be queried from the device attribute
|
* for the use of this flag can be queried from the device attribute
|
||||||
* ::CU_DEVICE_ATTRIBUTE_READ_ONLY_HOST_REGISTER_SUPPORTED. Using this flag with
|
* CU_DEVICE_ATTRIBUTE_READ_ONLY_HOST_REGISTER_SUPPORTED. Using this flag with
|
||||||
* a current context associated with a device that does not have this attribute
|
* a current context associated with a device that does not have this attribute
|
||||||
* set will cause ::cuMemHostRegister to error with ::CUDA_ERROR_NOT_SUPPORTED.
|
* set will cause ::cuMemHostRegister to error with CUDA_ERROR_NOT_SUPPORTED.
|
||||||
*/
|
*/
|
||||||
#define CU_MEMHOSTREGISTER_READ_ONLY 0x08
|
#define CU_MEMHOSTREGISTER_READ_ONLY 0x08
|
||||||
|
|
||||||
@@ -3735,117 +3708,117 @@ CUresult CUDAAPI cuDeviceGetTexture1DLinearMaxWidth(size_t *maxWidthInElements,
|
|||||||
* \p dev. The supported attributes are:
|
* \p dev. The supported attributes are:
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_BLOCK: Maximum number of threads per
|
* - ::CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_BLOCK: Maximum number of threads per
|
||||||
* block;
|
* block;
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X: Maximum x-dimension of a block
|
* - ::CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X: Maximum x-dimension of a block;
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Y: Maximum y-dimension of a block
|
* - ::CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Y: Maximum y-dimension of a block;
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Z: Maximum z-dimension of a block
|
* - ::CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Z: Maximum z-dimension of a block;
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_X: Maximum x-dimension of a grid
|
* - ::CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_X: Maximum x-dimension of a grid;
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Y: Maximum y-dimension of a grid
|
* - ::CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Y: Maximum y-dimension of a grid;
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Z: Maximum z-dimension of a grid
|
* - ::CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Z: Maximum z-dimension of a grid;
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK: Maximum amount of
|
* - ::CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK: Maximum amount of
|
||||||
* shared memory available to a thread block in bytes
|
* shared memory available to a thread block in bytes;
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_TOTAL_CONSTANT_MEMORY: Memory available on device for
|
* - ::CU_DEVICE_ATTRIBUTE_TOTAL_CONSTANT_MEMORY: Memory available on device for
|
||||||
* __constant__ variables in a CUDA C kernel in bytes
|
* __constant__ variables in a CUDA C kernel in bytes;
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_WARP_SIZE: Warp size in threads
|
* - ::CU_DEVICE_ATTRIBUTE_WARP_SIZE: Warp size in threads;
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_MAX_PITCH: Maximum pitch in bytes allowed by the
|
* - ::CU_DEVICE_ATTRIBUTE_MAX_PITCH: Maximum pitch in bytes allowed by the
|
||||||
* memory copy functions that involve memory regions allocated through
|
* memory copy functions that involve memory regions allocated through
|
||||||
* ::cuMemAllocPitch()
|
* ::cuMemAllocPitch();
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_WIDTH: Maximum 1D
|
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_WIDTH: Maximum 1D
|
||||||
* texture width
|
* texture width;
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_LINEAR_WIDTH: Maximum width
|
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_LINEAR_WIDTH: Maximum width
|
||||||
* for a 1D texture bound to linear memory
|
* for a 1D texture bound to linear memory;
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_MIPMAPPED_WIDTH: Maximum
|
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_MIPMAPPED_WIDTH: Maximum
|
||||||
* mipmapped 1D texture width
|
* mipmapped 1D texture width;
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_WIDTH: Maximum 2D
|
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_WIDTH: Maximum 2D
|
||||||
* texture width
|
* texture width;
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_HEIGHT: Maximum 2D
|
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_HEIGHT: Maximum 2D
|
||||||
* texture height
|
* texture height;
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LINEAR_WIDTH: Maximum width
|
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LINEAR_WIDTH: Maximum width
|
||||||
* for a 2D texture bound to linear memory
|
* for a 2D texture bound to linear memory;
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LINEAR_HEIGHT: Maximum height
|
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LINEAR_HEIGHT: Maximum height
|
||||||
* for a 2D texture bound to linear memory
|
* for a 2D texture bound to linear memory;
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LINEAR_PITCH: Maximum pitch
|
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LINEAR_PITCH: Maximum pitch
|
||||||
* in bytes for a 2D texture bound to linear memory
|
* in bytes for a 2D texture bound to linear memory;
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_MIPMAPPED_WIDTH: Maximum
|
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_MIPMAPPED_WIDTH: Maximum
|
||||||
* mipmapped 2D texture width
|
* mipmapped 2D texture width;
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_MIPMAPPED_HEIGHT: Maximum
|
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_MIPMAPPED_HEIGHT: Maximum
|
||||||
* mipmapped 2D texture height
|
* mipmapped 2D texture height;
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_WIDTH: Maximum 3D
|
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_WIDTH: Maximum 3D
|
||||||
* texture width
|
* texture width;
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_HEIGHT: Maximum 3D
|
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_HEIGHT: Maximum 3D
|
||||||
* texture height
|
* texture height;
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_DEPTH: Maximum 3D
|
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_DEPTH: Maximum 3D
|
||||||
* texture depth
|
* texture depth;
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_WIDTH_ALTERNATE:
|
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_WIDTH_ALTERNATE:
|
||||||
* Alternate maximum 3D texture width, 0 if no alternate
|
* Alternate maximum 3D texture width, 0 if no alternate
|
||||||
* maximum 3D texture size is supported
|
* maximum 3D texture size is supported;
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_HEIGHT_ALTERNATE:
|
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_HEIGHT_ALTERNATE:
|
||||||
* Alternate maximum 3D texture height, 0 if no alternate
|
* Alternate maximum 3D texture height, 0 if no alternate
|
||||||
* maximum 3D texture size is supported
|
* maximum 3D texture size is supported;
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_DEPTH_ALTERNATE:
|
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_DEPTH_ALTERNATE:
|
||||||
* Alternate maximum 3D texture depth, 0 if no alternate
|
* Alternate maximum 3D texture depth, 0 if no alternate
|
||||||
* maximum 3D texture size is supported
|
* maximum 3D texture size is supported;
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURECUBEMAP_WIDTH:
|
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURECUBEMAP_WIDTH:
|
||||||
* Maximum cubemap texture width or height
|
* Maximum cubemap texture width or height;
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_LAYERED_WIDTH:
|
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_LAYERED_WIDTH:
|
||||||
* Maximum 1D layered texture width
|
* Maximum 1D layered texture width;
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_LAYERED_LAYERS:
|
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_LAYERED_LAYERS:
|
||||||
* Maximum layers in a 1D layered texture
|
* Maximum layers in a 1D layered texture;
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LAYERED_WIDTH:
|
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LAYERED_WIDTH:
|
||||||
* Maximum 2D layered texture width
|
* Maximum 2D layered texture width;
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LAYERED_HEIGHT:
|
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LAYERED_HEIGHT:
|
||||||
* Maximum 2D layered texture height
|
* Maximum 2D layered texture height;
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LAYERED_LAYERS:
|
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LAYERED_LAYERS:
|
||||||
* Maximum layers in a 2D layered texture
|
* Maximum layers in a 2D layered texture;
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURECUBEMAP_LAYERED_WIDTH:
|
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURECUBEMAP_LAYERED_WIDTH:
|
||||||
* Maximum cubemap layered texture width or height
|
* Maximum cubemap layered texture width or height;
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURECUBEMAP_LAYERED_LAYERS:
|
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURECUBEMAP_LAYERED_LAYERS:
|
||||||
* Maximum layers in a cubemap layered texture
|
* Maximum layers in a cubemap layered texture;
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE1D_WIDTH:
|
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE1D_WIDTH:
|
||||||
* Maximum 1D surface width
|
* Maximum 1D surface width;
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_WIDTH:
|
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_WIDTH:
|
||||||
* Maximum 2D surface width
|
* Maximum 2D surface width;
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_HEIGHT:
|
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_HEIGHT:
|
||||||
* Maximum 2D surface height
|
* Maximum 2D surface height;
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE3D_WIDTH:
|
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE3D_WIDTH:
|
||||||
* Maximum 3D surface width
|
* Maximum 3D surface width;
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE3D_HEIGHT:
|
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE3D_HEIGHT:
|
||||||
* Maximum 3D surface height
|
* Maximum 3D surface height;
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE3D_DEPTH:
|
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE3D_DEPTH:
|
||||||
* Maximum 3D surface depth
|
* Maximum 3D surface depth;
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE1D_LAYERED_WIDTH:
|
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE1D_LAYERED_WIDTH:
|
||||||
* Maximum 1D layered surface width
|
* Maximum 1D layered surface width;
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE1D_LAYERED_LAYERS:
|
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE1D_LAYERED_LAYERS:
|
||||||
* Maximum layers in a 1D layered surface
|
* Maximum layers in a 1D layered surface;
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_LAYERED_WIDTH:
|
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_LAYERED_WIDTH:
|
||||||
* Maximum 2D layered surface width
|
* Maximum 2D layered surface width;
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_LAYERED_HEIGHT:
|
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_LAYERED_HEIGHT:
|
||||||
* Maximum 2D layered surface height
|
* Maximum 2D layered surface height;
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_LAYERED_LAYERS:
|
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_LAYERED_LAYERS:
|
||||||
* Maximum layers in a 2D layered surface
|
* Maximum layers in a 2D layered surface;
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACECUBEMAP_WIDTH:
|
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACECUBEMAP_WIDTH:
|
||||||
* Maximum cubemap surface width
|
* Maximum cubemap surface width;
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACECUBEMAP_LAYERED_WIDTH:
|
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACECUBEMAP_LAYERED_WIDTH:
|
||||||
* Maximum cubemap layered surface width
|
* Maximum cubemap layered surface width;
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACECUBEMAP_LAYERED_LAYERS:
|
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACECUBEMAP_LAYERED_LAYERS:
|
||||||
* Maximum layers in a cubemap layered surface
|
* Maximum layers in a cubemap layered surface;
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_BLOCK: Maximum number of 32-bit
|
* - ::CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_BLOCK: Maximum number of 32-bit
|
||||||
* registers available to a thread block
|
* registers available to a thread block;
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_CLOCK_RATE: The typical clock frequency in kilohertz
|
* - ::CU_DEVICE_ATTRIBUTE_CLOCK_RATE: The typical clock frequency in kilohertz;
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_TEXTURE_ALIGNMENT: Alignment requirement; texture
|
* - ::CU_DEVICE_ATTRIBUTE_TEXTURE_ALIGNMENT: Alignment requirement; texture
|
||||||
* base addresses aligned to ::textureAlign bytes do not need an offset
|
* base addresses aligned to ::textureAlign bytes do not need an offset
|
||||||
* applied to texture fetches
|
* applied to texture fetches;
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_TEXTURE_PITCH_ALIGNMENT: Pitch alignment requirement
|
* - ::CU_DEVICE_ATTRIBUTE_TEXTURE_PITCH_ALIGNMENT: Pitch alignment requirement
|
||||||
* for 2D texture references bound to pitched memory
|
* for 2D texture references bound to pitched memory;
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_GPU_OVERLAP: 1 if the device can concurrently copy
|
* - ::CU_DEVICE_ATTRIBUTE_GPU_OVERLAP: 1 if the device can concurrently copy
|
||||||
* memory between host and device while executing a kernel, or 0 if not
|
* memory between host and device while executing a kernel, or 0 if not;
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT: Number of multiprocessors on
|
* - ::CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT: Number of multiprocessors on
|
||||||
* the device
|
* the device;
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_KERNEL_EXEC_TIMEOUT: 1 if there is a run time limit
|
* - ::CU_DEVICE_ATTRIBUTE_KERNEL_EXEC_TIMEOUT: 1 if there is a run time limit
|
||||||
* for kernels executed on the device, or 0 if not
|
* for kernels executed on the device, or 0 if not;
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_INTEGRATED: 1 if the device is integrated with the
|
* - ::CU_DEVICE_ATTRIBUTE_INTEGRATED: 1 if the device is integrated with the
|
||||||
* memory subsystem, or 0 if not
|
* memory subsystem, or 0 if not;
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_CAN_MAP_HOST_MEMORY: 1 if the device can map host
|
* - ::CU_DEVICE_ATTRIBUTE_CAN_MAP_HOST_MEMORY: 1 if the device can map host
|
||||||
* memory into the CUDA address space, or 0 if not
|
* memory into the CUDA address space, or 0 if not;
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_COMPUTE_MODE: Compute mode that device is currently
|
* - ::CU_DEVICE_ATTRIBUTE_COMPUTE_MODE: Compute mode that device is currently
|
||||||
* in. Available modes are as follows:
|
* in. Available modes are as follows:
|
||||||
* - ::CU_COMPUTEMODE_DEFAULT: Default mode - Device is not restricted and
|
* - ::CU_COMPUTEMODE_DEFAULT: Default mode - Device is not restricted and
|
||||||
@@ -3858,33 +3831,33 @@ CUresult CUDAAPI cuDeviceGetTexture1DLinearMaxWidth(size_t *maxWidthInElements,
|
|||||||
* executing multiple kernels within the same context simultaneously, or 0 if
|
* executing multiple kernels within the same context simultaneously, or 0 if
|
||||||
* not. It is not guaranteed that multiple kernels will be resident
|
* not. It is not guaranteed that multiple kernels will be resident
|
||||||
* on the device concurrently so this feature should not be relied upon for
|
* on the device concurrently so this feature should not be relied upon for
|
||||||
* correctness.
|
* correctness;
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_ECC_ENABLED: 1 if error correction is enabled on the
|
* - ::CU_DEVICE_ATTRIBUTE_ECC_ENABLED: 1 if error correction is enabled on the
|
||||||
* device, 0 if error correction is disabled or not supported by the device
|
* device, 0 if error correction is disabled or not supported by the device;
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_PCI_BUS_ID: PCI bus identifier of the device
|
* - ::CU_DEVICE_ATTRIBUTE_PCI_BUS_ID: PCI bus identifier of the device;
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_PCI_DEVICE_ID: PCI device (also known as slot) identifier
|
* - ::CU_DEVICE_ATTRIBUTE_PCI_DEVICE_ID: PCI device (also known as slot) identifier
|
||||||
* of the device
|
* of the device;
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_PCI_DOMAIN_ID: PCI domain identifier of the device
|
* - ::CU_DEVICE_ATTRIBUTE_PCI_DOMAIN_ID: PCI domain identifier of the device
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_TCC_DRIVER: 1 if the device is using a TCC driver. TCC
|
* - ::CU_DEVICE_ATTRIBUTE_TCC_DRIVER: 1 if the device is using a TCC driver. TCC
|
||||||
* is only available on Tesla hardware running Windows Vista or later
|
* is only available on Tesla hardware running Windows Vista or later;
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE: Peak memory clock frequency in kilohertz
|
* - ::CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE: Peak memory clock frequency in kilohertz;
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH: Global memory bus width in bits
|
* - ::CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH: Global memory bus width in bits;
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_L2_CACHE_SIZE: Size of L2 cache in bytes. 0 if the device doesn't have L2 cache
|
* - ::CU_DEVICE_ATTRIBUTE_L2_CACHE_SIZE: Size of L2 cache in bytes. 0 if the device doesn't have L2 cache;
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_MULTIPROCESSOR: Maximum resident threads per multiprocessor
|
* - ::CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_MULTIPROCESSOR: Maximum resident threads per multiprocessor;
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_UNIFIED_ADDRESSING: 1 if the device shares a unified address space with
|
* - ::CU_DEVICE_ATTRIBUTE_UNIFIED_ADDRESSING: 1 if the device shares a unified address space with
|
||||||
* the host, or 0 if not
|
* the host, or 0 if not;
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR: Major compute capability version number
|
* - ::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR: Major compute capability version number;
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR: Minor compute capability version number
|
* - ::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR: Minor compute capability version number;
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_GLOBAL_L1_CACHE_SUPPORTED: 1 if device supports caching globals
|
* - ::CU_DEVICE_ATTRIBUTE_GLOBAL_L1_CACHE_SUPPORTED: 1 if device supports caching globals
|
||||||
* in L1 cache, 0 if caching globals in L1 cache is not supported by the device
|
* in L1 cache, 0 if caching globals in L1 cache is not supported by the device;
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_LOCAL_L1_CACHE_SUPPORTED: 1 if device supports caching locals
|
* - ::CU_DEVICE_ATTRIBUTE_LOCAL_L1_CACHE_SUPPORTED: 1 if device supports caching locals
|
||||||
* in L1 cache, 0 if caching locals in L1 cache is not supported by the device
|
* in L1 cache, 0 if caching locals in L1 cache is not supported by the device;
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR: Maximum amount of
|
* - ::CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR: Maximum amount of
|
||||||
* shared memory available to a multiprocessor in bytes; this amount is shared
|
* shared memory available to a multiprocessor in bytes; this amount is shared
|
||||||
* by all thread blocks simultaneously resident on a multiprocessor
|
* by all thread blocks simultaneously resident on a multiprocessor;
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_MULTIPROCESSOR: Maximum number of 32-bit
|
* - ::CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_MULTIPROCESSOR: Maximum number of 32-bit
|
||||||
* registers available to a multiprocessor; this number is shared by all thread
|
* registers available to a multiprocessor; this number is shared by all thread
|
||||||
* blocks simultaneously resident on a multiprocessor
|
* blocks simultaneously resident on a multiprocessor;
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_MANAGED_MEMORY: 1 if device supports allocating managed memory
|
* - ::CU_DEVICE_ATTRIBUTE_MANAGED_MEMORY: 1 if device supports allocating managed memory
|
||||||
* on this system, 0 if allocating managed memory is not supported by the device on this system.
|
* on this system, 0 if allocating managed memory is not supported by the device on this system.
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_MULTI_GPU_BOARD: 1 if device is on a multi-GPU board, 0 if not.
|
* - ::CU_DEVICE_ATTRIBUTE_MULTI_GPU_BOARD: 1 if device is on a multi-GPU board, 0 if not.
|
||||||
@@ -3910,20 +3883,14 @@ CUresult CUDAAPI cuDeviceGetTexture1DLinearMaxWidth(size_t *maxWidthInElements,
|
|||||||
* - ::CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED: Device supports virtual memory management APIs like ::cuMemAddressReserve, ::cuMemCreate, ::cuMemMap and related APIs
|
* - ::CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED: Device supports virtual memory management APIs like ::cuMemAddressReserve, ::cuMemCreate, ::cuMemMap and related APIs
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR_SUPPORTED: Device supports exporting memory to a posix file descriptor with ::cuMemExportToShareableHandle, if requested via ::cuMemCreate
|
* - ::CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR_SUPPORTED: Device supports exporting memory to a posix file descriptor with ::cuMemExportToShareableHandle, if requested via ::cuMemCreate
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_WIN32_HANDLE_SUPPORTED: Device supports exporting memory to a Win32 NT handle with ::cuMemExportToShareableHandle, if requested via ::cuMemCreate
|
* - ::CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_WIN32_HANDLE_SUPPORTED: Device supports exporting memory to a Win32 NT handle with ::cuMemExportToShareableHandle, if requested via ::cuMemCreate
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_WIN32_KMT_HANDLE_SUPPORTED: Device supports exporting memory to a Win32 KMT handle with ::cuMemExportToShareableHandle, if requested via ::cuMemCreate
|
* - ::CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_WIN32_KMT_HANDLE_SUPPORTED: Device supports exporting memory to a Win32 KMT handle with ::cuMemExportToShareableHandle, if requested ::cuMemCreate
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_MAX_BLOCKS_PER_MULTIPROCESSOR: Maximum number of thread blocks that can reside on a multiprocessor
|
* - ::CU_DEVICE_ATTRIBUTE_MAX_PERSISTING_L2_CACHE_SIZE: Maximum L2 persisting lines capacity setting in bytes.
|
||||||
|
* - ::CU_DEVICE_ATTRIBUTE_MAX_ACCESS_POLICY_WINDOW_SIZE: Maximum value of CUaccessPolicyWindow::num_bytes.
|
||||||
|
* - ::CU_DEVICE_ATTRIBUTE_MAX_BLOCKS_PER_MULTIPROCESSOR: Maximum number of thread blocks that can reside on a multiprocessor.
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_GENERIC_COMPRESSION_SUPPORTED: Device supports compressible memory allocation via ::cuMemCreate
|
* - ::CU_DEVICE_ATTRIBUTE_GENERIC_COMPRESSION_SUPPORTED: Device supports compressible memory allocation via ::cuMemCreate
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_MAX_PERSISTING_L2_CACHE_SIZE: Maximum L2 persisting lines capacity setting in bytes
|
* - ::CU_DEVICE_ATTRIBUTE_RESERVED_SHARED_MEMORY_PER_BLOCK: Amount of shared memory per block reserved by CUDA driver in bytes.
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_MAX_ACCESS_POLICY_WINDOW_SIZE: Maximum value of CUaccessPolicyWindow::num_bytes
|
* - ::CU_DEVICE_ATTRIBUTE_READ_ONLY_HOST_REGISTER_SUPPORTED: Device supports using the ::cuMemHostRegister flag CU_MEMHOSTERGISTER_READ_ONLY to register memory that must be mapped as read-only to the GPU
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_WITH_CUDA_VMM_SUPPORTED: Device supports specifying the GPUDirect RDMA flag with ::cuMemCreate.
|
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_RESERVED_SHARED_MEMORY_PER_BLOCK: Amount of shared memory per block reserved by CUDA driver in bytes
|
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_SPARSE_CUDA_ARRAY_SUPPORTED: Device supports sparse CUDA arrays and sparse CUDA mipmapped arrays.
|
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_READ_ONLY_HOST_REGISTER_SUPPORTED: Device supports using the ::cuMemHostRegister flag ::CU_MEMHOSTERGISTER_READ_ONLY to register memory that must be mapped as read-only to the GPU
|
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED: Device supports using the ::cuMemAllocAsync and ::cuMemPool family of APIs
|
* - ::CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED: Device supports using the ::cuMemAllocAsync and ::cuMemPool family of APIs
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_SUPPORTED: Device supports GPUDirect RDMA APIs, like nvidia_p2p_get_pages (see https://docs.nvidia.com/cuda/gpudirect-rdma for more information)
|
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_FLUSH_WRITES_OPTIONS: The returned attribute shall be interpreted as a bitmask, where the individual bits are described by the ::CUflushGPUDirectRDMAWritesOptions enum
|
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_WRITES_ORDERING: GPUDirect RDMA writes to the device do not need to be flushed for consumers within the scope indicated by the returned attribute. See ::CUGPUDirectRDMAWritesOrdering for the numerical values returned here.
|
|
||||||
* - ::CU_DEVICE_ATTRIBUTE_MEMPOOL_SUPPORTED_HANDLE_TYPES: Bitmask of handle types supported with mempool based IPC
|
|
||||||
*
|
*
|
||||||
* \param pi - Returned device attribute value
|
* \param pi - Returned device attribute value
|
||||||
* \param attrib - Device attribute to query
|
* \param attrib - Device attribute to query
|
||||||
@@ -4690,13 +4657,6 @@ CUresult CUDAAPI cuCtxCreate_v3(CUcontext *pctx, CUexecAffinityParam *paramsArra
|
|||||||
* It is the responsibility of the calling function to ensure that no API
|
* It is the responsibility of the calling function to ensure that no API
|
||||||
* call issues using \p ctx while ::cuCtxDestroy() is executing.
|
* call issues using \p ctx while ::cuCtxDestroy() is executing.
|
||||||
*
|
*
|
||||||
* Destroys and cleans up all resources associated with the context.
|
|
||||||
* It is the caller's responsibility to ensure that the context or its resources
|
|
||||||
* are not accessed or passed in subsequent API calls and doing so will result in undefined behavior.
|
|
||||||
* These resources include CUDA types such as ::CUmodule, ::CUfunction, ::CUstream, ::CUevent,
|
|
||||||
* ::CUarray, ::CUmipmappedArray, ::CUtexObject, ::CUsurfObject, ::CUtexref, ::CUsurfref,
|
|
||||||
* ::CUgraphicsResource, ::CUlinkState, ::CUexternalMemory and ::CUexternalSemaphore.
|
|
||||||
*
|
|
||||||
* If \p ctx is current to the calling thread then \p ctx will also be
|
* If \p ctx is current to the calling thread then \p ctx will also be
|
||||||
* popped from the current thread's context stack (as though ::cuCtxPopCurrent()
|
* popped from the current thread's context stack (as though ::cuCtxPopCurrent()
|
||||||
* were called). If \p ctx is current to other threads, then \p ctx will
|
* were called). If \p ctx is current to other threads, then \p ctx will
|
||||||
@@ -5672,7 +5632,6 @@ CUresult CUDAAPI cuModuleLoadFatBinary(CUmodule *module, const void *fatCubin);
|
|||||||
* ::CUDA_ERROR_INVALID_CONTEXT,
|
* ::CUDA_ERROR_INVALID_CONTEXT,
|
||||||
* ::CUDA_ERROR_INVALID_VALUE
|
* ::CUDA_ERROR_INVALID_VALUE
|
||||||
* \notefnerr
|
* \notefnerr
|
||||||
* \note_destroy_ub
|
|
||||||
*
|
*
|
||||||
* \sa ::cuModuleGetFunction,
|
* \sa ::cuModuleGetFunction,
|
||||||
* ::cuModuleGetGlobal,
|
* ::cuModuleGetGlobal,
|
||||||
@@ -5993,9 +5952,8 @@ cuLinkDestroy(CUlinkState state);
|
|||||||
/**
|
/**
|
||||||
* \brief Gets free and total memory
|
* \brief Gets free and total memory
|
||||||
*
|
*
|
||||||
* Returns in \p *total the total amount of memory available to the the current context.
|
* Returns in \p *free and \p *total respectively, the free and total amount of
|
||||||
* Returns in \p *free the amount of memory on the device that is free according to the OS.
|
* memory available for allocation by the CUDA context, in bytes.
|
||||||
* CUDA is not guaranteed to be able to allocate all of the memory that the OS reports as free.
|
|
||||||
*
|
*
|
||||||
* \param free - Returned free memory in bytes
|
* \param free - Returned free memory in bytes
|
||||||
* \param total - Returned total memory in bytes
|
* \param total - Returned total memory in bytes
|
||||||
@@ -6839,10 +6797,10 @@ CUresult CUDAAPI cuIpcCloseMemHandle(CUdeviceptr dptr);
|
|||||||
*
|
*
|
||||||
* - ::CU_MEMHOSTREGISTER_READ_ONLY: The pointer is treated as pointing to memory
|
* - ::CU_MEMHOSTREGISTER_READ_ONLY: The pointer is treated as pointing to memory
|
||||||
* that is considered read-only by the device. On platforms without
|
* that is considered read-only by the device. On platforms without
|
||||||
* ::CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS_USES_HOST_PAGE_TABLES, this flag is
|
* CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS_USES_HOST_PAGE_TABLES, this flag is
|
||||||
* required in order to register memory mapped to the CPU as read-only. Support
|
* required in order to register memory mapped to the CPU as read-only. Support
|
||||||
* for the use of this flag can be queried from the device attribute
|
* for the use of this flag can be queried from the device attribute
|
||||||
* ::CU_DEVICE_ATTRIBUTE_READ_ONLY_HOST_REGISTER_SUPPORTED. Using this flag with
|
* CU_DEVICE_ATTRIBUTE_READ_ONLY_HOST_REGISTER_SUPPORTED. Using this flag with
|
||||||
* a current context associated with a device that does not have this attribute
|
* a current context associated with a device that does not have this attribute
|
||||||
* set will cause ::cuMemHostRegister to error with CUDA_ERROR_NOT_SUPPORTED.
|
* set will cause ::cuMemHostRegister to error with CUDA_ERROR_NOT_SUPPORTED.
|
||||||
*
|
*
|
||||||
@@ -8987,7 +8945,7 @@ CUresult CUDAAPI cuMemsetD2D32Async(CUdeviceptr dstDevice, size_t dstPitch, unsi
|
|||||||
* float16's:
|
* float16's:
|
||||||
* \code
|
* \code
|
||||||
CUDA_ARRAY_DESCRIPTOR desc;
|
CUDA_ARRAY_DESCRIPTOR desc;
|
||||||
desc.Format = CU_AD_FORMAT_HALF;
|
desc.FormatFlags = CU_AD_FORMAT_HALF;
|
||||||
desc.NumChannels = 4;
|
desc.NumChannels = 4;
|
||||||
desc.Width = width;
|
desc.Width = width;
|
||||||
desc.Height = height;
|
desc.Height = height;
|
||||||
@@ -8997,7 +8955,7 @@ CUresult CUDAAPI cuMemsetD2D32Async(CUdeviceptr dstDevice, size_t dstPitch, unsi
|
|||||||
* of which is two 8-bit unsigned chars:
|
* of which is two 8-bit unsigned chars:
|
||||||
* \code
|
* \code
|
||||||
CUDA_ARRAY_DESCRIPTOR arrayDesc;
|
CUDA_ARRAY_DESCRIPTOR arrayDesc;
|
||||||
desc.Format = CU_AD_FORMAT_UNSIGNED_INT8;
|
desc.FormatFlags = CU_AD_FORMAT_UNSIGNED_INT8;
|
||||||
desc.NumChannels = 2;
|
desc.NumChannels = 2;
|
||||||
desc.Width = width;
|
desc.Width = width;
|
||||||
desc.Height = height;
|
desc.Height = height;
|
||||||
@@ -9323,7 +9281,7 @@ CUresult CUDAAPI cuArrayDestroy(CUarray hArray);
|
|||||||
* 4x16-bit float16's:
|
* 4x16-bit float16's:
|
||||||
* \code
|
* \code
|
||||||
CUDA_ARRAY3D_DESCRIPTOR desc;
|
CUDA_ARRAY3D_DESCRIPTOR desc;
|
||||||
desc.Format = CU_AD_FORMAT_HALF;
|
desc.FormatFlags = CU_AD_FORMAT_HALF;
|
||||||
desc.NumChannels = 4;
|
desc.NumChannels = 4;
|
||||||
desc.Width = width;
|
desc.Width = width;
|
||||||
desc.Height = height;
|
desc.Height = height;
|
||||||
@@ -15180,7 +15138,7 @@ CUresult CUDAAPI cuGraphExternalSemaphoresWaitNodeSetParams(CUgraphNode hNode, c
|
|||||||
* \param nodeParams - Parameters for the node
|
* \param nodeParams - Parameters for the node
|
||||||
*
|
*
|
||||||
* When ::cuGraphAddMemAllocNode creates an allocation node, it returns the address of the allocation in
|
* When ::cuGraphAddMemAllocNode creates an allocation node, it returns the address of the allocation in
|
||||||
* \p nodeParams.dptr. The allocation's address remains fixed across instantiations and launches.
|
* \param nodeParams.dptr. The allocation's address remains fixed across instantiations and launches.
|
||||||
*
|
*
|
||||||
* If the allocation is freed in the same graph, by creating a free node using ::cuGraphAddMemFreeNode,
|
* If the allocation is freed in the same graph, by creating a free node using ::cuGraphAddMemFreeNode,
|
||||||
* the allocation can be accessed by nodes ordered after the allocation node but before the free node.
|
* the allocation can be accessed by nodes ordered after the allocation node but before the free node.
|
||||||
@@ -15356,9 +15314,7 @@ CUresult CUDAAPI cuGraphMemFreeNodeGetParams(CUgraphNode hNode, CUdeviceptr *dpt
|
|||||||
*
|
*
|
||||||
* \sa
|
* \sa
|
||||||
* ::cuGraphAddMemAllocNode,
|
* ::cuGraphAddMemAllocNode,
|
||||||
* ::cuGraphAddMemFreeNode,
|
* ::cuGraphAddMemFreeNode
|
||||||
* ::cuDeviceSetGraphMemAttribute,
|
|
||||||
* ::cuDeviceGetGraphMemAttribute
|
|
||||||
*/
|
*/
|
||||||
CUresult CUDAAPI cuDeviceGraphMemTrim(CUdevice device);
|
CUresult CUDAAPI cuDeviceGraphMemTrim(CUdevice device);
|
||||||
|
|
||||||
@@ -15384,7 +15340,6 @@ CUresult CUDAAPI cuDeviceGraphMemTrim(CUdevice device);
|
|||||||
* ::CUDA_ERROR_INVALID_DEVICE
|
* ::CUDA_ERROR_INVALID_DEVICE
|
||||||
*
|
*
|
||||||
* \sa
|
* \sa
|
||||||
* ::cuDeviceSetGraphMemAttribute,
|
|
||||||
* ::cuGraphAddMemAllocNode,
|
* ::cuGraphAddMemAllocNode,
|
||||||
* ::cuGraphAddMemFreeNode
|
* ::cuGraphAddMemFreeNode
|
||||||
*/
|
*/
|
||||||
@@ -15409,7 +15364,6 @@ CUresult CUDAAPI cuDeviceGetGraphMemAttribute(CUdevice device, CUgraphMem_attrib
|
|||||||
* ::CUDA_ERROR_INVALID_DEVICE
|
* ::CUDA_ERROR_INVALID_DEVICE
|
||||||
*
|
*
|
||||||
* \sa
|
* \sa
|
||||||
* ::cuDeviceGetGraphMemAttribute,
|
|
||||||
* ::cuGraphAddMemAllocNode,
|
* ::cuGraphAddMemAllocNode,
|
||||||
* ::cuGraphAddMemFreeNode
|
* ::cuGraphAddMemFreeNode
|
||||||
*/
|
*/
|
||||||
|
@@ -1,4 +1,4 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#ifndef _TRITON_IR_BASIC_BLOCK_H_
|
#ifndef _TRITON_IR_BASIC_BLOCK_H_
|
||||||
#define _TRITON_IR_BASIC_BLOCK_H_
|
#define _TRITON_IR_BASIC_BLOCK_H_
|
||||||
@@ -27,7 +27,7 @@ public:
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
// constructors
|
// constructors
|
||||||
basic_block(context &ctx, const std::string &name, function *parent);
|
basic_block(context &ctx, const std::string &name, function *parent, basic_block *next);
|
||||||
|
|
||||||
public:
|
public:
|
||||||
// accessors
|
// accessors
|
||||||
@@ -35,6 +35,7 @@ public:
|
|||||||
context& get_context() { return ctx_; }
|
context& get_context() { return ctx_; }
|
||||||
|
|
||||||
// get iterator to first instruction that is not a phi
|
// get iterator to first instruction that is not a phi
|
||||||
|
void replace_phi_uses_with(basic_block* before, basic_block* after);
|
||||||
iterator get_first_non_phi();
|
iterator get_first_non_phi();
|
||||||
|
|
||||||
// get instruction list
|
// get instruction list
|
||||||
@@ -60,13 +61,16 @@ public:
|
|||||||
inline const instruction &back() const { return *inst_list_.back(); }
|
inline const instruction &back() const { return *inst_list_.back(); }
|
||||||
inline instruction &back() { return *inst_list_.back(); }
|
inline instruction &back() { return *inst_list_.back(); }
|
||||||
|
|
||||||
|
void append_instruction(ir::instruction* i);
|
||||||
|
// split
|
||||||
|
basic_block* split_before(ir::instruction* loc, const std::string& name);
|
||||||
|
|
||||||
// predecessors
|
// predecessors
|
||||||
const std::vector<basic_block*>& get_predecessors() const { return preds_; }
|
std::vector<basic_block*> get_predecessors() const;
|
||||||
const std::vector<basic_block*>& get_successors() const { return succs_; }
|
std::vector<basic_block*> get_successors() const;
|
||||||
void add_predecessor(basic_block* pred);
|
|
||||||
|
|
||||||
// factory functions
|
// factory functions
|
||||||
static basic_block* create(context &ctx, const std::string &name, function *parent);
|
static basic_block* create(context &ctx, const std::string &name, function *parent, basic_block *next = nullptr);
|
||||||
|
|
||||||
void print(std::ostream &os);
|
void print(std::ostream &os);
|
||||||
|
|
||||||
|
@@ -22,6 +22,7 @@ class phi_node;
|
|||||||
|
|
||||||
/* Builder */
|
/* Builder */
|
||||||
class builder{
|
class builder{
|
||||||
|
public:
|
||||||
typedef basic_block::iterator iterator;
|
typedef basic_block::iterator iterator;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
@@ -75,6 +76,7 @@ public:
|
|||||||
value* create_br(basic_block *dest);
|
value* create_br(basic_block *dest);
|
||||||
value* create_cond_br(value *cond, basic_block* if_dest, basic_block* else_dest);
|
value* create_cond_br(value *cond, basic_block* if_dest, basic_block* else_dest);
|
||||||
value* create_ret_void();
|
value* create_ret_void();
|
||||||
|
value* create_ret(value *ret);
|
||||||
// Cast instructions
|
// Cast instructions
|
||||||
value *create_cast(cast_op_t op, value *v, type *dst_ty);
|
value *create_cast(cast_op_t op, value *v, type *dst_ty);
|
||||||
value* create_ptr_to_int(value *src, type *dst_ty);
|
value* create_ptr_to_int(value *src, type *dst_ty);
|
||||||
@@ -86,6 +88,9 @@ public:
|
|||||||
value* create_fp_trunc(value *src, type *dst_ty);
|
value* create_fp_trunc(value *src, type *dst_ty);
|
||||||
value* create_int_cast(value *src, type *dst_ty, bool is_signed);
|
value* create_int_cast(value *src, type *dst_ty, bool is_signed);
|
||||||
value *create_downcast(value *arg);
|
value *create_downcast(value *arg);
|
||||||
|
// Call instruction
|
||||||
|
value* create_call(function* fn, const std::vector<value*>& args);
|
||||||
|
value* create_launch(function* fn, const std::vector<value*>& args, const std::vector<value*>& grid, value* num_warps);
|
||||||
// Phi instruction
|
// Phi instruction
|
||||||
phi_node* create_phi(type *ty, unsigned num_reserved);
|
phi_node* create_phi(type *ty, unsigned num_reserved);
|
||||||
// Binary instructions
|
// Binary instructions
|
||||||
@@ -142,6 +147,9 @@ public:
|
|||||||
value *create_store(value *ptr, value *val);
|
value *create_store(value *ptr, value *val);
|
||||||
value *create_masked_load(value *arg, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile);
|
value *create_masked_load(value *arg, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile);
|
||||||
value *create_masked_store(value *ptr, value *val, value *mask);
|
value *create_masked_store(value *ptr, value *val, value *mask);
|
||||||
|
// Struct instructions
|
||||||
|
value *create_insert_value(value* val, value *elt, size_t idx);
|
||||||
|
value *create_extract_value(value* val, size_t idx);
|
||||||
// Block instruction
|
// Block instruction
|
||||||
value *create_splat(value *arg, const type::block_shapes_t &shapes);
|
value *create_splat(value *arg, const type::block_shapes_t &shapes);
|
||||||
value *create_reshape(value *arg, const type::block_shapes_t &shapes);
|
value *create_reshape(value *arg, const type::block_shapes_t &shapes);
|
||||||
|
@@ -31,7 +31,8 @@ public:
|
|||||||
std::map<std::pair<type*, unsigned>, std::unique_ptr<pointer_type>> ptr_tys;
|
std::map<std::pair<type*, unsigned>, std::unique_ptr<pointer_type>> ptr_tys;
|
||||||
// Block types
|
// Block types
|
||||||
std::map<std::pair<type*, type::block_shapes_t>, std::unique_ptr<block_type>> block_tys;
|
std::map<std::pair<type*, type::block_shapes_t>, std::unique_ptr<block_type>> block_tys;
|
||||||
|
// Struct types
|
||||||
|
std::map<type::contained_tys_vec_t, struct_type*> struct_tys;
|
||||||
// Int constants
|
// Int constants
|
||||||
std::map<std::pair<type*, uint64_t>, std::unique_ptr<constant_int>> int_constants_;
|
std::map<std::pair<type*, uint64_t>, std::unique_ptr<constant_int>> int_constants_;
|
||||||
// Float constants
|
// Float constants
|
||||||
|
@@ -95,6 +95,9 @@ enum value_id_t: unsigned {
|
|||||||
INSTRUCTIONS
|
INSTRUCTIONS
|
||||||
* ------------ */
|
* ------------ */
|
||||||
INST_BEGIN,
|
INST_BEGIN,
|
||||||
|
// call
|
||||||
|
INST_CALL,
|
||||||
|
INST_LAUNCH,
|
||||||
// phi
|
// phi
|
||||||
INST_PHI,
|
INST_PHI,
|
||||||
// arithmetic
|
// arithmetic
|
||||||
@@ -129,6 +132,9 @@ enum value_id_t: unsigned {
|
|||||||
INST_MASKED_LOAD_ASYNC,
|
INST_MASKED_LOAD_ASYNC,
|
||||||
INST_UNMASKED_STORE,
|
INST_UNMASKED_STORE,
|
||||||
INST_MASKED_STORE,
|
INST_MASKED_STORE,
|
||||||
|
// struct
|
||||||
|
INST_EXTRACT_VALUE,
|
||||||
|
INST_INSERT_VALUE,
|
||||||
// retile
|
// retile
|
||||||
INST_RESHAPE,
|
INST_RESHAPE,
|
||||||
INST_SPLAT,
|
INST_SPLAT,
|
||||||
|
@@ -121,6 +121,8 @@ public:
|
|||||||
const attr_map_t &attrs() { return attrs_; }
|
const attr_map_t &attrs() { return attrs_; }
|
||||||
bool has_attr(unsigned arg_id) const { return attrs_.find(arg_id) != attrs_.end(); }
|
bool has_attr(unsigned arg_id) const { return attrs_.find(arg_id) != attrs_.end(); }
|
||||||
std::set<attribute> get_attributes(const argument* arg) { return attrs_[arg->get_arg_no() + 1]; }
|
std::set<attribute> get_attributes(const argument* arg) { return attrs_[arg->get_arg_no() + 1]; }
|
||||||
|
void set_is_kernel(bool new_val) { is_kernel_ = new_val; }
|
||||||
|
bool get_is_kernel() { return is_kernel_; }
|
||||||
|
|
||||||
void print(std::ostream &os);
|
void print(std::ostream &os);
|
||||||
|
|
||||||
@@ -134,6 +136,7 @@ private:
|
|||||||
args_t args_;
|
args_t args_;
|
||||||
blocks_t blocks_;
|
blocks_t blocks_;
|
||||||
attr_map_t attrs_;
|
attr_map_t attrs_;
|
||||||
|
bool is_kernel_;
|
||||||
};
|
};
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@@ -81,6 +81,51 @@ private:
|
|||||||
value_id_t id_;
|
value_id_t id_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// call_inst classes
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
class call_inst: public instruction {
|
||||||
|
private:
|
||||||
|
std::string repr_impl() const;
|
||||||
|
call_inst(ir::function* fn, const std::vector<ir::value*>& values, const std::string& name, instruction* next);
|
||||||
|
|
||||||
|
public:
|
||||||
|
static call_inst* create(ir::function* fn, const std::vector<ir::value*>& values, const std::string &name = "", instruction *next = nullptr);
|
||||||
|
ir::function* get_fn() { return fn_; }
|
||||||
|
|
||||||
|
_TRITON_DEFINE_CLONE(call_inst)
|
||||||
|
_TRITON_DEFINE_ACCEPT(call_inst)
|
||||||
|
|
||||||
|
private:
|
||||||
|
ir::function* fn_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class launch_inst: public instruction {
|
||||||
|
private:
|
||||||
|
std::string repr_impl() const { return "launch"; }
|
||||||
|
launch_inst(ir::function* fn, const std::vector<ir::value*>& values, const std::vector<ir::value*>& grid, ir::value* num_warps,
|
||||||
|
const std::string &name = "", instruction *next = nullptr);
|
||||||
|
|
||||||
|
public:
|
||||||
|
static launch_inst* create(ir::function* fn, const std::vector<ir::value*>& values, const std::vector<ir::value*>& grid, ir::value* num_warps,
|
||||||
|
const std::string& name = "", instruction* next = nullptr);
|
||||||
|
|
||||||
|
ir::function* get_fn();
|
||||||
|
std::vector<ir::value*> get_values();
|
||||||
|
std::vector<ir::value*> get_grid();
|
||||||
|
ir::value* get_num_warps();
|
||||||
|
|
||||||
|
|
||||||
|
_TRITON_DEFINE_CLONE(launch_inst)
|
||||||
|
_TRITON_DEFINE_ACCEPT(launch_inst)
|
||||||
|
|
||||||
|
private:
|
||||||
|
unsigned val_begin;
|
||||||
|
unsigned val_end;
|
||||||
|
unsigned grid_begin;
|
||||||
|
unsigned grid_end;
|
||||||
|
};
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// phi_node classes
|
// phi_node classes
|
||||||
@@ -546,6 +591,44 @@ public:
|
|||||||
_TRITON_DEFINE_ACCEPT(masked_store_inst)
|
_TRITON_DEFINE_ACCEPT(masked_store_inst)
|
||||||
};
|
};
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// struct classes
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
// insert_value
|
||||||
|
|
||||||
|
class insert_value_inst: public instruction {
|
||||||
|
private:
|
||||||
|
std::string repr_impl() const { return "insertvalue"; }
|
||||||
|
insert_value_inst(value *val, value *elt, size_t idx, const std::string &name, instruction *next);
|
||||||
|
|
||||||
|
public:
|
||||||
|
static insert_value_inst* create(value *val, value* elt, size_t idx, const std::string &name = "", instruction *next = nullptr);
|
||||||
|
size_t get_idx() { return idx_; }
|
||||||
|
_TRITON_DEFINE_CLONE(insert_value_inst)
|
||||||
|
_TRITON_DEFINE_ACCEPT(insert_value_inst)
|
||||||
|
|
||||||
|
private:
|
||||||
|
size_t idx_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// extract_value
|
||||||
|
|
||||||
|
class extract_value_inst: public instruction {
|
||||||
|
private:
|
||||||
|
std::string repr_impl() const { return "extractvalue"; }
|
||||||
|
extract_value_inst(value *val, size_t idx, const std::string &name, instruction *next);
|
||||||
|
|
||||||
|
public:
|
||||||
|
static extract_value_inst* create(value *val, size_t idx, const std::string &name = "", instruction *next = nullptr);
|
||||||
|
size_t get_idx() { return idx_; }
|
||||||
|
_TRITON_DEFINE_CLONE(extract_value_inst)
|
||||||
|
_TRITON_DEFINE_ACCEPT(extract_value_inst)
|
||||||
|
|
||||||
|
private:
|
||||||
|
size_t idx_;
|
||||||
|
};
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// retile_inst classes
|
// retile_inst classes
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@@ -34,79 +34,97 @@ class constant;
|
|||||||
class global_value;
|
class global_value;
|
||||||
class alloc_const;
|
class alloc_const;
|
||||||
|
|
||||||
/* Module */
|
class value_constructor {
|
||||||
|
|
||||||
class module {
|
|
||||||
typedef std::pair<std::string, basic_block*> val_key_t;
|
typedef std::pair<std::string, basic_block*> val_key_t;
|
||||||
friend class function;
|
|
||||||
typedef std::pair<ir::metadata::kind_t, unsigned> md_pair_t;
|
typedef std::pair<ir::metadata::kind_t, unsigned> md_pair_t;
|
||||||
|
|
||||||
public:
|
|
||||||
typedef std::map<std::string, global_value*> symbols_map_t;
|
|
||||||
typedef std::vector<function*> functions_list_t;
|
|
||||||
struct current_iteration_info_t{
|
|
||||||
lang::iteration_statement *statement;
|
|
||||||
basic_block *block;
|
|
||||||
};
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
phi_node *make_phi(type *ty, unsigned num_values, basic_block *block);
|
phi_node *make_phi(type *ty, unsigned num_values, basic_block *block);
|
||||||
value *try_remove_trivial_phis(ir::phi_node *&phi);
|
value *try_remove_trivial_phis(ir::phi_node *&phi);
|
||||||
value *add_phi_operands(const std::string& name, phi_node *&phi);
|
value *add_phi_operands(const std::string& name, phi_node *&phi);
|
||||||
value *get_value_recursive(const std::string& name, basic_block *block);
|
value *get_value_recursive(const std::string& name, basic_block *block);
|
||||||
|
|
||||||
|
public:
|
||||||
|
value_constructor(builder &builder);
|
||||||
|
|
||||||
|
void set_value(const std::string& name, basic_block* block, value *x);
|
||||||
|
void set_value(const std::string& name, value* x);
|
||||||
|
const std::map<val_key_t, value*>& get_values() { return values_; }
|
||||||
|
void set_values(const std::map<val_key_t, value*>& values) { values_ = values; }
|
||||||
|
value *get_value(const std::string& name, basic_block* block);
|
||||||
|
value *get_value(const std::string& name);
|
||||||
|
void set_type(const std::string& name, ir::type* ty) { types_[name] = ty; }
|
||||||
|
// Seal block -- no more predecessors will be added
|
||||||
|
void seal_block(basic_block *block);
|
||||||
|
// Metadata
|
||||||
|
void add_metadata(const std::string &name, md_pair_t x) { metadatas_[name] = x; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
ir::builder& builder_;
|
||||||
|
std::map<val_key_t, value*> values_;
|
||||||
|
std::map<std::string, type*> types_;
|
||||||
|
std::set<basic_block*> sealed_blocks_;
|
||||||
|
std::map<basic_block*, std::map<std::string, phi_node*>> incomplete_phis_;
|
||||||
|
std::map<value*, value**> current_phi_;
|
||||||
|
std::map<std::string, md_pair_t> metadatas_;
|
||||||
|
};
|
||||||
|
|
||||||
|
/* Module */
|
||||||
|
|
||||||
|
class module {
|
||||||
|
typedef std::pair<std::string, basic_block*> val_key_t;
|
||||||
|
friend class function;
|
||||||
|
|
||||||
|
public:
|
||||||
|
typedef std::map<std::string, global_value*> symbols_map_t;
|
||||||
|
typedef std::vector<function*> functions_list_t;
|
||||||
|
|
||||||
|
private:
|
||||||
void push_function(function *fn) { functions_.push_back(fn); }
|
void push_function(function *fn) { functions_.push_back(fn); }
|
||||||
|
|
||||||
public:
|
public:
|
||||||
module(const std::string &name, builder& builder);
|
module(const std::string &name, builder& builder);
|
||||||
builder& get_builder();
|
builder& get_builder();
|
||||||
// Setters
|
// Setters
|
||||||
void set_value(const std::string& name, basic_block* block, value *x);
|
|
||||||
void set_value(const std::string& name, value* x);
|
|
||||||
void set_const(const std::string& name);
|
|
||||||
void set_continue_fn(std::function<ir::value*()> fn);
|
void set_continue_fn(std::function<ir::value*()> fn);
|
||||||
// Getters
|
// Getters
|
||||||
const std::map<val_key_t, value*>& get_values() { return values_; }
|
|
||||||
const std::map<std::string, type*>& get_types() { return types_; }
|
|
||||||
void set_values(const std::map<val_key_t, value*>& values) { values_ = values; }
|
|
||||||
void set_types(const std::map<std::string, type*>& types) { types_ = types; }
|
|
||||||
|
|
||||||
value *get_value(const std::string& name, basic_block* block);
|
|
||||||
value *get_value(const std::string& name);
|
|
||||||
void set_type(const std::string& name, ir::type* ty) { types_[name] = ty; }
|
|
||||||
const std::string& get_name();
|
const std::string& get_name();
|
||||||
std::function<ir::value*()> get_continue_fn();
|
std::function<ir::value*()> get_continue_fn();
|
||||||
// Seal block -- no more predecessors will be added
|
|
||||||
void seal_block(basic_block *block);
|
|
||||||
// Functions
|
// Functions
|
||||||
const functions_list_t &get_function_list() const { return functions_; }
|
const functions_list_t &get_function_list() const { return functions_; }
|
||||||
functions_list_t &get_function_list() { return functions_; }
|
functions_list_t &get_function_list() { return functions_; }
|
||||||
|
function *get_function(const std::string& name) {
|
||||||
|
if(symbols_.find(name) == symbols_.end())
|
||||||
|
throw std::runtime_error("function " + name + " is not declared");
|
||||||
|
return (function*)symbols_.at(name);
|
||||||
|
}
|
||||||
function *get_or_insert_function(const std::string &name, function_type *ty);
|
function *get_or_insert_function(const std::string &name, function_type *ty);
|
||||||
|
bool has_function(const std::string& name){
|
||||||
|
return symbols_.find(name) != symbols_.end();
|
||||||
|
}
|
||||||
|
void remove_function(ir::function* fn){
|
||||||
|
functions_.erase(std::remove(functions_.begin(), functions_.end(), fn), functions_.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
void reset_ret_ty(const std::string& name, type* ty);
|
||||||
|
|
||||||
// Const allocation
|
// Const allocation
|
||||||
void add_alloc(ir::alloc_const* x) { allocs_.push_back(x); }
|
void add_alloc(ir::alloc_const* x) { allocs_.push_back(x); }
|
||||||
const std::vector<ir::alloc_const*>& allocs() { return allocs_; }
|
const std::vector<ir::alloc_const*>& allocs() { return allocs_; }
|
||||||
// Register global
|
// Register global
|
||||||
void register_global(const std::string& name, ir::value *x) { globals_[name] = x; }
|
void register_global(const std::string& name, ir::value *x) { globals_[name] = x; }
|
||||||
const std::map<std::string, ir::value*>& globals() const { return globals_; }
|
const std::map<std::string, ir::value*>& globals() const { return globals_; }
|
||||||
// Metadata
|
//
|
||||||
void add_metadata(const std::string &name, md_pair_t x) { metadatas_[name] = x; }
|
|
||||||
|
|
||||||
void print(std::ostream &os);
|
void print(std::ostream &os);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::string name_;
|
std::string name_;
|
||||||
builder& builder_;
|
builder& builder_;
|
||||||
std::map<val_key_t, value*> values_;
|
|
||||||
std::map<std::string, type*> types_;
|
|
||||||
std::set<std::string> const_;
|
|
||||||
std::set<basic_block*> sealed_blocks_;
|
|
||||||
std::map<basic_block*, std::map<std::string, phi_node*>> incomplete_phis_;
|
|
||||||
functions_list_t functions_;
|
functions_list_t functions_;
|
||||||
symbols_map_t symbols_;
|
symbols_map_t symbols_;
|
||||||
std::function<ir::value*()> continue_fn_;
|
std::function<ir::value*()> continue_fn_;
|
||||||
std::map<value*, value**> current_phi_;
|
|
||||||
std::vector<ir::alloc_const*> allocs_;
|
std::vector<ir::alloc_const*> allocs_;
|
||||||
std::map<std::string, ir::value*> globals_;
|
std::map<std::string, ir::value*> globals_;
|
||||||
std::map<std::string, md_pair_t> metadatas_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@@ -1,4 +1,4 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#ifndef _TRITON_IR_TYPE_H_
|
#ifndef _TRITON_IR_TYPE_H_
|
||||||
#define _TRITON_IR_TYPE_H_
|
#define _TRITON_IR_TYPE_H_
|
||||||
@@ -73,6 +73,8 @@ public:
|
|||||||
type *get_tile_element_ty() const;
|
type *get_tile_element_ty() const;
|
||||||
unsigned get_pointer_address_space() const;
|
unsigned get_pointer_address_space() const;
|
||||||
type *get_pointer_element_ty() const;
|
type *get_pointer_element_ty() const;
|
||||||
|
unsigned get_struct_numel() const { return contained_tys_.size(); }
|
||||||
|
type *get_struct_type(unsigned int i) const { return contained_tys_[i]; }
|
||||||
|
|
||||||
// primitive predicates
|
// primitive predicates
|
||||||
bool is_void_ty() const { return id_ == VoidTyID; }
|
bool is_void_ty() const { return id_ == VoidTyID; }
|
||||||
@@ -91,6 +93,7 @@ public:
|
|||||||
bool is_bool_ty() const { return is_integer_ty(1); }
|
bool is_bool_ty() const { return is_integer_ty(1); }
|
||||||
bool is_pointer_ty() const { return id_ == PointerTyID; }
|
bool is_pointer_ty() const { return id_ == PointerTyID; }
|
||||||
bool is_block_ty() const { return id_ == BlockTyID; }
|
bool is_block_ty() const { return id_ == BlockTyID; }
|
||||||
|
bool is_struct_ty() const { return id_ == StructTyID; }
|
||||||
|
|
||||||
// Composite predicates
|
// Composite predicates
|
||||||
bool is_int_or_tileint_ty();
|
bool is_int_or_tileint_ty();
|
||||||
@@ -138,10 +141,10 @@ public:
|
|||||||
switch(id_) {
|
switch(id_) {
|
||||||
case VoidTyID: return "void";
|
case VoidTyID: return "void";
|
||||||
case FP8TyID: return "fp8";
|
case FP8TyID: return "fp8";
|
||||||
|
case BF16TyID: return "bf16";
|
||||||
case FP16TyID: return "f16";
|
case FP16TyID: return "f16";
|
||||||
case FP32TyID: return "f32";
|
case FP32TyID: return "f32";
|
||||||
case FP64TyID: return "f64";
|
case FP64TyID: return "f64";
|
||||||
case BF16TyID: return "bf16";
|
|
||||||
case LabelTyID: return "label";
|
case LabelTyID: return "label";
|
||||||
case MetadataTyID: return "md";
|
case MetadataTyID: return "md";
|
||||||
case TokenTyID: return "tok";
|
case TokenTyID: return "tok";
|
||||||
@@ -194,6 +197,16 @@ public:
|
|||||||
type* get_type_at_index(value *idx) const;
|
type* get_type_at_index(value *idx) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class struct_type: public composite_type {
|
||||||
|
public:
|
||||||
|
struct_type(const contained_tys_vec_t& tys, bool is_packed);
|
||||||
|
unsigned get_num_types() const { return contained_tys_.size(); }
|
||||||
|
static struct_type* get(const contained_tys_vec_t& tys, bool is_packed);
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool is_packed_;
|
||||||
|
};
|
||||||
|
|
||||||
class block_type: public composite_type {
|
class block_type: public composite_type {
|
||||||
private:
|
private:
|
||||||
block_type(type *ty, const block_shapes_t &shapes);
|
block_type(type *ty, const block_shapes_t &shapes);
|
||||||
@@ -242,6 +255,7 @@ public:
|
|||||||
ty_iterator params_end() { return contained_tys_.end(); }
|
ty_iterator params_end() { return contained_tys_.end(); }
|
||||||
type* get_param_ty(unsigned i) const { return contained_tys_.at(1 + i); }
|
type* get_param_ty(unsigned i) const { return contained_tys_.at(1 + i); }
|
||||||
type* get_return_ty() const { return contained_tys_.at(0); }
|
type* get_return_ty() const { return contained_tys_.at(0); }
|
||||||
|
void reset_ret_ty(type* ty) { contained_tys_[0] = ty;}
|
||||||
// factory methods
|
// factory methods
|
||||||
static function_type* get(type *ret_ty, const std::vector<type*>& param_tys);
|
static function_type* get(type *ret_ty, const std::vector<type*>& param_tys);
|
||||||
};
|
};
|
||||||
|
@@ -21,7 +21,7 @@ class visitor;
|
|||||||
|
|
||||||
class value {
|
class value {
|
||||||
public:
|
public:
|
||||||
typedef std::set<user*> users_t;
|
typedef std::vector<user*> users_t;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
// constructor
|
// constructor
|
||||||
@@ -30,7 +30,7 @@ public:
|
|||||||
// uses
|
// uses
|
||||||
void add_use(user* arg);
|
void add_use(user* arg);
|
||||||
users_t::iterator erase_use(user* arg);
|
users_t::iterator erase_use(user* arg);
|
||||||
const std::set<user*> &get_users() { return users_; }
|
const std::vector<user*> &get_users() { return users_; }
|
||||||
void replace_all_uses_with(value *target);
|
void replace_all_uses_with(value *target);
|
||||||
// name
|
// name
|
||||||
void set_name(const std::string &name);
|
void set_name(const std::string &name);
|
||||||
|
@@ -11,6 +11,9 @@ class value;
|
|||||||
|
|
||||||
class instruction;
|
class instruction;
|
||||||
|
|
||||||
|
class call_inst;
|
||||||
|
class launch_inst;
|
||||||
|
|
||||||
class phi_node;
|
class phi_node;
|
||||||
class binary_operator;
|
class binary_operator;
|
||||||
class getelementptr_inst;
|
class getelementptr_inst;
|
||||||
@@ -42,6 +45,9 @@ class masked_load_inst;
|
|||||||
class unmasked_store_inst;
|
class unmasked_store_inst;
|
||||||
class masked_store_inst;
|
class masked_store_inst;
|
||||||
|
|
||||||
|
class extract_value_inst;
|
||||||
|
class insert_value_inst;
|
||||||
|
|
||||||
class retile_inst;
|
class retile_inst;
|
||||||
class reshape_inst;
|
class reshape_inst;
|
||||||
class splat_inst;
|
class splat_inst;
|
||||||
@@ -105,6 +111,8 @@ public:
|
|||||||
virtual ~visitor() {}
|
virtual ~visitor() {}
|
||||||
|
|
||||||
virtual void visit_value(ir::value*);
|
virtual void visit_value(ir::value*);
|
||||||
|
virtual void visit_call_inst(ir::call_inst*) = 0;
|
||||||
|
virtual void visit_launch_inst(ir::launch_inst*) = 0;
|
||||||
|
|
||||||
virtual void visit_basic_block(basic_block*) = 0;
|
virtual void visit_basic_block(basic_block*) = 0;
|
||||||
virtual void visit_argument(argument*) = 0;
|
virtual void visit_argument(argument*) = 0;
|
||||||
@@ -132,6 +140,9 @@ public:
|
|||||||
virtual void visit_sin_inst(sin_inst*) = 0;
|
virtual void visit_sin_inst(sin_inst*) = 0;
|
||||||
virtual void visit_log_inst(log_inst*) = 0;
|
virtual void visit_log_inst(log_inst*) = 0;
|
||||||
|
|
||||||
|
virtual void visit_extract_value_inst(extract_value_inst*) = 0;
|
||||||
|
virtual void visit_insert_value_inst(insert_value_inst*) = 0;
|
||||||
|
|
||||||
virtual void visit_reshape_inst(reshape_inst*) = 0;
|
virtual void visit_reshape_inst(reshape_inst*) = 0;
|
||||||
virtual void visit_splat_inst(splat_inst*) = 0;
|
virtual void visit_splat_inst(splat_inst*) = 0;
|
||||||
virtual void visit_cat_inst(cat_inst*) = 0;
|
virtual void visit_cat_inst(cat_inst*) = 0;
|
||||||
|
@@ -608,6 +608,8 @@ void layouts::run(ir::module &mod) {
|
|||||||
// create temporaries
|
// create temporaries
|
||||||
size_t id = values_.size();
|
size_t id = values_.size();
|
||||||
ir::for_each_instruction(mod, [this, &id](ir::instruction* i) {
|
ir::for_each_instruction(mod, [this, &id](ir::instruction* i) {
|
||||||
|
// std::cout << "layout: " << std::endl;
|
||||||
|
// i->print(std::cout);
|
||||||
if(auto *red = dynamic_cast<ir::reduce_inst*>(i)) {
|
if(auto *red = dynamic_cast<ir::reduce_inst*>(i)) {
|
||||||
id++;
|
id++;
|
||||||
ir::value *arg = red->get_operand(0);
|
ir::value *arg = red->get_operand(0);
|
||||||
|
@@ -13,6 +13,7 @@
|
|||||||
#include "triton/codegen/transform/peephole.h"
|
#include "triton/codegen/transform/peephole.h"
|
||||||
#include "triton/codegen/transform/pipeline.h"
|
#include "triton/codegen/transform/pipeline.h"
|
||||||
#include "triton/codegen/transform/prefetch.h"
|
#include "triton/codegen/transform/prefetch.h"
|
||||||
|
#include "triton/codegen/transform/inline.h"
|
||||||
#include "triton/ir/function.h"
|
#include "triton/ir/function.h"
|
||||||
#include "triton/ir/module.h"
|
#include "triton/ir/module.h"
|
||||||
#include "triton/ir/print.h"
|
#include "triton/ir/print.h"
|
||||||
@@ -33,6 +34,7 @@ std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC
|
|||||||
bool cts_use_async = target->as_nvidia() && target->as_nvidia()->sm() >= 80;
|
bool cts_use_async = target->as_nvidia() && target->as_nvidia()->sm() >= 80;
|
||||||
// create passes
|
// create passes
|
||||||
codegen::analysis::align align;
|
codegen::analysis::align align;
|
||||||
|
codegen::transform::inliner inliner;
|
||||||
codegen::analysis::axes axes;
|
codegen::analysis::axes axes;
|
||||||
codegen::transform::cts cts(cts_use_async);
|
codegen::transform::cts cts(cts_use_async);
|
||||||
codegen::transform::pipeline pipeline(cts_use_async, num_stages);
|
codegen::transform::pipeline pipeline(cts_use_async, num_stages);
|
||||||
@@ -48,6 +50,7 @@ std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC
|
|||||||
codegen::transform::membar barriers(&liveness, &layouts, &allocation, &prefetch_s, target);
|
codegen::transform::membar barriers(&liveness, &layouts, &allocation, &prefetch_s, target);
|
||||||
codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target, num_warps);
|
codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target, num_warps);
|
||||||
// run passes
|
// run passes
|
||||||
|
inliner.run(ir);
|
||||||
dce.run(ir);
|
dce.run(ir);
|
||||||
peephole.run(ir);
|
peephole.run(ir);
|
||||||
dce.run(ir);
|
dce.run(ir);
|
||||||
|
@@ -13,6 +13,7 @@
|
|||||||
#include "triton/ir/module.h"
|
#include "triton/ir/module.h"
|
||||||
#include "triton/ir/function.h"
|
#include "triton/ir/function.h"
|
||||||
#include "triton/ir/type.h"
|
#include "triton/ir/type.h"
|
||||||
|
#include "triton/ir/utils.h"
|
||||||
#include "llvm/IR/Module.h"
|
#include "llvm/IR/Module.h"
|
||||||
#include "llvm/IR/IRBuilder.h"
|
#include "llvm/IR/IRBuilder.h"
|
||||||
#include "llvm/IR/IntrinsicsNVPTX.h"
|
#include "llvm/IR/IntrinsicsNVPTX.h"
|
||||||
@@ -139,6 +140,14 @@ Value* geper::operator()(Value *ptr, Value* off, const std::string& name){
|
|||||||
* \brief Convert Triton-IR Type to LLVM-IR Type
|
* \brief Convert Triton-IR Type to LLVM-IR Type
|
||||||
*/
|
*/
|
||||||
Type *generator::cvt(ir::type *ty) {
|
Type *generator::cvt(ir::type *ty) {
|
||||||
|
// struct
|
||||||
|
if(ty->is_struct_ty()){
|
||||||
|
std::vector<Type*> tys;
|
||||||
|
for(size_t i = 0; i < ty->get_struct_numel(); i++)
|
||||||
|
tys.push_back(cvt(ty->get_struct_type(i)));
|
||||||
|
return StructType::get(builder_->getContext(), tys, true);
|
||||||
|
}
|
||||||
|
|
||||||
// function
|
// function
|
||||||
if(auto* tt = dynamic_cast<ir::function_type*>(ty)){
|
if(auto* tt = dynamic_cast<ir::function_type*>(ty)){
|
||||||
Type *ret_ty = cvt(tt->get_return_ty());
|
Type *ret_ty = cvt(tt->get_return_ty());
|
||||||
@@ -266,6 +275,7 @@ void generator::visit_value(ir::value* v) {
|
|||||||
builder_->SetInsertPoint(&*current->getFirstNonPHI());
|
builder_->SetInsertPoint(&*current->getFirstNonPHI());
|
||||||
// visit user
|
// visit user
|
||||||
if(auto *usr = dynamic_cast<ir::user*>(v)){
|
if(auto *usr = dynamic_cast<ir::user*>(v)){
|
||||||
|
if(!dynamic_cast<ir::function*>(usr))
|
||||||
usr->accept(this);
|
usr->accept(this);
|
||||||
}
|
}
|
||||||
// revert insert point
|
// revert insert point
|
||||||
@@ -282,6 +292,81 @@ void generator::visit_phi_node(ir::phi_node* x) {
|
|||||||
vals_[x][idx] = phi(ty, x->get_num_operands());
|
vals_[x][idx] = phi(ty, x->get_num_operands());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* \brief Code Generation for `call`
|
||||||
|
*/
|
||||||
|
void generator::visit_call_inst(ir::call_inst* call) {
|
||||||
|
throw std::runtime_error("call not supported! Triton should be inlining everything.");
|
||||||
|
}
|
||||||
|
|
||||||
|
void generator::visit_launch_inst(ir::launch_inst *launch) {
|
||||||
|
ir::function* fn = (ir::function*)launch->get_operand(0);
|
||||||
|
// forward-declare cudaGetParameterBufferV2
|
||||||
|
std::vector<Type*> get_param_arg_tys = {PointerType::get(builder_->getInt8Ty(), 0),
|
||||||
|
ArrayType::get(builder_->getInt32Ty(), 3),
|
||||||
|
ArrayType::get(builder_->getInt32Ty(), 3),
|
||||||
|
builder_->getInt32Ty()};
|
||||||
|
FunctionType* get_param_ty = FunctionType::get(PointerType::get(builder_->getInt8Ty(), 0), get_param_arg_tys, false);
|
||||||
|
Function* get_param_buffer = Function::Create(get_param_ty, Function::ExternalLinkage, "cudaGetParameterBufferV2", mod_);
|
||||||
|
AllocaInst* grid = builder_->CreateAlloca(get_param_arg_tys[1]);
|
||||||
|
AllocaInst* block = builder_->CreateAlloca(get_param_arg_tys[2]);
|
||||||
|
ConstantInt* _0 = builder_->getInt32(0);
|
||||||
|
ConstantInt* _1 = builder_->getInt32(1);
|
||||||
|
ConstantInt* _2 = builder_->getInt32(2);
|
||||||
|
// create basic block
|
||||||
|
BasicBlock* launch_done_bb = BasicBlock::Create(builder_->getContext(), "launch_done", builder_->GetInsertBlock()->getParent());
|
||||||
|
BasicBlock* launch_bb = BasicBlock::Create(builder_->getContext(), "launch", launch_done_bb->getParent(), launch_done_bb);
|
||||||
|
Value *tid = tgt_->get_local_id(mod_, *builder_, 0);
|
||||||
|
Value *is_first_thread = builder_->CreateICmpEQ(tid, i32(0));
|
||||||
|
builder_->CreateCondBr(is_first_thread, launch_bb, launch_done_bb);
|
||||||
|
builder_->SetInsertPoint(launch_bb);
|
||||||
|
|
||||||
|
//
|
||||||
|
builder_->CreateStore(vals_[launch->get_grid()[0]][{}], builder_->CreateGEP(grid, {_0, _0}));
|
||||||
|
builder_->CreateStore(vals_[launch->get_grid()[1]][{}], builder_->CreateGEP(grid, {_0, _1}));
|
||||||
|
builder_->CreateStore(vals_[launch->get_grid()[2]][{}], builder_->CreateGEP(grid, {_0, _2}));
|
||||||
|
Value* num_warps = mul(builder_->getInt32(32), vals_[launch->get_num_warps()][{}]);
|
||||||
|
builder_->CreateStore(num_warps, builder_->CreateGEP(block, {_0, _0}));
|
||||||
|
builder_->CreateStore(builder_->getInt32(1), builder_->CreateGEP(block, {_0, _1}));
|
||||||
|
builder_->CreateStore(builder_->getInt32(1), builder_->CreateGEP(block, {_0, _2}));
|
||||||
|
Function* called_fn = fns_[fn];
|
||||||
|
Value* callee = ConstantExpr::getCast(Instruction::BitCast, called_fn, get_param_arg_tys[0]);
|
||||||
|
Value* arg_ptr = builder_->CreateCall(get_param_buffer, {callee, builder_->CreateLoad(grid), builder_->CreateLoad(block), builder_->getInt32(0)});
|
||||||
|
// forwrd-declare cudaLaunchDeviceV2
|
||||||
|
std::vector<Type*> launch_device_arg_tys = {get_param_ty->getReturnType(), builder_->getInt64Ty()};
|
||||||
|
FunctionType* launch_device_ty = FunctionType::get(builder_->getInt32Ty(), launch_device_arg_tys, false);
|
||||||
|
Function* launch_device = Function::Create(launch_device_ty, Function::ExternalLinkage, "cudaLaunchDeviceV2", mod_);
|
||||||
|
// TODO: add branch
|
||||||
|
Value* do_not_launch = builder_->CreateICmpEQ(builder_->CreatePtrToInt(arg_ptr, builder_->getInt64Ty()),
|
||||||
|
builder_->getInt64(0));
|
||||||
|
BasicBlock* launch2_bb = BasicBlock::Create(builder_->getContext(), "launch2", launch_done_bb->getParent(), launch_done_bb);
|
||||||
|
builder_->CreateCondBr(do_not_launch, launch_done_bb, launch2_bb);
|
||||||
|
builder_->SetInsertPoint(launch2_bb);
|
||||||
|
|
||||||
|
unsigned addr_space = arg_ptr->getType()->getPointerAddressSpace();
|
||||||
|
unsigned off = 0;
|
||||||
|
unsigned last_size = 0;
|
||||||
|
for(ir::value* arg: launch->get_values()){
|
||||||
|
Value* curr_arg = vals_[arg][{}];
|
||||||
|
Type* curr_arg_ty = curr_arg->getType();
|
||||||
|
// handle struct alignment
|
||||||
|
off += last_size;
|
||||||
|
unsigned size = curr_arg_ty->isPointerTy() ? 8 : curr_arg_ty->getPrimitiveSizeInBits() / 8;
|
||||||
|
off = (off + size - 1) / size * size;
|
||||||
|
// get pointer to current arg
|
||||||
|
Value* curr_arg_ptr = builder_->CreateGEP(arg_ptr, builder_->getInt32(off));
|
||||||
|
curr_arg_ptr = builder_->CreateBitCast(curr_arg_ptr, curr_arg_ty->getPointerTo(addr_space));
|
||||||
|
// store arg
|
||||||
|
builder_->CreateStore(curr_arg, curr_arg_ptr);
|
||||||
|
last_size = size;
|
||||||
|
}
|
||||||
|
builder_->CreateCall(launch_device, {arg_ptr, builder_->getInt64(0)});
|
||||||
|
builder_->CreateBr(launch_done_bb);
|
||||||
|
// done
|
||||||
|
builder_->SetInsertPoint(launch_done_bb);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* \brief Code Generation for `binary_operator`
|
* \brief Code Generation for `binary_operator`
|
||||||
*/
|
*/
|
||||||
@@ -311,6 +396,7 @@ void generator::visit_binary_operator(ir::binary_operator*x) {
|
|||||||
default: throw std::runtime_error("unreachable switch");
|
default: throw std::runtime_error("unreachable switch");
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
// x->print(std::cout);
|
||||||
for(indices_t idx: idxs_.at(x)){
|
for(indices_t idx: idxs_.at(x)){
|
||||||
Value *lhs = vals_[x->get_operand(0)][idx];
|
Value *lhs = vals_[x->get_operand(0)][idx];
|
||||||
Value *rhs = vals_[x->get_operand(1)][idx];
|
Value *rhs = vals_[x->get_operand(1)][idx];
|
||||||
@@ -852,6 +938,31 @@ void generator::visit_masked_store_inst(ir::masked_store_inst* x) {
|
|||||||
visit_store_inst(x);
|
visit_store_inst(x);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// --
|
||||||
|
|
||||||
|
void generator::visit_extract_value_inst(ir::extract_value_inst *x) {
|
||||||
|
auto idxs = idxs_.at(x);
|
||||||
|
ir::value* agg = x->get_operand(0);
|
||||||
|
unsigned insert_idx = x->get_idx();
|
||||||
|
for(size_t i = 0; i < idxs.size(); i++){
|
||||||
|
auto idx = idxs[i];
|
||||||
|
vals_[x][idx] = builder_->CreateExtractValue(vals_[agg][idx], {insert_idx});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
void generator::visit_insert_value_inst(ir::insert_value_inst *x){
|
||||||
|
auto idxs = idxs_.at(x);
|
||||||
|
ir::value* agg = x->get_operand(0);
|
||||||
|
ir::value* val = x->get_operand(1);
|
||||||
|
unsigned insert_idx = x->get_idx();
|
||||||
|
for(size_t i = 0; i < idxs.size(); i++){
|
||||||
|
auto idx = idxs[i];
|
||||||
|
vals_[x][idx] = builder_->CreateInsertValue(vals_[agg][idx], vals_[val][idx],{insert_idx});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --
|
||||||
/**
|
/**
|
||||||
* \brief Code Generation for `cat`
|
* \brief Code Generation for `cat`
|
||||||
*/
|
*/
|
||||||
@@ -2686,7 +2797,8 @@ void generator::visit_make_range(ir::make_range* x) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void generator::visit_undef_value(ir::undef_value *x) {
|
void generator::visit_undef_value(ir::undef_value *x) {
|
||||||
Type* ty = cvt(x->get_type()->get_scalar_ty());
|
ir::type* sca_ty = x->get_type()->get_scalar_ty();
|
||||||
|
Type* ty = cvt(sca_ty);
|
||||||
for(indices_t idx: idxs_.at(x))
|
for(indices_t idx: idxs_.at(x))
|
||||||
vals_[x][idx] = llvm::UndefValue::get(ty);
|
vals_[x][idx] = llvm::UndefValue::get(ty);
|
||||||
}
|
}
|
||||||
@@ -2713,8 +2825,7 @@ void generator::visit_alloc_const(ir::alloc_const *alloc) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
void generator::visit_function(ir::function* fn) {
|
void generator::forward_declare(ir::function* fn){
|
||||||
LLVMContext &ctx = builder_->getContext();
|
|
||||||
FunctionType *fn_ty = (FunctionType*)cvt(fn->get_fn_type());
|
FunctionType *fn_ty = (FunctionType*)cvt(fn->get_fn_type());
|
||||||
if(!tgt_->is_gpu()){
|
if(!tgt_->is_gpu()){
|
||||||
Type *fn_ret_ty = fn_ty->getReturnType();
|
Type *fn_ret_ty = fn_ty->getReturnType();
|
||||||
@@ -2727,6 +2838,18 @@ void generator::visit_function(ir::function* fn) {
|
|||||||
fn_ty = FunctionType::get(fn_ret_ty, fn_args_ty, false);
|
fn_ty = FunctionType::get(fn_ret_ty, fn_args_ty, false);
|
||||||
}
|
}
|
||||||
Function *ret = Function::Create(fn_ty, Function::ExternalLinkage, fn->get_name(), mod_);
|
Function *ret = Function::Create(fn_ty, Function::ExternalLinkage, fn->get_name(), mod_);
|
||||||
|
fns_[fn] = ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
void generator::visit_function(ir::function* fn) {
|
||||||
|
idxs_.clear();
|
||||||
|
vals_.clear();
|
||||||
|
seen_.clear();
|
||||||
|
LLVMContext &ctx = builder_->getContext();
|
||||||
|
|
||||||
|
Function* ret = fns_[fn];
|
||||||
|
|
||||||
|
|
||||||
// set attributes
|
// set attributes
|
||||||
for(auto attr_pair: fn->attrs()){
|
for(auto attr_pair: fn->attrs()){
|
||||||
unsigned id = attr_pair.first;
|
unsigned id = attr_pair.first;
|
||||||
@@ -2751,7 +2874,8 @@ void generator::visit_function(ir::function* fn) {
|
|||||||
for(unsigned i = 0; i < fn->args().size(); i++)
|
for(unsigned i = 0; i < fn->args().size(); i++)
|
||||||
vals_[fn->args()[i]][{}] = &*(ret->arg_begin() + i);
|
vals_[fn->args()[i]][{}] = &*(ret->arg_begin() + i);
|
||||||
// create blocks
|
// create blocks
|
||||||
for(ir::basic_block *block: fn->blocks()) {
|
auto blocks = ir::cfg::reverse_post_order(fn);
|
||||||
|
for(ir::basic_block *block: blocks) {
|
||||||
BasicBlock *dst_block = BasicBlock::Create(ctx, block->get_name(), ret);
|
BasicBlock *dst_block = BasicBlock::Create(ctx, block->get_name(), ret);
|
||||||
bbs_[block] = dst_block;
|
bbs_[block] = dst_block;
|
||||||
}
|
}
|
||||||
@@ -2761,7 +2885,7 @@ void generator::visit_function(ir::function* fn) {
|
|||||||
visit_layout(x.second);
|
visit_layout(x.second);
|
||||||
}
|
}
|
||||||
// generate LLVM-IR code
|
// generate LLVM-IR code
|
||||||
for(ir::basic_block *block: fn->blocks())
|
for(ir::basic_block *block: blocks)
|
||||||
visit_basic_block(block);
|
visit_basic_block(block);
|
||||||
// finalize
|
// finalize
|
||||||
finalize_function(fn);
|
finalize_function(fn);
|
||||||
@@ -2982,10 +3106,12 @@ void generator::visit_layout_shared(analysis::shared_layout* layout) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void generator::visit_basic_block(ir::basic_block * block) {
|
void generator::visit_basic_block(ir::basic_block * block) {
|
||||||
|
|
||||||
BasicBlock *parent = bbs_[block];
|
BasicBlock *parent = bbs_[block];
|
||||||
builder_->SetInsertPoint(parent);
|
builder_->SetInsertPoint(parent);
|
||||||
for(ir::instruction *i: block->get_inst_list())
|
for(ir::instruction *i: block->get_inst_list()){
|
||||||
visit_value(i);
|
visit_value(i);
|
||||||
|
}
|
||||||
// Update ir bb -> llvm bb mapping
|
// Update ir bb -> llvm bb mapping
|
||||||
bbs_[block] = builder_->GetInsertBlock();
|
bbs_[block] = builder_->GetInsertBlock();
|
||||||
}
|
}
|
||||||
@@ -3168,6 +3294,12 @@ void generator::finalize_phi_node(ir::phi_node *x) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
StructType* generator::packed_type(ir::value* i){
|
||||||
|
Type* dtype = cvt(i->get_type()->get_tile_element_ty());
|
||||||
|
auto* layout = dynamic_cast<analysis::scanline_layout*>(layouts_->get(i));
|
||||||
|
assert(layout);
|
||||||
|
}
|
||||||
|
|
||||||
void generator::visit(ir::module &src, llvm::Module &dst) {
|
void generator::visit(ir::module &src, llvm::Module &dst) {
|
||||||
mod_ = &dst;
|
mod_ = &dst;
|
||||||
ctx_ = &dst.getContext();
|
ctx_ = &dst.getContext();
|
||||||
@@ -3184,7 +3316,16 @@ void generator::visit(ir::module &src, llvm::Module &dst) {
|
|||||||
nullptr, "__shared_ptr", nullptr, GlobalVariable::NotThreadLocal, 3);
|
nullptr, "__shared_ptr", nullptr, GlobalVariable::NotThreadLocal, 3);
|
||||||
shmem_ = bit_cast(sh_mem_array, ptr_ty);
|
shmem_ = bit_cast(sh_mem_array, ptr_ty);
|
||||||
}
|
}
|
||||||
|
// instantiate device functions
|
||||||
|
// for(ir::function *fn: src.get_function_list())
|
||||||
|
// for(ir::basic_block *bb: fn->blocks())
|
||||||
|
// for(ir::instruction *i: bb->get_inst_list())
|
||||||
|
// if(auto *call = dynamic_cast<ir::call_inst*>(i)){
|
||||||
|
// std::cout << "call??" << std::endl;
|
||||||
|
// }
|
||||||
// visit functions
|
// visit functions
|
||||||
|
for(ir::function *fn: src.get_function_list())
|
||||||
|
forward_declare(fn);
|
||||||
for(ir::function *fn: src.get_function_list())
|
for(ir::function *fn: src.get_function_list())
|
||||||
visit_function(fn);
|
visit_function(fn);
|
||||||
}
|
}
|
||||||
|
@@ -3,6 +3,7 @@
|
|||||||
#include "triton/ir/basic_block.h"
|
#include "triton/ir/basic_block.h"
|
||||||
#include "triton/ir/module.h"
|
#include "triton/ir/module.h"
|
||||||
#include "triton/ir/utils.h"
|
#include "triton/ir/utils.h"
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
namespace triton {
|
namespace triton {
|
||||||
namespace codegen{
|
namespace codegen{
|
||||||
@@ -28,6 +29,8 @@ void dce::run(ir::module &mod) {
|
|||||||
case ir::INST_ATOMIC_CAS:
|
case ir::INST_ATOMIC_CAS:
|
||||||
case ir::INST_ATOMIC_RMW:
|
case ir::INST_ATOMIC_RMW:
|
||||||
case ir::INST_ATOMIC_EXCH:
|
case ir::INST_ATOMIC_EXCH:
|
||||||
|
case ir::INST_CALL:
|
||||||
|
case ir::INST_LAUNCH:
|
||||||
case ir::INST_BARRIER: {
|
case ir::INST_BARRIER: {
|
||||||
work_list.push_back(i);
|
work_list.push_back(i);
|
||||||
marked.insert(i);
|
marked.insert(i);
|
||||||
@@ -65,6 +68,7 @@ void dce::run(ir::module &mod) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// delete
|
// delete
|
||||||
for(ir::instruction* i: to_delete)
|
for(ir::instruction* i: to_delete)
|
||||||
i->erase_from_parent();
|
i->erase_from_parent();
|
||||||
|
127
lib/codegen/transform/inline.cc
Normal file
127
lib/codegen/transform/inline.cc
Normal file
@@ -0,0 +1,127 @@
|
|||||||
|
#include <iostream>
|
||||||
|
#include "triton/codegen/transform/inline.h"
|
||||||
|
#include "triton/ir/module.h"
|
||||||
|
#include "triton/ir/function.h"
|
||||||
|
#include "triton/ir/utils.h"
|
||||||
|
|
||||||
|
namespace triton{
|
||||||
|
namespace codegen{
|
||||||
|
namespace transform{
|
||||||
|
|
||||||
|
|
||||||
|
bool fncmp::operator()(ir::function* x, ir::function* y) const {
|
||||||
|
auto fn_list = x->get_parent()->get_function_list();
|
||||||
|
return std::find(fn_list.begin(), fn_list.end(), x) < std::find(fn_list.begin(), fn_list.end(), y);
|
||||||
|
};
|
||||||
|
|
||||||
|
void inliner::do_inline(ir::function* fn, ir::call_inst* callsite, ir::builder& builder,
|
||||||
|
std::list<ir::call_inst*>& callsites){
|
||||||
|
ir::basic_block* parent_block = callsite->get_parent();
|
||||||
|
ir::function* parent_fn = parent_block->get_parent();
|
||||||
|
// the parent block is split into block A and block B:
|
||||||
|
// - block A (`new_blocks[0]`) is the entry block of the inlined function
|
||||||
|
// - block B (`exit`) resumes execution of the parent function
|
||||||
|
ir::basic_block* entry = parent_block->split_before(callsite, fn->get_name());
|
||||||
|
ir::basic_block* exit = entry->get_successors()[0];
|
||||||
|
std::vector<ir::basic_block*> new_blocks = {entry};
|
||||||
|
for(size_t i = 1; i < fn->blocks().size(); i++){
|
||||||
|
ir::basic_block* block = fn->blocks()[i];
|
||||||
|
ir::context& ctx = block->get_context();
|
||||||
|
const std::string& name = block->get_parent()->get_name() + "_" + block->get_name();
|
||||||
|
new_blocks.push_back(ir::basic_block::create(ctx, name, parent_fn));
|
||||||
|
}
|
||||||
|
// a phi node holds the return values of the inlined function
|
||||||
|
if(exit->get_inst_list().empty())
|
||||||
|
builder.set_insert_point(exit);
|
||||||
|
else
|
||||||
|
builder.set_insert_point(exit->get_first_non_phi());
|
||||||
|
ir::phi_node* exit_val = builder.create_phi(fn->get_fn_type()->get_return_ty(), 0);
|
||||||
|
callsite->replace_all_uses_with(exit_val);
|
||||||
|
callsite->erase_from_parent();
|
||||||
|
// get arguments `fn` is called with
|
||||||
|
std::vector<ir::value*> tgt_args(callsite->op_begin(), callsite->op_end());
|
||||||
|
std::vector<ir::argument*> src_args(fn->args().begin(), fn->args().end());
|
||||||
|
// Actually generate the instructions:
|
||||||
|
// - Remove the branch created by basic_block::split_before
|
||||||
|
// - Clone all instructions
|
||||||
|
// - Replace `ret` with incoming nodes to `exit_val` and branches to `exit`
|
||||||
|
ir::instruction* terminator = new_blocks[0]->get_inst_list().back();
|
||||||
|
// new_blocks[0]->get_inst_list().back()->erase_from_parent();
|
||||||
|
terminator->erase_from_parent();
|
||||||
|
std::map<ir::instruction*, ir::instruction*> inst_map;
|
||||||
|
std::map<ir::argument*, ir::value*> arg_map;
|
||||||
|
for(size_t k = 0; k < fn->args().size(); k++)
|
||||||
|
arg_map[fn->args()[k]] = callsite->ops()[k];
|
||||||
|
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
|
||||||
|
for(size_t i = 0; i < new_blocks.size(); i++){
|
||||||
|
ir::basic_block* old_block = fn->blocks()[i];
|
||||||
|
ir::basic_block* new_block = new_blocks[i];
|
||||||
|
builder.set_insert_point(new_block);
|
||||||
|
for(ir::instruction* old_inst: old_block->get_inst_list()){
|
||||||
|
// clone instruction
|
||||||
|
ir::instruction* new_inst = old_inst->clone();
|
||||||
|
// replace basic block
|
||||||
|
for(size_t k = 0; k < new_blocks.size(); k++)
|
||||||
|
new_inst->replace_uses_of_with(fn->blocks()[k], new_blocks[k]);
|
||||||
|
// replace values
|
||||||
|
for(size_t k = 0; k < new_inst->get_num_operands(); k++){
|
||||||
|
ir::value* op = new_inst->get_operand(k);
|
||||||
|
if(auto arg_op = dynamic_cast<ir::argument*>(op))
|
||||||
|
new_inst->set_operand(k, arg_map.at(arg_op));
|
||||||
|
if(auto inst_op = dynamic_cast<ir::instruction*>(op))
|
||||||
|
if(inst_map.find(inst_op) != inst_map.end())
|
||||||
|
new_inst->set_operand(k, inst_map.at(inst_op));
|
||||||
|
}
|
||||||
|
// `ret` instruction is a special case:
|
||||||
|
// instead of returning we need to branch to after the function call
|
||||||
|
if(ir::return_inst* ret = dynamic_cast<ir::return_inst*>(new_inst)){
|
||||||
|
if(ir::value* ret_val = ret->get_return_value())
|
||||||
|
exit_val->add_incoming(ret_val, new_block);
|
||||||
|
new_inst = ir::branch_inst::create(exit);
|
||||||
|
}
|
||||||
|
inst_map[old_inst] = new_inst;
|
||||||
|
builder.insert(new_inst);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if(exit_val->get_num_incoming() == 1)
|
||||||
|
exit_val->replace_all_uses_with(exit_val->get_incoming_value(0));
|
||||||
|
// done -- make sure insert point is properly set to exit block
|
||||||
|
builder.set_insert_point(exit);
|
||||||
|
}
|
||||||
|
|
||||||
|
void inliner::run(ir::module &mod) {
|
||||||
|
|
||||||
|
// gather all call sites
|
||||||
|
while(true){
|
||||||
|
std::map<ir::function*, size_t> counts;
|
||||||
|
for(ir::function* fn: mod.get_function_list())
|
||||||
|
counts[fn] = 0;
|
||||||
|
|
||||||
|
std::list<ir::call_inst*> callsites;
|
||||||
|
for(ir::function* fn: mod.get_function_list()){
|
||||||
|
for(ir::basic_block* block: fn->blocks())
|
||||||
|
for(ir::instruction* instr: block->get_inst_list())
|
||||||
|
if(ir::call_inst* call = dynamic_cast<ir::call_inst*>(instr)){
|
||||||
|
callsites.push_back(call);
|
||||||
|
counts[call->get_fn()] += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for(auto& count: counts){
|
||||||
|
if(!count.first->get_is_kernel() && count.second == 0)
|
||||||
|
count.first->get_parent()->remove_function(count.first);
|
||||||
|
}
|
||||||
|
|
||||||
|
if(callsites.empty())
|
||||||
|
break;
|
||||||
|
|
||||||
|
for(ir::call_inst* call: callsites)
|
||||||
|
do_inline(call->get_fn(), call, mod.get_builder(), callsites);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@@ -178,6 +178,27 @@ bool peephole::rewrite_mult(ir::instruction *value, ir::builder& builder) {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool peephole::rewrite_insert_extract(ir::instruction *value, ir::builder& builder){
|
||||||
|
auto extracted = dynamic_cast<ir::extract_value_inst*>(value);
|
||||||
|
if(!extracted)
|
||||||
|
return false;
|
||||||
|
size_t extract_idx = extracted->get_idx();
|
||||||
|
ir::value* agg = extracted->get_operand(0);
|
||||||
|
auto insert = dynamic_cast<ir::insert_value_inst*>(agg);
|
||||||
|
while(insert){
|
||||||
|
agg = insert->get_operand(0);
|
||||||
|
ir::value* inserted = insert->get_operand(1);
|
||||||
|
size_t insert_idx = insert->get_idx();
|
||||||
|
insert = dynamic_cast<ir::insert_value_inst*>(agg);
|
||||||
|
if(extract_idx == insert_idx){
|
||||||
|
extracted->replace_all_uses_with(inserted);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
insert = dynamic_cast<ir::insert_value_inst*>(agg);
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
bool peephole::rewrite_gep_ptr_min_off_plus_off(ir::instruction *value, ir::builder& builder) {
|
bool peephole::rewrite_gep_ptr_min_off_plus_off(ir::instruction *value, ir::builder& builder) {
|
||||||
auto x = dynamic_cast<ir::getelementptr_inst*>(value);
|
auto x = dynamic_cast<ir::getelementptr_inst*>(value);
|
||||||
@@ -291,6 +312,7 @@ void peephole::run(ir::module &mod) {
|
|||||||
was_modified = was_modified || rewrite_mult(i, builder);
|
was_modified = was_modified || rewrite_mult(i, builder);
|
||||||
// was_modified = was_modified || rewrite_cts_cfs(i, builder);
|
// was_modified = was_modified || rewrite_cts_cfs(i, builder);
|
||||||
// was_modified = was_modified || rewrite_trans_phi(i, builder);
|
// was_modified = was_modified || rewrite_trans_phi(i, builder);
|
||||||
|
was_modified = was_modified || rewrite_insert_extract(i, builder);
|
||||||
was_modified = was_modified || rewrite_unit_red(i, builder);
|
was_modified = was_modified || rewrite_unit_red(i, builder);
|
||||||
was_modified = was_modified || rewrite_gep_ptr_min_off_plus_off(i, builder);
|
was_modified = was_modified || rewrite_gep_ptr_min_off_plus_off(i, builder);
|
||||||
// TODO: DOESN'T WORK FOR VECTORIZED MASKED LOAD
|
// TODO: DOESN'T WORK FOR VECTORIZED MASKED LOAD
|
||||||
|
@@ -134,6 +134,7 @@ void pipeline::run(ir::module &mod) {
|
|||||||
ir::builder &builder = mod.get_builder();
|
ir::builder &builder = mod.get_builder();
|
||||||
const int num_stages = num_stages_;
|
const int num_stages = num_stages_;
|
||||||
std::vector<std::pair<ir::phi_node*, std::vector<ir::value*>>> preheader_loads; // Used to reorder loads
|
std::vector<std::pair<ir::phi_node*, std::vector<ir::value*>>> preheader_loads; // Used to reorder loads
|
||||||
|
|
||||||
for(auto info: to_pipeline){
|
for(auto info: to_pipeline){
|
||||||
ir::load_inst* load = info.load;
|
ir::load_inst* load = info.load;
|
||||||
ir::phi_node* ptr = info.ptr;
|
ir::phi_node* ptr = info.ptr;
|
||||||
|
@@ -138,6 +138,7 @@ CUDA_DEFINE3(CUresult, cuDeviceGetAttribute, int *, CUdevice_attribute, CUdevice
|
|||||||
CUDA_DEFINE1(CUresult, cuDeviceGetCount, int*)
|
CUDA_DEFINE1(CUresult, cuDeviceGetCount, int*)
|
||||||
|
|
||||||
// link management
|
// link management
|
||||||
|
CUDA_DEFINE6(CUresult, cuLinkAddFile_v2, CUlinkState, CUjitInputType, const char *, unsigned int , CUjit_option *, void **);
|
||||||
CUDA_DEFINE8(CUresult, cuLinkAddData_v2, CUlinkState, CUjitInputType, void*, size_t, const char*, unsigned int, CUjit_option*, void**);
|
CUDA_DEFINE8(CUresult, cuLinkAddData_v2, CUlinkState, CUjitInputType, void*, size_t, const char*, unsigned int, CUjit_option*, void**);
|
||||||
CUDA_DEFINE4(CUresult, cuLinkCreate_v2, unsigned int, CUjit_option*, void**, CUlinkState*);
|
CUDA_DEFINE4(CUresult, cuLinkCreate_v2, unsigned int, CUjit_option*, void**, CUlinkState*);
|
||||||
CUDA_DEFINE1(CUresult, cuLinkDestroy, CUlinkState);
|
CUDA_DEFINE1(CUresult, cuLinkDestroy, CUlinkState);
|
||||||
|
@@ -90,7 +90,7 @@ void check(CUresult err)
|
|||||||
case CUDA_ERROR_NOT_PERMITTED : throw not_permitted();
|
case CUDA_ERROR_NOT_PERMITTED : throw not_permitted();
|
||||||
case CUDA_ERROR_NOT_SUPPORTED : throw not_supported();
|
case CUDA_ERROR_NOT_SUPPORTED : throw not_supported();
|
||||||
case CUDA_ERROR_UNKNOWN : throw unknown();
|
case CUDA_ERROR_UNKNOWN : throw unknown();
|
||||||
default : throw unknown();
|
default : throw std::runtime_error("unimplemented code: " + std::to_string(err));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -174,6 +174,7 @@ std::string llir_to_ptx(llvm::Module* module, int cc, int version){
|
|||||||
init_llvm();
|
init_llvm();
|
||||||
// verify and store llvm
|
// verify and store llvm
|
||||||
llvm::legacy::PassManager pm;
|
llvm::legacy::PassManager pm;
|
||||||
|
// pm.add(llvm::createPrintModulePass(llvm::outs()));
|
||||||
pm.add(llvm::createVerifierPass());
|
pm.add(llvm::createVerifierPass());
|
||||||
pm.run(*module);
|
pm.run(*module);
|
||||||
// module->print(llvm::outs(), nullptr);
|
// module->print(llvm::outs(), nullptr);
|
||||||
@@ -213,6 +214,7 @@ std::string llir_to_ptx(llvm::Module* module, int cc, int version){
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
std::string ptx_to_cubin(const std::string& ptx, const std::string& ptxas, int cc) {
|
std::string ptx_to_cubin(const std::string& ptx, const std::string& ptxas, int cc) {
|
||||||
// compile ptx with ptxas
|
// compile ptx with ptxas
|
||||||
char _fsrc[L_tmpnam];
|
char _fsrc[L_tmpnam];
|
||||||
|
@@ -1,3 +1,5 @@
|
|||||||
|
#include <iostream>
|
||||||
|
#include <algorithm>
|
||||||
#include "triton/ir/basic_block.h"
|
#include "triton/ir/basic_block.h"
|
||||||
#include "triton/ir/instructions.h"
|
#include "triton/ir/instructions.h"
|
||||||
#include "triton/ir/type.h"
|
#include "triton/ir/type.h"
|
||||||
@@ -9,23 +11,68 @@ namespace ir {
|
|||||||
class phi_node;
|
class phi_node;
|
||||||
|
|
||||||
|
|
||||||
basic_block::basic_block(context &ctx, const std::string &name, function *parent):
|
basic_block::basic_block(context &ctx, const std::string &name, function *parent, basic_block* next):
|
||||||
value(type::get_label_ty(ctx), name), ctx_(ctx), parent_(parent) {
|
value(type::get_label_ty(ctx), name), ctx_(ctx), parent_(parent) {
|
||||||
if(parent_)
|
if(parent_)
|
||||||
parent_->insert_block(this);
|
parent_->insert_block(this, next);
|
||||||
}
|
}
|
||||||
|
|
||||||
basic_block* basic_block::create(context &ctx, const std::string &name, function *parent){
|
basic_block* basic_block::create(context &ctx, const std::string &name, function *parent, basic_block* next){
|
||||||
return new basic_block(ctx, name, parent);
|
return new basic_block(ctx, name, parent, next);
|
||||||
}
|
}
|
||||||
|
|
||||||
void basic_block::add_predecessor(basic_block *pred) {
|
void basic_block::replace_phi_uses_with(basic_block* before, basic_block* after) {
|
||||||
preds_.push_back(pred);
|
for(ir::instruction* i: inst_list_){
|
||||||
if(pred)
|
auto* curr_phi = dynamic_cast<ir::phi_node*>(i);
|
||||||
pred->succs_.push_back(this);
|
if(!curr_phi)
|
||||||
|
break;
|
||||||
|
curr_phi->replace_uses_of_with(before, after);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void basic_block::append_instruction(ir::instruction* i){
|
||||||
|
i->set_parent(this);
|
||||||
|
inst_list_.push_back(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
basic_block* basic_block::split_before(ir::instruction* loc, const std::string& name) {
|
||||||
|
basic_block* ret = basic_block::create(ctx_, name, parent_, this);
|
||||||
|
ret->set_name(get_name());
|
||||||
|
set_name("after_" + name);
|
||||||
|
|
||||||
|
// splice instruction list
|
||||||
|
auto loc_it = std::find(inst_list_.begin(), inst_list_.end(), loc);
|
||||||
|
ret->get_inst_list().splice(ret->get_inst_list().begin(), inst_list_, inst_list_.begin(), loc_it);
|
||||||
|
for(ir::instruction* i: ret->get_inst_list())
|
||||||
|
i->set_parent(ret);
|
||||||
|
// the predecessors of `this` becomes the predecessors of `ret`
|
||||||
|
for(ir::basic_block* pred: get_predecessors()){
|
||||||
|
auto* term = dynamic_cast<ir::terminator_inst*>(pred->get_inst_list().back());
|
||||||
|
assert(term);
|
||||||
|
term->replace_uses_of_with(this, ret);
|
||||||
|
replace_phi_uses_with(pred, ret);
|
||||||
|
}
|
||||||
|
ir::branch_inst* br = branch_inst::create(this);
|
||||||
|
ret->append_instruction(br);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<basic_block*> basic_block::get_predecessors() const {
|
||||||
|
std::vector<basic_block*> ret;
|
||||||
|
for(ir::user* u: users_)
|
||||||
|
if(auto term = dynamic_cast<ir::terminator_inst*>(u))
|
||||||
|
ret.push_back(term->get_parent());
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<basic_block*> basic_block::get_successors() const {
|
||||||
|
std::vector<basic_block*> ret;
|
||||||
|
for(ir::instruction* i: inst_list_)
|
||||||
|
for(ir::value* v: i->ops())
|
||||||
|
if(auto block = dynamic_cast<ir::basic_block*>(v))
|
||||||
|
ret.push_back(block);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
basic_block::iterator basic_block::get_first_non_phi(){
|
basic_block::iterator basic_block::get_first_non_phi(){
|
||||||
auto it = begin();
|
auto it = begin();
|
||||||
|
@@ -117,13 +117,10 @@ type *builder::get_double_ty()
|
|||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
value* builder::create_br(basic_block *dest){
|
value* builder::create_br(basic_block *dest){
|
||||||
dest->add_predecessor(block_);
|
|
||||||
return insert(branch_inst::create(dest));
|
return insert(branch_inst::create(dest));
|
||||||
}
|
}
|
||||||
|
|
||||||
value* builder::create_cond_br(value *cond, basic_block *if_dest, basic_block *else_dest){
|
value* builder::create_cond_br(value *cond, basic_block *if_dest, basic_block *else_dest){
|
||||||
if_dest->add_predecessor(block_);
|
|
||||||
else_dest->add_predecessor(block_);
|
|
||||||
return insert(branch_inst::create(cond, if_dest, else_dest));
|
return insert(branch_inst::create(cond, if_dest, else_dest));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -131,6 +128,10 @@ value *builder::create_ret_void() {
|
|||||||
return insert(return_inst::create(ctx_));
|
return insert(return_inst::create(ctx_));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
value *builder::create_ret(value* val) {
|
||||||
|
return insert(return_inst::create(ctx_, val));
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// cast instructions
|
// cast instructions
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@@ -163,6 +164,19 @@ phi_node* builder::create_phi(type *ty, unsigned num_reserved){
|
|||||||
return insert(phi_node::create(ty, num_reserved));
|
return insert(phi_node::create(ty, num_reserved));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// call instructions
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
value *builder::create_call(function* fn, const std::vector<value*>& args){
|
||||||
|
return insert(call_inst::create(fn, args));
|
||||||
|
}
|
||||||
|
|
||||||
|
value* builder::create_launch(function* fn, const std::vector<value*>& args, const std::vector<value*>& grid, value* num_warps){
|
||||||
|
return insert(launch_inst::create(fn, args, grid, num_warps));
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// binary float instructions
|
// binary float instructions
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@@ -307,6 +321,19 @@ value *builder::create_masked_store(value *ptr, value *val, value *mask){
|
|||||||
return insert(masked_store_inst::create(ptr, val, mask));
|
return insert(masked_store_inst::create(ptr, val, mask));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// struct instructions
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
|
||||||
|
// Struct instructions
|
||||||
|
value *builder::create_insert_value(value* val, value *elt, size_t idx){
|
||||||
|
return insert(insert_value_inst::create(val, elt, idx));
|
||||||
|
}
|
||||||
|
|
||||||
|
value *builder::create_extract_value(value* val, size_t idx) {
|
||||||
|
return insert(extract_value_inst::create(val, idx));
|
||||||
|
}
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// block instructions
|
// block instructions
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@@ -865,6 +865,9 @@ ir::value *dispatch::clock(ir::builder *builder) {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Control FLow
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
//
|
//
|
||||||
|
|
||||||
ir::value *dispatch::multiple_of(ir::value *x, int value, ir::builder *){
|
ir::value *dispatch::multiple_of(ir::value *x, int value, ir::builder *){
|
||||||
|
@@ -33,8 +33,10 @@ void argument::accept(visitor *v) {
|
|||||||
/* function */
|
/* function */
|
||||||
function::function(function_type *ty, linkage_types_t linkage,
|
function::function(function_type *ty, linkage_types_t linkage,
|
||||||
const std::string &name, module *parent)
|
const std::string &name, module *parent)
|
||||||
: global_object(ty, 0, linkage, name), parent_(parent), fn_ty_(ty) {
|
: global_object(ty, 0, linkage, name), parent_(parent), fn_ty_(ty), is_kernel_(false) {
|
||||||
unsigned num_params = fn_ty_->get_num_params();
|
unsigned num_params = fn_ty_->get_num_params();
|
||||||
|
if(parent)
|
||||||
|
parent->push_function(this);
|
||||||
// skip if no parameter
|
// skip if no parameter
|
||||||
if(num_params == 0)
|
if(num_params == 0)
|
||||||
return;
|
return;
|
||||||
@@ -44,8 +46,6 @@ function::function(function_type *ty, linkage_types_t linkage,
|
|||||||
type *param_ty = fn_ty_->get_param_ty(i);
|
type *param_ty = fn_ty_->get_param_ty(i);
|
||||||
args_[i] = argument::create(param_ty, "", this, i);
|
args_[i] = argument::create(param_ty, "", this, i);
|
||||||
}
|
}
|
||||||
if(parent)
|
|
||||||
parent->push_function(this);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* basic block */
|
/* basic block */
|
||||||
|
@@ -5,6 +5,7 @@
|
|||||||
#include "triton/ir/instructions.h"
|
#include "triton/ir/instructions.h"
|
||||||
#include "triton/ir/constant.h"
|
#include "triton/ir/constant.h"
|
||||||
#include "triton/ir/type.h"
|
#include "triton/ir/type.h"
|
||||||
|
#include "triton/ir/function.h"
|
||||||
|
|
||||||
namespace triton{
|
namespace triton{
|
||||||
namespace ir{
|
namespace ir{
|
||||||
@@ -79,6 +80,70 @@ phi_node* phi_node::create(type *ty, unsigned num_reserved, const std::string &n
|
|||||||
return new phi_node(ty, num_reserved, name, next);
|
return new phi_node(ty, num_reserved, name, next);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// call_inst classes
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
std::string call_inst::repr_impl() const { return "call " + fn_->get_name(); }
|
||||||
|
|
||||||
|
call_inst::call_inst(ir::function* fn, const std::vector<ir::value*>& values, const std::string& name, instruction* next)
|
||||||
|
: instruction(fn->get_fn_type()->get_return_ty(), INST_CALL, values.size(), name, next), fn_(fn){
|
||||||
|
for(size_t i = 0; i < values.size(); i++)
|
||||||
|
set_operand(i, values.at(i));
|
||||||
|
}
|
||||||
|
|
||||||
|
call_inst* call_inst::create(ir::function* fn, const std::vector<ir::value*>& values, const std::string &name, instruction *next) {
|
||||||
|
return new call_inst(fn, values, name, next);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// launch
|
||||||
|
|
||||||
|
launch_inst::launch_inst(ir::function* fn, const std::vector<ir::value*>& values, const std::vector<ir::value*>& grid, ir::value* num_warps, const std::string& name, instruction* next)
|
||||||
|
: instruction(fn->get_fn_type()->get_return_ty(), INST_LAUNCH, 1 + values.size() + grid.size() + 1, name, next){
|
||||||
|
int k = 0;
|
||||||
|
if(grid.size() != 3)
|
||||||
|
throw std::runtime_error("grid must have 3 elements");
|
||||||
|
set_operand(k++, fn);
|
||||||
|
val_begin = k;
|
||||||
|
for(ir::value* v: values)
|
||||||
|
set_operand(k++, v);
|
||||||
|
val_end = k;
|
||||||
|
grid_begin = k;
|
||||||
|
for(ir::value* g: grid)
|
||||||
|
set_operand(k++, g);
|
||||||
|
grid_end = k;
|
||||||
|
set_operand(k++, num_warps);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
ir::function* launch_inst::get_fn() {
|
||||||
|
return (ir::function*)get_operand(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<ir::value*> launch_inst::get_values() {
|
||||||
|
std::vector<ir::value*> ret;
|
||||||
|
for(int i = val_begin; i < val_end; i++)
|
||||||
|
ret.push_back(get_operand(i));
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<ir::value*> launch_inst::get_grid() {
|
||||||
|
std::vector<ir::value*> ret;
|
||||||
|
for(int i = grid_begin; i < grid_end; i++)
|
||||||
|
ret.push_back(get_operand(i));
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
ir::value* launch_inst::get_num_warps() {
|
||||||
|
return get_operand(grid_end);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
launch_inst* launch_inst::create(ir::function *fn, const std::vector<ir::value *> &values, const std::vector<ir::value *> &grid, ir::value *num_warps, const std::string &name, instruction *next) {
|
||||||
|
return new launch_inst(fn, values, grid, num_warps, name, next);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// binary_operator classes
|
// binary_operator classes
|
||||||
@@ -324,7 +389,7 @@ cast_inst *cast_inst::create_integer_cast(value *arg, type *ty, bool is_signed,
|
|||||||
|
|
||||||
// return_inst
|
// return_inst
|
||||||
return_inst::return_inst(context &ctx, value *ret_val, instruction *next)
|
return_inst::return_inst(context &ctx, value *ret_val, instruction *next)
|
||||||
: terminator_inst(type::get_void_ty(ctx), INST_RETURN, ret_val!=nullptr, "", next){
|
: terminator_inst(ret_val?ret_val->get_type():type::get_void_ty(ctx), INST_RETURN, ret_val!=nullptr, "", next){
|
||||||
if(ret_val)
|
if(ret_val)
|
||||||
set_operand(0, ret_val);
|
set_operand(0, ret_val);
|
||||||
}
|
}
|
||||||
@@ -521,6 +586,36 @@ masked_store_inst::masked_store_inst(value *ptr, value *val, value *mask,
|
|||||||
masked_store_inst* masked_store_inst::create(value *ptr, value *val, value *mask, const std::string &name, instruction *next) {
|
masked_store_inst* masked_store_inst::create(value *ptr, value *val, value *mask, const std::string &name, instruction *next) {
|
||||||
return new masked_store_inst(ptr, val, mask, name, next);
|
return new masked_store_inst(ptr, val, mask, name, next);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// struct classes
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
// insert value
|
||||||
|
|
||||||
|
insert_value_inst::insert_value_inst(value *val, value *elt, size_t idx, const std::string& name, instruction *next)
|
||||||
|
: instruction(val->get_type(), INST_INSERT_VALUE, 2, name, next), idx_(idx) {
|
||||||
|
set_operand(0, val);
|
||||||
|
set_operand(1, elt);
|
||||||
|
}
|
||||||
|
|
||||||
|
insert_value_inst* insert_value_inst::create(value *val, value *elt, size_t idx, const std::string& name, instruction *next){
|
||||||
|
return new insert_value_inst(val, elt, idx, name, next);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// extract value
|
||||||
|
|
||||||
|
extract_value_inst::extract_value_inst(value *val, size_t idx, const std::string& name, instruction *next)
|
||||||
|
: instruction(val->get_type()->get_struct_type(idx), INST_EXTRACT_VALUE, 1, name, next), idx_(idx) {
|
||||||
|
set_operand(0, val);
|
||||||
|
}
|
||||||
|
|
||||||
|
extract_value_inst* extract_value_inst::create(value *val, size_t idx, const std::string& name, instruction *next){
|
||||||
|
return new extract_value_inst(val, idx, name, next);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// retile_inst classes
|
// retile_inst classes
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@@ -575,6 +670,9 @@ instruction* downcast_inst::create(value *arg, const std::string &name, instruct
|
|||||||
return new downcast_inst(arg->get_type()->get_scalar_ty(), INST_DOWNCAST, arg, name, next);
|
return new downcast_inst(arg->get_type()->get_scalar_ty(), INST_DOWNCAST, arg, name, next);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// matmul_inst classes
|
// matmul_inst classes
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@@ -9,17 +9,12 @@
|
|||||||
namespace triton{
|
namespace triton{
|
||||||
namespace ir{
|
namespace ir{
|
||||||
|
|
||||||
/* Module */
|
/* */
|
||||||
module::module(const std::string &name, builder &builder)
|
value_constructor::value_constructor(ir::builder& builder): builder_(builder){
|
||||||
: name_(name), builder_(builder) {
|
|
||||||
sealed_blocks_.insert(nullptr);
|
sealed_blocks_.insert(nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
ir::builder& module::get_builder() {
|
void value_constructor::set_value(const std::string& name, ir::basic_block *block, ir::value *value){
|
||||||
return builder_;
|
|
||||||
}
|
|
||||||
|
|
||||||
void module::set_value(const std::string& name, ir::basic_block *block, ir::value *value){
|
|
||||||
values_[val_key_t{name, block}] = value;
|
values_[val_key_t{name, block}] = value;
|
||||||
auto it = metadatas_.find(name);
|
auto it = metadatas_.find(name);
|
||||||
if(auto *x = dynamic_cast<ir::instruction*>(value))
|
if(auto *x = dynamic_cast<ir::instruction*>(value))
|
||||||
@@ -29,23 +24,11 @@ void module::set_value(const std::string& name, ir::basic_block *block, ir::valu
|
|||||||
// value->set_name(name);
|
// value->set_name(name);
|
||||||
}
|
}
|
||||||
|
|
||||||
void module::set_value(const std::string& name, ir::value *value){
|
void value_constructor::set_value(const std::string& name, ir::value *value){
|
||||||
return set_value(name, builder_.get_insert_block(), value);
|
return set_value(name, builder_.get_insert_block(), value);
|
||||||
}
|
}
|
||||||
|
|
||||||
void module::set_const(const std::string& name){
|
ir::phi_node* value_constructor::make_phi(ir::type *ty, unsigned num_values, ir::basic_block *block){
|
||||||
const_.insert(name);
|
|
||||||
}
|
|
||||||
|
|
||||||
void module::set_continue_fn(std::function<ir::value*()> fn) {
|
|
||||||
continue_fn_ = fn;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::function<ir::value*()> module::get_continue_fn() {
|
|
||||||
return continue_fn_;
|
|
||||||
}
|
|
||||||
|
|
||||||
ir::phi_node* module::make_phi(ir::type *ty, unsigned num_values, ir::basic_block *block){
|
|
||||||
basic_block::iterator insert = block->get_first_non_phi();
|
basic_block::iterator insert = block->get_first_non_phi();
|
||||||
if(insert != block->end()){
|
if(insert != block->end()){
|
||||||
builder_.set_insert_point(insert);
|
builder_.set_insert_point(insert);
|
||||||
@@ -56,7 +39,7 @@ ir::phi_node* module::make_phi(ir::type *ty, unsigned num_values, ir::basic_bloc
|
|||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
ir::value *module::try_remove_trivial_phis(ir::phi_node *&phi){
|
ir::value *value_constructor::try_remove_trivial_phis(ir::phi_node *&phi){
|
||||||
// find non-self references
|
// find non-self references
|
||||||
std::set<ir::value*> non_self_ref;
|
std::set<ir::value*> non_self_ref;
|
||||||
std::copy_if(phi->ops().begin(), phi->ops().end(), std::inserter(non_self_ref, non_self_ref.begin()),
|
std::copy_if(phi->ops().begin(), phi->ops().end(), std::inserter(non_self_ref, non_self_ref.begin()),
|
||||||
@@ -69,7 +52,7 @@ ir::value *module::try_remove_trivial_phis(ir::phi_node *&phi){
|
|||||||
assert(same != nullptr);
|
assert(same != nullptr);
|
||||||
phi->replace_all_uses_with(same);
|
phi->replace_all_uses_with(same);
|
||||||
phi->erase_from_parent();
|
phi->erase_from_parent();
|
||||||
std::set<ir::user*> users = phi->get_users();
|
std::vector<ir::user*> users = phi->get_users();
|
||||||
for(ir::user* u: users)
|
for(ir::user* u: users)
|
||||||
if(auto *uphi = dynamic_cast<ir::phi_node*>(u))
|
if(auto *uphi = dynamic_cast<ir::phi_node*>(u))
|
||||||
if(uphi != phi)
|
if(uphi != phi)
|
||||||
@@ -78,7 +61,7 @@ ir::value *module::try_remove_trivial_phis(ir::phi_node *&phi){
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
ir::value *module::add_phi_operands(const std::string& name, ir::phi_node *&phi){
|
ir::value *value_constructor::add_phi_operands(const std::string& name, ir::phi_node *&phi){
|
||||||
// already initialized
|
// already initialized
|
||||||
if(phi->get_num_operands())
|
if(phi->get_num_operands())
|
||||||
return phi;
|
return phi;
|
||||||
@@ -90,12 +73,11 @@ ir::value *module::add_phi_operands(const std::string& name, ir::phi_node *&phi)
|
|||||||
return phi;
|
return phi;
|
||||||
}
|
}
|
||||||
|
|
||||||
ir::value *module::get_value_recursive(const std::string& name, ir::basic_block *block) {
|
ir::value *value_constructor::get_value_recursive(const std::string& name, ir::basic_block *block) {
|
||||||
ir::value *result;
|
ir::value *result;
|
||||||
bool is_const = const_.find(name) != const_.end();
|
auto preds = block->get_predecessors();
|
||||||
auto &preds = block->get_predecessors();
|
|
||||||
ir::type *ty = types_.at(name);
|
ir::type *ty = types_.at(name);
|
||||||
if(block && !is_const && sealed_blocks_.find(block) == sealed_blocks_.end()){
|
if(block && sealed_blocks_.find(block) == sealed_blocks_.end()){
|
||||||
incomplete_phis_[block][name] = make_phi(ty, 1, block);
|
incomplete_phis_[block][name] = make_phi(ty, 1, block);
|
||||||
result = (ir::value*)incomplete_phis_[block][name];
|
result = (ir::value*)incomplete_phis_[block][name];
|
||||||
}
|
}
|
||||||
@@ -117,10 +99,12 @@ ir::value *module::get_value_recursive(const std::string& name, ir::basic_block
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
ir::value *module::get_value(const std::string& name, ir::basic_block *block) {
|
ir::value *value_constructor::get_value(const std::string& name, ir::basic_block *block) {
|
||||||
ir::basic_block* save_block = builder_.get_insert_block();
|
ir::basic_block* save_block = builder_.get_insert_block();
|
||||||
ir::basic_block::iterator save_pt = builder_.get_insert_point();
|
ir::basic_block::iterator save_pt = builder_.get_insert_point();
|
||||||
val_key_t key(name, block);
|
val_key_t key(name, block);
|
||||||
|
// std::cout << values_.size() << std::endl;
|
||||||
|
// std::cout << name << " " << block << " " << values_.begin()->first.first << " " << values_.begin()->first.second << std::endl;
|
||||||
if(values_.find(key) != values_.end()){
|
if(values_.find(key) != values_.end()){
|
||||||
return values_.at(key);
|
return values_.at(key);
|
||||||
}
|
}
|
||||||
@@ -131,15 +115,11 @@ ir::value *module::get_value(const std::string& name, ir::basic_block *block) {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
ir::value *module::get_value(const std::string& name) {
|
ir::value *value_constructor::get_value(const std::string& name) {
|
||||||
return get_value(name, builder_.get_insert_block());
|
return get_value(name, builder_.get_insert_block());
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::string& module::get_name() {
|
void value_constructor::seal_block(ir::basic_block *block){
|
||||||
return name_;
|
|
||||||
}
|
|
||||||
|
|
||||||
void module::seal_block(ir::basic_block *block){
|
|
||||||
for(auto &x: incomplete_phis_[block]){
|
for(auto &x: incomplete_phis_[block]){
|
||||||
add_phi_operands(x.first, x.second);
|
add_phi_operands(x.first, x.second);
|
||||||
if(get_value(x.first) == x.second)
|
if(get_value(x.first) == x.second)
|
||||||
@@ -149,11 +129,40 @@ void module::seal_block(ir::basic_block *block){
|
|||||||
incomplete_phis_[block].clear();
|
incomplete_phis_[block].clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
/* Module */
|
||||||
|
|
||||||
|
module::module(const std::string &name, builder &builder)
|
||||||
|
: name_(name), builder_(builder) {
|
||||||
|
}
|
||||||
|
|
||||||
|
void module::reset_ret_ty(const std::string& name, type* ty) {
|
||||||
|
get_function(name)->get_fn_type()->reset_ret_ty(ty);
|
||||||
|
}
|
||||||
|
|
||||||
|
ir::builder& module::get_builder() {
|
||||||
|
return builder_;
|
||||||
|
}
|
||||||
|
|
||||||
|
void module::set_continue_fn(std::function<ir::value*()> fn) {
|
||||||
|
continue_fn_ = fn;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::function<ir::value*()> module::get_continue_fn() {
|
||||||
|
return continue_fn_;
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::string& module::get_name() {
|
||||||
|
return name_;
|
||||||
|
}
|
||||||
|
|
||||||
/* functions */
|
/* functions */
|
||||||
function *module::get_or_insert_function(const std::string &name, function_type *ty) {
|
function *module::get_or_insert_function(const std::string &name, function_type *ty) {
|
||||||
function *&fn = (function*&)symbols_[name];
|
function *&fn = (function*&)symbols_[name];
|
||||||
if(fn == nullptr)
|
if(fn == nullptr){
|
||||||
return fn = function::create(ty, global_value::external, name, this);
|
fn = function::create(ty, global_value::external, name, this);
|
||||||
|
}
|
||||||
return fn;
|
return fn;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -188,7 +188,26 @@ bool composite_type::index_valid(value *idx) const{
|
|||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// tile_type class
|
// struct_type class
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
struct_type::struct_type(const contained_tys_vec_t& tys, bool is_packed)
|
||||||
|
: composite_type(tys[0]->get_context(), StructTyID), is_packed_(is_packed) {
|
||||||
|
contained_tys_ = tys;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct_type* struct_type::get(const contained_tys_vec_t& tys, bool is_packed) {
|
||||||
|
assert(tys.size());
|
||||||
|
context_impl* impl = tys[0]->get_context().p_impl.get();
|
||||||
|
struct_type *& entry = impl->struct_tys[tys];
|
||||||
|
if(!entry)
|
||||||
|
entry = new struct_type(tys, is_packed);
|
||||||
|
return entry;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// block_type class
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
block_type::block_type(type *ty, const block_shapes_t &shapes)
|
block_type::block_type(type *ty, const block_shapes_t &shapes)
|
||||||
|
@@ -1,5 +1,6 @@
|
|||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
#include <algorithm>
|
||||||
#include "triton/ir/value.h"
|
#include "triton/ir/value.h"
|
||||||
#include "triton/ir/instructions.h"
|
#include "triton/ir/instructions.h"
|
||||||
|
|
||||||
@@ -17,11 +18,11 @@ value::value(type *ty, const std::string &name): ty_(ty){
|
|||||||
}
|
}
|
||||||
|
|
||||||
void value::add_use(user *arg) {
|
void value::add_use(user *arg) {
|
||||||
users_.insert(arg);
|
users_.push_back(arg);
|
||||||
}
|
}
|
||||||
|
|
||||||
value::users_t::iterator value::erase_use(user *arg){
|
value::users_t::iterator value::erase_use(user *arg){
|
||||||
auto it = users_.find(arg);
|
auto it = std::find(users_.begin(), users_.end(), arg);
|
||||||
if(it == users_.end())
|
if(it == users_.end())
|
||||||
return it;
|
return it;
|
||||||
return users_.erase(it);
|
return users_.erase(it);
|
||||||
|
@@ -79,7 +79,7 @@ class CMakeBuild(build_ext):
|
|||||||
|
|
||||||
def build_extension(self, ext):
|
def build_extension(self, ext):
|
||||||
llvm_include_dir, llvm_library_dir = get_llvm()
|
llvm_include_dir, llvm_library_dir = get_llvm()
|
||||||
self.debug = True
|
# self.debug = True
|
||||||
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path)))
|
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path)))
|
||||||
# create build directories
|
# create build directories
|
||||||
build_suffix = 'debug' if self.debug else 'release'
|
build_suffix = 'debug' if self.debug else 'release'
|
||||||
|
@@ -659,6 +659,8 @@ void init_triton_ir(py::module &&m) {
|
|||||||
py::class_<ir::type>(m, "type")
|
py::class_<ir::type>(m, "type")
|
||||||
.def("is_ptr", &ir::type::is_pointer_ty)
|
.def("is_ptr", &ir::type::is_pointer_ty)
|
||||||
.def("is_int", static_cast<bool (ir::type::*)() const>(&ir::type::is_integer_ty))
|
.def("is_int", static_cast<bool (ir::type::*)() const>(&ir::type::is_integer_ty))
|
||||||
|
.def("get_int_width", &ir::type::get_integer_bitwidth)
|
||||||
|
|
||||||
.def("is_floating", &ir::type::is_floating_point_ty)
|
.def("is_floating", &ir::type::is_floating_point_ty)
|
||||||
.def("is_block", &ir::type::is_block_ty)
|
.def("is_block", &ir::type::is_block_ty)
|
||||||
.def("make_ptr", &ir::pointer_type::get, ret::reference)
|
.def("make_ptr", &ir::pointer_type::get, ret::reference)
|
||||||
@@ -695,6 +697,7 @@ void init_triton_ir(py::module &&m) {
|
|||||||
.def("is_uint16", [](ir::type *self) { return self->is_integer_ty(16, ir::signedness::UNSIGNED); })
|
.def("is_uint16", [](ir::type *self) { return self->is_integer_ty(16, ir::signedness::UNSIGNED); })
|
||||||
.def("is_uint32", [](ir::type *self) { return self->is_integer_ty(32, ir::signedness::UNSIGNED); })
|
.def("is_uint32", [](ir::type *self) { return self->is_integer_ty(32, ir::signedness::UNSIGNED); })
|
||||||
.def("is_uint64", [](ir::type *self) { return self->is_integer_ty(64, ir::signedness::UNSIGNED); })
|
.def("is_uint64", [](ir::type *self) { return self->is_integer_ty(64, ir::signedness::UNSIGNED); })
|
||||||
|
.def("is_struct", &ir::type::is_struct_ty)
|
||||||
|
|
||||||
.def("repr", &ir::type::repr)
|
.def("repr", &ir::type::repr)
|
||||||
.def_property_readonly("fp_mantissa_width", &ir::type::get_fp_mantissa_width)
|
.def_property_readonly("fp_mantissa_width", &ir::type::get_fp_mantissa_width)
|
||||||
@@ -704,23 +707,37 @@ void init_triton_ir(py::module &&m) {
|
|||||||
py::class_<ir::pointer_type, ir::type>(m, "pointer_type")
|
py::class_<ir::pointer_type, ir::type>(m, "pointer_type")
|
||||||
.def_property_readonly("element", &ir::pointer_type::get_element_ty, ret::reference);
|
.def_property_readonly("element", &ir::pointer_type::get_element_ty, ret::reference);
|
||||||
|
|
||||||
py::class_<ir::function_type, ir::type>(m, "function_type");
|
py::class_<ir::function_type, ir::type>(m, "function_type")
|
||||||
|
.def_property_readonly("ret_ty", &ir::function_type::get_return_ty)
|
||||||
|
.def_property_readonly("arg_tys", [](ir::function_type* self){
|
||||||
|
return std::vector<ir::type*>(self->params_begin(), self->params_end());
|
||||||
|
});
|
||||||
|
|
||||||
py::class_<ir::integer_type, ir::type>(m, "integer_type");
|
py::class_<ir::integer_type, ir::type>(m, "integer_type");
|
||||||
|
|
||||||
py::class_<ir::block_type, ir::type>(m, "block_type")
|
py::class_<ir::block_type, ir::type>(m, "block_type")
|
||||||
.def_property_readonly("shape", &ir::block_type::get_shapes)
|
.def_property_readonly("shape", &ir::block_type::get_shapes)
|
||||||
.def_property_readonly("numel", &ir::type::get_tile_num_elements);
|
.def_property_readonly("numel", &ir::type::get_tile_num_elements);
|
||||||
|
|
||||||
|
py::class_<ir::struct_type, ir::type>(m, "struct_type")
|
||||||
|
.def("get", &ir::struct_type::get, ret::reference)
|
||||||
|
.def_property_readonly("num_types", &ir::struct_type::get_num_types);
|
||||||
|
|
||||||
|
py::class_<ir::value_constructor>(m, "value_constructor")
|
||||||
|
.def(py::init<ir::builder&>())
|
||||||
|
.def("seal_block", &ir::value_constructor::seal_block)
|
||||||
|
.def("set_value", (void (ir::value_constructor::*)(const std::string &, ir::value *)) & ir::value_constructor::set_value)
|
||||||
|
.def("set_type", &ir::value_constructor::set_type)
|
||||||
|
.def("get_value", (ir::value * (ir::value_constructor::*)(const std::string &)) & ir::value_constructor::get_value, ret::reference)
|
||||||
|
.def("get_values", &ir::value_constructor::get_values, ret::reference)
|
||||||
|
.def("set_values", &ir::value_constructor::set_values);
|
||||||
|
|
||||||
py::class_<ir::module>(m, "module")
|
py::class_<ir::module>(m, "module")
|
||||||
.def(py::init<std::string, ir::builder &>())
|
.def(py::init<std::string, ir::builder &>())
|
||||||
|
.def("has_function", &ir::module::has_function)
|
||||||
|
.def("get_function", &ir::module::get_function, ret::reference)
|
||||||
.def("get_or_insert_function", &ir::module::get_or_insert_function, ret::reference)
|
.def("get_or_insert_function", &ir::module::get_or_insert_function, ret::reference)
|
||||||
.def("seal_block", &ir::module::seal_block)
|
.def("reset_ret_ty", &ir::module::reset_ret_ty)
|
||||||
.def("set_value", (void (ir::module::*)(const std::string &, ir::value *)) & ir::module::set_value)
|
|
||||||
.def("set_type", &ir::module::set_type)
|
|
||||||
.def("get_value", (ir::value * (ir::module::*)(const std::string &)) & ir::module::get_value, ret::reference)
|
|
||||||
.def("get_values", &ir::module::get_values, ret::reference)
|
|
||||||
.def("set_values", &ir::module::set_values)
|
|
||||||
.def("get_types", &ir::module::get_types, ret::reference)
|
|
||||||
.def("set_types", &ir::module::set_types)
|
|
||||||
.def_property_readonly("builder", &ir::module::get_builder, ret::reference);
|
.def_property_readonly("builder", &ir::module::get_builder, ret::reference);
|
||||||
|
|
||||||
using eattr = ir::attribute_kind_t;
|
using eattr = ir::attribute_kind_t;
|
||||||
@@ -734,29 +751,45 @@ void init_triton_ir(py::module &&m) {
|
|||||||
.value("not_implemented", eattr::not_implemented);
|
.value("not_implemented", eattr::not_implemented);
|
||||||
|
|
||||||
py::class_<ir::attribute>(m, "attribute")
|
py::class_<ir::attribute>(m, "attribute")
|
||||||
.def(py::init<eattr, int>());
|
.def(py::init<eattr, int>())
|
||||||
|
.def_property_readonly("value", &ir::attribute::get_value);
|
||||||
|
|
||||||
py::class_<ir::function>(m, "function")
|
py::class_<ir::function>(m, "function")
|
||||||
.def_property_readonly("args", &ir::function::args)
|
.def_property_readonly("args", &ir::function::args)
|
||||||
.def_property_readonly("attrs", &ir::function::attrs)
|
.def_property_readonly("attrs", &ir::function::attrs)
|
||||||
.def("add_attr", &ir::function::add_attr);
|
.def("set_is_kernel", &ir::function::set_is_kernel)
|
||||||
|
.def("add_attr", &ir::function::add_attr)
|
||||||
|
.def("has_attr", &ir::function::has_attr)
|
||||||
|
.def("get_attrs", &ir::function::get_attributes);
|
||||||
|
|
||||||
py::class_<ir::argument, ir::value>(m, "argument");
|
py::class_<ir::argument, ir::value>(m, "argument")
|
||||||
|
.def_property_readonly("parent", &ir::argument::get_parent, ret::reference)
|
||||||
|
.def_property_readonly("arg_no", &ir::argument::get_arg_no);
|
||||||
|
|
||||||
py::class_<ir::basic_block, ir::value>(m, "basic_block")
|
py::class_<ir::basic_block, ir::value>(m, "basic_block")
|
||||||
.def("create", &ir::basic_block::create, ret::reference)
|
.def("create", &ir::basic_block::create, ret::reference, py::arg(), py::arg(), py::arg() = nullptr)
|
||||||
.def_property_readonly("parent", &ir::basic_block::get_parent, ret::reference);
|
.def_property_readonly("parent", &ir::basic_block::get_parent, ret::reference);
|
||||||
|
|
||||||
|
py::class_<ir::builder::iterator>(m, "bb_iterator");
|
||||||
|
|
||||||
py::class_<ir::builder>(m, "builder", py::dynamic_attr())
|
py::class_<ir::builder>(m, "builder", py::dynamic_attr())
|
||||||
.def(py::init<ir::context &>())
|
.def(py::init<ir::context &>())
|
||||||
// getters
|
// getters
|
||||||
.def_property_readonly("context", &ir::builder::get_context, ret::reference)
|
.def_property_readonly("context", &ir::builder::get_context, ret::reference)
|
||||||
// control flow
|
// control flow
|
||||||
|
.def("call", &ir::builder::create_call, ret::reference)
|
||||||
|
.def("launch", &ir::builder::create_launch, ret::reference)
|
||||||
.def("br", &ir::builder::create_br, ret::reference)
|
.def("br", &ir::builder::create_br, ret::reference)
|
||||||
.def("cond_br", &ir::builder::create_cond_br, ret::reference)
|
.def("cond_br", &ir::builder::create_cond_br, ret::reference)
|
||||||
.def("ret_void", &ir::builder::create_ret_void, ret::reference)
|
.def("ret_void", &ir::builder::create_ret_void, ret::reference)
|
||||||
|
.def("ret", &ir::builder::create_ret, ret::reference)
|
||||||
|
.def("get_insert_point", &ir::builder::get_insert_point)
|
||||||
|
.def("set_insert_point", (void (ir::builder::*)(ir::builder::iterator))&ir::builder::set_insert_point)
|
||||||
.def("get_insert_block", &ir::builder::get_insert_block, ret::reference)
|
.def("get_insert_block", &ir::builder::get_insert_block, ret::reference)
|
||||||
.def("set_insert_block", (void (ir::builder::*)(ir::basic_block *)) & ir::builder::set_insert_point)
|
.def("set_insert_block", (void (ir::builder::*)(ir::basic_block *)) & ir::builder::set_insert_point)
|
||||||
|
// struct
|
||||||
|
.def("insert_value", &ir::builder::create_insert_value, ret::reference)
|
||||||
|
.def("extract_value", &ir::builder::create_extract_value, ret::reference)
|
||||||
// constants
|
// constants
|
||||||
.def("get_int1", &ir::builder::get_int1, ret::reference)
|
.def("get_int1", &ir::builder::get_int1, ret::reference)
|
||||||
.def("get_int32", &ir::builder::get_int32, ret::reference)
|
.def("get_int32", &ir::builder::get_int32, ret::reference)
|
||||||
|
@@ -585,7 +585,6 @@ def test_f8_f16_roundtrip():
|
|||||||
|
|
||||||
f8_output_tensor = torch.empty_like(f16, dtype=torch.int8)
|
f8_output_tensor = torch.empty_like(f16, dtype=torch.int8)
|
||||||
f8_output = triton.reinterpret(f8_output_tensor, tl.float8)
|
f8_output = triton.reinterpret(f8_output_tensor, tl.float8)
|
||||||
print(f16.dtype, f8_output.dtype)
|
|
||||||
copy_kernel[grid](f16, f8_output, n_elements, BLOCK_SIZE=1024)
|
copy_kernel[grid](f16, f8_output, n_elements, BLOCK_SIZE=1024)
|
||||||
|
|
||||||
assert torch.all(f8_tensor == f8_output_tensor)
|
assert torch.all(f8_tensor == f8_output_tensor)
|
||||||
@@ -1009,8 +1008,8 @@ def test_value_specialization(value: int, value_type: str, device='cuda') -> Non
|
|||||||
|
|
||||||
# Parse out the type of the 'VALUE' parameter from the Triton IR.
|
# Parse out the type of the 'VALUE' parameter from the Triton IR.
|
||||||
triton_ir = pgm.asm['ttir']
|
triton_ir = pgm.asm['ttir']
|
||||||
ir_value_match = re.match(r'\s*def void kernel\((\w+) VALUE ', triton_ir)
|
ir_value_match = re.match(r'\s*def void (\w+)\((\w+) VALUE ', triton_ir)
|
||||||
ir_value_type = None if ir_value_match is None else ir_value_match.group(1)
|
ir_value_type = None if ir_value_match is None else ir_value_match.group(2)
|
||||||
assert ir_value_type == value_type
|
assert ir_value_type == value_type
|
||||||
|
|
||||||
|
|
||||||
@@ -1031,3 +1030,28 @@ def test_value_specialization_overflow(value: int, overflow: bool, device='cuda'
|
|||||||
kernel[(1, )](value, x)
|
kernel[(1, )](value, x)
|
||||||
else:
|
else:
|
||||||
kernel[(1, )](value, x)
|
kernel[(1, )](value, x)
|
||||||
|
# -------------------------
|
||||||
|
# test dynamic parallelism
|
||||||
|
# -------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def mult(x, alpha):
|
||||||
|
tl.store(x + tl.program_id(0), alpha)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def stub(X, alpha, grid_0, grid_1, grid_2):
|
||||||
|
tl.launch(mult, [X, alpha], [grid_0, grid_1, grid_2])
|
||||||
|
|
||||||
|
|
||||||
|
def test_dyn_par(cond=True, device='cuda'):
|
||||||
|
n_pids = 10
|
||||||
|
# pids = torch.arange(n_pids, device=device)
|
||||||
|
# alpha = 2.0
|
||||||
|
# x_ref = pids * alpha
|
||||||
|
x_tri = torch.full((10,), fill_value=-1., device=device)
|
||||||
|
# cond = torch.tensor([cond], device=device)
|
||||||
|
stub[(1,)](x_tri, 3.14, n_pids, 1, 1)
|
||||||
|
print(x_tri)
|
||||||
|
# triton.testing.assert_almost_equal(x_ref, x_tri)
|
||||||
|
@@ -21,6 +21,41 @@ import triton._C.libtriton.triton as _triton
|
|||||||
from .tools.disasm import extract
|
from .tools.disasm import extract
|
||||||
|
|
||||||
|
|
||||||
|
def mangle_ty(type):
|
||||||
|
if type.is_ptr():
|
||||||
|
return 'P' + mangle_ty(type.element)
|
||||||
|
if type.is_int():
|
||||||
|
return 'i' + str(type.get_int_width())
|
||||||
|
if type.is_fp8():
|
||||||
|
return 'fp8'
|
||||||
|
if type.is_fp16():
|
||||||
|
return 'fp16'
|
||||||
|
if type.is_bf16():
|
||||||
|
return 'bf16'
|
||||||
|
if type.is_fp32():
|
||||||
|
return 'fp32'
|
||||||
|
if type.is_fp64():
|
||||||
|
return 'fp64'
|
||||||
|
if type.is_void():
|
||||||
|
return 'V'
|
||||||
|
if type.is_block():
|
||||||
|
elt = mangle_ty(type.scalar)
|
||||||
|
shape = '_'.join(map(str, type.shape))
|
||||||
|
return f'{elt}S{shape}S'
|
||||||
|
assert False, "Unsupport type"
|
||||||
|
|
||||||
|
|
||||||
|
def mangle_fn(name, arg_tys, constants):
|
||||||
|
# doesn't mangle ret type, which must be a function of arg tys
|
||||||
|
mangled_arg_names = '_'.join([mangle_ty(ty) for ty in arg_tys])
|
||||||
|
key = lambda x: x.__name__ if isinstance(x, JITFunction) else repr(x)
|
||||||
|
mangled_constants = '_'.join([f'{i}c{key(constants[i])}' for i in sorted(constants)])
|
||||||
|
mangled_constants = mangled_constants.replace('.', '_d_')
|
||||||
|
mangled_constants = mangled_constants.replace("'", '_sq_')
|
||||||
|
ret = f'{name}__{mangled_arg_names}__{mangled_constants}'
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
class CodeGenerator(ast.NodeVisitor):
|
class CodeGenerator(ast.NodeVisitor):
|
||||||
def get_value(self, name):
|
def get_value(self, name):
|
||||||
# search node.id in local scope
|
# search node.id in local scope
|
||||||
@@ -36,7 +71,7 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f'{name} is not defined')
|
raise ValueError(f'{name} is not defined')
|
||||||
if isinstance(ret, triton.language.block):
|
if isinstance(ret, triton.language.block):
|
||||||
handle = self.module.get_value(name)
|
handle = self.value_constructor.get_value(name)
|
||||||
return triton.language.block(handle)
|
return triton.language.block(handle)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
@@ -44,8 +79,8 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
if isinstance(value, _triton.ir.value):
|
if isinstance(value, _triton.ir.value):
|
||||||
value = triton.language.block(value)
|
value = triton.language.block(value)
|
||||||
if isinstance(value, triton.language.block):
|
if isinstance(value, triton.language.block):
|
||||||
self.module.set_value(name, value.handle)
|
self.value_constructor.set_value(name, value.handle)
|
||||||
self.module.set_type(name, value.handle.type)
|
self.value_constructor.set_type(name, value.handle.type)
|
||||||
self.lscope[name] = value
|
self.lscope[name] = value
|
||||||
|
|
||||||
def is_triton_object(self, value):
|
def is_triton_object(self, value):
|
||||||
@@ -58,16 +93,17 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
break
|
break
|
||||||
return stmts and isinstance(stmt, ast.Return)
|
return stmts and isinstance(stmt, ast.Return)
|
||||||
|
|
||||||
def __init__(self, context, prototype, gscope, attributes, constants, kwargs):
|
def __init__(self, context, prototype, gscope, attributes, constants, module=None, is_kernel=False):
|
||||||
self.builder = _triton.ir.builder(context)
|
self.builder = _triton.ir.builder(context)
|
||||||
self.module = _triton.ir.module('', self.builder)
|
self.value_constructor = _triton.ir.value_constructor(self.builder)
|
||||||
|
self.module = _triton.ir.module('', self.builder) if module is None else module
|
||||||
self.prototype = prototype
|
self.prototype = prototype
|
||||||
self.gscope = gscope
|
self.gscope = gscope
|
||||||
self.lscope = dict()
|
self.lscope = dict()
|
||||||
self.attributes = attributes
|
self.attributes = attributes
|
||||||
self.constants = constants
|
self.constants = constants
|
||||||
self.kwargs = kwargs
|
|
||||||
self.last_node = None
|
self.last_node = None
|
||||||
|
self.is_kernel = is_kernel
|
||||||
self.builtins = {
|
self.builtins = {
|
||||||
'range': range,
|
'range': range,
|
||||||
'min': triton.language.minimum,
|
'min': triton.language.minimum,
|
||||||
@@ -92,9 +128,17 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
ret = self.visit(node.value)
|
ret = self.visit(node.value)
|
||||||
if ret is None:
|
if ret is None:
|
||||||
return self.builder.ret_void()
|
return self.builder.ret_void()
|
||||||
|
if isinstance(ret, _triton.ir.value):
|
||||||
|
ret = self.builder.ret(ret)
|
||||||
return ret
|
return ret
|
||||||
|
if isinstance(ret, triton.language.block):
|
||||||
|
ret = ret.handle
|
||||||
|
if isinstance(ret, triton.language.constexpr):
|
||||||
|
ret = triton.language.core._to_ir(ret, self.builder)
|
||||||
|
# TODO: should return tl.block
|
||||||
|
return self.builder.ret(ret)
|
||||||
|
|
||||||
def visit_FunctionDef(self, node, inline=False, arg_values=None):
|
def visit_FunctionDef(self, node):
|
||||||
arg_names, kwarg_names = self.visit(node.args)
|
arg_names, kwarg_names = self.visit(node.args)
|
||||||
# initialize defaults
|
# initialize defaults
|
||||||
for i, default_value in enumerate(node.args.defaults):
|
for i, default_value in enumerate(node.args.defaults):
|
||||||
@@ -107,13 +151,10 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
else:
|
else:
|
||||||
init_node = ast.AnnAssign(target=st_target, value=default_value, annotation=annotation)
|
init_node = ast.AnnAssign(target=st_target, value=default_value, annotation=annotation)
|
||||||
self.visit(init_node)
|
self.visit(init_node)
|
||||||
# store keyword arguments in local scope
|
|
||||||
self.lscope[kwarg_names] = self.kwargs
|
|
||||||
# initialize function
|
# initialize function
|
||||||
if inline:
|
fn_name = mangle_fn(node.name, self.prototype.arg_tys, self.constants)
|
||||||
pass
|
fn = self.module.get_or_insert_function(fn_name, self.prototype)
|
||||||
else:
|
fn.set_is_kernel(self.is_kernel)
|
||||||
fn = self.module.get_or_insert_function(node.name, self.prototype)
|
|
||||||
arg_values = []
|
arg_values = []
|
||||||
idx = 0
|
idx = 0
|
||||||
for i, arg_name in enumerate(arg_names):
|
for i, arg_name in enumerate(arg_names):
|
||||||
@@ -133,19 +174,21 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
arg_values.append(fn.args[idx])
|
arg_values.append(fn.args[idx])
|
||||||
idx += 1
|
idx += 1
|
||||||
|
|
||||||
|
insert_pt = self.builder.get_insert_block()
|
||||||
|
entry = _triton.ir.basic_block.create(self.builder.context, "entry", fn)
|
||||||
|
self.builder.set_insert_block(entry)
|
||||||
|
self.value_constructor.seal_block(entry)
|
||||||
for arg_name, arg_value in zip(arg_names, arg_values):
|
for arg_name, arg_value in zip(arg_names, arg_values):
|
||||||
self.set_value(arg_name, arg_value)
|
self.set_value(arg_name, arg_value)
|
||||||
if inline:
|
|
||||||
self.visit_compound_statement(node.body)
|
|
||||||
return self.last_ret
|
|
||||||
else:
|
|
||||||
entry = _triton.ir.basic_block.create(self.builder.context, "entry", fn)
|
|
||||||
self.module.seal_block(entry)
|
|
||||||
self.builder.set_insert_block(entry)
|
|
||||||
# visit function body
|
# visit function body
|
||||||
self.visit_compound_statement(node.body)
|
has_ret = self.visit_compound_statement(node.body)
|
||||||
# finalize function
|
# finalize
|
||||||
|
if not has_ret:
|
||||||
self.builder.ret_void()
|
self.builder.ret_void()
|
||||||
|
else:
|
||||||
|
self.module.reset_ret_ty(fn_name, self.last_ret.type)
|
||||||
|
# self.module.reset_ret_type(node.name)
|
||||||
|
self.builder.set_insert_block(insert_pt)
|
||||||
|
|
||||||
def visit_arguments(self, node):
|
def visit_arguments(self, node):
|
||||||
arg_names = []
|
arg_names = []
|
||||||
@@ -186,6 +229,12 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
names = [names]
|
names = [names]
|
||||||
if not isinstance(values, tuple):
|
if not isinstance(values, tuple):
|
||||||
values = [values]
|
values = [values]
|
||||||
|
if isinstance(values[0], _triton.ir.value):
|
||||||
|
struct = values[0]
|
||||||
|
ty = struct.type
|
||||||
|
if ty.is_struct():
|
||||||
|
values = [self.builder.extract_value(struct, i) for i in range(ty.num_types)]
|
||||||
|
assert len(values) == len(names)
|
||||||
for name, value in zip(names, values):
|
for name, value in zip(names, values):
|
||||||
# by default, constexpr are assigned into python variable
|
# by default, constexpr are assigned into python variable
|
||||||
if isinstance(value, triton.language.constexpr):
|
if isinstance(value, triton.language.constexpr):
|
||||||
@@ -215,6 +264,17 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
|
|
||||||
def visit_Tuple(self, node):
|
def visit_Tuple(self, node):
|
||||||
args = [self.visit(x) for x in node.elts]
|
args = [self.visit(x) for x in node.elts]
|
||||||
|
mode = type(args[0])
|
||||||
|
# tuple of values -- create a struct
|
||||||
|
if len(args) > 1 and mode == triton.language.block\
|
||||||
|
and all([type(arg) == mode for arg in args]):
|
||||||
|
args = [arg.handle for arg in args]
|
||||||
|
tys = [arg.type for arg in args]
|
||||||
|
struct_ty = _triton.ir.struct_type.get(tys, True)
|
||||||
|
ret = _triton.ir.undef.get(struct_ty)
|
||||||
|
for i, arg in enumerate(args):
|
||||||
|
ret = self.builder.insert_value(ret, arg, i)
|
||||||
|
return ret
|
||||||
return tuple(args)
|
return tuple(args)
|
||||||
|
|
||||||
def visit_BinOp(self, node):
|
def visit_BinOp(self, node):
|
||||||
@@ -254,9 +314,9 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
then_bb = _triton.ir.basic_block.create(self.builder.context, "then", current_bb.parent)
|
then_bb = _triton.ir.basic_block.create(self.builder.context, "then", current_bb.parent)
|
||||||
else_bb = _triton.ir.basic_block.create(self.builder.context, "else", current_bb.parent) if node.orelse else None
|
else_bb = _triton.ir.basic_block.create(self.builder.context, "else", current_bb.parent) if node.orelse else None
|
||||||
endif_bb = _triton.ir.basic_block.create(self.builder.context, "endif", current_bb.parent)
|
endif_bb = _triton.ir.basic_block.create(self.builder.context, "endif", current_bb.parent)
|
||||||
self.module.seal_block(then_bb)
|
self.value_constructor.seal_block(then_bb)
|
||||||
if else_bb:
|
if else_bb:
|
||||||
self.module.seal_block(else_bb)
|
self.value_constructor.seal_block(else_bb)
|
||||||
self.builder.cond_br(cond.handle, then_bb, else_bb)
|
self.builder.cond_br(cond.handle, then_bb, else_bb)
|
||||||
else:
|
else:
|
||||||
self.builder.cond_br(cond.handle, then_bb, endif_bb)
|
self.builder.cond_br(cond.handle, then_bb, endif_bb)
|
||||||
@@ -271,7 +331,7 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
# TODO: last statement is a terminator?
|
# TODO: last statement is a terminator?
|
||||||
if not is_terminator:
|
if not is_terminator:
|
||||||
self.builder.br(endif_bb)
|
self.builder.br(endif_bb)
|
||||||
self.module.seal_block(endif_bb)
|
self.value_constructor.seal_block(endif_bb)
|
||||||
self.builder.set_insert_block(endif_bb)
|
self.builder.set_insert_block(endif_bb)
|
||||||
else:
|
else:
|
||||||
if isinstance(cond, triton.language.constexpr):
|
if isinstance(cond, triton.language.constexpr):
|
||||||
@@ -350,9 +410,9 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
self.visit_compound_statement(node.body)
|
self.visit_compound_statement(node.body)
|
||||||
continue_fn()
|
continue_fn()
|
||||||
stop_bb = self.builder.get_insert_block()
|
stop_bb = self.builder.get_insert_block()
|
||||||
self.module.seal_block(stop_bb)
|
self.value_constructor.seal_block(stop_bb)
|
||||||
self.module.seal_block(loop_bb)
|
self.value_constructor.seal_block(loop_bb)
|
||||||
self.module.seal_block(next_bb)
|
self.value_constructor.seal_block(next_bb)
|
||||||
self.builder.set_insert_block(next_bb)
|
self.builder.set_insert_block(next_bb)
|
||||||
|
|
||||||
for stmt in node.orelse:
|
for stmt in node.orelse:
|
||||||
@@ -421,9 +481,9 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
# TODO: handle case where body breaks control flow
|
# TODO: handle case where body breaks control flow
|
||||||
continue_fn()
|
continue_fn()
|
||||||
stop_bb = self.builder.get_insert_block()
|
stop_bb = self.builder.get_insert_block()
|
||||||
self.module.seal_block(stop_bb)
|
self.value_constructor.seal_block(stop_bb)
|
||||||
self.module.seal_block(loop_bb)
|
self.value_constructor.seal_block(loop_bb)
|
||||||
self.module.seal_block(next_bb)
|
self.value_constructor.seal_block(next_bb)
|
||||||
self.builder.set_insert_block(next_bb)
|
self.builder.set_insert_block(next_bb)
|
||||||
|
|
||||||
for stmt in node.orelse:
|
for stmt in node.orelse:
|
||||||
@@ -449,15 +509,62 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
for keyword in node.keywords:
|
for keyword in node.keywords:
|
||||||
kws.update(self.visit(keyword))
|
kws.update(self.visit(keyword))
|
||||||
args = [self.visit(arg) for arg in node.args]
|
args = [self.visit(arg) for arg in node.args]
|
||||||
|
|
||||||
if isinstance(fn, JITFunction):
|
if isinstance(fn, JITFunction):
|
||||||
return fn(*args, generator=self, **kws)
|
from inspect import getcallargs
|
||||||
|
args = getcallargs(fn.fn, *args, **kws)
|
||||||
|
args = [args[name] for name in fn.arg_names]
|
||||||
|
args = [arg if isinstance(arg, triton.language.block)
|
||||||
|
else triton.language.constexpr(arg) for arg in args]
|
||||||
|
# generate function def
|
||||||
|
attributes = dict()
|
||||||
|
constexprs = [i for i, arg in enumerate(args) if isinstance(arg, triton.language.constexpr)]
|
||||||
|
constants = {i: args[i] for i in constexprs}
|
||||||
|
# generate call
|
||||||
|
args = [None if i in constexprs else arg for i, arg in enumerate(args)]
|
||||||
|
arg_vals = [arg.handle for arg in args if arg is not None]
|
||||||
|
arg_types = [arg.type for arg in arg_vals]
|
||||||
|
fn_name = mangle_fn(fn.__name__, arg_types, constants)
|
||||||
|
# generate function def if necessary
|
||||||
|
if not self.module.has_function(fn_name):
|
||||||
|
ret_type = _triton.ir.type.get_void(self.builder.context)
|
||||||
|
prototype = _triton.ir.type.make_function(ret_type, arg_types)
|
||||||
|
gscope = sys.modules[fn.fn.__module__].__dict__
|
||||||
|
generator = CodeGenerator(self.builder.context, prototype, gscope, attributes, constants, module=self.module)
|
||||||
|
generator.visit(fn.parse())
|
||||||
|
symbol = self.module.get_function(fn_name)
|
||||||
|
ret = self.builder.call(symbol, arg_vals)
|
||||||
|
if not ret.type.is_void() and not ret.type.is_struct():
|
||||||
|
ret = triton.language.block(ret)
|
||||||
|
return ret
|
||||||
|
# built-in function
|
||||||
if hasattr(fn, '__self__') and self.is_triton_object(fn.__self__) or \
|
if hasattr(fn, '__self__') and self.is_triton_object(fn.__self__) or \
|
||||||
sys.modules[fn.__module__] is triton.language.core:
|
sys.modules[fn.__module__] is triton.language.core:
|
||||||
return fn(*args, _builder=self.builder, **kws)
|
ret = fn(*args, _builder=self.builder, **kws)
|
||||||
if fn in self.builtins.values():
|
if fn in self.builtins.values():
|
||||||
args = [arg.value if isinstance(arg, triton.language.constexpr) else arg
|
args = [arg.value if isinstance(arg, triton.language.constexpr) else arg
|
||||||
for arg in args]
|
for arg in args]
|
||||||
return fn(*args, **kws)
|
ret = fn(*args, **kws)
|
||||||
|
# special case: dynamic parallelism
|
||||||
|
# in this case the core primitive returns a proxy
|
||||||
|
# if isinstance(ret, triton.language.core.LaunchProxy):
|
||||||
|
# ret_type = _triton.ir.type.get_void(self.builder.context)
|
||||||
|
# arg_tys = [x.type for x in ret.args]
|
||||||
|
# prototype = _triton.ir.type.make_function(ret_type, arg_tys)
|
||||||
|
# gscope = sys.modules[ret.fn.fn.__module__].__dict__
|
||||||
|
# constants = ret.constants
|
||||||
|
# fn_name = mangle_fn(ret.fn.__name__, arg_tys, ret.constants)
|
||||||
|
# # TODO: clean-up attributes handling in function
|
||||||
|
# if not self.module.has_function(fn_name):
|
||||||
|
# attributes = {i: list(arg.parent.get_attrs(arg))[0].value for i, arg in enumerate(ret.args) \
|
||||||
|
# if isinstance(arg, _triton.ir.argument) and arg.parent.has_attr(i + 1) }
|
||||||
|
# generator = CodeGenerator(self.builder.context, prototype, gscope, attributes, constants, module=self.module, is_kernel=True)
|
||||||
|
# generator.visit(ret.fn.parse())
|
||||||
|
# symbol = self.module.get_function(fn_name)
|
||||||
|
# # TODO: should ret.args not include any constants ?
|
||||||
|
# ret = self.builder.launch(symbol, ret.args, ret.grid, ret.num_warps)
|
||||||
|
return ret
|
||||||
|
# return fn(*args, **kws)
|
||||||
|
|
||||||
def visit_Constant(self, node):
|
def visit_Constant(self, node):
|
||||||
return triton.language.constexpr(node.value)
|
return triton.language.constexpr(node.value)
|
||||||
@@ -669,6 +776,7 @@ class Kernel:
|
|||||||
|
|
||||||
def add_to_cache(self, key, wargs, device_idx, num_warps, num_stages):
|
def add_to_cache(self, key, wargs, device_idx, num_warps, num_stages):
|
||||||
tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
|
tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
|
||||||
|
|
||||||
# attributes
|
# attributes
|
||||||
attributes = dict()
|
attributes = dict()
|
||||||
for i, arg in enumerate(wargs):
|
for i, arg in enumerate(wargs):
|
||||||
@@ -881,7 +989,7 @@ class JITFunction:
|
|||||||
|
|
||||||
cache_hook = None
|
cache_hook = None
|
||||||
|
|
||||||
def __init__(self, fn, version=None, do_not_specialize=None):
|
def __init__(self, fn, version=None, inline=True, do_not_specialize=None):
|
||||||
# information of wrapped function
|
# information of wrapped function
|
||||||
self.fn = fn
|
self.fn = fn
|
||||||
self.module = fn.__module__
|
self.module = fn.__module__
|
||||||
@@ -890,6 +998,7 @@ class JITFunction:
|
|||||||
self.arg_defaults = [v.default for v in signature.parameters.values()]
|
self.arg_defaults = [v.default for v in signature.parameters.values()]
|
||||||
|
|
||||||
self.version = version
|
self.version = version
|
||||||
|
self.inline = inline
|
||||||
self.src = textwrap.dedent(inspect.getsource(fn))
|
self.src = textwrap.dedent(inspect.getsource(fn))
|
||||||
self.src = self.src[self.src.find("def"):]
|
self.src = self.src[self.src.find("def"):]
|
||||||
self.do_not_specialize = [] if do_not_specialize is None else do_not_specialize
|
self.do_not_specialize = [] if do_not_specialize is None else do_not_specialize
|
||||||
@@ -904,6 +1013,8 @@ class JITFunction:
|
|||||||
# annotations
|
# annotations
|
||||||
self.annotations = {self.arg_names.index(name): ty for name, ty in fn.__annotations__.items()}
|
self.annotations = {self.arg_names.index(name): ty for name, ty in fn.__annotations__.items()}
|
||||||
self.__annotations__ = fn.__annotations__
|
self.__annotations__ = fn.__annotations__
|
||||||
|
# constexprs
|
||||||
|
self.constexprs = [self.arg_names.index(ann) for ann in self.__annotations__.keys()]
|
||||||
# forward docs
|
# forward docs
|
||||||
self.__doc__ = fn.__doc__
|
self.__doc__ = fn.__doc__
|
||||||
self.__name__ = fn.__name__
|
self.__name__ = fn.__name__
|
||||||
@@ -930,31 +1041,8 @@ class JITFunction:
|
|||||||
assert isinstance(tree.body[0], ast.FunctionDef)
|
assert isinstance(tree.body[0], ast.FunctionDef)
|
||||||
return tree
|
return tree
|
||||||
|
|
||||||
def __call__(self, *args, generator: CodeGenerator, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
try:
|
raise RuntimeError("Cannot call @triton.jit'd outside of the scope of a kernel.")
|
||||||
from inspect import getcallargs
|
|
||||||
arg_values = getcallargs(self.fn, *args, **kwargs)
|
|
||||||
arg_values = [arg_values[name] for name in self.arg_names]
|
|
||||||
arg_values = [arg if isinstance(arg, triton.language.block)
|
|
||||||
else triton.language.constexpr(arg) for arg in arg_values]
|
|
||||||
|
|
||||||
gscope = generator.gscope.copy()
|
|
||||||
lscope = generator.lscope.copy()
|
|
||||||
values = generator.module.get_values().copy()
|
|
||||||
types = generator.module.get_types().copy()
|
|
||||||
generator.gscope = sys.modules[self.fn.__module__].__dict__
|
|
||||||
generator.lscope = dict()
|
|
||||||
ret = generator.visit_FunctionDef(self.parse().body[0], inline=True, arg_values=arg_values)
|
|
||||||
generator.gscope = gscope
|
|
||||||
generator.lscope = lscope
|
|
||||||
generator.module.set_values(values)
|
|
||||||
generator.module.set_types(types)
|
|
||||||
return ret
|
|
||||||
except Exception as e:
|
|
||||||
node = generator.last_node
|
|
||||||
if node is None or isinstance(e, (NotImplementedError, CompilationError)):
|
|
||||||
raise e
|
|
||||||
raise CompilationError(self.src, node) from e
|
|
||||||
|
|
||||||
# - when `.src` attribute is set, cache path needs
|
# - when `.src` attribute is set, cache path needs
|
||||||
# to be reinitialized
|
# to be reinitialized
|
||||||
@@ -1039,7 +1127,7 @@ class JITFunction:
|
|||||||
# generate Triton-IR
|
# generate Triton-IR
|
||||||
# export symbols visible from self into code-generator object
|
# export symbols visible from self into code-generator object
|
||||||
gscope = self.__globals__
|
gscope = self.__globals__
|
||||||
generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, kwargs=dict())
|
generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, is_kernel=True)
|
||||||
try:
|
try:
|
||||||
generator.visit(self.parse())
|
generator.visit(self.parse())
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -1199,9 +1287,21 @@ def jit(*args, **kwargs):
|
|||||||
return JITFunction(fn, **kwargs)
|
return JITFunction(fn, **kwargs)
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
######
|
||||||
|
|
||||||
|
# class ForwardDeclaration:
|
||||||
|
|
||||||
|
# def __init__(self, name, ret_ty, arg_tys) -> None:
|
||||||
|
# self.name = name
|
||||||
|
# self.ret_ty = ret_ty
|
||||||
|
# self.arg_tys = arg_tys
|
||||||
|
|
||||||
|
# def forward_declare(name, ret_ty, arg_tys):
|
||||||
|
# return ForwardDeclaration(name, ret_ty, arg_tys)
|
||||||
|
|
||||||
######
|
######
|
||||||
|
|
||||||
|
|
||||||
def cdiv(x, y):
|
def cdiv(x, y):
|
||||||
return (x + y - 1) // y
|
return (x + y - 1) // y
|
||||||
|
|
||||||
|
@@ -888,7 +888,7 @@ def sigmoid(x):
|
|||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
@_add_math_1arg_docstr("softmax")
|
@_add_math_1arg_docstr("softmax")
|
||||||
def softmax(x, ieee_rounding=False):
|
def softmax(x, ieee_rounding: constexpr = False):
|
||||||
z = x - triton.language.max(x, 0)
|
z = x - triton.language.max(x, 0)
|
||||||
num = triton.language.exp(z)
|
num = triton.language.exp(z)
|
||||||
den = triton.language.sum(num, 0)
|
den = triton.language.sum(num, 0)
|
||||||
@@ -942,3 +942,26 @@ def swizzle2d(i, j, size_i, size_j, size_g):
|
|||||||
@triton.jit
|
@triton.jit
|
||||||
def zeros_like(input):
|
def zeros_like(input):
|
||||||
return zeros(input.shape, input.dtype)
|
return zeros(input.shape, input.dtype)
|
||||||
|
# -----------------------
|
||||||
|
# Dynamic Parallelism
|
||||||
|
# -----------------------
|
||||||
|
|
||||||
|
|
||||||
|
class LaunchProxy:
|
||||||
|
|
||||||
|
def __init__(self, fn, args, constants, grid, num_warps) -> None:
|
||||||
|
self.args = args
|
||||||
|
self.grid = grid
|
||||||
|
self.constants = constants
|
||||||
|
self.num_warps = num_warps
|
||||||
|
self.fn = fn
|
||||||
|
|
||||||
|
|
||||||
|
@builtin
|
||||||
|
def launch(fn, args, grid, num_warps=None, _builder=None):
|
||||||
|
constants = {i: x for i, x in enumerate(args) if isinstance(x, constexpr)}
|
||||||
|
args = [_to_ir(x, builder=_builder) for x in args if not isinstance(x, constexpr)]
|
||||||
|
grid = [_to_ir(x, builder=_builder) for x in grid]
|
||||||
|
if num_warps is None:
|
||||||
|
num_warps = _to_ir(4, builder=_builder)
|
||||||
|
return LaunchProxy(fn, args, constants, grid, num_warps)
|
||||||
|
Reference in New Issue
Block a user