ForOp's SSA construction

This commit is contained in:
Yan Da
2022-04-03 19:11:47 +08:00
parent 61413b8a97
commit c71c50cd0c
2 changed files with 44 additions and 22 deletions

View File

@@ -659,12 +659,35 @@ void init_triton_ir(py::module &&m) {
py::class_<mlir::BlockArgument, mlir::Value>(m, "block_arguement")
;
py::class_<mlir::Region>(m, "region")
.def("get_parent_region", &mlir::Region::getParentRegion, ret::reference)
.def("size", [](mlir::Region &self) {
return self.getBlocks().size();
})
;
py::class_<mlir::Block>(m, "block")
.def("arg", [](mlir::Block &self, int index) -> mlir::BlockArgument {
return self.getArgument(index);
})
.def("dump", &mlir::Block::dump)
.def("move_before", &mlir::Block::moveBefore)
.def("insert_before", &mlir::Block::insertBefore)
.def("get_parent", &mlir::Block::getParent, ret::reference)
.def("merge_block_before", [](mlir::Block &self, mlir::Block &dst) {
// ref: RewriterBase::mergeBlocks()
if (self.getNumArguments() != 0)
throw std::runtime_error("This block has arguments, don't merge");
dst.getOperations().splice(dst.end(), self.getOperations());
self.dropAllUses();
self.erase();
})
.def("replace_use_in_block_with", [](mlir::Block &self, mlir::Value &v, mlir::Value &newVal) {
v.replaceUsesWithIf(newVal, [&](mlir::OpOperand &operand){
mlir::Operation *user = operand.getOwner();
return user->getBlock() == &self;
});
})
;
// py::class_<mlir::ModuleOp>(m, "module")
@@ -713,12 +736,13 @@ void init_triton_ir(py::module &&m) {
.def("get_region", [](mlir::OpState &self, unsigned idx) -> mlir::Region& {
return self->getRegion(idx);
}, ret::reference)
.def("get_body", [](mlir::scf::ForOp &self, unsigned idx) -> mlir::Block* {
return self.getBody(idx);
}, ret::reference)
.def("dump", [](mlir::OpState &self) { self->dump(); })
;
// scf Ops
py::class_<mlir::scf::ForOp, mlir::OpState>(m, "ForOp")
.def("get_body", [](mlir::scf::ForOp &self) -> mlir::Block* {
return self.getBody();
}, ret::reference);
py::class_<mlir::scf::ForOp, mlir::OpState>(m, "ForOp");
py::class_<mlir::scf::IfOp, mlir::OpState>(m, "IfOp")
.def("get_then_block", &mlir::scf::IfOp::thenBlock, ret::reference)
.def("get_else_block", &mlir::scf::IfOp::elseBlock, ret::reference)
@@ -856,11 +880,14 @@ void init_triton_ir(py::module &&m) {
mlir::Region *parent = self.getBlock()->getParent();
return self.createBlock(parent);
}, ret::reference)
.def("new_block", [](mlir::OpBuilder &self) -> mlir::Block* {
return new mlir::Block();
}, ret::reference)
// Structured control flow
.def("create_for_op", [](mlir::OpBuilder &self, mlir::Value &lb, mlir::Value &ub,
mlir::Value &step, std::vector<mlir::Value> &initArgs) -> mlir::scf::ForOp {
auto loc = self.getUnknownLoc();
return self.create<mlir::scf::ForOp>(loc, lb, ub, step);
return self.create<mlir::scf::ForOp>(loc, lb, ub, step, initArgs);
})
.def("create_if_of", [](mlir::OpBuilder &self, mlir::Value &condition) -> mlir::scf::IfOp {
auto loc = self.getUnknownLoc();