[Triton-MLIR][BACKEND] Add argmin / argmax implementation for ReduceOp (#918)

This commit is contained in:
Qingyi Liu
2022-11-28 14:59:27 +08:00
committed by GitHub
parent 04ec5deb41
commit 9d31998a9d
12 changed files with 341 additions and 75 deletions

View File

@@ -1195,10 +1195,11 @@ void init_triton_ir(py::module &&m) {
operand.getType().dyn_cast<mlir::RankedTensorType>();
std::vector<int64_t> shape = inputTensorType.getShape();
shape.erase(shape.begin() + axis);
mlir::Type resType = inputTensorType.getElementType();
bool withIndex = mlir::triton::ReduceOp::withIndex(redOp);
mlir::Type resType = withIndex ? self.getI32Type()
: inputTensorType.getElementType();
if (!shape.empty()) {
resType = mlir::RankedTensorType::get(
shape, inputTensorType.getElementType());
resType = mlir::RankedTensorType::get(shape, resType);
}
return self.create<mlir::triton::ReduceOp>(loc, resType, redOp,
operand, axis);