[Triton-MLIR][BACKEND] Add argmin / argmax implementation for ReduceOp (#918)

This commit is contained in:
Qingyi Liu
2022-11-28 14:59:27 +08:00
committed by GitHub
parent 04ec5deb41
commit 9d31998a9d
12 changed files with 341 additions and 75 deletions

View File

@@ -20,8 +20,6 @@ SmallVector<unsigned>
getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
unsigned &outVec);
SmallVector<unsigned> getScratchConfigForReduce(triton::ReduceOp op);
} // namespace triton
/// Modified from llvm-15.0: llvm/ADT/AddressRanges.h

View File

@@ -26,6 +26,12 @@ public:
unsigned getThreadsReductionAxis();
SmallVector<unsigned> getScratchConfigBasic();
SmallVector<SmallVector<unsigned>> getScratchConfigsFast();
unsigned getScratchSizeInBytes();
private:
triton::ReduceOp op;
RankedTensorType srcTy{};
@@ -39,6 +45,14 @@ bool maybeAliasOp(Operation *op);
std::string getValueOperandName(Value value, AsmState &state);
template <typename T_OUT, typename T_IN>
inline SmallVector<T_OUT> convertType(ArrayRef<T_IN> in) {
SmallVector<T_OUT> out;
for (const T_IN &i : in)
out.push_back(T_OUT(i));
return out;
}
template <typename Int> Int product(llvm::ArrayRef<Int> arr) {
return std::accumulate(arr.begin(), arr.end(), 1, std::multiplies{});
}

View File

@@ -351,6 +351,11 @@ def TT_ReduceOp : TT_Op<"reduce", [NoSideEffect,
let assemblyFormat = "$operand attr-dict `:` type($operand) `->` type($result)";
let extraClassDeclaration = [{
// This member function is marked static because we need to call it before the ReduceOp
// is constructed, see the implementation of create_reduce in triton.cc.
static bool withIndex(mlir::triton::RedOp redOp);
}];
}
//