Add more Ops

This commit is contained in:
Yan Da
2022-03-28 19:50:23 +08:00
parent 0d139ec460
commit 38e67b4293
2 changed files with 150 additions and 53 deletions

View File

@@ -676,19 +676,6 @@ void init_triton_ir(py::module &&m) {
})
;
// py::class_<mlir::triton::Float8Type>(m, "float8_type")
// .def_static("get", &mlir::triton::Float8Type::get);
// py::class_<mlir::triton::BFloat8Type>(m, "bfloat8_type")
// .def_static("get", &mlir::triton::BFloat8Type::get);
// py::class_<mlir::triton::PointerType>(m, "pointer_type")
// .def_static("get", &mlir::triton::PointerType::get);
// py::class_<mlir::FunctionType>(m, "function_type")
// .def_static("get", &mlir::FunctionType::get);
// py::class_<mlir::IntegerType>(m, "integer_type")
// .def_static("get", &mlir::IntegerType::get);
// py::class_<mlir::RankedTensorType>(m, "block_type")
// .def_static("get", &mlir::RankedTensorType::get);
// py::class_<mlir::ModuleOp>(m, "module")
// .def(py::init<std::string, ir::builder &>())
// .def("set_instr_metadata", [](ir::module *self, const std::string &name, ir::value *value) {
@@ -834,17 +821,22 @@ void init_triton_ir(py::module &&m) {
}
throw std::runtime_error("invalid function type");
})
// Structured control flow
.def("create_for", [](mlir::OpBuilder &self, MlirValue &lb, MlirValue &ub,
MlirValue &step) {
auto loc = self.getUnknownLoc();
return wrap(
self.create<mlir::scf::ForOp>(loc, unwrap(lb), unwrap(ub), unwrap(step))
);
})
// .def("create_yield")
// .def("create_if")
// .def("create_while")
// // Structured control flow
// .def("create_for", [](mlir::OpBuilder &self, MlirValue &lb, MlirValue &ub,
// MlirValue &step, std::vector<MlirValue> &initArgs) -> MlirOperation {
// auto loc = self.getUnknownLoc();
// return wrap(self.create<mlir::scf::ForOp>(
// loc, unwrap(lb), unwrap(ub), unwrap(step)).getOperation());
// })
// .def("create_if", [](mlir::OpBuilder &self, MlirValue &condition) -> MlirOperation {
// auto loc = self.getUnknownLoc();
// return wrap(self.create<mlir::scf::IfOp>(loc, unwrap(condition)).getOperation());
// })
// .def("create_yield", [](mlir::OpBuilder &self) -> MlirOperation {
// auto loc = self.getUnknownLoc();
// return wrap(self.create<mlir::scf::YieldOp>(loc).getOperation());
// })
// // .def("create_while")
// miscellious
.def("create_make_range", [](mlir::OpBuilder &self, int start, int end) -> MlirValue {
@@ -861,20 +853,53 @@ void init_triton_ir(py::module &&m) {
);
})
// // Cast instructions
// .def("create_bitcast", &ir::builder::create_bitcast, ret::reference)
// .def("create_cast", &ir::builder::create_cast, ret::reference)
// .def("create_ptr_to_int", &ir::builder::create_ptr_to_int, ret::reference)
// .def("create_si_to_fp", &ir::builder::create_si_to_fp, ret::reference)
// .def("create_ui_to_fp", &ir::builder::create_ui_to_fp, ret::reference)
// .def("create_fp_to_si", &ir::builder::create_fp_to_si, ret::reference)
// .def("create_fp_to_ui", &ir::builder::create_fp_to_ui, ret::reference)
// .def("create_fp_ext", &ir::builder::create_fp_ext, ret::reference)
// .def("create_fp_trunc", &ir::builder::create_fp_trunc, ret::reference)
// .def("create_int_cast", &ir::builder::create_int_cast, ret::reference)
// .def("create_downcast", &ir::builder::create_downcast, ret::reference)
// // Binary instructions
// .def("create_insert_nuwnswb_binop", &ir::builder::create_insert_nuwnswb_binop, ret::reference)
// Cast instructions
.def("create_bitcast", [](mlir::OpBuilder &self, MlirValue &src, MlirType &dstType) -> MlirValue {
auto loc = self.getUnknownLoc();
return wrap(mlir::Value(
self.create<mlir::arith::BitcastOp>(loc, unwrap(dstType), unwrap(src))
));
})
// .def("create_cast", &ir::builder::create_cast)
// .def("create_ptr_to_int", &ir::builder::create_ptr_to_int)
.def("create_si_to_fp", [](mlir::OpBuilder &self, MlirValue &src, MlirType &dstType) -> MlirValue {
auto loc = self.getUnknownLoc();
return wrap(mlir::Value(
self.create<mlir::arith::SIToFPOp>(loc, unwrap(dstType), unwrap(src))
));
})
.def("create_ui_to_fp", [](mlir::OpBuilder &self, MlirValue &src, MlirType &dstType) -> MlirValue {
auto loc = self.getUnknownLoc();
return wrap(mlir::Value(
self.create<mlir::arith::UIToFPOp>(loc, unwrap(dstType), unwrap(src))
));
})
.def("create_fp_to_si", [](mlir::OpBuilder &self, MlirValue &src, MlirType &dstType) -> MlirValue {
auto loc = self.getUnknownLoc();
return wrap(mlir::Value(
self.create<mlir::arith::FPToSIOp>(loc, unwrap(dstType), unwrap(src))
));
})
.def("create_fp_to_ui", [](mlir::OpBuilder &self, MlirValue &src, MlirType &dstType) -> MlirValue {
auto loc = self.getUnknownLoc();
return wrap(mlir::Value(
self.create<mlir::arith::FPToUIOp>(loc, unwrap(dstType), unwrap(src))
));
})
.def("create_fp_ext", [](mlir::OpBuilder &self, MlirValue &src, MlirType &dstType) -> MlirValue {
auto loc = self.getUnknownLoc();
return wrap(mlir::Value(
self.create<mlir::arith::ExtFOp>(loc, unwrap(dstType), unwrap(src))
));
})
.def("create_fp_trunc", [](mlir::OpBuilder &self, MlirValue &src, MlirType &dstType) -> MlirValue {
auto loc = self.getUnknownLoc();
return wrap(mlir::Value(
self.create<mlir::arith::TruncFOp>(loc, unwrap(dstType), unwrap(src))
));
})
// .def("create_int_cast", &ir::builder::create_int_cast)
// .def("create_downcast", &ir::builder::create_downcast)
.def("create_fmul", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
auto loc = self.getUnknownLoc();
return wrap(mlir::Value(
@@ -1159,12 +1184,38 @@ void init_triton_ir(py::module &&m) {
auto loc = self.getUnknownLoc();
self.create<mlir::triton::StoreOp>(loc, unwrap(ptrs), unwrap(value));
})
// .def("create_masked_load", &ir::builder::create_masked_load, ret::reference)
// .def("create_masked_store", &ir::builder::create_masked_store, ret::reference)
// // Block instruction
// .def("create_splat", &ir::builder::create_splat, ret::reference)
// .def("create_reshape", &ir::builder::create_reshape, ret::reference)
// .def("create_cat", &ir::builder::create_cat, ret::reference)
.def("create_masked_load", [](mlir::OpBuilder &self, MlirValue &ptrs, MlirValue &mask, MlirValue &other) -> MlirValue {
auto loc = self.getUnknownLoc();
auto ptrType = unwrap(ptrs).getType().dyn_cast<mlir::RankedTensorType>();
std::vector<int64_t> shape = ptrType.getShape();
mlir::Type elementType = ptrType.getElementType().dyn_cast<mlir::triton::PointerType>().getPointeeType();
return wrap(mlir::Value(self.create<mlir::triton::LoadOp>(
loc, mlir::RankedTensorType::get(shape, elementType), unwrap(ptrs), unwrap(mask), unwrap(other))
));
})
.def("create_masked_store", [](mlir::OpBuilder &self, MlirValue &ptrs, MlirValue &val, MlirValue &mask) -> void {
auto loc = self.getUnknownLoc();
self.create<mlir::triton::StoreOp>(loc, unwrap(ptrs), unwrap(val), unwrap(mask));
})
// Block instruction
.def("create_reshape", [](mlir::OpBuilder &self, MlirValue &arg, std::vector<int64_t> &shape) -> MlirValue {
auto loc = self.getUnknownLoc();
auto argType = unwrap(arg).getType().dyn_cast<mlir::RankedTensorType>();
return wrap(mlir::Value(self.create<mlir::triton::ReshapeOp>(
loc, mlir::RankedTensorType::get(shape, argType), unwrap(arg), self.getI64ArrayAttr(shape)
)));
})
.def("create_cat", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
auto loc = self.getUnknownLoc();
auto lhsType = unwrap(lhs).getType().dyn_cast<mlir::RankedTensorType>();
auto rhsType = unwrap(rhs).getType().dyn_cast<mlir::RankedTensorType>();
if (!(lhsType.getShape().size() == 1 && rhsType.getShape().size() == 1))
throw std::runtime_error("shape not supported by cat. Expecting rank-1 inputs");
std::vector<int64_t> shape {lhsType.getShape()[0] + rhsType.getShape()[0]};
return wrap(mlir::Value(self.create<mlir::triton::CatOp>(
loc, mlir::RankedTensorType::get(shape, lhsType.getElementType()), unwrap(lhs), unwrap(rhs)
)));
})
.def("create_broadcast", [](mlir::OpBuilder &self, MlirValue &arg, std::vector<int64_t> &shape) -> MlirValue {
auto loc = self.getUnknownLoc();
auto argType = unwrap(arg).getType();
@@ -1173,17 +1224,48 @@ void init_triton_ir(py::module &&m) {
)));
})
// // atomic
// .def("create_atomic_cas", &ir::builder::create_atomic_cas, ret::reference)
// .def("create_atomic_rmw", &ir::builder::create_atomic_rmw, ret::reference)
.def("create_atomic_cas", [](mlir::OpBuilder &self, MlirValue &ptr,
MlirValue &cmp, MlirValue &val) -> MlirValue {
auto loc = self.getUnknownLoc();
auto ptrType = unwrap(ptr).getType().dyn_cast<mlir::triton::PointerType>();
mlir::Type dstType = ptrType.getPointeeType();
return wrap(mlir::Value(self.create<mlir::triton::AtomicCASOp>(
loc, dstType, unwrap(ptr), unwrap(cmp), unwrap(val)
)));
})
.def("create_atomic_rmw", [](mlir::OpBuilder &self, mlir::triton::RMWOp rmwOp,
MlirValue &ptr, MlirValue &val, MlirValue &mask) -> MlirValue {
auto loc = self.getUnknownLoc();
auto ptrType = unwrap(ptr).getType().dyn_cast<mlir::triton::PointerType>();
mlir::Type dstType = ptrType.getPointeeType();
return wrap(mlir::Value(self.create<mlir::triton::AtomicRMWOp>(
loc, dstType, rmwOp, unwrap(ptr), unwrap(val), unwrap(mask)
)));
})
// // Built-in instruction
// .def("create_get_program_id", &ir::builder::create_get_program_id, ret::reference)
// .def("create_get_num_programs", &ir::builder::create_get_num_programs, ret::reference)
// Built-in instruction
.def("create_get_program_id", [](mlir::OpBuilder &self, int axis) -> MlirValue {
auto loc = self.getUnknownLoc();
return wrap(mlir::Value(self.create<mlir::triton::GetProgramIdOp>(
loc, self.getI32Type(), self.getI32IntegerAttr(axis)
)));
})
.def("create_get_num_programs", [](mlir::OpBuilder &self, int axis) -> MlirValue {
auto loc = self.getUnknownLoc();
return wrap(mlir::Value(self.create<mlir::triton::GetNumProgramsOp>(
loc, self.getI32Type(), self.getI32IntegerAttr(axis)
)));
})
.def("create_dot", [](mlir::OpBuilder &self, MlirValue &a, MlirValue &b, MlirValue &c) -> MlirValue {
auto loc = self.getUnknownLoc();
return wrap(mlir::Value(self.create<mlir::triton::DotOp>(
loc, unwrap(c).getType(), unwrap(a), unwrap(b), unwrap(c)
)));
})
// .def("create_exp", &ir::builder::create_exp, ret::reference)
// .def("create_cos", &ir::builder::create_cos, ret::reference)
// .def("create_sin", &ir::builder::create_sin, ret::reference)
// .def("create_log", &ir::builder::create_log, ret::reference)
// .def("create_dot", &ir::builder::create_dot, ret::reference)
// .def("create_log", &ir::builder::create_log, ret::reference)
// .def("create_trans", &ir::builder::create_trans, ret::reference)
// .def("create_sqrt", &ir::builder::create_sqrt, ret::reference)
// .def("create_reduce", &ir::builder::create_reduce, ret::reference)