[CODEGEN] Various bugfixes and stability improvements in compiler backend (#240)

This commit is contained in:
Philippe Tillet
2021-08-30 11:50:35 -07:00
committed by GitHub
parent 85426dbaf7
commit 4ff3714d61
25 changed files with 568 additions and 399 deletions

View File

@@ -93,7 +93,22 @@ protected:
shape_t shape_;
};
class mma_layout: public data_layout {
class distributed_layout: public data_layout{
public:
distributed_layout(id_t id,
const std::vector<int>& axes,
const std::vector<unsigned>& shape,
const std::vector<ir::value*>& values,
analysis::align* align);
int shape_per_cta(size_t k) { return shape_per_cta_.at(k); }
int rep_per_cta(size_t k) { return shape_[k] / shape_per_cta_[k]; }
protected:
std::vector<int> shape_per_cta_;
};
class mma_layout: public distributed_layout {
public:
mma_layout(size_t num_warps,
const std::vector<int>& axes,
@@ -107,7 +122,6 @@ public:
int fpw(size_t k) { return fpw_.at(k); }
int wpt(size_t k) { return wpt_.at(k); }
int spw(size_t k) { return spw_.at(k); }
int spt(size_t k) { return spt_.at(k); }
int rep(size_t k) { return rep_.at(k); }
private:
@@ -123,7 +137,7 @@ private:
std::vector<int> rep_;
};
struct scanline_layout: public data_layout {
struct scanline_layout: public distributed_layout {
scanline_layout(size_t num_warps,
const std::vector<int>& axes,
const std::vector<unsigned>& shape,
@@ -219,6 +233,7 @@ public:
// accessors
unsigned layout_of(ir::value *value) const { return groups_.at(value); }
bool has(ir::value* value) const { return groups_.find(value) != groups_.end(); }
const std::vector<ir::value*>& values_of(unsigned id) const { return values_.at(id); }
size_t num_layouts() const { return values_.size();}
data_layout* get(size_t id) { return layouts_.at(id); }
@@ -226,7 +241,7 @@ public:
std::map<size_t, data_layout*> &get_all() { return layouts_; }
bool has_tmp(ir::value* i) { return tmp_.find(i) != tmp_.end(); }
int tmp(ir::value* i) { return tmp_.at(i);}
void copy(ir::value* dst, ir::value* src) { groups_[dst] = groups_[src]; }
// execution
void run(ir::module &mod);

View File

@@ -171,7 +171,8 @@ public:
void visit_reducend_inst(ir::reduce_inst*, std::function<Value*(Value*,Value*)>, Value*);
void visit_reduce_inst(ir::reduce_inst*);
void visit_select_inst(ir::select_inst*);
void visit_recoalesce_inst(ir::recoalesce_inst*);
void visit_layout_convert(ir::value *out, ir::value *in);
void visit_cvt_layout_inst(ir::cvt_layout_inst*);
void visit_masked_load_async_inst(ir::masked_load_async_inst*);
void visit_copy_to_shared_inst(ir::copy_to_shared_inst*);
void visit_copy_from_shared_inst(ir::copy_from_shared_inst*);

View File

@@ -33,6 +33,7 @@ private:
public:
coalesce(analysis::align* align, triton::codegen::analysis::layouts *layouts);
triton::ir::value *simplify(ir::instruction* i, triton::ir::builder &builder);
void run(ir::module &mod);
private:

View File

@@ -34,8 +34,7 @@ private:
bool rewrite_gep_ptr_min_off_plus_off(ir::instruction *value, ir::builder& builder);
bool rewrite_select_masked_load(ir::instruction *value, ir::builder& builder);
bool rewrite_load_to_shared(ir::instruction *value, ir::builder& builder);
private:
bool rewrite_cvt_layout(ir::instruction *value, ir::builder& builder);
public:
peephole(target* tgt, analysis::layouts* layouts): tgt_(tgt), layouts_(layouts) {}

View File

@@ -60,136 +60,151 @@ protected:
public:
static bool nvmlinit();
static bool cuinit();
static bool spvllvminit();
static void release();
// CUDA
/* ------------------- *
* CUDA
* ------------------- */
// context management
static CUresult cuInit(unsigned int Flags);
static CUresult cuCtxGetCurrent(CUcontext *pctx);
static CUresult cuCtxSetCurrent(CUcontext ctx);
static CUresult cuCtxDestroy_v2(CUcontext ctx);
static CUresult cuEventCreate(CUevent *phEvent, unsigned int Flags);
static CUresult cuDeviceGet(CUdevice *device, int ordinal);
static CUresult cuMemcpyDtoH_v2(void *dstHost, CUdeviceptr srcDevice, size_t ByteCount);
static CUresult cuStreamCreate(CUstream *phStream, unsigned int Flags);
static CUresult cuEventElapsedTime(float *pMilliseconds, CUevent hStart, CUevent hEnd);
static CUresult cuMemFree_v2(CUdeviceptr dptr);
static CUresult cuMemcpyDtoHAsync_v2(void *dstHost, CUdeviceptr srcDevice, size_t ByteCount, CUstream hStream);
static CUresult cuCtxCreate_v2(CUcontext *pctx, unsigned int flags, CUdevice dev);
static CUresult cuCtxPushCurrent_v2(CUcontext ctx);
static CUresult cuCtxPopCurrent_v2(CUcontext *pctx);
static CUresult cuCtxGetDevice(CUdevice* result);
static CUresult cuCtxEnablePeerAccess(CUcontext peerContext, unsigned int flags);
static CUresult cuDriverGetVersion(int *driverVersion);
// device management
static CUresult cuDeviceGet(CUdevice *device, int ordinal);
static CUresult cuDeviceGetName(char *name, int len, CUdevice dev);
static CUresult cuDeviceGetPCIBusId(char *id, int len, CUdevice dev);
static CUresult cuModuleGetGlobal_v2(CUdeviceptr *dptr, size_t* bytes, CUmodule hmod, const char *name);
static CUresult cuMemcpyHtoDAsync_v2(CUdeviceptr dstDevice, const void *srcHost, size_t ByteCount, CUstream hStream);
static CUresult cuModuleLoad(CUmodule *module, const char *fname);
static CUresult cuModuleLoadData(CUmodule* module, const void* image);
static CUresult cuLaunchKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes, CUstream hStream, void **kernelParams, void **extra);
static CUresult cuModuleUnload(CUmodule hmod);
static CUresult cuModuleLoadDataEx(CUmodule *module, const void *image, unsigned int numOptions, CUjit_option *options, void **optionValues);
static CUresult cuDeviceGetAttribute(int *pi, CUdevice_attribute attrib, CUdevice dev);
static CUresult cuDeviceGetCount(int *count);
// link management
static CUresult cuLinkAddData_v2(CUlinkState state, CUjitInputType type, void* data, size_t size, const char* name, unsigned int numOptions, CUjit_option* options, void** optionValues);
static CUresult cuLinkCreate_v2(unsigned int numOptions, CUjit_option* options, void** optionValues, CUlinkState* stateOut);
static CUresult cuLinkComplete(CUlinkState state, void** cubinOut, size_t* sizeOut);
static CUresult cuLinkDestroy(CUlinkState state);
static CUresult cuDeviceGetAttribute(int *pi, CUdevice_attribute attrib, CUdevice dev);
static CUresult cuDeviceGetCount(int *count);
static CUresult cuMemcpyHtoD_v2(CUdeviceptr dstDevice, const void *srcHost, size_t ByteCount);
static CUresult cuInit(unsigned int Flags);
static CUresult cuEventRecord(CUevent hEvent, CUstream hStream);
static CUresult cuCtxCreate_v2(CUcontext *pctx, unsigned int flags, CUdevice dev);
static CUresult cuCtxPushCurrent_v2(CUcontext ctx);
static CUresult cuCtxPopCurrent_v2(CUcontext *pctx);
// module management
static CUresult cuModuleGetGlobal_v2(CUdeviceptr *dptr, size_t* bytes, CUmodule hmod, const char *name);
static CUresult cuModuleLoad(CUmodule *module, const char *fname);
static CUresult cuModuleLoadData(CUmodule* module, const void* image);
static CUresult cuModuleUnload(CUmodule hmod);
static CUresult cuModuleLoadDataEx(CUmodule *module, const void *image, unsigned int numOptions, CUjit_option *options, void **optionValues);
static CUresult cuModuleGetFunction(CUfunction *hfunc, CUmodule hmod, const char *name);
// stream management
static CUresult cuStreamCreate(CUstream *phStream, unsigned int Flags);
static CUresult cuStreamSynchronize(CUstream hStream);
static CUresult cuStreamGetCtx(CUstream hStream, CUcontext* pctx);
static CUresult cuStreamDestroy_v2(CUstream hStream);
static CUresult cuEventDestroy_v2(CUevent hEvent);
static CUresult cuMemAlloc_v2(CUdeviceptr *dptr, size_t bytesize);
static CUresult cuPointerGetAttribute(void * data, CUpointer_attribute attribute, CUdeviceptr ptr);
static CUresult cuCtxGetDevice(CUdevice* result);
static CUresult cuMemsetD8Async(CUdeviceptr dst, unsigned char x, size_t N, CUstream stream);
static CUresult cuLaunchKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes, CUstream hStream, void **kernelParams, void **extra);
// function management
static CUresult cuFuncGetAttribute(int* pi, CUfunction_attribute attrib, CUfunction hfunc);
static CUresult cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value);
static CUresult cuFuncSetCacheConfig(CUfunction hfunc, CUfunc_cache config);
static CUresult cuCtxEnablePeerAccess(CUcontext peerContext, unsigned int flags);
// NVML
// memory management
static CUresult cuMemAlloc_v2(CUdeviceptr *dptr, size_t bytesize);
static CUresult cuPointerGetAttribute(void * data, CUpointer_attribute attribute, CUdeviceptr ptr);
static CUresult cuMemsetD8Async(CUdeviceptr dst, unsigned char x, size_t N, CUstream stream);
static CUresult cuMemcpyDtoH_v2(void *dstHost, CUdeviceptr srcDevice, size_t ByteCount);
static CUresult cuMemFree_v2(CUdeviceptr dptr);
static CUresult cuMemcpyDtoHAsync_v2(void *dstHost, CUdeviceptr srcDevice, size_t ByteCount, CUstream hStream);
static CUresult cuMemcpyHtoDAsync_v2(CUdeviceptr dstDevice, const void *srcHost, size_t ByteCount, CUstream hStream);
static CUresult cuMemcpyHtoD_v2(CUdeviceptr dstDevice, const void *srcHost, size_t ByteCount);
// event management
static CUresult cuEventCreate(CUevent *phEvent, unsigned int Flags);
static CUresult cuEventElapsedTime(float *pMilliseconds, CUevent hStart, CUevent hEnd);
static CUresult cuEventRecord(CUevent hEvent, CUstream hStream);
static CUresult cuEventDestroy_v2(CUevent hEvent);
/* ------------------- *
* NVML
* ------------------- */
static nvmlReturn_t nvmlDeviceGetHandleByPciBusId_v2( const char* pciBusId, nvmlDevice_t* device);
static nvmlReturn_t nvmlDeviceGetClockInfo(nvmlDevice_t device, nvmlClockType_t type, unsigned int *clock);
static nvmlReturn_t nvmlDeviceGetMaxClockInfo(nvmlDevice_t device, nvmlClockType_t type, unsigned int *clock);
static nvmlReturn_t nvmlDeviceSetApplicationsClocks(nvmlDevice_t device, unsigned int mem_clock, unsigned int sm_clock);
// SPIR-V libraries
static int initializeLLVMToSPIRVPass(llvm::PassRegistry &);
static bool writeSpirv(llvm::Module *M, std::ostream &OS, std::string &ErrMsg);
private:
// Libraries
static void* cuda_;
static void* nvml_;
static void* vulkan_;
static void* spvllvm_;
static void* spvcross_;
static void* opengl_;
// CUDA functions
/* ------------------- *
* CUDA
* ------------------- */
// context management
static void* cuCtxGetCurrent_;
static void* cuCtxSetCurrent_;
static void* cuCtxDestroy_v2_;
static void* cuEventCreate_;
static void* cuDeviceGet_;
static void* cuMemcpyDtoH_v2_;
static void* cuStreamCreate_;
static void* cuEventElapsedTime_;
static void* cuMemFree_v2_;
static void* cuMemcpyDtoHAsync_v2_;
static void* cuCtxCreate_v2_;
static void* cuCtxGetDevice_;
static void* cuCtxPushCurrent_v2_;
static void* cuCtxPopCurrent_v2_;
static void* cuCtxEnablePeerAccess_;
static void* cuDriverGetVersion_;
static void* cuInit_;
// device management
static void* cuDeviceGet_;
static void* cuDeviceGetName_;
static void* cuDeviceGetPCIBusId_;
static void* cuModuleGetGlobal_v2_;
static void* cuMemcpyHtoDAsync_v2_;
static void* cuModuleLoad_;
static void* cuLaunchKernel_;
static void* cuModuleUnload_;
static void* cuModuleLoadDataEx_;
static void* cuDeviceGetAttribute_;
static void* cuDeviceGetCount_;
// link management
static void* cuLinkAddData_v2_;
static void* cuLinkCreate_v2_;
static void* cuLinkDestroy_;
static void* cuModuleLoadData_;
static void* cuLinkComplete_;
static void* cuDeviceGetAttribute_;
static void* cuDeviceGetCount_;
static void* cuMemcpyHtoD_v2_;
static void* cuInit_;
static void* cuEventRecord_;
static void* cuCtxCreate_v2_;
// module management
static void* cuModuleGetGlobal_v2_;
static void* cuModuleLoad_;
static void* cuModuleUnload_;
static void* cuModuleLoadDataEx_;
static void* cuModuleLoadData_;
static void* cuModuleGetFunction_;
// stream management
static void* cuStreamCreate_;
static void* cuStreamSynchronize_;
static void* cuStreamDestroy_v2_;
static void* cuStreamGetCtx_;
static void* cuEventDestroy_v2_;
static void* cuMemAlloc_v2_;
static void* cuPointerGetAttribute_;
static void* cuCtxGetDevice_;
static void* cuMemsetD8Async_;
static void* cuCtxPushCurrent_v2_;
static void* cuCtxPopCurrent_v2_;
static void* cuLaunchKernel_;
// function management
static void* cuFuncGetAttribute_;
static void* cuFuncSetAttribute_;
static void* cuFuncSetCacheConfig_;
static void* cuCtxEnablePeerAccess_;
// NVML
// memory management
static void* cuMemcpyDtoH_v2_;
static void* cuMemFree_v2_;
static void* cuMemcpyDtoHAsync_v2_;
static void* cuMemcpyHtoDAsync_v2_;
static void* cuMemcpyHtoD_v2_;
static void* cuMemAlloc_v2_;
static void* cuMemsetD8Async_;
static void* cuPointerGetAttribute_;
// event management
static void* cuEventCreate_;
static void* cuEventElapsedTime_;
static void* cuEventRecord_;
static void* cuEventDestroy_v2_;
/* ------------------- *
* NVML
* ------------------- */
static void* nvmlInit_v2_;
static void* nvmlDeviceGetHandleByPciBusId_v2_;
static void* nvmlDeviceGetClockInfo_;
static void* nvmlDeviceGetMaxClockInfo_;
static void* nvmlDeviceSetApplicationsClocks_;
// LLVM to SPIR-V
static void* initializeLLVMToSPIRVPass_;
static void* writeSpirv_;
};
}

View File

@@ -153,6 +153,9 @@ enum value_id_t: unsigned {
// intrinsics
INST_COPY_TO_SHARED,
INST_COPY_FROM_SHARED,
INST_CVT_LAYOUT,
INST_CVT_SCANLINE,
INST_DECOALESCE,
INST_RECOALESCE,
INST_BARRIER,
INST_ASYNC_WAIT,

View File

@@ -807,16 +807,15 @@ public:
_TRITON_DEFINE_ACCEPT(copy_from_shared_inst)
};
class recoalesce_inst: public unary_inst{
class cvt_layout_inst: public unary_inst {
private:
using unary_inst::unary_inst;
std::string repr_impl() const { return "recoalesce_inst"; }
std::string repr_impl() const { return "cvt_layout_inst"; }
public:
static recoalesce_inst* create(value *arg, const std::string &name = "", instruction *next = nullptr);
_TRITON_DEFINE_CLONE(recoalesce_inst)
_TRITON_DEFINE_ACCEPT(recoalesce_inst)
static cvt_layout_inst* create(value *arg, const std::string &name = "", instruction *next = nullptr);
_TRITON_DEFINE_CLONE(cvt_layout_inst)
_TRITON_DEFINE_ACCEPT(cvt_layout_inst)
};
class barrier_inst: public instruction{

View File

@@ -64,7 +64,7 @@ class sqrt_inst;
class reduce_inst;
class select_inst;
class recoalesce_inst;
class cvt_layout_inst;
class copy_to_shared_inst;
class copy_from_shared_inst;
class masked_load_async_inst;
@@ -142,9 +142,11 @@ public:
virtual void visit_reduce_inst(reduce_inst*) = 0;
virtual void visit_select_inst(select_inst*) = 0;
virtual void visit_recoalesce_inst(recoalesce_inst*) = 0;
virtual void visit_cvt_layout_inst(cvt_layout_inst*) = 0;
virtual void visit_copy_to_shared_inst(copy_to_shared_inst*) = 0;
virtual void visit_copy_from_shared_inst(copy_from_shared_inst*) = 0;
virtual void visit_masked_load_async_inst(masked_load_async_inst*)= 0;
virtual void visit_barrier_inst(barrier_inst*) = 0;
virtual void visit_async_wait_inst(async_wait_inst*) = 0;

View File

@@ -6,6 +6,7 @@
#include <map>
#include <set>
#include <vector>
#include <iostream>
namespace triton {
namespace tools{
@@ -40,8 +41,9 @@ public:
nmap->clear();
std::set<node_t> nodes = nodes_;
unsigned id = 0;
while(!nodes.empty())
while(!nodes.empty()){
connected_components_impl(*nodes.begin(), nodes, nmap, cmap, id++);
}
}
void add_edge(node_t x, node_t y) {