[Triton-MLIR][BACKEND] Add argmin / argmax implementation for ReduceOp (#918)
This commit is contained in:
@@ -37,6 +37,55 @@ unsigned ReduceOpHelper::getThreadsReductionAxis() {
|
||||
triton::gpu::getWarpsPerCTA(srcLayout)[axis];
|
||||
}
|
||||
|
||||
SmallVector<unsigned> ReduceOpHelper::getScratchConfigBasic() {
|
||||
auto axis = op.axis();
|
||||
auto smemShape = convertType<unsigned>(getSrcShape());
|
||||
smemShape[axis] = std::min(smemShape[axis], getThreadsReductionAxis());
|
||||
return smemShape;
|
||||
}
|
||||
|
||||
SmallVector<SmallVector<unsigned>> ReduceOpHelper::getScratchConfigsFast() {
|
||||
auto axis = op.axis();
|
||||
SmallVector<SmallVector<unsigned>> smemShapes(3);
|
||||
|
||||
/// shared memory block0
|
||||
smemShapes[0] = convertType<unsigned>(getSrcShape());
|
||||
smemShapes[0][axis] = getInterWarpSize();
|
||||
|
||||
/// FIXME(Qingyi): This size is actually larger than required.
|
||||
/// shared memory block1:
|
||||
auto mod = op.getOperation()->getParentOfType<ModuleOp>();
|
||||
unsigned numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
|
||||
smemShapes[1].push_back(numWarps * 32);
|
||||
|
||||
/// FIXME(Qingyi): This requirement is actually not necessary, because it is
|
||||
/// always smaller than smemShapes[0] shared memory block2
|
||||
smemShapes[2] = convertType<unsigned>(getSrcShape());
|
||||
smemShapes[2].erase(smemShapes[2].begin() + axis);
|
||||
|
||||
return smemShapes;
|
||||
}
|
||||
|
||||
unsigned ReduceOpHelper::getScratchSizeInBytes() {
|
||||
unsigned elems = 0;
|
||||
if (isFastReduction()) {
|
||||
auto smemShapes = getScratchConfigsFast();
|
||||
for (const auto &smemShape : smemShapes)
|
||||
elems = std::max(elems, product<unsigned>(smemShape));
|
||||
} else {
|
||||
auto smemShape = getScratchConfigBasic();
|
||||
elems = product<unsigned>(smemShape);
|
||||
}
|
||||
|
||||
auto tensorType = op.operand().getType().cast<RankedTensorType>();
|
||||
unsigned bytes = elems * tensorType.getElementTypeBitWidth() / 8;
|
||||
|
||||
if (triton::ReduceOp::withIndex(op.redOp()))
|
||||
bytes += elems * sizeof(int32_t);
|
||||
|
||||
return bytes;
|
||||
}
|
||||
|
||||
bool isSharedEncoding(Value value) {
|
||||
auto type = value.getType();
|
||||
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
|
||||
|
Reference in New Issue
Block a user