From d022f5cf2c65ce937387fef2ebd71a92e716753f Mon Sep 17 00:00:00 2001 From: Michael Melesse Date: Fri, 21 Oct 2022 17:54:31 +0000 Subject: [PATCH] add compiling back to gcn --- python/src/triton.cc | 58 ++++++++++++++++++++++++++++++++++++++- python/triton/compiler.py | 18 +++++++++--- scripts/amd/run.sh | 2 +- 3 files changed, 72 insertions(+), 6 deletions(-) diff --git a/python/src/triton.cc b/python/src/triton.cc index 204317079..888ca179c 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -437,7 +437,7 @@ typedef std::map asm_map_t; // --------------------------------------- void init_triton_codegen(py::module &&m) { - m.def("compile_ttir", + m.def("compile_ttir_to_ptx", [](backend_t backend, ir::module &ir, uint64_t device, int num_warps, int num_stages, py::dict& extern_libs, size_t cc) { std::ostringstream ttir; int n_shared_bytes; @@ -492,6 +492,62 @@ void init_triton_codegen(py::module &&m) { }, py::return_value_policy::take_ownership); + m.def("compile_ttir_to_amdgpu", + [](backend_t backend, ir::module &ir, uint64_t device, int num_warps, int num_stages, py::dict& extern_libs, size_t cc) { + std::ostringstream ttir; + int n_shared_bytes; + std::string tmp; + std::string amdgpu; + std::string hipmodule; + std::string name; + { + // Scope where the GIL is released + py::gil_scoped_release allow_threads; + name = ir.get_function_list()[0]->get_name(); + ir.print(ttir); + llvm::LLVMContext ctx; + // construct extern lib map + triton::codegen::ExternLibMap extern_lib_map; + for (auto item : extern_libs) { + auto name = item.first.cast(); + auto path = item.second.cast(); + extern_lib_map.emplace( + name, triton::codegen::create_extern_lib(name, path)); + } + // device properties + if (cc == 0) { + hipDevice_t dev = (hipDevice_t)device; + size_t major = hipGetInfo(dev); + size_t minor = hipGetInfo(dev); + cc = major*10 + minor; + } + int version; + // std::string ptxas_path = drv::path_to_ptxas(version); + // Triton-IR -> AMDGCN LLVM-IR + triton::codegen::amd_cl_target target; + auto llvm = triton::codegen::add_passes_to_emit_bin( + ir, ctx, &target, num_warps, num_stages, n_shared_bytes, extern_lib_map); + llvm::raw_string_ostream llir(tmp); + llir << *llvm; + llir.flush(); + // LLVM-IR -> AMD HSACO + std::string amdgpu = drv::llir_to_amdgpu(llvm.get(), "gfx90a"); + // HSACO -> GCN + hipModule_t hipmodule = drv::amdgpu_to_hipmodule(amdgpu); + } + asm_map_t asm_map; + asm_map["ttir"] = py::cast(ttir.str()); + asm_map["llir"] = py::cast(tmp); + asm_map["amdgpu"] = py::cast(amdgpu); + + if(!hipmodule.empty()){ + py::bytes bytes(hipmodule); + asm_map["hipmodule"] = bytes; + } + return std::make_tuple(name, asm_map, n_shared_bytes); + }, + py::return_value_policy::take_ownership); + // --------------------------------------- // Load provided assembly code into driver diff --git a/python/triton/compiler.py b/python/triton/compiler.py index 9e4726b87..04b1b3e8d 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -893,14 +893,19 @@ def _compile(fn, signature: str, device: int = -1, constants=dict(), if output == "ttir": return module - assert output == "cubin" + assert (output == "cubin" or output == "hsaco") if torch.version.hip is not None: backend = _triton.runtime.backend.ROCM else: backend = _triton.runtime.backend.CUDA if extern_libs is None: extern_libs = dict() - name, asm, shared_mem = _triton.code_gen.compile_ttir(backend, module, device, num_warps, num_stages, extern_libs, cc) + + # compile ttir + if torch.version.hip is not None: + name, asm, shared_mem = _triton.code_gen.compile_ttir_to_amdgpu(backend, module, device, num_warps, num_stages, extern_libs, cc) + else: + name, asm, shared_mem = _triton.code_gen.compile_ttir_to_ptx(backend, module, device, num_warps, num_stages, extern_libs, cc) return asm, shared_mem, name @@ -1274,8 +1279,13 @@ def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: i not fn_cache_manager.has_file(ptx_name) or \ not fn_cache_manager.has_file(ttir_name) or \ not fn_cache_manager.has_file(llir_name): - asm, shared, kernel_name = _compile(fn, signature, device, constants, configs[0], num_warps, num_stages, - extern_libs, "cubin", cc) + + if torch.version.hip is not None: + asm, shared, kernel_name = _compile(fn, signature, device, constants, configs[0], num_warps, num_stages, + extern_libs, "hsaco", cc) + else + asm, shared, kernel_name = _compile(fn, signature, device, constants, configs[0], num_warps, num_stages, + extern_libs, "cubin", cc) metadata = {"name": kernel_name, "shared": shared, "num_warps": num_warps, "num_stages": num_stages} fn_cache_manager.put(asm["cubin"], cubin_name) fn_cache_manager.put(asm["ptx"], ptx_name, binary=False) diff --git a/scripts/amd/run.sh b/scripts/amd/run.sh index d396b733a..3bdea1a4e 100644 --- a/scripts/amd/run.sh +++ b/scripts/amd/run.sh @@ -15,4 +15,4 @@ bash scripts/amd/test.sh 2>&1 |tee $LOG_DIR/test.log # bash scripts/amd/backtrace.sh 2>&1 |tee $LOG_DIR/backtrace.log -bash scripts/amd/post.sh # dont double call \ No newline at end of file +# bash scripts/amd/post.sh # dont double call \ No newline at end of file