cuBLAS: fixed CUDA context import
This commit is contained in:
@@ -19,11 +19,17 @@ class ISAACAPI Context
|
|||||||
friend class Program;
|
friend class Program;
|
||||||
friend class CommandQueue;
|
friend class CommandQueue;
|
||||||
friend class Buffer;
|
friend class Buffer;
|
||||||
|
|
||||||
static std::string cache_path();
|
static std::string cache_path();
|
||||||
|
|
||||||
|
static CUdevice device(CUcontext)
|
||||||
|
{
|
||||||
|
CUdevice res;
|
||||||
|
cuda::check(dispatch::cuCtxGetDevice(&res));
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
explicit Context(CUcontext const & context, CUdevice const & device, bool take_ownership = true);
|
explicit Context(CUcontext const & context, bool take_ownership = true);
|
||||||
explicit Context(cl_context const & context, bool take_ownership = true);
|
explicit Context(cl_context const & context, bool take_ownership = true);
|
||||||
explicit Context(Device const & device);
|
explicit Context(Device const & device);
|
||||||
|
|
||||||
|
@@ -114,6 +114,7 @@ public:
|
|||||||
static CUresult cuEventDestroy_v2(CUevent hEvent);
|
static CUresult cuEventDestroy_v2(CUevent hEvent);
|
||||||
static CUresult cuMemAlloc_v2(CUdeviceptr *dptr, size_t bytesize);
|
static CUresult cuMemAlloc_v2(CUdeviceptr *dptr, size_t bytesize);
|
||||||
static CUresult cuPointerGetAttribute(void * data, CUpointer_attribute attribute, CUdeviceptr ptr);
|
static CUresult cuPointerGetAttribute(void * data, CUpointer_attribute attribute, CUdeviceptr ptr);
|
||||||
|
static CUresult cuCtxGetDevice(CUdevice* result);
|
||||||
|
|
||||||
static nvrtcResult nvrtcCompileProgram(nvrtcProgram prog, int numOptions, const char **options);
|
static nvrtcResult nvrtcCompileProgram(nvrtcProgram prog, int numOptions, const char **options);
|
||||||
static nvrtcResult nvrtcGetProgramLogSize(nvrtcProgram prog, size_t *logSizeRet);
|
static nvrtcResult nvrtcGetProgramLogSize(nvrtcProgram prog, size_t *logSizeRet);
|
||||||
@@ -190,6 +191,7 @@ private:
|
|||||||
static void* cuEventDestroy_v2_;
|
static void* cuEventDestroy_v2_;
|
||||||
static void* cuMemAlloc_v2_;
|
static void* cuMemAlloc_v2_;
|
||||||
static void* cuPointerGetAttribute_;
|
static void* cuPointerGetAttribute_;
|
||||||
|
static void* cuCtxGetDevice_;
|
||||||
|
|
||||||
static void* nvrtcCompileProgram_;
|
static void* nvrtcCompileProgram_;
|
||||||
static void* nvrtcGetProgramLogSize_;
|
static void* nvrtcGetProgramLogSize_;
|
||||||
|
@@ -34,7 +34,9 @@ std::string Context::cache_path()
|
|||||||
return "";
|
return "";
|
||||||
}
|
}
|
||||||
|
|
||||||
Context::Context(CUcontext const & context, CUdevice const & device, bool take_ownership) : backend_(CUDA), device_(device, false), cache_path_(cache_path()), h_(backend_, take_ownership)
|
|
||||||
|
|
||||||
|
Context::Context(CUcontext const & context, bool take_ownership) : backend_(CUDA), device_(device(context), false), cache_path_(cache_path()), h_(backend_, take_ownership)
|
||||||
{
|
{
|
||||||
h_.cu() = context;
|
h_.cu() = context;
|
||||||
}
|
}
|
||||||
@@ -59,6 +61,7 @@ Context::Context(Device const & device) : backend_(device.backend_), device_(dev
|
|||||||
default:
|
default:
|
||||||
throw;
|
throw;
|
||||||
}
|
}
|
||||||
|
std::cout << "Shouldn't happen" << std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Context::operator==(Context const & other) const
|
bool Context::operator==(Context const & other) const
|
||||||
|
@@ -166,6 +166,7 @@ CUDA_DEFINE1(CUresult, cuStreamDestroy_v2, CUstream)
|
|||||||
CUDA_DEFINE1(CUresult, cuEventDestroy_v2, CUevent)
|
CUDA_DEFINE1(CUresult, cuEventDestroy_v2, CUevent)
|
||||||
CUDA_DEFINE2(CUresult, cuMemAlloc_v2, CUdeviceptr*, size_t)
|
CUDA_DEFINE2(CUresult, cuMemAlloc_v2, CUdeviceptr*, size_t)
|
||||||
CUDA_DEFINE3(CUresult, cuPointerGetAttribute, void*, CUpointer_attribute, CUdeviceptr)
|
CUDA_DEFINE3(CUresult, cuPointerGetAttribute, void*, CUpointer_attribute, CUdeviceptr)
|
||||||
|
CUDA_DEFINE1(CUresult, cuCtxGetDevice, CUdevice*)
|
||||||
|
|
||||||
NVRTC_DEFINE3(nvrtcResult, nvrtcCompileProgram, nvrtcProgram, int, const char **)
|
NVRTC_DEFINE3(nvrtcResult, nvrtcCompileProgram, nvrtcProgram, int, const char **)
|
||||||
NVRTC_DEFINE2(nvrtcResult, nvrtcGetProgramLogSize, nvrtcProgram, size_t *)
|
NVRTC_DEFINE2(nvrtcResult, nvrtcGetProgramLogSize, nvrtcProgram, size_t *)
|
||||||
@@ -257,6 +258,7 @@ void* dispatch::cuStreamDestroy_v2_;
|
|||||||
void* dispatch::cuEventDestroy_v2_;
|
void* dispatch::cuEventDestroy_v2_;
|
||||||
void* dispatch::cuMemAlloc_v2_;
|
void* dispatch::cuMemAlloc_v2_;
|
||||||
void* dispatch::cuPointerGetAttribute_;
|
void* dispatch::cuPointerGetAttribute_;
|
||||||
|
void* dispatch::cuCtxGetDevice_;
|
||||||
|
|
||||||
void* dispatch::nvrtcCompileProgram_;
|
void* dispatch::nvrtcCompileProgram_;
|
||||||
void* dispatch::nvrtcGetProgramLogSize_;
|
void* dispatch::nvrtcGetProgramLogSize_;
|
||||||
|
@@ -30,6 +30,7 @@ extern "C"
|
|||||||
cublasStatus_t cublasDestroy_v2 (cublasHandle_t handle)
|
cublasStatus_t cublasDestroy_v2 (cublasHandle_t handle)
|
||||||
{
|
{
|
||||||
delete handle;
|
delete handle;
|
||||||
|
cublasShutdown();
|
||||||
return CUBLAS_STATUS_SUCCESS;
|
return CUBLAS_STATUS_SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user