diff --git a/include/triton/ir/TritonOps.td b/include/triton/ir/TritonOps.td index 4a6b8d5f1..60b768e32 100644 --- a/include/triton/ir/TritonOps.td +++ b/include/triton/ir/TritonOps.td @@ -26,6 +26,7 @@ def TT_FloatTensor : TensorOf<[TT_Float]>; // IntegerType def TT_Int : AnyTypeOf<[I8, I16, I32, I64], "integer">; def TT_IntegerTensor : TensorOf<[TT_Int]>; +def TT_I1Tensor : TensorOf<[I1]>; // PointerType def TT_IsPtrType : CPred<"$_self.isa<::mlir::triton::PointerType>()">; @@ -145,6 +146,14 @@ def TT_GEPOp : TT_Op<"getelementptr", [NoSideEffect, SameOperandsAndResultShape] // // Shape Manipulation Ops // +def TT_ReshapeOp : TT_Op<"reshape", [SameOperandsAndResultElementType]> { + let summary = "reshape"; + + let arguments = (ins TT_Tensor:$src, I64ArrayAttr:$shape); + + let results = (outs TT_Tensor:$result); +} + def TT_BroadcastOp : TT_Op<"broadcast", [SameOperandsAndResultElementType]> { let summary = "broadcast"; @@ -170,6 +179,12 @@ def TT_GetProgramIdOp : TT_Op<"get_program_id"> { let results = (outs I32:$result); } +def TT_GetNumProgramsOp : TT_Op<"get_num_programs"> { + let arguments = (ins I32Attr:$axis); + + let results = (outs I32:$result); +} + def TT_DotOp : TT_Op<"dot", [NoSideEffect, SameOperandsAndResultShape]> { let summary = "dot"; @@ -227,7 +242,7 @@ def TT_AtomicRMWOp : TT_Op<"atomic_rmw"> { }]; let arguments = (ins TT_AtomicRMWAttr:$atomic_rmw_op, TT_PtrTensor:$ptr, - TT_Type:$val); + TT_Type:$val, TT_I1Tensor:$mask); let results = (outs TT_Type:$result); } @@ -245,7 +260,7 @@ def TT_AtomicCASOp : TT_Op<"atomic_cas"> { return $old }]; - let arguments = (ins TT_PtrTensor:$ptr, TT_Type:$cmp, TT_Type:$val); + let arguments = (ins TT_AnyPtr:$ptr, TT_Type:$cmp, TT_Type:$val); let results = (outs TT_Type:$result); } diff --git a/python/src/triton.cc b/python/src/triton.cc index 19c79f3b8..2bd987d55 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -676,19 +676,6 @@ void init_triton_ir(py::module &&m) { }) ; - // py::class_(m, "float8_type") - // .def_static("get", &mlir::triton::Float8Type::get); - // py::class_(m, "bfloat8_type") - // .def_static("get", &mlir::triton::BFloat8Type::get); - // py::class_(m, "pointer_type") - // .def_static("get", &mlir::triton::PointerType::get); - // py::class_(m, "function_type") - // .def_static("get", &mlir::FunctionType::get); - // py::class_(m, "integer_type") - // .def_static("get", &mlir::IntegerType::get); - // py::class_(m, "block_type") - // .def_static("get", &mlir::RankedTensorType::get); - // py::class_(m, "module") // .def(py::init()) // .def("set_instr_metadata", [](ir::module *self, const std::string &name, ir::value *value) { @@ -834,17 +821,22 @@ void init_triton_ir(py::module &&m) { } throw std::runtime_error("invalid function type"); }) - // Structured control flow - .def("create_for", [](mlir::OpBuilder &self, MlirValue &lb, MlirValue &ub, - MlirValue &step) { - auto loc = self.getUnknownLoc(); - return wrap( - self.create(loc, unwrap(lb), unwrap(ub), unwrap(step)) - ); - }) - // .def("create_yield") - // .def("create_if") - // .def("create_while") + // // Structured control flow + // .def("create_for", [](mlir::OpBuilder &self, MlirValue &lb, MlirValue &ub, + // MlirValue &step, std::vector &initArgs) -> MlirOperation { + // auto loc = self.getUnknownLoc(); + // return wrap(self.create( + // loc, unwrap(lb), unwrap(ub), unwrap(step)).getOperation()); + // }) + // .def("create_if", [](mlir::OpBuilder &self, MlirValue &condition) -> MlirOperation { + // auto loc = self.getUnknownLoc(); + // return wrap(self.create(loc, unwrap(condition)).getOperation()); + // }) + // .def("create_yield", [](mlir::OpBuilder &self) -> MlirOperation { + // auto loc = self.getUnknownLoc(); + // return wrap(self.create(loc).getOperation()); + // }) + // // .def("create_while") // miscellious .def("create_make_range", [](mlir::OpBuilder &self, int start, int end) -> MlirValue { @@ -861,20 +853,53 @@ void init_triton_ir(py::module &&m) { ); }) - // // Cast instructions - // .def("create_bitcast", &ir::builder::create_bitcast, ret::reference) - // .def("create_cast", &ir::builder::create_cast, ret::reference) - // .def("create_ptr_to_int", &ir::builder::create_ptr_to_int, ret::reference) - // .def("create_si_to_fp", &ir::builder::create_si_to_fp, ret::reference) - // .def("create_ui_to_fp", &ir::builder::create_ui_to_fp, ret::reference) - // .def("create_fp_to_si", &ir::builder::create_fp_to_si, ret::reference) - // .def("create_fp_to_ui", &ir::builder::create_fp_to_ui, ret::reference) - // .def("create_fp_ext", &ir::builder::create_fp_ext, ret::reference) - // .def("create_fp_trunc", &ir::builder::create_fp_trunc, ret::reference) - // .def("create_int_cast", &ir::builder::create_int_cast, ret::reference) - // .def("create_downcast", &ir::builder::create_downcast, ret::reference) - // // Binary instructions - // .def("create_insert_nuwnswb_binop", &ir::builder::create_insert_nuwnswb_binop, ret::reference) + // Cast instructions + .def("create_bitcast", [](mlir::OpBuilder &self, MlirValue &src, MlirType &dstType) -> MlirValue { + auto loc = self.getUnknownLoc(); + return wrap(mlir::Value( + self.create(loc, unwrap(dstType), unwrap(src)) + )); + }) + // .def("create_cast", &ir::builder::create_cast) + // .def("create_ptr_to_int", &ir::builder::create_ptr_to_int) + .def("create_si_to_fp", [](mlir::OpBuilder &self, MlirValue &src, MlirType &dstType) -> MlirValue { + auto loc = self.getUnknownLoc(); + return wrap(mlir::Value( + self.create(loc, unwrap(dstType), unwrap(src)) + )); + }) + .def("create_ui_to_fp", [](mlir::OpBuilder &self, MlirValue &src, MlirType &dstType) -> MlirValue { + auto loc = self.getUnknownLoc(); + return wrap(mlir::Value( + self.create(loc, unwrap(dstType), unwrap(src)) + )); + }) + .def("create_fp_to_si", [](mlir::OpBuilder &self, MlirValue &src, MlirType &dstType) -> MlirValue { + auto loc = self.getUnknownLoc(); + return wrap(mlir::Value( + self.create(loc, unwrap(dstType), unwrap(src)) + )); + }) + .def("create_fp_to_ui", [](mlir::OpBuilder &self, MlirValue &src, MlirType &dstType) -> MlirValue { + auto loc = self.getUnknownLoc(); + return wrap(mlir::Value( + self.create(loc, unwrap(dstType), unwrap(src)) + )); + }) + .def("create_fp_ext", [](mlir::OpBuilder &self, MlirValue &src, MlirType &dstType) -> MlirValue { + auto loc = self.getUnknownLoc(); + return wrap(mlir::Value( + self.create(loc, unwrap(dstType), unwrap(src)) + )); + }) + .def("create_fp_trunc", [](mlir::OpBuilder &self, MlirValue &src, MlirType &dstType) -> MlirValue { + auto loc = self.getUnknownLoc(); + return wrap(mlir::Value( + self.create(loc, unwrap(dstType), unwrap(src)) + )); + }) + // .def("create_int_cast", &ir::builder::create_int_cast) + // .def("create_downcast", &ir::builder::create_downcast) .def("create_fmul", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { auto loc = self.getUnknownLoc(); return wrap(mlir::Value( @@ -1159,12 +1184,38 @@ void init_triton_ir(py::module &&m) { auto loc = self.getUnknownLoc(); self.create(loc, unwrap(ptrs), unwrap(value)); }) - // .def("create_masked_load", &ir::builder::create_masked_load, ret::reference) - // .def("create_masked_store", &ir::builder::create_masked_store, ret::reference) - // // Block instruction - // .def("create_splat", &ir::builder::create_splat, ret::reference) - // .def("create_reshape", &ir::builder::create_reshape, ret::reference) - // .def("create_cat", &ir::builder::create_cat, ret::reference) + .def("create_masked_load", [](mlir::OpBuilder &self, MlirValue &ptrs, MlirValue &mask, MlirValue &other) -> MlirValue { + auto loc = self.getUnknownLoc(); + auto ptrType = unwrap(ptrs).getType().dyn_cast(); + std::vector shape = ptrType.getShape(); + mlir::Type elementType = ptrType.getElementType().dyn_cast().getPointeeType(); + return wrap(mlir::Value(self.create( + loc, mlir::RankedTensorType::get(shape, elementType), unwrap(ptrs), unwrap(mask), unwrap(other)) + )); + }) + .def("create_masked_store", [](mlir::OpBuilder &self, MlirValue &ptrs, MlirValue &val, MlirValue &mask) -> void { + auto loc = self.getUnknownLoc(); + self.create(loc, unwrap(ptrs), unwrap(val), unwrap(mask)); + }) + // Block instruction + .def("create_reshape", [](mlir::OpBuilder &self, MlirValue &arg, std::vector &shape) -> MlirValue { + auto loc = self.getUnknownLoc(); + auto argType = unwrap(arg).getType().dyn_cast(); + return wrap(mlir::Value(self.create( + loc, mlir::RankedTensorType::get(shape, argType), unwrap(arg), self.getI64ArrayAttr(shape) + ))); + }) + .def("create_cat", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { + auto loc = self.getUnknownLoc(); + auto lhsType = unwrap(lhs).getType().dyn_cast(); + auto rhsType = unwrap(rhs).getType().dyn_cast(); + if (!(lhsType.getShape().size() == 1 && rhsType.getShape().size() == 1)) + throw std::runtime_error("shape not supported by cat. Expecting rank-1 inputs"); + std::vector shape {lhsType.getShape()[0] + rhsType.getShape()[0]}; + return wrap(mlir::Value(self.create( + loc, mlir::RankedTensorType::get(shape, lhsType.getElementType()), unwrap(lhs), unwrap(rhs) + ))); + }) .def("create_broadcast", [](mlir::OpBuilder &self, MlirValue &arg, std::vector &shape) -> MlirValue { auto loc = self.getUnknownLoc(); auto argType = unwrap(arg).getType(); @@ -1173,17 +1224,48 @@ void init_triton_ir(py::module &&m) { ))); }) // // atomic - // .def("create_atomic_cas", &ir::builder::create_atomic_cas, ret::reference) - // .def("create_atomic_rmw", &ir::builder::create_atomic_rmw, ret::reference) + .def("create_atomic_cas", [](mlir::OpBuilder &self, MlirValue &ptr, + MlirValue &cmp, MlirValue &val) -> MlirValue { + auto loc = self.getUnknownLoc(); + auto ptrType = unwrap(ptr).getType().dyn_cast(); + mlir::Type dstType = ptrType.getPointeeType(); + return wrap(mlir::Value(self.create( + loc, dstType, unwrap(ptr), unwrap(cmp), unwrap(val) + ))); + }) + .def("create_atomic_rmw", [](mlir::OpBuilder &self, mlir::triton::RMWOp rmwOp, + MlirValue &ptr, MlirValue &val, MlirValue &mask) -> MlirValue { + auto loc = self.getUnknownLoc(); + auto ptrType = unwrap(ptr).getType().dyn_cast(); + mlir::Type dstType = ptrType.getPointeeType(); + return wrap(mlir::Value(self.create( + loc, dstType, rmwOp, unwrap(ptr), unwrap(val), unwrap(mask) + ))); + }) - // // Built-in instruction - // .def("create_get_program_id", &ir::builder::create_get_program_id, ret::reference) - // .def("create_get_num_programs", &ir::builder::create_get_num_programs, ret::reference) + // Built-in instruction + .def("create_get_program_id", [](mlir::OpBuilder &self, int axis) -> MlirValue { + auto loc = self.getUnknownLoc(); + return wrap(mlir::Value(self.create( + loc, self.getI32Type(), self.getI32IntegerAttr(axis) + ))); + }) + .def("create_get_num_programs", [](mlir::OpBuilder &self, int axis) -> MlirValue { + auto loc = self.getUnknownLoc(); + return wrap(mlir::Value(self.create( + loc, self.getI32Type(), self.getI32IntegerAttr(axis) + ))); + }) + .def("create_dot", [](mlir::OpBuilder &self, MlirValue &a, MlirValue &b, MlirValue &c) -> MlirValue { + auto loc = self.getUnknownLoc(); + return wrap(mlir::Value(self.create( + loc, unwrap(c).getType(), unwrap(a), unwrap(b), unwrap(c) + ))); + }) // .def("create_exp", &ir::builder::create_exp, ret::reference) // .def("create_cos", &ir::builder::create_cos, ret::reference) // .def("create_sin", &ir::builder::create_sin, ret::reference) - // .def("create_log", &ir::builder::create_log, ret::reference) - // .def("create_dot", &ir::builder::create_dot, ret::reference) + // .def("create_log", &ir::builder::create_log, ret::reference) // .def("create_trans", &ir::builder::create_trans, ret::reference) // .def("create_sqrt", &ir::builder::create_sqrt, ret::reference) // .def("create_reduce", &ir::builder::create_reduce, ret::reference)