save compiler

This commit is contained in:
Michael Melesse
2022-10-19 20:39:32 +00:00
parent 41144f927f
commit 4624fd4e1d

View File

@@ -894,8 +894,10 @@ def _compile(fn, signature: str, device: int = -1, constants=dict(),
return module
assert output == "cubin"
assert torch.version.hip is None
backend = _triton.runtime.backend.CUDA
if torch.version.hip is not None:
backend = _triton.runtime.backend.ROCM
else:
backend = _triton.runtime.backend.CUDA
if extern_libs is None:
extern_libs = dict()
name, asm, shared_mem = _triton.code_gen.compile_ttir(backend, module, device, num_warps, num_stages, extern_libs, cc)
@@ -1174,7 +1176,7 @@ def _build(name, src, srcdir):
cc = gcc if gcc is not None else clang
py_include_dir = get_paths()["include"]
if torch.version.hip is not None:
cc_cmd = [cc, src, "-O3", f"-I{hip_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", "-lcuda", "-o", so]
cc_cmd = [cc, src, "-O3", f"-I{hip_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", "-o", so]
cc_cmd += [f"-L{dir}" for dir in hip_lib_dirs]
else:
cc_cmd = [cc, src, "-O3", f"-I{cu_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", "-lcuda", "-o", so]