[Triton-MLIR] Increase block size K to completely eliminate shared memory bank conflicts (#862)
This commit is contained in:
@@ -144,6 +144,10 @@ void llPrintf(StringRef msg, ValueRange args,
|
||||
LLVM::createIndexConstant(rewriter, loc, this->getTypeConverter(), \
|
||||
__VA_ARGS__)
|
||||
|
||||
// Helper function
|
||||
#define tid_val() getThreadId(rewriter, loc)
|
||||
#define llprintf(fmt, ...) LLVM::llPrintf(fmt, {__VA_ARGS__}, rewriter)
|
||||
|
||||
} // namespace LLVM
|
||||
} // namespace mlir
|
||||
|
||||
@@ -4686,9 +4690,9 @@ struct InsertSliceAsyncOpConversion
|
||||
auto threadsPerCTA = getThreadsPerCTA(srcBlockedLayout);
|
||||
auto inOrder = srcBlockedLayout.getOrder();
|
||||
|
||||
// If perPhase * maxPhase > threadsPerCTA, we need to swizzle over
|
||||
// elements across phases. If perPhase * maxPhase <= threadsPerCTA,
|
||||
// swizzle is not allowd
|
||||
// If perPhase * maxPhase > threadsPerCTA, we will have elements
|
||||
// that share the same tile indices. The index calculation will
|
||||
// be cached.
|
||||
auto numSwizzleRows = std::max<unsigned>(
|
||||
(perPhase * maxPhase) / threadsPerCTA[inOrder[1]], 1);
|
||||
// A sharedLayout encoding has a "vec" parameter.
|
||||
@@ -4727,12 +4731,14 @@ struct InsertSliceAsyncOpConversion
|
||||
// Example1:
|
||||
// outVec = 2, inVec = 2, minVec = 2
|
||||
// outVec = 2, inVec = 4, minVec = 2
|
||||
// | [1 2] [3 4] ... [15 16] |
|
||||
// | [3 4] [5 6] ... [1 2] |
|
||||
// | [1 2] [3 4] [5 6] ... |
|
||||
// | [3 4] [1 2] [7 8] ... |
|
||||
// | [5 6] [7 8] [1 2] ... |
|
||||
// Example2:
|
||||
// outVec = 4, inVec = 2, minVec = 2
|
||||
// | [1 2 3 4] [5 6 7 8] ... [13 14 15 16] |
|
||||
// | [5 6 7 8] [9 10 11 12] ... [1 2 3 4] |
|
||||
// | [1 2 3 4] [5 6 7 8] [9 10 11 12] ... |
|
||||
// | [5 6 7 8] [1 2 3 4] [13 14 15 16] ... |
|
||||
// | [9 10 11 12] [13 14 15 16] [1 2 3 4] ... |
|
||||
auto srcIdx = srcIndices[tileVecIdxRow * sizePerThread[inOrder[0]]];
|
||||
Value phase = urem(udiv(srcIdx[inOrder[1]], i32_val(perPhase)),
|
||||
i32_val(maxPhase));
|
||||
|
@@ -156,7 +156,7 @@ import triton.language as tl
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
|
||||
],
|
||||
key=['M', 'N', 'K'],
|
||||
)
|
||||
|
Reference in New Issue
Block a user