[Triton-MLIR][BACKEND] Add argmin / argmax implementation for ReduceOp (#918)
This commit is contained in:
@@ -240,12 +240,16 @@ mlir::LogicalResult mlir::triton::ReduceOp::inferReturnTypes(
|
||||
Value arg = operands[0];
|
||||
auto argTy = arg.getType().cast<RankedTensorType>();
|
||||
auto argEltTy = argTy.getElementType();
|
||||
auto i32Ty = IntegerType::get(argEltTy.getContext(), 32);
|
||||
auto redOp = attributes.get("redOp").cast<mlir::triton::RedOpAttr>().getValue();
|
||||
bool withIndex = mlir::triton::ReduceOp::withIndex(redOp);
|
||||
auto retEltTy = withIndex ? i32Ty : argEltTy;
|
||||
auto retShape = argTy.getShape().vec();
|
||||
int axis = attributes.get("axis").cast<IntegerAttr>().getInt();
|
||||
retShape.erase(retShape.begin() + axis);
|
||||
if (retShape.empty()) {
|
||||
// 0d-tensor -> scalar
|
||||
inferredReturnTypes.push_back(argEltTy);
|
||||
inferredReturnTypes.push_back(retEltTy);
|
||||
} else {
|
||||
// nd-tensor where n >= 1
|
||||
// infer encoding
|
||||
@@ -264,11 +268,20 @@ mlir::LogicalResult mlir::triton::ReduceOp::inferReturnTypes(
|
||||
}
|
||||
// create type
|
||||
inferredReturnTypes.push_back(
|
||||
RankedTensorType::get(retShape, argEltTy, retEncoding));
|
||||
RankedTensorType::get(retShape, retEltTy, retEncoding));
|
||||
}
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
bool mlir::triton::ReduceOp::withIndex(mlir::triton::RedOp redOp) {
|
||||
return redOp == mlir::triton::RedOp::ARGMIN ||
|
||||
redOp == mlir::triton::RedOp::ARGMAX ||
|
||||
redOp == mlir::triton::RedOp::ARGUMIN ||
|
||||
redOp == mlir::triton::RedOp::ARGUMAX ||
|
||||
redOp == mlir::triton::RedOp::ARGFMIN ||
|
||||
redOp == mlir::triton::RedOp::ARGFMAX;
|
||||
}
|
||||
|
||||
//-- SplatOp --
|
||||
OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
|
||||
auto constOperand = src().getDefiningOp<arith::ConstantOp>();
|
||||
|
Reference in New Issue
Block a user