send amdgcn to cache

This commit is contained in:
Michael Melesse
2022-10-26 17:18:33 +00:00
parent df925f7187
commit 39381d99f8
4 changed files with 42 additions and 48 deletions

View File

@@ -13,7 +13,7 @@ std::string path_to_ptxas(int& version);
std::string llir_to_ptx(llvm::Module* module, int cc, int version);
std::string ptx_to_cubin(const std::string& ptx, const std::string& ptxas_path, int cc);
CUmodule ptx_to_cumodule(const std::string& ptx, int cc);
std::string llir_to_amdgpu(llvm::Module* module, const std::string& proc);
std::tuple<std::string, std::string> llir_to_amdgcn(llvm::Module* module, const std::string& proc);
hipModule_t amdgpu_to_hipmodule(const std::string& path);
}

View File

@@ -269,8 +269,9 @@ namespace triton
/* ------------------------ */
// HIP //
/* ------------------------ */
std::string llir_to_amdgpu(llvm::Module *module, const std::string &_proc)
std::tuple<std::string, std::string> llir_to_amdgcn(llvm::Module *module, const std::string &_proc)
{
std::cout << "llvm.cc: llir_to_amdgcn:" << std::endl;
init_llvm();
// proc = std::get<0>(GetFeatureStrFromGCNArchName(rocminfo));
// features = std::get<1>(GetFeatureStrFromGCNArchName(rocminfo));
@@ -329,21 +330,13 @@ namespace triton
machine->addPassesToEmitFile(pass, *isabin_fs, nullptr, llvm::CGFT_ObjectFile);
pass.run(*module);
#ifdef DEBUG_ROCM
std::cout << "Generating GCN ISA file" << std::endl;
// Save GCN ISA.
llvm::SmallVector<char, 0> debugBuffer;
llvm::legacy::PassManager debugPass;
llvm::raw_svector_ostream debugStream(debugBuffer);
machine->addPassesToEmitFile(debugPass, debugStream, nullptr, llvm::CodeGenFileType::CGFT_AssemblyFile); // TODO:cause segfault on REM ops also cause @llvm.amdgcn.if bug
debugPass.run(*module);
// Save GCN ISA.
std::string amdgcn_path = std::string("/tmp/") + kernel_name + std::string(".gcn");
std::string result(debugBuffer.begin(), debugBuffer.end());
std::ofstream amdgcn(amdgcn_path);
amdgcn << result;
amdgcn.close();
#endif
std::string amdgcn(debugBuffer.begin(), debugBuffer.end());
// generate HASCO file
std::string hsaco_path = std::string("/tmp/") + kernel_name + std::string(".hsaco");
@@ -359,13 +352,14 @@ namespace triton
std::cout << lld_result << std::endl;
}
return hsaco_path;
return std::make_tuple(amdgcn, hsaco_path);
}
hipModule_t amdgpu_to_hipmodule(const std::string &path)
hipModule_t amdgpu_to_hipmodule(const std::string &hsaco_path)
{
std::cout << "llvm.cc: amdgpu_to_hipmodule:" << std::endl;
// Read HSACO.
std::ifstream hsaco_file(path, std::ios::binary | std::ios::ate);
std::ifstream hsaco_file(hsaco_path, std::ios::binary | std::ios::ate);
std::ifstream::pos_type hsaco_file_size = hsaco_file.tellg();
std::vector<unsigned char> hsaco(hsaco_file_size);

View File

@@ -492,16 +492,16 @@ void init_triton_codegen(py::module &&m) {
},
py::return_value_policy::take_ownership);
m.def("compile_ttir_to_amdgpu",
m.def("compile_ttir_to_amdgcn",
[](backend_t backend, ir::module &ir, uint64_t device, int num_warps, int num_stages, py::dict& extern_libs, size_t cc) {
std::ostringstream ttir;
int n_shared_bytes;
std::string tmp;
std::string amdgpu;
std::string hipmodule;
std::string amdgcn;
std::string hsaco_path;
std::string name;
{
std::cout << "triton.cc: compile_ttir_to_amdgpu:" << std::endl;
std::cout << "triton.cc: compile_ttir_to_amdgcn:" << std::endl;
// Scope where the GIL is released
py::gil_scoped_release allow_threads;
name = ir.get_function_list()[0]->get_name();
@@ -534,17 +534,15 @@ void init_triton_codegen(py::module &&m) {
llir << *llvm;
llir.flush();
// LLVM-IR -> AMDGPU
std::string amdgpu = drv::llir_to_amdgpu(llvm.get(), "gfx90a");
std::cout << "amdgpu = " << amdgpu << std::endl;
// AMDGPU -> Binary
hipModule_t hipmodule = drv::amdgpu_to_hipmodule(amdgpu);
std::cout << "hipmodule = " << hipmodule << std::endl;
std::tuple<std::string, std::string> amdgpu = drv::llir_to_amdgcn(llvm.get(), "gfx90a");
amdgcn = std::get<0>(amdgpu);
hsaco_path = std::get<1>(amdgpu);
}
asm_map_t asm_map;
asm_map["ttir"] = py::cast(ttir.str());
asm_map["llir"] = py::cast(tmp);
asm_map["amdgpu"] = py::cast(amdgpu);
asm_map["hipmodule"] = py::bytes(hipmodule);
asm_map["amdgcn"] = py::cast(amdgcn);
asm_map["hsaco_path"] = py::cast(hsaco_path);
return std::make_tuple(name, asm_map, n_shared_bytes);
},
@@ -585,34 +583,36 @@ void init_triton_codegen(py::module &&m) {
// ---------------------------------------
// Load provided assembly code into driver
// ---------------------------------------
m.def("load_binary_hipmodule", [](const std::string& name, const std::string& data, size_t n_shared_bytes, uint64_t device){
std::cout << "triton.cc: load_binary_hipmodule:" << std::endl;
m.def("load_binary_hsaco", [](const std::string& name, const std::string& data, size_t n_shared_bytes, uint64_t device){
std::cout << "triton.cc: load_binary_hsaco:" << std::endl;
std::cout << "\tname:" << name << std::endl;
std::cout << "\tdata:" << data << std::endl;
std::cout << "\tn_shared_bytes:" << n_shared_bytes << std::endl;
std::cout << "\tdevice:" << device << std::endl;
py::gil_scoped_release allow_threads;
// create driver handles
std::cout << "\t" << "// create driver handles" << std::endl;
hipFunction_t fun;
hipModule_t mod;
drv::dispatch::hipModuleLoadData(&mod, data.c_str());
hipModule_t mod = drv::amdgpu_to_hipmodule(data);
drv::dispatch::hipModuleGetFunction(&fun, mod, name.c_str());
// get allocated registers and spilled registers from the function
std::cout << "\t" << "// get allocated registers and spilled registers from the function" << std::endl;
int n_regs = 0;
int n_spills = 0;
hipFuncAttributes attr;
drv::dispatch::hipFuncGetAttributes(&attr, fun);
drv::dispatch::hipFuncGetAttributes(&attr, fun);
// drv::dispatch::hipFuncGetAttributes(&attr, fun);
// drv::dispatch::hipFuncGetAttributes(&attr, fun);
n_regs = attr.numRegs;
n_spills = attr.localSizeBytes / 4;
// set dynamic shared memory if necessary
std::cout << "\t" << "// set dynamic shared memory if necessary" << std::endl;
int shared_optin;
drv::dispatch::hipDeviceGetAttribute(&shared_optin, hipDeviceAttributeSharedMemPerBlockOptin, device);
if(n_shared_bytes > 49152 && shared_optin > 49152){
drv::dispatch::hipFuncSetCacheConfig(fun, hipFuncCachePreferShared);
// drv::dispatch::hipFuncSetCacheConfig(fun, hipFuncCachePreferShared);
int shared_total, shared_static;
drv::dispatch::hipDeviceGetAttribute(&shared_total, hipDeviceAttributeMaxSharedMemoryPerMultiprocessor, device);
drv::dispatch::hipFuncGetAttributes(&attr, fun);
// drv::dispatch::hipFuncGetAttributes(&attr, fun);
shared_total = attr.sharedSizeBytes;
// drv::dispatch::hipFuncSetAttribute(fun, hipFuncAttributeMaxDynamicSharedMemorySize, shared_optin - shared_static);
}

View File

@@ -895,7 +895,7 @@ def _compile(fn, signature: str, device: int = -1, constants=dict(),
if output == "ttir":
return module
assert (output == "cubin" or output == "hipmodule")
assert (output == "cubin" or output == "hsaco")
if torch.version.hip is not None:
backend = _triton.runtime.backend.ROCM
else:
@@ -905,7 +905,7 @@ def _compile(fn, signature: str, device: int = -1, constants=dict(),
# compile ttir
if torch.version.hip is not None:
name, asm, shared_mem = _triton.code_gen.compile_ttir_to_amdgpu(backend, module, device, num_warps, num_stages, extern_libs, cc)
name, asm, shared_mem = _triton.code_gen.compile_ttir_to_amdgcn(backend, module, device, num_warps, num_stages, extern_libs, cc)
else:
name, asm, shared_mem = _triton.code_gen.compile_ttir_to_ptx(backend, module, device, num_warps, num_stages, extern_libs, cc)
return asm, shared_mem, name
@@ -1275,7 +1275,7 @@ def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: i
src_path = os.path.join(tmpdir, "main.c")
with open(src_path, "w") as f:
f.write(src)
so = _build(fn.__name__, src_path, tmpdir)
so = _build(fn.__name__, src_path, tmpdir) # build step
with open(so, "rb") as f:
so_cache_manager.put(f.read(), so_name, binary=True)
@@ -1283,10 +1283,10 @@ def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: i
fn_cache_key = make_fn_cache_key(fn.cache_key, signature, configs, constants, num_warps, num_stages)
fn_cache_manager = CacheManager(fn_cache_key)
if torch.version.hip is not None:
amdgpu_name = f"{name}.gcn"
hipmodule_name = f"{name}.hipmodule"
assembly_name = amdgpu_name
binary_name = hipmodule_name
amdgcn_name = f"{name}.gcn"
hasco_name = f"{name}.hsaco"
assembly_name = amdgcn_name
binary_name = hasco_name
else:
ptx_name = f"{name}.ptx"
cubin_name = f"{name}.cubin"
@@ -1304,10 +1304,10 @@ def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: i
if torch.version.hip is not None:
asm, shared, kernel_name = _compile(fn, signature, device, constants, configs[0], num_warps, num_stages,
extern_libs, "hipmodule", cc)
extern_libs, "hsaco", cc)
# cache AMD assembly and binary
fn_cache_manager.put(asm["hipmodule"], binary_name)
fn_cache_manager.put(asm["amdgpu"], assembly_name, binary=False)
fn_cache_manager.put(asm["hsaco_path"], binary_name, binary=False)
fn_cache_manager.put(asm["amdgcn"], assembly_name, binary=False)
else:
asm, shared, kernel_name = _compile(fn, signature, device, constants, configs[0], num_warps, num_stages,
extern_libs, "cubin", cc)
@@ -1354,10 +1354,10 @@ class CompiledKernel:
# initialize asm dict
self.asm = dict()
if torch.version.hip is not None:
with open(os.path.join(cache_dir, f"{fn_name}.hipmodule"), "rb") as f:
self.asm["hipmodule"] = f.read()
with open(os.path.join(cache_dir, f"{fn_name}.hsaco"), "rb") as f:
self.asm["hsaco_path"] = f.read()
with open(os.path.join(cache_dir, f"{fn_name}.gcn"), "r") as f:
self.asm["amdgpu"] = f.read()
self.asm["amdgcn"] = f.read()
else:
with open(os.path.join(cache_dir, f"{fn_name}.cubin"), "rb") as f:
self.asm["cubin"] = f.read()
@@ -1369,7 +1369,7 @@ class CompiledKernel:
self.asm["ttir"] = f.read()
if torch.version.hip is not None:
mod, func, n_regs, n_spills = _triton.code_gen.load_binary_hipmodule(metadata["name"], self.asm["hipmodule"], self.shared, device)
mod, func, n_regs, n_spills = _triton.code_gen.load_binary_hsaco(metadata["name"], self.asm["hsaco_path"], self.shared, device)
else:
mod, func, n_regs, n_spills = _triton.code_gen.load_binary_cubin(metadata["name"], self.asm["cubin"], self.shared, device)
self.fn_name = fn_name