cuBLAS: fixed CUDA context import

This commit is contained in:
Philippe Tillet
2015-11-26 21:09:34 -05:00
parent 6fc94c0c0b
commit c0b9bbee43
5 changed files with 17 additions and 3 deletions

View File

@@ -19,11 +19,17 @@ class ISAACAPI Context
friend class Program;
friend class CommandQueue;
friend class Buffer;
static std::string cache_path();
static CUdevice device(CUcontext)
{
CUdevice res;
cuda::check(dispatch::cuCtxGetDevice(&res));
return res;
}
public:
explicit Context(CUcontext const & context, CUdevice const & device, bool take_ownership = true);
explicit Context(CUcontext const & context, bool take_ownership = true);
explicit Context(cl_context const & context, bool take_ownership = true);
explicit Context(Device const & device);

View File

@@ -114,6 +114,7 @@ public:
static CUresult cuEventDestroy_v2(CUevent hEvent);
static CUresult cuMemAlloc_v2(CUdeviceptr *dptr, size_t bytesize);
static CUresult cuPointerGetAttribute(void * data, CUpointer_attribute attribute, CUdeviceptr ptr);
static CUresult cuCtxGetDevice(CUdevice* result);
static nvrtcResult nvrtcCompileProgram(nvrtcProgram prog, int numOptions, const char **options);
static nvrtcResult nvrtcGetProgramLogSize(nvrtcProgram prog, size_t *logSizeRet);
@@ -190,6 +191,7 @@ private:
static void* cuEventDestroy_v2_;
static void* cuMemAlloc_v2_;
static void* cuPointerGetAttribute_;
static void* cuCtxGetDevice_;
static void* nvrtcCompileProgram_;
static void* nvrtcGetProgramLogSize_;

View File

@@ -34,7 +34,9 @@ std::string Context::cache_path()
return "";
}
Context::Context(CUcontext const & context, CUdevice const & device, bool take_ownership) : backend_(CUDA), device_(device, false), cache_path_(cache_path()), h_(backend_, take_ownership)
Context::Context(CUcontext const & context, bool take_ownership) : backend_(CUDA), device_(device(context), false), cache_path_(cache_path()), h_(backend_, take_ownership)
{
h_.cu() = context;
}
@@ -59,6 +61,7 @@ Context::Context(Device const & device) : backend_(device.backend_), device_(dev
default:
throw;
}
std::cout << "Shouldn't happen" << std::endl;
}
bool Context::operator==(Context const & other) const

View File

@@ -166,6 +166,7 @@ CUDA_DEFINE1(CUresult, cuStreamDestroy_v2, CUstream)
CUDA_DEFINE1(CUresult, cuEventDestroy_v2, CUevent)
CUDA_DEFINE2(CUresult, cuMemAlloc_v2, CUdeviceptr*, size_t)
CUDA_DEFINE3(CUresult, cuPointerGetAttribute, void*, CUpointer_attribute, CUdeviceptr)
CUDA_DEFINE1(CUresult, cuCtxGetDevice, CUdevice*)
NVRTC_DEFINE3(nvrtcResult, nvrtcCompileProgram, nvrtcProgram, int, const char **)
NVRTC_DEFINE2(nvrtcResult, nvrtcGetProgramLogSize, nvrtcProgram, size_t *)
@@ -257,6 +258,7 @@ void* dispatch::cuStreamDestroy_v2_;
void* dispatch::cuEventDestroy_v2_;
void* dispatch::cuMemAlloc_v2_;
void* dispatch::cuPointerGetAttribute_;
void* dispatch::cuCtxGetDevice_;
void* dispatch::nvrtcCompileProgram_;
void* dispatch::nvrtcGetProgramLogSize_;

View File

@@ -30,6 +30,7 @@ extern "C"
cublasStatus_t cublasDestroy_v2 (cublasHandle_t handle)
{
delete handle;
cublasShutdown();
return CUBLAS_STATUS_SUCCESS;
}