Changes to eliminate the need for the MI_GPU_ARCH environment variable.

The AMDGPU arch is now parsed out of the rocminfo dump.
This commit is contained in:
Rohit Santhanam
2022-11-18 12:58:51 +00:00
parent 9a9fabbba9
commit 8cc448d92e
4 changed files with 29 additions and 17 deletions

View File

@@ -44,11 +44,6 @@ endif()
if (TRITON_USE_ROCM)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -std=gnu++17 -Wno-unused-result -Wno-attributes")
set(MI_GPU_ARCH $ENV{MI_GPU_ARCH})
if (NOT MI_GPU_ARCH)
set(MI_GPU_ARCH "gfx90a")
endif()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DMI_GPU_ARCH=${MI_GPU_ARCH}")
else()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -std=gnu++17")
endif()

View File

@@ -300,7 +300,7 @@ namespace triton
std::string triple = "amdgcn-amd-amdhsa";
std::string layout = "";
std::string features = "+sramecc,-xnack";
std::string proc = STRINGIFY(MI_GPU_ARCH);
std::string proc = _proc;
// name kernel
auto in_time_t = std::chrono::system_clock::to_time_t(std::chrono::system_clock::now());
std::stringstream cur_time;

View File

@@ -493,7 +493,7 @@ void init_triton_codegen(py::module &&m) {
py::return_value_policy::take_ownership);
m.def("compile_ttir_to_amdgcn",
[](backend_t backend, ir::module &ir, uint64_t device, int num_warps, int num_stages, py::dict& extern_libs, size_t cc) {
[](backend_t backend, ir::module &ir, uint64_t device, int num_warps, int num_stages, py::dict& extern_libs, const std::string& gfx_arch) {
std::ostringstream ttir;
int n_shared_bytes;
std::string tmp;
@@ -515,13 +515,6 @@ void init_triton_codegen(py::module &&m) {
extern_lib_map.emplace(
name, triton::codegen::create_extern_lib(name, path));
}
// device properties
if (cc == 0) {
hipDevice_t dev = (hipDevice_t)device;
size_t major = hipGetInfo<hipDeviceAttributeComputeCapabilityMajor>(dev);
size_t minor = hipGetInfo<hipDeviceAttributeComputeCapabilityMinor>(dev);
cc = major*10 + minor;
}
int version;
// std::string ptxas_path = drv::path_to_ptxas(version);
// Triton-IR -> AMDGCN LLVM-IR
@@ -536,7 +529,7 @@ void init_triton_codegen(py::module &&m) {
std::cout << "\t" << llir.str() << std::endl;
llir.flush();
// LLVM-IR -> AMDGPU
std::tuple<std::string, std::string> amdgpu = drv::llir_to_amdgcn(llvm.get(), "gfx90a");
std::tuple<std::string, std::string> amdgpu = drv::llir_to_amdgcn(llvm.get(), gfx_arch);
amdgcn = std::get<0>(amdgpu);
hsaco_path = std::get<1>(amdgpu);
std::cout << "amdgcn:" << std::endl;

View File

@@ -7,6 +7,7 @@ import hashlib
import io
import json
import os
import re
import shutil
import subprocess
import sys
@@ -24,6 +25,12 @@ import triton
import triton._C.libtriton.triton as _triton
from .tools.disasm import extract
def static_vars(**kwargs):
def decorate(func):
for k in kwargs:
setattr(func, k, kwargs[k])
return func
return decorate
def str_to_ty(name):
if name[0] == "*":
@@ -880,7 +887,19 @@ def ptx_get_kernel_name(ptx: str) -> str:
if line.startswith('// .globl'):
return line.split()[-1]
@functools.lru_cache()
def rocm_path_dir():
return os.getenv("ROCM_PATH", default="/opt/rocm")
def _get_amdgpu_arch():
try:
rocminfo = subprocess.check_output(rocm_path_dir() + '/bin/rocminfo').decode()
gfx_arch = re.search('Name:\\s+.*(gfx\\d+)', rocminfo)
return gfx_arch.group(1).strip()
except:
return None
@static_vars(discovered_gfx_arch = _get_amdgpu_arch())
def _compile(fn, signature: str, device: int = -1, constants=dict(),
specialization=_triton.code_gen.instance_descriptor(),
num_warps: int = 4, num_stages: int = 3, extern_libs=None,
@@ -905,7 +924,10 @@ def _compile(fn, signature: str, device: int = -1, constants=dict(),
# compile ttir
if torch.version.hip is not None:
name, asm, shared_mem = _triton.code_gen.compile_ttir_to_amdgcn(backend, module, device, num_warps, num_stages, extern_libs, cc)
gfx_arch = os.environ.get('MI_GPU_ARCH', _compile.discovered_gfx_arch)
if gfx_arch is None:
raise RuntimeError('AMDGCN gfx arch is not defined.')
name, asm, shared_mem = _triton.code_gen.compile_ttir_to_amdgcn(backend, module, device, num_warps, num_stages, extern_libs, gfx_arch)
else:
name, asm, shared_mem = _triton.code_gen.compile_ttir_to_ptx(backend, module, device, num_warps, num_stages, extern_libs, cc)
return asm, shared_mem, name
@@ -1145,7 +1167,6 @@ def libcuda_dirs():
locs = subprocess.check_output(["whereis", "libcuda.so"]).decode().strip().split()[1:]
return [os.path.dirname(loc) for loc in locs]
@functools.lru_cache()
def libhip_dirs():
return ["/opt/rocm/lib"]
@@ -1162,6 +1183,9 @@ def hip_home_dirs():
default_dir = "/opt/rocm"
return os.getenv("ROCM_HOME", default=default_dir)
@functools.lru_cache()
def rocm_path_dir():
return os.getenv("ROCM_PATH", default="/opt/rocm")
@contextlib.contextmanager
def quiet():