[CODEGEN] Various bugfixes and stability improvements in compiler backend (#240)
This commit is contained in:
		| @@ -93,7 +93,22 @@ protected: | |||||||
|   shape_t shape_; |   shape_t shape_; | ||||||
| }; | }; | ||||||
|  |  | ||||||
| class mma_layout: public data_layout { | class distributed_layout: public data_layout{ | ||||||
|  | public: | ||||||
|  |   distributed_layout(id_t id, | ||||||
|  |                      const std::vector<int>& axes, | ||||||
|  |                      const std::vector<unsigned>& shape, | ||||||
|  |                      const std::vector<ir::value*>& values, | ||||||
|  |                      analysis::align* align); | ||||||
|  |  | ||||||
|  |   int shape_per_cta(size_t k) { return shape_per_cta_.at(k); } | ||||||
|  |   int rep_per_cta(size_t k) { return shape_[k] / shape_per_cta_[k]; } | ||||||
|  |  | ||||||
|  | protected: | ||||||
|  |   std::vector<int> shape_per_cta_; | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | class mma_layout: public distributed_layout { | ||||||
| public: | public: | ||||||
|   mma_layout(size_t num_warps, |   mma_layout(size_t num_warps, | ||||||
|                 const std::vector<int>& axes, |                 const std::vector<int>& axes, | ||||||
| @@ -107,7 +122,6 @@ public: | |||||||
|   int fpw(size_t k) { return fpw_.at(k); } |   int fpw(size_t k) { return fpw_.at(k); } | ||||||
|   int wpt(size_t k) { return wpt_.at(k); } |   int wpt(size_t k) { return wpt_.at(k); } | ||||||
|   int spw(size_t k) { return spw_.at(k); } |   int spw(size_t k) { return spw_.at(k); } | ||||||
|   int spt(size_t k) { return spt_.at(k); } |  | ||||||
|   int rep(size_t k) { return rep_.at(k); } |   int rep(size_t k) { return rep_.at(k); } | ||||||
|  |  | ||||||
| private: | private: | ||||||
| @@ -123,7 +137,7 @@ private: | |||||||
|   std::vector<int> rep_; |   std::vector<int> rep_; | ||||||
| }; | }; | ||||||
|  |  | ||||||
| struct scanline_layout: public data_layout { | struct scanline_layout: public distributed_layout { | ||||||
|   scanline_layout(size_t num_warps, |   scanline_layout(size_t num_warps, | ||||||
|                     const std::vector<int>& axes, |                     const std::vector<int>& axes, | ||||||
|                     const std::vector<unsigned>& shape, |                     const std::vector<unsigned>& shape, | ||||||
| @@ -219,6 +233,7 @@ public: | |||||||
|  |  | ||||||
|   // accessors |   // accessors | ||||||
|   unsigned layout_of(ir::value *value) const                  { return groups_.at(value); } |   unsigned layout_of(ir::value *value) const                  { return groups_.at(value); } | ||||||
|  |   bool has(ir::value* value) const { return groups_.find(value) != groups_.end(); } | ||||||
|   const std::vector<ir::value*>& values_of(unsigned id) const { return values_.at(id); } |   const std::vector<ir::value*>& values_of(unsigned id) const { return values_.at(id); } | ||||||
|   size_t num_layouts() const                                  { return values_.size();} |   size_t num_layouts() const                                  { return values_.size();} | ||||||
|   data_layout* get(size_t id)                                 { return layouts_.at(id); } |   data_layout* get(size_t id)                                 { return layouts_.at(id); } | ||||||
| @@ -226,7 +241,7 @@ public: | |||||||
|   std::map<size_t, data_layout*> &get_all()                   { return layouts_; } |   std::map<size_t, data_layout*> &get_all()                   { return layouts_; } | ||||||
|   bool has_tmp(ir::value* i)                                  { return tmp_.find(i) != tmp_.end(); } |   bool has_tmp(ir::value* i)                                  { return tmp_.find(i) != tmp_.end(); } | ||||||
|   int tmp(ir::value* i)                                       { return tmp_.at(i);} |   int tmp(ir::value* i)                                       { return tmp_.at(i);} | ||||||
|  |   void copy(ir::value* dst, ir::value* src)                   { groups_[dst] = groups_[src]; } | ||||||
|   // execution |   // execution | ||||||
|   void run(ir::module &mod); |   void run(ir::module &mod); | ||||||
|  |  | ||||||
|   | |||||||
| @@ -171,7 +171,8 @@ public: | |||||||
|   void visit_reducend_inst(ir::reduce_inst*, std::function<Value*(Value*,Value*)>, Value*); |   void visit_reducend_inst(ir::reduce_inst*, std::function<Value*(Value*,Value*)>, Value*); | ||||||
|   void visit_reduce_inst(ir::reduce_inst*); |   void visit_reduce_inst(ir::reduce_inst*); | ||||||
|   void visit_select_inst(ir::select_inst*); |   void visit_select_inst(ir::select_inst*); | ||||||
|   void visit_recoalesce_inst(ir::recoalesce_inst*); |   void visit_layout_convert(ir::value *out, ir::value *in); | ||||||
|  |   void visit_cvt_layout_inst(ir::cvt_layout_inst*); | ||||||
|   void visit_masked_load_async_inst(ir::masked_load_async_inst*); |   void visit_masked_load_async_inst(ir::masked_load_async_inst*); | ||||||
|   void visit_copy_to_shared_inst(ir::copy_to_shared_inst*); |   void visit_copy_to_shared_inst(ir::copy_to_shared_inst*); | ||||||
|   void visit_copy_from_shared_inst(ir::copy_from_shared_inst*); |   void visit_copy_from_shared_inst(ir::copy_from_shared_inst*); | ||||||
|   | |||||||
| @@ -33,6 +33,7 @@ private: | |||||||
|  |  | ||||||
| public: | public: | ||||||
|   coalesce(analysis::align* align, triton::codegen::analysis::layouts *layouts); |   coalesce(analysis::align* align, triton::codegen::analysis::layouts *layouts); | ||||||
|  |   triton::ir::value *simplify(ir::instruction* i, triton::ir::builder &builder); | ||||||
|   void run(ir::module &mod); |   void run(ir::module &mod); | ||||||
|  |  | ||||||
| private: | private: | ||||||
|   | |||||||
| @@ -34,8 +34,7 @@ private: | |||||||
|   bool rewrite_gep_ptr_min_off_plus_off(ir::instruction *value, ir::builder& builder); |   bool rewrite_gep_ptr_min_off_plus_off(ir::instruction *value, ir::builder& builder); | ||||||
|   bool rewrite_select_masked_load(ir::instruction *value, ir::builder& builder); |   bool rewrite_select_masked_load(ir::instruction *value, ir::builder& builder); | ||||||
|   bool rewrite_load_to_shared(ir::instruction *value, ir::builder& builder); |   bool rewrite_load_to_shared(ir::instruction *value, ir::builder& builder); | ||||||
|  |   bool rewrite_cvt_layout(ir::instruction *value, ir::builder& builder); | ||||||
| private: |  | ||||||
|  |  | ||||||
| public: | public: | ||||||
|   peephole(target* tgt, analysis::layouts* layouts): tgt_(tgt), layouts_(layouts) {} |   peephole(target* tgt, analysis::layouts* layouts): tgt_(tgt), layouts_(layouts) {} | ||||||
|   | |||||||
| @@ -60,136 +60,151 @@ protected: | |||||||
| public: | public: | ||||||
|   static bool nvmlinit(); |   static bool nvmlinit(); | ||||||
|   static bool cuinit(); |   static bool cuinit(); | ||||||
|   static bool spvllvminit(); |  | ||||||
|   static void release(); |   static void release(); | ||||||
|  |  | ||||||
|   // CUDA |   /* ------------------- * | ||||||
|  |    * CUDA | ||||||
|  |    * ------------------- */ | ||||||
|  |   // context management | ||||||
|  |   static CUresult cuInit(unsigned int Flags); | ||||||
|   static CUresult cuCtxGetCurrent(CUcontext *pctx); |   static CUresult cuCtxGetCurrent(CUcontext *pctx); | ||||||
|   static CUresult cuCtxSetCurrent(CUcontext ctx); |   static CUresult cuCtxSetCurrent(CUcontext ctx); | ||||||
|   static CUresult cuCtxDestroy_v2(CUcontext ctx); |   static CUresult cuCtxDestroy_v2(CUcontext ctx); | ||||||
|   static CUresult cuEventCreate(CUevent *phEvent, unsigned int Flags); |   static CUresult cuCtxCreate_v2(CUcontext *pctx, unsigned int flags, CUdevice dev); | ||||||
|   static CUresult cuDeviceGet(CUdevice *device, int ordinal); |   static CUresult cuCtxPushCurrent_v2(CUcontext ctx); | ||||||
|   static CUresult cuMemcpyDtoH_v2(void *dstHost, CUdeviceptr srcDevice, size_t ByteCount); |   static CUresult cuCtxPopCurrent_v2(CUcontext *pctx); | ||||||
|   static CUresult cuStreamCreate(CUstream *phStream, unsigned int Flags); |   static CUresult cuCtxGetDevice(CUdevice* result); | ||||||
|   static CUresult cuEventElapsedTime(float *pMilliseconds, CUevent hStart, CUevent hEnd); |   static CUresult cuCtxEnablePeerAccess(CUcontext peerContext, unsigned int flags); | ||||||
|   static CUresult cuMemFree_v2(CUdeviceptr dptr); |  | ||||||
|   static CUresult cuMemcpyDtoHAsync_v2(void *dstHost, CUdeviceptr srcDevice, size_t ByteCount, CUstream hStream); |  | ||||||
|   static CUresult cuDriverGetVersion(int *driverVersion); |   static CUresult cuDriverGetVersion(int *driverVersion); | ||||||
|  |   // device management | ||||||
|  |   static CUresult cuDeviceGet(CUdevice *device, int ordinal); | ||||||
|   static CUresult cuDeviceGetName(char *name, int len, CUdevice dev); |   static CUresult cuDeviceGetName(char *name, int len, CUdevice dev); | ||||||
|   static CUresult cuDeviceGetPCIBusId(char *id, int len, CUdevice dev); |   static CUresult cuDeviceGetPCIBusId(char *id, int len, CUdevice dev); | ||||||
|   static CUresult cuModuleGetGlobal_v2(CUdeviceptr *dptr, size_t* bytes, CUmodule hmod, const char *name); |   static CUresult cuDeviceGetAttribute(int *pi, CUdevice_attribute attrib, CUdevice dev); | ||||||
|   static CUresult cuMemcpyHtoDAsync_v2(CUdeviceptr dstDevice, const void *srcHost, size_t ByteCount, CUstream hStream); |   static CUresult cuDeviceGetCount(int *count); | ||||||
|   static CUresult cuModuleLoad(CUmodule *module, const char *fname); |   // link management | ||||||
|   static CUresult cuModuleLoadData(CUmodule* module, const void* image); |  | ||||||
|   static CUresult cuLaunchKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes, CUstream hStream, void **kernelParams, void **extra); |  | ||||||
|   static CUresult cuModuleUnload(CUmodule hmod); |  | ||||||
|   static CUresult cuModuleLoadDataEx(CUmodule *module, const void *image, unsigned int numOptions, CUjit_option *options, void **optionValues); |  | ||||||
|  |  | ||||||
|   static CUresult cuLinkAddData_v2(CUlinkState state, CUjitInputType type, void* data, size_t size, const char* name, unsigned int numOptions, CUjit_option* options, void** optionValues); |   static CUresult cuLinkAddData_v2(CUlinkState state, CUjitInputType type, void* data, size_t size, const char* name, unsigned int numOptions, CUjit_option* options, void** optionValues); | ||||||
|   static CUresult cuLinkCreate_v2(unsigned int  numOptions, CUjit_option* options, void** optionValues, CUlinkState* stateOut); |   static CUresult cuLinkCreate_v2(unsigned int  numOptions, CUjit_option* options, void** optionValues, CUlinkState* stateOut); | ||||||
|   static CUresult cuLinkComplete(CUlinkState state, void** cubinOut, size_t* sizeOut); |   static CUresult cuLinkComplete(CUlinkState state, void** cubinOut, size_t* sizeOut); | ||||||
|   static CUresult cuLinkDestroy(CUlinkState state); |   static CUresult cuLinkDestroy(CUlinkState state); | ||||||
|  |   // module management | ||||||
|   static CUresult cuDeviceGetAttribute(int *pi, CUdevice_attribute attrib, CUdevice dev); |   static CUresult cuModuleGetGlobal_v2(CUdeviceptr *dptr, size_t* bytes, CUmodule hmod, const char *name); | ||||||
|   static CUresult cuDeviceGetCount(int *count); |   static CUresult cuModuleLoad(CUmodule *module, const char *fname); | ||||||
|   static CUresult cuMemcpyHtoD_v2(CUdeviceptr dstDevice, const void *srcHost, size_t ByteCount); |   static CUresult cuModuleLoadData(CUmodule* module, const void* image); | ||||||
|   static CUresult cuInit(unsigned int Flags); |   static CUresult cuModuleUnload(CUmodule hmod); | ||||||
|   static CUresult cuEventRecord(CUevent hEvent, CUstream hStream); |   static CUresult cuModuleLoadDataEx(CUmodule *module, const void *image, unsigned int numOptions, CUjit_option *options, void **optionValues); | ||||||
|   static CUresult cuCtxCreate_v2(CUcontext *pctx, unsigned int flags, CUdevice dev); |  | ||||||
|   static CUresult cuCtxPushCurrent_v2(CUcontext ctx); |  | ||||||
|   static CUresult cuCtxPopCurrent_v2(CUcontext *pctx); |  | ||||||
|   static CUresult cuModuleGetFunction(CUfunction *hfunc, CUmodule hmod, const char *name); |   static CUresult cuModuleGetFunction(CUfunction *hfunc, CUmodule hmod, const char *name); | ||||||
|  |   // stream management | ||||||
|  |   static CUresult cuStreamCreate(CUstream *phStream, unsigned int Flags); | ||||||
|   static CUresult cuStreamSynchronize(CUstream hStream); |   static CUresult cuStreamSynchronize(CUstream hStream); | ||||||
|   static CUresult cuStreamGetCtx(CUstream hStream, CUcontext* pctx); |   static CUresult cuStreamGetCtx(CUstream hStream, CUcontext* pctx); | ||||||
|   static CUresult cuStreamDestroy_v2(CUstream hStream); |   static CUresult cuStreamDestroy_v2(CUstream hStream); | ||||||
|   static CUresult cuEventDestroy_v2(CUevent hEvent); |   static CUresult cuLaunchKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes, CUstream hStream, void **kernelParams, void **extra); | ||||||
|   static CUresult cuMemAlloc_v2(CUdeviceptr *dptr, size_t bytesize); |   // function management | ||||||
|   static CUresult cuPointerGetAttribute(void * data, CUpointer_attribute attribute, CUdeviceptr ptr); |  | ||||||
|   static CUresult cuCtxGetDevice(CUdevice* result); |  | ||||||
|   static CUresult cuMemsetD8Async(CUdeviceptr dst, unsigned char x, size_t N, CUstream stream); |  | ||||||
|   static CUresult cuFuncGetAttribute(int* pi, CUfunction_attribute attrib, CUfunction hfunc); |   static CUresult cuFuncGetAttribute(int* pi, CUfunction_attribute attrib, CUfunction hfunc); | ||||||
|   static CUresult cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value); |   static CUresult cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value); | ||||||
|   static CUresult cuFuncSetCacheConfig(CUfunction hfunc, CUfunc_cache config); |   static CUresult cuFuncSetCacheConfig(CUfunction hfunc, CUfunc_cache config); | ||||||
|   static CUresult cuCtxEnablePeerAccess(CUcontext peerContext, unsigned int flags); |   // memory management | ||||||
|   // NVML |   static CUresult cuMemAlloc_v2(CUdeviceptr *dptr, size_t bytesize); | ||||||
|  |   static CUresult cuPointerGetAttribute(void * data, CUpointer_attribute attribute, CUdeviceptr ptr); | ||||||
|  |   static CUresult cuMemsetD8Async(CUdeviceptr dst, unsigned char x, size_t N, CUstream stream); | ||||||
|  |   static CUresult cuMemcpyDtoH_v2(void *dstHost, CUdeviceptr srcDevice, size_t ByteCount); | ||||||
|  |   static CUresult cuMemFree_v2(CUdeviceptr dptr); | ||||||
|  |   static CUresult cuMemcpyDtoHAsync_v2(void *dstHost, CUdeviceptr srcDevice, size_t ByteCount, CUstream hStream); | ||||||
|  |   static CUresult cuMemcpyHtoDAsync_v2(CUdeviceptr dstDevice, const void *srcHost, size_t ByteCount, CUstream hStream); | ||||||
|  |   static CUresult cuMemcpyHtoD_v2(CUdeviceptr dstDevice, const void *srcHost, size_t ByteCount); | ||||||
|  |   // event management | ||||||
|  |   static CUresult cuEventCreate(CUevent *phEvent, unsigned int Flags); | ||||||
|  |   static CUresult cuEventElapsedTime(float *pMilliseconds, CUevent hStart, CUevent hEnd); | ||||||
|  |   static CUresult cuEventRecord(CUevent hEvent, CUstream hStream); | ||||||
|  |   static CUresult cuEventDestroy_v2(CUevent hEvent); | ||||||
|  |  | ||||||
|  |  | ||||||
|  |   /* ------------------- * | ||||||
|  |    * NVML | ||||||
|  |    * ------------------- */ | ||||||
|   static nvmlReturn_t nvmlDeviceGetHandleByPciBusId_v2( const char* pciBusId, nvmlDevice_t* device); |   static nvmlReturn_t nvmlDeviceGetHandleByPciBusId_v2( const char* pciBusId, nvmlDevice_t* device); | ||||||
|   static nvmlReturn_t nvmlDeviceGetClockInfo(nvmlDevice_t device, nvmlClockType_t type, unsigned int *clock); |   static nvmlReturn_t nvmlDeviceGetClockInfo(nvmlDevice_t device, nvmlClockType_t type, unsigned int *clock); | ||||||
|   static nvmlReturn_t nvmlDeviceGetMaxClockInfo(nvmlDevice_t device, nvmlClockType_t type, unsigned int *clock); |   static nvmlReturn_t nvmlDeviceGetMaxClockInfo(nvmlDevice_t device, nvmlClockType_t type, unsigned int *clock); | ||||||
|   static nvmlReturn_t nvmlDeviceSetApplicationsClocks(nvmlDevice_t device, unsigned int mem_clock, unsigned int sm_clock); |   static nvmlReturn_t nvmlDeviceSetApplicationsClocks(nvmlDevice_t device, unsigned int mem_clock, unsigned int sm_clock); | ||||||
|  |  | ||||||
|  |  | ||||||
|   // SPIR-V libraries |  | ||||||
|   static int initializeLLVMToSPIRVPass(llvm::PassRegistry &); |  | ||||||
|   static bool writeSpirv(llvm::Module *M, std::ostream &OS, std::string &ErrMsg); |  | ||||||
|  |  | ||||||
|  |  | ||||||
| private: | private: | ||||||
|  |  | ||||||
|   // Libraries |   // Libraries | ||||||
|   static void* cuda_; |   static void* cuda_; | ||||||
|   static void* nvml_; |   static void* nvml_; | ||||||
|   static void* vulkan_; |  | ||||||
|   static void* spvllvm_; |  | ||||||
|   static void* spvcross_; |  | ||||||
|   static void* opengl_; |  | ||||||
|  |  | ||||||
|  |  | ||||||
|   // CUDA functions |   /* ------------------- * | ||||||
|  |    * CUDA | ||||||
|  |    * ------------------- */ | ||||||
|  |   // context management | ||||||
|   static void* cuCtxGetCurrent_; |   static void* cuCtxGetCurrent_; | ||||||
|   static void* cuCtxSetCurrent_; |   static void* cuCtxSetCurrent_; | ||||||
|   static void* cuCtxDestroy_v2_; |   static void* cuCtxDestroy_v2_; | ||||||
|   static void* cuEventCreate_; |   static void* cuCtxCreate_v2_; | ||||||
|   static void* cuDeviceGet_; |   static void* cuCtxGetDevice_; | ||||||
|   static void* cuMemcpyDtoH_v2_; |   static void* cuCtxPushCurrent_v2_; | ||||||
|   static void* cuStreamCreate_; |   static void* cuCtxPopCurrent_v2_; | ||||||
|   static void* cuEventElapsedTime_; |   static void* cuCtxEnablePeerAccess_; | ||||||
|   static void* cuMemFree_v2_; |  | ||||||
|   static void* cuMemcpyDtoHAsync_v2_; |  | ||||||
|   static void* cuDriverGetVersion_; |   static void* cuDriverGetVersion_; | ||||||
|  |   static void* cuInit_; | ||||||
|  |   // device management | ||||||
|  |   static void* cuDeviceGet_; | ||||||
|   static void* cuDeviceGetName_; |   static void* cuDeviceGetName_; | ||||||
|   static void* cuDeviceGetPCIBusId_; |   static void* cuDeviceGetPCIBusId_; | ||||||
|   static void* cuModuleGetGlobal_v2_; |   static void* cuDeviceGetAttribute_; | ||||||
|   static void* cuMemcpyHtoDAsync_v2_; |   static void* cuDeviceGetCount_; | ||||||
|   static void* cuModuleLoad_; |   // link management | ||||||
|   static void* cuLaunchKernel_; |  | ||||||
|   static void* cuModuleUnload_; |  | ||||||
|   static void* cuModuleLoadDataEx_; |  | ||||||
|   static void* cuLinkAddData_v2_; |   static void* cuLinkAddData_v2_; | ||||||
|   static void* cuLinkCreate_v2_; |   static void* cuLinkCreate_v2_; | ||||||
|   static void* cuLinkDestroy_; |   static void* cuLinkDestroy_; | ||||||
|   static void* cuModuleLoadData_; |  | ||||||
|   static void* cuLinkComplete_; |   static void* cuLinkComplete_; | ||||||
|   static void* cuDeviceGetAttribute_; |   // module management | ||||||
|   static void* cuDeviceGetCount_; |   static void* cuModuleGetGlobal_v2_; | ||||||
|   static void* cuMemcpyHtoD_v2_; |   static void* cuModuleLoad_; | ||||||
|   static void* cuInit_; |   static void* cuModuleUnload_; | ||||||
|   static void* cuEventRecord_; |   static void* cuModuleLoadDataEx_; | ||||||
|   static void* cuCtxCreate_v2_; |   static void* cuModuleLoadData_; | ||||||
|   static void* cuModuleGetFunction_; |   static void* cuModuleGetFunction_; | ||||||
|  |   // stream management | ||||||
|  |   static void* cuStreamCreate_; | ||||||
|   static void* cuStreamSynchronize_; |   static void* cuStreamSynchronize_; | ||||||
|   static void* cuStreamDestroy_v2_; |   static void* cuStreamDestroy_v2_; | ||||||
|   static void* cuStreamGetCtx_; |   static void* cuStreamGetCtx_; | ||||||
|   static void* cuEventDestroy_v2_; |   static void* cuLaunchKernel_; | ||||||
|   static void* cuMemAlloc_v2_; |   // function management | ||||||
|   static void* cuPointerGetAttribute_; |  | ||||||
|   static void* cuCtxGetDevice_; |  | ||||||
|   static void* cuMemsetD8Async_; |  | ||||||
|   static void* cuCtxPushCurrent_v2_; |  | ||||||
|   static void* cuCtxPopCurrent_v2_; |  | ||||||
|   static void* cuFuncGetAttribute_; |   static void* cuFuncGetAttribute_; | ||||||
|   static void* cuFuncSetAttribute_; |   static void* cuFuncSetAttribute_; | ||||||
|   static void* cuFuncSetCacheConfig_; |   static void* cuFuncSetCacheConfig_; | ||||||
|   static void* cuCtxEnablePeerAccess_; |   // memory management | ||||||
|   // NVML |   static void* cuMemcpyDtoH_v2_; | ||||||
|  |   static void* cuMemFree_v2_; | ||||||
|  |   static void* cuMemcpyDtoHAsync_v2_; | ||||||
|  |   static void* cuMemcpyHtoDAsync_v2_; | ||||||
|  |   static void* cuMemcpyHtoD_v2_; | ||||||
|  |   static void* cuMemAlloc_v2_; | ||||||
|  |   static void* cuMemsetD8Async_; | ||||||
|  |   static void* cuPointerGetAttribute_; | ||||||
|  |   // event management | ||||||
|  |   static void* cuEventCreate_; | ||||||
|  |   static void* cuEventElapsedTime_; | ||||||
|  |   static void* cuEventRecord_; | ||||||
|  |   static void* cuEventDestroy_v2_; | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  |   /* ------------------- * | ||||||
|  |    * NVML | ||||||
|  |    * ------------------- */ | ||||||
|   static void* nvmlInit_v2_; |   static void* nvmlInit_v2_; | ||||||
|   static void* nvmlDeviceGetHandleByPciBusId_v2_; |   static void* nvmlDeviceGetHandleByPciBusId_v2_; | ||||||
|   static void* nvmlDeviceGetClockInfo_; |   static void* nvmlDeviceGetClockInfo_; | ||||||
|   static void* nvmlDeviceGetMaxClockInfo_; |   static void* nvmlDeviceGetMaxClockInfo_; | ||||||
|   static void* nvmlDeviceSetApplicationsClocks_; |   static void* nvmlDeviceSetApplicationsClocks_; | ||||||
|  |  | ||||||
|   // LLVM to SPIR-V |  | ||||||
|   static void* initializeLLVMToSPIRVPass_; |  | ||||||
|   static void* writeSpirv_; |  | ||||||
| }; | }; | ||||||
|  |  | ||||||
| } | } | ||||||
|   | |||||||
| @@ -153,6 +153,9 @@ enum value_id_t: unsigned { | |||||||
|   // intrinsics |   // intrinsics | ||||||
|   INST_COPY_TO_SHARED, |   INST_COPY_TO_SHARED, | ||||||
|   INST_COPY_FROM_SHARED, |   INST_COPY_FROM_SHARED, | ||||||
|  |   INST_CVT_LAYOUT, | ||||||
|  |   INST_CVT_SCANLINE, | ||||||
|  |   INST_DECOALESCE, | ||||||
|   INST_RECOALESCE, |   INST_RECOALESCE, | ||||||
|   INST_BARRIER, |   INST_BARRIER, | ||||||
|   INST_ASYNC_WAIT, |   INST_ASYNC_WAIT, | ||||||
|   | |||||||
| @@ -807,16 +807,15 @@ public: | |||||||
|   _TRITON_DEFINE_ACCEPT(copy_from_shared_inst) |   _TRITON_DEFINE_ACCEPT(copy_from_shared_inst) | ||||||
| }; | }; | ||||||
|  |  | ||||||
|  | class cvt_layout_inst: public unary_inst { | ||||||
| class recoalesce_inst: public unary_inst{ |  | ||||||
| private: | private: | ||||||
|   using unary_inst::unary_inst; |   using unary_inst::unary_inst; | ||||||
|   std::string repr_impl() const { return "recoalesce_inst"; } |   std::string repr_impl() const { return "cvt_layout_inst"; } | ||||||
|  |  | ||||||
| public: | public: | ||||||
|   static recoalesce_inst* create(value *arg, const std::string &name = "", instruction *next = nullptr); |   static cvt_layout_inst* create(value *arg, const std::string &name = "", instruction *next = nullptr); | ||||||
|   _TRITON_DEFINE_CLONE(recoalesce_inst) |   _TRITON_DEFINE_CLONE(cvt_layout_inst) | ||||||
|   _TRITON_DEFINE_ACCEPT(recoalesce_inst) |   _TRITON_DEFINE_ACCEPT(cvt_layout_inst) | ||||||
| }; | }; | ||||||
|  |  | ||||||
| class barrier_inst: public instruction{ | class barrier_inst: public instruction{ | ||||||
|   | |||||||
| @@ -64,7 +64,7 @@ class sqrt_inst; | |||||||
| class reduce_inst; | class reduce_inst; | ||||||
| class select_inst; | class select_inst; | ||||||
|  |  | ||||||
| class recoalesce_inst; | class cvt_layout_inst; | ||||||
| class copy_to_shared_inst; | class copy_to_shared_inst; | ||||||
| class copy_from_shared_inst; | class copy_from_shared_inst; | ||||||
| class masked_load_async_inst; | class masked_load_async_inst; | ||||||
| @@ -142,9 +142,11 @@ public: | |||||||
|   virtual void visit_reduce_inst(reduce_inst*) = 0; |   virtual void visit_reduce_inst(reduce_inst*) = 0; | ||||||
|   virtual void visit_select_inst(select_inst*) = 0; |   virtual void visit_select_inst(select_inst*) = 0; | ||||||
|  |  | ||||||
|   virtual void visit_recoalesce_inst(recoalesce_inst*) = 0; |   virtual void visit_cvt_layout_inst(cvt_layout_inst*) = 0; | ||||||
|   virtual void visit_copy_to_shared_inst(copy_to_shared_inst*) = 0; |   virtual void visit_copy_to_shared_inst(copy_to_shared_inst*) = 0; | ||||||
|   virtual void visit_copy_from_shared_inst(copy_from_shared_inst*) = 0; |   virtual void visit_copy_from_shared_inst(copy_from_shared_inst*) = 0; | ||||||
|  |  | ||||||
|  |  | ||||||
|   virtual void visit_masked_load_async_inst(masked_load_async_inst*)= 0; |   virtual void visit_masked_load_async_inst(masked_load_async_inst*)= 0; | ||||||
|   virtual void visit_barrier_inst(barrier_inst*) = 0; |   virtual void visit_barrier_inst(barrier_inst*) = 0; | ||||||
|   virtual void visit_async_wait_inst(async_wait_inst*) = 0; |   virtual void visit_async_wait_inst(async_wait_inst*) = 0; | ||||||
|   | |||||||
| @@ -6,6 +6,7 @@ | |||||||
| #include <map> | #include <map> | ||||||
| #include <set> | #include <set> | ||||||
| #include <vector> | #include <vector> | ||||||
|  | #include <iostream> | ||||||
|  |  | ||||||
| namespace triton { | namespace triton { | ||||||
| namespace tools{ | namespace tools{ | ||||||
| @@ -40,9 +41,10 @@ public: | |||||||
|       nmap->clear(); |       nmap->clear(); | ||||||
|     std::set<node_t> nodes = nodes_; |     std::set<node_t> nodes = nodes_; | ||||||
|     unsigned id = 0; |     unsigned id = 0; | ||||||
|     while(!nodes.empty()) |     while(!nodes.empty()){ | ||||||
|       connected_components_impl(*nodes.begin(), nodes, nmap, cmap, id++); |       connected_components_impl(*nodes.begin(), nodes, nmap, cmap, id++); | ||||||
|     } |     } | ||||||
|  |   } | ||||||
|  |  | ||||||
|   void add_edge(node_t x, node_t y) { |   void add_edge(node_t x, node_t y) { | ||||||
|     nodes_.insert(x); |     nodes_.insert(x); | ||||||
|   | |||||||
| @@ -50,7 +50,6 @@ void allocation::run(ir::module &mod) { | |||||||
|       J.erase(j_it); |       J.erase(j_it); | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   // Build interference graph |   // Build interference graph | ||||||
|   std::map<shared_layout*, std::set<shared_layout*>> interferences; |   std::map<shared_layout*, std::set<shared_layout*>> interferences; | ||||||
|   for(shared_layout* x: V) |   for(shared_layout* x: V) | ||||||
| @@ -66,13 +65,10 @@ void allocation::run(ir::module &mod) { | |||||||
|         && XS.intersect(YS)) |         && XS.intersect(YS)) | ||||||
|       interferences[x].insert(y); |       interferences[x].insert(y); | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   // Initialize colors |   // Initialize colors | ||||||
|   std::map<shared_layout*, int> colors; |   std::map<shared_layout*, int> colors; | ||||||
|   for(shared_layout* X: V) |   for(shared_layout* X: V) | ||||||
|     colors[X] = (X==V[0])?0:-1; |     colors[X] = (X==V[0])?0:-1; | ||||||
|  |  | ||||||
|  |  | ||||||
|   // First-fit graph coloring |   // First-fit graph coloring | ||||||
|   std::vector<bool> available(V.size()); |   std::vector<bool> available(V.size()); | ||||||
|   for(shared_layout* x: V){ |   for(shared_layout* x: V){ | ||||||
| @@ -87,7 +83,6 @@ void allocation::run(ir::module &mod) { | |||||||
|     auto It = std::find(available.begin(), available.end(), true); |     auto It = std::find(available.begin(), available.end(), true); | ||||||
|     colors[x] = std::distance(available.begin(), It); |     colors[x] = std::distance(available.begin(), It); | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   // Finalize allocation |   // Finalize allocation | ||||||
|   for(shared_layout* x: V){ |   for(shared_layout* x: V){ | ||||||
|     unsigned Adj = 0; |     unsigned Adj = 0; | ||||||
| @@ -95,7 +90,6 @@ void allocation::run(ir::module &mod) { | |||||||
|       Adj = std::max<unsigned>(Adj, starts[y] + y->get_size()); |       Adj = std::max<unsigned>(Adj, starts[y] + y->get_size()); | ||||||
|     offsets_[x] = starts[x] + colors[x] * Adj; |     offsets_[x] = starts[x] + colors[x] * Adj; | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   // Save maximum size of induced memory space |   // Save maximum size of induced memory space | ||||||
|   allocated_size_ = 0; |   allocated_size_ = 0; | ||||||
|   for(shared_layout* x: V) |   for(shared_layout* x: V) | ||||||
|   | |||||||
| @@ -112,9 +112,9 @@ void axes::update_graph(ir::instruction *i) { | |||||||
|     case ir::INST_BROADCAST:         return update_graph_broadcast(i); |     case ir::INST_BROADCAST:         return update_graph_broadcast(i); | ||||||
|     case ir::INST_DOT:               return update_graph_dot(i); |     case ir::INST_DOT:               return update_graph_dot(i); | ||||||
|     case ir::INST_COPY_TO_SHARED:    return update_graph_no_edge(i); |     case ir::INST_COPY_TO_SHARED:    return update_graph_no_edge(i); | ||||||
|     case ir::INST_MASKED_LOAD_ASYNC:return update_graph_elementwise(i, false); |     case ir::INST_MASKED_LOAD_ASYNC: return update_graph_elementwise(i, false); | ||||||
|     case ir::INST_COPY_FROM_SHARED:  return update_graph_no_edge(i); |     case ir::INST_COPY_FROM_SHARED:  return update_graph_no_edge(i); | ||||||
|     case ir::INST_RECOALESCE:       return update_graph_no_edge(i); |     case ir::INST_CVT_LAYOUT:        return update_graph_no_edge(i); | ||||||
|     default:                         return update_graph_elementwise(i); |     default:                         return update_graph_elementwise(i); | ||||||
|   } |   } | ||||||
|   return; |   return; | ||||||
| @@ -135,11 +135,15 @@ std::vector<int> axes::get(ir::value *value) { | |||||||
| void axes::run(ir::module &mod) { | void axes::run(ir::module &mod) { | ||||||
|   // make graph |   // make graph | ||||||
|   graph_.clear(); |   graph_.clear(); | ||||||
|  |   axes_.clear(); | ||||||
|   ir::for_each_instruction(mod, [this](ir::instruction *x) { |   ir::for_each_instruction(mod, [this](ir::instruction *x) { | ||||||
|     update_graph(x); |     update_graph(x); | ||||||
|   }); |   }); | ||||||
|   // find connected components |   // find connected components | ||||||
|   graph_.connected_components(nullptr, &axes_); |   graph_.connected_components(nullptr, &axes_); | ||||||
|  |   std::set<size_t> uniq; | ||||||
|  |   for(auto x: axes_) | ||||||
|  |     uniq.insert(x.second); | ||||||
| } | } | ||||||
|  |  | ||||||
| } | } | ||||||
|   | |||||||
| @@ -109,9 +109,6 @@ data_layout::data_layout(id_t id, | |||||||
|         max_contiguous = curr; |         max_contiguous = curr; | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
|   bool is_recoalesce = false; |  | ||||||
|   for(ir::value* v: values) |  | ||||||
|     is_recoalesce = is_recoalesce || dynamic_cast<ir::recoalesce_inst*>(v); |  | ||||||
|   if(max_contiguous.size() > 0){ |   if(max_contiguous.size() > 0){ | ||||||
|     std::sort(order_.begin(), order_.end(), [&](unsigned a, unsigned b) { |     std::sort(order_.begin(), order_.end(), [&](unsigned a, unsigned b) { | ||||||
|       return max_contiguous[a] > max_contiguous[b]; |       return max_contiguous[a] > max_contiguous[b]; | ||||||
| @@ -129,6 +126,13 @@ int data_layout::find_axis(int to_find) const { | |||||||
| } | } | ||||||
|  |  | ||||||
|  |  | ||||||
|  | distributed_layout::distributed_layout(id_t id, | ||||||
|  |                          const std::vector<int> &axes, | ||||||
|  |                          const std::vector<unsigned> &shape, | ||||||
|  |                          const std::vector<ir::value *> &values, | ||||||
|  |                          analysis::align* align): data_layout(id, axes, shape, values, align) | ||||||
|  | { } | ||||||
|  |  | ||||||
| /* -------------------------------- * | /* -------------------------------- * | ||||||
|  *           MMA Layout             * |  *           MMA Layout             * | ||||||
|  * -------------------------------- */ |  * -------------------------------- */ | ||||||
| @@ -138,20 +142,11 @@ mma_layout::mma_layout(size_t num_warps, | |||||||
|                        const std::vector<unsigned>& shape, |                        const std::vector<unsigned>& shape, | ||||||
|                        const std::vector<ir::value *> &values, |                        const std::vector<ir::value *> &values, | ||||||
|                        analysis::align* align, target* tgt, |                        analysis::align* align, target* tgt, | ||||||
|                        shared_layout *layout_a, shared_layout *layout_b): data_layout(MMA, axes, shape, values, align) { |                        shared_layout *layout_a, shared_layout *layout_b): distributed_layout(MMA, axes, shape, values, align) { | ||||||
|   /* fragments per warp */ |   /* fragments per warp */ | ||||||
|   // try to make things as square as possible to maximize data re-use |   // try to make things as square as possible to maximize data re-use | ||||||
|   if(tgt->as_nvidia()->sm() < 80){ |   if(tgt->as_nvidia()->sm() < 80){ | ||||||
|     fpw_ = {2, 2, 1}; |     fpw_ = {2, 2, 1}; | ||||||
| //    std::vector<int> fpw_nm1; |  | ||||||
| //    unsigned num_fragments = std::min<unsigned>((shape_[0]/8)*(shape_[1]/8), 4); |  | ||||||
| //    do { |  | ||||||
| //      fpw_nm1 = fpw_; |  | ||||||
| //      if(fpw_[0]*fpw_[1] < num_fragments) |  | ||||||
| //        fpw_[0] = clamp(fpw_[0]*2, 1, shape_[0] / 8); |  | ||||||
| //      if(fpw_[0]*fpw_[1] < num_fragments) |  | ||||||
| //        fpw_[1] = clamp(fpw_[1]*2, 1, shape_[1] / 8); |  | ||||||
| //    }while(fpw_nm1 != fpw_); |  | ||||||
|     auto ord_a = layout_a->get_order(); |     auto ord_a = layout_a->get_order(); | ||||||
|     auto ord_b = layout_b->get_order(); |     auto ord_b = layout_b->get_order(); | ||||||
|     bool is_a_row = ord_a[0] != 0; |     bool is_a_row = ord_a[0] != 0; | ||||||
| @@ -168,6 +163,7 @@ mma_layout::mma_layout(size_t num_warps, | |||||||
|     spw_ = {16, 8, 1}; |     spw_ = {16, 8, 1}; | ||||||
|     rep_ = {2,  2, 1}; |     rep_ = {2,  2, 1}; | ||||||
|   } |   } | ||||||
|  |   order_ = {0, 1}; | ||||||
|  |  | ||||||
|   /* warps per tile */ |   /* warps per tile */ | ||||||
|   // try to make things as square as possible to maximize data re-use |   // try to make things as square as possible to maximize data re-use | ||||||
| @@ -182,7 +178,7 @@ mma_layout::mma_layout(size_t num_warps, | |||||||
|   }while(wpt_nm1 != wpt_); |   }while(wpt_nm1 != wpt_); | ||||||
|  |  | ||||||
|   /* shape per block */ |   /* shape per block */ | ||||||
|   spt_ = {spw_[0]*wpt_[0], spw_[1]*wpt_[1], 1}; |   shape_per_cta_ = {spw_[0]*wpt_[0], spw_[1]*wpt_[1], 1}; | ||||||
| } | } | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -194,7 +190,7 @@ scanline_layout::scanline_layout(size_t num_warps, | |||||||
|                                  const std::vector<int>& axes, |                                  const std::vector<int>& axes, | ||||||
|                                  const std::vector<unsigned>& shape, |                                  const std::vector<unsigned>& shape, | ||||||
|                                  const std::vector<ir::value *> &values, |                                  const std::vector<ir::value *> &values, | ||||||
|                                  analysis::align* align, target *tgt): data_layout(SCANLINE, axes, shape, values, align){ |                                  analysis::align* align, target *tgt): distributed_layout(SCANLINE, axes, shape, values, align){ | ||||||
|   unsigned size = std::accumulate(shape_.begin(), shape_.end(), 1, std::multiplies<int>()); |   unsigned size = std::accumulate(shape_.begin(), shape_.end(), 1, std::multiplies<int>()); | ||||||
|   unsigned num_threads = tgt->is_gpu() ? num_warps * 32 : 1; |   unsigned num_threads = tgt->is_gpu() ? num_warps * 32 : 1; | ||||||
|   nts_.resize(shape_.size()); |   nts_.resize(shape_.size()); | ||||||
| @@ -230,6 +226,10 @@ scanline_layout::scanline_layout(size_t num_warps, | |||||||
|     mts_[i] = clamp(num_threads, 1, shape_[i] / nts_[i]); |     mts_[i] = clamp(num_threads, 1, shape_[i] / nts_[i]); | ||||||
|     num_threads = num_threads / mts_[i]; |     num_threads = num_threads / mts_[i]; | ||||||
|   } |   } | ||||||
|  |  | ||||||
|  |   shape_per_cta_.resize(shape_.size()); | ||||||
|  |   for(size_t d = 0; d < shape_.size(); d++) | ||||||
|  |     shape_per_cta_[d] = mts_[d]*nts_[d]; | ||||||
| } | } | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -489,6 +489,9 @@ void layouts::create(size_t id, const std::vector<ir::value*>& values) { | |||||||
| void layouts::run(ir::module &mod) { | void layouts::run(ir::module &mod) { | ||||||
|   // make graph |   // make graph | ||||||
|   graph_.clear(); |   graph_.clear(); | ||||||
|  |   layouts_.clear(); | ||||||
|  |   groups_.clear(); | ||||||
|  |  | ||||||
|   ir::for_each_instruction(mod, [this](ir::instruction* i) { |   ir::for_each_instruction(mod, [this](ir::instruction* i) { | ||||||
|     make_graph(i); |     make_graph(i); | ||||||
|   }); |   }); | ||||||
| @@ -515,23 +518,18 @@ void layouts::run(ir::module &mod) { | |||||||
|       layouts_[id] = new shared_layout(layout, axes_->get(arg), shapes, {red}, red->get_type()->get_scalar_ty(), align_); |       layouts_[id] = new shared_layout(layout, axes_->get(arg), shapes, {red}, red->get_type()->get_scalar_ty(), align_); | ||||||
|       tmp_[red] = id; |       tmp_[red] = id; | ||||||
|     } |     } | ||||||
|     if(auto *recoalasce = dynamic_cast<ir::recoalesce_inst*>(i)){ |     if(auto *val = dynamic_cast<ir::cvt_layout_inst*>(i)){ | ||||||
|       ir::value *val = recoalasce->get_operand(0); |       distributed_layout* out_layout = dynamic_cast<distributed_layout*>(get(val)); | ||||||
|       mma_layout* in_layout = get(val)->to_mma(); |       distributed_layout* in_layout = dynamic_cast<distributed_layout*>(get(i->get_operand(0))); | ||||||
|       scanline_layout* out_layout = get(i)->to_scanline(); |  | ||||||
|       if(!in_layout || !out_layout) |  | ||||||
|         return; |  | ||||||
|       id++; |       id++; | ||||||
|       ir::type::block_shapes_t in_shape = val->get_type()->get_block_shapes(); |       size_t dim = val->get_type()->get_tile_rank(); | ||||||
|       ir::type::block_shapes_t shape(in_shape.size()); |       ir::type::block_shapes_t shape(dim); | ||||||
|       size_t ld = out_layout->get_order(0); |       for(size_t k = 0; k < dim; k++){ | ||||||
|       shape[ld] = in_shape[ld]; |         shape[k] = std::max(in_layout->shape_per_cta(k), | ||||||
|       for(size_t k = 0; k < in_shape.size(); k++) |                             out_layout->shape_per_cta(k)); | ||||||
|         if(k != ld) |       } | ||||||
|           shape[k] = in_layout->to_mma()->spt(k); |       layouts_[id] = new shared_layout(out_layout, axes_->get(val), shape, {val}, val->get_type()->get_scalar_ty(), align_); | ||||||
|       // create layout |       tmp_[val] = id; | ||||||
|       layouts_[id] = new shared_layout(out_layout, axes_->get(val), shape, {recoalasce}, val->get_type()->get_scalar_ty(), align_); |  | ||||||
|       tmp_[recoalasce] = id; |  | ||||||
|     } |     } | ||||||
|     if(auto *atom = dynamic_cast<ir::atomic_inst*>(i)){ |     if(auto *atom = dynamic_cast<ir::atomic_inst*>(i)){ | ||||||
|       id++; |       id++; | ||||||
|   | |||||||
| @@ -56,10 +56,8 @@ void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps, | |||||||
|   dce.run(ir); |   dce.run(ir); | ||||||
|   peephole.run(ir); |   peephole.run(ir); | ||||||
|   dce.run(ir); |   dce.run(ir); | ||||||
|   // ir::print(ir, std::cout); |  | ||||||
|   pipeline.run(ir); |   pipeline.run(ir); | ||||||
|   dce.run(ir); |   dce.run(ir); | ||||||
|   // ir::print(ir, std::cout); |  | ||||||
|   disassociate.run(ir); |   disassociate.run(ir); | ||||||
|   dce.run(ir); |   dce.run(ir); | ||||||
|   align.run(ir); |   align.run(ir); | ||||||
| @@ -74,14 +72,15 @@ void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps, | |||||||
|   layouts.run(ir); |   layouts.run(ir); | ||||||
|   coalesce.run(ir); |   coalesce.run(ir); | ||||||
|   dce.run(ir); |   dce.run(ir); | ||||||
|  | //  exit(1); | ||||||
|  |  | ||||||
|   align.run(ir); |   align.run(ir); | ||||||
|   dce.run(ir); |   dce.run(ir); | ||||||
|   if (target->is_gpu()) { |   if (target->is_gpu()) | ||||||
| //    reassociate.run(ir); |  | ||||||
|     cts.run(ir); |     cts.run(ir); | ||||||
|   } |  | ||||||
|   dce.run(ir); |   dce.run(ir); | ||||||
|   align.run(ir); |   align.run(ir); | ||||||
|  | //  ir::print(ir, std::cout); | ||||||
|   axes.run(ir); |   axes.run(ir); | ||||||
|   layouts.run(ir); |   layouts.run(ir); | ||||||
|   peephole.run(ir); |   peephole.run(ir); | ||||||
| @@ -93,10 +92,7 @@ void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps, | |||||||
|   liveness.run(ir); |   liveness.run(ir); | ||||||
|   allocation.run(ir); |   allocation.run(ir); | ||||||
|   prefetch_s.run(ir); |   prefetch_s.run(ir); | ||||||
| //  ir::print(ir, std::cout); |  | ||||||
|   barriers.run(ir); |   barriers.run(ir); | ||||||
| //  ir::print(ir, std::cout); |  | ||||||
| //  ir::print(ir, std::cout); |  | ||||||
|   isel.visit(ir, *llvm); |   isel.visit(ir, *llvm); | ||||||
|   mod = driver::module::create(dev, std::move(llvm)); |   mod = driver::module::create(dev, std::move(llvm)); | ||||||
|   ker = driver::kernel::create(&*mod, name.c_str()); |   ker = driver::kernel::create(&*mod, name.c_str()); | ||||||
|   | |||||||
| @@ -586,7 +586,7 @@ void generator::visit_load_inst(ir::load_inst* x){ | |||||||
|   Type* ty  = cvt(op->get_type()->get_scalar_ty()->get_pointer_element_ty()); |   Type* ty  = cvt(op->get_type()->get_scalar_ty()->get_pointer_element_ty()); | ||||||
|   // compute vector width |   // compute vector width | ||||||
|   size_t vec = 1; |   size_t vec = 1; | ||||||
|   if(op->get_type()->is_block_ty()){ |   if(op->get_type()->is_block_ty() && op->get_type()->get_tile_rank() > 1){ | ||||||
|     auto   ord = ords_.at(op); |     auto   ord = ords_.at(op); | ||||||
|     size_t aln = alignment_->get(op, ord[0]); |     size_t aln = alignment_->get(op, ord[0]); | ||||||
|     size_t nts = layouts_->get(x)->to_scanline()->nts(ord[0]); |     size_t nts = layouts_->get(x)->to_scanline()->nts(ord[0]); | ||||||
| @@ -626,10 +626,10 @@ void generator::visit_load_inst(ir::load_inst* x){ | |||||||
|     // ----- |     // ----- | ||||||
|     std::ostringstream asm_oss; |     std::ostringstream asm_oss; | ||||||
|     asm_oss << "@$" << n_words; // predicate |     asm_oss << "@$" << n_words; // predicate | ||||||
|     if(force_nc_cache_) | //    if(force_nc_cache_) | ||||||
|       asm_oss << " ld.global.nc"; |       asm_oss << " ld.global"; | ||||||
|     else | //    else | ||||||
|       asm_oss << " ld.global.cg"; | //      asm_oss << " ld.global.cg"; | ||||||
|     if(n_words > 1) |     if(n_words > 1) | ||||||
|       asm_oss << ".v" << n_words; // vector width |       asm_oss << ".v" << n_words; // vector width | ||||||
|     asm_oss << ".b" << width; // word size |     asm_oss << ".b" << width; // word size | ||||||
| @@ -1058,6 +1058,7 @@ void generator::visit_mma884(ir::dot_inst* C, ir::value *A, ir::value *B, ir::va | |||||||
|   /* --------------------------------- */ |   /* --------------------------------- */ | ||||||
|   BasicBlock* curr_bb = builder_->GetInsertBlock(); |   BasicBlock* curr_bb = builder_->GetInsertBlock(); | ||||||
|   BasicBlock* entry = &curr_bb->getParent()->getEntryBlock(); |   BasicBlock* entry = &curr_bb->getParent()->getEntryBlock(); | ||||||
|  |   if(entry != curr_bb) | ||||||
|     builder_->SetInsertPoint(entry->getTerminator()); |     builder_->SetInsertPoint(entry->getTerminator()); | ||||||
|   Value* off_a0 = is_a_row ? offset_a_k_[layout_c] : offset_a_m_[layout_c]; |   Value* off_a0 = is_a_row ? offset_a_k_[layout_c] : offset_a_m_[layout_c]; | ||||||
|   Value* off_a1 = is_a_row ? offset_a_m_[layout_c] : offset_a_k_[layout_c]; |   Value* off_a1 = is_a_row ? offset_a_m_[layout_c] : offset_a_k_[layout_c]; | ||||||
| @@ -1116,8 +1117,8 @@ void generator::visit_mma884(ir::dot_inst* C, ir::value *A, ir::value *B, ir::va | |||||||
|   for(indices_t idx: idxs_.at(C)) |   for(indices_t idx: idxs_.at(C)) | ||||||
|     acc.push_back(vals_[D][idx]); |     acc.push_back(vals_[D][idx]); | ||||||
|  |  | ||||||
|   unsigned num_m = layout_c->rep(0) * shape_c[0] / layout_c->spt(0); |   unsigned num_m = layout_c->rep(0) * shape_c[0] / layout_c->shape_per_cta(0); | ||||||
|   unsigned num_n = layout_c->rep(1) * shape_c[1] / layout_c->spt(1); |   unsigned num_n = layout_c->rep(1) * shape_c[1] / layout_c->shape_per_cta(1); | ||||||
|  |  | ||||||
|   // create mma & unpack result |   // create mma & unpack result | ||||||
|   auto call_mma = [&](unsigned m, unsigned n, unsigned K) { |   auto call_mma = [&](unsigned m, unsigned n, unsigned K) { | ||||||
| @@ -1333,6 +1334,7 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir:: | |||||||
|  |  | ||||||
|   BasicBlock* CurrBB = builder_->GetInsertBlock(); |   BasicBlock* CurrBB = builder_->GetInsertBlock(); | ||||||
|   BasicBlock* FirstBB = &CurrBB->getParent()->getEntryBlock(); |   BasicBlock* FirstBB = &CurrBB->getParent()->getEntryBlock(); | ||||||
|  |   if(FirstBB != CurrBB) | ||||||
|     builder_->SetInsertPoint(FirstBB->getTerminator()); |     builder_->SetInsertPoint(FirstBB->getTerminator()); | ||||||
|  |  | ||||||
|   Value* thread = tgt_->get_local_id(mod_, *builder_, 0); |   Value* thread = tgt_->get_local_id(mod_, *builder_, 0); | ||||||
| @@ -1396,8 +1398,8 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir:: | |||||||
|                                              "{$10, $11, $12, $13};", |                                              "{$10, $11, $12, $13};", | ||||||
|                                              "=f,=f,=f,=f,r,r,r,r,r,r,0,1,2,3", true); |                                              "=f,=f,=f,=f,r,r,r,r,r,r,0,1,2,3", true); | ||||||
|  |  | ||||||
|   unsigned num_rep_0 = shapes[0] / layout->spt(0); |   unsigned num_rep_0 = shapes[0] / layout->shape_per_cta(0); | ||||||
|   unsigned num_rep_1 = shapes[1] / layout->spt(1); |   unsigned num_rep_1 = shapes[1] / layout->shape_per_cta(1); | ||||||
|  |  | ||||||
|   // create mma & unpack result |   // create mma & unpack result | ||||||
|   auto call_mma = [&](unsigned m, unsigned n, unsigned K) { |   auto call_mma = [&](unsigned m, unsigned n, unsigned K) { | ||||||
| @@ -1626,8 +1628,8 @@ void generator::visit_fmadot(ir::dot_inst* C, ir::value* A, ir::value* B, ir::va | |||||||
|   std::map<std::pair<int, int>, Value*> has, hbs; |   std::map<std::pair<int, int>, Value*> has, hbs; | ||||||
|   for(unsigned k = 0; k < NK; k++){ |   for(unsigned k = 0; k < NK; k++){ | ||||||
|     int z = 0; |     int z = 0; | ||||||
|     for(unsigned m = 0; m < shape_c[0]; m+=layout_c->mts(0)*layout_c->nts(0)) |     for(unsigned m = 0; m < shape_c[0]; m += layout_c->shape_per_cta(0)) | ||||||
|     for(unsigned n = 0; n < shape_c[1]; n+=layout_c->mts(1)*layout_c->nts(1)) |     for(unsigned n = 0; n < shape_c[1]; n += layout_c->shape_per_cta(1)) | ||||||
|     for(unsigned mm = 0; mm < layout_c->nts(0); mm++) |     for(unsigned mm = 0; mm < layout_c->nts(0); mm++) | ||||||
|     for(unsigned nn = 0; nn < layout_c->nts(1); nn++) |     for(unsigned nn = 0; nn < layout_c->nts(1); nn++) | ||||||
|     { |     { | ||||||
| @@ -1818,6 +1820,7 @@ void generator::visit_reducend_inst(ir::reduce_inst* x, std::function<Value*(Val | |||||||
|       add_barrier(); |       add_barrier(); | ||||||
|       // update accumulator |       // update accumulator | ||||||
|       acc = do_acc(acc, load(read_ptr)); |       acc = do_acc(acc, load(read_ptr)); | ||||||
|  |       add_barrier(); | ||||||
|       store(acc, write_ptr); |       store(acc, write_ptr); | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
| @@ -1884,56 +1887,76 @@ void generator::visit_select_inst(ir::select_inst* x) { | |||||||
|   } |   } | ||||||
| } | } | ||||||
|  |  | ||||||
| /** |  | ||||||
|  * \brief Code Generation for `recoalesce` |  | ||||||
|  */ | void generator::visit_layout_convert(ir::value *out, ir::value *in){ | ||||||
| void generator::visit_recoalesce_inst(ir::recoalesce_inst* rc) { |   ir::block_type::block_shapes_t shape = out->get_type()->get_block_shapes(); | ||||||
|   ir::value *op = rc->get_operand(0); |  | ||||||
|   ir::block_type::block_shapes_t shape = rc->get_type()->get_block_shapes(); |  | ||||||
|   // pointer to temporary shared memory |   // pointer to temporary shared memory | ||||||
|   Type *ty = cvt(rc->get_type()->get_scalar_ty()); |   Type *ty = cvt(out->get_type()->get_scalar_ty()); | ||||||
|   // layout |  | ||||||
|   analysis::mma_layout* in_layout = layouts_->get(op)->to_mma(); |  | ||||||
|   analysis::scanline_layout* out_layout = layouts_->get(rc)->to_scanline(); |  | ||||||
|   // Orders |   // Orders | ||||||
|   auto ord = layouts_->get(rc)->to_scanline()->get_order(); |   analysis::distributed_layout* in_layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(in)); | ||||||
|  |   analysis::distributed_layout* out_layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(out)); | ||||||
|  |   auto in_ord = in_layout->get_order(); | ||||||
|  |   auto out_ord = out_layout->get_order(); | ||||||
|   Value *base; |   Value *base; | ||||||
|   base = gep(shmem_, i32(alloc_->offset(layouts_->get(layouts_->tmp(rc))))); |   base = gep(shmem_, i32(alloc_->offset(layouts_->get(layouts_->tmp(out))))); | ||||||
|   base = bit_cast(base, ptr_ty(ty, 3)); |   base = bit_cast(base, ptr_ty(ty, 3)); | ||||||
|   Value *ld = i32(shape[ord[0]]); |   std::vector<int> n_reps; | ||||||
|   auto in_ord0 = axes_.at(a_axes_->get(op, ord[0])).values; |   for(int i = 0; i < shape.size(); i++){ | ||||||
|   auto in_ord1 = axes_.at(a_axes_->get(op, ord[1])).values; |     int in_per_cta = in_layout->shape_per_cta(i); | ||||||
|   auto out_ord0 = axes_.at(a_axes_->get(rc, ord[0])).values; |     int out_per_cta = out_layout->shape_per_cta(i); | ||||||
|   auto out_ord1 = axes_.at(a_axes_->get(rc, ord[1])).values; |     int max_per_cta = std::max(in_per_cta, out_per_cta); | ||||||
|   int in_spt0  = in_layout->spt(ord[0]); |     n_reps.push_back(shape[i]/max_per_cta); | ||||||
|   int in_spt1  = in_layout->spt(ord[1]); |   } | ||||||
|   int out_spt0 = out_layout->mts(ord[0])*out_layout->nts(ord[0]); |   std::vector<std::vector<Value*>> in_ax; | ||||||
|   int out_spt1 = out_layout->mts(ord[1])*out_layout->nts(ord[1]); |   std::vector<std::vector<Value*>> out_ax; | ||||||
|   int max_spt1 = std::max(in_spt1, out_spt1); |   for(int d = 0; d < shape.size(); d++){ | ||||||
|   indices_t idx(2); |     in_ax.push_back(axes_.at(a_axes_->get(in, d)).values); | ||||||
|   int num_packs = shape[ord[1]]/max_spt1; |     out_ax.push_back(axes_.at(a_axes_->get(out, d)).values); | ||||||
|   for(size_t j = 0; j < num_packs; j++){ |   } | ||||||
|  |   in_ord = in_layout->to_mma() ? out_ord : in_ord; | ||||||
|  |   out_ord = out_layout->to_mma() ? in_ord : out_ord; | ||||||
|  |   Value *in_ld = i32(shape[in_ord[0]]); | ||||||
|  |   Value *out_ld = i32(shape[out_ord[0]]); | ||||||
|  |   for(int i = 0; i < n_reps[0]; i++) | ||||||
|  |   for(int j = 0; j < n_reps[1]; j++){ | ||||||
|  |     int max_ii, max_jj; | ||||||
|     add_barrier(); |     add_barrier(); | ||||||
|     for(size_t k = 0; k < in_ord1.size()/num_packs; k++) |     max_ii = in_ax[0].size()/n_reps[0]; | ||||||
|     for(size_t i = 0; i < in_ord0.size(); i++){ |     max_jj = in_ax[1].size()/n_reps[1]; | ||||||
|       idx[ord[0]] = in_ord0[i]; |     for(int ii = 0; ii < max_ii; ii++) | ||||||
|       idx[ord[1]] = in_ord1[j*in_ord1.size()/num_packs + k]; |     for(int jj = 0; jj < max_jj; jj++){ | ||||||
|       Value *off = add(idx[ord[0]], mul(in_ord1[k], ld)); |       // shared mem pointer | ||||||
|  |       indices_t offs = {in_ax[0][ii], in_ax[1][jj]}; | ||||||
|  |       Value *off  = add(offs[out_ord[0]], mul(out_ld, offs[out_ord[1]])); | ||||||
|       Value *ptr = gep(base, off); |       Value *ptr = gep(base, off); | ||||||
|       store(vals_[op][idx], ptr); |       // stash value to shared mem | ||||||
|  |       indices_t idxs = {in_ax[0][i*max_ii + ii], | ||||||
|  |                         in_ax[1][j*max_jj + jj]}; | ||||||
|  |       store(vals_[in][idxs], ptr); | ||||||
|     } |     } | ||||||
|     add_barrier(); |     add_barrier(); | ||||||
|     for(size_t k = 0; k < out_ord1.size()/num_packs; k++) |     max_ii = out_ax[0].size()/n_reps[0]; | ||||||
|     for(size_t i = 0; i < out_ord0.size(); i++){ |     max_jj = out_ax[1].size()/n_reps[1]; | ||||||
|       idx[ord[0]] = out_ord0[i]; |     for(int ii = 0; ii < max_ii; ii++) | ||||||
|       idx[ord[1]] = out_ord1[j*out_ord1.size()/num_packs + k]; |     for(int jj = 0; jj < max_jj; jj++){ | ||||||
|       Value *off = add(idx[ord[0]], mul(out_ord1[k], ld)); |       // shared mem pointer | ||||||
|  |       indices_t offs = {out_ax[0][ii], out_ax[1][jj]}; | ||||||
|  |       Value *off  = add(offs[out_ord[0]], mul(out_ld, offs[out_ord[1]])); | ||||||
|       Value *ptr = gep(base, off); |       Value *ptr = gep(base, off); | ||||||
|       vals_[rc][idx] = load(ptr); |       // load value from shared rem | ||||||
|  |       indices_t idxs = {out_ax[0][i*max_ii + ii], | ||||||
|  |                         out_ax[1][j*max_jj + jj]}; | ||||||
|  |       vals_[out][idxs] = load(ptr); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|   } |   } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | void generator::visit_cvt_layout_inst(ir::cvt_layout_inst *rc) { | ||||||
|  |   visit_layout_convert(rc, rc->get_operand(0)); | ||||||
|  | } | ||||||
|  |  | ||||||
| void generator::visit_masked_load_async_inst(ir::masked_load_async_inst* x){ | void generator::visit_masked_load_async_inst(ir::masked_load_async_inst* x){ | ||||||
|   unsigned in_vec = 1; |   unsigned in_vec = 1; | ||||||
|   ir::value *arg = x->get_pointer_operand(); |   ir::value *arg = x->get_pointer_operand(); | ||||||
| @@ -2325,12 +2348,12 @@ void generator::visit_layout_mma(analysis::mma_layout* layout) { | |||||||
|     offset_b_k_[layout] = and_(lane, _3); |     offset_b_k_[layout] = and_(lane, _3); | ||||||
|     // i indices |     // i indices | ||||||
|     Value *offset_c_m = add(and_(lane, _1), offset_a_m_[layout]); |     Value *offset_c_m = add(and_(lane, _1), offset_a_m_[layout]); | ||||||
|     for(unsigned m = 0; m < shape[0]; m+=layout->spt(0)) |     for(unsigned m = 0; m < shape[0]; m+=layout->shape_per_cta(0)) | ||||||
|     for(unsigned mm = 0; mm < layout->rep(0); mm++) |     for(unsigned mm = 0; mm < layout->rep(0); mm++) | ||||||
|       idx_m.push_back(add(offset_c_m, i32(m + mm*2))); |       idx_m.push_back(add(offset_c_m, i32(m + mm*2))); | ||||||
|     // j indices |     // j indices | ||||||
|     Value *offset_c_n = add(and_(lane, _2), add(off_warp_n, off_pair_n)); |     Value *offset_c_n = add(and_(lane, _2), add(off_warp_n, off_pair_n)); | ||||||
|     for(unsigned n = 0; n < shape[1]; n+=layout->spt(1)) |     for(unsigned n = 0; n < shape[1]; n+=layout->shape_per_cta(1)) | ||||||
|     for(unsigned nn = 0; nn < layout->rep(1); nn++){ |     for(unsigned nn = 0; nn < layout->rep(1); nn++){ | ||||||
|       idx_n.push_back(add(offset_c_n, i32(n + nn/2*4 + (nn%2)*2*layout->fpw(1)*layout->rep(1)))); |       idx_n.push_back(add(offset_c_n, i32(n + nn/2*4 + (nn%2)*2*layout->fpw(1)*layout->rep(1)))); | ||||||
|       idx_n.push_back(add(offset_c_n, i32(n + nn/2*4 + (nn%2)*2*layout->fpw(1)*layout->rep(1) + 1))); |       idx_n.push_back(add(offset_c_n, i32(n + nn/2*4 + (nn%2)*2*layout->fpw(1)*layout->rep(1) + 1))); | ||||||
| @@ -2366,11 +2389,11 @@ void generator::visit_layout_mma(analysis::mma_layout* layout) { | |||||||
|     // c offset |     // c offset | ||||||
|     Value *off_c_m = add(udiv(lane, _4), off_warp_m); |     Value *off_c_m = add(udiv(lane, _4), off_warp_m); | ||||||
|     Value *off_c_n = add(mul(_2, urem(lane, _4)), off_warp_n); |     Value *off_c_n = add(mul(_2, urem(lane, _4)), off_warp_n); | ||||||
|     for(unsigned m = 0; m < shape[0]; m+=layout->spt(0)){ |     for(unsigned m = 0; m < shape[0]; m+=layout->shape_per_cta(0)){ | ||||||
|       idx_m.push_back(add(off_c_m, i32(m))); |       idx_m.push_back(add(off_c_m, i32(m))); | ||||||
|       idx_m.push_back(add(off_c_m, i32(m + 8))); |       idx_m.push_back(add(off_c_m, i32(m + 8))); | ||||||
|     } |     } | ||||||
|     for(unsigned n = 0; n < shape[1]; n+=layout->spt(1)){ |     for(unsigned n = 0; n < shape[1]; n+=layout->shape_per_cta(1)){ | ||||||
|       idx_n.push_back(add(off_c_n, i32(n))); |       idx_n.push_back(add(off_c_n, i32(n))); | ||||||
|       idx_n.push_back(add(off_c_n, i32(n + 1))); |       idx_n.push_back(add(off_c_n, i32(n + 1))); | ||||||
|     } |     } | ||||||
| @@ -2406,11 +2429,11 @@ void generator::visit_layout_scanline(analysis::scanline_layout* layout) { | |||||||
|     std::string str_k = std::to_string(k); |     std::string str_k = std::to_string(k); | ||||||
|     Value *contiguous_k = i32(nts); |     Value *contiguous_k = i32(nts); | ||||||
|     Value *scaled_thread_id = mul(thread_id[k], contiguous_k); |     Value *scaled_thread_id = mul(thread_id[k], contiguous_k); | ||||||
|     unsigned per_block  = nts * mts; |     unsigned per_cta  = layout->shape_per_cta(k); | ||||||
|     unsigned per_thread = nts * shape[k] / per_block; |     unsigned per_thread = nts * shape[k] / per_cta; | ||||||
|     std::vector<Value*> idx_list(per_thread); |     std::vector<Value*> idx_list(per_thread); | ||||||
|     for(unsigned n = 0 ; n < per_thread; n++){ |     for(unsigned n = 0 ; n < per_thread; n++){ | ||||||
|       unsigned offset = n / nts * per_block + n % nts; |       unsigned offset = n / nts * per_cta + n % nts; | ||||||
|       idx_list[n] = add(scaled_thread_id, i32(offset), "idx_" + str_k + "_" + std::to_string(n)); |       idx_list[n] = add(scaled_thread_id, i32(offset), "idx_" + str_k + "_" + std::to_string(n)); | ||||||
|     } |     } | ||||||
|     axes_[layout->get_axis(k)] = distributed_axis{nts, idx_list, thread_id[k]}; |     axes_[layout->get_axis(k)] = distributed_axis{nts, idx_list, thread_id[k]}; | ||||||
|   | |||||||
| @@ -15,128 +15,109 @@ namespace transform{ | |||||||
| coalesce::coalesce(analysis::align* align, analysis::layouts *layouts) | coalesce::coalesce(analysis::align* align, analysis::layouts *layouts) | ||||||
|   : align_(align), layout_(layouts) { } |   : align_(align), layout_(layouts) { } | ||||||
|  |  | ||||||
| // Find all values that are used as pointer operands in LD/ST |  | ||||||
| void coalesce::extract_io_use(ir::value *v, std::set<ir::io_inst*>& result) { |  | ||||||
|   for(ir::user* u: v->get_users()){ |  | ||||||
|     auto i = dynamic_cast<ir::io_inst*>(u); |  | ||||||
|     if(i && i->get_pointer_operand() == v) |  | ||||||
|       result.insert(i); |  | ||||||
|   } |  | ||||||
| } |  | ||||||
|  |  | ||||||
| void coalesce::extract_ld(ir::io_inst* i, std::map<int, std::vector<ir::io_inst*>>& result) { | // simplify layout conversions using the following simple rules: | ||||||
|   ir::value *ptr = i->get_pointer_operand(); | //   - cvt_1(cvt_2(x)) if convert1 is the inverse of convert2 | ||||||
|   auto contiguous = align_->contiguous(ptr); | //   - cvt_1(elementwise(x, y)) = elementwise(convert(x), convert(y)) | ||||||
|   auto it = std::max_element(contiguous.begin(), contiguous.end()); | //ir::value* coalesce::simplify(ir::instruction *inst, ir::builder& builder){ | ||||||
|   int axis = std::distance(contiguous.begin(), it); | //  ir::value* _op = inst->get_operand(0); | ||||||
|   result[axis].push_back(i); | //  ir::instruction* op = dynamic_cast<ir::instruction*>(_op); | ||||||
| } | //  analysis::mma_layout* mma_in  = layout_->get(op)  ->to_mma(); | ||||||
|  | //  analysis::mma_layout* mma_out = layout_->get(inst)->to_mma(); | ||||||
| ir::value* coalesce::rematerialize(ir::value *x, ir::builder &builder, | //  std::cout << 1 << std::endl; | ||||||
|                                    std::map<ir::value*, ir::value*>& seen) { | //  // i must be layout conversion instruction | ||||||
|   if(seen.find(x) != seen.end()) | //  if(!mma_in && !mma_out) | ||||||
|     return seen.at(x); | //    return inst; | ||||||
|   auto i = dynamic_cast<ir::instruction*>(x); | //  //   - cvt_1(cvt_2(x)) if convert1 is the inverse of convert2 | ||||||
|   // not an instruction -- forward value | //  bool is_op_cvt = op->get_id() == ir::INST_CVT_LAYOUT; | ||||||
|   if(!i) | //  if((mma_in || mma_out) && is_op_cvt && | ||||||
|     return x; | //     (layout_->get(inst) == layout_->get(op->get_operand(0)))) | ||||||
|   // already in shared memory -- forward value | //    return op->get_operand(0); | ||||||
|   if(dynamic_cast<ir::copy_to_shared_inst*>(x)){ | //  //   - cvt_1(elementwise(x, y)) = elementwise(cvt_1(x), cvt_2(y)) | ||||||
|     return x; | //  if(op->get_id() != ir::INST_BINOP && op->get_id() != ir::INST_GETELEMENTPTR) | ||||||
|   } | //    return inst; | ||||||
|   // set insert point | //  std::cout << 1 << std::endl; | ||||||
|   auto& inst_list = i->get_parent()->get_inst_list(); | //  for(size_t i = 0; i < op->get_num_operands(); i++){ | ||||||
|   auto pos = ++std::find(inst_list.begin(), inst_list.end(), i); | //    ir::value* arg_i = op->get_operand(i); | ||||||
|   builder.set_insert_point(pos); | //    builder.set_insert_point(op); | ||||||
|   if(dynamic_cast<ir::load_inst*>(x)){ | //    // create new layout transform | ||||||
|     ir::value *ret = builder.insert(ir::copy_to_shared_inst::create(x)); | //    ir::instruction* new_arg_i = inst->clone(); | ||||||
|     return ret; | //    builder.insert(new_arg_i); | ||||||
|   } | //    // set the right args | ||||||
|   // default -- recursive clone | //    new_arg_i->replace_uses_of_with(new_arg_i->get_operand(0), arg_i); | ||||||
|   ir::instruction *cloned = builder.insert(i->clone()); | //    op->replace_uses_of_with(arg_i, simplify(new_arg_i, builder)); | ||||||
|   seen[i] = cloned; | //  } | ||||||
|   // rematerialize operands | //  std::cout << 2 << std::endl; | ||||||
|   for(ir::value *op: cloned->ops()) | //  return op; | ||||||
|     cloned->replace_uses_of_with(op, rematerialize(op, builder, seen)); | //} | ||||||
|   return cloned; |  | ||||||
| } |  | ||||||
|  |  | ||||||
| void coalesce::run(ir::module &mod) { | void coalesce::run(ir::module &mod) { | ||||||
|   size_t num_groups = layout_->num_layouts(); |  | ||||||
|  |  | ||||||
|  |  | ||||||
|   for(size_t id = 0; id < num_groups; id++) { |  | ||||||
|     if(!layout_->get(id)->to_mma()) |  | ||||||
|       continue; |  | ||||||
|     // extract memory stores |  | ||||||
|     const auto& values = layout_->values_of(id); |  | ||||||
|     ir::value* dot = nullptr; |  | ||||||
|     for(ir::value *v: values) |  | ||||||
|       if(auto x = dynamic_cast<ir::dot_inst*>(v)) |  | ||||||
|         dot = x; |  | ||||||
|  |  | ||||||
|   ir::builder& builder = mod.get_builder(); |   ir::builder& builder = mod.get_builder(); | ||||||
|     std::vector<ir::value*> worklist = {dot}; |   // add layout conversion instructions | ||||||
|     std::set<ir::value*> seen; |   for(ir::function *fn: mod.get_function_list()) | ||||||
|     while(!worklist.empty()) { |   for(ir::basic_block *block: fn->blocks()) | ||||||
|       ir::value *current = worklist.back(); |   for(ir::instruction* i: block->get_inst_list()){ | ||||||
|       seen.insert(current); |     // coalesce before store | ||||||
|       worklist.pop_back(); |     if(auto x = dynamic_cast<ir::store_inst*>(i)) | ||||||
|       // stop if trunc |     if(ir::value* op = x->get_value_operand()) | ||||||
|       if(auto x = dynamic_cast<ir::fp_trunc_inst*>(current)){ |     if(op->get_type()->is_block_ty()) | ||||||
|  |     if(layout_->get(op)->to_mma()){ | ||||||
|  |       builder.set_insert_point(x); | ||||||
|  |       ir::instruction* new_op = ir::cvt_layout_inst::create(op); | ||||||
|  |       builder.insert(new_op); | ||||||
|  |       x->replace_uses_of_with(op, new_op); | ||||||
|  |     } | ||||||
|  |     // uncoalesce after load | ||||||
|  |     if(auto x = dynamic_cast<ir::load_inst*>(i)) | ||||||
|  |     if(x->get_type()->is_block_ty()) | ||||||
|  |     if(x->get_type()->get_tile_rank()==2) | ||||||
|  |     if(layout_->get(x)->to_mma()){ | ||||||
|         builder.set_insert_point_after(x); |         builder.set_insert_point_after(x); | ||||||
|         ir::recoalesce_inst* rc = ir::recoalesce_inst::create(x); |         ir::instruction* new_x = ir::cvt_layout_inst::create(x); | ||||||
|         builder.insert(rc); |         builder.insert(new_x); | ||||||
|         x->replace_all_uses_with(rc); |         x->replace_all_uses_with(new_x); | ||||||
|         rc->replace_uses_of_with(rc, x); |         new_x->replace_uses_of_with(new_x, x); | ||||||
|  | //        new_x->replace_uses_of_with(new_x, new_x); | ||||||
|  |     } | ||||||
|  |     // re-arrange scanline to promote memory coalescing | ||||||
|  |     if(auto x = dynamic_cast<ir::store_inst*>(i)){ | ||||||
|  |       ir::value* ptr = x->get_pointer_operand(); | ||||||
|  |       ir::value* val = x->get_value_operand(); | ||||||
|  |       auto out_contig = align_->contiguous(ptr); | ||||||
|  |       auto val_inst = dynamic_cast<ir::instruction*>(val); | ||||||
|  |       if(!val_inst) | ||||||
|  |         break; | ||||||
|  |       if(dynamic_cast<ir::cvt_layout_inst*>(val)) | ||||||
|  |         break; | ||||||
|  |       std::vector<unsigned> in_contig; | ||||||
|  |       std::vector<ir::instruction*> queue = {val_inst}; | ||||||
|  |       std::set<ir::instruction*> seen; | ||||||
|  |       std::vector<ir::io_inst*> ios; | ||||||
|  |       while(!queue.empty()){ | ||||||
|  |         ir::instruction* curr = queue.back(); | ||||||
|  |         seen.insert(curr); | ||||||
|  |         queue.pop_back(); | ||||||
|  |         if(auto io_inst = dynamic_cast<ir::io_inst*>(curr)){ | ||||||
|  |           in_contig = align_->contiguous(io_inst->get_pointer_operand()); | ||||||
|           break; |           break; | ||||||
|         } |         } | ||||||
|       // recurse |         for(ir::value* op: curr->ops()){ | ||||||
|       for(ir::user *u: current->get_users()) |           auto inst_op = dynamic_cast<ir::instruction*>(op); | ||||||
|         if(seen.find(u) == seen.end()) |           if(!inst_op || seen.find(inst_op) != seen.end()) | ||||||
|           worklist.push_back(u); |  | ||||||
|     } |  | ||||||
|   } |  | ||||||
|  |  | ||||||
|   // find values to rematerialize |  | ||||||
|   std::vector<ir::io_inst*> remat; |  | ||||||
|   for(size_t id = 0; id < num_groups; id++) { |  | ||||||
|     const auto& values = layout_->values_of(id); |  | ||||||
|     // extract pointers used in ld/st operations |  | ||||||
|     std::set<ir::io_inst*> io; |  | ||||||
|     for(ir::value *v: values) |  | ||||||
|       extract_io_use(v, io); |  | ||||||
|     // extract leading axes |  | ||||||
|     std::map<int, std::vector<ir::io_inst*>> axes; |  | ||||||
|     for(ir::io_inst *i: io){ |  | ||||||
|       if(i->get_pointer_operand()->get_type()->get_tile_rank() == layout_->get(id)->get_rank()){ |  | ||||||
|         extract_ld(i, axes); |  | ||||||
|       } |  | ||||||
|     } |  | ||||||
|     // update list of values to rematerialize |  | ||||||
|     if(axes.empty()) |  | ||||||
|             continue; |             continue; | ||||||
|     for(auto it = ++axes.rbegin(); it != axes.rend(); it++){ |           if(!op->get_type()->is_block_ty() || | ||||||
|       if(it->second.size() == 1) |              !val->get_type()->is_block_ty()) | ||||||
|             continue; |             continue; | ||||||
|       remat.insert(remat.begin(), it->second.begin(), it->second.end()); |           if(op->get_type()->get_tile_num_elements() == | ||||||
|  |              val->get_type()->get_tile_num_elements()) | ||||||
|  |             queue.push_back(inst_op); | ||||||
|         } |         } | ||||||
|       } |       } | ||||||
|   // rematerialize values |       if(in_contig.empty() || out_contig==in_contig) | ||||||
|   for(ir::io_inst *r: remat) { |         continue; | ||||||
|     ir::builder& builder = mod.get_builder(); |       builder.set_insert_point_after(val_inst); | ||||||
|     // rematerialize operands |       auto new_val = builder.insert(ir::cvt_layout_inst::create(val_inst)); | ||||||
|     std::map<ir::value*, ir::value*> seen; |       x->replace_uses_of_with(val_inst, new_val); | ||||||
|     for(ir::value *op: r->ops()) |  | ||||||
|       r->replace_uses_of_with(op, rematerialize(op, mod.get_builder(), seen)); |  | ||||||
|     // copy to shared if load |  | ||||||
|     auto& inst_list = r->get_parent()->get_inst_list(); |  | ||||||
|     auto pos = ++std::find(inst_list.begin(), inst_list.end(), r); |  | ||||||
|     builder.set_insert_point(pos); |  | ||||||
|     if(dynamic_cast<ir::load_inst*>(r)){ |  | ||||||
|       ir::instruction *cts = builder.insert(ir::copy_to_shared_inst::create(r)); |  | ||||||
|       r->replace_all_uses_with(cts); |  | ||||||
|       cts->replace_uses_of_with(cts, r); |  | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
| } | } | ||||||
|   | |||||||
| @@ -9,67 +9,48 @@ namespace triton { | |||||||
| namespace codegen{ | namespace codegen{ | ||||||
| namespace transform{ | namespace transform{ | ||||||
|  |  | ||||||
| void extract_retile_chain(ir::user *root, | ir::instruction* rematerialize(ir::builder& bld, ir::instruction *root, | ||||||
|                           std::map<int, std::set<ir::user*>>& result, |  | ||||||
|                           int depth, |  | ||||||
|                           std::set<ir::value*>& seen) { |                           std::set<ir::value*>& seen) { | ||||||
|   if(!seen.insert(root).second) |   if(!seen.insert(root).second) | ||||||
|     return; |     return root; | ||||||
|   result[depth].insert(root); |   if(!root->get_type()->is_block_ty()) | ||||||
|   if(dynamic_cast<ir::make_range*>(root) || |     return root; | ||||||
|      dynamic_cast<ir::splat_inst*>(root)){ |  | ||||||
|     return; |   bld.set_insert_point(root); | ||||||
|   } |   ir::instruction *new_root = bld.insert(root->clone()); | ||||||
|   for(ir::value *op: root->ops()){ |   for(ir::value *op: root->ops()){ | ||||||
|     ir::user *u = dynamic_cast<ir::user*>(op); |     ir::instruction *i = dynamic_cast<ir::instruction*>(op); | ||||||
|     if(!u) |     if(!i || i->get_id() == ir::INST_REDUCE) | ||||||
|       continue; |       continue; | ||||||
|     extract_retile_chain(u, result, depth + 1, seen); |     ir::instruction* new_op = rematerialize(bld, i, seen); | ||||||
|  |     new_root->replace_uses_of_with(op, new_op); | ||||||
|   } |   } | ||||||
|  |   return new_root; | ||||||
| } | } | ||||||
|  |  | ||||||
| void disassociate::run(ir::module &mod) { | void disassociate::run(ir::module &mod) { | ||||||
|   ir::builder &bld = mod.get_builder(); |   ir::builder &bld = mod.get_builder(); | ||||||
|  |  | ||||||
|   std::map<ir::user*, std::map<int, std::set<ir::user*>>> clone_info; | //  ir::for_each_instruction(mod, [&](ir::instruction *i){ | ||||||
|  | //    bld.set_insert_point(i); | ||||||
|  | //    for(ir::value* op: i->ops()){ | ||||||
|  | //      auto reshape = dynamic_cast<ir::make_range*>(op); | ||||||
|  | //      if(!reshape) | ||||||
|  | //        continue; | ||||||
|  | //      ir::instruction* new_op = bld.insert(reshape->clone()); | ||||||
|  | //      i->replace_uses_of_with(op, new_op); | ||||||
|  | //    } | ||||||
|  | //  }); | ||||||
|  |  | ||||||
|  |  | ||||||
|   ir::for_each_instruction(mod, [&](ir::instruction *i){ |   ir::for_each_instruction(mod, [&](ir::instruction *i){ | ||||||
|     if(dynamic_cast<ir::reshape_inst*>(i)){ |     if(dynamic_cast<ir::reshape_inst*>(i) || dynamic_cast<ir::splat_inst*>(i)){ | ||||||
|       ir::value* op = i->get_operand(0); |  | ||||||
|       if(!dynamic_cast<ir::user*>(op)) |  | ||||||
|         return; |  | ||||||
|       if(op->get_type()->get_tile_rank() > i->get_type()->get_tile_rank()) |  | ||||||
|         return; |  | ||||||
|       std::map<int, std::set<ir::user*>> chains; |  | ||||||
|       std::set<ir::value*> seen; |       std::set<ir::value*> seen; | ||||||
|       extract_retile_chain(i, chains, 0, seen); |       ir::instruction* new_i = rematerialize(bld, i, seen); | ||||||
|       if(chains.size()) |       i->replace_all_uses_with(new_i); | ||||||
|         clone_info[i] = chains; |  | ||||||
|     } |     } | ||||||
|   }); |   }); | ||||||
|  |  | ||||||
|   for(const auto& x: clone_info){ |  | ||||||
|     int depth = 1; |  | ||||||
|     std::map<ir::instruction*, ir::instruction*> clone_map; |  | ||||||
|     while(x.second.find(depth) != x.second.end()){ |  | ||||||
|       // clone all users |  | ||||||
|       const auto& remat = x.second.at(depth); |  | ||||||
|       for(ir::user* u: remat){ |  | ||||||
|         ir::instruction *y = (ir::instruction*)u; |  | ||||||
|         ir::instruction *cloned = y->clone(); |  | ||||||
|         bld.set_insert_point(y); |  | ||||||
|         bld.insert(cloned); |  | ||||||
|         clone_map[y] = cloned; |  | ||||||
|         // replace operands of parents |  | ||||||
|         if(depth > 1) |  | ||||||
|           for(ir::user* ux: x.second.at(depth - 1)) |  | ||||||
|             clone_map.at((ir::instruction*)ux)->replace_uses_of_with(y, cloned); |  | ||||||
|         else |  | ||||||
|           x.first->replace_uses_of_with(y, cloned); |  | ||||||
|       } |  | ||||||
|       depth += 1; |  | ||||||
|     } |  | ||||||
|   } |  | ||||||
|  |  | ||||||
|  |  | ||||||
| } | } | ||||||
|  |  | ||||||
|   | |||||||
| @@ -211,6 +211,42 @@ bool peephole::rewrite_select_masked_load(ir::instruction *value, ir::builder& b | |||||||
|   return true; |   return true; | ||||||
| } | } | ||||||
|  |  | ||||||
|  | bool peephole::rewrite_cvt_layout(ir::instruction *value, ir::builder& builder){ | ||||||
|  |   auto cvt = dynamic_cast<ir::cvt_layout_inst*>(value); | ||||||
|  |   if(!cvt) | ||||||
|  |     return false; | ||||||
|  |   ir::instruction* op = dynamic_cast<ir::instruction*>(cvt->get_operand(0)); | ||||||
|  |   if(!op) | ||||||
|  |     return false; | ||||||
|  |   // convert(elementwise(x, y)) = elementwise(convert(x), convert(y)) | ||||||
|  |   if(op->get_id() == ir::INST_BINOP){ | ||||||
|  |     for(size_t i = 0; i < op->get_num_operands(); i++){ | ||||||
|  |       ir::value* arg_i = op->get_operand(i); | ||||||
|  |       builder.set_insert_point(op); | ||||||
|  |       // create new layout transform | ||||||
|  |       ir::instruction* new_arg_i = cvt->clone(); | ||||||
|  |       layouts_->copy(new_arg_i, op); | ||||||
|  |       builder.insert(new_arg_i); | ||||||
|  |       // set the right args | ||||||
|  |       new_arg_i->replace_uses_of_with(new_arg_i->get_operand(0), arg_i); | ||||||
|  |       op->replace_uses_of_with(arg_i, new_arg_i); | ||||||
|  |     } | ||||||
|  |     cvt->replace_all_uses_with(op); | ||||||
|  |     return true; | ||||||
|  |   } | ||||||
|  |   auto cvt_op = dynamic_cast<ir::cvt_layout_inst*>(op); | ||||||
|  |   if(!cvt_op) | ||||||
|  |     return false; | ||||||
|  |   // convert1(convert2(x)) if convert1 is the inverse of convert2 | ||||||
|  |   ir::value* op_op = cvt_op->get_operand(0); | ||||||
|  |   if(layouts_->has(cvt) && layouts_->has(op_op) && | ||||||
|  |      layouts_->get(cvt) && layouts_->get(op_op)){ | ||||||
|  |     cvt->replace_all_uses_with(op_op); | ||||||
|  |     return true; | ||||||
|  |   } | ||||||
|  |   return false; | ||||||
|  | } | ||||||
|  |  | ||||||
| void peephole::run(ir::module &mod) { | void peephole::run(ir::module &mod) { | ||||||
|   ir::builder &builder = mod.get_builder(); |   ir::builder &builder = mod.get_builder(); | ||||||
|   // keep track of whether any modification was made |   // keep track of whether any modification was made | ||||||
| @@ -248,6 +284,7 @@ void peephole::run(ir::module &mod) { | |||||||
|       was_modified = was_modified || rewrite_unit_red(i, builder); |       was_modified = was_modified || rewrite_unit_red(i, builder); | ||||||
|       was_modified = was_modified || rewrite_gep_ptr_min_off_plus_off(i, builder); |       was_modified = was_modified || rewrite_gep_ptr_min_off_plus_off(i, builder); | ||||||
|       was_modified = was_modified || rewrite_select_masked_load(i, builder); |       was_modified = was_modified || rewrite_select_masked_load(i, builder); | ||||||
|  |       was_modified = was_modified || rewrite_cvt_layout(i, builder); | ||||||
|       if(tgt_->as_nvidia()->sm() >= 80) |       if(tgt_->as_nvidia()->sm() >= 80) | ||||||
|         was_modified = was_modified || rewrite_load_to_shared(i, builder); |         was_modified = was_modified || rewrite_load_to_shared(i, builder); | ||||||
|       if(was_modified) |       if(was_modified) | ||||||
|   | |||||||
| @@ -126,12 +126,6 @@ bool dispatch::nvmlinit(){ | |||||||
|   return res; |   return res; | ||||||
| } | } | ||||||
|  |  | ||||||
| bool dispatch::spvllvminit(){ |  | ||||||
|   if(spvllvm_==nullptr) |  | ||||||
|     spvllvm_ = dlopen("libLLVMSPIRVLib.so", RTLD_LAZY); |  | ||||||
|   return spvllvm_ != nullptr; |  | ||||||
| } |  | ||||||
|  |  | ||||||
| //CUDA | //CUDA | ||||||
| CUDA_DEFINE1(CUresult, cuCtxDestroy_v2, CUcontext) | CUDA_DEFINE1(CUresult, cuCtxDestroy_v2, CUcontext) | ||||||
| CUDA_DEFINE2(CUresult, cuEventCreate, CUevent *, unsigned int) | CUDA_DEFINE2(CUresult, cuEventCreate, CUevent *, unsigned int) | ||||||
| @@ -185,14 +179,6 @@ NVML_DEFINE3(nvmlReturn_t, nvmlDeviceGetClockInfo, nvmlDevice_t, nvmlClockType_t | |||||||
| NVML_DEFINE3(nvmlReturn_t, nvmlDeviceGetMaxClockInfo, nvmlDevice_t, nvmlClockType_t, unsigned int*) | NVML_DEFINE3(nvmlReturn_t, nvmlDeviceGetMaxClockInfo, nvmlDevice_t, nvmlClockType_t, unsigned int*) | ||||||
| NVML_DEFINE3(nvmlReturn_t, nvmlDeviceSetApplicationsClocks, nvmlDevice_t, unsigned int, unsigned int) | NVML_DEFINE3(nvmlReturn_t, nvmlDeviceSetApplicationsClocks, nvmlDevice_t, unsigned int, unsigned int) | ||||||
|  |  | ||||||
| // LLVM to SPIR-V |  | ||||||
| int dispatch::initializeLLVMToSPIRVPass(llvm::PassRegistry ®istry){ |  | ||||||
|   return f_impl<dispatch::spvllvminit>(spvllvm_, initializeLLVMToSPIRVPass, initializeLLVMToSPIRVPass_, "initializeLLVMToSPIRVPass", std::ref(registry)); |  | ||||||
| } |  | ||||||
|  |  | ||||||
| bool dispatch::writeSpirv(llvm::Module *M, std::ostream &OS, std::string &ErrMsg){ |  | ||||||
|   return f_impl<dispatch::spvllvminit>(spvllvm_, writeSpirv, writeSpirv_, "writeSpirv", M, std::ref(OS), std::ref(ErrMsg)); |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // Release | // Release | ||||||
| void dispatch::release(){ | void dispatch::release(){ | ||||||
| @@ -204,7 +190,6 @@ void dispatch::release(){ | |||||||
|  |  | ||||||
| void* dispatch::cuda_; | void* dispatch::cuda_; | ||||||
| void* dispatch::nvml_; | void* dispatch::nvml_; | ||||||
| void* dispatch::spvllvm_; |  | ||||||
|  |  | ||||||
| //CUDA | //CUDA | ||||||
| void* dispatch::cuCtxGetCurrent_; | void* dispatch::cuCtxGetCurrent_; | ||||||
| @@ -261,9 +246,5 @@ void* dispatch::nvmlDeviceGetClockInfo_; | |||||||
| void* dispatch::nvmlDeviceGetMaxClockInfo_; | void* dispatch::nvmlDeviceGetMaxClockInfo_; | ||||||
| void* dispatch::nvmlDeviceSetApplicationsClocks_; | void* dispatch::nvmlDeviceSetApplicationsClocks_; | ||||||
|  |  | ||||||
| // SPIR-V |  | ||||||
| void* dispatch::initializeLLVMToSPIRVPass_; |  | ||||||
| void* dispatch::writeSpirv_; |  | ||||||
|  |  | ||||||
| } | } | ||||||
| } | } | ||||||
|   | |||||||
| @@ -374,12 +374,15 @@ ir::value *dispatch::broadcast(ir::value *input, shape_t shape, ir::builder *bui | |||||||
|   auto src_shape = input->get_type()->get_block_shapes(); |   auto src_shape = input->get_type()->get_block_shapes(); | ||||||
|   if (src_shape.size() != shape.size()) |   if (src_shape.size() != shape.size()) | ||||||
|     throw std::runtime_error("Cannot broadcast"); |     throw std::runtime_error("Cannot broadcast"); | ||||||
|  |   if(shape == src_shape) | ||||||
|  |     return input; | ||||||
|   return builder->create_broadcast(input, shape); |   return builder->create_broadcast(input, shape); | ||||||
| } | } | ||||||
|  |  | ||||||
| std::tuple<ir::value*, ir::value*> dispatch::broadcast(ir::value *lhs, ir::value* rhs, ir::builder *builder) { | std::tuple<ir::value*, ir::value*> dispatch::broadcast(ir::value *lhs, ir::value* rhs, ir::builder *builder) { | ||||||
|   ir::type *lhs_ty = lhs->get_type(); |   ir::type *lhs_ty = lhs->get_type(); | ||||||
|   ir::type *rhs_ty = rhs->get_type(); |   ir::type *rhs_ty = rhs->get_type(); | ||||||
|  |  | ||||||
|   // make_shape_compatible(block, scalar) |   // make_shape_compatible(block, scalar) | ||||||
|   if (lhs_ty->is_block_ty() && !rhs_ty->is_block_ty()) |   if (lhs_ty->is_block_ty() && !rhs_ty->is_block_ty()) | ||||||
|     rhs = builder->create_splat(rhs, lhs_ty->get_block_shapes()); |     rhs = builder->create_splat(rhs, lhs_ty->get_block_shapes()); | ||||||
|   | |||||||
| @@ -806,6 +806,11 @@ instruction* log_inst::create(value *val, const std::string& name, instruction * | |||||||
| //                               intrinsic instructions | //                               intrinsic instructions | ||||||
| //===----------------------------------------------------------------------===// | //===----------------------------------------------------------------------===// | ||||||
|  |  | ||||||
|  | // cvt_scanline | ||||||
|  | cvt_layout_inst* cvt_layout_inst::create(value *arg, const std::string &name, instruction *next) { | ||||||
|  |   return new cvt_layout_inst(arg->get_type(), INST_CVT_LAYOUT, arg, name, next); | ||||||
|  | } | ||||||
|  |  | ||||||
| // copy to shared | // copy to shared | ||||||
| copy_to_shared_inst* copy_to_shared_inst::create(value *arg, const std::string &name, | copy_to_shared_inst* copy_to_shared_inst::create(value *arg, const std::string &name, | ||||||
|                                                  instruction *next) { |                                                  instruction *next) { | ||||||
| @@ -818,13 +823,6 @@ copy_from_shared_inst* copy_from_shared_inst::create(value *arg, const std::stri | |||||||
|   return new copy_from_shared_inst(arg->get_type(), INST_COPY_FROM_SHARED, arg, name, next); |   return new copy_from_shared_inst(arg->get_type(), INST_COPY_FROM_SHARED, arg, name, next); | ||||||
| } | } | ||||||
|  |  | ||||||
| // recoalesce |  | ||||||
| recoalesce_inst* recoalesce_inst::create(value *arg, const std::string &name, instruction *next) { |  | ||||||
|   return new recoalesce_inst(arg->get_type(), INST_RECOALESCE, arg, name, next); |  | ||||||
| } |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| // barrier | // barrier | ||||||
| barrier_inst::barrier_inst(context &ctx, const std::string &name, | barrier_inst::barrier_inst(context &ctx, const std::string &name, | ||||||
|                                                        instruction *next) |                                                        instruction *next) | ||||||
|   | |||||||
| @@ -363,6 +363,133 @@ def test_reduce1d(dtype, shape, device='cuda'): | |||||||
|     triton.testing.assert_almost_equal(z_tri, z_ref) |     triton.testing.assert_almost_equal(z_tri, z_ref) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @pytest.mark.parametrize("dtype, shape, axis",  | ||||||
|  |   [(dtype, shape, 1) \ | ||||||
|  |         for dtype in ['float32']\ | ||||||
|  |         for shape in [(1, 1024)]]) | ||||||
|  | def test_reduce2d(dtype, shape, axis, device='cuda'): | ||||||
|  |     dtype = cvt[dtype] | ||||||
|  |     # triton kernel | ||||||
|  |     @triton.jit | ||||||
|  |     def kernel(X, Z, **meta): | ||||||
|  |         range_m = tl.arange(0, meta['BLOCK_M']) | ||||||
|  |         range_n = tl.arange(0, meta['BLOCK_N']) | ||||||
|  |         x = tl.load(X + range_m[:, None]*meta['BLOCK_N'] + range_n[None, :]) | ||||||
|  |         z = tl.sum(x, axis=meta['AXIS']) | ||||||
|  |         tl.store(Z + range_m, z) | ||||||
|  |     # input | ||||||
|  |     x = triton.testing.random(shape, dtype=dtype, device=device) | ||||||
|  |     # triton result | ||||||
|  |     z_tri = torch.empty((shape[0],), dtype=dtype, device=device) | ||||||
|  |     kernel[(1,)](x, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis) | ||||||
|  |     # torch result | ||||||
|  |     z_ref = torch.sum(x, axis=axis).to(dtype) | ||||||
|  |     # compare | ||||||
|  |     triton.testing.assert_almost_equal(z_tri, z_ref) | ||||||
|  |  | ||||||
|  | # --------------- | ||||||
|  | # test permute | ||||||
|  | # --------------- | ||||||
|  |  | ||||||
|  | # --------------- | ||||||
|  | # test permute | ||||||
|  | # --------------- | ||||||
|  |  | ||||||
|  | @pytest.mark.parametrize("dtype, shape, perm", | ||||||
|  |   [(dtype, shape, perm) \ | ||||||
|  |         for dtype in ['float32']\ | ||||||
|  |         for shape in [(128, 128)]\ | ||||||
|  |         for perm  in [(1, 0)]]) | ||||||
|  | def test_permute(dtype, shape, perm, device='cuda'): | ||||||
|  |     dtype = cvt[dtype] | ||||||
|  |     # triton kernel | ||||||
|  |     @triton.jit | ||||||
|  |     def kernel(X, stride_xm, stride_xn,  | ||||||
|  |                Z, stride_zm, stride_zn, **meta): | ||||||
|  |         BLOCK_M = meta['BLOCK_M'] | ||||||
|  |         BLOCK_N = meta['BLOCK_N'] | ||||||
|  |         off_m = tl.arange(0, BLOCK_M) | ||||||
|  |         off_n = tl.arange(0, BLOCK_N) | ||||||
|  |         Xs = X + off_m[:, None] * stride_xm + off_n[None, :] * stride_xn | ||||||
|  |         Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn | ||||||
|  |         tl.store(Zs, tl.load(Xs)) | ||||||
|  |     # input | ||||||
|  |     x = triton.testing.random(shape, dtype=dtype, device=device) | ||||||
|  |     # triton result | ||||||
|  |     z_tri = torch.empty_like(x) | ||||||
|  |     pgm = kernel[(1, 1)](x, x.stride(0), x.stride(1),  | ||||||
|  |                         z_tri, z_tri.stride(1), z_tri.stride(0),  | ||||||
|  |                         BLOCK_M=shape[0], BLOCK_N=shape[1]) | ||||||
|  |     # torch result | ||||||
|  |     z_ref = x.permute(*perm).contiguous() | ||||||
|  |     # compare | ||||||
|  |     triton.testing.assert_almost_equal(z_tri, z_ref) | ||||||
|  |     # parse ptx to make sure ld/st are vectorized | ||||||
|  |     ptx = pgm.asm('ptx') | ||||||
|  |     assert 'ld.global.v4' in ptx | ||||||
|  |     assert 'st.global.v4' in ptx | ||||||
|  |  | ||||||
|  | # --------------- | ||||||
|  | # test dot | ||||||
|  | # --------------- | ||||||
|  |  | ||||||
|  | @pytest.mark.parametrize("epilogue", ['none', 'add-matrix', 'add-rows', 'add-cols']) | ||||||
|  | def test_dot(epilogue, device='cuda'): | ||||||
|  |     torch.manual_seed(0) | ||||||
|  |     # triton kernel | ||||||
|  |     @triton.jit | ||||||
|  |     def kernel(X, stride_xm, stride_xk,  | ||||||
|  |                Y, stride_yk, stride_yn, | ||||||
|  |                Z, stride_zm, stride_zn, **meta): | ||||||
|  |         BLOCK_M = meta['BLOCK_M'] | ||||||
|  |         BLOCK_K = meta['BLOCK_K'] | ||||||
|  |         BLOCK_N = meta['BLOCK_N'] | ||||||
|  |         off_m = tl.arange(0, BLOCK_M) | ||||||
|  |         off_n = tl.arange(0, BLOCK_N) | ||||||
|  |         off_k = tl.arange(0, BLOCK_K) | ||||||
|  |         Xs = X + off_m[:, None] * stride_xm + off_k[None, :] * stride_xk | ||||||
|  |         Ys = Y + off_k[:, None] * stride_yk + off_n[None, :] * stride_yn | ||||||
|  |         Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn | ||||||
|  |         z = tl.dot(tl.load(Xs), tl.load(Ys)) | ||||||
|  |         if meta['ADD_MATRIX']: | ||||||
|  |             z += tl.load(Zs) | ||||||
|  |         if meta['ADD_ROWS']: | ||||||
|  |             ZRs = Z + off_m * stride_zm | ||||||
|  |             z += tl.load(ZRs)[:, None] | ||||||
|  |         if meta['ADD_COLS']: | ||||||
|  |             ZCs = Z + off_n * stride_zn  | ||||||
|  |             z += tl.load(ZCs)[None, :] | ||||||
|  |         tl.store(Zs, z) | ||||||
|  |     # input | ||||||
|  |     M, N, K = 64, 64, 32 | ||||||
|  |     x = triton.testing.random((M, K), dtype=torch.float16, device=device) | ||||||
|  |     y = triton.testing.random((K, N), dtype=torch.float16, device=device) | ||||||
|  |     # triton result | ||||||
|  |     z = triton.testing.random((M, N), dtype=torch.float16, device=device) | ||||||
|  |     z_tri = z.clone() | ||||||
|  |     pgm = kernel[(1, 1)](x, x.stride(0), x.stride(1), | ||||||
|  |                          y, y.stride(0), y.stride(1), | ||||||
|  |                          z_tri, z_tri.stride(0), z_tri.stride(1), | ||||||
|  |                          BLOCK_M=M, BLOCK_K=K, BLOCK_N=N, | ||||||
|  |                          ADD_MATRIX = epilogue=='add-matrix', | ||||||
|  |                          ADD_ROWS = epilogue=='add-rows', | ||||||
|  |                          ADD_COLS = epilogue=='add-cols') | ||||||
|  |     # torch result | ||||||
|  |     z_ref = torch.matmul(x.float(), y.float()) | ||||||
|  |     if epilogue == 'add-matrix': | ||||||
|  |         z_ref += z | ||||||
|  |     if epilogue == 'add-rows': | ||||||
|  |         z_ref += z[:,0][:, None] | ||||||
|  |     if epilogue == 'add-cols': | ||||||
|  |         z_ref += z[0,:][None, :] | ||||||
|  |     z_ref = z_ref.to(torch.float16) | ||||||
|  |     # compare | ||||||
|  |     ptx = pgm.asm('ptx') | ||||||
|  |     # print(ptx) | ||||||
|  |     triton.testing.assert_almost_equal(z_tri, z_ref) | ||||||
|  |     # make sure ld/st are vectorized | ||||||
|  |     assert 'ld.global.v4' in ptx | ||||||
|  |     assert 'st.global.v4' in ptx | ||||||
|  |  | ||||||
|  |  | ||||||
| # --------------- | # --------------- | ||||||
|   | |||||||
| @@ -624,6 +624,14 @@ def max_contiguous(input, value, _builder=None): | |||||||
|     return frontend.max_contiguous(input, value, _builder) |     return frontend.max_contiguous(input, value, _builder) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @builtin | ||||||
|  | def max_contiguous(input, value, _builder=None): | ||||||
|  |     """ | ||||||
|  |     Let the compiler knows that the `value` first values in :code:`input` are contiguous.  | ||||||
|  |     """ | ||||||
|  |     return frontend.max_contiguous(input, value, _builder) | ||||||
|  |  | ||||||
|  |  | ||||||
| # ----------------------- | # ----------------------- | ||||||
| # Standard library | # Standard library | ||||||
| # ----------------------- | # ----------------------- | ||||||
|   | |||||||
| @@ -46,8 +46,8 @@ def _kernel(A, B, C, M, N, K, | |||||||
|     pid_n = (pid % width) // (group_size) |     pid_n = (pid % width) // (group_size) | ||||||
|     # do matrix multiplication |     # do matrix multiplication | ||||||
|     rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) |     rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) | ||||||
|     ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) |  | ||||||
|     rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) |     rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) | ||||||
|  |     ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) | ||||||
|     rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) |     rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) | ||||||
|     rk = tl.arange(0, BLOCK_K) |     rk = tl.arange(0, BLOCK_K) | ||||||
|     # pointers |     # pointers | ||||||
|   | |||||||
| @@ -87,6 +87,7 @@ def assert_allclose(x, y, tol=1e-2): | |||||||
|  |  | ||||||
|  |  | ||||||
| def random(shape, dtype, device): | def random(shape, dtype, device): | ||||||
|  |     torch.manual_seed(0) | ||||||
|     if isinstance(shape, int): |     if isinstance(shape, int): | ||||||
|         shape = (shape, ) |         shape = (shape, ) | ||||||
|     if dtype == torch.bool: |     if dtype == torch.bool: | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user