bindings for ModuleOp

This commit is contained in:
Yan Da
2022-03-30 13:32:52 +08:00
parent 38e67b4293
commit e95d98a886
4 changed files with 29 additions and 8 deletions

View File

@@ -641,6 +641,23 @@ void init_triton_ir(py::module &&m) {
// // py::class_<ir::undef_value, ir::constant>(m, "undef") // // py::class_<ir::undef_value, ir::constant>(m, "undef")
// // .def("get", &ir::undef_value::get, ret::reference); // // .def("get", &ir::undef_value::get, ret::reference);
py::class_<MlirModule>(m, "module")
// .def("set_attr")
.def("dump", [](MlirModule &self) -> void {
unwrap(self).dump();
})
.def("push_back", [](MlirModule &self, MlirOperation &funcOperation) {
if (auto info = unwrap(funcOperation)->getRegisteredInfo()) {
if (mlir::TypeID::get<mlir::FuncOp>() == info->getTypeID()) {
auto funcOp = mlir::FuncOp::getFromOpaquePointer(unwrap(funcOperation));
unwrap(self).push_back(funcOp);
} else
throw std::runtime_error("Only FuncOp can call push_back");
} else
throw std::runtime_error("Unknown error");
})
;
py::class_<MlirType>(m, "type") py::class_<MlirType>(m, "type")
.def("is_integer", [](MlirType &self) -> bool { .def("is_integer", [](MlirType &self) -> bool {
return mlirTypeIsAInteger(self); return mlirTypeIsAInteger(self);
@@ -654,8 +671,8 @@ void init_triton_ir(py::module &&m) {
.def("add_entry_block", [](MlirOperation &self) -> MlirBlock { .def("add_entry_block", [](MlirOperation &self) -> MlirBlock {
if (auto info = unwrap(self)->getRegisteredInfo()) { if (auto info = unwrap(self)->getRegisteredInfo()) {
if (mlir::TypeID::get<mlir::FuncOp>() == info->getTypeID()) { if (mlir::TypeID::get<mlir::FuncOp>() == info->getTypeID()) {
auto FunctionOp = mlir::FuncOp::getFromOpaquePointer(unwrap(self)); auto funcOp = mlir::FuncOp::getFromOpaquePointer(unwrap(self));
mlir::Block *entry = FunctionOp.addEntryBlock(); mlir::Block *entry = funcOp.addEntryBlock();
return wrap(entry); return wrap(entry);
} }
throw std::runtime_error("Only FuncOp can call add_entry_block"); throw std::runtime_error("Only FuncOp can call add_entry_block");
@@ -716,6 +733,10 @@ void init_triton_ir(py::module &&m) {
.def(py::init<mlir::MLIRContext *>()) .def(py::init<mlir::MLIRContext *>())
// // getters // // getters
// .def_property_readonly("context", &ir::builder::get_context, ret::reference); // .def_property_readonly("context", &ir::builder::get_context, ret::reference);
.def("create_module", [](mlir::OpBuilder &self) -> MlirModule {
auto loc = self.getUnknownLoc();
return wrap(self.create<mlir::ModuleOp>(loc));
})
// // control flow // // control flow
// .def("br", &ir::builder::create_br, ret::reference) // .def("br", &ir::builder::create_br, ret::reference)
// .def("cond_br", &ir::builder::create_cond_br, ret::reference) // .def("cond_br", &ir::builder::create_cond_br, ret::reference)

View File

@@ -26,7 +26,7 @@ from .tools.disasm import extract
class CodeGenerator(ast.NodeVisitor): class CodeGenerator(ast.NodeVisitor):
def __init__(self, context, prototype, gscope, attributes, constants, kwargs): def __init__(self, context, prototype, gscope, attributes, constants, kwargs):
self.builder = _triton.ir.builder(context) self.builder = _triton.ir.builder(context)
self.module = _triton.ir.module('', self.builder) self.module = self.builder.create_module()
self.prototype = prototype self.prototype = prototype
self.gscope = gscope self.gscope = gscope
self.lscope = dict() self.lscope = dict()

View File

@@ -128,4 +128,4 @@ def benchmark(size, provider):
# %% # %%
# We can now run the decorated function above. Pass `print_data=True` to see the performance number, `show_plots=True` to plot them, and/or # We can now run the decorated function above. Pass `print_data=True` to see the performance number, `show_plots=True` to plot them, and/or
# `save_path='/path/to/results/' to save them to disk along with raw CSV data # `save_path='/path/to/results/' to save them to disk along with raw CSV data
benchmark.run(print_data=True, show_plots=True) # benchmark.run(print_data=True, show_plots=True)

View File

@@ -6,7 +6,7 @@ ctx.load_triton()
# TODO # TODO
builder = ir.builder(ctx) builder = ir.builder(ctx)
# module = builder.create_module() module = builder.create_module()
i1_ty = builder.get_int1_ty() i1_ty = builder.get_int1_ty()
@@ -49,7 +49,7 @@ c_ptrs = builder.create_broadcast(entry.arg(2), [128])
c_ptrs = builder.create_gep(c_ptrs, offsets) c_ptrs = builder.create_gep(c_ptrs, offsets)
builder.create_store(c_ptrs, c) builder.create_store(c_ptrs, c)
func.dump() # func.dump()
# module.push_back(func) module.push_back(func)
# module.dump() module.dump()