[Triton-MLIR][BACKEND] Add argmin / argmax implementation for ReduceOp (#918)

This commit is contained in:
Qingyi Liu
2022-11-28 14:59:27 +08:00
committed by GitHub
parent 04ec5deb41
commit 9d31998a9d
12 changed files with 341 additions and 75 deletions

View File

@@ -37,6 +37,55 @@ unsigned ReduceOpHelper::getThreadsReductionAxis() {
triton::gpu::getWarpsPerCTA(srcLayout)[axis];
}
SmallVector<unsigned> ReduceOpHelper::getScratchConfigBasic() {
auto axis = op.axis();
auto smemShape = convertType<unsigned>(getSrcShape());
smemShape[axis] = std::min(smemShape[axis], getThreadsReductionAxis());
return smemShape;
}
SmallVector<SmallVector<unsigned>> ReduceOpHelper::getScratchConfigsFast() {
auto axis = op.axis();
SmallVector<SmallVector<unsigned>> smemShapes(3);
/// shared memory block0
smemShapes[0] = convertType<unsigned>(getSrcShape());
smemShapes[0][axis] = getInterWarpSize();
/// FIXME(Qingyi): This size is actually larger than required.
/// shared memory block1:
auto mod = op.getOperation()->getParentOfType<ModuleOp>();
unsigned numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
smemShapes[1].push_back(numWarps * 32);
/// FIXME(Qingyi): This requirement is actually not necessary, because it is
/// always smaller than smemShapes[0] shared memory block2
smemShapes[2] = convertType<unsigned>(getSrcShape());
smemShapes[2].erase(smemShapes[2].begin() + axis);
return smemShapes;
}
unsigned ReduceOpHelper::getScratchSizeInBytes() {
unsigned elems = 0;
if (isFastReduction()) {
auto smemShapes = getScratchConfigsFast();
for (const auto &smemShape : smemShapes)
elems = std::max(elems, product<unsigned>(smemShape));
} else {
auto smemShape = getScratchConfigBasic();
elems = product<unsigned>(smemShape);
}
auto tensorType = op.operand().getType().cast<RankedTensorType>();
unsigned bytes = elems * tensorType.getElementTypeBitWidth() / 8;
if (triton::ReduceOp::withIndex(op.redOp()))
bytes += elems * sizeof(int32_t);
return bytes;
}
bool isSharedEncoding(Value value) {
auto type = value.getType();
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {