diff --git a/include/isaac/driver/context.h b/include/isaac/driver/context.h index 0c548a1d4..aacf6660c 100644 --- a/include/isaac/driver/context.h +++ b/include/isaac/driver/context.h @@ -19,11 +19,17 @@ class ISAACAPI Context friend class Program; friend class CommandQueue; friend class Buffer; - static std::string cache_path(); + static CUdevice device(CUcontext) + { + CUdevice res; + cuda::check(dispatch::cuCtxGetDevice(&res)); + return res; + } + public: - explicit Context(CUcontext const & context, CUdevice const & device, bool take_ownership = true); + explicit Context(CUcontext const & context, bool take_ownership = true); explicit Context(cl_context const & context, bool take_ownership = true); explicit Context(Device const & device); diff --git a/include/isaac/driver/dispatch.h b/include/isaac/driver/dispatch.h index dba45232f..64ca4fd10 100644 --- a/include/isaac/driver/dispatch.h +++ b/include/isaac/driver/dispatch.h @@ -114,6 +114,7 @@ public: static CUresult cuEventDestroy_v2(CUevent hEvent); static CUresult cuMemAlloc_v2(CUdeviceptr *dptr, size_t bytesize); static CUresult cuPointerGetAttribute(void * data, CUpointer_attribute attribute, CUdeviceptr ptr); + static CUresult cuCtxGetDevice(CUdevice* result); static nvrtcResult nvrtcCompileProgram(nvrtcProgram prog, int numOptions, const char **options); static nvrtcResult nvrtcGetProgramLogSize(nvrtcProgram prog, size_t *logSizeRet); @@ -190,6 +191,7 @@ private: static void* cuEventDestroy_v2_; static void* cuMemAlloc_v2_; static void* cuPointerGetAttribute_; + static void* cuCtxGetDevice_; static void* nvrtcCompileProgram_; static void* nvrtcGetProgramLogSize_; diff --git a/lib/driver/context.cpp b/lib/driver/context.cpp index 33d23e032..c1e0bbf5c 100644 --- a/lib/driver/context.cpp +++ b/lib/driver/context.cpp @@ -34,7 +34,9 @@ std::string Context::cache_path() return ""; } -Context::Context(CUcontext const & context, CUdevice const & device, bool take_ownership) : backend_(CUDA), device_(device, false), cache_path_(cache_path()), h_(backend_, take_ownership) + + +Context::Context(CUcontext const & context, bool take_ownership) : backend_(CUDA), device_(device(context), false), cache_path_(cache_path()), h_(backend_, take_ownership) { h_.cu() = context; } @@ -59,6 +61,7 @@ Context::Context(Device const & device) : backend_(device.backend_), device_(dev default: throw; } + std::cout << "Shouldn't happen" << std::endl; } bool Context::operator==(Context const & other) const diff --git a/lib/driver/dispatch.cpp b/lib/driver/dispatch.cpp index ef2a3fd01..449d6626a 100644 --- a/lib/driver/dispatch.cpp +++ b/lib/driver/dispatch.cpp @@ -166,6 +166,7 @@ CUDA_DEFINE1(CUresult, cuStreamDestroy_v2, CUstream) CUDA_DEFINE1(CUresult, cuEventDestroy_v2, CUevent) CUDA_DEFINE2(CUresult, cuMemAlloc_v2, CUdeviceptr*, size_t) CUDA_DEFINE3(CUresult, cuPointerGetAttribute, void*, CUpointer_attribute, CUdeviceptr) +CUDA_DEFINE1(CUresult, cuCtxGetDevice, CUdevice*) NVRTC_DEFINE3(nvrtcResult, nvrtcCompileProgram, nvrtcProgram, int, const char **) NVRTC_DEFINE2(nvrtcResult, nvrtcGetProgramLogSize, nvrtcProgram, size_t *) @@ -257,6 +258,7 @@ void* dispatch::cuStreamDestroy_v2_; void* dispatch::cuEventDestroy_v2_; void* dispatch::cuMemAlloc_v2_; void* dispatch::cuPointerGetAttribute_; +void* dispatch::cuCtxGetDevice_; void* dispatch::nvrtcCompileProgram_; void* dispatch::nvrtcGetProgramLogSize_; diff --git a/lib/wrap/cublas.cpp b/lib/wrap/cublas.cpp index 3269d4332..3d2b20c7f 100644 --- a/lib/wrap/cublas.cpp +++ b/lib/wrap/cublas.cpp @@ -30,6 +30,7 @@ extern "C" cublasStatus_t cublasDestroy_v2 (cublasHandle_t handle) { delete handle; + cublasShutdown(); return CUBLAS_STATUS_SUCCESS; }