[Triton-MLIR][Backend] Fix number of warps and threads per warp when matrices are small (#917)
This commit is contained in:
@@ -8,6 +8,29 @@
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class ReduceOpHelper {
|
||||
public:
|
||||
explicit ReduceOpHelper(triton::ReduceOp op) : op(op) {
|
||||
srcTy = op.operand().getType().cast<RankedTensorType>();
|
||||
}
|
||||
|
||||
ArrayRef<int64_t> getSrcShape() { return srcTy.getShape(); }
|
||||
|
||||
Attribute getSrcLayout() { return srcTy.getEncoding(); }
|
||||
|
||||
bool isFastReduction();
|
||||
|
||||
unsigned getInterWarpSize();
|
||||
|
||||
unsigned getIntraWarpSize();
|
||||
|
||||
unsigned getThreadsReductionAxis();
|
||||
|
||||
private:
|
||||
triton::ReduceOp op;
|
||||
RankedTensorType srcTy{};
|
||||
};
|
||||
|
||||
bool isSharedEncoding(Value value);
|
||||
|
||||
bool maybeSharedAllocationOp(Operation *op);
|
||||
|
@@ -228,9 +228,11 @@ for
|
||||
unsigned remainingLanes = 32;
|
||||
unsigned remainingThreads = numWarps*32;
|
||||
unsigned remainingWarps = numWarps;
|
||||
unsigned prevLanes = 1;
|
||||
unsigned prevWarps = 1;
|
||||
SmallVector<unsigned, 4> threadsPerWarp(rank);
|
||||
SmallVector<unsigned, 4> warpsPerCTA(rank);
|
||||
for (int _dim = 0; _dim < rank; ++_dim) {
|
||||
for (int _dim = 0; _dim < rank - 1; ++_dim) {
|
||||
int i = order[_dim];
|
||||
unsigned threadsPerCTA = std::clamp<unsigned>(remainingThreads, 1, shape[i] / sizePerThread[i]);
|
||||
threadsPerWarp[i] = std::clamp<unsigned>(threadsPerCTA, 1, remainingLanes);
|
||||
@@ -238,7 +240,12 @@ for
|
||||
remainingWarps /= warpsPerCTA[i];
|
||||
remainingLanes /= threadsPerWarp[i];
|
||||
remainingThreads /= threadsPerCTA;
|
||||
prevLanes *= threadsPerWarp[i];
|
||||
prevWarps *= warpsPerCTA[i];
|
||||
}
|
||||
// Expand the last dimension to fill the remaining lanes and warps
|
||||
threadsPerWarp[order[rank-1]] = 32 / prevLanes;
|
||||
warpsPerCTA[order[rank-1]] = numWarps / prevWarps;
|
||||
|
||||
return $_get(context, sizePerThread, threadsPerWarp, warpsPerCTA, order);
|
||||
|
||||
|
Reference in New Issue
Block a user