Add more Ops
This commit is contained in:
@@ -676,19 +676,6 @@ void init_triton_ir(py::module &&m) {
|
||||
})
|
||||
;
|
||||
|
||||
// py::class_<mlir::triton::Float8Type>(m, "float8_type")
|
||||
// .def_static("get", &mlir::triton::Float8Type::get);
|
||||
// py::class_<mlir::triton::BFloat8Type>(m, "bfloat8_type")
|
||||
// .def_static("get", &mlir::triton::BFloat8Type::get);
|
||||
// py::class_<mlir::triton::PointerType>(m, "pointer_type")
|
||||
// .def_static("get", &mlir::triton::PointerType::get);
|
||||
// py::class_<mlir::FunctionType>(m, "function_type")
|
||||
// .def_static("get", &mlir::FunctionType::get);
|
||||
// py::class_<mlir::IntegerType>(m, "integer_type")
|
||||
// .def_static("get", &mlir::IntegerType::get);
|
||||
// py::class_<mlir::RankedTensorType>(m, "block_type")
|
||||
// .def_static("get", &mlir::RankedTensorType::get);
|
||||
|
||||
// py::class_<mlir::ModuleOp>(m, "module")
|
||||
// .def(py::init<std::string, ir::builder &>())
|
||||
// .def("set_instr_metadata", [](ir::module *self, const std::string &name, ir::value *value) {
|
||||
@@ -834,17 +821,22 @@ void init_triton_ir(py::module &&m) {
|
||||
}
|
||||
throw std::runtime_error("invalid function type");
|
||||
})
|
||||
// Structured control flow
|
||||
.def("create_for", [](mlir::OpBuilder &self, MlirValue &lb, MlirValue &ub,
|
||||
MlirValue &step) {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return wrap(
|
||||
self.create<mlir::scf::ForOp>(loc, unwrap(lb), unwrap(ub), unwrap(step))
|
||||
);
|
||||
})
|
||||
// .def("create_yield")
|
||||
// .def("create_if")
|
||||
// .def("create_while")
|
||||
// // Structured control flow
|
||||
// .def("create_for", [](mlir::OpBuilder &self, MlirValue &lb, MlirValue &ub,
|
||||
// MlirValue &step, std::vector<MlirValue> &initArgs) -> MlirOperation {
|
||||
// auto loc = self.getUnknownLoc();
|
||||
// return wrap(self.create<mlir::scf::ForOp>(
|
||||
// loc, unwrap(lb), unwrap(ub), unwrap(step)).getOperation());
|
||||
// })
|
||||
// .def("create_if", [](mlir::OpBuilder &self, MlirValue &condition) -> MlirOperation {
|
||||
// auto loc = self.getUnknownLoc();
|
||||
// return wrap(self.create<mlir::scf::IfOp>(loc, unwrap(condition)).getOperation());
|
||||
// })
|
||||
// .def("create_yield", [](mlir::OpBuilder &self) -> MlirOperation {
|
||||
// auto loc = self.getUnknownLoc();
|
||||
// return wrap(self.create<mlir::scf::YieldOp>(loc).getOperation());
|
||||
// })
|
||||
// // .def("create_while")
|
||||
|
||||
// miscellious
|
||||
.def("create_make_range", [](mlir::OpBuilder &self, int start, int end) -> MlirValue {
|
||||
@@ -861,20 +853,53 @@ void init_triton_ir(py::module &&m) {
|
||||
);
|
||||
})
|
||||
|
||||
// // Cast instructions
|
||||
// .def("create_bitcast", &ir::builder::create_bitcast, ret::reference)
|
||||
// .def("create_cast", &ir::builder::create_cast, ret::reference)
|
||||
// .def("create_ptr_to_int", &ir::builder::create_ptr_to_int, ret::reference)
|
||||
// .def("create_si_to_fp", &ir::builder::create_si_to_fp, ret::reference)
|
||||
// .def("create_ui_to_fp", &ir::builder::create_ui_to_fp, ret::reference)
|
||||
// .def("create_fp_to_si", &ir::builder::create_fp_to_si, ret::reference)
|
||||
// .def("create_fp_to_ui", &ir::builder::create_fp_to_ui, ret::reference)
|
||||
// .def("create_fp_ext", &ir::builder::create_fp_ext, ret::reference)
|
||||
// .def("create_fp_trunc", &ir::builder::create_fp_trunc, ret::reference)
|
||||
// .def("create_int_cast", &ir::builder::create_int_cast, ret::reference)
|
||||
// .def("create_downcast", &ir::builder::create_downcast, ret::reference)
|
||||
// // Binary instructions
|
||||
// .def("create_insert_nuwnswb_binop", &ir::builder::create_insert_nuwnswb_binop, ret::reference)
|
||||
// Cast instructions
|
||||
.def("create_bitcast", [](mlir::OpBuilder &self, MlirValue &src, MlirType &dstType) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return wrap(mlir::Value(
|
||||
self.create<mlir::arith::BitcastOp>(loc, unwrap(dstType), unwrap(src))
|
||||
));
|
||||
})
|
||||
// .def("create_cast", &ir::builder::create_cast)
|
||||
// .def("create_ptr_to_int", &ir::builder::create_ptr_to_int)
|
||||
.def("create_si_to_fp", [](mlir::OpBuilder &self, MlirValue &src, MlirType &dstType) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return wrap(mlir::Value(
|
||||
self.create<mlir::arith::SIToFPOp>(loc, unwrap(dstType), unwrap(src))
|
||||
));
|
||||
})
|
||||
.def("create_ui_to_fp", [](mlir::OpBuilder &self, MlirValue &src, MlirType &dstType) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return wrap(mlir::Value(
|
||||
self.create<mlir::arith::UIToFPOp>(loc, unwrap(dstType), unwrap(src))
|
||||
));
|
||||
})
|
||||
.def("create_fp_to_si", [](mlir::OpBuilder &self, MlirValue &src, MlirType &dstType) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return wrap(mlir::Value(
|
||||
self.create<mlir::arith::FPToSIOp>(loc, unwrap(dstType), unwrap(src))
|
||||
));
|
||||
})
|
||||
.def("create_fp_to_ui", [](mlir::OpBuilder &self, MlirValue &src, MlirType &dstType) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return wrap(mlir::Value(
|
||||
self.create<mlir::arith::FPToUIOp>(loc, unwrap(dstType), unwrap(src))
|
||||
));
|
||||
})
|
||||
.def("create_fp_ext", [](mlir::OpBuilder &self, MlirValue &src, MlirType &dstType) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return wrap(mlir::Value(
|
||||
self.create<mlir::arith::ExtFOp>(loc, unwrap(dstType), unwrap(src))
|
||||
));
|
||||
})
|
||||
.def("create_fp_trunc", [](mlir::OpBuilder &self, MlirValue &src, MlirType &dstType) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return wrap(mlir::Value(
|
||||
self.create<mlir::arith::TruncFOp>(loc, unwrap(dstType), unwrap(src))
|
||||
));
|
||||
})
|
||||
// .def("create_int_cast", &ir::builder::create_int_cast)
|
||||
// .def("create_downcast", &ir::builder::create_downcast)
|
||||
.def("create_fmul", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return wrap(mlir::Value(
|
||||
@@ -1159,12 +1184,38 @@ void init_triton_ir(py::module &&m) {
|
||||
auto loc = self.getUnknownLoc();
|
||||
self.create<mlir::triton::StoreOp>(loc, unwrap(ptrs), unwrap(value));
|
||||
})
|
||||
// .def("create_masked_load", &ir::builder::create_masked_load, ret::reference)
|
||||
// .def("create_masked_store", &ir::builder::create_masked_store, ret::reference)
|
||||
// // Block instruction
|
||||
// .def("create_splat", &ir::builder::create_splat, ret::reference)
|
||||
// .def("create_reshape", &ir::builder::create_reshape, ret::reference)
|
||||
// .def("create_cat", &ir::builder::create_cat, ret::reference)
|
||||
.def("create_masked_load", [](mlir::OpBuilder &self, MlirValue &ptrs, MlirValue &mask, MlirValue &other) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
auto ptrType = unwrap(ptrs).getType().dyn_cast<mlir::RankedTensorType>();
|
||||
std::vector<int64_t> shape = ptrType.getShape();
|
||||
mlir::Type elementType = ptrType.getElementType().dyn_cast<mlir::triton::PointerType>().getPointeeType();
|
||||
return wrap(mlir::Value(self.create<mlir::triton::LoadOp>(
|
||||
loc, mlir::RankedTensorType::get(shape, elementType), unwrap(ptrs), unwrap(mask), unwrap(other))
|
||||
));
|
||||
})
|
||||
.def("create_masked_store", [](mlir::OpBuilder &self, MlirValue &ptrs, MlirValue &val, MlirValue &mask) -> void {
|
||||
auto loc = self.getUnknownLoc();
|
||||
self.create<mlir::triton::StoreOp>(loc, unwrap(ptrs), unwrap(val), unwrap(mask));
|
||||
})
|
||||
// Block instruction
|
||||
.def("create_reshape", [](mlir::OpBuilder &self, MlirValue &arg, std::vector<int64_t> &shape) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
auto argType = unwrap(arg).getType().dyn_cast<mlir::RankedTensorType>();
|
||||
return wrap(mlir::Value(self.create<mlir::triton::ReshapeOp>(
|
||||
loc, mlir::RankedTensorType::get(shape, argType), unwrap(arg), self.getI64ArrayAttr(shape)
|
||||
)));
|
||||
})
|
||||
.def("create_cat", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
auto lhsType = unwrap(lhs).getType().dyn_cast<mlir::RankedTensorType>();
|
||||
auto rhsType = unwrap(rhs).getType().dyn_cast<mlir::RankedTensorType>();
|
||||
if (!(lhsType.getShape().size() == 1 && rhsType.getShape().size() == 1))
|
||||
throw std::runtime_error("shape not supported by cat. Expecting rank-1 inputs");
|
||||
std::vector<int64_t> shape {lhsType.getShape()[0] + rhsType.getShape()[0]};
|
||||
return wrap(mlir::Value(self.create<mlir::triton::CatOp>(
|
||||
loc, mlir::RankedTensorType::get(shape, lhsType.getElementType()), unwrap(lhs), unwrap(rhs)
|
||||
)));
|
||||
})
|
||||
.def("create_broadcast", [](mlir::OpBuilder &self, MlirValue &arg, std::vector<int64_t> &shape) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
auto argType = unwrap(arg).getType();
|
||||
@@ -1173,17 +1224,48 @@ void init_triton_ir(py::module &&m) {
|
||||
)));
|
||||
})
|
||||
// // atomic
|
||||
// .def("create_atomic_cas", &ir::builder::create_atomic_cas, ret::reference)
|
||||
// .def("create_atomic_rmw", &ir::builder::create_atomic_rmw, ret::reference)
|
||||
.def("create_atomic_cas", [](mlir::OpBuilder &self, MlirValue &ptr,
|
||||
MlirValue &cmp, MlirValue &val) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
auto ptrType = unwrap(ptr).getType().dyn_cast<mlir::triton::PointerType>();
|
||||
mlir::Type dstType = ptrType.getPointeeType();
|
||||
return wrap(mlir::Value(self.create<mlir::triton::AtomicCASOp>(
|
||||
loc, dstType, unwrap(ptr), unwrap(cmp), unwrap(val)
|
||||
)));
|
||||
})
|
||||
.def("create_atomic_rmw", [](mlir::OpBuilder &self, mlir::triton::RMWOp rmwOp,
|
||||
MlirValue &ptr, MlirValue &val, MlirValue &mask) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
auto ptrType = unwrap(ptr).getType().dyn_cast<mlir::triton::PointerType>();
|
||||
mlir::Type dstType = ptrType.getPointeeType();
|
||||
return wrap(mlir::Value(self.create<mlir::triton::AtomicRMWOp>(
|
||||
loc, dstType, rmwOp, unwrap(ptr), unwrap(val), unwrap(mask)
|
||||
)));
|
||||
})
|
||||
|
||||
// // Built-in instruction
|
||||
// .def("create_get_program_id", &ir::builder::create_get_program_id, ret::reference)
|
||||
// .def("create_get_num_programs", &ir::builder::create_get_num_programs, ret::reference)
|
||||
// Built-in instruction
|
||||
.def("create_get_program_id", [](mlir::OpBuilder &self, int axis) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return wrap(mlir::Value(self.create<mlir::triton::GetProgramIdOp>(
|
||||
loc, self.getI32Type(), self.getI32IntegerAttr(axis)
|
||||
)));
|
||||
})
|
||||
.def("create_get_num_programs", [](mlir::OpBuilder &self, int axis) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return wrap(mlir::Value(self.create<mlir::triton::GetNumProgramsOp>(
|
||||
loc, self.getI32Type(), self.getI32IntegerAttr(axis)
|
||||
)));
|
||||
})
|
||||
.def("create_dot", [](mlir::OpBuilder &self, MlirValue &a, MlirValue &b, MlirValue &c) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return wrap(mlir::Value(self.create<mlir::triton::DotOp>(
|
||||
loc, unwrap(c).getType(), unwrap(a), unwrap(b), unwrap(c)
|
||||
)));
|
||||
})
|
||||
// .def("create_exp", &ir::builder::create_exp, ret::reference)
|
||||
// .def("create_cos", &ir::builder::create_cos, ret::reference)
|
||||
// .def("create_sin", &ir::builder::create_sin, ret::reference)
|
||||
// .def("create_log", &ir::builder::create_log, ret::reference)
|
||||
// .def("create_dot", &ir::builder::create_dot, ret::reference)
|
||||
// .def("create_log", &ir::builder::create_log, ret::reference)
|
||||
// .def("create_trans", &ir::builder::create_trans, ret::reference)
|
||||
// .def("create_sqrt", &ir::builder::create_sqrt, ret::reference)
|
||||
// .def("create_reduce", &ir::builder::create_reduce, ret::reference)
|
||||
|
Reference in New Issue
Block a user