cuBLAS: fixed CUDA context import
This commit is contained in:
@@ -19,11 +19,17 @@ class ISAACAPI Context
|
||||
friend class Program;
|
||||
friend class CommandQueue;
|
||||
friend class Buffer;
|
||||
|
||||
static std::string cache_path();
|
||||
|
||||
static CUdevice device(CUcontext)
|
||||
{
|
||||
CUdevice res;
|
||||
cuda::check(dispatch::cuCtxGetDevice(&res));
|
||||
return res;
|
||||
}
|
||||
|
||||
public:
|
||||
explicit Context(CUcontext const & context, CUdevice const & device, bool take_ownership = true);
|
||||
explicit Context(CUcontext const & context, bool take_ownership = true);
|
||||
explicit Context(cl_context const & context, bool take_ownership = true);
|
||||
explicit Context(Device const & device);
|
||||
|
||||
|
@@ -114,6 +114,7 @@ public:
|
||||
static CUresult cuEventDestroy_v2(CUevent hEvent);
|
||||
static CUresult cuMemAlloc_v2(CUdeviceptr *dptr, size_t bytesize);
|
||||
static CUresult cuPointerGetAttribute(void * data, CUpointer_attribute attribute, CUdeviceptr ptr);
|
||||
static CUresult cuCtxGetDevice(CUdevice* result);
|
||||
|
||||
static nvrtcResult nvrtcCompileProgram(nvrtcProgram prog, int numOptions, const char **options);
|
||||
static nvrtcResult nvrtcGetProgramLogSize(nvrtcProgram prog, size_t *logSizeRet);
|
||||
@@ -190,6 +191,7 @@ private:
|
||||
static void* cuEventDestroy_v2_;
|
||||
static void* cuMemAlloc_v2_;
|
||||
static void* cuPointerGetAttribute_;
|
||||
static void* cuCtxGetDevice_;
|
||||
|
||||
static void* nvrtcCompileProgram_;
|
||||
static void* nvrtcGetProgramLogSize_;
|
||||
|
@@ -34,7 +34,9 @@ std::string Context::cache_path()
|
||||
return "";
|
||||
}
|
||||
|
||||
Context::Context(CUcontext const & context, CUdevice const & device, bool take_ownership) : backend_(CUDA), device_(device, false), cache_path_(cache_path()), h_(backend_, take_ownership)
|
||||
|
||||
|
||||
Context::Context(CUcontext const & context, bool take_ownership) : backend_(CUDA), device_(device(context), false), cache_path_(cache_path()), h_(backend_, take_ownership)
|
||||
{
|
||||
h_.cu() = context;
|
||||
}
|
||||
@@ -59,6 +61,7 @@ Context::Context(Device const & device) : backend_(device.backend_), device_(dev
|
||||
default:
|
||||
throw;
|
||||
}
|
||||
std::cout << "Shouldn't happen" << std::endl;
|
||||
}
|
||||
|
||||
bool Context::operator==(Context const & other) const
|
||||
|
@@ -166,6 +166,7 @@ CUDA_DEFINE1(CUresult, cuStreamDestroy_v2, CUstream)
|
||||
CUDA_DEFINE1(CUresult, cuEventDestroy_v2, CUevent)
|
||||
CUDA_DEFINE2(CUresult, cuMemAlloc_v2, CUdeviceptr*, size_t)
|
||||
CUDA_DEFINE3(CUresult, cuPointerGetAttribute, void*, CUpointer_attribute, CUdeviceptr)
|
||||
CUDA_DEFINE1(CUresult, cuCtxGetDevice, CUdevice*)
|
||||
|
||||
NVRTC_DEFINE3(nvrtcResult, nvrtcCompileProgram, nvrtcProgram, int, const char **)
|
||||
NVRTC_DEFINE2(nvrtcResult, nvrtcGetProgramLogSize, nvrtcProgram, size_t *)
|
||||
@@ -257,6 +258,7 @@ void* dispatch::cuStreamDestroy_v2_;
|
||||
void* dispatch::cuEventDestroy_v2_;
|
||||
void* dispatch::cuMemAlloc_v2_;
|
||||
void* dispatch::cuPointerGetAttribute_;
|
||||
void* dispatch::cuCtxGetDevice_;
|
||||
|
||||
void* dispatch::nvrtcCompileProgram_;
|
||||
void* dispatch::nvrtcGetProgramLogSize_;
|
||||
|
@@ -30,6 +30,7 @@ extern "C"
|
||||
cublasStatus_t cublasDestroy_v2 (cublasHandle_t handle)
|
||||
{
|
||||
delete handle;
|
||||
cublasShutdown();
|
||||
return CUBLAS_STATUS_SUCCESS;
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user