diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 69ca5fc22..64dd0d76d 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -144,6 +144,10 @@ void llPrintf(StringRef msg, ValueRange args, LLVM::createIndexConstant(rewriter, loc, this->getTypeConverter(), \ __VA_ARGS__) +// Helper function +#define tid_val() getThreadId(rewriter, loc) +#define llprintf(fmt, ...) LLVM::llPrintf(fmt, {__VA_ARGS__}, rewriter) + } // namespace LLVM } // namespace mlir @@ -4686,9 +4690,9 @@ struct InsertSliceAsyncOpConversion auto threadsPerCTA = getThreadsPerCTA(srcBlockedLayout); auto inOrder = srcBlockedLayout.getOrder(); - // If perPhase * maxPhase > threadsPerCTA, we need to swizzle over - // elements across phases. If perPhase * maxPhase <= threadsPerCTA, - // swizzle is not allowd + // If perPhase * maxPhase > threadsPerCTA, we will have elements + // that share the same tile indices. The index calculation will + // be cached. auto numSwizzleRows = std::max( (perPhase * maxPhase) / threadsPerCTA[inOrder[1]], 1); // A sharedLayout encoding has a "vec" parameter. @@ -4727,12 +4731,14 @@ struct InsertSliceAsyncOpConversion // Example1: // outVec = 2, inVec = 2, minVec = 2 // outVec = 2, inVec = 4, minVec = 2 - // | [1 2] [3 4] ... [15 16] | - // | [3 4] [5 6] ... [1 2] | + // | [1 2] [3 4] [5 6] ... | + // | [3 4] [1 2] [7 8] ... | + // | [5 6] [7 8] [1 2] ... | // Example2: // outVec = 4, inVec = 2, minVec = 2 - // | [1 2 3 4] [5 6 7 8] ... [13 14 15 16] | - // | [5 6 7 8] [9 10 11 12] ... [1 2 3 4] | + // | [1 2 3 4] [5 6 7 8] [9 10 11 12] ... | + // | [5 6 7 8] [1 2 3 4] [13 14 15 16] ... | + // | [9 10 11 12] [13 14 15 16] [1 2 3 4] ... | auto srcIdx = srcIndices[tileVecIdxRow * sizePerThread[inOrder[0]]]; Value phase = urem(udiv(srcIdx[inOrder[1]], i32_val(perPhase)), i32_val(maxPhase)); diff --git a/python/tutorials/03-matrix-multiplication.py b/python/tutorials/03-matrix-multiplication.py index 82a264dc2..f11c3bc09 100644 --- a/python/tutorials/03-matrix-multiplication.py +++ b/python/tutorials/03-matrix-multiplication.py @@ -156,7 +156,7 @@ import triton.language as tl @triton.autotune( configs=[ - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), ], key=['M', 'N', 'K'], )