[driver] added spirv-llvm dispatch functions
This commit is contained in:
@@ -103,7 +103,6 @@ int main() {
|
||||
stream->write(da, true, 0, ha);
|
||||
stream->write(db, true, 0, hb);
|
||||
stream->write(dc, true, 0, hc);
|
||||
stream->write(dlocks, true, 0, hlocks);
|
||||
stream->synchronize();
|
||||
|
||||
|
||||
@@ -116,6 +115,8 @@ int main() {
|
||||
unsigned nthreads = info.num_threads;
|
||||
unsigned GZ = jit.get_int("GZ");
|
||||
std::array<size_t, 3> grid = {(M + TM - 1)/TM, (N + TN - 1)/TN, GZ};
|
||||
// init locks
|
||||
stream->write(dlocks, true, 0, hlocks);
|
||||
// set argument
|
||||
kernel->setArg(0, da);
|
||||
kernel->setArg(1, db);
|
||||
|
@@ -38,6 +38,11 @@
|
||||
#include <iostream>
|
||||
#include <stdexcept>
|
||||
|
||||
namespace llvm {
|
||||
class PassRegistry;
|
||||
class Module;
|
||||
}
|
||||
|
||||
namespace triton
|
||||
{
|
||||
namespace driver
|
||||
@@ -85,6 +90,7 @@ public:
|
||||
static bool cuinit();
|
||||
static bool cublasinit();
|
||||
static bool cudnninit();
|
||||
static bool spvllvminit();
|
||||
static void release();
|
||||
|
||||
// OpenCL
|
||||
@@ -123,10 +129,9 @@ public:
|
||||
static cl_program clCreateProgramWithSource(cl_context, cl_uint, const char **, const size_t *, cl_int *);
|
||||
static cl_int clReleaseKernel(cl_kernel);
|
||||
|
||||
//CUDA
|
||||
// CUDA
|
||||
static CUresult cuCtxGetCurrent(CUcontext *pctx);
|
||||
static CUresult cuCtxSetCurrent(CUcontext ctx);
|
||||
|
||||
static CUresult cuCtxDestroy_v2(CUcontext ctx);
|
||||
static CUresult cuEventCreate(CUevent *phEvent, unsigned int Flags);
|
||||
static CUresult cuDeviceGet(CUdevice *device, int ordinal);
|
||||
@@ -139,7 +144,6 @@ public:
|
||||
static CUresult cuDeviceGetName(char *name, int len, CUdevice dev);
|
||||
static CUresult cuDeviceGetPCIBusId(char *id, int len, CUdevice dev);
|
||||
static CUresult cuModuleGetGlobal_v2(CUdeviceptr *dptr, size_t* bytes, CUmodule hmod, const char *name);
|
||||
|
||||
static CUresult cuMemcpyHtoDAsync_v2(CUdeviceptr dstDevice, const void *srcHost, size_t ByteCount, CUstream hStream);
|
||||
static CUresult cuModuleLoad(CUmodule *module, const char *fname);
|
||||
static CUresult cuLaunchKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes, CUstream hStream, void **kernelParams, void **extra);
|
||||
@@ -161,12 +165,12 @@ public:
|
||||
static CUresult cuPointerGetAttribute(void * data, CUpointer_attribute attribute, CUdeviceptr ptr);
|
||||
static CUresult cuCtxGetDevice(CUdevice* result);
|
||||
static CUresult cuMemsetD8Async(CUdeviceptr dst, unsigned char x, size_t N, CUstream stream);
|
||||
|
||||
// NVML
|
||||
static nvmlReturn_t nvmlDeviceGetHandleByPciBusId_v2( const char* pciBusId, nvmlDevice_t* device);
|
||||
static nvmlReturn_t nvmlDeviceGetClockInfo(nvmlDevice_t device, nvmlClockType_t type, unsigned int *clock);
|
||||
static nvmlReturn_t nvmlDeviceGetMaxClockInfo(nvmlDevice_t device, nvmlClockType_t type, unsigned int *clock);
|
||||
static nvmlReturn_t nvmlDeviceSetApplicationsClocks(nvmlDevice_t device, unsigned int mem_clock, unsigned int sm_clock);
|
||||
|
||||
// CUBLAS
|
||||
static cublasHandle_t cublasHandle(driver::cu_context const & ctx);
|
||||
static cublasStatus_t cublasCreate_v2(cublasHandle_t* h);
|
||||
static cublasStatus_t cublasGetStream_v2(cublasHandle_t h, cudaStream_t *streamId);
|
||||
@@ -175,7 +179,7 @@ public:
|
||||
static cublasStatus_t cublasDgemm_v2 (cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, double* alpha, const double *A, int lda, const double *B, int ldb, double* beta, double *C, int ldc);
|
||||
static cublasStatus_t cublasHgemm (cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, half* alpha, const half *A, int lda, const half *B, int ldb, half* beta, half *C, int ldc);
|
||||
static cublasStatus_t cublasGemmEx(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const void *alpha, const void *A, cudaDataType Atype, int lda, const void *B, cudaDataType Btype, int ldb, const void *beta, void *C, cudaDataType Ctype, int ldc, cudaDataType computeType, cublasGemmAlgo_t algo);
|
||||
|
||||
// CUDNN
|
||||
static cudnnHandle_t cudnnHandle(driver::cu_context const & ctx);
|
||||
static cudnnStatus_t cudnnCreatePoolingDescriptor(cudnnPoolingDescriptor_t *poolingDesc);
|
||||
static cudnnStatus_t cudnnCreateConvolutionDescriptor(cudnnConvolutionDescriptor_t* convDesc);
|
||||
@@ -196,6 +200,10 @@ public:
|
||||
static cudnnStatus_t cudnnSetStream(cudnnHandle_t handle, cudaStream_t streamId);
|
||||
static cudnnStatus_t cudnnTransformTensor(cudnnHandle_t handle, const void *alpha, const cudnnTensorDescriptor_t xDesc, const void *x, const void *beta, const cudnnTensorDescriptor_t yDesc, void *y);
|
||||
|
||||
// SPIR-V libraries
|
||||
static int initializeLLVMToSPIRVPass(llvm::PassRegistry &);
|
||||
static bool writeSpirv(llvm::Module *M, std::ostream &OS, std::string &ErrMsg);
|
||||
|
||||
private:
|
||||
|
||||
// Libraries
|
||||
@@ -204,6 +212,10 @@ private:
|
||||
static void* nvml_;
|
||||
static void* cublas_;
|
||||
static void* cudnn_;
|
||||
static void* vulkan_;
|
||||
static void* spvllvm_;
|
||||
static void* spvcross_;
|
||||
static void* opengl_;
|
||||
|
||||
// OpenCL functions
|
||||
static void* clBuildProgram_;
|
||||
@@ -310,6 +322,10 @@ private:
|
||||
static void* cudnnPoolingForward_;
|
||||
static void* cudnnSetStream_;
|
||||
static void* cudnnTransformTensor_;
|
||||
|
||||
// LLVM to SPIR-V
|
||||
static void* initializeLLVMToSPIRVPass_;
|
||||
static void* writeSpirv_;
|
||||
};
|
||||
|
||||
}
|
||||
|
@@ -158,6 +158,12 @@ bool dispatch::cudnninit(){
|
||||
return cudnn_ != nullptr;
|
||||
}
|
||||
|
||||
bool dispatch::spvllvminit(){
|
||||
if(spvllvm_==nullptr)
|
||||
spvllvm_ = dlopen("libLLVMSPIRVLib.so", RTLD_LAZY);
|
||||
return spvllvm_ != nullptr;
|
||||
}
|
||||
|
||||
//CUDA
|
||||
CUDA_DEFINE1(CUresult, cuCtxDestroy_v2, CUcontext)
|
||||
CUDA_DEFINE2(CUresult, cuEventCreate, CUevent *, unsigned int)
|
||||
@@ -292,6 +298,15 @@ OCL_DEFINE5(cl_mem, clCreateBuffer, cl_context, cl_mem_flags, size_t, void *, cl
|
||||
OCL_DEFINE5(cl_program, clCreateProgramWithSource, cl_context, cl_uint, const char **, const size_t *, cl_int *)
|
||||
OCL_DEFINE1(cl_int, clReleaseKernel, cl_kernel)
|
||||
|
||||
// LLVM to SPIR-V
|
||||
int dispatch::initializeLLVMToSPIRVPass(llvm::PassRegistry ®istry){
|
||||
return f_impl<dispatch::spvllvminit>(spvllvm_, initializeLLVMToSPIRVPass, initializeLLVMToSPIRVPass_, "initializeLLVMToSPIRVPass", std::ref(registry));
|
||||
}
|
||||
|
||||
bool dispatch::writeSpirv(llvm::Module *M, std::ostream &OS, std::string &ErrMsg){
|
||||
return f_impl<dispatch::spvllvminit>(spvllvm_, writeSpirv, writeSpirv_, "writeSpirv", M, std::ref(OS), std::ref(ErrMsg));
|
||||
}
|
||||
|
||||
// Release
|
||||
void dispatch::release(){
|
||||
if(cuda_){
|
||||
@@ -313,6 +328,7 @@ void* dispatch::cuda_;
|
||||
void* dispatch::nvml_;
|
||||
void* dispatch::cublas_;
|
||||
void* dispatch::cudnn_;
|
||||
void* dispatch::spvllvm_;
|
||||
|
||||
//OpenCL
|
||||
void* dispatch::clBuildProgram_;
|
||||
@@ -421,5 +437,9 @@ void* dispatch::cudnnPoolingForward_;
|
||||
void* dispatch::cudnnSetStream_;
|
||||
void* dispatch::cudnnTransformTensor_;
|
||||
|
||||
// SPIR-V
|
||||
void* dispatch::initializeLLVMToSPIRVPass_;
|
||||
void* dispatch::writeSpirv_;
|
||||
|
||||
}
|
||||
}
|
||||
|
@@ -71,12 +71,11 @@ std::unique_ptr<llvm::Module> jit::make_llvm_module(ir::module &module, passes_w
|
||||
llvm::Module* result = new llvm::Module(module.get_name(), llvm_context_);
|
||||
passes.selection.run(module, *result);
|
||||
// launch information
|
||||
launch_information info;
|
||||
launch_information& info = launch_info_map_[result->getName()];
|
||||
info.global_range_size.clear();
|
||||
for(unsigned i = 0; i < passes.tune.get_num_global_range(); i++)
|
||||
info.global_range_size.push_back(passes.tune.get_global_range_size(i));
|
||||
info.num_threads = passes.tune.get_num_threads();
|
||||
launch_info_map_.insert({result->getName(), info});
|
||||
return std::unique_ptr<llvm::Module>(result);
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user