Use mlir::Block to replace MlirBlock
This commit is contained in:
@@ -668,12 +668,12 @@ void init_triton_ir(py::module &&m) {
|
||||
;
|
||||
|
||||
py::class_<MlirOperation>(m, "operation")
|
||||
.def("add_entry_block", [](MlirOperation &self) -> MlirBlock {
|
||||
.def("add_entry_block", [](MlirOperation &self) -> mlir::Block {
|
||||
if (auto info = unwrap(self)->getRegisteredInfo()) {
|
||||
if (mlir::TypeID::get<mlir::FuncOp>() == info->getTypeID()) {
|
||||
auto funcOp = mlir::FuncOp::getFromOpaquePointer(unwrap(self));
|
||||
mlir::Block *entry = funcOp.addEntryBlock();
|
||||
return wrap(entry);
|
||||
return *entry;
|
||||
}
|
||||
throw std::runtime_error("Only FuncOp can call add_entry_block");
|
||||
} else
|
||||
@@ -684,12 +684,14 @@ void init_triton_ir(py::module &&m) {
|
||||
})
|
||||
;
|
||||
|
||||
py::class_<MlirValue>(m, "value")
|
||||
py::class_<mlir::Value>(m, "value")
|
||||
;
|
||||
py::class_<mlir::BlockArgument, mlir::Value>(m, "block_arguement")
|
||||
;
|
||||
|
||||
py::class_<MlirBlock>(m, "block")
|
||||
.def("arg", [](MlirBlock &self, int index) -> MlirValue {
|
||||
return wrap(unwrap(self)->getArgument(index));
|
||||
py::class_<mlir::Block>(m, "block")
|
||||
.def("arg", [](mlir::Block &self, int index) -> mlir::BlockArgument {
|
||||
return self.getArgument(index);
|
||||
})
|
||||
;
|
||||
|
||||
@@ -741,12 +743,16 @@ void init_triton_ir(py::module &&m) {
|
||||
// .def("br", &ir::builder::create_br, ret::reference)
|
||||
// .def("cond_br", &ir::builder::create_cond_br, ret::reference)
|
||||
// .def("ret_void", &ir::builder::create_ret_void, ret::reference)
|
||||
// // insertion block/point, insert points are represented as (*bb, *instr)
|
||||
.def("set_insertion_point_to_start", [](mlir::OpBuilder &self, MlirBlock &block) -> void{
|
||||
self.setInsertionPointToStart(unwrap(block));
|
||||
// insertion block/point
|
||||
.def("set_insertion_point_to_start", [](mlir::OpBuilder &self, mlir::Block &block) -> void {
|
||||
self.setInsertionPointToStart(&block);
|
||||
})
|
||||
// .def("get_insert_block", &ir::builder::get_insert_block, ret::reference)
|
||||
// .def("set_insert_block", (void (ir::builder::*)(ir::basic_block *)) & ir::builder::set_insert_point)
|
||||
.def("set_insertion_point_to_end", [](mlir::OpBuilder &self, mlir::Block &block) {
|
||||
self.setInsertionPointToEnd(&block);
|
||||
})
|
||||
.def("get_insertion_block", [](mlir::OpBuilder &self) -> mlir::Block & {
|
||||
return *self.getInsertionBlock();
|
||||
}, ret::reference)
|
||||
// .def("get_insert_point", [](ir::builder *self) {
|
||||
// ir::basic_block *bb = self->get_insert_block();
|
||||
// ir::basic_block::iterator it = self->get_insert_point();
|
||||
@@ -768,11 +774,11 @@ void init_triton_ir(py::module &&m) {
|
||||
// Use arith.ConstantOp to create constants
|
||||
// // Constants
|
||||
// .def("get_int1", &ir::builder::get_int1, ret::reference)
|
||||
.def("get_int32", [](mlir::OpBuilder &self, int64_t v) -> MlirValue {
|
||||
.def("get_int32", [](mlir::OpBuilder &self, int64_t v) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return wrap(mlir::Value(self.create<mlir::arith::ConstantIntOp>(
|
||||
return mlir::Value(self.create<mlir::arith::ConstantIntOp>(
|
||||
loc, v, self.getI32Type()
|
||||
)));
|
||||
));
|
||||
})
|
||||
// .def("get_uint32", &ir::builder::get_int32, ret::reference)
|
||||
// .def("get_int64", [](ir::builder *self, int64_t v) { return self->get_int64((uint64_t)v); }, ret::reference)
|
||||
@@ -818,9 +824,15 @@ void init_triton_ir(py::module &&m) {
|
||||
.def("get_double_ty", [](mlir::OpBuilder &self) -> MlirType {
|
||||
return wrap(self.getF64Type());
|
||||
})
|
||||
.def("get_ptr_ty", [](mlir::OpBuilder &self, MlirType &type) -> MlirType {
|
||||
.def("get_ptr_ty", [](mlir::OpBuilder &self, MlirType &type, int addrSpace) -> MlirType {
|
||||
return wrap(
|
||||
mlir::triton::PointerType::get(unwrap(type))
|
||||
mlir::triton::PointerType::get(unwrap(type), addrSpace)
|
||||
);
|
||||
})
|
||||
.def("get_block_ty", [](mlir::OpBuilder &self, MlirType &elementType,
|
||||
std::vector<int64_t> &shape) -> MlirType {
|
||||
return wrap(
|
||||
mlir::RankedTensorType::get(shape, unwrap(elementType))
|
||||
);
|
||||
})
|
||||
.def("get_function_ty", [](mlir::OpBuilder &self,
|
||||
|
Reference in New Issue
Block a user