[Triton-MLIR] Increase block size K to completely eliminate shared memory bank conflicts (#862)

This commit is contained in:
Keren Zhou
2022-11-08 17:39:23 -08:00
committed by GitHub
parent 080b4addf8
commit 2da71b2aaa
2 changed files with 14 additions and 8 deletions

View File

@@ -144,6 +144,10 @@ void llPrintf(StringRef msg, ValueRange args,
LLVM::createIndexConstant(rewriter, loc, this->getTypeConverter(), \
__VA_ARGS__)
// Helper function
#define tid_val() getThreadId(rewriter, loc)
#define llprintf(fmt, ...) LLVM::llPrintf(fmt, {__VA_ARGS__}, rewriter)
} // namespace LLVM
} // namespace mlir
@@ -4686,9 +4690,9 @@ struct InsertSliceAsyncOpConversion
auto threadsPerCTA = getThreadsPerCTA(srcBlockedLayout);
auto inOrder = srcBlockedLayout.getOrder();
// If perPhase * maxPhase > threadsPerCTA, we need to swizzle over
// elements across phases. If perPhase * maxPhase <= threadsPerCTA,
// swizzle is not allowd
// If perPhase * maxPhase > threadsPerCTA, we will have elements
// that share the same tile indices. The index calculation will
// be cached.
auto numSwizzleRows = std::max<unsigned>(
(perPhase * maxPhase) / threadsPerCTA[inOrder[1]], 1);
// A sharedLayout encoding has a "vec" parameter.
@@ -4727,12 +4731,14 @@ struct InsertSliceAsyncOpConversion
// Example1:
// outVec = 2, inVec = 2, minVec = 2
// outVec = 2, inVec = 4, minVec = 2
// | [1 2] [3 4] ... [15 16] |
// | [3 4] [5 6] ... [1 2] |
// | [1 2] [3 4] [5 6] ... |
// | [3 4] [1 2] [7 8] ... |
// | [5 6] [7 8] [1 2] ... |
// Example2:
// outVec = 4, inVec = 2, minVec = 2
// | [1 2 3 4] [5 6 7 8] ... [13 14 15 16] |
// | [5 6 7 8] [9 10 11 12] ... [1 2 3 4] |
// | [1 2 3 4] [5 6 7 8] [9 10 11 12] ... |
// | [5 6 7 8] [1 2 3 4] [13 14 15 16] ... |
// | [9 10 11 12] [13 14 15 16] [1 2 3 4] ... |
auto srcIdx = srcIndices[tileVecIdxRow * sizePerThread[inOrder[0]]];
Value phase = urem(udiv(srcIdx[inOrder[1]], i32_val(perPhase)),
i32_val(maxPhase));

View File

@@ -156,7 +156,7 @@ import triton.language as tl
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
],
key=['M', 'N', 'K'],
)