diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index f143f8354..01bf51a28 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -817,9 +817,9 @@ void generator::visit_mma884(ir::dot_inst* C, ir::value *A, ir::value *B, ir::va // update accumulators unsigned num_m = layout_c->rep(0) * shape_c[0] / layout_c->spt(0); unsigned num_n = layout_c->rep(1) * shape_c[1] / layout_c->spt(1); + for(unsigned K = 0; K < NK; K += 4) for(unsigned m = 0; m < num_m/2; m++) - for(unsigned n = 0; n < num_n/2; n++) - for(unsigned K = 0; K < NK; K += 4){ + for(unsigned n = 0; n < num_n/2; n++) { if(has.find({m, K}) == has.end()){ Value* ptra = ptr_a[(is_a_row ? K/4 : m) % num_ptr_a]; int step_am = is_a_row ? m : m / (num_ptr_a)*(num_ptr_a); diff --git a/python/src/triton.cc b/python/src/triton.cc index 17d720f78..29efa6f6b 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -7,6 +7,7 @@ #include "triton/ir/enums.h" #include "triton/ir/function.h" #include "triton/ir/module.h" +#include "triton/ir/print.h" #include #include #include @@ -78,7 +79,9 @@ void init_triton_codegen(py::module &&m) { drv::kernel *ker; size_t shared_mem; triton::codegen::add_passes_to_emit_bin(ir, dev, num_warps, mod, ker, shared_mem); - return std::make_tuple(mod, ker, shared_mem); + std::stringstream ss; + ir::print(ir, ss); + return std::make_tuple(mod, ker, shared_mem, ss.str()); }, py::return_value_policy::take_ownership); } diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 96a1a9d2a..a357eb5b0 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -387,13 +387,17 @@ class CodeGenerator(ast.NodeVisitor): class Binary: - def __init__(self, module, kernel, num_warps, shared_mem): + def __init__(self, module, kernel, num_warps, shared_mem, ir_asm): + # cache ir asm + self.ir_asm = ir_asm self.module = module self.kernel = kernel self.shared_mem = shared_mem self.num_warps = num_warps def asm(self, mode): + if mode == 'ttir': + return self.ir_asm if mode == 'ptx': return self.module.ptx() if mode == 'llir': @@ -495,8 +499,8 @@ class Kernel: raise CompilationError(self.fn.src, node, e) tt_device = _triton.driver.cu_device(device.index, False) # Compile to machine code - mod, ker, shared_mem = _triton.code_gen.add_passes_to_emit_bin(generator.module, tt_device, num_warps) - return Binary(mod, ker, num_warps, shared_mem) + mod, ker, shared_mem, ir_asm = _triton.code_gen.add_passes_to_emit_bin(generator.module, tt_device, num_warps) + return Binary(mod, ker, num_warps, shared_mem, ir_asm) def __call__(self, *wargs, grid, num_warps=4, **meta): # device inference @@ -576,7 +580,7 @@ class Autotuner: config = self.cache[key] else: config = self.configs[0] - self.kernel(*args, num_warps=config.num_warps, **meta, **config.meta) + return self.kernel(*args, num_warps=config.num_warps, **meta, **config.meta) class JITFunction: