Replace MlirOperation with MlirValue

This commit is contained in:
Yan Da
2022-03-23 13:31:14 +08:00
parent f2ab318614
commit 14a71dcb6f

View File

@@ -643,10 +643,10 @@ void init_triton_ir(py::module &&m) {
// // .def("get", &ir::undef_value::get, ret::reference);
py::class_<MlirType>(m, "type")
.def("is_integer", [](MlirType &self) {
.def("is_integer", [](MlirType &self) -> bool {
return mlirTypeIsAInteger(self);
})
.def("is_fp16", [](MlirType &self) {
.def("is_fp16", [](MlirType &self) -> bool {
return mlirTypeIsABF16(self);
})
;
@@ -669,7 +669,13 @@ void init_triton_ir(py::module &&m) {
})
;
py::class_<MlirValue>(m, "value")
;
py::class_<MlirBlock>(m, "block")
.def("arg", [](MlirBlock &self, int index) -> MlirValue {
return wrap(unwrap(self)->getArgument(index));
})
;
// py::class_<mlir::triton::Float8Type>(m, "float8_type")
@@ -756,7 +762,12 @@ void init_triton_ir(py::module &&m) {
// Use arith.ConstantOp to create constants
// // Constants
// .def("get_int1", &ir::builder::get_int1, ret::reference)
// .def("get_int32", [](ir::builder *self, int32_t v) { return self->get_int32((uint32_t)v); }, ret::reference)
.def("get_int32", [](mlir::OpBuilder &self, int64_t v) -> MlirValue {
auto loc = self.getUnknownLoc();
return wrap(mlir::Value(self.create<mlir::arith::ConstantIntOp>(
loc, v, self.getI32Type()
)));
}, ret::reference)
// .def("get_uint32", &ir::builder::get_int32, ret::reference)
// .def("get_int64", [](ir::builder *self, int64_t v) { return self->get_int64((uint64_t)v); }, ret::reference)
// .def("get_uint64", &ir::builder::get_int64, ret::reference)
@@ -801,6 +812,11 @@ void init_triton_ir(py::module &&m) {
.def("get_double_ty", [](mlir::OpBuilder &self) -> MlirType {
return wrap(self.getF64Type());
}, ret::reference)
.def("get_ptr_ty", [](mlir::OpBuilder &self, MlirType &type) -> MlirType {
return wrap(
mlir::triton::PointerType::get(unwrap(type))
);
})
.def("get_function_ty", [](mlir::OpBuilder &self,
std::vector<MlirType> inTypes,
std::vector<MlirType> outTypes) -> MlirType {
@@ -829,14 +845,18 @@ void init_triton_ir(py::module &&m) {
// .def("create_scf_while")
// miscellious
.def("create_make_range", [](mlir::OpBuilder &self, int start, int end){
.def("create_make_range", [](mlir::OpBuilder &self, int start, int end) -> MlirValue {
auto loc = self.getUnknownLoc();
auto retType = mlir::RankedTensorType::get({end-start}, self.getI32Type());
return wrap(self.create<mlir::triton::MakeRangeOp>(loc, retType, start, end).getOperation());
return wrap(
mlir::Value(self.create<mlir::triton::MakeRangeOp>(loc, retType, start, end))
);
}, ret::reference)
.def("create_get_program_id", [](mlir::OpBuilder &self, int axis) {
.def("create_get_program_id", [](mlir::OpBuilder &self, int axis) -> MlirValue {
auto loc = self.getUnknownLoc();
return wrap(self.create<mlir::triton::GetProgramIdOp>(loc, self.getI32Type(), axis).getOperation());
return wrap(
mlir::Value(self.create<mlir::triton::GetProgramIdOp>(loc, self.getI32Type(), axis))
);
})
// // Cast instructions
@@ -853,43 +873,84 @@ void init_triton_ir(py::module &&m) {
// .def("create_downcast", &ir::builder::create_downcast, ret::reference)
// // Binary instructions
// .def("create_insert_nuwnswb_binop", &ir::builder::create_insert_nuwnswb_binop, ret::reference)
// .def("create_fmul", &ir::builder::create_fmul, ret::reference)
// .def("create_fdiv", &ir::builder::create_fdiv, ret::reference)
// .def("create_frem", &ir::builder::create_frem, ret::reference)
// .def("create_fadd", &ir::builder::create_fadd, ret::reference)
// .def("create_fsub", &ir::builder::create_fsub, ret::reference)
.def("create_mul", [](mlir::OpBuilder &self, MlirOperation &lhs, MlirOperation &rhs) -> MlirOperation {
.def("create_fmul", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
auto loc = self.getUnknownLoc();
return wrap(mlir::Value(
self.create<mlir::arith::MulFOp>(loc, unwrap(lhs), unwrap(rhs))
));
}, ret::reference)
.def("create_fdiv", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
auto loc = self.getUnknownLoc();
return wrap(mlir::Value(
self.create<mlir::arith::DivFOp>(loc, unwrap(lhs), unwrap(rhs))
));
}, ret::reference)
.def("create_frem", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
auto loc = self.getUnknownLoc();
return wrap(mlir::Value(
self.create<mlir::arith::RemFOp>(loc, unwrap(lhs), unwrap(rhs))
));
}, ret::reference)
.def("create_fadd", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
auto loc = self.getUnknownLoc();
return wrap(mlir::Value(
self.create<mlir::arith::AddFOp>(loc, unwrap(lhs), unwrap(rhs))
));
}, ret::reference)
.def("create_fsub", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
auto loc = self.getUnknownLoc();
return wrap(mlir::Value(
self.create<mlir::arith::SubFOp>(loc, unwrap(lhs), unwrap(rhs))
));
}, ret::reference)
.def("create_mul", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
auto loc = self.getUnknownLoc();
// Check lhs & rhs have single result (?)
return wrap(self.create<mlir::arith::MulIOp>(loc, unwrap(lhs)->getResult(0), unwrap(rhs)->getResult(0)).getOperation());
return wrap(
mlir::Value(self.create<mlir::arith::MulIOp>(loc, unwrap(lhs), unwrap(rhs)))
);
}, ret::reference)
.def("create_sdiv", [](mlir::OpBuilder &self, MlirOperation &lhs, MlirOperation &rhs) -> MlirOperation {
.def("create_sdiv", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
auto loc = self.getUnknownLoc();
return wrap(self.create<mlir::arith::DivSIOp>(loc, unwrap(lhs)->getResult(0), unwrap(rhs)->getResult(0)).getOperation());
return wrap(
mlir::Value(self.create<mlir::arith::DivSIOp>(loc, unwrap(lhs), unwrap(rhs)))
);
}, ret::reference)
.def("create_udiv", [](mlir::OpBuilder &self, MlirOperation &lhs, MlirOperation &rhs) -> MlirOperation {
.def("create_udiv", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
auto loc = self.getUnknownLoc();
return wrap(self.create<mlir::arith::DivUIOp>(loc, unwrap(lhs)->getResult(0), unwrap(rhs)->getResult(0)).getOperation());
return wrap(
mlir::Value(self.create<mlir::arith::DivUIOp>(loc, unwrap(lhs), unwrap(rhs)))
);
}, ret::reference)
.def("create_srem", [](mlir::OpBuilder &self, MlirOperation &lhs, MlirOperation &rhs) -> MlirOperation {
.def("create_srem", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
auto loc = self.getUnknownLoc();
return wrap(self.create<mlir::arith::RemSIOp>(loc, unwrap(lhs)->getResult(0), unwrap(rhs)->getResult(0)).getOperation());
return wrap(
mlir::Value(self.create<mlir::arith::RemSIOp>(loc, unwrap(lhs), unwrap(rhs)))
);
}, ret::reference)
.def("create_urem", [](mlir::OpBuilder &self, MlirOperation &lhs, MlirOperation &rhs) -> MlirOperation {
.def("create_urem", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
auto loc = self.getUnknownLoc();
return wrap(self.create<mlir::arith::RemUIOp>(loc, unwrap(lhs)->getResult(0), unwrap(rhs)->getResult(0)).getOperation());
return wrap(
mlir::Value(self.create<mlir::arith::RemUIOp>(loc, unwrap(lhs), unwrap(rhs)))
);
}, ret::reference)
.def("create_add", [](mlir::OpBuilder &self, MlirOperation &lhs, MlirOperation &rhs) -> MlirOperation {
.def("create_add", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
auto loc = self.getUnknownLoc();
return wrap(self.create<mlir::arith::AddIOp>(loc, unwrap(lhs)->getResult(0), unwrap(rhs)->getResult(0)).getOperation());
return wrap(
mlir::Value(self.create<mlir::arith::AddIOp>(loc, unwrap(lhs), unwrap(rhs)))
);
}, ret::reference)
.def("create_sub", [](mlir::OpBuilder &self, MlirOperation &lhs, MlirOperation &rhs) -> MlirOperation {
.def("create_sub", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
auto loc = self.getUnknownLoc();
return wrap(self.create<mlir::arith::SubIOp>(loc, unwrap(lhs)->getResult(0), unwrap(rhs)->getResult(0)).getOperation());
return wrap(
mlir::Value(self.create<mlir::arith::SubIOp>(loc, unwrap(lhs), unwrap(rhs)))
);
}, ret::reference)
.def("create_shl", [](mlir::OpBuilder &self, MlirOperation &lhs, MlirOperation &rhs) -> MlirOperation {
.def("create_shl", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
auto loc = self.getUnknownLoc();
return wrap(self.create<mlir::arith::ShLIOp>(loc, unwrap(lhs)->getResult(0), unwrap(rhs)->getResult(0)).getOperation());
return wrap(
mlir::Value(self.create<mlir::arith::ShLIOp>(loc, unwrap(lhs), unwrap(rhs)))
);
}, ret::reference)
// .def("create_lshr", &ir::builder::create_lshr, ret::reference,
// py::arg("lhs"), py::arg("rhs"),
@@ -897,88 +958,91 @@ void init_triton_ir(py::module &&m) {
// .def("create_ashr", &ir::builder::create_ashr, ret::reference,
// py::arg("lhs"), py::arg("rhs"),
// py::arg("has_nuw")=false, py::arg("has_nsw")=false)
// // GEP
// .def("create_gep", [](mlir::OpBuilder &self, MlirOperation &ptr, MlirOperation &offset) -> MlirOperation {
// auto loc = self.getUnknownLoc();
// }, ret::reference)
// GEP
.def("create_gep", [](mlir::OpBuilder &self, MlirValue &ptr, MlirValue &offset) -> MlirValue {
auto loc = self.getUnknownLoc();
return wrap(
mlir::Value(self.create<mlir::triton::GEPOp>(loc, unwrap(ptr).getType(), unwrap(ptr), unwrap(offset)))
);
}, ret::reference)
// Comparison (int)
.def("create_icmpSLE", [](mlir::OpBuilder &self, MlirOperation &lhs, MlirOperation &rhs) -> MlirOperation {
.def("create_icmpSLE", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
auto loc = self.getUnknownLoc();
return wrap(self.create<mlir::arith::CmpIOp>(
return wrap(mlir::Value(self.create<mlir::arith::CmpIOp>(
loc, mlir::arith::CmpIPredicate::sle,
unwrap(lhs)->getResult(0), unwrap(rhs)->getResult(0)
).getOperation());
unwrap(lhs), unwrap(rhs)
)));
}, ret::reference)
.def("create_icmpSLT", [](mlir::OpBuilder &self, MlirOperation &lhs, MlirOperation &rhs) -> MlirOperation {
.def("create_icmpSLT", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
auto loc = self.getUnknownLoc();
return wrap(self.create<mlir::arith::CmpIOp>(
return wrap(mlir::Value(self.create<mlir::arith::CmpIOp>(
loc, mlir::arith::CmpIPredicate::slt,
unwrap(lhs)->getResult(0), unwrap(rhs)->getResult(0)
).getOperation());
unwrap(lhs), unwrap(rhs)
)));
}, ret::reference)
.def("create_icmpSGE", [](mlir::OpBuilder &self, MlirOperation &lhs, MlirOperation &rhs) -> MlirOperation {
.def("create_icmpSGE", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
auto loc = self.getUnknownLoc();
return wrap(self.create<mlir::arith::CmpIOp>(
return wrap(mlir::Value(self.create<mlir::arith::CmpIOp>(
loc, mlir::arith::CmpIPredicate::sge,
unwrap(lhs)->getResult(0), unwrap(rhs)->getResult(0)
).getOperation());
unwrap(lhs), unwrap(rhs)
)));
}, ret::reference)
.def("create_icmpSGT", [](mlir::OpBuilder &self, MlirOperation &lhs, MlirOperation &rhs) -> MlirOperation {
.def("create_icmpSGT", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
auto loc = self.getUnknownLoc();
return wrap(self.create<mlir::arith::CmpIOp>(
return wrap(mlir::Value(self.create<mlir::arith::CmpIOp>(
loc, mlir::arith::CmpIPredicate::sgt,
unwrap(lhs)->getResult(0), unwrap(rhs)->getResult(0)
).getOperation());
unwrap(lhs), unwrap(rhs)
)));
}, ret::reference)
.def("create_icmpULE", [](mlir::OpBuilder &self, MlirOperation &lhs, MlirOperation &rhs) -> MlirOperation {
.def("create_icmpULE", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
auto loc = self.getUnknownLoc();
return wrap(self.create<mlir::arith::CmpIOp>(
return wrap(mlir::Value(self.create<mlir::arith::CmpIOp>(
loc, mlir::arith::CmpIPredicate::ule,
unwrap(lhs)->getResult(0), unwrap(rhs)->getResult(0)
).getOperation());
unwrap(lhs), unwrap(rhs)
)));
}, ret::reference)
.def("create_icmpULT", [](mlir::OpBuilder &self, MlirOperation &lhs, MlirOperation &rhs) -> MlirOperation {
.def("create_icmpULT", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
auto loc = self.getUnknownLoc();
return wrap(self.create<mlir::arith::CmpIOp>(
return wrap(mlir::Value(self.create<mlir::arith::CmpIOp>(
loc, mlir::arith::CmpIPredicate::ult,
unwrap(lhs)->getResult(0), unwrap(rhs)->getResult(0)
).getOperation());
unwrap(lhs), unwrap(rhs)
)));
}, ret::reference)
.def("create_icmpUGE", [](mlir::OpBuilder &self, MlirOperation &lhs, MlirOperation &rhs) -> MlirOperation {
.def("create_icmpUGE", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
auto loc = self.getUnknownLoc();
return wrap(self.create<mlir::arith::CmpIOp>(
return wrap(mlir::Value(self.create<mlir::arith::CmpIOp>(
loc, mlir::arith::CmpIPredicate::uge,
unwrap(lhs)->getResult(0), unwrap(rhs)->getResult(0)
).getOperation());
unwrap(lhs), unwrap(rhs)
)));
}, ret::reference)
.def("create_icmpUGT", [](mlir::OpBuilder &self, MlirOperation &lhs, MlirOperation &rhs) -> MlirOperation {
.def("create_icmpUGT", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
auto loc = self.getUnknownLoc();
return wrap(self.create<mlir::arith::CmpIOp>(
return wrap(mlir::Value(self.create<mlir::arith::CmpIOp>(
loc, mlir::arith::CmpIPredicate::ugt,
unwrap(lhs)->getResult(0), unwrap(rhs)->getResult(0)
).getOperation());
unwrap(lhs), unwrap(rhs)
)));
}, ret::reference)
.def("create_icmpEQ", [](mlir::OpBuilder &self, MlirOperation &lhs, MlirOperation &rhs) -> MlirOperation {
.def("create_icmpEQ", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
auto loc = self.getUnknownLoc();
return wrap(self.create<mlir::arith::CmpIOp>(
return wrap(mlir::Value(self.create<mlir::arith::CmpIOp>(
loc, mlir::arith::CmpIPredicate::eq,
unwrap(lhs)->getResult(0), unwrap(rhs)->getResult(0)
).getOperation());
unwrap(lhs), unwrap(rhs)
)));
}, ret::reference)
.def("create_icmpNE", [](mlir::OpBuilder &self, MlirOperation &lhs, MlirOperation &rhs) -> MlirOperation {
.def("create_icmpNE", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
auto loc = self.getUnknownLoc();
return wrap(self.create<mlir::arith::CmpIOp>(
return wrap(mlir::Value(self.create<mlir::arith::CmpIOp>(
loc, mlir::arith::CmpIPredicate::ne,
unwrap(lhs)->getResult(0), unwrap(rhs)->getResult(0)
).getOperation());
unwrap(lhs), unwrap(rhs)
)));
}, ret::reference)
// Comparison (float)
.def("create_fcmpOLT", [](mlir::OpBuilder &self, MlirOperation &lhs, MlirOperation &rhs) -> MlirOperation {
.def("create_fcmpOLT", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
auto loc = self.getUnknownLoc();
return wrap(self.create<mlir::arith::CmpFOp>(
return wrap(mlir::Value(self.create<mlir::arith::CmpFOp>(
loc, mlir::arith::CmpFPredicate::OLT,
unwrap(lhs)->getResult(0), unwrap(rhs)->getResult(0)
).getOperation());
unwrap(lhs), unwrap(rhs)
)));
}, ret::reference)
// .def("create_fcmpOGT", &ir::builder::create_fcmpOGT, ret::reference)
// .def("create_fcmpOLE", &ir::builder::create_fcmpOLE, ret::reference)
@@ -996,15 +1060,29 @@ void init_triton_ir(py::module &&m) {
// .def("create_xor", &ir::builder::create_xor, ret::reference)
// .def("create_or", &ir::builder::create_or, ret::reference)
// // Input/Output
// .def("create_load", &ir::builder::create_load, ret::reference)
// .def("create_store", &ir::builder::create_store, ret::reference)
.def("create_load", [](mlir::OpBuilder &self, MlirValue &ptrs) -> MlirValue {
auto loc = self.getUnknownLoc();
return wrap(mlir::Value(
self.create<mlir::triton::LoadOp>(loc, unwrap(ptrs))
));
}, ret::reference)
.def("create_store", [](mlir::OpBuilder &self, MlirValue &ptrs, MlirValue &value) -> void {
auto loc = self.getUnknownLoc();
self.create<mlir::triton::StoreOp>(loc, unwrap(ptrs), unwrap(value));
}, ret::reference)
// .def("create_masked_load", &ir::builder::create_masked_load, ret::reference)
// .def("create_masked_store", &ir::builder::create_masked_store, ret::reference)
// // Block instruction
// .def("create_splat", &ir::builder::create_splat, ret::reference)
// .def("create_reshape", &ir::builder::create_reshape, ret::reference)
// .def("create_cat", &ir::builder::create_cat, ret::reference)
// .def("create_broadcast", &ir::builder::create_broadcast, ret::reference)
.def("create_broadcast", [](mlir::OpBuilder &self, MlirValue &arg, std::vector<int64_t> &shape) -> MlirValue {
auto loc = self.getUnknownLoc();
auto argType = unwrap(arg).getType();
return wrap(mlir::Value(self.create<mlir::triton::BroadcastOp>(
loc, mlir::RankedTensorType::get(shape, argType), unwrap(arg)
)));
}, ret::reference)
// // atomic
// .def("create_atomic_cas", &ir::builder::create_atomic_cas, ret::reference)
// .def("create_atomic_rmw", &ir::builder::create_atomic_rmw, ret::reference)