From 8cc448d92ef86af16e2c4c71e100ee109731ec57 Mon Sep 17 00:00:00 2001 From: Rohit Santhanam Date: Fri, 18 Nov 2022 12:58:51 +0000 Subject: [PATCH] Changes to eliminate the need for the MI_GPU_ARCH environment variable. The AMDGPU arch is now parsed out of the rocminfo dump. --- CMakeLists.txt | 5 ----- lib/driver/llvm.cc | 2 +- python/src/triton.cc | 11 ++--------- python/triton/compiler.py | 28 ++++++++++++++++++++++++++-- 4 files changed, 29 insertions(+), 17 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index cbf78cf03..b99690519 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -44,11 +44,6 @@ endif() if (TRITON_USE_ROCM) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -std=gnu++17 -Wno-unused-result -Wno-attributes") - set(MI_GPU_ARCH $ENV{MI_GPU_ARCH}) - if (NOT MI_GPU_ARCH) - set(MI_GPU_ARCH "gfx90a") - endif() - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DMI_GPU_ARCH=${MI_GPU_ARCH}") else() set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -std=gnu++17") endif() diff --git a/lib/driver/llvm.cc b/lib/driver/llvm.cc index 58a267f5b..9952707e4 100644 --- a/lib/driver/llvm.cc +++ b/lib/driver/llvm.cc @@ -300,7 +300,7 @@ namespace triton std::string triple = "amdgcn-amd-amdhsa"; std::string layout = ""; std::string features = "+sramecc,-xnack"; - std::string proc = STRINGIFY(MI_GPU_ARCH); + std::string proc = _proc; // name kernel auto in_time_t = std::chrono::system_clock::to_time_t(std::chrono::system_clock::now()); std::stringstream cur_time; diff --git a/python/src/triton.cc b/python/src/triton.cc index 9393ead43..66d12a7c9 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -493,7 +493,7 @@ void init_triton_codegen(py::module &&m) { py::return_value_policy::take_ownership); m.def("compile_ttir_to_amdgcn", - [](backend_t backend, ir::module &ir, uint64_t device, int num_warps, int num_stages, py::dict& extern_libs, size_t cc) { + [](backend_t backend, ir::module &ir, uint64_t device, int num_warps, int num_stages, py::dict& extern_libs, const std::string& gfx_arch) { std::ostringstream ttir; int n_shared_bytes; std::string tmp; @@ -515,13 +515,6 @@ void init_triton_codegen(py::module &&m) { extern_lib_map.emplace( name, triton::codegen::create_extern_lib(name, path)); } - // device properties - if (cc == 0) { - hipDevice_t dev = (hipDevice_t)device; - size_t major = hipGetInfo(dev); - size_t minor = hipGetInfo(dev); - cc = major*10 + minor; - } int version; // std::string ptxas_path = drv::path_to_ptxas(version); // Triton-IR -> AMDGCN LLVM-IR @@ -536,7 +529,7 @@ void init_triton_codegen(py::module &&m) { std::cout << "\t" << llir.str() << std::endl; llir.flush(); // LLVM-IR -> AMDGPU - std::tuple amdgpu = drv::llir_to_amdgcn(llvm.get(), "gfx90a"); + std::tuple amdgpu = drv::llir_to_amdgcn(llvm.get(), gfx_arch); amdgcn = std::get<0>(amdgpu); hsaco_path = std::get<1>(amdgpu); std::cout << "amdgcn:" << std::endl; diff --git a/python/triton/compiler.py b/python/triton/compiler.py index 6eb331dbb..0d5f082c8 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -7,6 +7,7 @@ import hashlib import io import json import os +import re import shutil import subprocess import sys @@ -24,6 +25,12 @@ import triton import triton._C.libtriton.triton as _triton from .tools.disasm import extract +def static_vars(**kwargs): + def decorate(func): + for k in kwargs: + setattr(func, k, kwargs[k]) + return func + return decorate def str_to_ty(name): if name[0] == "*": @@ -880,7 +887,19 @@ def ptx_get_kernel_name(ptx: str) -> str: if line.startswith('// .globl'): return line.split()[-1] +@functools.lru_cache() +def rocm_path_dir(): + return os.getenv("ROCM_PATH", default="/opt/rocm") +def _get_amdgpu_arch(): + try: + rocminfo = subprocess.check_output(rocm_path_dir() + '/bin/rocminfo').decode() + gfx_arch = re.search('Name:\\s+.*(gfx\\d+)', rocminfo) + return gfx_arch.group(1).strip() + except: + return None + +@static_vars(discovered_gfx_arch = _get_amdgpu_arch()) def _compile(fn, signature: str, device: int = -1, constants=dict(), specialization=_triton.code_gen.instance_descriptor(), num_warps: int = 4, num_stages: int = 3, extern_libs=None, @@ -905,7 +924,10 @@ def _compile(fn, signature: str, device: int = -1, constants=dict(), # compile ttir if torch.version.hip is not None: - name, asm, shared_mem = _triton.code_gen.compile_ttir_to_amdgcn(backend, module, device, num_warps, num_stages, extern_libs, cc) + gfx_arch = os.environ.get('MI_GPU_ARCH', _compile.discovered_gfx_arch) + if gfx_arch is None: + raise RuntimeError('AMDGCN gfx arch is not defined.') + name, asm, shared_mem = _triton.code_gen.compile_ttir_to_amdgcn(backend, module, device, num_warps, num_stages, extern_libs, gfx_arch) else: name, asm, shared_mem = _triton.code_gen.compile_ttir_to_ptx(backend, module, device, num_warps, num_stages, extern_libs, cc) return asm, shared_mem, name @@ -1145,7 +1167,6 @@ def libcuda_dirs(): locs = subprocess.check_output(["whereis", "libcuda.so"]).decode().strip().split()[1:] return [os.path.dirname(loc) for loc in locs] - @functools.lru_cache() def libhip_dirs(): return ["/opt/rocm/lib"] @@ -1162,6 +1183,9 @@ def hip_home_dirs(): default_dir = "/opt/rocm" return os.getenv("ROCM_HOME", default=default_dir) +@functools.lru_cache() +def rocm_path_dir(): + return os.getenv("ROCM_PATH", default="/opt/rocm") @contextlib.contextmanager def quiet():