|
|
@@ -38,6 +38,11 @@
|
|
|
|
#include <iostream>
|
|
|
|
#include <iostream>
|
|
|
|
#include <stdexcept>
|
|
|
|
#include <stdexcept>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
namespace llvm {
|
|
|
|
|
|
|
|
class PassRegistry;
|
|
|
|
|
|
|
|
class Module;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
namespace triton
|
|
|
|
namespace triton
|
|
|
|
{
|
|
|
|
{
|
|
|
|
namespace driver
|
|
|
|
namespace driver
|
|
|
@@ -85,6 +90,7 @@ public:
|
|
|
|
static bool cuinit();
|
|
|
|
static bool cuinit();
|
|
|
|
static bool cublasinit();
|
|
|
|
static bool cublasinit();
|
|
|
|
static bool cudnninit();
|
|
|
|
static bool cudnninit();
|
|
|
|
|
|
|
|
static bool spvllvminit();
|
|
|
|
static void release();
|
|
|
|
static void release();
|
|
|
|
|
|
|
|
|
|
|
|
// OpenCL
|
|
|
|
// OpenCL
|
|
|
@@ -126,7 +132,6 @@ public:
|
|
|
|
// CUDA
|
|
|
|
// CUDA
|
|
|
|
static CUresult cuCtxGetCurrent(CUcontext *pctx);
|
|
|
|
static CUresult cuCtxGetCurrent(CUcontext *pctx);
|
|
|
|
static CUresult cuCtxSetCurrent(CUcontext ctx);
|
|
|
|
static CUresult cuCtxSetCurrent(CUcontext ctx);
|
|
|
|
|
|
|
|
|
|
|
|
static CUresult cuCtxDestroy_v2(CUcontext ctx);
|
|
|
|
static CUresult cuCtxDestroy_v2(CUcontext ctx);
|
|
|
|
static CUresult cuEventCreate(CUevent *phEvent, unsigned int Flags);
|
|
|
|
static CUresult cuEventCreate(CUevent *phEvent, unsigned int Flags);
|
|
|
|
static CUresult cuDeviceGet(CUdevice *device, int ordinal);
|
|
|
|
static CUresult cuDeviceGet(CUdevice *device, int ordinal);
|
|
|
@@ -139,7 +144,6 @@ public:
|
|
|
|
static CUresult cuDeviceGetName(char *name, int len, CUdevice dev);
|
|
|
|
static CUresult cuDeviceGetName(char *name, int len, CUdevice dev);
|
|
|
|
static CUresult cuDeviceGetPCIBusId(char *id, int len, CUdevice dev);
|
|
|
|
static CUresult cuDeviceGetPCIBusId(char *id, int len, CUdevice dev);
|
|
|
|
static CUresult cuModuleGetGlobal_v2(CUdeviceptr *dptr, size_t* bytes, CUmodule hmod, const char *name);
|
|
|
|
static CUresult cuModuleGetGlobal_v2(CUdeviceptr *dptr, size_t* bytes, CUmodule hmod, const char *name);
|
|
|
|
|
|
|
|
|
|
|
|
static CUresult cuMemcpyHtoDAsync_v2(CUdeviceptr dstDevice, const void *srcHost, size_t ByteCount, CUstream hStream);
|
|
|
|
static CUresult cuMemcpyHtoDAsync_v2(CUdeviceptr dstDevice, const void *srcHost, size_t ByteCount, CUstream hStream);
|
|
|
|
static CUresult cuModuleLoad(CUmodule *module, const char *fname);
|
|
|
|
static CUresult cuModuleLoad(CUmodule *module, const char *fname);
|
|
|
|
static CUresult cuLaunchKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes, CUstream hStream, void **kernelParams, void **extra);
|
|
|
|
static CUresult cuLaunchKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes, CUstream hStream, void **kernelParams, void **extra);
|
|
|
@@ -161,12 +165,12 @@ public:
|
|
|
|
static CUresult cuPointerGetAttribute(void * data, CUpointer_attribute attribute, CUdeviceptr ptr);
|
|
|
|
static CUresult cuPointerGetAttribute(void * data, CUpointer_attribute attribute, CUdeviceptr ptr);
|
|
|
|
static CUresult cuCtxGetDevice(CUdevice* result);
|
|
|
|
static CUresult cuCtxGetDevice(CUdevice* result);
|
|
|
|
static CUresult cuMemsetD8Async(CUdeviceptr dst, unsigned char x, size_t N, CUstream stream);
|
|
|
|
static CUresult cuMemsetD8Async(CUdeviceptr dst, unsigned char x, size_t N, CUstream stream);
|
|
|
|
|
|
|
|
// NVML
|
|
|
|
static nvmlReturn_t nvmlDeviceGetHandleByPciBusId_v2( const char* pciBusId, nvmlDevice_t* device);
|
|
|
|
static nvmlReturn_t nvmlDeviceGetHandleByPciBusId_v2( const char* pciBusId, nvmlDevice_t* device);
|
|
|
|
static nvmlReturn_t nvmlDeviceGetClockInfo(nvmlDevice_t device, nvmlClockType_t type, unsigned int *clock);
|
|
|
|
static nvmlReturn_t nvmlDeviceGetClockInfo(nvmlDevice_t device, nvmlClockType_t type, unsigned int *clock);
|
|
|
|
static nvmlReturn_t nvmlDeviceGetMaxClockInfo(nvmlDevice_t device, nvmlClockType_t type, unsigned int *clock);
|
|
|
|
static nvmlReturn_t nvmlDeviceGetMaxClockInfo(nvmlDevice_t device, nvmlClockType_t type, unsigned int *clock);
|
|
|
|
static nvmlReturn_t nvmlDeviceSetApplicationsClocks(nvmlDevice_t device, unsigned int mem_clock, unsigned int sm_clock);
|
|
|
|
static nvmlReturn_t nvmlDeviceSetApplicationsClocks(nvmlDevice_t device, unsigned int mem_clock, unsigned int sm_clock);
|
|
|
|
|
|
|
|
// CUBLAS
|
|
|
|
static cublasHandle_t cublasHandle(driver::cu_context const & ctx);
|
|
|
|
static cublasHandle_t cublasHandle(driver::cu_context const & ctx);
|
|
|
|
static cublasStatus_t cublasCreate_v2(cublasHandle_t* h);
|
|
|
|
static cublasStatus_t cublasCreate_v2(cublasHandle_t* h);
|
|
|
|
static cublasStatus_t cublasGetStream_v2(cublasHandle_t h, cudaStream_t *streamId);
|
|
|
|
static cublasStatus_t cublasGetStream_v2(cublasHandle_t h, cudaStream_t *streamId);
|
|
|
@@ -175,7 +179,7 @@ public:
|
|
|
|
static cublasStatus_t cublasDgemm_v2 (cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, double* alpha, const double *A, int lda, const double *B, int ldb, double* beta, double *C, int ldc);
|
|
|
|
static cublasStatus_t cublasDgemm_v2 (cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, double* alpha, const double *A, int lda, const double *B, int ldb, double* beta, double *C, int ldc);
|
|
|
|
static cublasStatus_t cublasHgemm (cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, half* alpha, const half *A, int lda, const half *B, int ldb, half* beta, half *C, int ldc);
|
|
|
|
static cublasStatus_t cublasHgemm (cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, half* alpha, const half *A, int lda, const half *B, int ldb, half* beta, half *C, int ldc);
|
|
|
|
static cublasStatus_t cublasGemmEx(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const void *alpha, const void *A, cudaDataType Atype, int lda, const void *B, cudaDataType Btype, int ldb, const void *beta, void *C, cudaDataType Ctype, int ldc, cudaDataType computeType, cublasGemmAlgo_t algo);
|
|
|
|
static cublasStatus_t cublasGemmEx(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const void *alpha, const void *A, cudaDataType Atype, int lda, const void *B, cudaDataType Btype, int ldb, const void *beta, void *C, cudaDataType Ctype, int ldc, cudaDataType computeType, cublasGemmAlgo_t algo);
|
|
|
|
|
|
|
|
// CUDNN
|
|
|
|
static cudnnHandle_t cudnnHandle(driver::cu_context const & ctx);
|
|
|
|
static cudnnHandle_t cudnnHandle(driver::cu_context const & ctx);
|
|
|
|
static cudnnStatus_t cudnnCreatePoolingDescriptor(cudnnPoolingDescriptor_t *poolingDesc);
|
|
|
|
static cudnnStatus_t cudnnCreatePoolingDescriptor(cudnnPoolingDescriptor_t *poolingDesc);
|
|
|
|
static cudnnStatus_t cudnnCreateConvolutionDescriptor(cudnnConvolutionDescriptor_t* convDesc);
|
|
|
|
static cudnnStatus_t cudnnCreateConvolutionDescriptor(cudnnConvolutionDescriptor_t* convDesc);
|
|
|
@@ -196,6 +200,10 @@ public:
|
|
|
|
static cudnnStatus_t cudnnSetStream(cudnnHandle_t handle, cudaStream_t streamId);
|
|
|
|
static cudnnStatus_t cudnnSetStream(cudnnHandle_t handle, cudaStream_t streamId);
|
|
|
|
static cudnnStatus_t cudnnTransformTensor(cudnnHandle_t handle, const void *alpha, const cudnnTensorDescriptor_t xDesc, const void *x, const void *beta, const cudnnTensorDescriptor_t yDesc, void *y);
|
|
|
|
static cudnnStatus_t cudnnTransformTensor(cudnnHandle_t handle, const void *alpha, const cudnnTensorDescriptor_t xDesc, const void *x, const void *beta, const cudnnTensorDescriptor_t yDesc, void *y);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// SPIR-V libraries
|
|
|
|
|
|
|
|
static int initializeLLVMToSPIRVPass(llvm::PassRegistry &);
|
|
|
|
|
|
|
|
static bool writeSpirv(llvm::Module *M, std::ostream &OS, std::string &ErrMsg);
|
|
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
private:
|
|
|
|
|
|
|
|
|
|
|
|
// Libraries
|
|
|
|
// Libraries
|
|
|
@@ -204,6 +212,10 @@ private:
|
|
|
|
static void* nvml_;
|
|
|
|
static void* nvml_;
|
|
|
|
static void* cublas_;
|
|
|
|
static void* cublas_;
|
|
|
|
static void* cudnn_;
|
|
|
|
static void* cudnn_;
|
|
|
|
|
|
|
|
static void* vulkan_;
|
|
|
|
|
|
|
|
static void* spvllvm_;
|
|
|
|
|
|
|
|
static void* spvcross_;
|
|
|
|
|
|
|
|
static void* opengl_;
|
|
|
|
|
|
|
|
|
|
|
|
// OpenCL functions
|
|
|
|
// OpenCL functions
|
|
|
|
static void* clBuildProgram_;
|
|
|
|
static void* clBuildProgram_;
|
|
|
@@ -310,6 +322,10 @@ private:
|
|
|
|
static void* cudnnPoolingForward_;
|
|
|
|
static void* cudnnPoolingForward_;
|
|
|
|
static void* cudnnSetStream_;
|
|
|
|
static void* cudnnSetStream_;
|
|
|
|
static void* cudnnTransformTensor_;
|
|
|
|
static void* cudnnTransformTensor_;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// LLVM to SPIR-V
|
|
|
|
|
|
|
|
static void* initializeLLVMToSPIRVPass_;
|
|
|
|
|
|
|
|
static void* writeSpirv_;
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|