From bb7008651a913b0115cd847644edcb8cb1dda307 Mon Sep 17 00:00:00 2001 From: goostavz <109190422+goostavz@users.noreply.github.com> Date: Thu, 27 Oct 2022 04:39:38 +0800 Subject: [PATCH] [Backend] Hacky fix of missing barrier in ConvertLayout blocked->shared (#803) Barrier should be set by a separate pass, but it seems like there may be some bugs --- .../TritonGPUToLLVM/TritonGPUToLLVM.cpp | 19 ++++++------ python/tests/test_gemm.py | 31 +++++++++++++++++-- 2 files changed, 38 insertions(+), 12 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 74461ebe7..2fc0505c9 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -91,7 +91,7 @@ Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width, #define store(val, ptr) rewriter.create(loc, val, ptr) #define select(...) rewriter.create(loc, __VA_ARGS__) #define address_of(...) rewriter.create(loc, __VA_ARGS__) -#define barrier rewriter.create(loc) +#define barrier() rewriter.create(loc) #define undef(...) rewriter.create(loc, __VA_ARGS__) #define i32_ty rewriter.getIntegerType(32) #define vec_ty(type, num) VectorType::get(num, type) @@ -1877,7 +1877,7 @@ LogicalResult ConvertLayoutOpConversion::lowerDistributedToDistributed( for (unsigned repId = 0; repId < accumNumReplicates; ++repId) { auto multiDimRepId = getMultiDimIndex(repId, numReplicates); - barrier; + barrier(); if (srcLayout.isa() || srcLayout.isa()) { processReplica(loc, rewriter, /*stNotRd*/ true, srcTy, inNumCTAsEachRep, @@ -1887,7 +1887,7 @@ LogicalResult ConvertLayoutOpConversion::lowerDistributedToDistributed( assert(0 && "ConvertLayout with input layout not implemented"); return failure(); } - barrier; + barrier(); if (dstLayout.isa() || dstLayout.isa()) { processReplica(loc, rewriter, /*stNotRd*/ false, dstTy, outNumCTAsEachRep, @@ -1963,6 +1963,11 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared( smemBase = bitcast(elemPtrTy, smemBase); unsigned numWordsEachRep = product(wordsInEachRep); SmallVector wordVecs(numWordsEachRep); + // TODO: We should get less barriers if it is handled by membar pass + // instead of the backend, since the later can only handle it in + // the most conservative way. However just keep for now and revisit + // in the future in case necessary. + barrier(); for (unsigned i = 0; i < numElems; ++i) { if (i % srcAccumSizeInThreads == 0) { // start of a replication @@ -2016,8 +2021,7 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared( } } } - // TODO: double confirm if the Barrier is necessary here - barrier; + barrier(); rewriter.replaceOp(op, smemBase); return success(); } @@ -3057,11 +3061,6 @@ struct MMA16816ConversionHelper { for (unsigned n = 0; n < numRepN; ++n) callMma(2 * m, n, 2 * k); - // NOTE, the barrier here is a temporary trick making the gemm with a - // k-forloop pass the precision test, or it will fail. - // TODO[Superjomn]: Fix with a more general and performance-friendly way. - barrier; - // replace with new packed result Type structTy = LLVM::LLVMStructType::getLiteral( ctx, SmallVector(fc.size(), type::f32Ty(ctx))); diff --git a/python/tests/test_gemm.py b/python/tests/test_gemm.py index 12644cbd3..e1a27ec74 100644 --- a/python/tests/test_gemm.py +++ b/python/tests/test_gemm.py @@ -83,6 +83,23 @@ def matmul_kernel( # TODO: DotConversion in TritonGPUToLLVM cannot support non-splat C for the moment +def get_variant_golden(a, b): + SIZE_M = a.shape[0] + SIZE_K = a.shape[1] + SIZE_N = b.shape[1] + assert a.shape[1] == b.shape[0] + zero_M_K = torch.zeros((SIZE_M, SIZE_K)).cuda() + zero_3M_K = torch.zeros((3 * SIZE_M, SIZE_K)).cuda() + zero_K_N = torch.zeros((SIZE_K, SIZE_N)).cuda() + zero_3K_N = torch.zeros((3 * SIZE_K, SIZE_N)).cuda() + a_padded = torch.cat((a, zero_M_K, zero_M_K), 0) + a_padded = torch.cat((a_padded, zero_3M_K, zero_3M_K), 1) + b_padded = torch.cat((b, zero_K_N, zero_K_N), 0) + b_padded = torch.cat((b_padded, zero_3K_N, zero_3K_N), 1) + c_padded = torch.matmul(a_padded, b_padded) + return c_padded[:SIZE_M, :SIZE_N] + + @pytest.mark.parametrize('SIZE_M,SIZE_N,SIZE_K,NUM_WARPS,BLOCK_SIZE_M,BLOCK_SIZE_N,BLOCK_SIZE_K', [ # Non-forloop [64, 32, 64, 4, 64, 32, 64], @@ -94,8 +111,8 @@ def matmul_kernel( [32, 64, 128, 4, 32, 64, 32], [32, 128, 256, 4, 32, 128, 64], [64, 128, 64, 4, 64, 128, 32], - [128, 128, 64, 4, 128, 128, 32], [64, 64, 128, 4, 64, 64, 32], + [128, 128, 64, 4, 128, 128, 32], [128, 128, 128, 4, 128, 128, 32], [128, 128, 256, 4, 128, 128, 64], [128, 256, 128, 4, 128, 256, 32], @@ -115,5 +132,15 @@ def test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLO BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K, num_warps=NUM_WARPS) golden = torch.matmul(a, b) + + # It's not easy to get a proper error threshold in different size + # Here the gemm calculation is padded to a different size in order to get + # a variant version of the golden result. And the error between golden and + # golden_variant provide reference on selecting the proper rtol / atol. + golden_variant = get_variant_golden(a, b) + golden_diff = golden - golden_variant + golden_abs_err = torch.max(torch.abs(golden_diff)).item() + golden_rel_err = torch.max(torch.abs(golden_diff / golden)).item() + torch.set_printoptions(profile="full") - assert_close(c, golden, rtol=1e-3, atol=1e-3, check_dtype=False) + assert_close(c, golden, rtol=max(1e-4, 1.5 * golden_rel_err), atol=max(1e-4, 1.5 * golden_abs_err), check_dtype=False)