More on scf Ops

This commit is contained in:
Yan Da
2022-03-31 21:42:48 +08:00
parent 2041b67fbf
commit 4ad432f1fc
2 changed files with 119 additions and 77 deletions

View File

@@ -711,6 +711,17 @@ void init_triton_ir(py::module &&m) {
.def("dump", [](mlir::FuncOp &self) { self.dump(); })
;
// Ops
py::class_<mlir::scf::ForOp>(m, "ForOp")
.def("get_body", &mlir::scf::ForOp::getBody, ret::reference);
py::class_<mlir::scf::IfOp>(m, "IfOp")
.def("get_then_block", &mlir::scf::IfOp::thenBlock, ret::reference)
.def("get_else_block", &mlir::scf::IfOp::elseBlock, ret::reference)
.def("get_then_yield", &mlir::scf::IfOp::thenYield)
.def("get_else_yield", &mlir::scf::IfOp::elseYield)
;
py::class_<mlir::scf::YieldOp>(m, "YieldOp");
py::class_<mlir::OpBuilder>(m, "builder", py::dynamic_attr())
.def(py::init<mlir::MLIRContext *>())
// // getters
@@ -839,21 +850,20 @@ void init_triton_ir(py::module &&m) {
}
throw std::runtime_error("invalid function type");
})
// // Structured control flow
// .def("create_for", [](mlir::OpBuilder &self, mlir::Value &lb, mlir::Value &ub,
// mlir::Value &step, std::vector<mlir::Value> &initArgs) -> MlirOperation {
// auto loc = self.getUnknownLoc();
// return wrap(self.create<mlir::scf::ForOp>(
// loc, unwrap(lb), unwrap(ub), unwrap(step)).getOperation());
// })
// .def("create_if", [](mlir::OpBuilder &self, mlir::Value &condition) -> MlirOperation {
// auto loc = self.getUnknownLoc();
// return wrap(self.create<mlir::scf::IfOp>(loc, unwrap(condition)).getOperation());
// })
// .def("create_yield", [](mlir::OpBuilder &self) -> MlirOperation {
// auto loc = self.getUnknownLoc();
// return wrap(self.create<mlir::scf::YieldOp>(loc).getOperation());
// })
// Structured control flow
.def("create_for_op", [](mlir::OpBuilder &self, mlir::Value &lb, mlir::Value &ub,
mlir::Value &step, std::vector<mlir::Value> &initArgs) -> mlir::scf::ForOp {
auto loc = self.getUnknownLoc();
return self.create<mlir::scf::ForOp>(loc, lb, ub, step);
})
.def("create_if_of", [](mlir::OpBuilder &self, mlir::Value &condition) -> mlir::scf::IfOp {
auto loc = self.getUnknownLoc();
return self.create<mlir::scf::IfOp>(loc, condition);
})
.def("create_yield_op", [](mlir::OpBuilder &self) -> mlir::scf::YieldOp {
auto loc = self.getUnknownLoc();
return self.create<mlir::scf::YieldOp>(loc);
})
// // .def("create_while")
// miscellious