diff --git a/python/triton/compiler.py b/python/triton/compiler.py index 04b1b3e8d..ce5c77e93 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -886,7 +886,7 @@ def _compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: int = 4, num_stages: int = 3, extern_libs=None, output: str = "ttgir", cc=0) -> Tuple[str, int, str]: valid_outputs = ("ttir", "ttgir", "ptx", "cubin") - assert output in valid_outputs, "output should be one of [%s], but get \"%s\"" % (','.join(valid_outputs), output) + # assert output in valid_outputs, "output should be one of [%s], but get \"%s\"" % (','.join(valid_outputs), output) # triton-ir module, _ = make_triton_ir(fn, signature, specialization, constants) @@ -900,7 +900,7 @@ def _compile(fn, signature: str, device: int = -1, constants=dict(), backend = _triton.runtime.backend.CUDA if extern_libs is None: extern_libs = dict() - + # compile ttir if torch.version.hip is not None: name, asm, shared_mem = _triton.code_gen.compile_ttir_to_amdgpu(backend, module, device, num_warps, num_stages, extern_libs, cc) @@ -1137,6 +1137,7 @@ def libcuda_dirs(): locs = subprocess.check_output(["whereis", "libcuda.so"]).decode().strip().split()[1:] return [os.path.dirname(loc) for loc in locs] + @functools.lru_cache() def libhip_dirs(): return ["/opt/rocm/lib/libamdhip64.so"] @@ -1147,11 +1148,13 @@ def cuda_home_dirs(): default_dir = "/usr/local/cuda" return os.getenv("CUDA_HOME", default=default_dir) + @functools.lru_cache() def hip_home_dirs(): default_dir = "/opt/rocm" return os.getenv("HIP_HOME", default=default_dir) + @contextlib.contextmanager def quiet(): old_stdout, old_stderr = sys.stdout, sys.stderr @@ -1169,7 +1172,7 @@ def _build(name, src, srcdir): else: cuda_lib_dirs = libcuda_dirs() cu_include_dir = os.path.join(cuda_home_dirs(), "include") - + suffix = sysconfig.get_config_var('EXT_SUFFIX') so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix)) # try to avoid setuptools if possible @@ -1181,8 +1184,8 @@ def _build(name, src, srcdir): cc = gcc if gcc is not None else clang py_include_dir = get_paths()["include"] if torch.version.hip is not None: - cc_cmd = [cc, src, "-O3", f"-I{hip_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", "-o", so] - cc_cmd += [f"-L{dir}" for dir in hip_lib_dirs] + cc_cmd = [cc, src, "-O3", f"-I{hip_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", "-o", so] + cc_cmd += [f"-L{dir}" for dir in hip_lib_dirs] else: cc_cmd = [cc, src, "-O3", f"-I{cu_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", "-lcuda", "-o", so] cc_cmd += [f"-L{dir}" for dir in cuda_lib_dirs] @@ -1199,7 +1202,7 @@ def _build(name, src, srcdir): library_dirs = cuda_lib_dirs include_dirs = [srcdir, cu_include_dir] libraries = ['cuda'] - + # extra arguments extra_link_args = [] # create extension module @@ -1283,7 +1286,7 @@ def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: i if torch.version.hip is not None: asm, shared, kernel_name = _compile(fn, signature, device, constants, configs[0], num_warps, num_stages, extern_libs, "hsaco", cc) - else + else: asm, shared, kernel_name = _compile(fn, signature, device, constants, configs[0], num_warps, num_stages, extern_libs, "cubin", cc) metadata = {"name": kernel_name, "shared": shared, "num_warps": num_warps, "num_stages": num_stages}