[CODEGEN] Various bugfixes and stability improvements in compiler backend (#240)

This commit is contained in:
Philippe Tillet
2021-08-30 11:50:35 -07:00
committed by GitHub
parent 85426dbaf7
commit 4ff3714d61
25 changed files with 568 additions and 399 deletions

View File

@@ -93,7 +93,22 @@ protected:
shape_t shape_; shape_t shape_;
}; };
class mma_layout: public data_layout { class distributed_layout: public data_layout{
public:
distributed_layout(id_t id,
const std::vector<int>& axes,
const std::vector<unsigned>& shape,
const std::vector<ir::value*>& values,
analysis::align* align);
int shape_per_cta(size_t k) { return shape_per_cta_.at(k); }
int rep_per_cta(size_t k) { return shape_[k] / shape_per_cta_[k]; }
protected:
std::vector<int> shape_per_cta_;
};
class mma_layout: public distributed_layout {
public: public:
mma_layout(size_t num_warps, mma_layout(size_t num_warps,
const std::vector<int>& axes, const std::vector<int>& axes,
@@ -107,7 +122,6 @@ public:
int fpw(size_t k) { return fpw_.at(k); } int fpw(size_t k) { return fpw_.at(k); }
int wpt(size_t k) { return wpt_.at(k); } int wpt(size_t k) { return wpt_.at(k); }
int spw(size_t k) { return spw_.at(k); } int spw(size_t k) { return spw_.at(k); }
int spt(size_t k) { return spt_.at(k); }
int rep(size_t k) { return rep_.at(k); } int rep(size_t k) { return rep_.at(k); }
private: private:
@@ -123,7 +137,7 @@ private:
std::vector<int> rep_; std::vector<int> rep_;
}; };
struct scanline_layout: public data_layout { struct scanline_layout: public distributed_layout {
scanline_layout(size_t num_warps, scanline_layout(size_t num_warps,
const std::vector<int>& axes, const std::vector<int>& axes,
const std::vector<unsigned>& shape, const std::vector<unsigned>& shape,
@@ -219,6 +233,7 @@ public:
// accessors // accessors
unsigned layout_of(ir::value *value) const { return groups_.at(value); } unsigned layout_of(ir::value *value) const { return groups_.at(value); }
bool has(ir::value* value) const { return groups_.find(value) != groups_.end(); }
const std::vector<ir::value*>& values_of(unsigned id) const { return values_.at(id); } const std::vector<ir::value*>& values_of(unsigned id) const { return values_.at(id); }
size_t num_layouts() const { return values_.size();} size_t num_layouts() const { return values_.size();}
data_layout* get(size_t id) { return layouts_.at(id); } data_layout* get(size_t id) { return layouts_.at(id); }
@@ -226,7 +241,7 @@ public:
std::map<size_t, data_layout*> &get_all() { return layouts_; } std::map<size_t, data_layout*> &get_all() { return layouts_; }
bool has_tmp(ir::value* i) { return tmp_.find(i) != tmp_.end(); } bool has_tmp(ir::value* i) { return tmp_.find(i) != tmp_.end(); }
int tmp(ir::value* i) { return tmp_.at(i);} int tmp(ir::value* i) { return tmp_.at(i);}
void copy(ir::value* dst, ir::value* src) { groups_[dst] = groups_[src]; }
// execution // execution
void run(ir::module &mod); void run(ir::module &mod);

View File

@@ -171,7 +171,8 @@ public:
void visit_reducend_inst(ir::reduce_inst*, std::function<Value*(Value*,Value*)>, Value*); void visit_reducend_inst(ir::reduce_inst*, std::function<Value*(Value*,Value*)>, Value*);
void visit_reduce_inst(ir::reduce_inst*); void visit_reduce_inst(ir::reduce_inst*);
void visit_select_inst(ir::select_inst*); void visit_select_inst(ir::select_inst*);
void visit_recoalesce_inst(ir::recoalesce_inst*); void visit_layout_convert(ir::value *out, ir::value *in);
void visit_cvt_layout_inst(ir::cvt_layout_inst*);
void visit_masked_load_async_inst(ir::masked_load_async_inst*); void visit_masked_load_async_inst(ir::masked_load_async_inst*);
void visit_copy_to_shared_inst(ir::copy_to_shared_inst*); void visit_copy_to_shared_inst(ir::copy_to_shared_inst*);
void visit_copy_from_shared_inst(ir::copy_from_shared_inst*); void visit_copy_from_shared_inst(ir::copy_from_shared_inst*);

View File

@@ -33,6 +33,7 @@ private:
public: public:
coalesce(analysis::align* align, triton::codegen::analysis::layouts *layouts); coalesce(analysis::align* align, triton::codegen::analysis::layouts *layouts);
triton::ir::value *simplify(ir::instruction* i, triton::ir::builder &builder);
void run(ir::module &mod); void run(ir::module &mod);
private: private:

View File

@@ -34,8 +34,7 @@ private:
bool rewrite_gep_ptr_min_off_plus_off(ir::instruction *value, ir::builder& builder); bool rewrite_gep_ptr_min_off_plus_off(ir::instruction *value, ir::builder& builder);
bool rewrite_select_masked_load(ir::instruction *value, ir::builder& builder); bool rewrite_select_masked_load(ir::instruction *value, ir::builder& builder);
bool rewrite_load_to_shared(ir::instruction *value, ir::builder& builder); bool rewrite_load_to_shared(ir::instruction *value, ir::builder& builder);
bool rewrite_cvt_layout(ir::instruction *value, ir::builder& builder);
private:
public: public:
peephole(target* tgt, analysis::layouts* layouts): tgt_(tgt), layouts_(layouts) {} peephole(target* tgt, analysis::layouts* layouts): tgt_(tgt), layouts_(layouts) {}

View File

@@ -60,136 +60,151 @@ protected:
public: public:
static bool nvmlinit(); static bool nvmlinit();
static bool cuinit(); static bool cuinit();
static bool spvllvminit();
static void release(); static void release();
// CUDA /* ------------------- *
* CUDA
* ------------------- */
// context management
static CUresult cuInit(unsigned int Flags);
static CUresult cuCtxGetCurrent(CUcontext *pctx); static CUresult cuCtxGetCurrent(CUcontext *pctx);
static CUresult cuCtxSetCurrent(CUcontext ctx); static CUresult cuCtxSetCurrent(CUcontext ctx);
static CUresult cuCtxDestroy_v2(CUcontext ctx); static CUresult cuCtxDestroy_v2(CUcontext ctx);
static CUresult cuEventCreate(CUevent *phEvent, unsigned int Flags); static CUresult cuCtxCreate_v2(CUcontext *pctx, unsigned int flags, CUdevice dev);
static CUresult cuDeviceGet(CUdevice *device, int ordinal); static CUresult cuCtxPushCurrent_v2(CUcontext ctx);
static CUresult cuMemcpyDtoH_v2(void *dstHost, CUdeviceptr srcDevice, size_t ByteCount); static CUresult cuCtxPopCurrent_v2(CUcontext *pctx);
static CUresult cuStreamCreate(CUstream *phStream, unsigned int Flags); static CUresult cuCtxGetDevice(CUdevice* result);
static CUresult cuEventElapsedTime(float *pMilliseconds, CUevent hStart, CUevent hEnd); static CUresult cuCtxEnablePeerAccess(CUcontext peerContext, unsigned int flags);
static CUresult cuMemFree_v2(CUdeviceptr dptr);
static CUresult cuMemcpyDtoHAsync_v2(void *dstHost, CUdeviceptr srcDevice, size_t ByteCount, CUstream hStream);
static CUresult cuDriverGetVersion(int *driverVersion); static CUresult cuDriverGetVersion(int *driverVersion);
// device management
static CUresult cuDeviceGet(CUdevice *device, int ordinal);
static CUresult cuDeviceGetName(char *name, int len, CUdevice dev); static CUresult cuDeviceGetName(char *name, int len, CUdevice dev);
static CUresult cuDeviceGetPCIBusId(char *id, int len, CUdevice dev); static CUresult cuDeviceGetPCIBusId(char *id, int len, CUdevice dev);
static CUresult cuModuleGetGlobal_v2(CUdeviceptr *dptr, size_t* bytes, CUmodule hmod, const char *name); static CUresult cuDeviceGetAttribute(int *pi, CUdevice_attribute attrib, CUdevice dev);
static CUresult cuMemcpyHtoDAsync_v2(CUdeviceptr dstDevice, const void *srcHost, size_t ByteCount, CUstream hStream); static CUresult cuDeviceGetCount(int *count);
static CUresult cuModuleLoad(CUmodule *module, const char *fname); // link management
static CUresult cuModuleLoadData(CUmodule* module, const void* image);
static CUresult cuLaunchKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes, CUstream hStream, void **kernelParams, void **extra);
static CUresult cuModuleUnload(CUmodule hmod);
static CUresult cuModuleLoadDataEx(CUmodule *module, const void *image, unsigned int numOptions, CUjit_option *options, void **optionValues);
static CUresult cuLinkAddData_v2(CUlinkState state, CUjitInputType type, void* data, size_t size, const char* name, unsigned int numOptions, CUjit_option* options, void** optionValues); static CUresult cuLinkAddData_v2(CUlinkState state, CUjitInputType type, void* data, size_t size, const char* name, unsigned int numOptions, CUjit_option* options, void** optionValues);
static CUresult cuLinkCreate_v2(unsigned int numOptions, CUjit_option* options, void** optionValues, CUlinkState* stateOut); static CUresult cuLinkCreate_v2(unsigned int numOptions, CUjit_option* options, void** optionValues, CUlinkState* stateOut);
static CUresult cuLinkComplete(CUlinkState state, void** cubinOut, size_t* sizeOut); static CUresult cuLinkComplete(CUlinkState state, void** cubinOut, size_t* sizeOut);
static CUresult cuLinkDestroy(CUlinkState state); static CUresult cuLinkDestroy(CUlinkState state);
// module management
static CUresult cuDeviceGetAttribute(int *pi, CUdevice_attribute attrib, CUdevice dev); static CUresult cuModuleGetGlobal_v2(CUdeviceptr *dptr, size_t* bytes, CUmodule hmod, const char *name);
static CUresult cuDeviceGetCount(int *count); static CUresult cuModuleLoad(CUmodule *module, const char *fname);
static CUresult cuMemcpyHtoD_v2(CUdeviceptr dstDevice, const void *srcHost, size_t ByteCount); static CUresult cuModuleLoadData(CUmodule* module, const void* image);
static CUresult cuInit(unsigned int Flags); static CUresult cuModuleUnload(CUmodule hmod);
static CUresult cuEventRecord(CUevent hEvent, CUstream hStream); static CUresult cuModuleLoadDataEx(CUmodule *module, const void *image, unsigned int numOptions, CUjit_option *options, void **optionValues);
static CUresult cuCtxCreate_v2(CUcontext *pctx, unsigned int flags, CUdevice dev);
static CUresult cuCtxPushCurrent_v2(CUcontext ctx);
static CUresult cuCtxPopCurrent_v2(CUcontext *pctx);
static CUresult cuModuleGetFunction(CUfunction *hfunc, CUmodule hmod, const char *name); static CUresult cuModuleGetFunction(CUfunction *hfunc, CUmodule hmod, const char *name);
// stream management
static CUresult cuStreamCreate(CUstream *phStream, unsigned int Flags);
static CUresult cuStreamSynchronize(CUstream hStream); static CUresult cuStreamSynchronize(CUstream hStream);
static CUresult cuStreamGetCtx(CUstream hStream, CUcontext* pctx); static CUresult cuStreamGetCtx(CUstream hStream, CUcontext* pctx);
static CUresult cuStreamDestroy_v2(CUstream hStream); static CUresult cuStreamDestroy_v2(CUstream hStream);
static CUresult cuEventDestroy_v2(CUevent hEvent); static CUresult cuLaunchKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes, CUstream hStream, void **kernelParams, void **extra);
static CUresult cuMemAlloc_v2(CUdeviceptr *dptr, size_t bytesize); // function management
static CUresult cuPointerGetAttribute(void * data, CUpointer_attribute attribute, CUdeviceptr ptr);
static CUresult cuCtxGetDevice(CUdevice* result);
static CUresult cuMemsetD8Async(CUdeviceptr dst, unsigned char x, size_t N, CUstream stream);
static CUresult cuFuncGetAttribute(int* pi, CUfunction_attribute attrib, CUfunction hfunc); static CUresult cuFuncGetAttribute(int* pi, CUfunction_attribute attrib, CUfunction hfunc);
static CUresult cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value); static CUresult cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value);
static CUresult cuFuncSetCacheConfig(CUfunction hfunc, CUfunc_cache config); static CUresult cuFuncSetCacheConfig(CUfunction hfunc, CUfunc_cache config);
static CUresult cuCtxEnablePeerAccess(CUcontext peerContext, unsigned int flags); // memory management
// NVML static CUresult cuMemAlloc_v2(CUdeviceptr *dptr, size_t bytesize);
static CUresult cuPointerGetAttribute(void * data, CUpointer_attribute attribute, CUdeviceptr ptr);
static CUresult cuMemsetD8Async(CUdeviceptr dst, unsigned char x, size_t N, CUstream stream);
static CUresult cuMemcpyDtoH_v2(void *dstHost, CUdeviceptr srcDevice, size_t ByteCount);
static CUresult cuMemFree_v2(CUdeviceptr dptr);
static CUresult cuMemcpyDtoHAsync_v2(void *dstHost, CUdeviceptr srcDevice, size_t ByteCount, CUstream hStream);
static CUresult cuMemcpyHtoDAsync_v2(CUdeviceptr dstDevice, const void *srcHost, size_t ByteCount, CUstream hStream);
static CUresult cuMemcpyHtoD_v2(CUdeviceptr dstDevice, const void *srcHost, size_t ByteCount);
// event management
static CUresult cuEventCreate(CUevent *phEvent, unsigned int Flags);
static CUresult cuEventElapsedTime(float *pMilliseconds, CUevent hStart, CUevent hEnd);
static CUresult cuEventRecord(CUevent hEvent, CUstream hStream);
static CUresult cuEventDestroy_v2(CUevent hEvent);
/* ------------------- *
* NVML
* ------------------- */
static nvmlReturn_t nvmlDeviceGetHandleByPciBusId_v2( const char* pciBusId, nvmlDevice_t* device); static nvmlReturn_t nvmlDeviceGetHandleByPciBusId_v2( const char* pciBusId, nvmlDevice_t* device);
static nvmlReturn_t nvmlDeviceGetClockInfo(nvmlDevice_t device, nvmlClockType_t type, unsigned int *clock); static nvmlReturn_t nvmlDeviceGetClockInfo(nvmlDevice_t device, nvmlClockType_t type, unsigned int *clock);
static nvmlReturn_t nvmlDeviceGetMaxClockInfo(nvmlDevice_t device, nvmlClockType_t type, unsigned int *clock); static nvmlReturn_t nvmlDeviceGetMaxClockInfo(nvmlDevice_t device, nvmlClockType_t type, unsigned int *clock);
static nvmlReturn_t nvmlDeviceSetApplicationsClocks(nvmlDevice_t device, unsigned int mem_clock, unsigned int sm_clock); static nvmlReturn_t nvmlDeviceSetApplicationsClocks(nvmlDevice_t device, unsigned int mem_clock, unsigned int sm_clock);
// SPIR-V libraries
static int initializeLLVMToSPIRVPass(llvm::PassRegistry &);
static bool writeSpirv(llvm::Module *M, std::ostream &OS, std::string &ErrMsg);
private: private:
// Libraries // Libraries
static void* cuda_; static void* cuda_;
static void* nvml_; static void* nvml_;
static void* vulkan_;
static void* spvllvm_;
static void* spvcross_;
static void* opengl_;
// CUDA functions /* ------------------- *
* CUDA
* ------------------- */
// context management
static void* cuCtxGetCurrent_; static void* cuCtxGetCurrent_;
static void* cuCtxSetCurrent_; static void* cuCtxSetCurrent_;
static void* cuCtxDestroy_v2_; static void* cuCtxDestroy_v2_;
static void* cuEventCreate_; static void* cuCtxCreate_v2_;
static void* cuDeviceGet_; static void* cuCtxGetDevice_;
static void* cuMemcpyDtoH_v2_; static void* cuCtxPushCurrent_v2_;
static void* cuStreamCreate_; static void* cuCtxPopCurrent_v2_;
static void* cuEventElapsedTime_; static void* cuCtxEnablePeerAccess_;
static void* cuMemFree_v2_;
static void* cuMemcpyDtoHAsync_v2_;
static void* cuDriverGetVersion_; static void* cuDriverGetVersion_;
static void* cuInit_;
// device management
static void* cuDeviceGet_;
static void* cuDeviceGetName_; static void* cuDeviceGetName_;
static void* cuDeviceGetPCIBusId_; static void* cuDeviceGetPCIBusId_;
static void* cuModuleGetGlobal_v2_; static void* cuDeviceGetAttribute_;
static void* cuMemcpyHtoDAsync_v2_; static void* cuDeviceGetCount_;
static void* cuModuleLoad_; // link management
static void* cuLaunchKernel_;
static void* cuModuleUnload_;
static void* cuModuleLoadDataEx_;
static void* cuLinkAddData_v2_; static void* cuLinkAddData_v2_;
static void* cuLinkCreate_v2_; static void* cuLinkCreate_v2_;
static void* cuLinkDestroy_; static void* cuLinkDestroy_;
static void* cuModuleLoadData_;
static void* cuLinkComplete_; static void* cuLinkComplete_;
static void* cuDeviceGetAttribute_; // module management
static void* cuDeviceGetCount_; static void* cuModuleGetGlobal_v2_;
static void* cuMemcpyHtoD_v2_; static void* cuModuleLoad_;
static void* cuInit_; static void* cuModuleUnload_;
static void* cuEventRecord_; static void* cuModuleLoadDataEx_;
static void* cuCtxCreate_v2_; static void* cuModuleLoadData_;
static void* cuModuleGetFunction_; static void* cuModuleGetFunction_;
// stream management
static void* cuStreamCreate_;
static void* cuStreamSynchronize_; static void* cuStreamSynchronize_;
static void* cuStreamDestroy_v2_; static void* cuStreamDestroy_v2_;
static void* cuStreamGetCtx_; static void* cuStreamGetCtx_;
static void* cuEventDestroy_v2_; static void* cuLaunchKernel_;
static void* cuMemAlloc_v2_; // function management
static void* cuPointerGetAttribute_;
static void* cuCtxGetDevice_;
static void* cuMemsetD8Async_;
static void* cuCtxPushCurrent_v2_;
static void* cuCtxPopCurrent_v2_;
static void* cuFuncGetAttribute_; static void* cuFuncGetAttribute_;
static void* cuFuncSetAttribute_; static void* cuFuncSetAttribute_;
static void* cuFuncSetCacheConfig_; static void* cuFuncSetCacheConfig_;
static void* cuCtxEnablePeerAccess_; // memory management
// NVML static void* cuMemcpyDtoH_v2_;
static void* cuMemFree_v2_;
static void* cuMemcpyDtoHAsync_v2_;
static void* cuMemcpyHtoDAsync_v2_;
static void* cuMemcpyHtoD_v2_;
static void* cuMemAlloc_v2_;
static void* cuMemsetD8Async_;
static void* cuPointerGetAttribute_;
// event management
static void* cuEventCreate_;
static void* cuEventElapsedTime_;
static void* cuEventRecord_;
static void* cuEventDestroy_v2_;
/* ------------------- *
* NVML
* ------------------- */
static void* nvmlInit_v2_; static void* nvmlInit_v2_;
static void* nvmlDeviceGetHandleByPciBusId_v2_; static void* nvmlDeviceGetHandleByPciBusId_v2_;
static void* nvmlDeviceGetClockInfo_; static void* nvmlDeviceGetClockInfo_;
static void* nvmlDeviceGetMaxClockInfo_; static void* nvmlDeviceGetMaxClockInfo_;
static void* nvmlDeviceSetApplicationsClocks_; static void* nvmlDeviceSetApplicationsClocks_;
// LLVM to SPIR-V
static void* initializeLLVMToSPIRVPass_;
static void* writeSpirv_;
}; };
} }

View File

@@ -153,6 +153,9 @@ enum value_id_t: unsigned {
// intrinsics // intrinsics
INST_COPY_TO_SHARED, INST_COPY_TO_SHARED,
INST_COPY_FROM_SHARED, INST_COPY_FROM_SHARED,
INST_CVT_LAYOUT,
INST_CVT_SCANLINE,
INST_DECOALESCE,
INST_RECOALESCE, INST_RECOALESCE,
INST_BARRIER, INST_BARRIER,
INST_ASYNC_WAIT, INST_ASYNC_WAIT,

View File

@@ -807,16 +807,15 @@ public:
_TRITON_DEFINE_ACCEPT(copy_from_shared_inst) _TRITON_DEFINE_ACCEPT(copy_from_shared_inst)
}; };
class cvt_layout_inst: public unary_inst {
class recoalesce_inst: public unary_inst{
private: private:
using unary_inst::unary_inst; using unary_inst::unary_inst;
std::string repr_impl() const { return "recoalesce_inst"; } std::string repr_impl() const { return "cvt_layout_inst"; }
public: public:
static recoalesce_inst* create(value *arg, const std::string &name = "", instruction *next = nullptr); static cvt_layout_inst* create(value *arg, const std::string &name = "", instruction *next = nullptr);
_TRITON_DEFINE_CLONE(recoalesce_inst) _TRITON_DEFINE_CLONE(cvt_layout_inst)
_TRITON_DEFINE_ACCEPT(recoalesce_inst) _TRITON_DEFINE_ACCEPT(cvt_layout_inst)
}; };
class barrier_inst: public instruction{ class barrier_inst: public instruction{

View File

@@ -64,7 +64,7 @@ class sqrt_inst;
class reduce_inst; class reduce_inst;
class select_inst; class select_inst;
class recoalesce_inst; class cvt_layout_inst;
class copy_to_shared_inst; class copy_to_shared_inst;
class copy_from_shared_inst; class copy_from_shared_inst;
class masked_load_async_inst; class masked_load_async_inst;
@@ -142,9 +142,11 @@ public:
virtual void visit_reduce_inst(reduce_inst*) = 0; virtual void visit_reduce_inst(reduce_inst*) = 0;
virtual void visit_select_inst(select_inst*) = 0; virtual void visit_select_inst(select_inst*) = 0;
virtual void visit_recoalesce_inst(recoalesce_inst*) = 0; virtual void visit_cvt_layout_inst(cvt_layout_inst*) = 0;
virtual void visit_copy_to_shared_inst(copy_to_shared_inst*) = 0; virtual void visit_copy_to_shared_inst(copy_to_shared_inst*) = 0;
virtual void visit_copy_from_shared_inst(copy_from_shared_inst*) = 0; virtual void visit_copy_from_shared_inst(copy_from_shared_inst*) = 0;
virtual void visit_masked_load_async_inst(masked_load_async_inst*)= 0; virtual void visit_masked_load_async_inst(masked_load_async_inst*)= 0;
virtual void visit_barrier_inst(barrier_inst*) = 0; virtual void visit_barrier_inst(barrier_inst*) = 0;
virtual void visit_async_wait_inst(async_wait_inst*) = 0; virtual void visit_async_wait_inst(async_wait_inst*) = 0;

View File

@@ -6,6 +6,7 @@
#include <map> #include <map>
#include <set> #include <set>
#include <vector> #include <vector>
#include <iostream>
namespace triton { namespace triton {
namespace tools{ namespace tools{
@@ -40,8 +41,9 @@ public:
nmap->clear(); nmap->clear();
std::set<node_t> nodes = nodes_; std::set<node_t> nodes = nodes_;
unsigned id = 0; unsigned id = 0;
while(!nodes.empty()) while(!nodes.empty()){
connected_components_impl(*nodes.begin(), nodes, nmap, cmap, id++); connected_components_impl(*nodes.begin(), nodes, nmap, cmap, id++);
}
} }
void add_edge(node_t x, node_t y) { void add_edge(node_t x, node_t y) {

View File

@@ -50,7 +50,6 @@ void allocation::run(ir::module &mod) {
J.erase(j_it); J.erase(j_it);
} }
} }
// Build interference graph // Build interference graph
std::map<shared_layout*, std::set<shared_layout*>> interferences; std::map<shared_layout*, std::set<shared_layout*>> interferences;
for(shared_layout* x: V) for(shared_layout* x: V)
@@ -66,13 +65,10 @@ void allocation::run(ir::module &mod) {
&& XS.intersect(YS)) && XS.intersect(YS))
interferences[x].insert(y); interferences[x].insert(y);
} }
// Initialize colors // Initialize colors
std::map<shared_layout*, int> colors; std::map<shared_layout*, int> colors;
for(shared_layout* X: V) for(shared_layout* X: V)
colors[X] = (X==V[0])?0:-1; colors[X] = (X==V[0])?0:-1;
// First-fit graph coloring // First-fit graph coloring
std::vector<bool> available(V.size()); std::vector<bool> available(V.size());
for(shared_layout* x: V){ for(shared_layout* x: V){
@@ -87,7 +83,6 @@ void allocation::run(ir::module &mod) {
auto It = std::find(available.begin(), available.end(), true); auto It = std::find(available.begin(), available.end(), true);
colors[x] = std::distance(available.begin(), It); colors[x] = std::distance(available.begin(), It);
} }
// Finalize allocation // Finalize allocation
for(shared_layout* x: V){ for(shared_layout* x: V){
unsigned Adj = 0; unsigned Adj = 0;
@@ -95,7 +90,6 @@ void allocation::run(ir::module &mod) {
Adj = std::max<unsigned>(Adj, starts[y] + y->get_size()); Adj = std::max<unsigned>(Adj, starts[y] + y->get_size());
offsets_[x] = starts[x] + colors[x] * Adj; offsets_[x] = starts[x] + colors[x] * Adj;
} }
// Save maximum size of induced memory space // Save maximum size of induced memory space
allocated_size_ = 0; allocated_size_ = 0;
for(shared_layout* x: V) for(shared_layout* x: V)

View File

@@ -105,17 +105,17 @@ void axes::update_graph_no_edge(ir::instruction *i) {
void axes::update_graph(ir::instruction *i) { void axes::update_graph(ir::instruction *i) {
switch (i->get_id()) { switch (i->get_id()) {
case ir::INST_REDUCE: return update_graph_reduce(i); case ir::INST_REDUCE: return update_graph_reduce(i);
case ir::INST_RESHAPE: return update_graph_reshape(i); case ir::INST_RESHAPE: return update_graph_reshape(i);
case ir::INST_SPLAT: return update_graph_no_edge(i);; case ir::INST_SPLAT: return update_graph_no_edge(i);;
case ir::INST_TRANS: return update_graph_trans(i); case ir::INST_TRANS: return update_graph_trans(i);
case ir::INST_BROADCAST: return update_graph_broadcast(i); case ir::INST_BROADCAST: return update_graph_broadcast(i);
case ir::INST_DOT: return update_graph_dot(i); case ir::INST_DOT: return update_graph_dot(i);
case ir::INST_COPY_TO_SHARED: return update_graph_no_edge(i); case ir::INST_COPY_TO_SHARED: return update_graph_no_edge(i);
case ir::INST_MASKED_LOAD_ASYNC:return update_graph_elementwise(i, false); case ir::INST_MASKED_LOAD_ASYNC: return update_graph_elementwise(i, false);
case ir::INST_COPY_FROM_SHARED: return update_graph_no_edge(i); case ir::INST_COPY_FROM_SHARED: return update_graph_no_edge(i);
case ir::INST_RECOALESCE: return update_graph_no_edge(i); case ir::INST_CVT_LAYOUT: return update_graph_no_edge(i);
default: return update_graph_elementwise(i); default: return update_graph_elementwise(i);
} }
return; return;
} }
@@ -135,11 +135,15 @@ std::vector<int> axes::get(ir::value *value) {
void axes::run(ir::module &mod) { void axes::run(ir::module &mod) {
// make graph // make graph
graph_.clear(); graph_.clear();
axes_.clear();
ir::for_each_instruction(mod, [this](ir::instruction *x) { ir::for_each_instruction(mod, [this](ir::instruction *x) {
update_graph(x); update_graph(x);
}); });
// find connected components // find connected components
graph_.connected_components(nullptr, &axes_); graph_.connected_components(nullptr, &axes_);
std::set<size_t> uniq;
for(auto x: axes_)
uniq.insert(x.second);
} }
} }

View File

@@ -109,9 +109,6 @@ data_layout::data_layout(id_t id,
max_contiguous = curr; max_contiguous = curr;
} }
} }
bool is_recoalesce = false;
for(ir::value* v: values)
is_recoalesce = is_recoalesce || dynamic_cast<ir::recoalesce_inst*>(v);
if(max_contiguous.size() > 0){ if(max_contiguous.size() > 0){
std::sort(order_.begin(), order_.end(), [&](unsigned a, unsigned b) { std::sort(order_.begin(), order_.end(), [&](unsigned a, unsigned b) {
return max_contiguous[a] > max_contiguous[b]; return max_contiguous[a] > max_contiguous[b];
@@ -129,6 +126,13 @@ int data_layout::find_axis(int to_find) const {
} }
distributed_layout::distributed_layout(id_t id,
const std::vector<int> &axes,
const std::vector<unsigned> &shape,
const std::vector<ir::value *> &values,
analysis::align* align): data_layout(id, axes, shape, values, align)
{ }
/* -------------------------------- * /* -------------------------------- *
* MMA Layout * * MMA Layout *
* -------------------------------- */ * -------------------------------- */
@@ -138,20 +142,11 @@ mma_layout::mma_layout(size_t num_warps,
const std::vector<unsigned>& shape, const std::vector<unsigned>& shape,
const std::vector<ir::value *> &values, const std::vector<ir::value *> &values,
analysis::align* align, target* tgt, analysis::align* align, target* tgt,
shared_layout *layout_a, shared_layout *layout_b): data_layout(MMA, axes, shape, values, align) { shared_layout *layout_a, shared_layout *layout_b): distributed_layout(MMA, axes, shape, values, align) {
/* fragments per warp */ /* fragments per warp */
// try to make things as square as possible to maximize data re-use // try to make things as square as possible to maximize data re-use
if(tgt->as_nvidia()->sm() < 80){ if(tgt->as_nvidia()->sm() < 80){
fpw_ = {2, 2, 1}; fpw_ = {2, 2, 1};
// std::vector<int> fpw_nm1;
// unsigned num_fragments = std::min<unsigned>((shape_[0]/8)*(shape_[1]/8), 4);
// do {
// fpw_nm1 = fpw_;
// if(fpw_[0]*fpw_[1] < num_fragments)
// fpw_[0] = clamp(fpw_[0]*2, 1, shape_[0] / 8);
// if(fpw_[0]*fpw_[1] < num_fragments)
// fpw_[1] = clamp(fpw_[1]*2, 1, shape_[1] / 8);
// }while(fpw_nm1 != fpw_);
auto ord_a = layout_a->get_order(); auto ord_a = layout_a->get_order();
auto ord_b = layout_b->get_order(); auto ord_b = layout_b->get_order();
bool is_a_row = ord_a[0] != 0; bool is_a_row = ord_a[0] != 0;
@@ -168,6 +163,7 @@ mma_layout::mma_layout(size_t num_warps,
spw_ = {16, 8, 1}; spw_ = {16, 8, 1};
rep_ = {2, 2, 1}; rep_ = {2, 2, 1};
} }
order_ = {0, 1};
/* warps per tile */ /* warps per tile */
// try to make things as square as possible to maximize data re-use // try to make things as square as possible to maximize data re-use
@@ -182,7 +178,7 @@ mma_layout::mma_layout(size_t num_warps,
}while(wpt_nm1 != wpt_); }while(wpt_nm1 != wpt_);
/* shape per block */ /* shape per block */
spt_ = {spw_[0]*wpt_[0], spw_[1]*wpt_[1], 1}; shape_per_cta_ = {spw_[0]*wpt_[0], spw_[1]*wpt_[1], 1};
} }
@@ -194,7 +190,7 @@ scanline_layout::scanline_layout(size_t num_warps,
const std::vector<int>& axes, const std::vector<int>& axes,
const std::vector<unsigned>& shape, const std::vector<unsigned>& shape,
const std::vector<ir::value *> &values, const std::vector<ir::value *> &values,
analysis::align* align, target *tgt): data_layout(SCANLINE, axes, shape, values, align){ analysis::align* align, target *tgt): distributed_layout(SCANLINE, axes, shape, values, align){
unsigned size = std::accumulate(shape_.begin(), shape_.end(), 1, std::multiplies<int>()); unsigned size = std::accumulate(shape_.begin(), shape_.end(), 1, std::multiplies<int>());
unsigned num_threads = tgt->is_gpu() ? num_warps * 32 : 1; unsigned num_threads = tgt->is_gpu() ? num_warps * 32 : 1;
nts_.resize(shape_.size()); nts_.resize(shape_.size());
@@ -230,6 +226,10 @@ scanline_layout::scanline_layout(size_t num_warps,
mts_[i] = clamp(num_threads, 1, shape_[i] / nts_[i]); mts_[i] = clamp(num_threads, 1, shape_[i] / nts_[i]);
num_threads = num_threads / mts_[i]; num_threads = num_threads / mts_[i];
} }
shape_per_cta_.resize(shape_.size());
for(size_t d = 0; d < shape_.size(); d++)
shape_per_cta_[d] = mts_[d]*nts_[d];
} }
@@ -489,6 +489,9 @@ void layouts::create(size_t id, const std::vector<ir::value*>& values) {
void layouts::run(ir::module &mod) { void layouts::run(ir::module &mod) {
// make graph // make graph
graph_.clear(); graph_.clear();
layouts_.clear();
groups_.clear();
ir::for_each_instruction(mod, [this](ir::instruction* i) { ir::for_each_instruction(mod, [this](ir::instruction* i) {
make_graph(i); make_graph(i);
}); });
@@ -515,23 +518,18 @@ void layouts::run(ir::module &mod) {
layouts_[id] = new shared_layout(layout, axes_->get(arg), shapes, {red}, red->get_type()->get_scalar_ty(), align_); layouts_[id] = new shared_layout(layout, axes_->get(arg), shapes, {red}, red->get_type()->get_scalar_ty(), align_);
tmp_[red] = id; tmp_[red] = id;
} }
if(auto *recoalasce = dynamic_cast<ir::recoalesce_inst*>(i)){ if(auto *val = dynamic_cast<ir::cvt_layout_inst*>(i)){
ir::value *val = recoalasce->get_operand(0); distributed_layout* out_layout = dynamic_cast<distributed_layout*>(get(val));
mma_layout* in_layout = get(val)->to_mma(); distributed_layout* in_layout = dynamic_cast<distributed_layout*>(get(i->get_operand(0)));
scanline_layout* out_layout = get(i)->to_scanline();
if(!in_layout || !out_layout)
return;
id++; id++;
ir::type::block_shapes_t in_shape = val->get_type()->get_block_shapes(); size_t dim = val->get_type()->get_tile_rank();
ir::type::block_shapes_t shape(in_shape.size()); ir::type::block_shapes_t shape(dim);
size_t ld = out_layout->get_order(0); for(size_t k = 0; k < dim; k++){
shape[ld] = in_shape[ld]; shape[k] = std::max(in_layout->shape_per_cta(k),
for(size_t k = 0; k < in_shape.size(); k++) out_layout->shape_per_cta(k));
if(k != ld) }
shape[k] = in_layout->to_mma()->spt(k); layouts_[id] = new shared_layout(out_layout, axes_->get(val), shape, {val}, val->get_type()->get_scalar_ty(), align_);
// create layout tmp_[val] = id;
layouts_[id] = new shared_layout(out_layout, axes_->get(val), shape, {recoalasce}, val->get_type()->get_scalar_ty(), align_);
tmp_[recoalasce] = id;
} }
if(auto *atom = dynamic_cast<ir::atomic_inst*>(i)){ if(auto *atom = dynamic_cast<ir::atomic_inst*>(i)){
id++; id++;

View File

@@ -56,10 +56,8 @@ void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps,
dce.run(ir); dce.run(ir);
peephole.run(ir); peephole.run(ir);
dce.run(ir); dce.run(ir);
// ir::print(ir, std::cout);
pipeline.run(ir); pipeline.run(ir);
dce.run(ir); dce.run(ir);
// ir::print(ir, std::cout);
disassociate.run(ir); disassociate.run(ir);
dce.run(ir); dce.run(ir);
align.run(ir); align.run(ir);
@@ -74,14 +72,15 @@ void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps,
layouts.run(ir); layouts.run(ir);
coalesce.run(ir); coalesce.run(ir);
dce.run(ir); dce.run(ir);
// exit(1);
align.run(ir); align.run(ir);
dce.run(ir); dce.run(ir);
if (target->is_gpu()) { if (target->is_gpu())
// reassociate.run(ir);
cts.run(ir); cts.run(ir);
}
dce.run(ir); dce.run(ir);
align.run(ir); align.run(ir);
// ir::print(ir, std::cout);
axes.run(ir); axes.run(ir);
layouts.run(ir); layouts.run(ir);
peephole.run(ir); peephole.run(ir);
@@ -93,10 +92,7 @@ void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps,
liveness.run(ir); liveness.run(ir);
allocation.run(ir); allocation.run(ir);
prefetch_s.run(ir); prefetch_s.run(ir);
// ir::print(ir, std::cout);
barriers.run(ir); barriers.run(ir);
// ir::print(ir, std::cout);
// ir::print(ir, std::cout);
isel.visit(ir, *llvm); isel.visit(ir, *llvm);
mod = driver::module::create(dev, std::move(llvm)); mod = driver::module::create(dev, std::move(llvm));
ker = driver::kernel::create(&*mod, name.c_str()); ker = driver::kernel::create(&*mod, name.c_str());

View File

@@ -586,7 +586,7 @@ void generator::visit_load_inst(ir::load_inst* x){
Type* ty = cvt(op->get_type()->get_scalar_ty()->get_pointer_element_ty()); Type* ty = cvt(op->get_type()->get_scalar_ty()->get_pointer_element_ty());
// compute vector width // compute vector width
size_t vec = 1; size_t vec = 1;
if(op->get_type()->is_block_ty()){ if(op->get_type()->is_block_ty() && op->get_type()->get_tile_rank() > 1){
auto ord = ords_.at(op); auto ord = ords_.at(op);
size_t aln = alignment_->get(op, ord[0]); size_t aln = alignment_->get(op, ord[0]);
size_t nts = layouts_->get(x)->to_scanline()->nts(ord[0]); size_t nts = layouts_->get(x)->to_scanline()->nts(ord[0]);
@@ -626,10 +626,10 @@ void generator::visit_load_inst(ir::load_inst* x){
// ----- // -----
std::ostringstream asm_oss; std::ostringstream asm_oss;
asm_oss << "@$" << n_words; // predicate asm_oss << "@$" << n_words; // predicate
if(force_nc_cache_) // if(force_nc_cache_)
asm_oss << " ld.global.nc"; asm_oss << " ld.global";
else // else
asm_oss << " ld.global.cg"; // asm_oss << " ld.global.cg";
if(n_words > 1) if(n_words > 1)
asm_oss << ".v" << n_words; // vector width asm_oss << ".v" << n_words; // vector width
asm_oss << ".b" << width; // word size asm_oss << ".b" << width; // word size
@@ -1058,7 +1058,8 @@ void generator::visit_mma884(ir::dot_inst* C, ir::value *A, ir::value *B, ir::va
/* --------------------------------- */ /* --------------------------------- */
BasicBlock* curr_bb = builder_->GetInsertBlock(); BasicBlock* curr_bb = builder_->GetInsertBlock();
BasicBlock* entry = &curr_bb->getParent()->getEntryBlock(); BasicBlock* entry = &curr_bb->getParent()->getEntryBlock();
builder_->SetInsertPoint(entry->getTerminator()); if(entry != curr_bb)
builder_->SetInsertPoint(entry->getTerminator());
Value* off_a0 = is_a_row ? offset_a_k_[layout_c] : offset_a_m_[layout_c]; Value* off_a0 = is_a_row ? offset_a_k_[layout_c] : offset_a_m_[layout_c];
Value* off_a1 = is_a_row ? offset_a_m_[layout_c] : offset_a_k_[layout_c]; Value* off_a1 = is_a_row ? offset_a_m_[layout_c] : offset_a_k_[layout_c];
Value* phase_a = urem(udiv(off_a1, i32(per_phase_a)), i32(max_phase_a)); Value* phase_a = urem(udiv(off_a1, i32(per_phase_a)), i32(max_phase_a));
@@ -1116,8 +1117,8 @@ void generator::visit_mma884(ir::dot_inst* C, ir::value *A, ir::value *B, ir::va
for(indices_t idx: idxs_.at(C)) for(indices_t idx: idxs_.at(C))
acc.push_back(vals_[D][idx]); acc.push_back(vals_[D][idx]);
unsigned num_m = layout_c->rep(0) * shape_c[0] / layout_c->spt(0); unsigned num_m = layout_c->rep(0) * shape_c[0] / layout_c->shape_per_cta(0);
unsigned num_n = layout_c->rep(1) * shape_c[1] / layout_c->spt(1); unsigned num_n = layout_c->rep(1) * shape_c[1] / layout_c->shape_per_cta(1);
// create mma & unpack result // create mma & unpack result
auto call_mma = [&](unsigned m, unsigned n, unsigned K) { auto call_mma = [&](unsigned m, unsigned n, unsigned K) {
@@ -1333,7 +1334,8 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir::
BasicBlock* CurrBB = builder_->GetInsertBlock(); BasicBlock* CurrBB = builder_->GetInsertBlock();
BasicBlock* FirstBB = &CurrBB->getParent()->getEntryBlock(); BasicBlock* FirstBB = &CurrBB->getParent()->getEntryBlock();
builder_->SetInsertPoint(FirstBB->getTerminator()); if(FirstBB != CurrBB)
builder_->SetInsertPoint(FirstBB->getTerminator());
Value* thread = tgt_->get_local_id(mod_, *builder_, 0); Value* thread = tgt_->get_local_id(mod_, *builder_, 0);
Value *lane = urem(thread, i32(32)); Value *lane = urem(thread, i32(32));
@@ -1396,8 +1398,8 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir::
"{$10, $11, $12, $13};", "{$10, $11, $12, $13};",
"=f,=f,=f,=f,r,r,r,r,r,r,0,1,2,3", true); "=f,=f,=f,=f,r,r,r,r,r,r,0,1,2,3", true);
unsigned num_rep_0 = shapes[0] / layout->spt(0); unsigned num_rep_0 = shapes[0] / layout->shape_per_cta(0);
unsigned num_rep_1 = shapes[1] / layout->spt(1); unsigned num_rep_1 = shapes[1] / layout->shape_per_cta(1);
// create mma & unpack result // create mma & unpack result
auto call_mma = [&](unsigned m, unsigned n, unsigned K) { auto call_mma = [&](unsigned m, unsigned n, unsigned K) {
@@ -1626,8 +1628,8 @@ void generator::visit_fmadot(ir::dot_inst* C, ir::value* A, ir::value* B, ir::va
std::map<std::pair<int, int>, Value*> has, hbs; std::map<std::pair<int, int>, Value*> has, hbs;
for(unsigned k = 0; k < NK; k++){ for(unsigned k = 0; k < NK; k++){
int z = 0; int z = 0;
for(unsigned m = 0; m < shape_c[0]; m+=layout_c->mts(0)*layout_c->nts(0)) for(unsigned m = 0; m < shape_c[0]; m += layout_c->shape_per_cta(0))
for(unsigned n = 0; n < shape_c[1]; n+=layout_c->mts(1)*layout_c->nts(1)) for(unsigned n = 0; n < shape_c[1]; n += layout_c->shape_per_cta(1))
for(unsigned mm = 0; mm < layout_c->nts(0); mm++) for(unsigned mm = 0; mm < layout_c->nts(0); mm++)
for(unsigned nn = 0; nn < layout_c->nts(1); nn++) for(unsigned nn = 0; nn < layout_c->nts(1); nn++)
{ {
@@ -1818,6 +1820,7 @@ void generator::visit_reducend_inst(ir::reduce_inst* x, std::function<Value*(Val
add_barrier(); add_barrier();
// update accumulator // update accumulator
acc = do_acc(acc, load(read_ptr)); acc = do_acc(acc, load(read_ptr));
add_barrier();
store(acc, write_ptr); store(acc, write_ptr);
} }
} }
@@ -1884,54 +1887,74 @@ void generator::visit_select_inst(ir::select_inst* x) {
} }
} }
/**
* \brief Code Generation for `recoalesce`
*/ void generator::visit_layout_convert(ir::value *out, ir::value *in){
void generator::visit_recoalesce_inst(ir::recoalesce_inst* rc) { ir::block_type::block_shapes_t shape = out->get_type()->get_block_shapes();
ir::value *op = rc->get_operand(0);
ir::block_type::block_shapes_t shape = rc->get_type()->get_block_shapes();
// pointer to temporary shared memory // pointer to temporary shared memory
Type *ty = cvt(rc->get_type()->get_scalar_ty()); Type *ty = cvt(out->get_type()->get_scalar_ty());
// layout
analysis::mma_layout* in_layout = layouts_->get(op)->to_mma();
analysis::scanline_layout* out_layout = layouts_->get(rc)->to_scanline();
// Orders // Orders
auto ord = layouts_->get(rc)->to_scanline()->get_order(); analysis::distributed_layout* in_layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(in));
analysis::distributed_layout* out_layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(out));
auto in_ord = in_layout->get_order();
auto out_ord = out_layout->get_order();
Value *base; Value *base;
base = gep(shmem_, i32(alloc_->offset(layouts_->get(layouts_->tmp(rc))))); base = gep(shmem_, i32(alloc_->offset(layouts_->get(layouts_->tmp(out)))));
base = bit_cast(base, ptr_ty(ty, 3)); base = bit_cast(base, ptr_ty(ty, 3));
Value *ld = i32(shape[ord[0]]); std::vector<int> n_reps;
auto in_ord0 = axes_.at(a_axes_->get(op, ord[0])).values; for(int i = 0; i < shape.size(); i++){
auto in_ord1 = axes_.at(a_axes_->get(op, ord[1])).values; int in_per_cta = in_layout->shape_per_cta(i);
auto out_ord0 = axes_.at(a_axes_->get(rc, ord[0])).values; int out_per_cta = out_layout->shape_per_cta(i);
auto out_ord1 = axes_.at(a_axes_->get(rc, ord[1])).values; int max_per_cta = std::max(in_per_cta, out_per_cta);
int in_spt0 = in_layout->spt(ord[0]); n_reps.push_back(shape[i]/max_per_cta);
int in_spt1 = in_layout->spt(ord[1]);
int out_spt0 = out_layout->mts(ord[0])*out_layout->nts(ord[0]);
int out_spt1 = out_layout->mts(ord[1])*out_layout->nts(ord[1]);
int max_spt1 = std::max(in_spt1, out_spt1);
indices_t idx(2);
int num_packs = shape[ord[1]]/max_spt1;
for(size_t j = 0; j < num_packs; j++){
add_barrier();
for(size_t k = 0; k < in_ord1.size()/num_packs; k++)
for(size_t i = 0; i < in_ord0.size(); i++){
idx[ord[0]] = in_ord0[i];
idx[ord[1]] = in_ord1[j*in_ord1.size()/num_packs + k];
Value *off = add(idx[ord[0]], mul(in_ord1[k], ld));
Value *ptr = gep(base, off);
store(vals_[op][idx], ptr);
}
add_barrier();
for(size_t k = 0; k < out_ord1.size()/num_packs; k++)
for(size_t i = 0; i < out_ord0.size(); i++){
idx[ord[0]] = out_ord0[i];
idx[ord[1]] = out_ord1[j*out_ord1.size()/num_packs + k];
Value *off = add(idx[ord[0]], mul(out_ord1[k], ld));
Value *ptr = gep(base, off);
vals_[rc][idx] = load(ptr);
}
} }
std::vector<std::vector<Value*>> in_ax;
std::vector<std::vector<Value*>> out_ax;
for(int d = 0; d < shape.size(); d++){
in_ax.push_back(axes_.at(a_axes_->get(in, d)).values);
out_ax.push_back(axes_.at(a_axes_->get(out, d)).values);
}
in_ord = in_layout->to_mma() ? out_ord : in_ord;
out_ord = out_layout->to_mma() ? in_ord : out_ord;
Value *in_ld = i32(shape[in_ord[0]]);
Value *out_ld = i32(shape[out_ord[0]]);
for(int i = 0; i < n_reps[0]; i++)
for(int j = 0; j < n_reps[1]; j++){
int max_ii, max_jj;
add_barrier();
max_ii = in_ax[0].size()/n_reps[0];
max_jj = in_ax[1].size()/n_reps[1];
for(int ii = 0; ii < max_ii; ii++)
for(int jj = 0; jj < max_jj; jj++){
// shared mem pointer
indices_t offs = {in_ax[0][ii], in_ax[1][jj]};
Value *off = add(offs[out_ord[0]], mul(out_ld, offs[out_ord[1]]));
Value *ptr = gep(base, off);
// stash value to shared mem
indices_t idxs = {in_ax[0][i*max_ii + ii],
in_ax[1][j*max_jj + jj]};
store(vals_[in][idxs], ptr);
}
add_barrier();
max_ii = out_ax[0].size()/n_reps[0];
max_jj = out_ax[1].size()/n_reps[1];
for(int ii = 0; ii < max_ii; ii++)
for(int jj = 0; jj < max_jj; jj++){
// shared mem pointer
indices_t offs = {out_ax[0][ii], out_ax[1][jj]};
Value *off = add(offs[out_ord[0]], mul(out_ld, offs[out_ord[1]]));
Value *ptr = gep(base, off);
// load value from shared rem
indices_t idxs = {out_ax[0][i*max_ii + ii],
out_ax[1][j*max_jj + jj]};
vals_[out][idxs] = load(ptr);
}
}
}
void generator::visit_cvt_layout_inst(ir::cvt_layout_inst *rc) {
visit_layout_convert(rc, rc->get_operand(0));
} }
void generator::visit_masked_load_async_inst(ir::masked_load_async_inst* x){ void generator::visit_masked_load_async_inst(ir::masked_load_async_inst* x){
@@ -2325,12 +2348,12 @@ void generator::visit_layout_mma(analysis::mma_layout* layout) {
offset_b_k_[layout] = and_(lane, _3); offset_b_k_[layout] = and_(lane, _3);
// i indices // i indices
Value *offset_c_m = add(and_(lane, _1), offset_a_m_[layout]); Value *offset_c_m = add(and_(lane, _1), offset_a_m_[layout]);
for(unsigned m = 0; m < shape[0]; m+=layout->spt(0)) for(unsigned m = 0; m < shape[0]; m+=layout->shape_per_cta(0))
for(unsigned mm = 0; mm < layout->rep(0); mm++) for(unsigned mm = 0; mm < layout->rep(0); mm++)
idx_m.push_back(add(offset_c_m, i32(m + mm*2))); idx_m.push_back(add(offset_c_m, i32(m + mm*2)));
// j indices // j indices
Value *offset_c_n = add(and_(lane, _2), add(off_warp_n, off_pair_n)); Value *offset_c_n = add(and_(lane, _2), add(off_warp_n, off_pair_n));
for(unsigned n = 0; n < shape[1]; n+=layout->spt(1)) for(unsigned n = 0; n < shape[1]; n+=layout->shape_per_cta(1))
for(unsigned nn = 0; nn < layout->rep(1); nn++){ for(unsigned nn = 0; nn < layout->rep(1); nn++){
idx_n.push_back(add(offset_c_n, i32(n + nn/2*4 + (nn%2)*2*layout->fpw(1)*layout->rep(1)))); idx_n.push_back(add(offset_c_n, i32(n + nn/2*4 + (nn%2)*2*layout->fpw(1)*layout->rep(1))));
idx_n.push_back(add(offset_c_n, i32(n + nn/2*4 + (nn%2)*2*layout->fpw(1)*layout->rep(1) + 1))); idx_n.push_back(add(offset_c_n, i32(n + nn/2*4 + (nn%2)*2*layout->fpw(1)*layout->rep(1) + 1)));
@@ -2366,11 +2389,11 @@ void generator::visit_layout_mma(analysis::mma_layout* layout) {
// c offset // c offset
Value *off_c_m = add(udiv(lane, _4), off_warp_m); Value *off_c_m = add(udiv(lane, _4), off_warp_m);
Value *off_c_n = add(mul(_2, urem(lane, _4)), off_warp_n); Value *off_c_n = add(mul(_2, urem(lane, _4)), off_warp_n);
for(unsigned m = 0; m < shape[0]; m+=layout->spt(0)){ for(unsigned m = 0; m < shape[0]; m+=layout->shape_per_cta(0)){
idx_m.push_back(add(off_c_m, i32(m))); idx_m.push_back(add(off_c_m, i32(m)));
idx_m.push_back(add(off_c_m, i32(m + 8))); idx_m.push_back(add(off_c_m, i32(m + 8)));
} }
for(unsigned n = 0; n < shape[1]; n+=layout->spt(1)){ for(unsigned n = 0; n < shape[1]; n+=layout->shape_per_cta(1)){
idx_n.push_back(add(off_c_n, i32(n))); idx_n.push_back(add(off_c_n, i32(n)));
idx_n.push_back(add(off_c_n, i32(n + 1))); idx_n.push_back(add(off_c_n, i32(n + 1)));
} }
@@ -2406,11 +2429,11 @@ void generator::visit_layout_scanline(analysis::scanline_layout* layout) {
std::string str_k = std::to_string(k); std::string str_k = std::to_string(k);
Value *contiguous_k = i32(nts); Value *contiguous_k = i32(nts);
Value *scaled_thread_id = mul(thread_id[k], contiguous_k); Value *scaled_thread_id = mul(thread_id[k], contiguous_k);
unsigned per_block = nts * mts; unsigned per_cta = layout->shape_per_cta(k);
unsigned per_thread = nts * shape[k] / per_block; unsigned per_thread = nts * shape[k] / per_cta;
std::vector<Value*> idx_list(per_thread); std::vector<Value*> idx_list(per_thread);
for(unsigned n = 0 ; n < per_thread; n++){ for(unsigned n = 0 ; n < per_thread; n++){
unsigned offset = n / nts * per_block + n % nts; unsigned offset = n / nts * per_cta + n % nts;
idx_list[n] = add(scaled_thread_id, i32(offset), "idx_" + str_k + "_" + std::to_string(n)); idx_list[n] = add(scaled_thread_id, i32(offset), "idx_" + str_k + "_" + std::to_string(n));
} }
axes_[layout->get_axis(k)] = distributed_axis{nts, idx_list, thread_id[k]}; axes_[layout->get_axis(k)] = distributed_axis{nts, idx_list, thread_id[k]};

View File

@@ -15,128 +15,109 @@ namespace transform{
coalesce::coalesce(analysis::align* align, analysis::layouts *layouts) coalesce::coalesce(analysis::align* align, analysis::layouts *layouts)
: align_(align), layout_(layouts) { } : align_(align), layout_(layouts) { }
// Find all values that are used as pointer operands in LD/ST
void coalesce::extract_io_use(ir::value *v, std::set<ir::io_inst*>& result) {
for(ir::user* u: v->get_users()){
auto i = dynamic_cast<ir::io_inst*>(u);
if(i && i->get_pointer_operand() == v)
result.insert(i);
}
}
void coalesce::extract_ld(ir::io_inst* i, std::map<int, std::vector<ir::io_inst*>>& result) { // simplify layout conversions using the following simple rules:
ir::value *ptr = i->get_pointer_operand(); // - cvt_1(cvt_2(x)) if convert1 is the inverse of convert2
auto contiguous = align_->contiguous(ptr); // - cvt_1(elementwise(x, y)) = elementwise(convert(x), convert(y))
auto it = std::max_element(contiguous.begin(), contiguous.end()); //ir::value* coalesce::simplify(ir::instruction *inst, ir::builder& builder){
int axis = std::distance(contiguous.begin(), it); // ir::value* _op = inst->get_operand(0);
result[axis].push_back(i); // ir::instruction* op = dynamic_cast<ir::instruction*>(_op);
} // analysis::mma_layout* mma_in = layout_->get(op) ->to_mma();
// analysis::mma_layout* mma_out = layout_->get(inst)->to_mma();
ir::value* coalesce::rematerialize(ir::value *x, ir::builder &builder, // std::cout << 1 << std::endl;
std::map<ir::value*, ir::value*>& seen) { // // i must be layout conversion instruction
if(seen.find(x) != seen.end()) // if(!mma_in && !mma_out)
return seen.at(x); // return inst;
auto i = dynamic_cast<ir::instruction*>(x); // // - cvt_1(cvt_2(x)) if convert1 is the inverse of convert2
// not an instruction -- forward value // bool is_op_cvt = op->get_id() == ir::INST_CVT_LAYOUT;
if(!i) // if((mma_in || mma_out) && is_op_cvt &&
return x; // (layout_->get(inst) == layout_->get(op->get_operand(0))))
// already in shared memory -- forward value // return op->get_operand(0);
if(dynamic_cast<ir::copy_to_shared_inst*>(x)){ // // - cvt_1(elementwise(x, y)) = elementwise(cvt_1(x), cvt_2(y))
return x; // if(op->get_id() != ir::INST_BINOP && op->get_id() != ir::INST_GETELEMENTPTR)
} // return inst;
// set insert point // std::cout << 1 << std::endl;
auto& inst_list = i->get_parent()->get_inst_list(); // for(size_t i = 0; i < op->get_num_operands(); i++){
auto pos = ++std::find(inst_list.begin(), inst_list.end(), i); // ir::value* arg_i = op->get_operand(i);
builder.set_insert_point(pos); // builder.set_insert_point(op);
if(dynamic_cast<ir::load_inst*>(x)){ // // create new layout transform
ir::value *ret = builder.insert(ir::copy_to_shared_inst::create(x)); // ir::instruction* new_arg_i = inst->clone();
return ret; // builder.insert(new_arg_i);
} // // set the right args
// default -- recursive clone // new_arg_i->replace_uses_of_with(new_arg_i->get_operand(0), arg_i);
ir::instruction *cloned = builder.insert(i->clone()); // op->replace_uses_of_with(arg_i, simplify(new_arg_i, builder));
seen[i] = cloned; // }
// rematerialize operands // std::cout << 2 << std::endl;
for(ir::value *op: cloned->ops()) // return op;
cloned->replace_uses_of_with(op, rematerialize(op, builder, seen)); //}
return cloned;
}
void coalesce::run(ir::module &mod) { void coalesce::run(ir::module &mod) {
size_t num_groups = layout_->num_layouts(); ir::builder& builder = mod.get_builder();
// add layout conversion instructions
for(ir::function *fn: mod.get_function_list())
for(size_t id = 0; id < num_groups; id++) { for(ir::basic_block *block: fn->blocks())
if(!layout_->get(id)->to_mma()) for(ir::instruction* i: block->get_inst_list()){
continue; // coalesce before store
// extract memory stores if(auto x = dynamic_cast<ir::store_inst*>(i))
const auto& values = layout_->values_of(id); if(ir::value* op = x->get_value_operand())
ir::value* dot = nullptr; if(op->get_type()->is_block_ty())
for(ir::value *v: values) if(layout_->get(op)->to_mma()){
if(auto x = dynamic_cast<ir::dot_inst*>(v)) builder.set_insert_point(x);
dot = x; ir::instruction* new_op = ir::cvt_layout_inst::create(op);
builder.insert(new_op);
ir::builder& builder = mod.get_builder(); x->replace_uses_of_with(op, new_op);
std::vector<ir::value*> worklist = {dot}; }
std::set<ir::value*> seen; // uncoalesce after load
while(!worklist.empty()) { if(auto x = dynamic_cast<ir::load_inst*>(i))
ir::value *current = worklist.back(); if(x->get_type()->is_block_ty())
seen.insert(current); if(x->get_type()->get_tile_rank()==2)
worklist.pop_back(); if(layout_->get(x)->to_mma()){
// stop if trunc
if(auto x = dynamic_cast<ir::fp_trunc_inst*>(current)){
builder.set_insert_point_after(x); builder.set_insert_point_after(x);
ir::recoalesce_inst* rc = ir::recoalesce_inst::create(x); ir::instruction* new_x = ir::cvt_layout_inst::create(x);
builder.insert(rc); builder.insert(new_x);
x->replace_all_uses_with(rc); x->replace_all_uses_with(new_x);
rc->replace_uses_of_with(rc, x); new_x->replace_uses_of_with(new_x, x);
// new_x->replace_uses_of_with(new_x, new_x);
}
// re-arrange scanline to promote memory coalescing
if(auto x = dynamic_cast<ir::store_inst*>(i)){
ir::value* ptr = x->get_pointer_operand();
ir::value* val = x->get_value_operand();
auto out_contig = align_->contiguous(ptr);
auto val_inst = dynamic_cast<ir::instruction*>(val);
if(!val_inst)
break; break;
if(dynamic_cast<ir::cvt_layout_inst*>(val))
break;
std::vector<unsigned> in_contig;
std::vector<ir::instruction*> queue = {val_inst};
std::set<ir::instruction*> seen;
std::vector<ir::io_inst*> ios;
while(!queue.empty()){
ir::instruction* curr = queue.back();
seen.insert(curr);
queue.pop_back();
if(auto io_inst = dynamic_cast<ir::io_inst*>(curr)){
in_contig = align_->contiguous(io_inst->get_pointer_operand());
break;
}
for(ir::value* op: curr->ops()){
auto inst_op = dynamic_cast<ir::instruction*>(op);
if(!inst_op || seen.find(inst_op) != seen.end())
continue;
if(!op->get_type()->is_block_ty() ||
!val->get_type()->is_block_ty())
continue;
if(op->get_type()->get_tile_num_elements() ==
val->get_type()->get_tile_num_elements())
queue.push_back(inst_op);
}
} }
// recurse if(in_contig.empty() || out_contig==in_contig)
for(ir::user *u: current->get_users())
if(seen.find(u) == seen.end())
worklist.push_back(u);
}
}
// find values to rematerialize
std::vector<ir::io_inst*> remat;
for(size_t id = 0; id < num_groups; id++) {
const auto& values = layout_->values_of(id);
// extract pointers used in ld/st operations
std::set<ir::io_inst*> io;
for(ir::value *v: values)
extract_io_use(v, io);
// extract leading axes
std::map<int, std::vector<ir::io_inst*>> axes;
for(ir::io_inst *i: io){
if(i->get_pointer_operand()->get_type()->get_tile_rank() == layout_->get(id)->get_rank()){
extract_ld(i, axes);
}
}
// update list of values to rematerialize
if(axes.empty())
continue;
for(auto it = ++axes.rbegin(); it != axes.rend(); it++){
if(it->second.size() == 1)
continue; continue;
remat.insert(remat.begin(), it->second.begin(), it->second.end()); builder.set_insert_point_after(val_inst);
} auto new_val = builder.insert(ir::cvt_layout_inst::create(val_inst));
} x->replace_uses_of_with(val_inst, new_val);
// rematerialize values
for(ir::io_inst *r: remat) {
ir::builder& builder = mod.get_builder();
// rematerialize operands
std::map<ir::value*, ir::value*> seen;
for(ir::value *op: r->ops())
r->replace_uses_of_with(op, rematerialize(op, mod.get_builder(), seen));
// copy to shared if load
auto& inst_list = r->get_parent()->get_inst_list();
auto pos = ++std::find(inst_list.begin(), inst_list.end(), r);
builder.set_insert_point(pos);
if(dynamic_cast<ir::load_inst*>(r)){
ir::instruction *cts = builder.insert(ir::copy_to_shared_inst::create(r));
r->replace_all_uses_with(cts);
cts->replace_uses_of_with(cts, r);
} }
} }
} }

View File

@@ -9,67 +9,48 @@ namespace triton {
namespace codegen{ namespace codegen{
namespace transform{ namespace transform{
void extract_retile_chain(ir::user *root, ir::instruction* rematerialize(ir::builder& bld, ir::instruction *root,
std::map<int, std::set<ir::user*>>& result,
int depth,
std::set<ir::value*>& seen) { std::set<ir::value*>& seen) {
if(!seen.insert(root).second) if(!seen.insert(root).second)
return; return root;
result[depth].insert(root); if(!root->get_type()->is_block_ty())
if(dynamic_cast<ir::make_range*>(root) || return root;
dynamic_cast<ir::splat_inst*>(root)){
return; bld.set_insert_point(root);
} ir::instruction *new_root = bld.insert(root->clone());
for(ir::value *op: root->ops()){ for(ir::value *op: root->ops()){
ir::user *u = dynamic_cast<ir::user*>(op); ir::instruction *i = dynamic_cast<ir::instruction*>(op);
if(!u) if(!i || i->get_id() == ir::INST_REDUCE)
continue; continue;
extract_retile_chain(u, result, depth + 1, seen); ir::instruction* new_op = rematerialize(bld, i, seen);
new_root->replace_uses_of_with(op, new_op);
} }
return new_root;
} }
void disassociate::run(ir::module &mod) { void disassociate::run(ir::module &mod) {
ir::builder &bld = mod.get_builder(); ir::builder &bld = mod.get_builder();
std::map<ir::user*, std::map<int, std::set<ir::user*>>> clone_info; // ir::for_each_instruction(mod, [&](ir::instruction *i){
// bld.set_insert_point(i);
// for(ir::value* op: i->ops()){
// auto reshape = dynamic_cast<ir::make_range*>(op);
// if(!reshape)
// continue;
// ir::instruction* new_op = bld.insert(reshape->clone());
// i->replace_uses_of_with(op, new_op);
// }
// });
ir::for_each_instruction(mod, [&](ir::instruction *i){ ir::for_each_instruction(mod, [&](ir::instruction *i){
if(dynamic_cast<ir::reshape_inst*>(i)){ if(dynamic_cast<ir::reshape_inst*>(i) || dynamic_cast<ir::splat_inst*>(i)){
ir::value* op = i->get_operand(0);
if(!dynamic_cast<ir::user*>(op))
return;
if(op->get_type()->get_tile_rank() > i->get_type()->get_tile_rank())
return;
std::map<int, std::set<ir::user*>> chains;
std::set<ir::value*> seen; std::set<ir::value*> seen;
extract_retile_chain(i, chains, 0, seen); ir::instruction* new_i = rematerialize(bld, i, seen);
if(chains.size()) i->replace_all_uses_with(new_i);
clone_info[i] = chains;
} }
}); });
for(const auto& x: clone_info){
int depth = 1;
std::map<ir::instruction*, ir::instruction*> clone_map;
while(x.second.find(depth) != x.second.end()){
// clone all users
const auto& remat = x.second.at(depth);
for(ir::user* u: remat){
ir::instruction *y = (ir::instruction*)u;
ir::instruction *cloned = y->clone();
bld.set_insert_point(y);
bld.insert(cloned);
clone_map[y] = cloned;
// replace operands of parents
if(depth > 1)
for(ir::user* ux: x.second.at(depth - 1))
clone_map.at((ir::instruction*)ux)->replace_uses_of_with(y, cloned);
else
x.first->replace_uses_of_with(y, cloned);
}
depth += 1;
}
}
} }

View File

@@ -211,6 +211,42 @@ bool peephole::rewrite_select_masked_load(ir::instruction *value, ir::builder& b
return true; return true;
} }
bool peephole::rewrite_cvt_layout(ir::instruction *value, ir::builder& builder){
auto cvt = dynamic_cast<ir::cvt_layout_inst*>(value);
if(!cvt)
return false;
ir::instruction* op = dynamic_cast<ir::instruction*>(cvt->get_operand(0));
if(!op)
return false;
// convert(elementwise(x, y)) = elementwise(convert(x), convert(y))
if(op->get_id() == ir::INST_BINOP){
for(size_t i = 0; i < op->get_num_operands(); i++){
ir::value* arg_i = op->get_operand(i);
builder.set_insert_point(op);
// create new layout transform
ir::instruction* new_arg_i = cvt->clone();
layouts_->copy(new_arg_i, op);
builder.insert(new_arg_i);
// set the right args
new_arg_i->replace_uses_of_with(new_arg_i->get_operand(0), arg_i);
op->replace_uses_of_with(arg_i, new_arg_i);
}
cvt->replace_all_uses_with(op);
return true;
}
auto cvt_op = dynamic_cast<ir::cvt_layout_inst*>(op);
if(!cvt_op)
return false;
// convert1(convert2(x)) if convert1 is the inverse of convert2
ir::value* op_op = cvt_op->get_operand(0);
if(layouts_->has(cvt) && layouts_->has(op_op) &&
layouts_->get(cvt) && layouts_->get(op_op)){
cvt->replace_all_uses_with(op_op);
return true;
}
return false;
}
void peephole::run(ir::module &mod) { void peephole::run(ir::module &mod) {
ir::builder &builder = mod.get_builder(); ir::builder &builder = mod.get_builder();
// keep track of whether any modification was made // keep track of whether any modification was made
@@ -248,6 +284,7 @@ void peephole::run(ir::module &mod) {
was_modified = was_modified || rewrite_unit_red(i, builder); was_modified = was_modified || rewrite_unit_red(i, builder);
was_modified = was_modified || rewrite_gep_ptr_min_off_plus_off(i, builder); was_modified = was_modified || rewrite_gep_ptr_min_off_plus_off(i, builder);
was_modified = was_modified || rewrite_select_masked_load(i, builder); was_modified = was_modified || rewrite_select_masked_load(i, builder);
was_modified = was_modified || rewrite_cvt_layout(i, builder);
if(tgt_->as_nvidia()->sm() >= 80) if(tgt_->as_nvidia()->sm() >= 80)
was_modified = was_modified || rewrite_load_to_shared(i, builder); was_modified = was_modified || rewrite_load_to_shared(i, builder);
if(was_modified) if(was_modified)

View File

@@ -311,4 +311,4 @@ void pipeline::run(ir::module &mod) {
} }
} }
} }

View File

@@ -126,12 +126,6 @@ bool dispatch::nvmlinit(){
return res; return res;
} }
bool dispatch::spvllvminit(){
if(spvllvm_==nullptr)
spvllvm_ = dlopen("libLLVMSPIRVLib.so", RTLD_LAZY);
return spvllvm_ != nullptr;
}
//CUDA //CUDA
CUDA_DEFINE1(CUresult, cuCtxDestroy_v2, CUcontext) CUDA_DEFINE1(CUresult, cuCtxDestroy_v2, CUcontext)
CUDA_DEFINE2(CUresult, cuEventCreate, CUevent *, unsigned int) CUDA_DEFINE2(CUresult, cuEventCreate, CUevent *, unsigned int)
@@ -185,14 +179,6 @@ NVML_DEFINE3(nvmlReturn_t, nvmlDeviceGetClockInfo, nvmlDevice_t, nvmlClockType_t
NVML_DEFINE3(nvmlReturn_t, nvmlDeviceGetMaxClockInfo, nvmlDevice_t, nvmlClockType_t, unsigned int*) NVML_DEFINE3(nvmlReturn_t, nvmlDeviceGetMaxClockInfo, nvmlDevice_t, nvmlClockType_t, unsigned int*)
NVML_DEFINE3(nvmlReturn_t, nvmlDeviceSetApplicationsClocks, nvmlDevice_t, unsigned int, unsigned int) NVML_DEFINE3(nvmlReturn_t, nvmlDeviceSetApplicationsClocks, nvmlDevice_t, unsigned int, unsigned int)
// LLVM to SPIR-V
int dispatch::initializeLLVMToSPIRVPass(llvm::PassRegistry &registry){
return f_impl<dispatch::spvllvminit>(spvllvm_, initializeLLVMToSPIRVPass, initializeLLVMToSPIRVPass_, "initializeLLVMToSPIRVPass", std::ref(registry));
}
bool dispatch::writeSpirv(llvm::Module *M, std::ostream &OS, std::string &ErrMsg){
return f_impl<dispatch::spvllvminit>(spvllvm_, writeSpirv, writeSpirv_, "writeSpirv", M, std::ref(OS), std::ref(ErrMsg));
}
// Release // Release
void dispatch::release(){ void dispatch::release(){
@@ -204,7 +190,6 @@ void dispatch::release(){
void* dispatch::cuda_; void* dispatch::cuda_;
void* dispatch::nvml_; void* dispatch::nvml_;
void* dispatch::spvllvm_;
//CUDA //CUDA
void* dispatch::cuCtxGetCurrent_; void* dispatch::cuCtxGetCurrent_;
@@ -261,9 +246,5 @@ void* dispatch::nvmlDeviceGetClockInfo_;
void* dispatch::nvmlDeviceGetMaxClockInfo_; void* dispatch::nvmlDeviceGetMaxClockInfo_;
void* dispatch::nvmlDeviceSetApplicationsClocks_; void* dispatch::nvmlDeviceSetApplicationsClocks_;
// SPIR-V
void* dispatch::initializeLLVMToSPIRVPass_;
void* dispatch::writeSpirv_;
} }
} }

View File

@@ -374,12 +374,15 @@ ir::value *dispatch::broadcast(ir::value *input, shape_t shape, ir::builder *bui
auto src_shape = input->get_type()->get_block_shapes(); auto src_shape = input->get_type()->get_block_shapes();
if (src_shape.size() != shape.size()) if (src_shape.size() != shape.size())
throw std::runtime_error("Cannot broadcast"); throw std::runtime_error("Cannot broadcast");
if(shape == src_shape)
return input;
return builder->create_broadcast(input, shape); return builder->create_broadcast(input, shape);
} }
std::tuple<ir::value*, ir::value*> dispatch::broadcast(ir::value *lhs, ir::value* rhs, ir::builder *builder) { std::tuple<ir::value*, ir::value*> dispatch::broadcast(ir::value *lhs, ir::value* rhs, ir::builder *builder) {
ir::type *lhs_ty = lhs->get_type(); ir::type *lhs_ty = lhs->get_type();
ir::type *rhs_ty = rhs->get_type(); ir::type *rhs_ty = rhs->get_type();
// make_shape_compatible(block, scalar) // make_shape_compatible(block, scalar)
if (lhs_ty->is_block_ty() && !rhs_ty->is_block_ty()) if (lhs_ty->is_block_ty() && !rhs_ty->is_block_ty())
rhs = builder->create_splat(rhs, lhs_ty->get_block_shapes()); rhs = builder->create_splat(rhs, lhs_ty->get_block_shapes());

View File

@@ -806,6 +806,11 @@ instruction* log_inst::create(value *val, const std::string& name, instruction *
// intrinsic instructions // intrinsic instructions
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// cvt_scanline
cvt_layout_inst* cvt_layout_inst::create(value *arg, const std::string &name, instruction *next) {
return new cvt_layout_inst(arg->get_type(), INST_CVT_LAYOUT, arg, name, next);
}
// copy to shared // copy to shared
copy_to_shared_inst* copy_to_shared_inst::create(value *arg, const std::string &name, copy_to_shared_inst* copy_to_shared_inst::create(value *arg, const std::string &name,
instruction *next) { instruction *next) {
@@ -818,13 +823,6 @@ copy_from_shared_inst* copy_from_shared_inst::create(value *arg, const std::stri
return new copy_from_shared_inst(arg->get_type(), INST_COPY_FROM_SHARED, arg, name, next); return new copy_from_shared_inst(arg->get_type(), INST_COPY_FROM_SHARED, arg, name, next);
} }
// recoalesce
recoalesce_inst* recoalesce_inst::create(value *arg, const std::string &name, instruction *next) {
return new recoalesce_inst(arg->get_type(), INST_RECOALESCE, arg, name, next);
}
// barrier // barrier
barrier_inst::barrier_inst(context &ctx, const std::string &name, barrier_inst::barrier_inst(context &ctx, const std::string &name,
instruction *next) instruction *next)

View File

@@ -363,6 +363,133 @@ def test_reduce1d(dtype, shape, device='cuda'):
triton.testing.assert_almost_equal(z_tri, z_ref) triton.testing.assert_almost_equal(z_tri, z_ref)
@pytest.mark.parametrize("dtype, shape, axis",
[(dtype, shape, 1) \
for dtype in ['float32']\
for shape in [(1, 1024)]])
def test_reduce2d(dtype, shape, axis, device='cuda'):
dtype = cvt[dtype]
# triton kernel
@triton.jit
def kernel(X, Z, **meta):
range_m = tl.arange(0, meta['BLOCK_M'])
range_n = tl.arange(0, meta['BLOCK_N'])
x = tl.load(X + range_m[:, None]*meta['BLOCK_N'] + range_n[None, :])
z = tl.sum(x, axis=meta['AXIS'])
tl.store(Z + range_m, z)
# input
x = triton.testing.random(shape, dtype=dtype, device=device)
# triton result
z_tri = torch.empty((shape[0],), dtype=dtype, device=device)
kernel[(1,)](x, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis)
# torch result
z_ref = torch.sum(x, axis=axis).to(dtype)
# compare
triton.testing.assert_almost_equal(z_tri, z_ref)
# ---------------
# test permute
# ---------------
# ---------------
# test permute
# ---------------
@pytest.mark.parametrize("dtype, shape, perm",
[(dtype, shape, perm) \
for dtype in ['float32']\
for shape in [(128, 128)]\
for perm in [(1, 0)]])
def test_permute(dtype, shape, perm, device='cuda'):
dtype = cvt[dtype]
# triton kernel
@triton.jit
def kernel(X, stride_xm, stride_xn,
Z, stride_zm, stride_zn, **meta):
BLOCK_M = meta['BLOCK_M']
BLOCK_N = meta['BLOCK_N']
off_m = tl.arange(0, BLOCK_M)
off_n = tl.arange(0, BLOCK_N)
Xs = X + off_m[:, None] * stride_xm + off_n[None, :] * stride_xn
Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn
tl.store(Zs, tl.load(Xs))
# input
x = triton.testing.random(shape, dtype=dtype, device=device)
# triton result
z_tri = torch.empty_like(x)
pgm = kernel[(1, 1)](x, x.stride(0), x.stride(1),
z_tri, z_tri.stride(1), z_tri.stride(0),
BLOCK_M=shape[0], BLOCK_N=shape[1])
# torch result
z_ref = x.permute(*perm).contiguous()
# compare
triton.testing.assert_almost_equal(z_tri, z_ref)
# parse ptx to make sure ld/st are vectorized
ptx = pgm.asm('ptx')
assert 'ld.global.v4' in ptx
assert 'st.global.v4' in ptx
# ---------------
# test dot
# ---------------
@pytest.mark.parametrize("epilogue", ['none', 'add-matrix', 'add-rows', 'add-cols'])
def test_dot(epilogue, device='cuda'):
torch.manual_seed(0)
# triton kernel
@triton.jit
def kernel(X, stride_xm, stride_xk,
Y, stride_yk, stride_yn,
Z, stride_zm, stride_zn, **meta):
BLOCK_M = meta['BLOCK_M']
BLOCK_K = meta['BLOCK_K']
BLOCK_N = meta['BLOCK_N']
off_m = tl.arange(0, BLOCK_M)
off_n = tl.arange(0, BLOCK_N)
off_k = tl.arange(0, BLOCK_K)
Xs = X + off_m[:, None] * stride_xm + off_k[None, :] * stride_xk
Ys = Y + off_k[:, None] * stride_yk + off_n[None, :] * stride_yn
Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn
z = tl.dot(tl.load(Xs), tl.load(Ys))
if meta['ADD_MATRIX']:
z += tl.load(Zs)
if meta['ADD_ROWS']:
ZRs = Z + off_m * stride_zm
z += tl.load(ZRs)[:, None]
if meta['ADD_COLS']:
ZCs = Z + off_n * stride_zn
z += tl.load(ZCs)[None, :]
tl.store(Zs, z)
# input
M, N, K = 64, 64, 32
x = triton.testing.random((M, K), dtype=torch.float16, device=device)
y = triton.testing.random((K, N), dtype=torch.float16, device=device)
# triton result
z = triton.testing.random((M, N), dtype=torch.float16, device=device)
z_tri = z.clone()
pgm = kernel[(1, 1)](x, x.stride(0), x.stride(1),
y, y.stride(0), y.stride(1),
z_tri, z_tri.stride(0), z_tri.stride(1),
BLOCK_M=M, BLOCK_K=K, BLOCK_N=N,
ADD_MATRIX = epilogue=='add-matrix',
ADD_ROWS = epilogue=='add-rows',
ADD_COLS = epilogue=='add-cols')
# torch result
z_ref = torch.matmul(x.float(), y.float())
if epilogue == 'add-matrix':
z_ref += z
if epilogue == 'add-rows':
z_ref += z[:,0][:, None]
if epilogue == 'add-cols':
z_ref += z[0,:][None, :]
z_ref = z_ref.to(torch.float16)
# compare
ptx = pgm.asm('ptx')
# print(ptx)
triton.testing.assert_almost_equal(z_tri, z_ref)
# make sure ld/st are vectorized
assert 'ld.global.v4' in ptx
assert 'st.global.v4' in ptx
# --------------- # ---------------

View File

@@ -624,6 +624,14 @@ def max_contiguous(input, value, _builder=None):
return frontend.max_contiguous(input, value, _builder) return frontend.max_contiguous(input, value, _builder)
@builtin
def max_contiguous(input, value, _builder=None):
"""
Let the compiler knows that the `value` first values in :code:`input` are contiguous.
"""
return frontend.max_contiguous(input, value, _builder)
# ----------------------- # -----------------------
# Standard library # Standard library
# ----------------------- # -----------------------

View File

@@ -46,8 +46,8 @@ def _kernel(A, B, C, M, N, K,
pid_n = (pid % width) // (group_size) pid_n = (pid % width) // (group_size)
# do matrix multiplication # do matrix multiplication
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
rk = tl.arange(0, BLOCK_K) rk = tl.arange(0, BLOCK_K)
# pointers # pointers

View File

@@ -87,6 +87,7 @@ def assert_allclose(x, y, tol=1e-2):
def random(shape, dtype, device): def random(shape, dtype, device):
torch.manual_seed(0)
if isinstance(shape, int): if isinstance(shape, int):
shape = (shape, ) shape = (shape, )
if dtype == torch.bool: if dtype == torch.bool: