This commit is contained in:
Michael Melesse
2022-10-21 17:58:38 +00:00
parent d022f5cf2c
commit 8785793445

View File

@@ -886,7 +886,7 @@ def _compile(fn, signature: str, device: int = -1, constants=dict(),
num_warps: int = 4, num_stages: int = 3, extern_libs=None,
output: str = "ttgir", cc=0) -> Tuple[str, int, str]:
valid_outputs = ("ttir", "ttgir", "ptx", "cubin")
assert output in valid_outputs, "output should be one of [%s], but get \"%s\"" % (','.join(valid_outputs), output)
# assert output in valid_outputs, "output should be one of [%s], but get \"%s\"" % (','.join(valid_outputs), output)
# triton-ir
module, _ = make_triton_ir(fn, signature, specialization, constants)
@@ -900,7 +900,7 @@ def _compile(fn, signature: str, device: int = -1, constants=dict(),
backend = _triton.runtime.backend.CUDA
if extern_libs is None:
extern_libs = dict()
# compile ttir
if torch.version.hip is not None:
name, asm, shared_mem = _triton.code_gen.compile_ttir_to_amdgpu(backend, module, device, num_warps, num_stages, extern_libs, cc)
@@ -1137,6 +1137,7 @@ def libcuda_dirs():
locs = subprocess.check_output(["whereis", "libcuda.so"]).decode().strip().split()[1:]
return [os.path.dirname(loc) for loc in locs]
@functools.lru_cache()
def libhip_dirs():
return ["/opt/rocm/lib/libamdhip64.so"]
@@ -1147,11 +1148,13 @@ def cuda_home_dirs():
default_dir = "/usr/local/cuda"
return os.getenv("CUDA_HOME", default=default_dir)
@functools.lru_cache()
def hip_home_dirs():
default_dir = "/opt/rocm"
return os.getenv("HIP_HOME", default=default_dir)
@contextlib.contextmanager
def quiet():
old_stdout, old_stderr = sys.stdout, sys.stderr
@@ -1169,7 +1172,7 @@ def _build(name, src, srcdir):
else:
cuda_lib_dirs = libcuda_dirs()
cu_include_dir = os.path.join(cuda_home_dirs(), "include")
suffix = sysconfig.get_config_var('EXT_SUFFIX')
so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix))
# try to avoid setuptools if possible
@@ -1181,8 +1184,8 @@ def _build(name, src, srcdir):
cc = gcc if gcc is not None else clang
py_include_dir = get_paths()["include"]
if torch.version.hip is not None:
cc_cmd = [cc, src, "-O3", f"-I{hip_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", "-o", so]
cc_cmd += [f"-L{dir}" for dir in hip_lib_dirs]
cc_cmd = [cc, src, "-O3", f"-I{hip_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", "-o", so]
cc_cmd += [f"-L{dir}" for dir in hip_lib_dirs]
else:
cc_cmd = [cc, src, "-O3", f"-I{cu_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", "-lcuda", "-o", so]
cc_cmd += [f"-L{dir}" for dir in cuda_lib_dirs]
@@ -1199,7 +1202,7 @@ def _build(name, src, srcdir):
library_dirs = cuda_lib_dirs
include_dirs = [srcdir, cu_include_dir]
libraries = ['cuda']
# extra arguments
extra_link_args = []
# create extension module
@@ -1283,7 +1286,7 @@ def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: i
if torch.version.hip is not None:
asm, shared, kernel_name = _compile(fn, signature, device, constants, configs[0], num_warps, num_stages,
extern_libs, "hsaco", cc)
else
else:
asm, shared, kernel_name = _compile(fn, signature, device, constants, configs[0], num_warps, num_stages,
extern_libs, "cubin", cc)
metadata = {"name": kernel_name, "shared": shared, "num_warps": num_warps, "num_stages": num_stages}