Device function & PassManager
This commit is contained in:
@@ -8,6 +8,10 @@
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Verifier.h"
|
||||
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
|
||||
|
||||
#include "triton/ir/Dialect.h"
|
||||
#include "triton/ir/Types.h"
|
||||
|
||||
@@ -717,6 +721,9 @@ void init_triton_ir(py::module &&m) {
|
||||
.def("set_attr", [](mlir::OpState &self, std::string &name, mlir::Attribute &attr) -> void {
|
||||
self->setAttr(name, attr);
|
||||
})
|
||||
.def("get_num_results", [](mlir::OpState &self) -> unsigned {
|
||||
return self->getNumResults();
|
||||
})
|
||||
.def("get_result", [](mlir::OpState &self, unsigned idx) -> mlir::Value {
|
||||
return self->getResult(idx);
|
||||
})
|
||||
@@ -755,12 +762,18 @@ void init_triton_ir(py::module &&m) {
|
||||
py::class_<mlir::scf::ConditionOp, mlir::OpState>(m, "CondtionOp");
|
||||
|
||||
py::class_<mlir::ModuleOp, mlir::OpState>(m, "module")
|
||||
.def("dump", [](mlir::ModuleOp &self) -> void {
|
||||
self.dump();
|
||||
})
|
||||
.def("dump", &mlir::ModuleOp::dump)
|
||||
.def("push_back", [](mlir::ModuleOp &self, mlir::FuncOp &funcOp) -> void {
|
||||
self.push_back(funcOp);
|
||||
})
|
||||
.def("has_function", [](mlir::ModuleOp &self, std::string &funcName) -> bool {
|
||||
if (self.lookupSymbol(funcName))
|
||||
return true;
|
||||
return false;
|
||||
})
|
||||
.def("get_function", [](mlir::ModuleOp &self, std::string &funcName) -> mlir::FuncOp {
|
||||
return self.lookupSymbol<mlir::FuncOp>(funcName);
|
||||
})
|
||||
;
|
||||
|
||||
py::class_<mlir::FuncOp, mlir::OpState>(m, "function")
|
||||
@@ -772,6 +785,7 @@ void init_triton_ir(py::module &&m) {
|
||||
.def("add_entry_block", [](mlir::FuncOp &self) -> mlir::Block* {
|
||||
return self.addEntryBlock();
|
||||
}, ret::reference)
|
||||
.def("reset_type", &mlir::FuncOp::setType)
|
||||
;
|
||||
|
||||
py::class_<mlir::OpBuilder::InsertPoint>(m, "InsertPoint");
|
||||
@@ -784,11 +798,14 @@ void init_triton_ir(py::module &&m) {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<mlir::ModuleOp>(loc);
|
||||
})
|
||||
// control flow
|
||||
.def("ret_void", [](mlir::OpBuilder &self) {
|
||||
.def("ret", [](mlir::OpBuilder &self, std::vector<mlir::Value> &vals) -> void {
|
||||
auto loc = self.getUnknownLoc();
|
||||
self.create<mlir::ReturnOp>(loc);
|
||||
}, ret::reference)
|
||||
self.create<mlir::ReturnOp>(loc, vals);
|
||||
})
|
||||
.def("call", [](mlir::OpBuilder &self, mlir::FuncOp &func, std::vector<mlir::Value> &args) -> mlir::OpState {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<mlir::CallOp>(loc, func, args);
|
||||
})
|
||||
// insertion block/point
|
||||
.def("set_insertion_point_to_start", [](mlir::OpBuilder &self, mlir::Block &block) -> void {
|
||||
self.setInsertionPointToStart(&block);
|
||||
@@ -900,6 +917,16 @@ void init_triton_ir(py::module &&m) {
|
||||
}
|
||||
throw std::runtime_error("invalid function type");
|
||||
})
|
||||
.def("get_or_insert_function", [](mlir::OpBuilder &self, mlir::ModuleOp &module,
|
||||
std::string &funcName, mlir::Type &funcType) -> mlir::FuncOp {
|
||||
if (mlir::Operation *funcOperation = module.lookupSymbol(funcName))
|
||||
return llvm::dyn_cast<mlir::FuncOp>(funcOperation);
|
||||
auto loc = self.getUnknownLoc();
|
||||
if (auto funcTy = funcType.dyn_cast<mlir::FunctionType>()) {
|
||||
return self.create<mlir::FuncOp>(loc, funcName, funcTy);
|
||||
}
|
||||
throw std::runtime_error("invalid function type");
|
||||
})
|
||||
.def("create_block", [](mlir::OpBuilder &self) -> mlir::Block* {
|
||||
mlir::Region *parent = self.getBlock()->getParent();
|
||||
return self.createBlock(parent);
|
||||
@@ -1293,6 +1320,16 @@ void init_triton_ir(py::module &&m) {
|
||||
// .def("create_umulhi", &ir::builder::create_umulhi, ret::reference)
|
||||
// .def("create_barrier", &ir::builder::create_barrier, ret::reference);
|
||||
;
|
||||
|
||||
py::class_<mlir::PassManager>(m, "pass_manager")
|
||||
.def(py::init<mlir::MLIRContext *>())
|
||||
.def("run", [](mlir::PassManager &self, mlir::ModuleOp &mod) {
|
||||
self.run(mod.getOperation());
|
||||
})
|
||||
.def("add_inliner_pass", [](mlir::PassManager &self) {
|
||||
self.addPass(mlir::createInlinerPass());
|
||||
})
|
||||
;
|
||||
}
|
||||
|
||||
void init_triton(py::module &m) {
|
||||
|
Reference in New Issue
Block a user