[Triton-MLIR][Backend] Fix number of warps and threads per warp when matrices are small (#917)

This commit is contained in:
Keren Zhou
2022-11-26 12:30:38 -08:00
committed by GitHub
parent f63be0e9b5
commit 35c9ec1103
7 changed files with 116 additions and 29 deletions

View File

@@ -89,24 +89,19 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
}
SmallVector<unsigned> getScratchConfigForReduce(triton::ReduceOp op) {
auto srcTy = op.operand().getType().cast<RankedTensorType>();
auto srcLayout = srcTy.getEncoding();
auto srcShape = srcTy.getShape();
auto axis = op.axis();
bool fastReduce = axis == getOrder(srcLayout)[0];
ReduceOpHelper helper(op);
SmallVector<unsigned> smemShape;
auto srcShape = helper.getSrcShape();
for (auto d : srcShape)
smemShape.push_back(d);
if (fastReduce) {
unsigned sizeInterWarps = gpu::getWarpsPerCTA(srcLayout)[axis];
smemShape[axis] = sizeInterWarps;
auto axis = op.axis();
if (helper.isFastReduction()) {
smemShape[axis] = helper.getInterWarpSize();
} else {
unsigned threadsPerCTAAxis = gpu::getThreadsPerWarp(srcLayout)[axis] *
gpu::getWarpsPerCTA(srcLayout)[axis];
smemShape[axis] = threadsPerCTAAxis;
smemShape[axis] =
std::min(smemShape[axis], helper.getThreadsReductionAxis());
}
return smemShape;
@@ -181,8 +176,7 @@ private:
// TODO(Keren): Reduce with index is not supported yet.
auto value = op->getOperand(0);
if (auto tensorType = value.getType().dyn_cast<RankedTensorType>()) {
auto srcLayout = tensorType.getEncoding();
bool fastReduce = reduceOp.axis() == getOrder(srcLayout)[0];
bool fastReduce = ReduceOpHelper(reduceOp).isFastReduction();
auto smemShape = getScratchConfigForReduce(reduceOp);
unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1,
std::multiplies{});

View File

@@ -5,6 +5,38 @@
namespace mlir {
bool ReduceOpHelper::isFastReduction() {
auto srcLayout = srcTy.getEncoding();
auto axis = op.axis();
return axis == triton::gpu::getOrder(srcLayout)[0];
}
unsigned ReduceOpHelper::getInterWarpSize() {
auto srcLayout = srcTy.getEncoding();
auto srcShape = srcTy.getShape();
auto axis = op.axis();
auto srcReduceDimSize = static_cast<unsigned>(srcShape[axis]);
unsigned sizeIntraWarps = getIntraWarpSize();
return std::min(srcReduceDimSize / sizeIntraWarps,
triton::gpu::getWarpsPerCTA(srcLayout)[axis]);
}
unsigned ReduceOpHelper::getIntraWarpSize() {
auto srcLayout = srcTy.getEncoding();
auto srcShape = srcTy.getShape();
auto axis = op.axis();
auto srcReduceDimSize = static_cast<unsigned>(srcShape[axis]);
return std::min(srcReduceDimSize,
triton::gpu::getThreadsPerWarp(srcLayout)[axis]);
}
unsigned ReduceOpHelper::getThreadsReductionAxis() {
auto srcLayout = srcTy.getEncoding();
auto axis = op.axis();
return triton::gpu::getThreadsPerWarp(srcLayout)[axis] *
triton::gpu::getWarpsPerCTA(srcLayout)[axis];
}
bool isSharedEncoding(Value value) {
auto type = value.getType();
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {

View File

@@ -1563,9 +1563,7 @@ private:
LogicalResult
ReduceOpConversion::matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto srcTy = op.operand().getType().cast<RankedTensorType>();
auto srcLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>();
if (op.axis() == srcLayout.getOrder()[0])
if (ReduceOpHelper(op).isFastReduction())
return matchAndRewriteFast(op, adaptor, rewriter);
return matchAndRewriteBasic(op, adaptor, rewriter);
}
@@ -1763,10 +1761,11 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation());
smemBase = bitcast(smemBase, elemPtrTy);
auto order = getOrder(srcLayout);
unsigned sizeIntraWarps = threadsPerWarp[axis];
unsigned sizeInterWarps = warpsPerCTA[axis];
ReduceOpHelper helper(op);
unsigned sizeIntraWarps = helper.getIntraWarpSize();
unsigned sizeInterWarps = helper.getInterWarpSize();
auto order = getOrder(srcLayout);
unsigned srcElems = getElemsPerThread(srcTy);
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcShape);
auto srcValues = getElementsFromStruct(loc, adaptor.operand(), rewriter);