Changes to eliminate the need for the MI_GPU_ARCH environment variable.
The AMDGPU arch is now parsed out of the rocminfo dump.
This commit is contained in:
@@ -44,11 +44,6 @@ endif()
|
|||||||
|
|
||||||
if (TRITON_USE_ROCM)
|
if (TRITON_USE_ROCM)
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -std=gnu++17 -Wno-unused-result -Wno-attributes")
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -std=gnu++17 -Wno-unused-result -Wno-attributes")
|
||||||
set(MI_GPU_ARCH $ENV{MI_GPU_ARCH})
|
|
||||||
if (NOT MI_GPU_ARCH)
|
|
||||||
set(MI_GPU_ARCH "gfx90a")
|
|
||||||
endif()
|
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DMI_GPU_ARCH=${MI_GPU_ARCH}")
|
|
||||||
else()
|
else()
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -std=gnu++17")
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -std=gnu++17")
|
||||||
endif()
|
endif()
|
||||||
|
@@ -300,7 +300,7 @@ namespace triton
|
|||||||
std::string triple = "amdgcn-amd-amdhsa";
|
std::string triple = "amdgcn-amd-amdhsa";
|
||||||
std::string layout = "";
|
std::string layout = "";
|
||||||
std::string features = "+sramecc,-xnack";
|
std::string features = "+sramecc,-xnack";
|
||||||
std::string proc = STRINGIFY(MI_GPU_ARCH);
|
std::string proc = _proc;
|
||||||
// name kernel
|
// name kernel
|
||||||
auto in_time_t = std::chrono::system_clock::to_time_t(std::chrono::system_clock::now());
|
auto in_time_t = std::chrono::system_clock::to_time_t(std::chrono::system_clock::now());
|
||||||
std::stringstream cur_time;
|
std::stringstream cur_time;
|
||||||
|
@@ -493,7 +493,7 @@ void init_triton_codegen(py::module &&m) {
|
|||||||
py::return_value_policy::take_ownership);
|
py::return_value_policy::take_ownership);
|
||||||
|
|
||||||
m.def("compile_ttir_to_amdgcn",
|
m.def("compile_ttir_to_amdgcn",
|
||||||
[](backend_t backend, ir::module &ir, uint64_t device, int num_warps, int num_stages, py::dict& extern_libs, size_t cc) {
|
[](backend_t backend, ir::module &ir, uint64_t device, int num_warps, int num_stages, py::dict& extern_libs, const std::string& gfx_arch) {
|
||||||
std::ostringstream ttir;
|
std::ostringstream ttir;
|
||||||
int n_shared_bytes;
|
int n_shared_bytes;
|
||||||
std::string tmp;
|
std::string tmp;
|
||||||
@@ -515,13 +515,6 @@ void init_triton_codegen(py::module &&m) {
|
|||||||
extern_lib_map.emplace(
|
extern_lib_map.emplace(
|
||||||
name, triton::codegen::create_extern_lib(name, path));
|
name, triton::codegen::create_extern_lib(name, path));
|
||||||
}
|
}
|
||||||
// device properties
|
|
||||||
if (cc == 0) {
|
|
||||||
hipDevice_t dev = (hipDevice_t)device;
|
|
||||||
size_t major = hipGetInfo<hipDeviceAttributeComputeCapabilityMajor>(dev);
|
|
||||||
size_t minor = hipGetInfo<hipDeviceAttributeComputeCapabilityMinor>(dev);
|
|
||||||
cc = major*10 + minor;
|
|
||||||
}
|
|
||||||
int version;
|
int version;
|
||||||
// std::string ptxas_path = drv::path_to_ptxas(version);
|
// std::string ptxas_path = drv::path_to_ptxas(version);
|
||||||
// Triton-IR -> AMDGCN LLVM-IR
|
// Triton-IR -> AMDGCN LLVM-IR
|
||||||
@@ -536,7 +529,7 @@ void init_triton_codegen(py::module &&m) {
|
|||||||
std::cout << "\t" << llir.str() << std::endl;
|
std::cout << "\t" << llir.str() << std::endl;
|
||||||
llir.flush();
|
llir.flush();
|
||||||
// LLVM-IR -> AMDGPU
|
// LLVM-IR -> AMDGPU
|
||||||
std::tuple<std::string, std::string> amdgpu = drv::llir_to_amdgcn(llvm.get(), "gfx90a");
|
std::tuple<std::string, std::string> amdgpu = drv::llir_to_amdgcn(llvm.get(), gfx_arch);
|
||||||
amdgcn = std::get<0>(amdgpu);
|
amdgcn = std::get<0>(amdgpu);
|
||||||
hsaco_path = std::get<1>(amdgpu);
|
hsaco_path = std::get<1>(amdgpu);
|
||||||
std::cout << "amdgcn:" << std::endl;
|
std::cout << "amdgcn:" << std::endl;
|
||||||
|
@@ -7,6 +7,7 @@ import hashlib
|
|||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
@@ -24,6 +25,12 @@ import triton
|
|||||||
import triton._C.libtriton.triton as _triton
|
import triton._C.libtriton.triton as _triton
|
||||||
from .tools.disasm import extract
|
from .tools.disasm import extract
|
||||||
|
|
||||||
|
def static_vars(**kwargs):
|
||||||
|
def decorate(func):
|
||||||
|
for k in kwargs:
|
||||||
|
setattr(func, k, kwargs[k])
|
||||||
|
return func
|
||||||
|
return decorate
|
||||||
|
|
||||||
def str_to_ty(name):
|
def str_to_ty(name):
|
||||||
if name[0] == "*":
|
if name[0] == "*":
|
||||||
@@ -880,7 +887,19 @@ def ptx_get_kernel_name(ptx: str) -> str:
|
|||||||
if line.startswith('// .globl'):
|
if line.startswith('// .globl'):
|
||||||
return line.split()[-1]
|
return line.split()[-1]
|
||||||
|
|
||||||
|
@functools.lru_cache()
|
||||||
|
def rocm_path_dir():
|
||||||
|
return os.getenv("ROCM_PATH", default="/opt/rocm")
|
||||||
|
|
||||||
|
def _get_amdgpu_arch():
|
||||||
|
try:
|
||||||
|
rocminfo = subprocess.check_output(rocm_path_dir() + '/bin/rocminfo').decode()
|
||||||
|
gfx_arch = re.search('Name:\\s+.*(gfx\\d+)', rocminfo)
|
||||||
|
return gfx_arch.group(1).strip()
|
||||||
|
except:
|
||||||
|
return None
|
||||||
|
|
||||||
|
@static_vars(discovered_gfx_arch = _get_amdgpu_arch())
|
||||||
def _compile(fn, signature: str, device: int = -1, constants=dict(),
|
def _compile(fn, signature: str, device: int = -1, constants=dict(),
|
||||||
specialization=_triton.code_gen.instance_descriptor(),
|
specialization=_triton.code_gen.instance_descriptor(),
|
||||||
num_warps: int = 4, num_stages: int = 3, extern_libs=None,
|
num_warps: int = 4, num_stages: int = 3, extern_libs=None,
|
||||||
@@ -905,7 +924,10 @@ def _compile(fn, signature: str, device: int = -1, constants=dict(),
|
|||||||
|
|
||||||
# compile ttir
|
# compile ttir
|
||||||
if torch.version.hip is not None:
|
if torch.version.hip is not None:
|
||||||
name, asm, shared_mem = _triton.code_gen.compile_ttir_to_amdgcn(backend, module, device, num_warps, num_stages, extern_libs, cc)
|
gfx_arch = os.environ.get('MI_GPU_ARCH', _compile.discovered_gfx_arch)
|
||||||
|
if gfx_arch is None:
|
||||||
|
raise RuntimeError('AMDGCN gfx arch is not defined.')
|
||||||
|
name, asm, shared_mem = _triton.code_gen.compile_ttir_to_amdgcn(backend, module, device, num_warps, num_stages, extern_libs, gfx_arch)
|
||||||
else:
|
else:
|
||||||
name, asm, shared_mem = _triton.code_gen.compile_ttir_to_ptx(backend, module, device, num_warps, num_stages, extern_libs, cc)
|
name, asm, shared_mem = _triton.code_gen.compile_ttir_to_ptx(backend, module, device, num_warps, num_stages, extern_libs, cc)
|
||||||
return asm, shared_mem, name
|
return asm, shared_mem, name
|
||||||
@@ -1145,7 +1167,6 @@ def libcuda_dirs():
|
|||||||
locs = subprocess.check_output(["whereis", "libcuda.so"]).decode().strip().split()[1:]
|
locs = subprocess.check_output(["whereis", "libcuda.so"]).decode().strip().split()[1:]
|
||||||
return [os.path.dirname(loc) for loc in locs]
|
return [os.path.dirname(loc) for loc in locs]
|
||||||
|
|
||||||
|
|
||||||
@functools.lru_cache()
|
@functools.lru_cache()
|
||||||
def libhip_dirs():
|
def libhip_dirs():
|
||||||
return ["/opt/rocm/lib"]
|
return ["/opt/rocm/lib"]
|
||||||
@@ -1162,6 +1183,9 @@ def hip_home_dirs():
|
|||||||
default_dir = "/opt/rocm"
|
default_dir = "/opt/rocm"
|
||||||
return os.getenv("ROCM_HOME", default=default_dir)
|
return os.getenv("ROCM_HOME", default=default_dir)
|
||||||
|
|
||||||
|
@functools.lru_cache()
|
||||||
|
def rocm_path_dir():
|
||||||
|
return os.getenv("ROCM_PATH", default="/opt/rocm")
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def quiet():
|
def quiet():
|
||||||
|
Reference in New Issue
Block a user