add compiling back to gcn
This commit is contained in:
@@ -437,7 +437,7 @@ typedef std::map<std::string, py::object> asm_map_t;
|
||||
// ---------------------------------------
|
||||
|
||||
void init_triton_codegen(py::module &&m) {
|
||||
m.def("compile_ttir",
|
||||
m.def("compile_ttir_to_ptx",
|
||||
[](backend_t backend, ir::module &ir, uint64_t device, int num_warps, int num_stages, py::dict& extern_libs, size_t cc) {
|
||||
std::ostringstream ttir;
|
||||
int n_shared_bytes;
|
||||
@@ -492,6 +492,62 @@ void init_triton_codegen(py::module &&m) {
|
||||
},
|
||||
py::return_value_policy::take_ownership);
|
||||
|
||||
m.def("compile_ttir_to_amdgpu",
|
||||
[](backend_t backend, ir::module &ir, uint64_t device, int num_warps, int num_stages, py::dict& extern_libs, size_t cc) {
|
||||
std::ostringstream ttir;
|
||||
int n_shared_bytes;
|
||||
std::string tmp;
|
||||
std::string amdgpu;
|
||||
std::string hipmodule;
|
||||
std::string name;
|
||||
{
|
||||
// Scope where the GIL is released
|
||||
py::gil_scoped_release allow_threads;
|
||||
name = ir.get_function_list()[0]->get_name();
|
||||
ir.print(ttir);
|
||||
llvm::LLVMContext ctx;
|
||||
// construct extern lib map
|
||||
triton::codegen::ExternLibMap extern_lib_map;
|
||||
for (auto item : extern_libs) {
|
||||
auto name = item.first.cast<std::string>();
|
||||
auto path = item.second.cast<std::string>();
|
||||
extern_lib_map.emplace(
|
||||
name, triton::codegen::create_extern_lib(name, path));
|
||||
}
|
||||
// device properties
|
||||
if (cc == 0) {
|
||||
hipDevice_t dev = (hipDevice_t)device;
|
||||
size_t major = hipGetInfo<hipDeviceAttributeComputeCapabilityMajor>(dev);
|
||||
size_t minor = hipGetInfo<hipDeviceAttributeComputeCapabilityMinor>(dev);
|
||||
cc = major*10 + minor;
|
||||
}
|
||||
int version;
|
||||
// std::string ptxas_path = drv::path_to_ptxas(version);
|
||||
// Triton-IR -> AMDGCN LLVM-IR
|
||||
triton::codegen::amd_cl_target target;
|
||||
auto llvm = triton::codegen::add_passes_to_emit_bin(
|
||||
ir, ctx, &target, num_warps, num_stages, n_shared_bytes, extern_lib_map);
|
||||
llvm::raw_string_ostream llir(tmp);
|
||||
llir << *llvm;
|
||||
llir.flush();
|
||||
// LLVM-IR -> AMD HSACO
|
||||
std::string amdgpu = drv::llir_to_amdgpu(llvm.get(), "gfx90a");
|
||||
// HSACO -> GCN
|
||||
hipModule_t hipmodule = drv::amdgpu_to_hipmodule(amdgpu);
|
||||
}
|
||||
asm_map_t asm_map;
|
||||
asm_map["ttir"] = py::cast(ttir.str());
|
||||
asm_map["llir"] = py::cast(tmp);
|
||||
asm_map["amdgpu"] = py::cast(amdgpu);
|
||||
|
||||
if(!hipmodule.empty()){
|
||||
py::bytes bytes(hipmodule);
|
||||
asm_map["hipmodule"] = bytes;
|
||||
}
|
||||
return std::make_tuple(name, asm_map, n_shared_bytes);
|
||||
},
|
||||
py::return_value_policy::take_ownership);
|
||||
|
||||
|
||||
// ---------------------------------------
|
||||
// Load provided assembly code into driver
|
||||
|
@@ -893,14 +893,19 @@ def _compile(fn, signature: str, device: int = -1, constants=dict(),
|
||||
if output == "ttir":
|
||||
return module
|
||||
|
||||
assert output == "cubin"
|
||||
assert (output == "cubin" or output == "hsaco")
|
||||
if torch.version.hip is not None:
|
||||
backend = _triton.runtime.backend.ROCM
|
||||
else:
|
||||
backend = _triton.runtime.backend.CUDA
|
||||
if extern_libs is None:
|
||||
extern_libs = dict()
|
||||
name, asm, shared_mem = _triton.code_gen.compile_ttir(backend, module, device, num_warps, num_stages, extern_libs, cc)
|
||||
|
||||
# compile ttir
|
||||
if torch.version.hip is not None:
|
||||
name, asm, shared_mem = _triton.code_gen.compile_ttir_to_amdgpu(backend, module, device, num_warps, num_stages, extern_libs, cc)
|
||||
else:
|
||||
name, asm, shared_mem = _triton.code_gen.compile_ttir_to_ptx(backend, module, device, num_warps, num_stages, extern_libs, cc)
|
||||
return asm, shared_mem, name
|
||||
|
||||
|
||||
@@ -1274,8 +1279,13 @@ def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: i
|
||||
not fn_cache_manager.has_file(ptx_name) or \
|
||||
not fn_cache_manager.has_file(ttir_name) or \
|
||||
not fn_cache_manager.has_file(llir_name):
|
||||
asm, shared, kernel_name = _compile(fn, signature, device, constants, configs[0], num_warps, num_stages,
|
||||
extern_libs, "cubin", cc)
|
||||
|
||||
if torch.version.hip is not None:
|
||||
asm, shared, kernel_name = _compile(fn, signature, device, constants, configs[0], num_warps, num_stages,
|
||||
extern_libs, "hsaco", cc)
|
||||
else
|
||||
asm, shared, kernel_name = _compile(fn, signature, device, constants, configs[0], num_warps, num_stages,
|
||||
extern_libs, "cubin", cc)
|
||||
metadata = {"name": kernel_name, "shared": shared, "num_warps": num_warps, "num_stages": num_stages}
|
||||
fn_cache_manager.put(asm["cubin"], cubin_name)
|
||||
fn_cache_manager.put(asm["ptx"], ptx_name, binary=False)
|
||||
|
@@ -15,4 +15,4 @@ bash scripts/amd/test.sh 2>&1 |tee $LOG_DIR/test.log
|
||||
# bash scripts/amd/backtrace.sh 2>&1 |tee $LOG_DIR/backtrace.log
|
||||
|
||||
|
||||
bash scripts/amd/post.sh # dont double call
|
||||
# bash scripts/amd/post.sh # dont double call
|
Reference in New Issue
Block a user