[Triton-MLIR][BACKEND] Add argmin / argmax implementation for ReduceOp (#918)
This commit is contained in:
@@ -1195,10 +1195,11 @@ void init_triton_ir(py::module &&m) {
|
||||
operand.getType().dyn_cast<mlir::RankedTensorType>();
|
||||
std::vector<int64_t> shape = inputTensorType.getShape();
|
||||
shape.erase(shape.begin() + axis);
|
||||
mlir::Type resType = inputTensorType.getElementType();
|
||||
bool withIndex = mlir::triton::ReduceOp::withIndex(redOp);
|
||||
mlir::Type resType = withIndex ? self.getI32Type()
|
||||
: inputTensorType.getElementType();
|
||||
if (!shape.empty()) {
|
||||
resType = mlir::RankedTensorType::get(
|
||||
shape, inputTensorType.getElementType());
|
||||
resType = mlir::RankedTensorType::get(shape, resType);
|
||||
}
|
||||
return self.create<mlir::triton::ReduceOp>(loc, resType, redOp,
|
||||
operand, axis);
|
||||
|
Reference in New Issue
Block a user