diff --git a/include/triton/codegen/analysis/layout.h b/include/triton/codegen/analysis/layout.h index 833e96767..4d12e34c0 100644 --- a/include/triton/codegen/analysis/layout.h +++ b/include/triton/codegen/analysis/layout.h @@ -93,7 +93,22 @@ protected: shape_t shape_; }; -class mma_layout: public data_layout { +class distributed_layout: public data_layout{ +public: + distributed_layout(id_t id, + const std::vector& axes, + const std::vector& shape, + const std::vector& values, + analysis::align* align); + + int shape_per_cta(size_t k) { return shape_per_cta_.at(k); } + int rep_per_cta(size_t k) { return shape_[k] / shape_per_cta_[k]; } + +protected: + std::vector shape_per_cta_; +}; + +class mma_layout: public distributed_layout { public: mma_layout(size_t num_warps, const std::vector& axes, @@ -107,7 +122,6 @@ public: int fpw(size_t k) { return fpw_.at(k); } int wpt(size_t k) { return wpt_.at(k); } int spw(size_t k) { return spw_.at(k); } - int spt(size_t k) { return spt_.at(k); } int rep(size_t k) { return rep_.at(k); } private: @@ -123,7 +137,7 @@ private: std::vector rep_; }; -struct scanline_layout: public data_layout { +struct scanline_layout: public distributed_layout { scanline_layout(size_t num_warps, const std::vector& axes, const std::vector& shape, @@ -219,6 +233,7 @@ public: // accessors unsigned layout_of(ir::value *value) const { return groups_.at(value); } + bool has(ir::value* value) const { return groups_.find(value) != groups_.end(); } const std::vector& values_of(unsigned id) const { return values_.at(id); } size_t num_layouts() const { return values_.size();} data_layout* get(size_t id) { return layouts_.at(id); } @@ -226,7 +241,7 @@ public: std::map &get_all() { return layouts_; } bool has_tmp(ir::value* i) { return tmp_.find(i) != tmp_.end(); } int tmp(ir::value* i) { return tmp_.at(i);} - + void copy(ir::value* dst, ir::value* src) { groups_[dst] = groups_[src]; } // execution void run(ir::module &mod); diff --git a/include/triton/codegen/selection/generator.h b/include/triton/codegen/selection/generator.h index ddde9198a..f2e7263e3 100644 --- a/include/triton/codegen/selection/generator.h +++ b/include/triton/codegen/selection/generator.h @@ -171,7 +171,8 @@ public: void visit_reducend_inst(ir::reduce_inst*, std::function, Value*); void visit_reduce_inst(ir::reduce_inst*); void visit_select_inst(ir::select_inst*); - void visit_recoalesce_inst(ir::recoalesce_inst*); + void visit_layout_convert(ir::value *out, ir::value *in); + void visit_cvt_layout_inst(ir::cvt_layout_inst*); void visit_masked_load_async_inst(ir::masked_load_async_inst*); void visit_copy_to_shared_inst(ir::copy_to_shared_inst*); void visit_copy_from_shared_inst(ir::copy_from_shared_inst*); diff --git a/include/triton/codegen/transform/coalesce.h b/include/triton/codegen/transform/coalesce.h index 1b15306f1..869ca9975 100644 --- a/include/triton/codegen/transform/coalesce.h +++ b/include/triton/codegen/transform/coalesce.h @@ -33,6 +33,7 @@ private: public: coalesce(analysis::align* align, triton::codegen::analysis::layouts *layouts); + triton::ir::value *simplify(ir::instruction* i, triton::ir::builder &builder); void run(ir::module &mod); private: diff --git a/include/triton/codegen/transform/peephole.h b/include/triton/codegen/transform/peephole.h index c14c74702..1b015fb41 100644 --- a/include/triton/codegen/transform/peephole.h +++ b/include/triton/codegen/transform/peephole.h @@ -34,8 +34,7 @@ private: bool rewrite_gep_ptr_min_off_plus_off(ir::instruction *value, ir::builder& builder); bool rewrite_select_masked_load(ir::instruction *value, ir::builder& builder); bool rewrite_load_to_shared(ir::instruction *value, ir::builder& builder); - -private: + bool rewrite_cvt_layout(ir::instruction *value, ir::builder& builder); public: peephole(target* tgt, analysis::layouts* layouts): tgt_(tgt), layouts_(layouts) {} diff --git a/include/triton/driver/dispatch.h b/include/triton/driver/dispatch.h index ad6574f44..9b79d714d 100755 --- a/include/triton/driver/dispatch.h +++ b/include/triton/driver/dispatch.h @@ -60,136 +60,151 @@ protected: public: static bool nvmlinit(); static bool cuinit(); - static bool spvllvminit(); static void release(); - // CUDA + /* ------------------- * + * CUDA + * ------------------- */ + // context management + static CUresult cuInit(unsigned int Flags); static CUresult cuCtxGetCurrent(CUcontext *pctx); static CUresult cuCtxSetCurrent(CUcontext ctx); static CUresult cuCtxDestroy_v2(CUcontext ctx); - static CUresult cuEventCreate(CUevent *phEvent, unsigned int Flags); - static CUresult cuDeviceGet(CUdevice *device, int ordinal); - static CUresult cuMemcpyDtoH_v2(void *dstHost, CUdeviceptr srcDevice, size_t ByteCount); - static CUresult cuStreamCreate(CUstream *phStream, unsigned int Flags); - static CUresult cuEventElapsedTime(float *pMilliseconds, CUevent hStart, CUevent hEnd); - static CUresult cuMemFree_v2(CUdeviceptr dptr); - static CUresult cuMemcpyDtoHAsync_v2(void *dstHost, CUdeviceptr srcDevice, size_t ByteCount, CUstream hStream); + static CUresult cuCtxCreate_v2(CUcontext *pctx, unsigned int flags, CUdevice dev); + static CUresult cuCtxPushCurrent_v2(CUcontext ctx); + static CUresult cuCtxPopCurrent_v2(CUcontext *pctx); + static CUresult cuCtxGetDevice(CUdevice* result); + static CUresult cuCtxEnablePeerAccess(CUcontext peerContext, unsigned int flags); static CUresult cuDriverGetVersion(int *driverVersion); + // device management + static CUresult cuDeviceGet(CUdevice *device, int ordinal); static CUresult cuDeviceGetName(char *name, int len, CUdevice dev); static CUresult cuDeviceGetPCIBusId(char *id, int len, CUdevice dev); - static CUresult cuModuleGetGlobal_v2(CUdeviceptr *dptr, size_t* bytes, CUmodule hmod, const char *name); - static CUresult cuMemcpyHtoDAsync_v2(CUdeviceptr dstDevice, const void *srcHost, size_t ByteCount, CUstream hStream); - static CUresult cuModuleLoad(CUmodule *module, const char *fname); - static CUresult cuModuleLoadData(CUmodule* module, const void* image); - static CUresult cuLaunchKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes, CUstream hStream, void **kernelParams, void **extra); - static CUresult cuModuleUnload(CUmodule hmod); - static CUresult cuModuleLoadDataEx(CUmodule *module, const void *image, unsigned int numOptions, CUjit_option *options, void **optionValues); - + static CUresult cuDeviceGetAttribute(int *pi, CUdevice_attribute attrib, CUdevice dev); + static CUresult cuDeviceGetCount(int *count); + // link management static CUresult cuLinkAddData_v2(CUlinkState state, CUjitInputType type, void* data, size_t size, const char* name, unsigned int numOptions, CUjit_option* options, void** optionValues); static CUresult cuLinkCreate_v2(unsigned int numOptions, CUjit_option* options, void** optionValues, CUlinkState* stateOut); static CUresult cuLinkComplete(CUlinkState state, void** cubinOut, size_t* sizeOut); static CUresult cuLinkDestroy(CUlinkState state); - - static CUresult cuDeviceGetAttribute(int *pi, CUdevice_attribute attrib, CUdevice dev); - static CUresult cuDeviceGetCount(int *count); - static CUresult cuMemcpyHtoD_v2(CUdeviceptr dstDevice, const void *srcHost, size_t ByteCount); - static CUresult cuInit(unsigned int Flags); - static CUresult cuEventRecord(CUevent hEvent, CUstream hStream); - static CUresult cuCtxCreate_v2(CUcontext *pctx, unsigned int flags, CUdevice dev); - static CUresult cuCtxPushCurrent_v2(CUcontext ctx); - static CUresult cuCtxPopCurrent_v2(CUcontext *pctx); + // module management + static CUresult cuModuleGetGlobal_v2(CUdeviceptr *dptr, size_t* bytes, CUmodule hmod, const char *name); + static CUresult cuModuleLoad(CUmodule *module, const char *fname); + static CUresult cuModuleLoadData(CUmodule* module, const void* image); + static CUresult cuModuleUnload(CUmodule hmod); + static CUresult cuModuleLoadDataEx(CUmodule *module, const void *image, unsigned int numOptions, CUjit_option *options, void **optionValues); static CUresult cuModuleGetFunction(CUfunction *hfunc, CUmodule hmod, const char *name); + // stream management + static CUresult cuStreamCreate(CUstream *phStream, unsigned int Flags); static CUresult cuStreamSynchronize(CUstream hStream); static CUresult cuStreamGetCtx(CUstream hStream, CUcontext* pctx); static CUresult cuStreamDestroy_v2(CUstream hStream); - static CUresult cuEventDestroy_v2(CUevent hEvent); - static CUresult cuMemAlloc_v2(CUdeviceptr *dptr, size_t bytesize); - static CUresult cuPointerGetAttribute(void * data, CUpointer_attribute attribute, CUdeviceptr ptr); - static CUresult cuCtxGetDevice(CUdevice* result); - static CUresult cuMemsetD8Async(CUdeviceptr dst, unsigned char x, size_t N, CUstream stream); + static CUresult cuLaunchKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes, CUstream hStream, void **kernelParams, void **extra); + // function management static CUresult cuFuncGetAttribute(int* pi, CUfunction_attribute attrib, CUfunction hfunc); static CUresult cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value); static CUresult cuFuncSetCacheConfig(CUfunction hfunc, CUfunc_cache config); - static CUresult cuCtxEnablePeerAccess(CUcontext peerContext, unsigned int flags); - // NVML + // memory management + static CUresult cuMemAlloc_v2(CUdeviceptr *dptr, size_t bytesize); + static CUresult cuPointerGetAttribute(void * data, CUpointer_attribute attribute, CUdeviceptr ptr); + static CUresult cuMemsetD8Async(CUdeviceptr dst, unsigned char x, size_t N, CUstream stream); + static CUresult cuMemcpyDtoH_v2(void *dstHost, CUdeviceptr srcDevice, size_t ByteCount); + static CUresult cuMemFree_v2(CUdeviceptr dptr); + static CUresult cuMemcpyDtoHAsync_v2(void *dstHost, CUdeviceptr srcDevice, size_t ByteCount, CUstream hStream); + static CUresult cuMemcpyHtoDAsync_v2(CUdeviceptr dstDevice, const void *srcHost, size_t ByteCount, CUstream hStream); + static CUresult cuMemcpyHtoD_v2(CUdeviceptr dstDevice, const void *srcHost, size_t ByteCount); + // event management + static CUresult cuEventCreate(CUevent *phEvent, unsigned int Flags); + static CUresult cuEventElapsedTime(float *pMilliseconds, CUevent hStart, CUevent hEnd); + static CUresult cuEventRecord(CUevent hEvent, CUstream hStream); + static CUresult cuEventDestroy_v2(CUevent hEvent); + + + /* ------------------- * + * NVML + * ------------------- */ static nvmlReturn_t nvmlDeviceGetHandleByPciBusId_v2( const char* pciBusId, nvmlDevice_t* device); static nvmlReturn_t nvmlDeviceGetClockInfo(nvmlDevice_t device, nvmlClockType_t type, unsigned int *clock); static nvmlReturn_t nvmlDeviceGetMaxClockInfo(nvmlDevice_t device, nvmlClockType_t type, unsigned int *clock); static nvmlReturn_t nvmlDeviceSetApplicationsClocks(nvmlDevice_t device, unsigned int mem_clock, unsigned int sm_clock); - // SPIR-V libraries - static int initializeLLVMToSPIRVPass(llvm::PassRegistry &); - static bool writeSpirv(llvm::Module *M, std::ostream &OS, std::string &ErrMsg); - private: // Libraries static void* cuda_; static void* nvml_; - static void* vulkan_; - static void* spvllvm_; - static void* spvcross_; - static void* opengl_; - // CUDA functions + /* ------------------- * + * CUDA + * ------------------- */ + // context management static void* cuCtxGetCurrent_; static void* cuCtxSetCurrent_; static void* cuCtxDestroy_v2_; - static void* cuEventCreate_; - static void* cuDeviceGet_; - static void* cuMemcpyDtoH_v2_; - static void* cuStreamCreate_; - static void* cuEventElapsedTime_; - static void* cuMemFree_v2_; - static void* cuMemcpyDtoHAsync_v2_; + static void* cuCtxCreate_v2_; + static void* cuCtxGetDevice_; + static void* cuCtxPushCurrent_v2_; + static void* cuCtxPopCurrent_v2_; + static void* cuCtxEnablePeerAccess_; static void* cuDriverGetVersion_; + static void* cuInit_; + // device management + static void* cuDeviceGet_; static void* cuDeviceGetName_; static void* cuDeviceGetPCIBusId_; - static void* cuModuleGetGlobal_v2_; - static void* cuMemcpyHtoDAsync_v2_; - static void* cuModuleLoad_; - static void* cuLaunchKernel_; - static void* cuModuleUnload_; - static void* cuModuleLoadDataEx_; + static void* cuDeviceGetAttribute_; + static void* cuDeviceGetCount_; + // link management static void* cuLinkAddData_v2_; static void* cuLinkCreate_v2_; static void* cuLinkDestroy_; - static void* cuModuleLoadData_; static void* cuLinkComplete_; - static void* cuDeviceGetAttribute_; - static void* cuDeviceGetCount_; - static void* cuMemcpyHtoD_v2_; - static void* cuInit_; - static void* cuEventRecord_; - static void* cuCtxCreate_v2_; + // module management + static void* cuModuleGetGlobal_v2_; + static void* cuModuleLoad_; + static void* cuModuleUnload_; + static void* cuModuleLoadDataEx_; + static void* cuModuleLoadData_; static void* cuModuleGetFunction_; + // stream management + static void* cuStreamCreate_; static void* cuStreamSynchronize_; static void* cuStreamDestroy_v2_; static void* cuStreamGetCtx_; - static void* cuEventDestroy_v2_; - static void* cuMemAlloc_v2_; - static void* cuPointerGetAttribute_; - static void* cuCtxGetDevice_; - static void* cuMemsetD8Async_; - static void* cuCtxPushCurrent_v2_; - static void* cuCtxPopCurrent_v2_; + static void* cuLaunchKernel_; + // function management static void* cuFuncGetAttribute_; static void* cuFuncSetAttribute_; static void* cuFuncSetCacheConfig_; - static void* cuCtxEnablePeerAccess_; - // NVML + // memory management + static void* cuMemcpyDtoH_v2_; + static void* cuMemFree_v2_; + static void* cuMemcpyDtoHAsync_v2_; + static void* cuMemcpyHtoDAsync_v2_; + static void* cuMemcpyHtoD_v2_; + static void* cuMemAlloc_v2_; + static void* cuMemsetD8Async_; + static void* cuPointerGetAttribute_; + // event management + static void* cuEventCreate_; + static void* cuEventElapsedTime_; + static void* cuEventRecord_; + static void* cuEventDestroy_v2_; + + + + + /* ------------------- * + * NVML + * ------------------- */ static void* nvmlInit_v2_; static void* nvmlDeviceGetHandleByPciBusId_v2_; static void* nvmlDeviceGetClockInfo_; static void* nvmlDeviceGetMaxClockInfo_; static void* nvmlDeviceSetApplicationsClocks_; - - // LLVM to SPIR-V - static void* initializeLLVMToSPIRVPass_; - static void* writeSpirv_; }; } diff --git a/include/triton/ir/enums.h b/include/triton/ir/enums.h index 48bdb7c66..5be63d4d2 100644 --- a/include/triton/ir/enums.h +++ b/include/triton/ir/enums.h @@ -153,6 +153,9 @@ enum value_id_t: unsigned { // intrinsics INST_COPY_TO_SHARED, INST_COPY_FROM_SHARED, + INST_CVT_LAYOUT, + INST_CVT_SCANLINE, + INST_DECOALESCE, INST_RECOALESCE, INST_BARRIER, INST_ASYNC_WAIT, diff --git a/include/triton/ir/instructions.h b/include/triton/ir/instructions.h index e3a389018..c9db25477 100644 --- a/include/triton/ir/instructions.h +++ b/include/triton/ir/instructions.h @@ -807,16 +807,15 @@ public: _TRITON_DEFINE_ACCEPT(copy_from_shared_inst) }; - -class recoalesce_inst: public unary_inst{ +class cvt_layout_inst: public unary_inst { private: using unary_inst::unary_inst; - std::string repr_impl() const { return "recoalesce_inst"; } + std::string repr_impl() const { return "cvt_layout_inst"; } public: - static recoalesce_inst* create(value *arg, const std::string &name = "", instruction *next = nullptr); - _TRITON_DEFINE_CLONE(recoalesce_inst) - _TRITON_DEFINE_ACCEPT(recoalesce_inst) + static cvt_layout_inst* create(value *arg, const std::string &name = "", instruction *next = nullptr); + _TRITON_DEFINE_CLONE(cvt_layout_inst) + _TRITON_DEFINE_ACCEPT(cvt_layout_inst) }; class barrier_inst: public instruction{ diff --git a/include/triton/ir/visitor.h b/include/triton/ir/visitor.h index a96211227..8073a6b66 100644 --- a/include/triton/ir/visitor.h +++ b/include/triton/ir/visitor.h @@ -64,7 +64,7 @@ class sqrt_inst; class reduce_inst; class select_inst; -class recoalesce_inst; +class cvt_layout_inst; class copy_to_shared_inst; class copy_from_shared_inst; class masked_load_async_inst; @@ -142,9 +142,11 @@ public: virtual void visit_reduce_inst(reduce_inst*) = 0; virtual void visit_select_inst(select_inst*) = 0; - virtual void visit_recoalesce_inst(recoalesce_inst*) = 0; + virtual void visit_cvt_layout_inst(cvt_layout_inst*) = 0; virtual void visit_copy_to_shared_inst(copy_to_shared_inst*) = 0; virtual void visit_copy_from_shared_inst(copy_from_shared_inst*) = 0; + + virtual void visit_masked_load_async_inst(masked_load_async_inst*)= 0; virtual void visit_barrier_inst(barrier_inst*) = 0; virtual void visit_async_wait_inst(async_wait_inst*) = 0; diff --git a/include/triton/tools/graph.h b/include/triton/tools/graph.h index b53e754cd..c2ba8d854 100644 --- a/include/triton/tools/graph.h +++ b/include/triton/tools/graph.h @@ -6,6 +6,7 @@ #include #include #include +#include namespace triton { namespace tools{ @@ -40,8 +41,9 @@ public: nmap->clear(); std::set nodes = nodes_; unsigned id = 0; - while(!nodes.empty()) + while(!nodes.empty()){ connected_components_impl(*nodes.begin(), nodes, nmap, cmap, id++); + } } void add_edge(node_t x, node_t y) { diff --git a/lib/codegen/analysis/allocation.cc b/lib/codegen/analysis/allocation.cc index d0a66543a..3af40c2cc 100644 --- a/lib/codegen/analysis/allocation.cc +++ b/lib/codegen/analysis/allocation.cc @@ -50,7 +50,6 @@ void allocation::run(ir::module &mod) { J.erase(j_it); } } - // Build interference graph std::map> interferences; for(shared_layout* x: V) @@ -66,13 +65,10 @@ void allocation::run(ir::module &mod) { && XS.intersect(YS)) interferences[x].insert(y); } - // Initialize colors std::map colors; for(shared_layout* X: V) colors[X] = (X==V[0])?0:-1; - - // First-fit graph coloring std::vector available(V.size()); for(shared_layout* x: V){ @@ -87,7 +83,6 @@ void allocation::run(ir::module &mod) { auto It = std::find(available.begin(), available.end(), true); colors[x] = std::distance(available.begin(), It); } - // Finalize allocation for(shared_layout* x: V){ unsigned Adj = 0; @@ -95,7 +90,6 @@ void allocation::run(ir::module &mod) { Adj = std::max(Adj, starts[y] + y->get_size()); offsets_[x] = starts[x] + colors[x] * Adj; } - // Save maximum size of induced memory space allocated_size_ = 0; for(shared_layout* x: V) diff --git a/lib/codegen/analysis/axes.cc b/lib/codegen/analysis/axes.cc index d68be1d82..13b8f8d05 100644 --- a/lib/codegen/analysis/axes.cc +++ b/lib/codegen/analysis/axes.cc @@ -105,17 +105,17 @@ void axes::update_graph_no_edge(ir::instruction *i) { void axes::update_graph(ir::instruction *i) { switch (i->get_id()) { - case ir::INST_REDUCE: return update_graph_reduce(i); - case ir::INST_RESHAPE: return update_graph_reshape(i); - case ir::INST_SPLAT: return update_graph_no_edge(i);; - case ir::INST_TRANS: return update_graph_trans(i); - case ir::INST_BROADCAST: return update_graph_broadcast(i); - case ir::INST_DOT: return update_graph_dot(i); - case ir::INST_COPY_TO_SHARED: return update_graph_no_edge(i); - case ir::INST_MASKED_LOAD_ASYNC:return update_graph_elementwise(i, false); - case ir::INST_COPY_FROM_SHARED: return update_graph_no_edge(i); - case ir::INST_RECOALESCE: return update_graph_no_edge(i); - default: return update_graph_elementwise(i); + case ir::INST_REDUCE: return update_graph_reduce(i); + case ir::INST_RESHAPE: return update_graph_reshape(i); + case ir::INST_SPLAT: return update_graph_no_edge(i);; + case ir::INST_TRANS: return update_graph_trans(i); + case ir::INST_BROADCAST: return update_graph_broadcast(i); + case ir::INST_DOT: return update_graph_dot(i); + case ir::INST_COPY_TO_SHARED: return update_graph_no_edge(i); + case ir::INST_MASKED_LOAD_ASYNC: return update_graph_elementwise(i, false); + case ir::INST_COPY_FROM_SHARED: return update_graph_no_edge(i); + case ir::INST_CVT_LAYOUT: return update_graph_no_edge(i); + default: return update_graph_elementwise(i); } return; } @@ -135,11 +135,15 @@ std::vector axes::get(ir::value *value) { void axes::run(ir::module &mod) { // make graph graph_.clear(); + axes_.clear(); ir::for_each_instruction(mod, [this](ir::instruction *x) { update_graph(x); }); // find connected components graph_.connected_components(nullptr, &axes_); + std::set uniq; + for(auto x: axes_) + uniq.insert(x.second); } } diff --git a/lib/codegen/analysis/layout.cc b/lib/codegen/analysis/layout.cc index 46f1649b0..1693eff42 100644 --- a/lib/codegen/analysis/layout.cc +++ b/lib/codegen/analysis/layout.cc @@ -109,9 +109,6 @@ data_layout::data_layout(id_t id, max_contiguous = curr; } } - bool is_recoalesce = false; - for(ir::value* v: values) - is_recoalesce = is_recoalesce || dynamic_cast(v); if(max_contiguous.size() > 0){ std::sort(order_.begin(), order_.end(), [&](unsigned a, unsigned b) { return max_contiguous[a] > max_contiguous[b]; @@ -129,6 +126,13 @@ int data_layout::find_axis(int to_find) const { } +distributed_layout::distributed_layout(id_t id, + const std::vector &axes, + const std::vector &shape, + const std::vector &values, + analysis::align* align): data_layout(id, axes, shape, values, align) +{ } + /* -------------------------------- * * MMA Layout * * -------------------------------- */ @@ -138,20 +142,11 @@ mma_layout::mma_layout(size_t num_warps, const std::vector& shape, const std::vector &values, analysis::align* align, target* tgt, - shared_layout *layout_a, shared_layout *layout_b): data_layout(MMA, axes, shape, values, align) { + shared_layout *layout_a, shared_layout *layout_b): distributed_layout(MMA, axes, shape, values, align) { /* fragments per warp */ // try to make things as square as possible to maximize data re-use if(tgt->as_nvidia()->sm() < 80){ fpw_ = {2, 2, 1}; -// std::vector fpw_nm1; -// unsigned num_fragments = std::min((shape_[0]/8)*(shape_[1]/8), 4); -// do { -// fpw_nm1 = fpw_; -// if(fpw_[0]*fpw_[1] < num_fragments) -// fpw_[0] = clamp(fpw_[0]*2, 1, shape_[0] / 8); -// if(fpw_[0]*fpw_[1] < num_fragments) -// fpw_[1] = clamp(fpw_[1]*2, 1, shape_[1] / 8); -// }while(fpw_nm1 != fpw_); auto ord_a = layout_a->get_order(); auto ord_b = layout_b->get_order(); bool is_a_row = ord_a[0] != 0; @@ -168,6 +163,7 @@ mma_layout::mma_layout(size_t num_warps, spw_ = {16, 8, 1}; rep_ = {2, 2, 1}; } + order_ = {0, 1}; /* warps per tile */ // try to make things as square as possible to maximize data re-use @@ -182,7 +178,7 @@ mma_layout::mma_layout(size_t num_warps, }while(wpt_nm1 != wpt_); /* shape per block */ - spt_ = {spw_[0]*wpt_[0], spw_[1]*wpt_[1], 1}; + shape_per_cta_ = {spw_[0]*wpt_[0], spw_[1]*wpt_[1], 1}; } @@ -194,7 +190,7 @@ scanline_layout::scanline_layout(size_t num_warps, const std::vector& axes, const std::vector& shape, const std::vector &values, - analysis::align* align, target *tgt): data_layout(SCANLINE, axes, shape, values, align){ + analysis::align* align, target *tgt): distributed_layout(SCANLINE, axes, shape, values, align){ unsigned size = std::accumulate(shape_.begin(), shape_.end(), 1, std::multiplies()); unsigned num_threads = tgt->is_gpu() ? num_warps * 32 : 1; nts_.resize(shape_.size()); @@ -230,6 +226,10 @@ scanline_layout::scanline_layout(size_t num_warps, mts_[i] = clamp(num_threads, 1, shape_[i] / nts_[i]); num_threads = num_threads / mts_[i]; } + + shape_per_cta_.resize(shape_.size()); + for(size_t d = 0; d < shape_.size(); d++) + shape_per_cta_[d] = mts_[d]*nts_[d]; } @@ -489,6 +489,9 @@ void layouts::create(size_t id, const std::vector& values) { void layouts::run(ir::module &mod) { // make graph graph_.clear(); + layouts_.clear(); + groups_.clear(); + ir::for_each_instruction(mod, [this](ir::instruction* i) { make_graph(i); }); @@ -515,23 +518,18 @@ void layouts::run(ir::module &mod) { layouts_[id] = new shared_layout(layout, axes_->get(arg), shapes, {red}, red->get_type()->get_scalar_ty(), align_); tmp_[red] = id; } - if(auto *recoalasce = dynamic_cast(i)){ - ir::value *val = recoalasce->get_operand(0); - mma_layout* in_layout = get(val)->to_mma(); - scanline_layout* out_layout = get(i)->to_scanline(); - if(!in_layout || !out_layout) - return; + if(auto *val = dynamic_cast(i)){ + distributed_layout* out_layout = dynamic_cast(get(val)); + distributed_layout* in_layout = dynamic_cast(get(i->get_operand(0))); id++; - ir::type::block_shapes_t in_shape = val->get_type()->get_block_shapes(); - ir::type::block_shapes_t shape(in_shape.size()); - size_t ld = out_layout->get_order(0); - shape[ld] = in_shape[ld]; - for(size_t k = 0; k < in_shape.size(); k++) - if(k != ld) - shape[k] = in_layout->to_mma()->spt(k); - // create layout - layouts_[id] = new shared_layout(out_layout, axes_->get(val), shape, {recoalasce}, val->get_type()->get_scalar_ty(), align_); - tmp_[recoalasce] = id; + size_t dim = val->get_type()->get_tile_rank(); + ir::type::block_shapes_t shape(dim); + for(size_t k = 0; k < dim; k++){ + shape[k] = std::max(in_layout->shape_per_cta(k), + out_layout->shape_per_cta(k)); + } + layouts_[id] = new shared_layout(out_layout, axes_->get(val), shape, {val}, val->get_type()->get_scalar_ty(), align_); + tmp_[val] = id; } if(auto *atom = dynamic_cast(i)){ id++; diff --git a/lib/codegen/pass.cc b/lib/codegen/pass.cc index ba4547d10..82fe61257 100644 --- a/lib/codegen/pass.cc +++ b/lib/codegen/pass.cc @@ -56,10 +56,8 @@ void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps, dce.run(ir); peephole.run(ir); dce.run(ir); - // ir::print(ir, std::cout); pipeline.run(ir); dce.run(ir); - // ir::print(ir, std::cout); disassociate.run(ir); dce.run(ir); align.run(ir); @@ -74,14 +72,15 @@ void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps, layouts.run(ir); coalesce.run(ir); dce.run(ir); +// exit(1); + align.run(ir); dce.run(ir); - if (target->is_gpu()) { -// reassociate.run(ir); + if (target->is_gpu()) cts.run(ir); - } dce.run(ir); align.run(ir); +// ir::print(ir, std::cout); axes.run(ir); layouts.run(ir); peephole.run(ir); @@ -93,10 +92,7 @@ void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps, liveness.run(ir); allocation.run(ir); prefetch_s.run(ir); -// ir::print(ir, std::cout); barriers.run(ir); -// ir::print(ir, std::cout); -// ir::print(ir, std::cout); isel.visit(ir, *llvm); mod = driver::module::create(dev, std::move(llvm)); ker = driver::kernel::create(&*mod, name.c_str()); diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index 7f189e4d2..eee50d2c2 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -586,7 +586,7 @@ void generator::visit_load_inst(ir::load_inst* x){ Type* ty = cvt(op->get_type()->get_scalar_ty()->get_pointer_element_ty()); // compute vector width size_t vec = 1; - if(op->get_type()->is_block_ty()){ + if(op->get_type()->is_block_ty() && op->get_type()->get_tile_rank() > 1){ auto ord = ords_.at(op); size_t aln = alignment_->get(op, ord[0]); size_t nts = layouts_->get(x)->to_scanline()->nts(ord[0]); @@ -626,10 +626,10 @@ void generator::visit_load_inst(ir::load_inst* x){ // ----- std::ostringstream asm_oss; asm_oss << "@$" << n_words; // predicate - if(force_nc_cache_) - asm_oss << " ld.global.nc"; - else - asm_oss << " ld.global.cg"; +// if(force_nc_cache_) + asm_oss << " ld.global"; +// else +// asm_oss << " ld.global.cg"; if(n_words > 1) asm_oss << ".v" << n_words; // vector width asm_oss << ".b" << width; // word size @@ -1058,7 +1058,8 @@ void generator::visit_mma884(ir::dot_inst* C, ir::value *A, ir::value *B, ir::va /* --------------------------------- */ BasicBlock* curr_bb = builder_->GetInsertBlock(); BasicBlock* entry = &curr_bb->getParent()->getEntryBlock(); - builder_->SetInsertPoint(entry->getTerminator()); + if(entry != curr_bb) + builder_->SetInsertPoint(entry->getTerminator()); Value* off_a0 = is_a_row ? offset_a_k_[layout_c] : offset_a_m_[layout_c]; Value* off_a1 = is_a_row ? offset_a_m_[layout_c] : offset_a_k_[layout_c]; Value* phase_a = urem(udiv(off_a1, i32(per_phase_a)), i32(max_phase_a)); @@ -1116,8 +1117,8 @@ void generator::visit_mma884(ir::dot_inst* C, ir::value *A, ir::value *B, ir::va for(indices_t idx: idxs_.at(C)) acc.push_back(vals_[D][idx]); - unsigned num_m = layout_c->rep(0) * shape_c[0] / layout_c->spt(0); - unsigned num_n = layout_c->rep(1) * shape_c[1] / layout_c->spt(1); + unsigned num_m = layout_c->rep(0) * shape_c[0] / layout_c->shape_per_cta(0); + unsigned num_n = layout_c->rep(1) * shape_c[1] / layout_c->shape_per_cta(1); // create mma & unpack result auto call_mma = [&](unsigned m, unsigned n, unsigned K) { @@ -1333,7 +1334,8 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir:: BasicBlock* CurrBB = builder_->GetInsertBlock(); BasicBlock* FirstBB = &CurrBB->getParent()->getEntryBlock(); - builder_->SetInsertPoint(FirstBB->getTerminator()); + if(FirstBB != CurrBB) + builder_->SetInsertPoint(FirstBB->getTerminator()); Value* thread = tgt_->get_local_id(mod_, *builder_, 0); Value *lane = urem(thread, i32(32)); @@ -1396,8 +1398,8 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir:: "{$10, $11, $12, $13};", "=f,=f,=f,=f,r,r,r,r,r,r,0,1,2,3", true); - unsigned num_rep_0 = shapes[0] / layout->spt(0); - unsigned num_rep_1 = shapes[1] / layout->spt(1); + unsigned num_rep_0 = shapes[0] / layout->shape_per_cta(0); + unsigned num_rep_1 = shapes[1] / layout->shape_per_cta(1); // create mma & unpack result auto call_mma = [&](unsigned m, unsigned n, unsigned K) { @@ -1626,8 +1628,8 @@ void generator::visit_fmadot(ir::dot_inst* C, ir::value* A, ir::value* B, ir::va std::map, Value*> has, hbs; for(unsigned k = 0; k < NK; k++){ int z = 0; - for(unsigned m = 0; m < shape_c[0]; m+=layout_c->mts(0)*layout_c->nts(0)) - for(unsigned n = 0; n < shape_c[1]; n+=layout_c->mts(1)*layout_c->nts(1)) + for(unsigned m = 0; m < shape_c[0]; m += layout_c->shape_per_cta(0)) + for(unsigned n = 0; n < shape_c[1]; n += layout_c->shape_per_cta(1)) for(unsigned mm = 0; mm < layout_c->nts(0); mm++) for(unsigned nn = 0; nn < layout_c->nts(1); nn++) { @@ -1818,6 +1820,7 @@ void generator::visit_reducend_inst(ir::reduce_inst* x, std::functionget_operand(0); - ir::block_type::block_shapes_t shape = rc->get_type()->get_block_shapes(); + + +void generator::visit_layout_convert(ir::value *out, ir::value *in){ + ir::block_type::block_shapes_t shape = out->get_type()->get_block_shapes(); // pointer to temporary shared memory - Type *ty = cvt(rc->get_type()->get_scalar_ty()); - // layout - analysis::mma_layout* in_layout = layouts_->get(op)->to_mma(); - analysis::scanline_layout* out_layout = layouts_->get(rc)->to_scanline(); + Type *ty = cvt(out->get_type()->get_scalar_ty()); // Orders - auto ord = layouts_->get(rc)->to_scanline()->get_order(); + analysis::distributed_layout* in_layout = dynamic_cast(layouts_->get(in)); + analysis::distributed_layout* out_layout = dynamic_cast(layouts_->get(out)); + auto in_ord = in_layout->get_order(); + auto out_ord = out_layout->get_order(); Value *base; - base = gep(shmem_, i32(alloc_->offset(layouts_->get(layouts_->tmp(rc))))); + base = gep(shmem_, i32(alloc_->offset(layouts_->get(layouts_->tmp(out))))); base = bit_cast(base, ptr_ty(ty, 3)); - Value *ld = i32(shape[ord[0]]); - auto in_ord0 = axes_.at(a_axes_->get(op, ord[0])).values; - auto in_ord1 = axes_.at(a_axes_->get(op, ord[1])).values; - auto out_ord0 = axes_.at(a_axes_->get(rc, ord[0])).values; - auto out_ord1 = axes_.at(a_axes_->get(rc, ord[1])).values; - int in_spt0 = in_layout->spt(ord[0]); - int in_spt1 = in_layout->spt(ord[1]); - int out_spt0 = out_layout->mts(ord[0])*out_layout->nts(ord[0]); - int out_spt1 = out_layout->mts(ord[1])*out_layout->nts(ord[1]); - int max_spt1 = std::max(in_spt1, out_spt1); - indices_t idx(2); - int num_packs = shape[ord[1]]/max_spt1; - for(size_t j = 0; j < num_packs; j++){ - add_barrier(); - for(size_t k = 0; k < in_ord1.size()/num_packs; k++) - for(size_t i = 0; i < in_ord0.size(); i++){ - idx[ord[0]] = in_ord0[i]; - idx[ord[1]] = in_ord1[j*in_ord1.size()/num_packs + k]; - Value *off = add(idx[ord[0]], mul(in_ord1[k], ld)); - Value *ptr = gep(base, off); - store(vals_[op][idx], ptr); - } - add_barrier(); - for(size_t k = 0; k < out_ord1.size()/num_packs; k++) - for(size_t i = 0; i < out_ord0.size(); i++){ - idx[ord[0]] = out_ord0[i]; - idx[ord[1]] = out_ord1[j*out_ord1.size()/num_packs + k]; - Value *off = add(idx[ord[0]], mul(out_ord1[k], ld)); - Value *ptr = gep(base, off); - vals_[rc][idx] = load(ptr); - } + std::vector n_reps; + for(int i = 0; i < shape.size(); i++){ + int in_per_cta = in_layout->shape_per_cta(i); + int out_per_cta = out_layout->shape_per_cta(i); + int max_per_cta = std::max(in_per_cta, out_per_cta); + n_reps.push_back(shape[i]/max_per_cta); } + std::vector> in_ax; + std::vector> out_ax; + for(int d = 0; d < shape.size(); d++){ + in_ax.push_back(axes_.at(a_axes_->get(in, d)).values); + out_ax.push_back(axes_.at(a_axes_->get(out, d)).values); + } + in_ord = in_layout->to_mma() ? out_ord : in_ord; + out_ord = out_layout->to_mma() ? in_ord : out_ord; + Value *in_ld = i32(shape[in_ord[0]]); + Value *out_ld = i32(shape[out_ord[0]]); + for(int i = 0; i < n_reps[0]; i++) + for(int j = 0; j < n_reps[1]; j++){ + int max_ii, max_jj; + add_barrier(); + max_ii = in_ax[0].size()/n_reps[0]; + max_jj = in_ax[1].size()/n_reps[1]; + for(int ii = 0; ii < max_ii; ii++) + for(int jj = 0; jj < max_jj; jj++){ + // shared mem pointer + indices_t offs = {in_ax[0][ii], in_ax[1][jj]}; + Value *off = add(offs[out_ord[0]], mul(out_ld, offs[out_ord[1]])); + Value *ptr = gep(base, off); + // stash value to shared mem + indices_t idxs = {in_ax[0][i*max_ii + ii], + in_ax[1][j*max_jj + jj]}; + store(vals_[in][idxs], ptr); + } + add_barrier(); + max_ii = out_ax[0].size()/n_reps[0]; + max_jj = out_ax[1].size()/n_reps[1]; + for(int ii = 0; ii < max_ii; ii++) + for(int jj = 0; jj < max_jj; jj++){ + // shared mem pointer + indices_t offs = {out_ax[0][ii], out_ax[1][jj]}; + Value *off = add(offs[out_ord[0]], mul(out_ld, offs[out_ord[1]])); + Value *ptr = gep(base, off); + // load value from shared rem + indices_t idxs = {out_ax[0][i*max_ii + ii], + out_ax[1][j*max_jj + jj]}; + vals_[out][idxs] = load(ptr); + } + + } +} + +void generator::visit_cvt_layout_inst(ir::cvt_layout_inst *rc) { + visit_layout_convert(rc, rc->get_operand(0)); } void generator::visit_masked_load_async_inst(ir::masked_load_async_inst* x){ @@ -2325,12 +2348,12 @@ void generator::visit_layout_mma(analysis::mma_layout* layout) { offset_b_k_[layout] = and_(lane, _3); // i indices Value *offset_c_m = add(and_(lane, _1), offset_a_m_[layout]); - for(unsigned m = 0; m < shape[0]; m+=layout->spt(0)) + for(unsigned m = 0; m < shape[0]; m+=layout->shape_per_cta(0)) for(unsigned mm = 0; mm < layout->rep(0); mm++) idx_m.push_back(add(offset_c_m, i32(m + mm*2))); // j indices Value *offset_c_n = add(and_(lane, _2), add(off_warp_n, off_pair_n)); - for(unsigned n = 0; n < shape[1]; n+=layout->spt(1)) + for(unsigned n = 0; n < shape[1]; n+=layout->shape_per_cta(1)) for(unsigned nn = 0; nn < layout->rep(1); nn++){ idx_n.push_back(add(offset_c_n, i32(n + nn/2*4 + (nn%2)*2*layout->fpw(1)*layout->rep(1)))); idx_n.push_back(add(offset_c_n, i32(n + nn/2*4 + (nn%2)*2*layout->fpw(1)*layout->rep(1) + 1))); @@ -2366,11 +2389,11 @@ void generator::visit_layout_mma(analysis::mma_layout* layout) { // c offset Value *off_c_m = add(udiv(lane, _4), off_warp_m); Value *off_c_n = add(mul(_2, urem(lane, _4)), off_warp_n); - for(unsigned m = 0; m < shape[0]; m+=layout->spt(0)){ + for(unsigned m = 0; m < shape[0]; m+=layout->shape_per_cta(0)){ idx_m.push_back(add(off_c_m, i32(m))); idx_m.push_back(add(off_c_m, i32(m + 8))); } - for(unsigned n = 0; n < shape[1]; n+=layout->spt(1)){ + for(unsigned n = 0; n < shape[1]; n+=layout->shape_per_cta(1)){ idx_n.push_back(add(off_c_n, i32(n))); idx_n.push_back(add(off_c_n, i32(n + 1))); } @@ -2406,11 +2429,11 @@ void generator::visit_layout_scanline(analysis::scanline_layout* layout) { std::string str_k = std::to_string(k); Value *contiguous_k = i32(nts); Value *scaled_thread_id = mul(thread_id[k], contiguous_k); - unsigned per_block = nts * mts; - unsigned per_thread = nts * shape[k] / per_block; + unsigned per_cta = layout->shape_per_cta(k); + unsigned per_thread = nts * shape[k] / per_cta; std::vector idx_list(per_thread); for(unsigned n = 0 ; n < per_thread; n++){ - unsigned offset = n / nts * per_block + n % nts; + unsigned offset = n / nts * per_cta + n % nts; idx_list[n] = add(scaled_thread_id, i32(offset), "idx_" + str_k + "_" + std::to_string(n)); } axes_[layout->get_axis(k)] = distributed_axis{nts, idx_list, thread_id[k]}; diff --git a/lib/codegen/transform/coalesce.cc b/lib/codegen/transform/coalesce.cc index 78316c0df..d736d3570 100644 --- a/lib/codegen/transform/coalesce.cc +++ b/lib/codegen/transform/coalesce.cc @@ -15,128 +15,109 @@ namespace transform{ coalesce::coalesce(analysis::align* align, analysis::layouts *layouts) : align_(align), layout_(layouts) { } -// Find all values that are used as pointer operands in LD/ST -void coalesce::extract_io_use(ir::value *v, std::set& result) { - for(ir::user* u: v->get_users()){ - auto i = dynamic_cast(u); - if(i && i->get_pointer_operand() == v) - result.insert(i); - } -} -void coalesce::extract_ld(ir::io_inst* i, std::map>& result) { - ir::value *ptr = i->get_pointer_operand(); - auto contiguous = align_->contiguous(ptr); - auto it = std::max_element(contiguous.begin(), contiguous.end()); - int axis = std::distance(contiguous.begin(), it); - result[axis].push_back(i); -} - -ir::value* coalesce::rematerialize(ir::value *x, ir::builder &builder, - std::map& seen) { - if(seen.find(x) != seen.end()) - return seen.at(x); - auto i = dynamic_cast(x); - // not an instruction -- forward value - if(!i) - return x; - // already in shared memory -- forward value - if(dynamic_cast(x)){ - return x; - } - // set insert point - auto& inst_list = i->get_parent()->get_inst_list(); - auto pos = ++std::find(inst_list.begin(), inst_list.end(), i); - builder.set_insert_point(pos); - if(dynamic_cast(x)){ - ir::value *ret = builder.insert(ir::copy_to_shared_inst::create(x)); - return ret; - } - // default -- recursive clone - ir::instruction *cloned = builder.insert(i->clone()); - seen[i] = cloned; - // rematerialize operands - for(ir::value *op: cloned->ops()) - cloned->replace_uses_of_with(op, rematerialize(op, builder, seen)); - return cloned; -} +// simplify layout conversions using the following simple rules: +// - cvt_1(cvt_2(x)) if convert1 is the inverse of convert2 +// - cvt_1(elementwise(x, y)) = elementwise(convert(x), convert(y)) +//ir::value* coalesce::simplify(ir::instruction *inst, ir::builder& builder){ +// ir::value* _op = inst->get_operand(0); +// ir::instruction* op = dynamic_cast(_op); +// analysis::mma_layout* mma_in = layout_->get(op) ->to_mma(); +// analysis::mma_layout* mma_out = layout_->get(inst)->to_mma(); +// std::cout << 1 << std::endl; +// // i must be layout conversion instruction +// if(!mma_in && !mma_out) +// return inst; +// // - cvt_1(cvt_2(x)) if convert1 is the inverse of convert2 +// bool is_op_cvt = op->get_id() == ir::INST_CVT_LAYOUT; +// if((mma_in || mma_out) && is_op_cvt && +// (layout_->get(inst) == layout_->get(op->get_operand(0)))) +// return op->get_operand(0); +// // - cvt_1(elementwise(x, y)) = elementwise(cvt_1(x), cvt_2(y)) +// if(op->get_id() != ir::INST_BINOP && op->get_id() != ir::INST_GETELEMENTPTR) +// return inst; +// std::cout << 1 << std::endl; +// for(size_t i = 0; i < op->get_num_operands(); i++){ +// ir::value* arg_i = op->get_operand(i); +// builder.set_insert_point(op); +// // create new layout transform +// ir::instruction* new_arg_i = inst->clone(); +// builder.insert(new_arg_i); +// // set the right args +// new_arg_i->replace_uses_of_with(new_arg_i->get_operand(0), arg_i); +// op->replace_uses_of_with(arg_i, simplify(new_arg_i, builder)); +// } +// std::cout << 2 << std::endl; +// return op; +//} void coalesce::run(ir::module &mod) { - size_t num_groups = layout_->num_layouts(); - - - for(size_t id = 0; id < num_groups; id++) { - if(!layout_->get(id)->to_mma()) - continue; - // extract memory stores - const auto& values = layout_->values_of(id); - ir::value* dot = nullptr; - for(ir::value *v: values) - if(auto x = dynamic_cast(v)) - dot = x; - - ir::builder& builder = mod.get_builder(); - std::vector worklist = {dot}; - std::set seen; - while(!worklist.empty()) { - ir::value *current = worklist.back(); - seen.insert(current); - worklist.pop_back(); - // stop if trunc - if(auto x = dynamic_cast(current)){ + ir::builder& builder = mod.get_builder(); + // add layout conversion instructions + for(ir::function *fn: mod.get_function_list()) + for(ir::basic_block *block: fn->blocks()) + for(ir::instruction* i: block->get_inst_list()){ + // coalesce before store + if(auto x = dynamic_cast(i)) + if(ir::value* op = x->get_value_operand()) + if(op->get_type()->is_block_ty()) + if(layout_->get(op)->to_mma()){ + builder.set_insert_point(x); + ir::instruction* new_op = ir::cvt_layout_inst::create(op); + builder.insert(new_op); + x->replace_uses_of_with(op, new_op); + } + // uncoalesce after load + if(auto x = dynamic_cast(i)) + if(x->get_type()->is_block_ty()) + if(x->get_type()->get_tile_rank()==2) + if(layout_->get(x)->to_mma()){ builder.set_insert_point_after(x); - ir::recoalesce_inst* rc = ir::recoalesce_inst::create(x); - builder.insert(rc); - x->replace_all_uses_with(rc); - rc->replace_uses_of_with(rc, x); + ir::instruction* new_x = ir::cvt_layout_inst::create(x); + builder.insert(new_x); + x->replace_all_uses_with(new_x); + new_x->replace_uses_of_with(new_x, x); +// new_x->replace_uses_of_with(new_x, new_x); + } + // re-arrange scanline to promote memory coalescing + if(auto x = dynamic_cast(i)){ + ir::value* ptr = x->get_pointer_operand(); + ir::value* val = x->get_value_operand(); + auto out_contig = align_->contiguous(ptr); + auto val_inst = dynamic_cast(val); + if(!val_inst) break; + if(dynamic_cast(val)) + break; + std::vector in_contig; + std::vector queue = {val_inst}; + std::set seen; + std::vector ios; + while(!queue.empty()){ + ir::instruction* curr = queue.back(); + seen.insert(curr); + queue.pop_back(); + if(auto io_inst = dynamic_cast(curr)){ + in_contig = align_->contiguous(io_inst->get_pointer_operand()); + break; + } + for(ir::value* op: curr->ops()){ + auto inst_op = dynamic_cast(op); + if(!inst_op || seen.find(inst_op) != seen.end()) + continue; + if(!op->get_type()->is_block_ty() || + !val->get_type()->is_block_ty()) + continue; + if(op->get_type()->get_tile_num_elements() == + val->get_type()->get_tile_num_elements()) + queue.push_back(inst_op); + } } - // recurse - for(ir::user *u: current->get_users()) - if(seen.find(u) == seen.end()) - worklist.push_back(u); - } - } - - // find values to rematerialize - std::vector remat; - for(size_t id = 0; id < num_groups; id++) { - const auto& values = layout_->values_of(id); - // extract pointers used in ld/st operations - std::set io; - for(ir::value *v: values) - extract_io_use(v, io); - // extract leading axes - std::map> axes; - for(ir::io_inst *i: io){ - if(i->get_pointer_operand()->get_type()->get_tile_rank() == layout_->get(id)->get_rank()){ - extract_ld(i, axes); - } - } - // update list of values to rematerialize - if(axes.empty()) - continue; - for(auto it = ++axes.rbegin(); it != axes.rend(); it++){ - if(it->second.size() == 1) + if(in_contig.empty() || out_contig==in_contig) continue; - remat.insert(remat.begin(), it->second.begin(), it->second.end()); - } - } - // rematerialize values - for(ir::io_inst *r: remat) { - ir::builder& builder = mod.get_builder(); - // rematerialize operands - std::map seen; - for(ir::value *op: r->ops()) - r->replace_uses_of_with(op, rematerialize(op, mod.get_builder(), seen)); - // copy to shared if load - auto& inst_list = r->get_parent()->get_inst_list(); - auto pos = ++std::find(inst_list.begin(), inst_list.end(), r); - builder.set_insert_point(pos); - if(dynamic_cast(r)){ - ir::instruction *cts = builder.insert(ir::copy_to_shared_inst::create(r)); - r->replace_all_uses_with(cts); - cts->replace_uses_of_with(cts, r); + builder.set_insert_point_after(val_inst); + auto new_val = builder.insert(ir::cvt_layout_inst::create(val_inst)); + x->replace_uses_of_with(val_inst, new_val); } } } diff --git a/lib/codegen/transform/disassociate.cc b/lib/codegen/transform/disassociate.cc index 1384e2102..0d9e1b8ef 100644 --- a/lib/codegen/transform/disassociate.cc +++ b/lib/codegen/transform/disassociate.cc @@ -9,67 +9,48 @@ namespace triton { namespace codegen{ namespace transform{ -void extract_retile_chain(ir::user *root, - std::map>& result, - int depth, +ir::instruction* rematerialize(ir::builder& bld, ir::instruction *root, std::set& seen) { if(!seen.insert(root).second) - return; - result[depth].insert(root); - if(dynamic_cast(root) || - dynamic_cast(root)){ - return; - } + return root; + if(!root->get_type()->is_block_ty()) + return root; + + bld.set_insert_point(root); + ir::instruction *new_root = bld.insert(root->clone()); for(ir::value *op: root->ops()){ - ir::user *u = dynamic_cast(op); - if(!u) + ir::instruction *i = dynamic_cast(op); + if(!i || i->get_id() == ir::INST_REDUCE) continue; - extract_retile_chain(u, result, depth + 1, seen); + ir::instruction* new_op = rematerialize(bld, i, seen); + new_root->replace_uses_of_with(op, new_op); } + return new_root; } void disassociate::run(ir::module &mod) { ir::builder &bld = mod.get_builder(); - std::map>> clone_info; +// ir::for_each_instruction(mod, [&](ir::instruction *i){ +// bld.set_insert_point(i); +// for(ir::value* op: i->ops()){ +// auto reshape = dynamic_cast(op); +// if(!reshape) +// continue; +// ir::instruction* new_op = bld.insert(reshape->clone()); +// i->replace_uses_of_with(op, new_op); +// } +// }); + + ir::for_each_instruction(mod, [&](ir::instruction *i){ - if(dynamic_cast(i)){ - ir::value* op = i->get_operand(0); - if(!dynamic_cast(op)) - return; - if(op->get_type()->get_tile_rank() > i->get_type()->get_tile_rank()) - return; - std::map> chains; + if(dynamic_cast(i) || dynamic_cast(i)){ std::set seen; - extract_retile_chain(i, chains, 0, seen); - if(chains.size()) - clone_info[i] = chains; + ir::instruction* new_i = rematerialize(bld, i, seen); + i->replace_all_uses_with(new_i); } }); - for(const auto& x: clone_info){ - int depth = 1; - std::map clone_map; - while(x.second.find(depth) != x.second.end()){ - // clone all users - const auto& remat = x.second.at(depth); - for(ir::user* u: remat){ - ir::instruction *y = (ir::instruction*)u; - ir::instruction *cloned = y->clone(); - bld.set_insert_point(y); - bld.insert(cloned); - clone_map[y] = cloned; - // replace operands of parents - if(depth > 1) - for(ir::user* ux: x.second.at(depth - 1)) - clone_map.at((ir::instruction*)ux)->replace_uses_of_with(y, cloned); - else - x.first->replace_uses_of_with(y, cloned); - } - depth += 1; - } - } - } diff --git a/lib/codegen/transform/peephole.cc b/lib/codegen/transform/peephole.cc index f5eeeb5d0..043e64c05 100644 --- a/lib/codegen/transform/peephole.cc +++ b/lib/codegen/transform/peephole.cc @@ -211,6 +211,42 @@ bool peephole::rewrite_select_masked_load(ir::instruction *value, ir::builder& b return true; } +bool peephole::rewrite_cvt_layout(ir::instruction *value, ir::builder& builder){ + auto cvt = dynamic_cast(value); + if(!cvt) + return false; + ir::instruction* op = dynamic_cast(cvt->get_operand(0)); + if(!op) + return false; + // convert(elementwise(x, y)) = elementwise(convert(x), convert(y)) + if(op->get_id() == ir::INST_BINOP){ + for(size_t i = 0; i < op->get_num_operands(); i++){ + ir::value* arg_i = op->get_operand(i); + builder.set_insert_point(op); + // create new layout transform + ir::instruction* new_arg_i = cvt->clone(); + layouts_->copy(new_arg_i, op); + builder.insert(new_arg_i); + // set the right args + new_arg_i->replace_uses_of_with(new_arg_i->get_operand(0), arg_i); + op->replace_uses_of_with(arg_i, new_arg_i); + } + cvt->replace_all_uses_with(op); + return true; + } + auto cvt_op = dynamic_cast(op); + if(!cvt_op) + return false; + // convert1(convert2(x)) if convert1 is the inverse of convert2 + ir::value* op_op = cvt_op->get_operand(0); + if(layouts_->has(cvt) && layouts_->has(op_op) && + layouts_->get(cvt) && layouts_->get(op_op)){ + cvt->replace_all_uses_with(op_op); + return true; + } + return false; +} + void peephole::run(ir::module &mod) { ir::builder &builder = mod.get_builder(); // keep track of whether any modification was made @@ -248,6 +284,7 @@ void peephole::run(ir::module &mod) { was_modified = was_modified || rewrite_unit_red(i, builder); was_modified = was_modified || rewrite_gep_ptr_min_off_plus_off(i, builder); was_modified = was_modified || rewrite_select_masked_load(i, builder); + was_modified = was_modified || rewrite_cvt_layout(i, builder); if(tgt_->as_nvidia()->sm() >= 80) was_modified = was_modified || rewrite_load_to_shared(i, builder); if(was_modified) diff --git a/lib/codegen/transform/pipeline.cc b/lib/codegen/transform/pipeline.cc index 8b9fe1a5d..096edb0f4 100644 --- a/lib/codegen/transform/pipeline.cc +++ b/lib/codegen/transform/pipeline.cc @@ -311,4 +311,4 @@ void pipeline::run(ir::module &mod) { } } -} \ No newline at end of file +} diff --git a/lib/driver/dispatch.cc b/lib/driver/dispatch.cc index f2a2c519f..69fa2e39e 100755 --- a/lib/driver/dispatch.cc +++ b/lib/driver/dispatch.cc @@ -126,12 +126,6 @@ bool dispatch::nvmlinit(){ return res; } -bool dispatch::spvllvminit(){ - if(spvllvm_==nullptr) - spvllvm_ = dlopen("libLLVMSPIRVLib.so", RTLD_LAZY); - return spvllvm_ != nullptr; -} - //CUDA CUDA_DEFINE1(CUresult, cuCtxDestroy_v2, CUcontext) CUDA_DEFINE2(CUresult, cuEventCreate, CUevent *, unsigned int) @@ -185,14 +179,6 @@ NVML_DEFINE3(nvmlReturn_t, nvmlDeviceGetClockInfo, nvmlDevice_t, nvmlClockType_t NVML_DEFINE3(nvmlReturn_t, nvmlDeviceGetMaxClockInfo, nvmlDevice_t, nvmlClockType_t, unsigned int*) NVML_DEFINE3(nvmlReturn_t, nvmlDeviceSetApplicationsClocks, nvmlDevice_t, unsigned int, unsigned int) -// LLVM to SPIR-V -int dispatch::initializeLLVMToSPIRVPass(llvm::PassRegistry ®istry){ - return f_impl(spvllvm_, initializeLLVMToSPIRVPass, initializeLLVMToSPIRVPass_, "initializeLLVMToSPIRVPass", std::ref(registry)); -} - -bool dispatch::writeSpirv(llvm::Module *M, std::ostream &OS, std::string &ErrMsg){ - return f_impl(spvllvm_, writeSpirv, writeSpirv_, "writeSpirv", M, std::ref(OS), std::ref(ErrMsg)); -} // Release void dispatch::release(){ @@ -204,7 +190,6 @@ void dispatch::release(){ void* dispatch::cuda_; void* dispatch::nvml_; -void* dispatch::spvllvm_; //CUDA void* dispatch::cuCtxGetCurrent_; @@ -261,9 +246,5 @@ void* dispatch::nvmlDeviceGetClockInfo_; void* dispatch::nvmlDeviceGetMaxClockInfo_; void* dispatch::nvmlDeviceSetApplicationsClocks_; -// SPIR-V -void* dispatch::initializeLLVMToSPIRVPass_; -void* dispatch::writeSpirv_; - } } diff --git a/lib/ir/dispatch.cc b/lib/ir/dispatch.cc index 21169c031..c1c7a591f 100644 --- a/lib/ir/dispatch.cc +++ b/lib/ir/dispatch.cc @@ -374,12 +374,15 @@ ir::value *dispatch::broadcast(ir::value *input, shape_t shape, ir::builder *bui auto src_shape = input->get_type()->get_block_shapes(); if (src_shape.size() != shape.size()) throw std::runtime_error("Cannot broadcast"); + if(shape == src_shape) + return input; return builder->create_broadcast(input, shape); } std::tuple dispatch::broadcast(ir::value *lhs, ir::value* rhs, ir::builder *builder) { ir::type *lhs_ty = lhs->get_type(); ir::type *rhs_ty = rhs->get_type(); + // make_shape_compatible(block, scalar) if (lhs_ty->is_block_ty() && !rhs_ty->is_block_ty()) rhs = builder->create_splat(rhs, lhs_ty->get_block_shapes()); diff --git a/lib/ir/instructions.cc b/lib/ir/instructions.cc index 18da3d3ae..e50fd790f 100644 --- a/lib/ir/instructions.cc +++ b/lib/ir/instructions.cc @@ -806,6 +806,11 @@ instruction* log_inst::create(value *val, const std::string& name, instruction * // intrinsic instructions //===----------------------------------------------------------------------===// +// cvt_scanline +cvt_layout_inst* cvt_layout_inst::create(value *arg, const std::string &name, instruction *next) { + return new cvt_layout_inst(arg->get_type(), INST_CVT_LAYOUT, arg, name, next); +} + // copy to shared copy_to_shared_inst* copy_to_shared_inst::create(value *arg, const std::string &name, instruction *next) { @@ -818,13 +823,6 @@ copy_from_shared_inst* copy_from_shared_inst::create(value *arg, const std::stri return new copy_from_shared_inst(arg->get_type(), INST_COPY_FROM_SHARED, arg, name, next); } -// recoalesce -recoalesce_inst* recoalesce_inst::create(value *arg, const std::string &name, instruction *next) { - return new recoalesce_inst(arg->get_type(), INST_RECOALESCE, arg, name, next); -} - - - // barrier barrier_inst::barrier_inst(context &ctx, const std::string &name, instruction *next) diff --git a/python/test/test_language.py b/python/test/test_language.py index e15bbc6bb..d5a436bd9 100644 --- a/python/test/test_language.py +++ b/python/test/test_language.py @@ -363,6 +363,133 @@ def test_reduce1d(dtype, shape, device='cuda'): triton.testing.assert_almost_equal(z_tri, z_ref) +@pytest.mark.parametrize("dtype, shape, axis", + [(dtype, shape, 1) \ + for dtype in ['float32']\ + for shape in [(1, 1024)]]) +def test_reduce2d(dtype, shape, axis, device='cuda'): + dtype = cvt[dtype] + # triton kernel + @triton.jit + def kernel(X, Z, **meta): + range_m = tl.arange(0, meta['BLOCK_M']) + range_n = tl.arange(0, meta['BLOCK_N']) + x = tl.load(X + range_m[:, None]*meta['BLOCK_N'] + range_n[None, :]) + z = tl.sum(x, axis=meta['AXIS']) + tl.store(Z + range_m, z) + # input + x = triton.testing.random(shape, dtype=dtype, device=device) + # triton result + z_tri = torch.empty((shape[0],), dtype=dtype, device=device) + kernel[(1,)](x, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis) + # torch result + z_ref = torch.sum(x, axis=axis).to(dtype) + # compare + triton.testing.assert_almost_equal(z_tri, z_ref) + +# --------------- +# test permute +# --------------- + +# --------------- +# test permute +# --------------- + +@pytest.mark.parametrize("dtype, shape, perm", + [(dtype, shape, perm) \ + for dtype in ['float32']\ + for shape in [(128, 128)]\ + for perm in [(1, 0)]]) +def test_permute(dtype, shape, perm, device='cuda'): + dtype = cvt[dtype] + # triton kernel + @triton.jit + def kernel(X, stride_xm, stride_xn, + Z, stride_zm, stride_zn, **meta): + BLOCK_M = meta['BLOCK_M'] + BLOCK_N = meta['BLOCK_N'] + off_m = tl.arange(0, BLOCK_M) + off_n = tl.arange(0, BLOCK_N) + Xs = X + off_m[:, None] * stride_xm + off_n[None, :] * stride_xn + Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn + tl.store(Zs, tl.load(Xs)) + # input + x = triton.testing.random(shape, dtype=dtype, device=device) + # triton result + z_tri = torch.empty_like(x) + pgm = kernel[(1, 1)](x, x.stride(0), x.stride(1), + z_tri, z_tri.stride(1), z_tri.stride(0), + BLOCK_M=shape[0], BLOCK_N=shape[1]) + # torch result + z_ref = x.permute(*perm).contiguous() + # compare + triton.testing.assert_almost_equal(z_tri, z_ref) + # parse ptx to make sure ld/st are vectorized + ptx = pgm.asm('ptx') + assert 'ld.global.v4' in ptx + assert 'st.global.v4' in ptx + +# --------------- +# test dot +# --------------- + +@pytest.mark.parametrize("epilogue", ['none', 'add-matrix', 'add-rows', 'add-cols']) +def test_dot(epilogue, device='cuda'): + torch.manual_seed(0) + # triton kernel + @triton.jit + def kernel(X, stride_xm, stride_xk, + Y, stride_yk, stride_yn, + Z, stride_zm, stride_zn, **meta): + BLOCK_M = meta['BLOCK_M'] + BLOCK_K = meta['BLOCK_K'] + BLOCK_N = meta['BLOCK_N'] + off_m = tl.arange(0, BLOCK_M) + off_n = tl.arange(0, BLOCK_N) + off_k = tl.arange(0, BLOCK_K) + Xs = X + off_m[:, None] * stride_xm + off_k[None, :] * stride_xk + Ys = Y + off_k[:, None] * stride_yk + off_n[None, :] * stride_yn + Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn + z = tl.dot(tl.load(Xs), tl.load(Ys)) + if meta['ADD_MATRIX']: + z += tl.load(Zs) + if meta['ADD_ROWS']: + ZRs = Z + off_m * stride_zm + z += tl.load(ZRs)[:, None] + if meta['ADD_COLS']: + ZCs = Z + off_n * stride_zn + z += tl.load(ZCs)[None, :] + tl.store(Zs, z) + # input + M, N, K = 64, 64, 32 + x = triton.testing.random((M, K), dtype=torch.float16, device=device) + y = triton.testing.random((K, N), dtype=torch.float16, device=device) + # triton result + z = triton.testing.random((M, N), dtype=torch.float16, device=device) + z_tri = z.clone() + pgm = kernel[(1, 1)](x, x.stride(0), x.stride(1), + y, y.stride(0), y.stride(1), + z_tri, z_tri.stride(0), z_tri.stride(1), + BLOCK_M=M, BLOCK_K=K, BLOCK_N=N, + ADD_MATRIX = epilogue=='add-matrix', + ADD_ROWS = epilogue=='add-rows', + ADD_COLS = epilogue=='add-cols') + # torch result + z_ref = torch.matmul(x.float(), y.float()) + if epilogue == 'add-matrix': + z_ref += z + if epilogue == 'add-rows': + z_ref += z[:,0][:, None] + if epilogue == 'add-cols': + z_ref += z[0,:][None, :] + z_ref = z_ref.to(torch.float16) + # compare + ptx = pgm.asm('ptx') + # print(ptx) + triton.testing.assert_almost_equal(z_tri, z_ref) + # make sure ld/st are vectorized + assert 'ld.global.v4' in ptx + assert 'st.global.v4' in ptx # --------------- diff --git a/python/triton/language.py b/python/triton/language.py index 6cf84f733..6b0100274 100644 --- a/python/triton/language.py +++ b/python/triton/language.py @@ -624,6 +624,14 @@ def max_contiguous(input, value, _builder=None): return frontend.max_contiguous(input, value, _builder) +@builtin +def max_contiguous(input, value, _builder=None): + """ + Let the compiler knows that the `value` first values in :code:`input` are contiguous. + """ + return frontend.max_contiguous(input, value, _builder) + + # ----------------------- # Standard library # ----------------------- diff --git a/python/triton/ops/matmul.py b/python/triton/ops/matmul.py index 7cac0a264..96c458285 100644 --- a/python/triton/ops/matmul.py +++ b/python/triton/ops/matmul.py @@ -46,8 +46,8 @@ def _kernel(A, B, C, M, N, K, pid_n = (pid % width) // (group_size) # do matrix multiplication rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) rk = tl.arange(0, BLOCK_K) # pointers diff --git a/python/triton/testing.py b/python/triton/testing.py index 5e8236a9e..0bee4375d 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -87,6 +87,7 @@ def assert_allclose(x, y, tol=1e-2): def random(shape, dtype, device): + torch.manual_seed(0) if isinstance(shape, int): shape = (shape, ) if dtype == torch.bool: