send amdgcn to cache
This commit is contained in:
@@ -13,7 +13,7 @@ std::string path_to_ptxas(int& version);
|
|||||||
std::string llir_to_ptx(llvm::Module* module, int cc, int version);
|
std::string llir_to_ptx(llvm::Module* module, int cc, int version);
|
||||||
std::string ptx_to_cubin(const std::string& ptx, const std::string& ptxas_path, int cc);
|
std::string ptx_to_cubin(const std::string& ptx, const std::string& ptxas_path, int cc);
|
||||||
CUmodule ptx_to_cumodule(const std::string& ptx, int cc);
|
CUmodule ptx_to_cumodule(const std::string& ptx, int cc);
|
||||||
std::string llir_to_amdgpu(llvm::Module* module, const std::string& proc);
|
std::tuple<std::string, std::string> llir_to_amdgcn(llvm::Module* module, const std::string& proc);
|
||||||
hipModule_t amdgpu_to_hipmodule(const std::string& path);
|
hipModule_t amdgpu_to_hipmodule(const std::string& path);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@@ -269,8 +269,9 @@ namespace triton
|
|||||||
/* ------------------------ */
|
/* ------------------------ */
|
||||||
// HIP //
|
// HIP //
|
||||||
/* ------------------------ */
|
/* ------------------------ */
|
||||||
std::string llir_to_amdgpu(llvm::Module *module, const std::string &_proc)
|
std::tuple<std::string, std::string> llir_to_amdgcn(llvm::Module *module, const std::string &_proc)
|
||||||
{
|
{
|
||||||
|
std::cout << "llvm.cc: llir_to_amdgcn:" << std::endl;
|
||||||
init_llvm();
|
init_llvm();
|
||||||
// proc = std::get<0>(GetFeatureStrFromGCNArchName(rocminfo));
|
// proc = std::get<0>(GetFeatureStrFromGCNArchName(rocminfo));
|
||||||
// features = std::get<1>(GetFeatureStrFromGCNArchName(rocminfo));
|
// features = std::get<1>(GetFeatureStrFromGCNArchName(rocminfo));
|
||||||
@@ -329,21 +330,13 @@ namespace triton
|
|||||||
machine->addPassesToEmitFile(pass, *isabin_fs, nullptr, llvm::CGFT_ObjectFile);
|
machine->addPassesToEmitFile(pass, *isabin_fs, nullptr, llvm::CGFT_ObjectFile);
|
||||||
pass.run(*module);
|
pass.run(*module);
|
||||||
|
|
||||||
#ifdef DEBUG_ROCM
|
// Save GCN ISA.
|
||||||
std::cout << "Generating GCN ISA file" << std::endl;
|
|
||||||
llvm::SmallVector<char, 0> debugBuffer;
|
llvm::SmallVector<char, 0> debugBuffer;
|
||||||
llvm::legacy::PassManager debugPass;
|
llvm::legacy::PassManager debugPass;
|
||||||
llvm::raw_svector_ostream debugStream(debugBuffer);
|
llvm::raw_svector_ostream debugStream(debugBuffer);
|
||||||
machine->addPassesToEmitFile(debugPass, debugStream, nullptr, llvm::CodeGenFileType::CGFT_AssemblyFile); // TODO:cause segfault on REM ops also cause @llvm.amdgcn.if bug
|
machine->addPassesToEmitFile(debugPass, debugStream, nullptr, llvm::CodeGenFileType::CGFT_AssemblyFile); // TODO:cause segfault on REM ops also cause @llvm.amdgcn.if bug
|
||||||
debugPass.run(*module);
|
debugPass.run(*module);
|
||||||
|
std::string amdgcn(debugBuffer.begin(), debugBuffer.end());
|
||||||
// Save GCN ISA.
|
|
||||||
std::string amdgcn_path = std::string("/tmp/") + kernel_name + std::string(".gcn");
|
|
||||||
std::string result(debugBuffer.begin(), debugBuffer.end());
|
|
||||||
std::ofstream amdgcn(amdgcn_path);
|
|
||||||
amdgcn << result;
|
|
||||||
amdgcn.close();
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// generate HASCO file
|
// generate HASCO file
|
||||||
std::string hsaco_path = std::string("/tmp/") + kernel_name + std::string(".hsaco");
|
std::string hsaco_path = std::string("/tmp/") + kernel_name + std::string(".hsaco");
|
||||||
@@ -359,13 +352,14 @@ namespace triton
|
|||||||
std::cout << lld_result << std::endl;
|
std::cout << lld_result << std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
return hsaco_path;
|
return std::make_tuple(amdgcn, hsaco_path);
|
||||||
}
|
}
|
||||||
|
|
||||||
hipModule_t amdgpu_to_hipmodule(const std::string &path)
|
hipModule_t amdgpu_to_hipmodule(const std::string &hsaco_path)
|
||||||
{
|
{
|
||||||
|
std::cout << "llvm.cc: amdgpu_to_hipmodule:" << std::endl;
|
||||||
// Read HSACO.
|
// Read HSACO.
|
||||||
std::ifstream hsaco_file(path, std::ios::binary | std::ios::ate);
|
std::ifstream hsaco_file(hsaco_path, std::ios::binary | std::ios::ate);
|
||||||
std::ifstream::pos_type hsaco_file_size = hsaco_file.tellg();
|
std::ifstream::pos_type hsaco_file_size = hsaco_file.tellg();
|
||||||
|
|
||||||
std::vector<unsigned char> hsaco(hsaco_file_size);
|
std::vector<unsigned char> hsaco(hsaco_file_size);
|
||||||
|
@@ -492,16 +492,16 @@ void init_triton_codegen(py::module &&m) {
|
|||||||
},
|
},
|
||||||
py::return_value_policy::take_ownership);
|
py::return_value_policy::take_ownership);
|
||||||
|
|
||||||
m.def("compile_ttir_to_amdgpu",
|
m.def("compile_ttir_to_amdgcn",
|
||||||
[](backend_t backend, ir::module &ir, uint64_t device, int num_warps, int num_stages, py::dict& extern_libs, size_t cc) {
|
[](backend_t backend, ir::module &ir, uint64_t device, int num_warps, int num_stages, py::dict& extern_libs, size_t cc) {
|
||||||
std::ostringstream ttir;
|
std::ostringstream ttir;
|
||||||
int n_shared_bytes;
|
int n_shared_bytes;
|
||||||
std::string tmp;
|
std::string tmp;
|
||||||
std::string amdgpu;
|
std::string amdgcn;
|
||||||
std::string hipmodule;
|
std::string hsaco_path;
|
||||||
std::string name;
|
std::string name;
|
||||||
{
|
{
|
||||||
std::cout << "triton.cc: compile_ttir_to_amdgpu:" << std::endl;
|
std::cout << "triton.cc: compile_ttir_to_amdgcn:" << std::endl;
|
||||||
// Scope where the GIL is released
|
// Scope where the GIL is released
|
||||||
py::gil_scoped_release allow_threads;
|
py::gil_scoped_release allow_threads;
|
||||||
name = ir.get_function_list()[0]->get_name();
|
name = ir.get_function_list()[0]->get_name();
|
||||||
@@ -534,17 +534,15 @@ void init_triton_codegen(py::module &&m) {
|
|||||||
llir << *llvm;
|
llir << *llvm;
|
||||||
llir.flush();
|
llir.flush();
|
||||||
// LLVM-IR -> AMDGPU
|
// LLVM-IR -> AMDGPU
|
||||||
std::string amdgpu = drv::llir_to_amdgpu(llvm.get(), "gfx90a");
|
std::tuple<std::string, std::string> amdgpu = drv::llir_to_amdgcn(llvm.get(), "gfx90a");
|
||||||
std::cout << "amdgpu = " << amdgpu << std::endl;
|
amdgcn = std::get<0>(amdgpu);
|
||||||
// AMDGPU -> Binary
|
hsaco_path = std::get<1>(amdgpu);
|
||||||
hipModule_t hipmodule = drv::amdgpu_to_hipmodule(amdgpu);
|
|
||||||
std::cout << "hipmodule = " << hipmodule << std::endl;
|
|
||||||
}
|
}
|
||||||
asm_map_t asm_map;
|
asm_map_t asm_map;
|
||||||
asm_map["ttir"] = py::cast(ttir.str());
|
asm_map["ttir"] = py::cast(ttir.str());
|
||||||
asm_map["llir"] = py::cast(tmp);
|
asm_map["llir"] = py::cast(tmp);
|
||||||
asm_map["amdgpu"] = py::cast(amdgpu);
|
asm_map["amdgcn"] = py::cast(amdgcn);
|
||||||
asm_map["hipmodule"] = py::bytes(hipmodule);
|
asm_map["hsaco_path"] = py::cast(hsaco_path);
|
||||||
|
|
||||||
return std::make_tuple(name, asm_map, n_shared_bytes);
|
return std::make_tuple(name, asm_map, n_shared_bytes);
|
||||||
},
|
},
|
||||||
@@ -585,34 +583,36 @@ void init_triton_codegen(py::module &&m) {
|
|||||||
// ---------------------------------------
|
// ---------------------------------------
|
||||||
// Load provided assembly code into driver
|
// Load provided assembly code into driver
|
||||||
// ---------------------------------------
|
// ---------------------------------------
|
||||||
m.def("load_binary_hipmodule", [](const std::string& name, const std::string& data, size_t n_shared_bytes, uint64_t device){
|
m.def("load_binary_hsaco", [](const std::string& name, const std::string& data, size_t n_shared_bytes, uint64_t device){
|
||||||
std::cout << "triton.cc: load_binary_hipmodule:" << std::endl;
|
std::cout << "triton.cc: load_binary_hsaco:" << std::endl;
|
||||||
std::cout << "\tname:" << name << std::endl;
|
std::cout << "\tname:" << name << std::endl;
|
||||||
std::cout << "\tdata:" << data << std::endl;
|
std::cout << "\tdata:" << data << std::endl;
|
||||||
std::cout << "\tn_shared_bytes:" << n_shared_bytes << std::endl;
|
std::cout << "\tn_shared_bytes:" << n_shared_bytes << std::endl;
|
||||||
std::cout << "\tdevice:" << device << std::endl;
|
std::cout << "\tdevice:" << device << std::endl;
|
||||||
py::gil_scoped_release allow_threads;
|
py::gil_scoped_release allow_threads;
|
||||||
// create driver handles
|
// create driver handles
|
||||||
|
std::cout << "\t" << "// create driver handles" << std::endl;
|
||||||
hipFunction_t fun;
|
hipFunction_t fun;
|
||||||
hipModule_t mod;
|
hipModule_t mod = drv::amdgpu_to_hipmodule(data);
|
||||||
drv::dispatch::hipModuleLoadData(&mod, data.c_str());
|
|
||||||
drv::dispatch::hipModuleGetFunction(&fun, mod, name.c_str());
|
drv::dispatch::hipModuleGetFunction(&fun, mod, name.c_str());
|
||||||
// get allocated registers and spilled registers from the function
|
// get allocated registers and spilled registers from the function
|
||||||
|
std::cout << "\t" << "// get allocated registers and spilled registers from the function" << std::endl;
|
||||||
int n_regs = 0;
|
int n_regs = 0;
|
||||||
int n_spills = 0;
|
int n_spills = 0;
|
||||||
hipFuncAttributes attr;
|
hipFuncAttributes attr;
|
||||||
drv::dispatch::hipFuncGetAttributes(&attr, fun);
|
// drv::dispatch::hipFuncGetAttributes(&attr, fun);
|
||||||
drv::dispatch::hipFuncGetAttributes(&attr, fun);
|
// drv::dispatch::hipFuncGetAttributes(&attr, fun);
|
||||||
n_regs = attr.numRegs;
|
n_regs = attr.numRegs;
|
||||||
n_spills = attr.localSizeBytes / 4;
|
n_spills = attr.localSizeBytes / 4;
|
||||||
// set dynamic shared memory if necessary
|
// set dynamic shared memory if necessary
|
||||||
|
std::cout << "\t" << "// set dynamic shared memory if necessary" << std::endl;
|
||||||
int shared_optin;
|
int shared_optin;
|
||||||
drv::dispatch::hipDeviceGetAttribute(&shared_optin, hipDeviceAttributeSharedMemPerBlockOptin, device);
|
drv::dispatch::hipDeviceGetAttribute(&shared_optin, hipDeviceAttributeSharedMemPerBlockOptin, device);
|
||||||
if(n_shared_bytes > 49152 && shared_optin > 49152){
|
if(n_shared_bytes > 49152 && shared_optin > 49152){
|
||||||
drv::dispatch::hipFuncSetCacheConfig(fun, hipFuncCachePreferShared);
|
// drv::dispatch::hipFuncSetCacheConfig(fun, hipFuncCachePreferShared);
|
||||||
int shared_total, shared_static;
|
int shared_total, shared_static;
|
||||||
drv::dispatch::hipDeviceGetAttribute(&shared_total, hipDeviceAttributeMaxSharedMemoryPerMultiprocessor, device);
|
drv::dispatch::hipDeviceGetAttribute(&shared_total, hipDeviceAttributeMaxSharedMemoryPerMultiprocessor, device);
|
||||||
drv::dispatch::hipFuncGetAttributes(&attr, fun);
|
// drv::dispatch::hipFuncGetAttributes(&attr, fun);
|
||||||
shared_total = attr.sharedSizeBytes;
|
shared_total = attr.sharedSizeBytes;
|
||||||
// drv::dispatch::hipFuncSetAttribute(fun, hipFuncAttributeMaxDynamicSharedMemorySize, shared_optin - shared_static);
|
// drv::dispatch::hipFuncSetAttribute(fun, hipFuncAttributeMaxDynamicSharedMemorySize, shared_optin - shared_static);
|
||||||
}
|
}
|
||||||
|
@@ -895,7 +895,7 @@ def _compile(fn, signature: str, device: int = -1, constants=dict(),
|
|||||||
if output == "ttir":
|
if output == "ttir":
|
||||||
return module
|
return module
|
||||||
|
|
||||||
assert (output == "cubin" or output == "hipmodule")
|
assert (output == "cubin" or output == "hsaco")
|
||||||
if torch.version.hip is not None:
|
if torch.version.hip is not None:
|
||||||
backend = _triton.runtime.backend.ROCM
|
backend = _triton.runtime.backend.ROCM
|
||||||
else:
|
else:
|
||||||
@@ -905,7 +905,7 @@ def _compile(fn, signature: str, device: int = -1, constants=dict(),
|
|||||||
|
|
||||||
# compile ttir
|
# compile ttir
|
||||||
if torch.version.hip is not None:
|
if torch.version.hip is not None:
|
||||||
name, asm, shared_mem = _triton.code_gen.compile_ttir_to_amdgpu(backend, module, device, num_warps, num_stages, extern_libs, cc)
|
name, asm, shared_mem = _triton.code_gen.compile_ttir_to_amdgcn(backend, module, device, num_warps, num_stages, extern_libs, cc)
|
||||||
else:
|
else:
|
||||||
name, asm, shared_mem = _triton.code_gen.compile_ttir_to_ptx(backend, module, device, num_warps, num_stages, extern_libs, cc)
|
name, asm, shared_mem = _triton.code_gen.compile_ttir_to_ptx(backend, module, device, num_warps, num_stages, extern_libs, cc)
|
||||||
return asm, shared_mem, name
|
return asm, shared_mem, name
|
||||||
@@ -1275,7 +1275,7 @@ def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: i
|
|||||||
src_path = os.path.join(tmpdir, "main.c")
|
src_path = os.path.join(tmpdir, "main.c")
|
||||||
with open(src_path, "w") as f:
|
with open(src_path, "w") as f:
|
||||||
f.write(src)
|
f.write(src)
|
||||||
so = _build(fn.__name__, src_path, tmpdir)
|
so = _build(fn.__name__, src_path, tmpdir) # build step
|
||||||
with open(so, "rb") as f:
|
with open(so, "rb") as f:
|
||||||
so_cache_manager.put(f.read(), so_name, binary=True)
|
so_cache_manager.put(f.read(), so_name, binary=True)
|
||||||
|
|
||||||
@@ -1283,10 +1283,10 @@ def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: i
|
|||||||
fn_cache_key = make_fn_cache_key(fn.cache_key, signature, configs, constants, num_warps, num_stages)
|
fn_cache_key = make_fn_cache_key(fn.cache_key, signature, configs, constants, num_warps, num_stages)
|
||||||
fn_cache_manager = CacheManager(fn_cache_key)
|
fn_cache_manager = CacheManager(fn_cache_key)
|
||||||
if torch.version.hip is not None:
|
if torch.version.hip is not None:
|
||||||
amdgpu_name = f"{name}.gcn"
|
amdgcn_name = f"{name}.gcn"
|
||||||
hipmodule_name = f"{name}.hipmodule"
|
hasco_name = f"{name}.hsaco"
|
||||||
assembly_name = amdgpu_name
|
assembly_name = amdgcn_name
|
||||||
binary_name = hipmodule_name
|
binary_name = hasco_name
|
||||||
else:
|
else:
|
||||||
ptx_name = f"{name}.ptx"
|
ptx_name = f"{name}.ptx"
|
||||||
cubin_name = f"{name}.cubin"
|
cubin_name = f"{name}.cubin"
|
||||||
@@ -1304,10 +1304,10 @@ def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: i
|
|||||||
|
|
||||||
if torch.version.hip is not None:
|
if torch.version.hip is not None:
|
||||||
asm, shared, kernel_name = _compile(fn, signature, device, constants, configs[0], num_warps, num_stages,
|
asm, shared, kernel_name = _compile(fn, signature, device, constants, configs[0], num_warps, num_stages,
|
||||||
extern_libs, "hipmodule", cc)
|
extern_libs, "hsaco", cc)
|
||||||
# cache AMD assembly and binary
|
# cache AMD assembly and binary
|
||||||
fn_cache_manager.put(asm["hipmodule"], binary_name)
|
fn_cache_manager.put(asm["hsaco_path"], binary_name, binary=False)
|
||||||
fn_cache_manager.put(asm["amdgpu"], assembly_name, binary=False)
|
fn_cache_manager.put(asm["amdgcn"], assembly_name, binary=False)
|
||||||
else:
|
else:
|
||||||
asm, shared, kernel_name = _compile(fn, signature, device, constants, configs[0], num_warps, num_stages,
|
asm, shared, kernel_name = _compile(fn, signature, device, constants, configs[0], num_warps, num_stages,
|
||||||
extern_libs, "cubin", cc)
|
extern_libs, "cubin", cc)
|
||||||
@@ -1354,10 +1354,10 @@ class CompiledKernel:
|
|||||||
# initialize asm dict
|
# initialize asm dict
|
||||||
self.asm = dict()
|
self.asm = dict()
|
||||||
if torch.version.hip is not None:
|
if torch.version.hip is not None:
|
||||||
with open(os.path.join(cache_dir, f"{fn_name}.hipmodule"), "rb") as f:
|
with open(os.path.join(cache_dir, f"{fn_name}.hsaco"), "rb") as f:
|
||||||
self.asm["hipmodule"] = f.read()
|
self.asm["hsaco_path"] = f.read()
|
||||||
with open(os.path.join(cache_dir, f"{fn_name}.gcn"), "r") as f:
|
with open(os.path.join(cache_dir, f"{fn_name}.gcn"), "r") as f:
|
||||||
self.asm["amdgpu"] = f.read()
|
self.asm["amdgcn"] = f.read()
|
||||||
else:
|
else:
|
||||||
with open(os.path.join(cache_dir, f"{fn_name}.cubin"), "rb") as f:
|
with open(os.path.join(cache_dir, f"{fn_name}.cubin"), "rb") as f:
|
||||||
self.asm["cubin"] = f.read()
|
self.asm["cubin"] = f.read()
|
||||||
@@ -1369,7 +1369,7 @@ class CompiledKernel:
|
|||||||
self.asm["ttir"] = f.read()
|
self.asm["ttir"] = f.read()
|
||||||
|
|
||||||
if torch.version.hip is not None:
|
if torch.version.hip is not None:
|
||||||
mod, func, n_regs, n_spills = _triton.code_gen.load_binary_hipmodule(metadata["name"], self.asm["hipmodule"], self.shared, device)
|
mod, func, n_regs, n_spills = _triton.code_gen.load_binary_hsaco(metadata["name"], self.asm["hsaco_path"], self.shared, device)
|
||||||
else:
|
else:
|
||||||
mod, func, n_regs, n_spills = _triton.code_gen.load_binary_cubin(metadata["name"], self.asm["cubin"], self.shared, device)
|
mod, func, n_regs, n_spills = _triton.code_gen.load_binary_cubin(metadata["name"], self.asm["cubin"], self.shared, device)
|
||||||
self.fn_name = fn_name
|
self.fn_name = fn_name
|
||||||
|
Reference in New Issue
Block a user