[Triton-MLIR][Backend] Fix number of warps and threads per warp when matrices are small (#917)
This commit is contained in:
@@ -1563,9 +1563,7 @@ private:
|
||||
LogicalResult
|
||||
ReduceOpConversion::matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto srcTy = op.operand().getType().cast<RankedTensorType>();
|
||||
auto srcLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>();
|
||||
if (op.axis() == srcLayout.getOrder()[0])
|
||||
if (ReduceOpHelper(op).isFastReduction())
|
||||
return matchAndRewriteFast(op, adaptor, rewriter);
|
||||
return matchAndRewriteBasic(op, adaptor, rewriter);
|
||||
}
|
||||
@@ -1763,10 +1761,11 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
|
||||
Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation());
|
||||
smemBase = bitcast(smemBase, elemPtrTy);
|
||||
|
||||
auto order = getOrder(srcLayout);
|
||||
unsigned sizeIntraWarps = threadsPerWarp[axis];
|
||||
unsigned sizeInterWarps = warpsPerCTA[axis];
|
||||
ReduceOpHelper helper(op);
|
||||
unsigned sizeIntraWarps = helper.getIntraWarpSize();
|
||||
unsigned sizeInterWarps = helper.getInterWarpSize();
|
||||
|
||||
auto order = getOrder(srcLayout);
|
||||
unsigned srcElems = getElemsPerThread(srcTy);
|
||||
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcShape);
|
||||
auto srcValues = getElementsFromStruct(loc, adaptor.operand(), rewriter);
|
||||
|
Reference in New Issue
Block a user