[RUNTIME] remove fixed cu_include_dir (#739)

Use environment variable `CUDA_HOME` with default value`/usr/local/cuda` for `cu_include_dir` #731
This commit is contained in:
shenggan
2022-10-05 10:49:57 +08:00
committed by GitHub
parent d3c925db8a
commit 77c752dc78

View File

@@ -1107,6 +1107,12 @@ def libcuda_dirs():
return [os.path.dirname(loc) for loc in locs]
@functools.lru_cache()
def cuda_home_dirs():
default_dir = "/usr/local/cuda"
return os.getenv("CUDA_HOME", default=default_dir)
@contextlib.contextmanager
def quiet():
old_stdout, old_stderr = sys.stdout, sys.stderr
@@ -1119,7 +1125,7 @@ def quiet():
def _build(name, src, srcdir):
cuda_lib_dirs = libcuda_dirs()
cu_include_dir = "/usr/local/cuda/include"
cu_include_dir = os.path.join(cuda_home_dirs(), "include")
suffix = sysconfig.get_config_var('EXT_SUFFIX')
so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix))
# try to avoid setuptools if possible