[PYTHON] Allow triton.code_gen.Binary to print Triton-IR asm. (#89)

This commit is contained in:
daadaada
2021-04-24 02:43:38 +08:00
committed by Philippe Tillet
parent 1112e2526e
commit f6688372db
3 changed files with 14 additions and 7 deletions

View File

@@ -387,13 +387,17 @@ class CodeGenerator(ast.NodeVisitor):
class Binary:
def __init__(self, module, kernel, num_warps, shared_mem):
def __init__(self, module, kernel, num_warps, shared_mem, ir_asm):
# cache ir asm
self.ir_asm = ir_asm
self.module = module
self.kernel = kernel
self.shared_mem = shared_mem
self.num_warps = num_warps
def asm(self, mode):
if mode == 'ttir':
return self.ir_asm
if mode == 'ptx':
return self.module.ptx()
if mode == 'llir':
@@ -495,8 +499,8 @@ class Kernel:
raise CompilationError(self.fn.src, node, e)
tt_device = _triton.driver.cu_device(device.index, False)
# Compile to machine code
mod, ker, shared_mem = _triton.code_gen.add_passes_to_emit_bin(generator.module, tt_device, num_warps)
return Binary(mod, ker, num_warps, shared_mem)
mod, ker, shared_mem, ir_asm = _triton.code_gen.add_passes_to_emit_bin(generator.module, tt_device, num_warps)
return Binary(mod, ker, num_warps, shared_mem, ir_asm)
def __call__(self, *wargs, grid, num_warps=4, **meta):
# device inference
@@ -576,7 +580,7 @@ class Autotuner:
config = self.cache[key]
else:
config = self.configs[0]
self.kernel(*args, num_warps=config.num_warps, **meta, **config.meta)
return self.kernel(*args, num_warps=config.num_warps, **meta, **config.meta)
class JITFunction: