[driver] added TRITON_LIBCUDA environment variable to specify libcuda
path if not in LD_LIBRARY_PATH
This commit is contained in:
@@ -38,11 +38,11 @@ namespace tools
|
||||
std::size_t sz = 0;
|
||||
_dupenv_s(&cache_path, &sz, name);
|
||||
#else
|
||||
const char * cache_path = std::getenv(name);
|
||||
const char * cstr = std::getenv(name);
|
||||
#endif
|
||||
if(!cache_path)
|
||||
if(!cstr)
|
||||
return "";
|
||||
std::string result(cache_path);
|
||||
std::string result(cstr);
|
||||
#ifdef _MSC_VER
|
||||
free(cache_path);
|
||||
#endif
|
||||
|
@@ -22,6 +22,7 @@
|
||||
|
||||
#include "triton/driver/dispatch.h"
|
||||
#include "triton/driver/context.h"
|
||||
#include "triton/tools/sys/getenv.hpp"
|
||||
|
||||
namespace triton
|
||||
{
|
||||
@@ -108,8 +109,13 @@ bool dispatch::clinit()
|
||||
}
|
||||
|
||||
bool dispatch::cuinit(){
|
||||
if(cuda_==nullptr)
|
||||
cuda_ = dlopen("libcuda.so", RTLD_LAZY);
|
||||
if(cuda_==nullptr){
|
||||
std::string libcuda = tools::getenv("TRITON_LIBCUDA");
|
||||
if(libcuda.empty())
|
||||
cuda_ = dlopen("libcuda.so", RTLD_LAZY);
|
||||
else
|
||||
cuda_ = dlopen(libcuda.c_str(), RTLD_LAZY);
|
||||
}
|
||||
if(cuda_ == nullptr)
|
||||
return false;
|
||||
CUresult (*fptr)(unsigned int);
|
||||
|
Reference in New Issue
Block a user