Use mlir::Block to replace MlirBlock

This commit is contained in:
Yan Da
2022-03-30 16:31:03 +08:00
parent e95d98a886
commit e381dc72c5
4 changed files with 76 additions and 57 deletions

View File

@@ -668,12 +668,12 @@ void init_triton_ir(py::module &&m) {
;
py::class_<MlirOperation>(m, "operation")
.def("add_entry_block", [](MlirOperation &self) -> MlirBlock {
.def("add_entry_block", [](MlirOperation &self) -> mlir::Block {
if (auto info = unwrap(self)->getRegisteredInfo()) {
if (mlir::TypeID::get<mlir::FuncOp>() == info->getTypeID()) {
auto funcOp = mlir::FuncOp::getFromOpaquePointer(unwrap(self));
mlir::Block *entry = funcOp.addEntryBlock();
return wrap(entry);
return *entry;
}
throw std::runtime_error("Only FuncOp can call add_entry_block");
} else
@@ -684,12 +684,14 @@ void init_triton_ir(py::module &&m) {
})
;
py::class_<MlirValue>(m, "value")
py::class_<mlir::Value>(m, "value")
;
py::class_<mlir::BlockArgument, mlir::Value>(m, "block_arguement")
;
py::class_<MlirBlock>(m, "block")
.def("arg", [](MlirBlock &self, int index) -> MlirValue {
return wrap(unwrap(self)->getArgument(index));
py::class_<mlir::Block>(m, "block")
.def("arg", [](mlir::Block &self, int index) -> mlir::BlockArgument {
return self.getArgument(index);
})
;
@@ -741,12 +743,16 @@ void init_triton_ir(py::module &&m) {
// .def("br", &ir::builder::create_br, ret::reference)
// .def("cond_br", &ir::builder::create_cond_br, ret::reference)
// .def("ret_void", &ir::builder::create_ret_void, ret::reference)
// // insertion block/point, insert points are represented as (*bb, *instr)
.def("set_insertion_point_to_start", [](mlir::OpBuilder &self, MlirBlock &block) -> void{
self.setInsertionPointToStart(unwrap(block));
// insertion block/point
.def("set_insertion_point_to_start", [](mlir::OpBuilder &self, mlir::Block &block) -> void {
self.setInsertionPointToStart(&block);
})
// .def("get_insert_block", &ir::builder::get_insert_block, ret::reference)
// .def("set_insert_block", (void (ir::builder::*)(ir::basic_block *)) & ir::builder::set_insert_point)
.def("set_insertion_point_to_end", [](mlir::OpBuilder &self, mlir::Block &block) {
self.setInsertionPointToEnd(&block);
})
.def("get_insertion_block", [](mlir::OpBuilder &self) -> mlir::Block & {
return *self.getInsertionBlock();
}, ret::reference)
// .def("get_insert_point", [](ir::builder *self) {
// ir::basic_block *bb = self->get_insert_block();
// ir::basic_block::iterator it = self->get_insert_point();
@@ -768,11 +774,11 @@ void init_triton_ir(py::module &&m) {
// Use arith.ConstantOp to create constants
// // Constants
// .def("get_int1", &ir::builder::get_int1, ret::reference)
.def("get_int32", [](mlir::OpBuilder &self, int64_t v) -> MlirValue {
.def("get_int32", [](mlir::OpBuilder &self, int64_t v) -> mlir::Value {
auto loc = self.getUnknownLoc();
return wrap(mlir::Value(self.create<mlir::arith::ConstantIntOp>(
return mlir::Value(self.create<mlir::arith::ConstantIntOp>(
loc, v, self.getI32Type()
)));
));
})
// .def("get_uint32", &ir::builder::get_int32, ret::reference)
// .def("get_int64", [](ir::builder *self, int64_t v) { return self->get_int64((uint64_t)v); }, ret::reference)
@@ -818,9 +824,15 @@ void init_triton_ir(py::module &&m) {
.def("get_double_ty", [](mlir::OpBuilder &self) -> MlirType {
return wrap(self.getF64Type());
})
.def("get_ptr_ty", [](mlir::OpBuilder &self, MlirType &type) -> MlirType {
.def("get_ptr_ty", [](mlir::OpBuilder &self, MlirType &type, int addrSpace) -> MlirType {
return wrap(
mlir::triton::PointerType::get(unwrap(type))
mlir::triton::PointerType::get(unwrap(type), addrSpace)
);
})
.def("get_block_ty", [](mlir::OpBuilder &self, MlirType &elementType,
std::vector<int64_t> &shape) -> MlirType {
return wrap(
mlir::RankedTensorType::get(shape, unwrap(elementType))
);
})
.def("get_function_ty", [](mlir::OpBuilder &self,