Changes to eliminate the need for the MI_GPU_ARCH environment variable.
The AMDGPU arch is now parsed out of the rocminfo dump.
This commit is contained in:
@@ -44,11 +44,6 @@ endif()
|
||||
|
||||
if (TRITON_USE_ROCM)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -std=gnu++17 -Wno-unused-result -Wno-attributes")
|
||||
set(MI_GPU_ARCH $ENV{MI_GPU_ARCH})
|
||||
if (NOT MI_GPU_ARCH)
|
||||
set(MI_GPU_ARCH "gfx90a")
|
||||
endif()
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DMI_GPU_ARCH=${MI_GPU_ARCH}")
|
||||
else()
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -std=gnu++17")
|
||||
endif()
|
||||
|
@@ -300,7 +300,7 @@ namespace triton
|
||||
std::string triple = "amdgcn-amd-amdhsa";
|
||||
std::string layout = "";
|
||||
std::string features = "+sramecc,-xnack";
|
||||
std::string proc = STRINGIFY(MI_GPU_ARCH);
|
||||
std::string proc = _proc;
|
||||
// name kernel
|
||||
auto in_time_t = std::chrono::system_clock::to_time_t(std::chrono::system_clock::now());
|
||||
std::stringstream cur_time;
|
||||
|
@@ -493,7 +493,7 @@ void init_triton_codegen(py::module &&m) {
|
||||
py::return_value_policy::take_ownership);
|
||||
|
||||
m.def("compile_ttir_to_amdgcn",
|
||||
[](backend_t backend, ir::module &ir, uint64_t device, int num_warps, int num_stages, py::dict& extern_libs, size_t cc) {
|
||||
[](backend_t backend, ir::module &ir, uint64_t device, int num_warps, int num_stages, py::dict& extern_libs, const std::string& gfx_arch) {
|
||||
std::ostringstream ttir;
|
||||
int n_shared_bytes;
|
||||
std::string tmp;
|
||||
@@ -515,13 +515,6 @@ void init_triton_codegen(py::module &&m) {
|
||||
extern_lib_map.emplace(
|
||||
name, triton::codegen::create_extern_lib(name, path));
|
||||
}
|
||||
// device properties
|
||||
if (cc == 0) {
|
||||
hipDevice_t dev = (hipDevice_t)device;
|
||||
size_t major = hipGetInfo<hipDeviceAttributeComputeCapabilityMajor>(dev);
|
||||
size_t minor = hipGetInfo<hipDeviceAttributeComputeCapabilityMinor>(dev);
|
||||
cc = major*10 + minor;
|
||||
}
|
||||
int version;
|
||||
// std::string ptxas_path = drv::path_to_ptxas(version);
|
||||
// Triton-IR -> AMDGCN LLVM-IR
|
||||
@@ -536,7 +529,7 @@ void init_triton_codegen(py::module &&m) {
|
||||
std::cout << "\t" << llir.str() << std::endl;
|
||||
llir.flush();
|
||||
// LLVM-IR -> AMDGPU
|
||||
std::tuple<std::string, std::string> amdgpu = drv::llir_to_amdgcn(llvm.get(), "gfx90a");
|
||||
std::tuple<std::string, std::string> amdgpu = drv::llir_to_amdgcn(llvm.get(), gfx_arch);
|
||||
amdgcn = std::get<0>(amdgpu);
|
||||
hsaco_path = std::get<1>(amdgpu);
|
||||
std::cout << "amdgcn:" << std::endl;
|
||||
|
@@ -7,6 +7,7 @@ import hashlib
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
@@ -24,6 +25,12 @@ import triton
|
||||
import triton._C.libtriton.triton as _triton
|
||||
from .tools.disasm import extract
|
||||
|
||||
def static_vars(**kwargs):
|
||||
def decorate(func):
|
||||
for k in kwargs:
|
||||
setattr(func, k, kwargs[k])
|
||||
return func
|
||||
return decorate
|
||||
|
||||
def str_to_ty(name):
|
||||
if name[0] == "*":
|
||||
@@ -880,7 +887,19 @@ def ptx_get_kernel_name(ptx: str) -> str:
|
||||
if line.startswith('// .globl'):
|
||||
return line.split()[-1]
|
||||
|
||||
@functools.lru_cache()
|
||||
def rocm_path_dir():
|
||||
return os.getenv("ROCM_PATH", default="/opt/rocm")
|
||||
|
||||
def _get_amdgpu_arch():
|
||||
try:
|
||||
rocminfo = subprocess.check_output(rocm_path_dir() + '/bin/rocminfo').decode()
|
||||
gfx_arch = re.search('Name:\\s+.*(gfx\\d+)', rocminfo)
|
||||
return gfx_arch.group(1).strip()
|
||||
except:
|
||||
return None
|
||||
|
||||
@static_vars(discovered_gfx_arch = _get_amdgpu_arch())
|
||||
def _compile(fn, signature: str, device: int = -1, constants=dict(),
|
||||
specialization=_triton.code_gen.instance_descriptor(),
|
||||
num_warps: int = 4, num_stages: int = 3, extern_libs=None,
|
||||
@@ -905,7 +924,10 @@ def _compile(fn, signature: str, device: int = -1, constants=dict(),
|
||||
|
||||
# compile ttir
|
||||
if torch.version.hip is not None:
|
||||
name, asm, shared_mem = _triton.code_gen.compile_ttir_to_amdgcn(backend, module, device, num_warps, num_stages, extern_libs, cc)
|
||||
gfx_arch = os.environ.get('MI_GPU_ARCH', _compile.discovered_gfx_arch)
|
||||
if gfx_arch is None:
|
||||
raise RuntimeError('AMDGCN gfx arch is not defined.')
|
||||
name, asm, shared_mem = _triton.code_gen.compile_ttir_to_amdgcn(backend, module, device, num_warps, num_stages, extern_libs, gfx_arch)
|
||||
else:
|
||||
name, asm, shared_mem = _triton.code_gen.compile_ttir_to_ptx(backend, module, device, num_warps, num_stages, extern_libs, cc)
|
||||
return asm, shared_mem, name
|
||||
@@ -1145,7 +1167,6 @@ def libcuda_dirs():
|
||||
locs = subprocess.check_output(["whereis", "libcuda.so"]).decode().strip().split()[1:]
|
||||
return [os.path.dirname(loc) for loc in locs]
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
def libhip_dirs():
|
||||
return ["/opt/rocm/lib"]
|
||||
@@ -1162,6 +1183,9 @@ def hip_home_dirs():
|
||||
default_dir = "/opt/rocm"
|
||||
return os.getenv("ROCM_HOME", default=default_dir)
|
||||
|
||||
@functools.lru_cache()
|
||||
def rocm_path_dir():
|
||||
return os.getenv("ROCM_PATH", default="/opt/rocm")
|
||||
|
||||
@contextlib.contextmanager
|
||||
def quiet():
|
||||
|
Reference in New Issue
Block a user