[Triton-MLIR][BACKEND] Add argmin / argmax implementation for ReduceOp (#918)
This commit is contained in:
@@ -20,8 +20,6 @@ SmallVector<unsigned>
|
||||
getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
|
||||
unsigned &outVec);
|
||||
|
||||
SmallVector<unsigned> getScratchConfigForReduce(triton::ReduceOp op);
|
||||
|
||||
} // namespace triton
|
||||
|
||||
/// Modified from llvm-15.0: llvm/ADT/AddressRanges.h
|
||||
|
@@ -26,6 +26,12 @@ public:
|
||||
|
||||
unsigned getThreadsReductionAxis();
|
||||
|
||||
SmallVector<unsigned> getScratchConfigBasic();
|
||||
|
||||
SmallVector<SmallVector<unsigned>> getScratchConfigsFast();
|
||||
|
||||
unsigned getScratchSizeInBytes();
|
||||
|
||||
private:
|
||||
triton::ReduceOp op;
|
||||
RankedTensorType srcTy{};
|
||||
@@ -39,6 +45,14 @@ bool maybeAliasOp(Operation *op);
|
||||
|
||||
std::string getValueOperandName(Value value, AsmState &state);
|
||||
|
||||
template <typename T_OUT, typename T_IN>
|
||||
inline SmallVector<T_OUT> convertType(ArrayRef<T_IN> in) {
|
||||
SmallVector<T_OUT> out;
|
||||
for (const T_IN &i : in)
|
||||
out.push_back(T_OUT(i));
|
||||
return out;
|
||||
}
|
||||
|
||||
template <typename Int> Int product(llvm::ArrayRef<Int> arr) {
|
||||
return std::accumulate(arr.begin(), arr.end(), 1, std::multiplies{});
|
||||
}
|
||||
|
@@ -351,6 +351,11 @@ def TT_ReduceOp : TT_Op<"reduce", [NoSideEffect,
|
||||
|
||||
let assemblyFormat = "$operand attr-dict `:` type($operand) `->` type($result)";
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// This member function is marked static because we need to call it before the ReduceOp
|
||||
// is constructed, see the implementation of create_reduce in triton.cc.
|
||||
static bool withIndex(mlir::triton::RedOp redOp);
|
||||
}];
|
||||
}
|
||||
|
||||
//
|
||||
|
Reference in New Issue
Block a user