diff --git a/lib/driver/dispatch.cc b/lib/driver/dispatch.cc index 4059ac235..9e2aca432 100755 --- a/lib/driver/dispatch.cc +++ b/lib/driver/dispatch.cc @@ -91,9 +91,13 @@ void* dispatch::fname ## _; bool dispatch::cuinit(){ if(cuda_==nullptr){ + #ifdef _WIN32 + cuda_ = dlopen("cudart64_110.dll", RTLD_LAZY); + #else cuda_ = dlopen("libcuda.so", RTLD_LAZY); if(!cuda_) cuda_ = dlopen("libcuda.so.1", RTLD_LAZY); + #endif if(!cuda_) throw std::runtime_error("Could not find `libcuda.so`. Make sure it is in your LD_LIBRARY_PATH."); } @@ -176,8 +180,13 @@ CUDA_DEFINE1(CUresult, cuEventDestroy_v2, CUevent) * NVML * ------------------- */ bool dispatch::nvmlinit(){ + #ifdef _WIN32 + if(nvml_==nullptr) + nvml_ = dlopen("nvml.dll", RTLD_LAZY); + #else if(nvml_==nullptr) nvml_ = dlopen("libnvidia-ml.so", RTLD_LAZY); + #endif nvmlReturn_t (*fptr)(); nvmlInit_v2_ = dlsym(nvml_, "nvmlInit_v2"); *reinterpret_cast(&fptr) = nvmlInit_v2_;