[Triton-MLIR][Backend] Fix number of warps and threads per warp when matrices are small (#917)

This commit is contained in:
Keren Zhou
2022-11-26 12:30:38 -08:00
committed by GitHub
parent f63be0e9b5
commit 35c9ec1103
7 changed files with 116 additions and 29 deletions

View File

@@ -8,6 +8,29 @@
namespace mlir {
class ReduceOpHelper {
public:
explicit ReduceOpHelper(triton::ReduceOp op) : op(op) {
srcTy = op.operand().getType().cast<RankedTensorType>();
}
ArrayRef<int64_t> getSrcShape() { return srcTy.getShape(); }
Attribute getSrcLayout() { return srcTy.getEncoding(); }
bool isFastReduction();
unsigned getInterWarpSize();
unsigned getIntraWarpSize();
unsigned getThreadsReductionAxis();
private:
triton::ReduceOp op;
RankedTensorType srcTy{};
};
bool isSharedEncoding(Value value);
bool maybeSharedAllocationOp(Operation *op);

View File

@@ -228,9 +228,11 @@ for
unsigned remainingLanes = 32;
unsigned remainingThreads = numWarps*32;
unsigned remainingWarps = numWarps;
unsigned prevLanes = 1;
unsigned prevWarps = 1;
SmallVector<unsigned, 4> threadsPerWarp(rank);
SmallVector<unsigned, 4> warpsPerCTA(rank);
for (int _dim = 0; _dim < rank; ++_dim) {
for (int _dim = 0; _dim < rank - 1; ++_dim) {
int i = order[_dim];
unsigned threadsPerCTA = std::clamp<unsigned>(remainingThreads, 1, shape[i] / sizePerThread[i]);
threadsPerWarp[i] = std::clamp<unsigned>(threadsPerCTA, 1, remainingLanes);
@@ -238,7 +240,12 @@ for
remainingWarps /= warpsPerCTA[i];
remainingLanes /= threadsPerWarp[i];
remainingThreads /= threadsPerCTA;
prevLanes *= threadsPerWarp[i];
prevWarps *= warpsPerCTA[i];
}
// Expand the last dimension to fill the remaining lanes and warps
threadsPerWarp[order[rank-1]] = 32 / prevLanes;
warpsPerCTA[order[rank-1]] = numWarps / prevWarps;
return $_get(context, sizePerThread, threadsPerWarp, warpsPerCTA, order);