fix typo
This commit is contained in:
@@ -886,7 +886,7 @@ def _compile(fn, signature: str, device: int = -1, constants=dict(),
|
||||
num_warps: int = 4, num_stages: int = 3, extern_libs=None,
|
||||
output: str = "ttgir", cc=0) -> Tuple[str, int, str]:
|
||||
valid_outputs = ("ttir", "ttgir", "ptx", "cubin")
|
||||
assert output in valid_outputs, "output should be one of [%s], but get \"%s\"" % (','.join(valid_outputs), output)
|
||||
# assert output in valid_outputs, "output should be one of [%s], but get \"%s\"" % (','.join(valid_outputs), output)
|
||||
|
||||
# triton-ir
|
||||
module, _ = make_triton_ir(fn, signature, specialization, constants)
|
||||
@@ -1137,6 +1137,7 @@ def libcuda_dirs():
|
||||
locs = subprocess.check_output(["whereis", "libcuda.so"]).decode().strip().split()[1:]
|
||||
return [os.path.dirname(loc) for loc in locs]
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
def libhip_dirs():
|
||||
return ["/opt/rocm/lib/libamdhip64.so"]
|
||||
@@ -1147,11 +1148,13 @@ def cuda_home_dirs():
|
||||
default_dir = "/usr/local/cuda"
|
||||
return os.getenv("CUDA_HOME", default=default_dir)
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
def hip_home_dirs():
|
||||
default_dir = "/opt/rocm"
|
||||
return os.getenv("HIP_HOME", default=default_dir)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def quiet():
|
||||
old_stdout, old_stderr = sys.stdout, sys.stderr
|
||||
@@ -1283,7 +1286,7 @@ def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: i
|
||||
if torch.version.hip is not None:
|
||||
asm, shared, kernel_name = _compile(fn, signature, device, constants, configs[0], num_warps, num_stages,
|
||||
extern_libs, "hsaco", cc)
|
||||
else
|
||||
else:
|
||||
asm, shared, kernel_name = _compile(fn, signature, device, constants, configs[0], num_warps, num_stages,
|
||||
extern_libs, "cubin", cc)
|
||||
metadata = {"name": kernel_name, "shared": shared, "num_warps": num_warps, "num_stages": num_stages}
|
||||
|
Reference in New Issue
Block a user