Device function & PassManager

This commit is contained in:
Yan Da
2022-04-15 14:41:57 +08:00
parent 44d75cf9bb
commit 1c52bd587d
5 changed files with 464 additions and 92 deletions

View File

@@ -8,6 +8,10 @@
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"
#include "triton/ir/Dialect.h"
#include "triton/ir/Types.h"
@@ -717,6 +721,9 @@ void init_triton_ir(py::module &&m) {
.def("set_attr", [](mlir::OpState &self, std::string &name, mlir::Attribute &attr) -> void {
self->setAttr(name, attr);
})
.def("get_num_results", [](mlir::OpState &self) -> unsigned {
return self->getNumResults();
})
.def("get_result", [](mlir::OpState &self, unsigned idx) -> mlir::Value {
return self->getResult(idx);
})
@@ -755,12 +762,18 @@ void init_triton_ir(py::module &&m) {
py::class_<mlir::scf::ConditionOp, mlir::OpState>(m, "CondtionOp");
py::class_<mlir::ModuleOp, mlir::OpState>(m, "module")
.def("dump", [](mlir::ModuleOp &self) -> void {
self.dump();
})
.def("dump", &mlir::ModuleOp::dump)
.def("push_back", [](mlir::ModuleOp &self, mlir::FuncOp &funcOp) -> void {
self.push_back(funcOp);
})
.def("has_function", [](mlir::ModuleOp &self, std::string &funcName) -> bool {
if (self.lookupSymbol(funcName))
return true;
return false;
})
.def("get_function", [](mlir::ModuleOp &self, std::string &funcName) -> mlir::FuncOp {
return self.lookupSymbol<mlir::FuncOp>(funcName);
})
;
py::class_<mlir::FuncOp, mlir::OpState>(m, "function")
@@ -772,6 +785,7 @@ void init_triton_ir(py::module &&m) {
.def("add_entry_block", [](mlir::FuncOp &self) -> mlir::Block* {
return self.addEntryBlock();
}, ret::reference)
.def("reset_type", &mlir::FuncOp::setType)
;
py::class_<mlir::OpBuilder::InsertPoint>(m, "InsertPoint");
@@ -784,11 +798,14 @@ void init_triton_ir(py::module &&m) {
auto loc = self.getUnknownLoc();
return self.create<mlir::ModuleOp>(loc);
})
// control flow
.def("ret_void", [](mlir::OpBuilder &self) {
.def("ret", [](mlir::OpBuilder &self, std::vector<mlir::Value> &vals) -> void {
auto loc = self.getUnknownLoc();
self.create<mlir::ReturnOp>(loc);
}, ret::reference)
self.create<mlir::ReturnOp>(loc, vals);
})
.def("call", [](mlir::OpBuilder &self, mlir::FuncOp &func, std::vector<mlir::Value> &args) -> mlir::OpState {
auto loc = self.getUnknownLoc();
return self.create<mlir::CallOp>(loc, func, args);
})
// insertion block/point
.def("set_insertion_point_to_start", [](mlir::OpBuilder &self, mlir::Block &block) -> void {
self.setInsertionPointToStart(&block);
@@ -900,6 +917,16 @@ void init_triton_ir(py::module &&m) {
}
throw std::runtime_error("invalid function type");
})
.def("get_or_insert_function", [](mlir::OpBuilder &self, mlir::ModuleOp &module,
std::string &funcName, mlir::Type &funcType) -> mlir::FuncOp {
if (mlir::Operation *funcOperation = module.lookupSymbol(funcName))
return llvm::dyn_cast<mlir::FuncOp>(funcOperation);
auto loc = self.getUnknownLoc();
if (auto funcTy = funcType.dyn_cast<mlir::FunctionType>()) {
return self.create<mlir::FuncOp>(loc, funcName, funcTy);
}
throw std::runtime_error("invalid function type");
})
.def("create_block", [](mlir::OpBuilder &self) -> mlir::Block* {
mlir::Region *parent = self.getBlock()->getParent();
return self.createBlock(parent);
@@ -1293,6 +1320,16 @@ void init_triton_ir(py::module &&m) {
// .def("create_umulhi", &ir::builder::create_umulhi, ret::reference)
// .def("create_barrier", &ir::builder::create_barrier, ret::reference);
;
py::class_<mlir::PassManager>(m, "pass_manager")
.def(py::init<mlir::MLIRContext *>())
.def("run", [](mlir::PassManager &self, mlir::ModuleOp &mod) {
self.run(mod.getOperation());
})
.def("add_inliner_pass", [](mlir::PassManager &self) {
self.addPass(mlir::createInlinerPass());
})
;
}
void init_triton(py::module &m) {