More python bindings

This commit is contained in:
Yan Da
2022-04-01 22:22:39 +08:00
parent 9dafa0e2e3
commit 61413b8a97
2 changed files with 54 additions and 18 deletions

View File

@@ -664,6 +664,7 @@ void init_triton_ir(py::module &&m) {
return self.getArgument(index);
})
.def("dump", &mlir::Block::dump)
.def("move_before", &mlir::Block::moveBefore)
;
// py::class_<mlir::ModuleOp>(m, "module")
@@ -705,15 +706,28 @@ void init_triton_ir(py::module &&m) {
;
// Ops
py::class_<mlir::scf::ForOp>(m, "ForOp")
.def("get_body", &mlir::scf::ForOp::getBody, ret::reference);
py::class_<mlir::scf::IfOp>(m, "IfOp")
py::class_<mlir::OpState>(m, "OpState")
.def("get_result", [](mlir::OpState &self, unsigned idx) -> mlir::Value {
return self->getResult(idx);
})
.def("get_region", [](mlir::OpState &self, unsigned idx) -> mlir::Region& {
return self->getRegion(idx);
}, ret::reference)
;
// scf Ops
py::class_<mlir::scf::ForOp, mlir::OpState>(m, "ForOp")
.def("get_body", [](mlir::scf::ForOp &self) -> mlir::Block* {
return self.getBody();
}, ret::reference);
py::class_<mlir::scf::IfOp, mlir::OpState>(m, "IfOp")
.def("get_then_block", &mlir::scf::IfOp::thenBlock, ret::reference)
.def("get_else_block", &mlir::scf::IfOp::elseBlock, ret::reference)
.def("get_then_yield", &mlir::scf::IfOp::thenYield)
.def("get_else_yield", &mlir::scf::IfOp::elseYield)
;
py::class_<mlir::scf::YieldOp>(m, "YieldOp");
py::class_<mlir::scf::YieldOp, mlir::OpState>(m, "YieldOp");
py::class_<mlir::OpBuilder::InsertPoint>(m, "InsertPoint");
py::class_<mlir::OpBuilder>(m, "builder", py::dynamic_attr())
.def(py::init<mlir::MLIRContext *>())
@@ -740,12 +754,8 @@ void init_triton_ir(py::module &&m) {
.def("get_insertion_block", [](mlir::OpBuilder &self) -> mlir::Block* {
return self.getInsertionBlock();
}, ret::reference)
// .def("get_insert_point", [](ir::builder *self) {
// ir::basic_block *bb = self->get_insert_block();
// ir::basic_block::iterator it = self->get_insert_point();
// ir::instruction *instr = it == bb->end() ? nullptr : *it;
// return std::make_pair(bb, instr);
// }, ret::reference)
.def("get_insertion_point", &mlir::OpBuilder::saveInsertionPoint)
.def("restore_insertion_point", &mlir::OpBuilder::restoreInsertionPoint)
// .def("set_insert_point", [](ir::builder *self, std::pair<ir::basic_block*, ir::instruction*> pt) {
// ir::basic_block *bb = pt.first;
// ir::instruction *instr = pt.second;
@@ -856,9 +866,9 @@ void init_triton_ir(py::module &&m) {
auto loc = self.getUnknownLoc();
return self.create<mlir::scf::IfOp>(loc, condition);
})
.def("create_yield_op", [](mlir::OpBuilder &self) -> mlir::scf::YieldOp {
.def("create_yield_op", [](mlir::OpBuilder &self, std::vector<mlir::Value> &yields) -> mlir::scf::YieldOp {
auto loc = self.getUnknownLoc();
return self.create<mlir::scf::YieldOp>(loc);
return self.create<mlir::scf::YieldOp>(loc, yields);
})
// // .def("create_while")