bindings for ModuleOp
This commit is contained in:
@@ -641,6 +641,23 @@ void init_triton_ir(py::module &&m) {
|
|||||||
// // py::class_<ir::undef_value, ir::constant>(m, "undef")
|
// // py::class_<ir::undef_value, ir::constant>(m, "undef")
|
||||||
// // .def("get", &ir::undef_value::get, ret::reference);
|
// // .def("get", &ir::undef_value::get, ret::reference);
|
||||||
|
|
||||||
|
py::class_<MlirModule>(m, "module")
|
||||||
|
// .def("set_attr")
|
||||||
|
.def("dump", [](MlirModule &self) -> void {
|
||||||
|
unwrap(self).dump();
|
||||||
|
})
|
||||||
|
.def("push_back", [](MlirModule &self, MlirOperation &funcOperation) {
|
||||||
|
if (auto info = unwrap(funcOperation)->getRegisteredInfo()) {
|
||||||
|
if (mlir::TypeID::get<mlir::FuncOp>() == info->getTypeID()) {
|
||||||
|
auto funcOp = mlir::FuncOp::getFromOpaquePointer(unwrap(funcOperation));
|
||||||
|
unwrap(self).push_back(funcOp);
|
||||||
|
} else
|
||||||
|
throw std::runtime_error("Only FuncOp can call push_back");
|
||||||
|
} else
|
||||||
|
throw std::runtime_error("Unknown error");
|
||||||
|
})
|
||||||
|
;
|
||||||
|
|
||||||
py::class_<MlirType>(m, "type")
|
py::class_<MlirType>(m, "type")
|
||||||
.def("is_integer", [](MlirType &self) -> bool {
|
.def("is_integer", [](MlirType &self) -> bool {
|
||||||
return mlirTypeIsAInteger(self);
|
return mlirTypeIsAInteger(self);
|
||||||
@@ -654,8 +671,8 @@ void init_triton_ir(py::module &&m) {
|
|||||||
.def("add_entry_block", [](MlirOperation &self) -> MlirBlock {
|
.def("add_entry_block", [](MlirOperation &self) -> MlirBlock {
|
||||||
if (auto info = unwrap(self)->getRegisteredInfo()) {
|
if (auto info = unwrap(self)->getRegisteredInfo()) {
|
||||||
if (mlir::TypeID::get<mlir::FuncOp>() == info->getTypeID()) {
|
if (mlir::TypeID::get<mlir::FuncOp>() == info->getTypeID()) {
|
||||||
auto FunctionOp = mlir::FuncOp::getFromOpaquePointer(unwrap(self));
|
auto funcOp = mlir::FuncOp::getFromOpaquePointer(unwrap(self));
|
||||||
mlir::Block *entry = FunctionOp.addEntryBlock();
|
mlir::Block *entry = funcOp.addEntryBlock();
|
||||||
return wrap(entry);
|
return wrap(entry);
|
||||||
}
|
}
|
||||||
throw std::runtime_error("Only FuncOp can call add_entry_block");
|
throw std::runtime_error("Only FuncOp can call add_entry_block");
|
||||||
@@ -716,6 +733,10 @@ void init_triton_ir(py::module &&m) {
|
|||||||
.def(py::init<mlir::MLIRContext *>())
|
.def(py::init<mlir::MLIRContext *>())
|
||||||
// // getters
|
// // getters
|
||||||
// .def_property_readonly("context", &ir::builder::get_context, ret::reference);
|
// .def_property_readonly("context", &ir::builder::get_context, ret::reference);
|
||||||
|
.def("create_module", [](mlir::OpBuilder &self) -> MlirModule {
|
||||||
|
auto loc = self.getUnknownLoc();
|
||||||
|
return wrap(self.create<mlir::ModuleOp>(loc));
|
||||||
|
})
|
||||||
// // control flow
|
// // control flow
|
||||||
// .def("br", &ir::builder::create_br, ret::reference)
|
// .def("br", &ir::builder::create_br, ret::reference)
|
||||||
// .def("cond_br", &ir::builder::create_cond_br, ret::reference)
|
// .def("cond_br", &ir::builder::create_cond_br, ret::reference)
|
||||||
|
@@ -26,7 +26,7 @@ from .tools.disasm import extract
|
|||||||
class CodeGenerator(ast.NodeVisitor):
|
class CodeGenerator(ast.NodeVisitor):
|
||||||
def __init__(self, context, prototype, gscope, attributes, constants, kwargs):
|
def __init__(self, context, prototype, gscope, attributes, constants, kwargs):
|
||||||
self.builder = _triton.ir.builder(context)
|
self.builder = _triton.ir.builder(context)
|
||||||
self.module = _triton.ir.module('', self.builder)
|
self.module = self.builder.create_module()
|
||||||
self.prototype = prototype
|
self.prototype = prototype
|
||||||
self.gscope = gscope
|
self.gscope = gscope
|
||||||
self.lscope = dict()
|
self.lscope = dict()
|
||||||
|
@@ -128,4 +128,4 @@ def benchmark(size, provider):
|
|||||||
# %%
|
# %%
|
||||||
# We can now run the decorated function above. Pass `print_data=True` to see the performance number, `show_plots=True` to plot them, and/or
|
# We can now run the decorated function above. Pass `print_data=True` to see the performance number, `show_plots=True` to plot them, and/or
|
||||||
# `save_path='/path/to/results/' to save them to disk along with raw CSV data
|
# `save_path='/path/to/results/' to save them to disk along with raw CSV data
|
||||||
benchmark.run(print_data=True, show_plots=True)
|
# benchmark.run(print_data=True, show_plots=True)
|
||||||
|
@@ -6,7 +6,7 @@ ctx.load_triton()
|
|||||||
# TODO
|
# TODO
|
||||||
builder = ir.builder(ctx)
|
builder = ir.builder(ctx)
|
||||||
|
|
||||||
# module = builder.create_module()
|
module = builder.create_module()
|
||||||
|
|
||||||
|
|
||||||
i1_ty = builder.get_int1_ty()
|
i1_ty = builder.get_int1_ty()
|
||||||
@@ -49,7 +49,7 @@ c_ptrs = builder.create_broadcast(entry.arg(2), [128])
|
|||||||
c_ptrs = builder.create_gep(c_ptrs, offsets)
|
c_ptrs = builder.create_gep(c_ptrs, offsets)
|
||||||
builder.create_store(c_ptrs, c)
|
builder.create_store(c_ptrs, c)
|
||||||
|
|
||||||
func.dump()
|
# func.dump()
|
||||||
|
|
||||||
# module.push_back(func)
|
module.push_back(func)
|
||||||
# module.dump()
|
module.dump()
|
||||||
|
Reference in New Issue
Block a user