From 4624fd4e1d63a9287515015067d8124f30775431 Mon Sep 17 00:00:00 2001 From: Michael Melesse Date: Wed, 19 Oct 2022 20:39:32 +0000 Subject: [PATCH] save compiler --- python/triton/compiler.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/python/triton/compiler.py b/python/triton/compiler.py index cfbca1b38..9e4726b87 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -894,8 +894,10 @@ def _compile(fn, signature: str, device: int = -1, constants=dict(), return module assert output == "cubin" - assert torch.version.hip is None - backend = _triton.runtime.backend.CUDA + if torch.version.hip is not None: + backend = _triton.runtime.backend.ROCM + else: + backend = _triton.runtime.backend.CUDA if extern_libs is None: extern_libs = dict() name, asm, shared_mem = _triton.code_gen.compile_ttir(backend, module, device, num_warps, num_stages, extern_libs, cc) @@ -1174,7 +1176,7 @@ def _build(name, src, srcdir): cc = gcc if gcc is not None else clang py_include_dir = get_paths()["include"] if torch.version.hip is not None: - cc_cmd = [cc, src, "-O3", f"-I{hip_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", "-lcuda", "-o", so] + cc_cmd = [cc, src, "-O3", f"-I{hip_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", "-o", so] cc_cmd += [f"-L{dir}" for dir in hip_lib_dirs] else: cc_cmd = [cc, src, "-O3", f"-I{cu_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", "-lcuda", "-o", so]