More progress on WhileOp codegen

This commit is contained in:
Yan Da
2022-04-05 15:55:48 +08:00
parent 76d9249724
commit c7ad928e60
5 changed files with 145 additions and 45 deletions

View File

@@ -638,17 +638,6 @@ void init_triton_ir(py::module &&m) {
// // py::class_<ir::undef_value, ir::constant>(m, "undef")
// // .def("get", &ir::undef_value::get, ret::reference);
py::class_<mlir::ModuleOp>(m, "module")
// .def("set_attr")
.def("dump", [](mlir::ModuleOp &self) -> void {
self.dump();
})
.def("push_back", [](mlir::ModuleOp &self, mlir::FuncOp &funcOp) -> void {
self.push_back(funcOp);
})
.def("get_context", &mlir::ModuleOp::getContext)
;
py::class_<mlir::Type>(m, "type")
.def("is_integer", &mlir::Type::isInteger)
.def("is_fp16", &mlir::Type::isF16)
@@ -753,13 +742,27 @@ void init_triton_ir(py::module &&m) {
.def("get_else_yield", &mlir::scf::IfOp::elseYield)
;
py::class_<mlir::scf::YieldOp, mlir::OpState>(m, "YieldOp");
py::class_<mlir::scf::WhileOp, mlir::OpState>(m, "WhileOp")
.def("get_before", &mlir::scf::WhileOp::getBefore, ret::reference)
.def("get_after", &mlir::scf::WhileOp::getAfter, ret::reference);
py::class_<mlir::scf::ConditionOp, mlir::OpState>(m, "CondtionOp");
py::class_<mlir::ModuleOp, mlir::OpState>(m, "module")
// .def("set_attr")
.def("dump", [](mlir::ModuleOp &self) -> void {
self.dump();
})
.def("push_back", [](mlir::ModuleOp &self, mlir::FuncOp &funcOp) -> void {
self.push_back(funcOp);
})
;
py::class_<mlir::OpBuilder::InsertPoint>(m, "InsertPoint");
py::class_<mlir::OpBuilder>(m, "builder", py::dynamic_attr())
.def(py::init<mlir::MLIRContext *>())
// // getters
// .def_property_readonly("context", &ir::builder::get_context, ret::reference);
.def_property_readonly("context", &mlir::OpBuilder::getContext, ret::reference)
.def("create_module", [](mlir::OpBuilder &self) -> mlir::ModuleOp {
auto loc = self.getUnknownLoc();
return self.create<mlir::ModuleOp>(loc);
@@ -883,6 +886,9 @@ void init_triton_ir(py::module &&m) {
mlir::Region *parent = self.getBlock()->getParent();
return self.createBlock(parent);
}, ret::reference)
.def("create_block_with_parent", [](mlir::OpBuilder &self, mlir::Region &parent) -> mlir::Block* {
return self.createBlock(&parent);
})
.def("new_block", [](mlir::OpBuilder &self) -> mlir::Block* {
return new mlir::Block();
}, ret::reference)
@@ -900,7 +906,16 @@ void init_triton_ir(py::module &&m) {
auto loc = self.getUnknownLoc();
return self.create<mlir::scf::YieldOp>(loc, yields);
})
// // .def("create_while")
.def("create_while_op", [](mlir::OpBuilder &self, std::vector<mlir::Type> &retTypes,
std::vector<mlir::Value> &initArgs) -> mlir::scf::WhileOp {
auto loc = self.getUnknownLoc();
return self.create<mlir::scf::WhileOp>(loc, retTypes, initArgs);
})
.def("create_condtion_op", [](mlir::OpBuilder &self, mlir::Value &cond,
std::vector<mlir::Value> &args) -> mlir::scf::ConditionOp {
auto loc = self.getUnknownLoc();
return self.create<mlir::scf::ConditionOp>(loc, cond, args);
})
// miscellious
.def("create_make_range", [](mlir::OpBuilder &self, int start, int end) -> mlir::Value {