Add ReduceOp
This commit is contained in:
@@ -1312,7 +1312,15 @@ void init_triton_ir(py::module &&m) {
|
||||
// .def("create_log", &ir::builder::create_log, ret::reference)
|
||||
// .def("create_trans", &ir::builder::create_trans, ret::reference)
|
||||
// .def("create_sqrt", &ir::builder::create_sqrt, ret::reference)
|
||||
// .def("create_reduce", &ir::builder::create_reduce, ret::reference)
|
||||
.def("create_reduce", [](mlir::OpBuilder &self, mlir::Value &operand,
|
||||
mlir::triton::RedOp redOp, int axis) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
auto inputTensorType = operand.getType().dyn_cast<mlir::RankedTensorType>();
|
||||
std::vector<int64_t> shape = inputTensorType.getShape();
|
||||
shape.erase(shape.begin() + axis);
|
||||
auto resType = mlir::RankedTensorType::get(shape, inputTensorType.getElementType());
|
||||
return self.create<mlir::triton::ReduceOp>(loc, resType, redOp, operand, axis);
|
||||
})
|
||||
.def("create_select", [](mlir::OpBuilder &self, mlir::Value &condition,
|
||||
mlir::Value &trueValue, mlir::Value &falseValue) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
|
Reference in New Issue
Block a user