Add ReduceOp

This commit is contained in:
Yan Da
2022-05-25 14:15:36 +08:00
parent a2c9f919a8
commit 9b670cfb9f
6 changed files with 29 additions and 8 deletions

View File

@@ -1312,7 +1312,15 @@ void init_triton_ir(py::module &&m) {
// .def("create_log", &ir::builder::create_log, ret::reference)
// .def("create_trans", &ir::builder::create_trans, ret::reference)
// .def("create_sqrt", &ir::builder::create_sqrt, ret::reference)
// .def("create_reduce", &ir::builder::create_reduce, ret::reference)
.def("create_reduce", [](mlir::OpBuilder &self, mlir::Value &operand,
mlir::triton::RedOp redOp, int axis) -> mlir::Value {
auto loc = self.getUnknownLoc();
auto inputTensorType = operand.getType().dyn_cast<mlir::RankedTensorType>();
std::vector<int64_t> shape = inputTensorType.getShape();
shape.erase(shape.begin() + axis);
auto resType = mlir::RankedTensorType::get(shape, inputTensorType.getElementType());
return self.create<mlir::triton::ReduceOp>(loc, resType, redOp, operand, axis);
})
.def("create_select", [](mlir::OpBuilder &self, mlir::Value &condition,
mlir::Value &trueValue, mlir::Value &falseValue) -> mlir::Value {
auto loc = self.getUnknownLoc();