From 14a71dcb6fd1db6bd5e35307da2851cd54bca9d6 Mon Sep 17 00:00:00 2001 From: Yan Da Date: Wed, 23 Mar 2022 13:31:14 +0800 Subject: [PATCH] Replace MlirOperation with MlirValue --- python/src/triton.cc | 236 ++++++++++++++++++++++++++++--------------- 1 file changed, 157 insertions(+), 79 deletions(-) diff --git a/python/src/triton.cc b/python/src/triton.cc index 73b81b8d2..fb88d6f82 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -643,10 +643,10 @@ void init_triton_ir(py::module &&m) { // // .def("get", &ir::undef_value::get, ret::reference); py::class_(m, "type") - .def("is_integer", [](MlirType &self) { + .def("is_integer", [](MlirType &self) -> bool { return mlirTypeIsAInteger(self); }) - .def("is_fp16", [](MlirType &self) { + .def("is_fp16", [](MlirType &self) -> bool { return mlirTypeIsABF16(self); }) ; @@ -669,7 +669,13 @@ void init_triton_ir(py::module &&m) { }) ; + py::class_(m, "value") + ; + py::class_(m, "block") + .def("arg", [](MlirBlock &self, int index) -> MlirValue { + return wrap(unwrap(self)->getArgument(index)); + }) ; // py::class_(m, "float8_type") @@ -756,7 +762,12 @@ void init_triton_ir(py::module &&m) { // Use arith.ConstantOp to create constants // // Constants // .def("get_int1", &ir::builder::get_int1, ret::reference) - // .def("get_int32", [](ir::builder *self, int32_t v) { return self->get_int32((uint32_t)v); }, ret::reference) + .def("get_int32", [](mlir::OpBuilder &self, int64_t v) -> MlirValue { + auto loc = self.getUnknownLoc(); + return wrap(mlir::Value(self.create( + loc, v, self.getI32Type() + ))); + }, ret::reference) // .def("get_uint32", &ir::builder::get_int32, ret::reference) // .def("get_int64", [](ir::builder *self, int64_t v) { return self->get_int64((uint64_t)v); }, ret::reference) // .def("get_uint64", &ir::builder::get_int64, ret::reference) @@ -801,6 +812,11 @@ void init_triton_ir(py::module &&m) { .def("get_double_ty", [](mlir::OpBuilder &self) -> MlirType { return wrap(self.getF64Type()); }, ret::reference) + .def("get_ptr_ty", [](mlir::OpBuilder &self, MlirType &type) -> MlirType { + return wrap( + mlir::triton::PointerType::get(unwrap(type)) + ); + }) .def("get_function_ty", [](mlir::OpBuilder &self, std::vector inTypes, std::vector outTypes) -> MlirType { @@ -829,14 +845,18 @@ void init_triton_ir(py::module &&m) { // .def("create_scf_while") // miscellious - .def("create_make_range", [](mlir::OpBuilder &self, int start, int end){ + .def("create_make_range", [](mlir::OpBuilder &self, int start, int end) -> MlirValue { auto loc = self.getUnknownLoc(); auto retType = mlir::RankedTensorType::get({end-start}, self.getI32Type()); - return wrap(self.create(loc, retType, start, end).getOperation()); + return wrap( + mlir::Value(self.create(loc, retType, start, end)) + ); }, ret::reference) - .def("create_get_program_id", [](mlir::OpBuilder &self, int axis) { + .def("create_get_program_id", [](mlir::OpBuilder &self, int axis) -> MlirValue { auto loc = self.getUnknownLoc(); - return wrap(self.create(loc, self.getI32Type(), axis).getOperation()); + return wrap( + mlir::Value(self.create(loc, self.getI32Type(), axis)) + ); }) // // Cast instructions @@ -853,43 +873,84 @@ void init_triton_ir(py::module &&m) { // .def("create_downcast", &ir::builder::create_downcast, ret::reference) // // Binary instructions // .def("create_insert_nuwnswb_binop", &ir::builder::create_insert_nuwnswb_binop, ret::reference) - // .def("create_fmul", &ir::builder::create_fmul, ret::reference) - // .def("create_fdiv", &ir::builder::create_fdiv, ret::reference) - // .def("create_frem", &ir::builder::create_frem, ret::reference) - // .def("create_fadd", &ir::builder::create_fadd, ret::reference) - // .def("create_fsub", &ir::builder::create_fsub, ret::reference) - .def("create_mul", [](mlir::OpBuilder &self, MlirOperation &lhs, MlirOperation &rhs) -> MlirOperation { + .def("create_fmul", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { + auto loc = self.getUnknownLoc(); + return wrap(mlir::Value( + self.create(loc, unwrap(lhs), unwrap(rhs)) + )); + }, ret::reference) + .def("create_fdiv", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { + auto loc = self.getUnknownLoc(); + return wrap(mlir::Value( + self.create(loc, unwrap(lhs), unwrap(rhs)) + )); + }, ret::reference) + .def("create_frem", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { + auto loc = self.getUnknownLoc(); + return wrap(mlir::Value( + self.create(loc, unwrap(lhs), unwrap(rhs)) + )); + }, ret::reference) + .def("create_fadd", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { + auto loc = self.getUnknownLoc(); + return wrap(mlir::Value( + self.create(loc, unwrap(lhs), unwrap(rhs)) + )); + }, ret::reference) + .def("create_fsub", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { + auto loc = self.getUnknownLoc(); + return wrap(mlir::Value( + self.create(loc, unwrap(lhs), unwrap(rhs)) + )); + }, ret::reference) + .def("create_mul", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { auto loc = self.getUnknownLoc(); // Check lhs & rhs have single result (?) - return wrap(self.create(loc, unwrap(lhs)->getResult(0), unwrap(rhs)->getResult(0)).getOperation()); + return wrap( + mlir::Value(self.create(loc, unwrap(lhs), unwrap(rhs))) + ); }, ret::reference) - .def("create_sdiv", [](mlir::OpBuilder &self, MlirOperation &lhs, MlirOperation &rhs) -> MlirOperation { + .def("create_sdiv", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { auto loc = self.getUnknownLoc(); - return wrap(self.create(loc, unwrap(lhs)->getResult(0), unwrap(rhs)->getResult(0)).getOperation()); + return wrap( + mlir::Value(self.create(loc, unwrap(lhs), unwrap(rhs))) + ); }, ret::reference) - .def("create_udiv", [](mlir::OpBuilder &self, MlirOperation &lhs, MlirOperation &rhs) -> MlirOperation { + .def("create_udiv", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { auto loc = self.getUnknownLoc(); - return wrap(self.create(loc, unwrap(lhs)->getResult(0), unwrap(rhs)->getResult(0)).getOperation()); + return wrap( + mlir::Value(self.create(loc, unwrap(lhs), unwrap(rhs))) + ); }, ret::reference) - .def("create_srem", [](mlir::OpBuilder &self, MlirOperation &lhs, MlirOperation &rhs) -> MlirOperation { + .def("create_srem", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { auto loc = self.getUnknownLoc(); - return wrap(self.create(loc, unwrap(lhs)->getResult(0), unwrap(rhs)->getResult(0)).getOperation()); + return wrap( + mlir::Value(self.create(loc, unwrap(lhs), unwrap(rhs))) + ); }, ret::reference) - .def("create_urem", [](mlir::OpBuilder &self, MlirOperation &lhs, MlirOperation &rhs) -> MlirOperation { + .def("create_urem", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { auto loc = self.getUnknownLoc(); - return wrap(self.create(loc, unwrap(lhs)->getResult(0), unwrap(rhs)->getResult(0)).getOperation()); + return wrap( + mlir::Value(self.create(loc, unwrap(lhs), unwrap(rhs))) + ); }, ret::reference) - .def("create_add", [](mlir::OpBuilder &self, MlirOperation &lhs, MlirOperation &rhs) -> MlirOperation { + .def("create_add", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { auto loc = self.getUnknownLoc(); - return wrap(self.create(loc, unwrap(lhs)->getResult(0), unwrap(rhs)->getResult(0)).getOperation()); + return wrap( + mlir::Value(self.create(loc, unwrap(lhs), unwrap(rhs))) + ); }, ret::reference) - .def("create_sub", [](mlir::OpBuilder &self, MlirOperation &lhs, MlirOperation &rhs) -> MlirOperation { + .def("create_sub", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { auto loc = self.getUnknownLoc(); - return wrap(self.create(loc, unwrap(lhs)->getResult(0), unwrap(rhs)->getResult(0)).getOperation()); + return wrap( + mlir::Value(self.create(loc, unwrap(lhs), unwrap(rhs))) + ); }, ret::reference) - .def("create_shl", [](mlir::OpBuilder &self, MlirOperation &lhs, MlirOperation &rhs) -> MlirOperation { + .def("create_shl", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { auto loc = self.getUnknownLoc(); - return wrap(self.create(loc, unwrap(lhs)->getResult(0), unwrap(rhs)->getResult(0)).getOperation()); + return wrap( + mlir::Value(self.create(loc, unwrap(lhs), unwrap(rhs))) + ); }, ret::reference) // .def("create_lshr", &ir::builder::create_lshr, ret::reference, // py::arg("lhs"), py::arg("rhs"), @@ -897,88 +958,91 @@ void init_triton_ir(py::module &&m) { // .def("create_ashr", &ir::builder::create_ashr, ret::reference, // py::arg("lhs"), py::arg("rhs"), // py::arg("has_nuw")=false, py::arg("has_nsw")=false) - // // GEP - // .def("create_gep", [](mlir::OpBuilder &self, MlirOperation &ptr, MlirOperation &offset) -> MlirOperation { - // auto loc = self.getUnknownLoc(); - // }, ret::reference) + // GEP + .def("create_gep", [](mlir::OpBuilder &self, MlirValue &ptr, MlirValue &offset) -> MlirValue { + auto loc = self.getUnknownLoc(); + return wrap( + mlir::Value(self.create(loc, unwrap(ptr).getType(), unwrap(ptr), unwrap(offset))) + ); + }, ret::reference) // Comparison (int) - .def("create_icmpSLE", [](mlir::OpBuilder &self, MlirOperation &lhs, MlirOperation &rhs) -> MlirOperation { + .def("create_icmpSLE", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { auto loc = self.getUnknownLoc(); - return wrap(self.create( + return wrap(mlir::Value(self.create( loc, mlir::arith::CmpIPredicate::sle, - unwrap(lhs)->getResult(0), unwrap(rhs)->getResult(0) - ).getOperation()); + unwrap(lhs), unwrap(rhs) + ))); }, ret::reference) - .def("create_icmpSLT", [](mlir::OpBuilder &self, MlirOperation &lhs, MlirOperation &rhs) -> MlirOperation { + .def("create_icmpSLT", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { auto loc = self.getUnknownLoc(); - return wrap(self.create( + return wrap(mlir::Value(self.create( loc, mlir::arith::CmpIPredicate::slt, - unwrap(lhs)->getResult(0), unwrap(rhs)->getResult(0) - ).getOperation()); + unwrap(lhs), unwrap(rhs) + ))); }, ret::reference) - .def("create_icmpSGE", [](mlir::OpBuilder &self, MlirOperation &lhs, MlirOperation &rhs) -> MlirOperation { + .def("create_icmpSGE", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { auto loc = self.getUnknownLoc(); - return wrap(self.create( + return wrap(mlir::Value(self.create( loc, mlir::arith::CmpIPredicate::sge, - unwrap(lhs)->getResult(0), unwrap(rhs)->getResult(0) - ).getOperation()); + unwrap(lhs), unwrap(rhs) + ))); }, ret::reference) - .def("create_icmpSGT", [](mlir::OpBuilder &self, MlirOperation &lhs, MlirOperation &rhs) -> MlirOperation { + .def("create_icmpSGT", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { auto loc = self.getUnknownLoc(); - return wrap(self.create( + return wrap(mlir::Value(self.create( loc, mlir::arith::CmpIPredicate::sgt, - unwrap(lhs)->getResult(0), unwrap(rhs)->getResult(0) - ).getOperation()); + unwrap(lhs), unwrap(rhs) + ))); }, ret::reference) - .def("create_icmpULE", [](mlir::OpBuilder &self, MlirOperation &lhs, MlirOperation &rhs) -> MlirOperation { + .def("create_icmpULE", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { auto loc = self.getUnknownLoc(); - return wrap(self.create( + return wrap(mlir::Value(self.create( loc, mlir::arith::CmpIPredicate::ule, - unwrap(lhs)->getResult(0), unwrap(rhs)->getResult(0) - ).getOperation()); + unwrap(lhs), unwrap(rhs) + ))); }, ret::reference) - .def("create_icmpULT", [](mlir::OpBuilder &self, MlirOperation &lhs, MlirOperation &rhs) -> MlirOperation { + .def("create_icmpULT", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { auto loc = self.getUnknownLoc(); - return wrap(self.create( + return wrap(mlir::Value(self.create( loc, mlir::arith::CmpIPredicate::ult, - unwrap(lhs)->getResult(0), unwrap(rhs)->getResult(0) - ).getOperation()); + unwrap(lhs), unwrap(rhs) + ))); }, ret::reference) - .def("create_icmpUGE", [](mlir::OpBuilder &self, MlirOperation &lhs, MlirOperation &rhs) -> MlirOperation { + .def("create_icmpUGE", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { auto loc = self.getUnknownLoc(); - return wrap(self.create( + return wrap(mlir::Value(self.create( loc, mlir::arith::CmpIPredicate::uge, - unwrap(lhs)->getResult(0), unwrap(rhs)->getResult(0) - ).getOperation()); + unwrap(lhs), unwrap(rhs) + ))); }, ret::reference) - .def("create_icmpUGT", [](mlir::OpBuilder &self, MlirOperation &lhs, MlirOperation &rhs) -> MlirOperation { + .def("create_icmpUGT", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { auto loc = self.getUnknownLoc(); - return wrap(self.create( + return wrap(mlir::Value(self.create( loc, mlir::arith::CmpIPredicate::ugt, - unwrap(lhs)->getResult(0), unwrap(rhs)->getResult(0) - ).getOperation()); + unwrap(lhs), unwrap(rhs) + ))); }, ret::reference) - .def("create_icmpEQ", [](mlir::OpBuilder &self, MlirOperation &lhs, MlirOperation &rhs) -> MlirOperation { + .def("create_icmpEQ", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { auto loc = self.getUnknownLoc(); - return wrap(self.create( + return wrap(mlir::Value(self.create( loc, mlir::arith::CmpIPredicate::eq, - unwrap(lhs)->getResult(0), unwrap(rhs)->getResult(0) - ).getOperation()); + unwrap(lhs), unwrap(rhs) + ))); }, ret::reference) - .def("create_icmpNE", [](mlir::OpBuilder &self, MlirOperation &lhs, MlirOperation &rhs) -> MlirOperation { + .def("create_icmpNE", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { auto loc = self.getUnknownLoc(); - return wrap(self.create( + return wrap(mlir::Value(self.create( loc, mlir::arith::CmpIPredicate::ne, - unwrap(lhs)->getResult(0), unwrap(rhs)->getResult(0) - ).getOperation()); + unwrap(lhs), unwrap(rhs) + ))); }, ret::reference) // Comparison (float) - .def("create_fcmpOLT", [](mlir::OpBuilder &self, MlirOperation &lhs, MlirOperation &rhs) -> MlirOperation { + .def("create_fcmpOLT", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { auto loc = self.getUnknownLoc(); - return wrap(self.create( + return wrap(mlir::Value(self.create( loc, mlir::arith::CmpFPredicate::OLT, - unwrap(lhs)->getResult(0), unwrap(rhs)->getResult(0) - ).getOperation()); + unwrap(lhs), unwrap(rhs) + ))); }, ret::reference) // .def("create_fcmpOGT", &ir::builder::create_fcmpOGT, ret::reference) // .def("create_fcmpOLE", &ir::builder::create_fcmpOLE, ret::reference) @@ -996,15 +1060,29 @@ void init_triton_ir(py::module &&m) { // .def("create_xor", &ir::builder::create_xor, ret::reference) // .def("create_or", &ir::builder::create_or, ret::reference) // // Input/Output - // .def("create_load", &ir::builder::create_load, ret::reference) - // .def("create_store", &ir::builder::create_store, ret::reference) + .def("create_load", [](mlir::OpBuilder &self, MlirValue &ptrs) -> MlirValue { + auto loc = self.getUnknownLoc(); + return wrap(mlir::Value( + self.create(loc, unwrap(ptrs)) + )); + }, ret::reference) + .def("create_store", [](mlir::OpBuilder &self, MlirValue &ptrs, MlirValue &value) -> void { + auto loc = self.getUnknownLoc(); + self.create(loc, unwrap(ptrs), unwrap(value)); + }, ret::reference) // .def("create_masked_load", &ir::builder::create_masked_load, ret::reference) // .def("create_masked_store", &ir::builder::create_masked_store, ret::reference) // // Block instruction // .def("create_splat", &ir::builder::create_splat, ret::reference) // .def("create_reshape", &ir::builder::create_reshape, ret::reference) // .def("create_cat", &ir::builder::create_cat, ret::reference) - // .def("create_broadcast", &ir::builder::create_broadcast, ret::reference) + .def("create_broadcast", [](mlir::OpBuilder &self, MlirValue &arg, std::vector &shape) -> MlirValue { + auto loc = self.getUnknownLoc(); + auto argType = unwrap(arg).getType(); + return wrap(mlir::Value(self.create( + loc, mlir::RankedTensorType::get(shape, argType), unwrap(arg) + ))); + }, ret::reference) // // atomic // .def("create_atomic_cas", &ir::builder::create_atomic_cas, ret::reference) // .def("create_atomic_rmw", &ir::builder::create_atomic_rmw, ret::reference)