diff --git a/python/src/triton.cc b/python/src/triton.cc index 9dae11ba5..e88e22641 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -14,6 +14,7 @@ #include "llvm/IR/LegacyPassManager.h" #include "llvm/IR/Verifier.h" +#include #include #include #include @@ -653,12 +654,14 @@ void init_triton_ir(py::module &&m) { .def("size", [](mlir::Region &self) { return self.getBlocks().size(); }) + .def("empty", &mlir::Region::empty) ; py::class_(m, "block") .def("arg", [](mlir::Block &self, int index) -> mlir::BlockArgument { return self.getArgument(index); }) + .def("get_num_arguments", &mlir::Block::getNumArguments) .def("dump", &mlir::Block::dump) .def("move_before", &mlir::Block::moveBefore) .def("insert_before", &mlir::Block::insertBefore) @@ -674,7 +677,14 @@ void init_triton_ir(py::module &&m) { .def("replace_use_in_block_with", [](mlir::Block &self, mlir::Value &v, mlir::Value &newVal) { v.replaceUsesWithIf(newVal, [&](mlir::OpOperand &operand){ mlir::Operation *user = operand.getOwner(); - return user->getBlock() == &self; + mlir::Block *currentBlock = user->getBlock(); + while (currentBlock) { + if (currentBlock == &self) + return true; + // Move up one level + currentBlock = currentBlock->getParent()->getParentOp()->getBlock(); + } + return false; }); }) ; @@ -886,8 +896,11 @@ void init_triton_ir(py::module &&m) { mlir::Region *parent = self.getBlock()->getParent(); return self.createBlock(parent); }, ret::reference) - .def("create_block_with_parent", [](mlir::OpBuilder &self, mlir::Region &parent) -> mlir::Block* { - return self.createBlock(&parent); + .def("create_block_with_parent", [](mlir::OpBuilder &self, mlir::Region &parent, + std::vector &argTypes) -> mlir::Block* { + auto argLoc = self.getUnknownLoc(); + llvm::SmallVector argLocs(argTypes.size(), argLoc); + return self.createBlock(&parent, {}, argTypes, argLocs); }, ret::reference) .def("new_block", [](mlir::OpBuilder &self) -> mlir::Block* { return new mlir::Block(); diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 3313c5328..d0dd38bf6 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -425,23 +425,32 @@ class CodeGenerator(ast.NodeVisitor): if loop_defs[name].type == liveins[name].type: # these are loop-carried values names.append(name) - ret_types.append(loop_defs[name].type.to_ir(self.builder)) + ret_types.append(loop_defs[name].type) init_args.append(liveins[name]) yields.append(loop_defs[name]) self.builder.set_insertion_point_to_end(insert_block) - while_op = self.builder.create_while_op(ret_types, init_args) + while_op = self.builder.create_while_op([ty.to_ir(self.builder) for ty in ret_types], + [arg.handle for arg in init_args]) # merge the condition region - before_block = self.builder.create_block_with_parent(while_op.get_before()) + before_block = self.builder.create_block_with_parent(while_op.get_before(), + [ty.to_ir(self.builder) for ty in ret_types]) cond_block.merge_block_before(before_block) self.builder.set_insertion_point_to_end(before_block) - self.builder.create_condtion_op(cond.handle, []) + # create CondtionOp: e.g., scf.condition(%cond) %arg0, %arg1, ... + self.builder.create_condtion_op(cond.handle, [before_block.arg(i) for i in range(len(init_args))]) # merge the loop body - after_block = self.builder.create_block_with_parent(while_op.get_after()) + after_block = self.builder.create_block_with_parent(while_op.get_after(), + [ty.to_ir(self.builder) for ty in ret_types]) loop_block.merge_block_before(after_block) self.builder.set_insertion_point_to_end(after_block) self.builder.create_yield_op([y.handle for y in yields]) + # update global_uses in while_op + for i, name in enumerate(names): + before_block.replace_use_in_block_with(init_args[i].handle, before_block.arg(i)) + after_block.replace_use_in_block_with(init_args[i].handle, after_block.arg(i)) + self.builder.set_insertion_point_to_end(insert_block) self.lscope = liveins self.local_defs = prev_defs diff --git a/rewrite-test/jit/vecadd-loop.py b/rewrite-test/jit/vecadd-loop.py index 1e02c27e0..49a81a230 100644 --- a/rewrite-test/jit/vecadd-loop.py +++ b/rewrite-test/jit/vecadd-loop.py @@ -47,6 +47,5 @@ z = torch.empty_like(x) # add_kernel[(1,)](x, y, z, size, 256) # print(add_kernel[(1,)].kernel.compile_to_ttir()) mod, ctx = add_kernel.compile_to_ttir(x, y, z, size, 128, 8, grid=(1,)) -mod.get_context() mod.dump() # print(mod) diff --git a/rewrite-test/jit/while.py b/rewrite-test/jit/while.py index 25f41e86d..15d49c75b 100644 --- a/rewrite-test/jit/while.py +++ b/rewrite-test/jit/while.py @@ -16,3 +16,6 @@ def generic_while(lb, value): locks = torch.zeros(32, dtype=torch.int32, device='cuda') mod_atomic, ctx_atomic = atomic.compile_to_ttir(locks, grid=(1,)) mod_atomic.dump() + +mod_generic_while, ctx_generic_while = generic_while.compile_to_ttir(8, 9, grid=(1,)) +mod_generic_while.dump()