diff --git a/python/src/triton.cc b/python/src/triton.cc index d7e76a3b8..27a31b981 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -664,6 +664,7 @@ void init_triton_ir(py::module &&m) { return self.getArgument(index); }) .def("dump", &mlir::Block::dump) + .def("move_before", &mlir::Block::moveBefore) ; // py::class_(m, "module") @@ -705,15 +706,28 @@ void init_triton_ir(py::module &&m) { ; // Ops - py::class_(m, "ForOp") - .def("get_body", &mlir::scf::ForOp::getBody, ret::reference); - py::class_(m, "IfOp") + py::class_(m, "OpState") + .def("get_result", [](mlir::OpState &self, unsigned idx) -> mlir::Value { + return self->getResult(idx); + }) + .def("get_region", [](mlir::OpState &self, unsigned idx) -> mlir::Region& { + return self->getRegion(idx); + }, ret::reference) + ; + // scf Ops + py::class_(m, "ForOp") + .def("get_body", [](mlir::scf::ForOp &self) -> mlir::Block* { + return self.getBody(); + }, ret::reference); + py::class_(m, "IfOp") .def("get_then_block", &mlir::scf::IfOp::thenBlock, ret::reference) .def("get_else_block", &mlir::scf::IfOp::elseBlock, ret::reference) .def("get_then_yield", &mlir::scf::IfOp::thenYield) .def("get_else_yield", &mlir::scf::IfOp::elseYield) ; - py::class_(m, "YieldOp"); + py::class_(m, "YieldOp"); + + py::class_(m, "InsertPoint"); py::class_(m, "builder", py::dynamic_attr()) .def(py::init()) @@ -740,12 +754,8 @@ void init_triton_ir(py::module &&m) { .def("get_insertion_block", [](mlir::OpBuilder &self) -> mlir::Block* { return self.getInsertionBlock(); }, ret::reference) - // .def("get_insert_point", [](ir::builder *self) { - // ir::basic_block *bb = self->get_insert_block(); - // ir::basic_block::iterator it = self->get_insert_point(); - // ir::instruction *instr = it == bb->end() ? nullptr : *it; - // return std::make_pair(bb, instr); - // }, ret::reference) + .def("get_insertion_point", &mlir::OpBuilder::saveInsertionPoint) + .def("restore_insertion_point", &mlir::OpBuilder::restoreInsertionPoint) // .def("set_insert_point", [](ir::builder *self, std::pair pt) { // ir::basic_block *bb = pt.first; // ir::instruction *instr = pt.second; @@ -856,9 +866,9 @@ void init_triton_ir(py::module &&m) { auto loc = self.getUnknownLoc(); return self.create(loc, condition); }) - .def("create_yield_op", [](mlir::OpBuilder &self) -> mlir::scf::YieldOp { + .def("create_yield_op", [](mlir::OpBuilder &self, std::vector &yields) -> mlir::scf::YieldOp { auto loc = self.getUnknownLoc(); - return self.create(loc); + return self.create(loc, yields); }) // // .def("create_while") diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 690f6cfc4..5890d7358 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -415,8 +415,10 @@ class CodeGenerator(ast.NodeVisitor): ub = triton.language.core._to_tensor(ub, self.builder).handle step = triton.language.core._to_tensor(step, self.builder).handle - loop_body = self.builder.create_block() - self.builder.set_insertion_point_to_start(loop_body) + insert_block = self.builder.get_insertion_block() + + block = self.builder.create_block() + self.builder.set_insertion_point_to_start(block) liveins = self.lscope.copy() prev_defs = self.local_defs.copy() @@ -425,20 +427,44 @@ class CodeGenerator(ast.NodeVisitor): # visit loop body self.visit_compound_statement(node.body) - # TODO: update insertion point init_args = [] yields = [] + names = [] for name in self.local_defs: if name in liveins: assert self.is_triton_tensor(self.local_defs[name]) assert self.is_triton_tensor(liveins[name]) if self.local_defs[name].type == liveins[name].type: - init_args.append(triton.language.core._to_tensor(liveins[name], self.builder).handle) - yields.append(triton.language.core._to_tensor(self.local_defs[name], self.builder).handle) - for_op = self.builder.create_for_op(lb, ub, step, init_args) + # TODO: better way to do this? + names.append(name) + init_args.append(triton.language.core._to_tensor(liveins[name], self.builder)) + yields.append(triton.language.core._to_tensor(self.local_defs[name], self.builder)) + + print(f'names: {names}') + print("After insert block body") + self.module.dump() + self.builder.set_insertion_point_to_end(insert_block) + for_op = self.builder.create_for_op(lb, ub, step, [arg.handle for arg in init_args]) + print("After create for op") + self.module.dump() + # TODO: revisit blocks & regions of ForOp + block.move_before(for_op.get_body()) + self.builder.set_insertion_point_to_end(for_op.get_body()) + self.builder.create_yield_op([y.handle for y in yields]) + print("After create yield op") + self.module.dump() + + + self.builder.set_insertion_point_to_end(insert_block) + print("After restoring insert point") + self.module.dump() self.lscope = liveins self.local_defs = prev_defs + # ForOp defines new values + for i, name in enumerate(names): + self.lscope[name] = for_op.get_result(i) + self.local_defs[name] = for_op.get_result(i) for stmt in node.orelse: assert False