diff --git a/python/src/triton.cc b/python/src/triton.cc index 27a31b981..66cece243 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -659,12 +659,35 @@ void init_triton_ir(py::module &&m) { py::class_(m, "block_arguement") ; + py::class_(m, "region") + .def("get_parent_region", &mlir::Region::getParentRegion, ret::reference) + .def("size", [](mlir::Region &self) { + return self.getBlocks().size(); + }) + ; + py::class_(m, "block") .def("arg", [](mlir::Block &self, int index) -> mlir::BlockArgument { return self.getArgument(index); }) .def("dump", &mlir::Block::dump) .def("move_before", &mlir::Block::moveBefore) + .def("insert_before", &mlir::Block::insertBefore) + .def("get_parent", &mlir::Block::getParent, ret::reference) + .def("merge_block_before", [](mlir::Block &self, mlir::Block &dst) { + // ref: RewriterBase::mergeBlocks() + if (self.getNumArguments() != 0) + throw std::runtime_error("This block has arguments, don't merge"); + dst.getOperations().splice(dst.end(), self.getOperations()); + self.dropAllUses(); + self.erase(); + }) + .def("replace_use_in_block_with", [](mlir::Block &self, mlir::Value &v, mlir::Value &newVal) { + v.replaceUsesWithIf(newVal, [&](mlir::OpOperand &operand){ + mlir::Operation *user = operand.getOwner(); + return user->getBlock() == &self; + }); + }) ; // py::class_(m, "module") @@ -713,12 +736,13 @@ void init_triton_ir(py::module &&m) { .def("get_region", [](mlir::OpState &self, unsigned idx) -> mlir::Region& { return self->getRegion(idx); }, ret::reference) + .def("get_body", [](mlir::scf::ForOp &self, unsigned idx) -> mlir::Block* { + return self.getBody(idx); + }, ret::reference) + .def("dump", [](mlir::OpState &self) { self->dump(); }) ; // scf Ops - py::class_(m, "ForOp") - .def("get_body", [](mlir::scf::ForOp &self) -> mlir::Block* { - return self.getBody(); - }, ret::reference); + py::class_(m, "ForOp"); py::class_(m, "IfOp") .def("get_then_block", &mlir::scf::IfOp::thenBlock, ret::reference) .def("get_else_block", &mlir::scf::IfOp::elseBlock, ret::reference) @@ -856,11 +880,14 @@ void init_triton_ir(py::module &&m) { mlir::Region *parent = self.getBlock()->getParent(); return self.createBlock(parent); }, ret::reference) + .def("new_block", [](mlir::OpBuilder &self) -> mlir::Block* { + return new mlir::Block(); + }, ret::reference) // Structured control flow .def("create_for_op", [](mlir::OpBuilder &self, mlir::Value &lb, mlir::Value &ub, mlir::Value &step, std::vector &initArgs) -> mlir::scf::ForOp { auto loc = self.getUnknownLoc(); - return self.create(loc, lb, ub, step); + return self.create(loc, lb, ub, step, initArgs); }) .def("create_if_of", [](mlir::OpBuilder &self, mlir::Value &condition) -> mlir::scf::IfOp { auto loc = self.getUnknownLoc(); diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 5890d7358..dcf038136 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -195,7 +195,6 @@ class CodeGenerator(ast.NodeVisitor): assert len(_names) == 1 names = _names[0] values = self.visit(node.value) - print(f'visit_Assign({names}, {values})') if not isinstance(names, tuple): names = [names] if not isinstance(values, tuple): @@ -440,34 +439,30 @@ class CodeGenerator(ast.NodeVisitor): init_args.append(triton.language.core._to_tensor(liveins[name], self.builder)) yields.append(triton.language.core._to_tensor(self.local_defs[name], self.builder)) - print(f'names: {names}') - print("After insert block body") - self.module.dump() self.builder.set_insertion_point_to_end(insert_block) for_op = self.builder.create_for_op(lb, ub, step, [arg.handle for arg in init_args]) - print("After create for op") - self.module.dump() - # TODO: revisit blocks & regions of ForOp - block.move_before(for_op.get_body()) - self.builder.set_insertion_point_to_end(for_op.get_body()) + # FIXME: the body should be a region (?) + # FIXME: this won't work for nested control flow + block.merge_block_before(for_op.get_body(0)) + self.builder.set_insertion_point_to_end(for_op.get_body(0)) self.builder.create_yield_op([y.handle for y in yields]) - print("After create yield op") - self.module.dump() - + for_op_region = for_op.get_body(0).get_parent() + assert for_op_region.size() == 1, "(For developer) Should use region here" + for i, name in enumerate(names): + # arg0 is the induction variable + for_op.get_body(0).replace_use_in_block_with(init_args[i].handle, for_op.get_body(0).arg(i+1)) self.builder.set_insertion_point_to_end(insert_block) - print("After restoring insert point") - self.module.dump() self.lscope = liveins self.local_defs = prev_defs # ForOp defines new values for i, name in enumerate(names): - self.lscope[name] = for_op.get_result(i) - self.local_defs[name] = for_op.get_result(i) + self.lscope[name] = triton.language.core.tensor(for_op.get_result(i), yields[i].type) + self.local_defs[name] = triton.language.core.tensor(for_op.get_result(i), yields[i].type) for stmt in node.orelse: - assert False + assert False, "Don't know what to do with else after for" ast.NodeVisitor.generic_visit(self, stmt) def visit_Slice(self, node):