[Triton-MLIR] Increase block size K to completely eliminate shared memory bank conflicts (#862)
This commit is contained in:
@@ -144,6 +144,10 @@ void llPrintf(StringRef msg, ValueRange args,
|
|||||||
LLVM::createIndexConstant(rewriter, loc, this->getTypeConverter(), \
|
LLVM::createIndexConstant(rewriter, loc, this->getTypeConverter(), \
|
||||||
__VA_ARGS__)
|
__VA_ARGS__)
|
||||||
|
|
||||||
|
// Helper function
|
||||||
|
#define tid_val() getThreadId(rewriter, loc)
|
||||||
|
#define llprintf(fmt, ...) LLVM::llPrintf(fmt, {__VA_ARGS__}, rewriter)
|
||||||
|
|
||||||
} // namespace LLVM
|
} // namespace LLVM
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
@@ -4686,9 +4690,9 @@ struct InsertSliceAsyncOpConversion
|
|||||||
auto threadsPerCTA = getThreadsPerCTA(srcBlockedLayout);
|
auto threadsPerCTA = getThreadsPerCTA(srcBlockedLayout);
|
||||||
auto inOrder = srcBlockedLayout.getOrder();
|
auto inOrder = srcBlockedLayout.getOrder();
|
||||||
|
|
||||||
// If perPhase * maxPhase > threadsPerCTA, we need to swizzle over
|
// If perPhase * maxPhase > threadsPerCTA, we will have elements
|
||||||
// elements across phases. If perPhase * maxPhase <= threadsPerCTA,
|
// that share the same tile indices. The index calculation will
|
||||||
// swizzle is not allowd
|
// be cached.
|
||||||
auto numSwizzleRows = std::max<unsigned>(
|
auto numSwizzleRows = std::max<unsigned>(
|
||||||
(perPhase * maxPhase) / threadsPerCTA[inOrder[1]], 1);
|
(perPhase * maxPhase) / threadsPerCTA[inOrder[1]], 1);
|
||||||
// A sharedLayout encoding has a "vec" parameter.
|
// A sharedLayout encoding has a "vec" parameter.
|
||||||
@@ -4727,12 +4731,14 @@ struct InsertSliceAsyncOpConversion
|
|||||||
// Example1:
|
// Example1:
|
||||||
// outVec = 2, inVec = 2, minVec = 2
|
// outVec = 2, inVec = 2, minVec = 2
|
||||||
// outVec = 2, inVec = 4, minVec = 2
|
// outVec = 2, inVec = 4, minVec = 2
|
||||||
// | [1 2] [3 4] ... [15 16] |
|
// | [1 2] [3 4] [5 6] ... |
|
||||||
// | [3 4] [5 6] ... [1 2] |
|
// | [3 4] [1 2] [7 8] ... |
|
||||||
|
// | [5 6] [7 8] [1 2] ... |
|
||||||
// Example2:
|
// Example2:
|
||||||
// outVec = 4, inVec = 2, minVec = 2
|
// outVec = 4, inVec = 2, minVec = 2
|
||||||
// | [1 2 3 4] [5 6 7 8] ... [13 14 15 16] |
|
// | [1 2 3 4] [5 6 7 8] [9 10 11 12] ... |
|
||||||
// | [5 6 7 8] [9 10 11 12] ... [1 2 3 4] |
|
// | [5 6 7 8] [1 2 3 4] [13 14 15 16] ... |
|
||||||
|
// | [9 10 11 12] [13 14 15 16] [1 2 3 4] ... |
|
||||||
auto srcIdx = srcIndices[tileVecIdxRow * sizePerThread[inOrder[0]]];
|
auto srcIdx = srcIndices[tileVecIdxRow * sizePerThread[inOrder[0]]];
|
||||||
Value phase = urem(udiv(srcIdx[inOrder[1]], i32_val(perPhase)),
|
Value phase = urem(udiv(srcIdx[inOrder[1]], i32_val(perPhase)),
|
||||||
i32_val(maxPhase));
|
i32_val(maxPhase));
|
||||||
|
@@ -156,7 +156,7 @@ import triton.language as tl
|
|||||||
|
|
||||||
@triton.autotune(
|
@triton.autotune(
|
||||||
configs=[
|
configs=[
|
||||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
|
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
|
||||||
],
|
],
|
||||||
key=['M', 'N', 'K'],
|
key=['M', 'N', 'K'],
|
||||||
)
|
)
|
||||||
|
Reference in New Issue
Block a user