[PYTHON] Allow triton.code_gen.Binary to print Triton-IR asm. (#89)
This commit is contained in:
committed by
Philippe Tillet
parent
1112e2526e
commit
f6688372db
@@ -7,6 +7,7 @@
|
||||
#include "triton/ir/enums.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/print.h"
|
||||
#include <optional>
|
||||
#include <pybind11/buffer_info.h>
|
||||
#include <pybind11/functional.h>
|
||||
@@ -78,7 +79,9 @@ void init_triton_codegen(py::module &&m) {
|
||||
drv::kernel *ker;
|
||||
size_t shared_mem;
|
||||
triton::codegen::add_passes_to_emit_bin(ir, dev, num_warps, mod, ker, shared_mem);
|
||||
return std::make_tuple(mod, ker, shared_mem);
|
||||
std::stringstream ss;
|
||||
ir::print(ir, ss);
|
||||
return std::make_tuple(mod, ker, shared_mem, ss.str());
|
||||
},
|
||||
py::return_value_policy::take_ownership);
|
||||
}
|
||||
|
@@ -387,13 +387,17 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
|
||||
|
||||
class Binary:
|
||||
def __init__(self, module, kernel, num_warps, shared_mem):
|
||||
def __init__(self, module, kernel, num_warps, shared_mem, ir_asm):
|
||||
# cache ir asm
|
||||
self.ir_asm = ir_asm
|
||||
self.module = module
|
||||
self.kernel = kernel
|
||||
self.shared_mem = shared_mem
|
||||
self.num_warps = num_warps
|
||||
|
||||
def asm(self, mode):
|
||||
if mode == 'ttir':
|
||||
return self.ir_asm
|
||||
if mode == 'ptx':
|
||||
return self.module.ptx()
|
||||
if mode == 'llir':
|
||||
@@ -495,8 +499,8 @@ class Kernel:
|
||||
raise CompilationError(self.fn.src, node, e)
|
||||
tt_device = _triton.driver.cu_device(device.index, False)
|
||||
# Compile to machine code
|
||||
mod, ker, shared_mem = _triton.code_gen.add_passes_to_emit_bin(generator.module, tt_device, num_warps)
|
||||
return Binary(mod, ker, num_warps, shared_mem)
|
||||
mod, ker, shared_mem, ir_asm = _triton.code_gen.add_passes_to_emit_bin(generator.module, tt_device, num_warps)
|
||||
return Binary(mod, ker, num_warps, shared_mem, ir_asm)
|
||||
|
||||
def __call__(self, *wargs, grid, num_warps=4, **meta):
|
||||
# device inference
|
||||
@@ -576,7 +580,7 @@ class Autotuner:
|
||||
config = self.cache[key]
|
||||
else:
|
||||
config = self.configs[0]
|
||||
self.kernel(*args, num_warps=config.num_warps, **meta, **config.meta)
|
||||
return self.kernel(*args, num_warps=config.num_warps, **meta, **config.meta)
|
||||
|
||||
|
||||
class JITFunction:
|
||||
|
Reference in New Issue
Block a user