[CODEGEN] Various bugfixes and stability improvements in compiler backend (#240)
This commit is contained in:
@@ -93,7 +93,22 @@ protected:
|
||||
shape_t shape_;
|
||||
};
|
||||
|
||||
class mma_layout: public data_layout {
|
||||
class distributed_layout: public data_layout{
|
||||
public:
|
||||
distributed_layout(id_t id,
|
||||
const std::vector<int>& axes,
|
||||
const std::vector<unsigned>& shape,
|
||||
const std::vector<ir::value*>& values,
|
||||
analysis::align* align);
|
||||
|
||||
int shape_per_cta(size_t k) { return shape_per_cta_.at(k); }
|
||||
int rep_per_cta(size_t k) { return shape_[k] / shape_per_cta_[k]; }
|
||||
|
||||
protected:
|
||||
std::vector<int> shape_per_cta_;
|
||||
};
|
||||
|
||||
class mma_layout: public distributed_layout {
|
||||
public:
|
||||
mma_layout(size_t num_warps,
|
||||
const std::vector<int>& axes,
|
||||
@@ -107,7 +122,6 @@ public:
|
||||
int fpw(size_t k) { return fpw_.at(k); }
|
||||
int wpt(size_t k) { return wpt_.at(k); }
|
||||
int spw(size_t k) { return spw_.at(k); }
|
||||
int spt(size_t k) { return spt_.at(k); }
|
||||
int rep(size_t k) { return rep_.at(k); }
|
||||
|
||||
private:
|
||||
@@ -123,7 +137,7 @@ private:
|
||||
std::vector<int> rep_;
|
||||
};
|
||||
|
||||
struct scanline_layout: public data_layout {
|
||||
struct scanline_layout: public distributed_layout {
|
||||
scanline_layout(size_t num_warps,
|
||||
const std::vector<int>& axes,
|
||||
const std::vector<unsigned>& shape,
|
||||
@@ -219,6 +233,7 @@ public:
|
||||
|
||||
// accessors
|
||||
unsigned layout_of(ir::value *value) const { return groups_.at(value); }
|
||||
bool has(ir::value* value) const { return groups_.find(value) != groups_.end(); }
|
||||
const std::vector<ir::value*>& values_of(unsigned id) const { return values_.at(id); }
|
||||
size_t num_layouts() const { return values_.size();}
|
||||
data_layout* get(size_t id) { return layouts_.at(id); }
|
||||
@@ -226,7 +241,7 @@ public:
|
||||
std::map<size_t, data_layout*> &get_all() { return layouts_; }
|
||||
bool has_tmp(ir::value* i) { return tmp_.find(i) != tmp_.end(); }
|
||||
int tmp(ir::value* i) { return tmp_.at(i);}
|
||||
|
||||
void copy(ir::value* dst, ir::value* src) { groups_[dst] = groups_[src]; }
|
||||
// execution
|
||||
void run(ir::module &mod);
|
||||
|
||||
|
@@ -171,7 +171,8 @@ public:
|
||||
void visit_reducend_inst(ir::reduce_inst*, std::function<Value*(Value*,Value*)>, Value*);
|
||||
void visit_reduce_inst(ir::reduce_inst*);
|
||||
void visit_select_inst(ir::select_inst*);
|
||||
void visit_recoalesce_inst(ir::recoalesce_inst*);
|
||||
void visit_layout_convert(ir::value *out, ir::value *in);
|
||||
void visit_cvt_layout_inst(ir::cvt_layout_inst*);
|
||||
void visit_masked_load_async_inst(ir::masked_load_async_inst*);
|
||||
void visit_copy_to_shared_inst(ir::copy_to_shared_inst*);
|
||||
void visit_copy_from_shared_inst(ir::copy_from_shared_inst*);
|
||||
|
@@ -33,6 +33,7 @@ private:
|
||||
|
||||
public:
|
||||
coalesce(analysis::align* align, triton::codegen::analysis::layouts *layouts);
|
||||
triton::ir::value *simplify(ir::instruction* i, triton::ir::builder &builder);
|
||||
void run(ir::module &mod);
|
||||
|
||||
private:
|
||||
|
@@ -34,8 +34,7 @@ private:
|
||||
bool rewrite_gep_ptr_min_off_plus_off(ir::instruction *value, ir::builder& builder);
|
||||
bool rewrite_select_masked_load(ir::instruction *value, ir::builder& builder);
|
||||
bool rewrite_load_to_shared(ir::instruction *value, ir::builder& builder);
|
||||
|
||||
private:
|
||||
bool rewrite_cvt_layout(ir::instruction *value, ir::builder& builder);
|
||||
|
||||
public:
|
||||
peephole(target* tgt, analysis::layouts* layouts): tgt_(tgt), layouts_(layouts) {}
|
||||
|
@@ -60,136 +60,151 @@ protected:
|
||||
public:
|
||||
static bool nvmlinit();
|
||||
static bool cuinit();
|
||||
static bool spvllvminit();
|
||||
static void release();
|
||||
|
||||
// CUDA
|
||||
/* ------------------- *
|
||||
* CUDA
|
||||
* ------------------- */
|
||||
// context management
|
||||
static CUresult cuInit(unsigned int Flags);
|
||||
static CUresult cuCtxGetCurrent(CUcontext *pctx);
|
||||
static CUresult cuCtxSetCurrent(CUcontext ctx);
|
||||
static CUresult cuCtxDestroy_v2(CUcontext ctx);
|
||||
static CUresult cuEventCreate(CUevent *phEvent, unsigned int Flags);
|
||||
static CUresult cuDeviceGet(CUdevice *device, int ordinal);
|
||||
static CUresult cuMemcpyDtoH_v2(void *dstHost, CUdeviceptr srcDevice, size_t ByteCount);
|
||||
static CUresult cuStreamCreate(CUstream *phStream, unsigned int Flags);
|
||||
static CUresult cuEventElapsedTime(float *pMilliseconds, CUevent hStart, CUevent hEnd);
|
||||
static CUresult cuMemFree_v2(CUdeviceptr dptr);
|
||||
static CUresult cuMemcpyDtoHAsync_v2(void *dstHost, CUdeviceptr srcDevice, size_t ByteCount, CUstream hStream);
|
||||
static CUresult cuCtxCreate_v2(CUcontext *pctx, unsigned int flags, CUdevice dev);
|
||||
static CUresult cuCtxPushCurrent_v2(CUcontext ctx);
|
||||
static CUresult cuCtxPopCurrent_v2(CUcontext *pctx);
|
||||
static CUresult cuCtxGetDevice(CUdevice* result);
|
||||
static CUresult cuCtxEnablePeerAccess(CUcontext peerContext, unsigned int flags);
|
||||
static CUresult cuDriverGetVersion(int *driverVersion);
|
||||
// device management
|
||||
static CUresult cuDeviceGet(CUdevice *device, int ordinal);
|
||||
static CUresult cuDeviceGetName(char *name, int len, CUdevice dev);
|
||||
static CUresult cuDeviceGetPCIBusId(char *id, int len, CUdevice dev);
|
||||
static CUresult cuModuleGetGlobal_v2(CUdeviceptr *dptr, size_t* bytes, CUmodule hmod, const char *name);
|
||||
static CUresult cuMemcpyHtoDAsync_v2(CUdeviceptr dstDevice, const void *srcHost, size_t ByteCount, CUstream hStream);
|
||||
static CUresult cuModuleLoad(CUmodule *module, const char *fname);
|
||||
static CUresult cuModuleLoadData(CUmodule* module, const void* image);
|
||||
static CUresult cuLaunchKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes, CUstream hStream, void **kernelParams, void **extra);
|
||||
static CUresult cuModuleUnload(CUmodule hmod);
|
||||
static CUresult cuModuleLoadDataEx(CUmodule *module, const void *image, unsigned int numOptions, CUjit_option *options, void **optionValues);
|
||||
|
||||
static CUresult cuDeviceGetAttribute(int *pi, CUdevice_attribute attrib, CUdevice dev);
|
||||
static CUresult cuDeviceGetCount(int *count);
|
||||
// link management
|
||||
static CUresult cuLinkAddData_v2(CUlinkState state, CUjitInputType type, void* data, size_t size, const char* name, unsigned int numOptions, CUjit_option* options, void** optionValues);
|
||||
static CUresult cuLinkCreate_v2(unsigned int numOptions, CUjit_option* options, void** optionValues, CUlinkState* stateOut);
|
||||
static CUresult cuLinkComplete(CUlinkState state, void** cubinOut, size_t* sizeOut);
|
||||
static CUresult cuLinkDestroy(CUlinkState state);
|
||||
|
||||
static CUresult cuDeviceGetAttribute(int *pi, CUdevice_attribute attrib, CUdevice dev);
|
||||
static CUresult cuDeviceGetCount(int *count);
|
||||
static CUresult cuMemcpyHtoD_v2(CUdeviceptr dstDevice, const void *srcHost, size_t ByteCount);
|
||||
static CUresult cuInit(unsigned int Flags);
|
||||
static CUresult cuEventRecord(CUevent hEvent, CUstream hStream);
|
||||
static CUresult cuCtxCreate_v2(CUcontext *pctx, unsigned int flags, CUdevice dev);
|
||||
static CUresult cuCtxPushCurrent_v2(CUcontext ctx);
|
||||
static CUresult cuCtxPopCurrent_v2(CUcontext *pctx);
|
||||
// module management
|
||||
static CUresult cuModuleGetGlobal_v2(CUdeviceptr *dptr, size_t* bytes, CUmodule hmod, const char *name);
|
||||
static CUresult cuModuleLoad(CUmodule *module, const char *fname);
|
||||
static CUresult cuModuleLoadData(CUmodule* module, const void* image);
|
||||
static CUresult cuModuleUnload(CUmodule hmod);
|
||||
static CUresult cuModuleLoadDataEx(CUmodule *module, const void *image, unsigned int numOptions, CUjit_option *options, void **optionValues);
|
||||
static CUresult cuModuleGetFunction(CUfunction *hfunc, CUmodule hmod, const char *name);
|
||||
// stream management
|
||||
static CUresult cuStreamCreate(CUstream *phStream, unsigned int Flags);
|
||||
static CUresult cuStreamSynchronize(CUstream hStream);
|
||||
static CUresult cuStreamGetCtx(CUstream hStream, CUcontext* pctx);
|
||||
static CUresult cuStreamDestroy_v2(CUstream hStream);
|
||||
static CUresult cuEventDestroy_v2(CUevent hEvent);
|
||||
static CUresult cuMemAlloc_v2(CUdeviceptr *dptr, size_t bytesize);
|
||||
static CUresult cuPointerGetAttribute(void * data, CUpointer_attribute attribute, CUdeviceptr ptr);
|
||||
static CUresult cuCtxGetDevice(CUdevice* result);
|
||||
static CUresult cuMemsetD8Async(CUdeviceptr dst, unsigned char x, size_t N, CUstream stream);
|
||||
static CUresult cuLaunchKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes, CUstream hStream, void **kernelParams, void **extra);
|
||||
// function management
|
||||
static CUresult cuFuncGetAttribute(int* pi, CUfunction_attribute attrib, CUfunction hfunc);
|
||||
static CUresult cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value);
|
||||
static CUresult cuFuncSetCacheConfig(CUfunction hfunc, CUfunc_cache config);
|
||||
static CUresult cuCtxEnablePeerAccess(CUcontext peerContext, unsigned int flags);
|
||||
// NVML
|
||||
// memory management
|
||||
static CUresult cuMemAlloc_v2(CUdeviceptr *dptr, size_t bytesize);
|
||||
static CUresult cuPointerGetAttribute(void * data, CUpointer_attribute attribute, CUdeviceptr ptr);
|
||||
static CUresult cuMemsetD8Async(CUdeviceptr dst, unsigned char x, size_t N, CUstream stream);
|
||||
static CUresult cuMemcpyDtoH_v2(void *dstHost, CUdeviceptr srcDevice, size_t ByteCount);
|
||||
static CUresult cuMemFree_v2(CUdeviceptr dptr);
|
||||
static CUresult cuMemcpyDtoHAsync_v2(void *dstHost, CUdeviceptr srcDevice, size_t ByteCount, CUstream hStream);
|
||||
static CUresult cuMemcpyHtoDAsync_v2(CUdeviceptr dstDevice, const void *srcHost, size_t ByteCount, CUstream hStream);
|
||||
static CUresult cuMemcpyHtoD_v2(CUdeviceptr dstDevice, const void *srcHost, size_t ByteCount);
|
||||
// event management
|
||||
static CUresult cuEventCreate(CUevent *phEvent, unsigned int Flags);
|
||||
static CUresult cuEventElapsedTime(float *pMilliseconds, CUevent hStart, CUevent hEnd);
|
||||
static CUresult cuEventRecord(CUevent hEvent, CUstream hStream);
|
||||
static CUresult cuEventDestroy_v2(CUevent hEvent);
|
||||
|
||||
|
||||
/* ------------------- *
|
||||
* NVML
|
||||
* ------------------- */
|
||||
static nvmlReturn_t nvmlDeviceGetHandleByPciBusId_v2( const char* pciBusId, nvmlDevice_t* device);
|
||||
static nvmlReturn_t nvmlDeviceGetClockInfo(nvmlDevice_t device, nvmlClockType_t type, unsigned int *clock);
|
||||
static nvmlReturn_t nvmlDeviceGetMaxClockInfo(nvmlDevice_t device, nvmlClockType_t type, unsigned int *clock);
|
||||
static nvmlReturn_t nvmlDeviceSetApplicationsClocks(nvmlDevice_t device, unsigned int mem_clock, unsigned int sm_clock);
|
||||
|
||||
|
||||
// SPIR-V libraries
|
||||
static int initializeLLVMToSPIRVPass(llvm::PassRegistry &);
|
||||
static bool writeSpirv(llvm::Module *M, std::ostream &OS, std::string &ErrMsg);
|
||||
|
||||
|
||||
private:
|
||||
|
||||
// Libraries
|
||||
static void* cuda_;
|
||||
static void* nvml_;
|
||||
static void* vulkan_;
|
||||
static void* spvllvm_;
|
||||
static void* spvcross_;
|
||||
static void* opengl_;
|
||||
|
||||
|
||||
// CUDA functions
|
||||
/* ------------------- *
|
||||
* CUDA
|
||||
* ------------------- */
|
||||
// context management
|
||||
static void* cuCtxGetCurrent_;
|
||||
static void* cuCtxSetCurrent_;
|
||||
static void* cuCtxDestroy_v2_;
|
||||
static void* cuEventCreate_;
|
||||
static void* cuDeviceGet_;
|
||||
static void* cuMemcpyDtoH_v2_;
|
||||
static void* cuStreamCreate_;
|
||||
static void* cuEventElapsedTime_;
|
||||
static void* cuMemFree_v2_;
|
||||
static void* cuMemcpyDtoHAsync_v2_;
|
||||
static void* cuCtxCreate_v2_;
|
||||
static void* cuCtxGetDevice_;
|
||||
static void* cuCtxPushCurrent_v2_;
|
||||
static void* cuCtxPopCurrent_v2_;
|
||||
static void* cuCtxEnablePeerAccess_;
|
||||
static void* cuDriverGetVersion_;
|
||||
static void* cuInit_;
|
||||
// device management
|
||||
static void* cuDeviceGet_;
|
||||
static void* cuDeviceGetName_;
|
||||
static void* cuDeviceGetPCIBusId_;
|
||||
static void* cuModuleGetGlobal_v2_;
|
||||
static void* cuMemcpyHtoDAsync_v2_;
|
||||
static void* cuModuleLoad_;
|
||||
static void* cuLaunchKernel_;
|
||||
static void* cuModuleUnload_;
|
||||
static void* cuModuleLoadDataEx_;
|
||||
static void* cuDeviceGetAttribute_;
|
||||
static void* cuDeviceGetCount_;
|
||||
// link management
|
||||
static void* cuLinkAddData_v2_;
|
||||
static void* cuLinkCreate_v2_;
|
||||
static void* cuLinkDestroy_;
|
||||
static void* cuModuleLoadData_;
|
||||
static void* cuLinkComplete_;
|
||||
static void* cuDeviceGetAttribute_;
|
||||
static void* cuDeviceGetCount_;
|
||||
static void* cuMemcpyHtoD_v2_;
|
||||
static void* cuInit_;
|
||||
static void* cuEventRecord_;
|
||||
static void* cuCtxCreate_v2_;
|
||||
// module management
|
||||
static void* cuModuleGetGlobal_v2_;
|
||||
static void* cuModuleLoad_;
|
||||
static void* cuModuleUnload_;
|
||||
static void* cuModuleLoadDataEx_;
|
||||
static void* cuModuleLoadData_;
|
||||
static void* cuModuleGetFunction_;
|
||||
// stream management
|
||||
static void* cuStreamCreate_;
|
||||
static void* cuStreamSynchronize_;
|
||||
static void* cuStreamDestroy_v2_;
|
||||
static void* cuStreamGetCtx_;
|
||||
static void* cuEventDestroy_v2_;
|
||||
static void* cuMemAlloc_v2_;
|
||||
static void* cuPointerGetAttribute_;
|
||||
static void* cuCtxGetDevice_;
|
||||
static void* cuMemsetD8Async_;
|
||||
static void* cuCtxPushCurrent_v2_;
|
||||
static void* cuCtxPopCurrent_v2_;
|
||||
static void* cuLaunchKernel_;
|
||||
// function management
|
||||
static void* cuFuncGetAttribute_;
|
||||
static void* cuFuncSetAttribute_;
|
||||
static void* cuFuncSetCacheConfig_;
|
||||
static void* cuCtxEnablePeerAccess_;
|
||||
// NVML
|
||||
// memory management
|
||||
static void* cuMemcpyDtoH_v2_;
|
||||
static void* cuMemFree_v2_;
|
||||
static void* cuMemcpyDtoHAsync_v2_;
|
||||
static void* cuMemcpyHtoDAsync_v2_;
|
||||
static void* cuMemcpyHtoD_v2_;
|
||||
static void* cuMemAlloc_v2_;
|
||||
static void* cuMemsetD8Async_;
|
||||
static void* cuPointerGetAttribute_;
|
||||
// event management
|
||||
static void* cuEventCreate_;
|
||||
static void* cuEventElapsedTime_;
|
||||
static void* cuEventRecord_;
|
||||
static void* cuEventDestroy_v2_;
|
||||
|
||||
|
||||
|
||||
|
||||
/* ------------------- *
|
||||
* NVML
|
||||
* ------------------- */
|
||||
static void* nvmlInit_v2_;
|
||||
static void* nvmlDeviceGetHandleByPciBusId_v2_;
|
||||
static void* nvmlDeviceGetClockInfo_;
|
||||
static void* nvmlDeviceGetMaxClockInfo_;
|
||||
static void* nvmlDeviceSetApplicationsClocks_;
|
||||
|
||||
// LLVM to SPIR-V
|
||||
static void* initializeLLVMToSPIRVPass_;
|
||||
static void* writeSpirv_;
|
||||
};
|
||||
|
||||
}
|
||||
|
@@ -153,6 +153,9 @@ enum value_id_t: unsigned {
|
||||
// intrinsics
|
||||
INST_COPY_TO_SHARED,
|
||||
INST_COPY_FROM_SHARED,
|
||||
INST_CVT_LAYOUT,
|
||||
INST_CVT_SCANLINE,
|
||||
INST_DECOALESCE,
|
||||
INST_RECOALESCE,
|
||||
INST_BARRIER,
|
||||
INST_ASYNC_WAIT,
|
||||
|
@@ -807,16 +807,15 @@ public:
|
||||
_TRITON_DEFINE_ACCEPT(copy_from_shared_inst)
|
||||
};
|
||||
|
||||
|
||||
class recoalesce_inst: public unary_inst{
|
||||
class cvt_layout_inst: public unary_inst {
|
||||
private:
|
||||
using unary_inst::unary_inst;
|
||||
std::string repr_impl() const { return "recoalesce_inst"; }
|
||||
std::string repr_impl() const { return "cvt_layout_inst"; }
|
||||
|
||||
public:
|
||||
static recoalesce_inst* create(value *arg, const std::string &name = "", instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(recoalesce_inst)
|
||||
_TRITON_DEFINE_ACCEPT(recoalesce_inst)
|
||||
static cvt_layout_inst* create(value *arg, const std::string &name = "", instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(cvt_layout_inst)
|
||||
_TRITON_DEFINE_ACCEPT(cvt_layout_inst)
|
||||
};
|
||||
|
||||
class barrier_inst: public instruction{
|
||||
|
@@ -64,7 +64,7 @@ class sqrt_inst;
|
||||
class reduce_inst;
|
||||
class select_inst;
|
||||
|
||||
class recoalesce_inst;
|
||||
class cvt_layout_inst;
|
||||
class copy_to_shared_inst;
|
||||
class copy_from_shared_inst;
|
||||
class masked_load_async_inst;
|
||||
@@ -142,9 +142,11 @@ public:
|
||||
virtual void visit_reduce_inst(reduce_inst*) = 0;
|
||||
virtual void visit_select_inst(select_inst*) = 0;
|
||||
|
||||
virtual void visit_recoalesce_inst(recoalesce_inst*) = 0;
|
||||
virtual void visit_cvt_layout_inst(cvt_layout_inst*) = 0;
|
||||
virtual void visit_copy_to_shared_inst(copy_to_shared_inst*) = 0;
|
||||
virtual void visit_copy_from_shared_inst(copy_from_shared_inst*) = 0;
|
||||
|
||||
|
||||
virtual void visit_masked_load_async_inst(masked_load_async_inst*)= 0;
|
||||
virtual void visit_barrier_inst(barrier_inst*) = 0;
|
||||
virtual void visit_async_wait_inst(async_wait_inst*) = 0;
|
||||
|
@@ -6,6 +6,7 @@
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include <iostream>
|
||||
|
||||
namespace triton {
|
||||
namespace tools{
|
||||
@@ -40,8 +41,9 @@ public:
|
||||
nmap->clear();
|
||||
std::set<node_t> nodes = nodes_;
|
||||
unsigned id = 0;
|
||||
while(!nodes.empty())
|
||||
while(!nodes.empty()){
|
||||
connected_components_impl(*nodes.begin(), nodes, nmap, cmap, id++);
|
||||
}
|
||||
}
|
||||
|
||||
void add_edge(node_t x, node_t y) {
|
||||
|
Reference in New Issue
Block a user