This commit is contained in:
Michael Melesse
2022-10-21 17:58:38 +00:00
parent d022f5cf2c
commit 8785793445

View File

@@ -886,7 +886,7 @@ def _compile(fn, signature: str, device: int = -1, constants=dict(),
num_warps: int = 4, num_stages: int = 3, extern_libs=None,
output: str = "ttgir", cc=0) -> Tuple[str, int, str]:
valid_outputs = ("ttir", "ttgir", "ptx", "cubin")
assert output in valid_outputs, "output should be one of [%s], but get \"%s\"" % (','.join(valid_outputs), output)
# assert output in valid_outputs, "output should be one of [%s], but get \"%s\"" % (','.join(valid_outputs), output)
# triton-ir
module, _ = make_triton_ir(fn, signature, specialization, constants)
@@ -1137,6 +1137,7 @@ def libcuda_dirs():
locs = subprocess.check_output(["whereis", "libcuda.so"]).decode().strip().split()[1:]
return [os.path.dirname(loc) for loc in locs]
@functools.lru_cache()
def libhip_dirs():
return ["/opt/rocm/lib/libamdhip64.so"]
@@ -1147,11 +1148,13 @@ def cuda_home_dirs():
default_dir = "/usr/local/cuda"
return os.getenv("CUDA_HOME", default=default_dir)
@functools.lru_cache()
def hip_home_dirs():
default_dir = "/opt/rocm"
return os.getenv("HIP_HOME", default=default_dir)
@contextlib.contextmanager
def quiet():
old_stdout, old_stderr = sys.stdout, sys.stderr
@@ -1283,7 +1286,7 @@ def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: i
if torch.version.hip is not None:
asm, shared, kernel_name = _compile(fn, signature, device, constants, configs[0], num_warps, num_stages,
extern_libs, "hsaco", cc)
else
else:
asm, shared, kernel_name = _compile(fn, signature, device, constants, configs[0], num_warps, num_stages,
extern_libs, "cubin", cc)
metadata = {"name": kernel_name, "shared": shared, "num_warps": num_warps, "num_stages": num_stages}