bindings for ModuleOp

This commit is contained in:
Yan Da
2022-03-30 13:32:52 +08:00
parent 38e67b4293
commit e95d98a886
4 changed files with 29 additions and 8 deletions

View File

@@ -641,6 +641,23 @@ void init_triton_ir(py::module &&m) {
// // py::class_<ir::undef_value, ir::constant>(m, "undef")
// // .def("get", &ir::undef_value::get, ret::reference);
py::class_<MlirModule>(m, "module")
// .def("set_attr")
.def("dump", [](MlirModule &self) -> void {
unwrap(self).dump();
})
.def("push_back", [](MlirModule &self, MlirOperation &funcOperation) {
if (auto info = unwrap(funcOperation)->getRegisteredInfo()) {
if (mlir::TypeID::get<mlir::FuncOp>() == info->getTypeID()) {
auto funcOp = mlir::FuncOp::getFromOpaquePointer(unwrap(funcOperation));
unwrap(self).push_back(funcOp);
} else
throw std::runtime_error("Only FuncOp can call push_back");
} else
throw std::runtime_error("Unknown error");
})
;
py::class_<MlirType>(m, "type")
.def("is_integer", [](MlirType &self) -> bool {
return mlirTypeIsAInteger(self);
@@ -654,8 +671,8 @@ void init_triton_ir(py::module &&m) {
.def("add_entry_block", [](MlirOperation &self) -> MlirBlock {
if (auto info = unwrap(self)->getRegisteredInfo()) {
if (mlir::TypeID::get<mlir::FuncOp>() == info->getTypeID()) {
auto FunctionOp = mlir::FuncOp::getFromOpaquePointer(unwrap(self));
mlir::Block *entry = FunctionOp.addEntryBlock();
auto funcOp = mlir::FuncOp::getFromOpaquePointer(unwrap(self));
mlir::Block *entry = funcOp.addEntryBlock();
return wrap(entry);
}
throw std::runtime_error("Only FuncOp can call add_entry_block");
@@ -716,6 +733,10 @@ void init_triton_ir(py::module &&m) {
.def(py::init<mlir::MLIRContext *>())
// // getters
// .def_property_readonly("context", &ir::builder::get_context, ret::reference);
.def("create_module", [](mlir::OpBuilder &self) -> MlirModule {
auto loc = self.getUnknownLoc();
return wrap(self.create<mlir::ModuleOp>(loc));
})
// // control flow
// .def("br", &ir::builder::create_br, ret::reference)
// .def("cond_br", &ir::builder::create_cond_br, ret::reference)

View File

@@ -26,7 +26,7 @@ from .tools.disasm import extract
class CodeGenerator(ast.NodeVisitor):
def __init__(self, context, prototype, gscope, attributes, constants, kwargs):
self.builder = _triton.ir.builder(context)
self.module = _triton.ir.module('', self.builder)
self.module = self.builder.create_module()
self.prototype = prototype
self.gscope = gscope
self.lscope = dict()

View File

@@ -128,4 +128,4 @@ def benchmark(size, provider):
# %%
# We can now run the decorated function above. Pass `print_data=True` to see the performance number, `show_plots=True` to plot them, and/or
# `save_path='/path/to/results/' to save them to disk along with raw CSV data
benchmark.run(print_data=True, show_plots=True)
# benchmark.run(print_data=True, show_plots=True)

View File

@@ -6,7 +6,7 @@ ctx.load_triton()
# TODO
builder = ir.builder(ctx)
# module = builder.create_module()
module = builder.create_module()
i1_ty = builder.get_int1_ty()
@@ -49,7 +49,7 @@ c_ptrs = builder.create_broadcast(entry.arg(2), [128])
c_ptrs = builder.create_gep(c_ptrs, offsets)
builder.create_store(c_ptrs, c)
func.dump()
# func.dump()
# module.push_back(func)
# module.dump()
module.push_back(func)
module.dump()