diff --git a/python/src/triton.cc b/python/src/triton.cc index 2bd987d55..c486ba299 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -641,6 +641,23 @@ void init_triton_ir(py::module &&m) { // // py::class_(m, "undef") // // .def("get", &ir::undef_value::get, ret::reference); + py::class_(m, "module") + // .def("set_attr") + .def("dump", [](MlirModule &self) -> void { + unwrap(self).dump(); + }) + .def("push_back", [](MlirModule &self, MlirOperation &funcOperation) { + if (auto info = unwrap(funcOperation)->getRegisteredInfo()) { + if (mlir::TypeID::get() == info->getTypeID()) { + auto funcOp = mlir::FuncOp::getFromOpaquePointer(unwrap(funcOperation)); + unwrap(self).push_back(funcOp); + } else + throw std::runtime_error("Only FuncOp can call push_back"); + } else + throw std::runtime_error("Unknown error"); + }) + ; + py::class_(m, "type") .def("is_integer", [](MlirType &self) -> bool { return mlirTypeIsAInteger(self); @@ -654,8 +671,8 @@ void init_triton_ir(py::module &&m) { .def("add_entry_block", [](MlirOperation &self) -> MlirBlock { if (auto info = unwrap(self)->getRegisteredInfo()) { if (mlir::TypeID::get() == info->getTypeID()) { - auto FunctionOp = mlir::FuncOp::getFromOpaquePointer(unwrap(self)); - mlir::Block *entry = FunctionOp.addEntryBlock(); + auto funcOp = mlir::FuncOp::getFromOpaquePointer(unwrap(self)); + mlir::Block *entry = funcOp.addEntryBlock(); return wrap(entry); } throw std::runtime_error("Only FuncOp can call add_entry_block"); @@ -716,6 +733,10 @@ void init_triton_ir(py::module &&m) { .def(py::init()) // // getters // .def_property_readonly("context", &ir::builder::get_context, ret::reference); + .def("create_module", [](mlir::OpBuilder &self) -> MlirModule { + auto loc = self.getUnknownLoc(); + return wrap(self.create(loc)); + }) // // control flow // .def("br", &ir::builder::create_br, ret::reference) // .def("cond_br", &ir::builder::create_cond_br, ret::reference) diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 170b71a09..e61eef451 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -26,7 +26,7 @@ from .tools.disasm import extract class CodeGenerator(ast.NodeVisitor): def __init__(self, context, prototype, gscope, attributes, constants, kwargs): self.builder = _triton.ir.builder(context) - self.module = _triton.ir.module('', self.builder) + self.module = self.builder.create_module() self.prototype = prototype self.gscope = gscope self.lscope = dict() diff --git a/python/tutorials/01-vector-add.py b/python/tutorials/01-vector-add.py index d684106f1..2845eb7e5 100644 --- a/python/tutorials/01-vector-add.py +++ b/python/tutorials/01-vector-add.py @@ -128,4 +128,4 @@ def benchmark(size, provider): # %% # We can now run the decorated function above. Pass `print_data=True` to see the performance number, `show_plots=True` to plot them, and/or # `save_path='/path/to/results/' to save them to disk along with raw CSV data -benchmark.run(print_data=True, show_plots=True) +# benchmark.run(print_data=True, show_plots=True) diff --git a/rewrite-test/test_ir.py b/rewrite-test/test_ir.py index f14ddfb7f..2008f9c4e 100644 --- a/rewrite-test/test_ir.py +++ b/rewrite-test/test_ir.py @@ -6,7 +6,7 @@ ctx.load_triton() # TODO builder = ir.builder(ctx) -# module = builder.create_module() +module = builder.create_module() i1_ty = builder.get_int1_ty() @@ -49,7 +49,7 @@ c_ptrs = builder.create_broadcast(entry.arg(2), [128]) c_ptrs = builder.create_gep(c_ptrs, offsets) builder.create_store(c_ptrs, c) -func.dump() +# func.dump() -# module.push_back(func) -# module.dump() +module.push_back(func) +module.dump()