From f2fcaeabf3de72d4ecd3a117c316a071483d5426 Mon Sep 17 00:00:00 2001 From: Keren Zhou Date: Sat, 3 Dec 2022 11:14:12 -0800 Subject: [PATCH] [BACKEND] Support dot op when the output is mma encoding and allowtf32 is true (#937) --- lib/Analysis/Membar.cpp | 8 +- lib/Conversion/TritonGPUToLLVM/DotHelpers.h | 16 ++-- .../TritonGPUToLLVM/TritonGPUToLLVM.cpp | 78 +++++++------------ python/tests/test_gemm.py | 36 +++++---- test/Conversion/tritongpu_to_llvm.mlir | 39 ++++++++++ 5 files changed, 105 insertions(+), 72 deletions(-) diff --git a/lib/Analysis/Membar.cpp b/lib/Analysis/Membar.cpp index 68aebdbd1..88dd9165d 100644 --- a/lib/Analysis/Membar.cpp +++ b/lib/Analysis/Membar.cpp @@ -81,9 +81,11 @@ void MembarAnalysis::transfer(Operation *op, RegionInfo *regionInfo, return; } - if (isa(op)) { - // If the current op is an async wait, we insert a barrier op and sync - // previous reads and writes. + if (isa(op) && + !isa(op->getNextNode())) { + // If the current op is an async wait and the next op is not a barrier we + // insert a barrier op and sync + regionInfo->sync(); OpBuilder::InsertionGuard g(*builder); builder->setInsertionPointAfter(op); builder->create(op->getLoc()); diff --git a/lib/Conversion/TritonGPUToLLVM/DotHelpers.h b/lib/Conversion/TritonGPUToLLVM/DotHelpers.h index d08a83918..728857d39 100644 --- a/lib/Conversion/TritonGPUToLLVM/DotHelpers.h +++ b/lib/Conversion/TritonGPUToLLVM/DotHelpers.h @@ -708,19 +708,19 @@ public: Type elemTy = type::f32Ty(ctx); Type elemPtrTy = ptr_ty(elemTy); if (kOrder == 1) { - elems[0] = load(gep(elemPtrTy, ptr, i32_val(sOffsetElem))); - elems[1] = load(gep(elemPtrTy, ptr2, i32_val(sOffsetElem))); + elems[0] = load(gep(elemPtrTy, ptr, sOffsetElemVal)); + elems[1] = load(gep(elemPtrTy, ptr2, sOffsetElemVal)); elems[2] = - load(gep(elemPtrTy, ptr, i32_val(sOffsetElem + sOffsetArrElem))); + load(gep(elemPtrTy, ptr, sOffsetArrElemVal)); elems[3] = - load(gep(elemPtrTy, ptr2, i32_val(sOffsetElem + sOffsetArrElem))); + load(gep(elemPtrTy, ptr2, sOffsetArrElemVal)); } else { - elems[0] = load(gep(elemPtrTy, ptr, i32_val(sOffsetElem))); - elems[2] = load(gep(elemPtrTy, ptr2, i32_val(sOffsetElem))); + elems[0] = load(gep(elemPtrTy, ptr, sOffsetElemVal)); + elems[2] = load(gep(elemPtrTy, ptr2, sOffsetElemVal)); elems[1] = - load(gep(elemPtrTy, ptr, i32_val(sOffsetElem + sOffsetArrElem))); + load(gep(elemPtrTy, ptr, sOffsetArrElemVal)); elems[3] = - load(gep(elemPtrTy, ptr2, i32_val(sOffsetElem + sOffsetArrElem))); + load(gep(elemPtrTy, ptr2, sOffsetArrElemVal)); } return {elems[0], elems[1], elems[2], elems[3]}; diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 84315f963..0a290fff1 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -3327,10 +3327,10 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern { // We cannot get both the operand types(in TypeConverter), here we assume the // types of both the operands are identical here. // TODO[Superjomn]: Find a better way to implement it. - static bool isDotHMMA(TensorType operand, bool allowTF32, int mmaVersion) { + static bool isDotHMMA(TensorType operand, int mmaVersion) { auto elemTy = operand.getElementType(); return elemTy.isF16() || elemTy.isBF16() || - (elemTy.isF32() && allowTF32 && mmaVersion >= 2) || + (elemTy.isF32() && mmaVersion >= 2) || (elemTy.isInteger(8) && mmaVersion >= 2); } @@ -3354,11 +3354,7 @@ Value ConvertLayoutOpConversion::lowerSharedToDotOperandMMA( Value src = op.src(); Value dst = op.result(); auto dstTensorTy = dst.getType().cast(); - // TODO[Superjomn]: allowTF32 is not accessible here for it is an attribute of - // an Op instance. - bool allowTF32 = false; - bool isHMMA = DotOpConversion::isDotHMMA(dstTensorTy, allowTF32, - mmaLayout.getVersion()); + bool isHMMA = DotOpConversion::isDotHMMA(dstTensorTy, mmaLayout.getVersion()); auto smemObj = getSharedMemoryObjectFromStruct(loc, adaptor.src(), rewriter); Value res; @@ -3421,25 +3417,16 @@ LogicalResult ConvertLayoutOpConversion::lowerSharedToDotOperand( } else if (auto blockedLayout = dotOperandLayout.getParent() .dyn_cast_or_null()) { - // TODO[Superjomn]: the allowTF32 is not available in ConvertLayoutOp for it - // is an attribute of DotOp. - bool allowTF32 = false; - bool isFMADot = dstTensorTy.getElementType().isF32() && !allowTF32; - if (isFMADot) { - auto dotOpLayout = - dstTensorTy.getEncoding().cast(); - auto blockedLayout = dotOpLayout.getParent().cast(); - DotOpFMAConversionHelper helper(blockedLayout); - auto thread = getThreadId(rewriter, loc); - if (dotOpLayout.getOpIdx() == 0) { // $a - res = helper.loadA(src, adaptor.src(), blockedLayout, thread, loc, - rewriter); - } else { // $b - res = helper.loadB(src, adaptor.src(), blockedLayout, thread, loc, - rewriter); - } - } else - assert(false && "Unsupported dot operand layout found"); + auto dotOpLayout = dstTensorTy.getEncoding().cast(); + DotOpFMAConversionHelper helper(blockedLayout); + auto thread = getThreadId(rewriter, loc); + if (dotOpLayout.getOpIdx() == 0) { // $a + res = helper.loadA(src, adaptor.src(), blockedLayout, thread, loc, + rewriter); + } else { // $b + res = helper.loadB(src, adaptor.src(), blockedLayout, thread, loc, + rewriter); + } } else { assert(false && "Unsupported dot operand layout found"); } @@ -3805,13 +3792,6 @@ public: auto ctx = type.getContext(); Attribute layout = type.getEncoding(); auto shape = type.getShape(); - - // TODO[Keren, Superjomn]: fix it, allowTF32 is not accessible here for it - // is bound to an Op instance. - bool allowTF32 = false; - bool isFMADot = type.getElementType().isF32() && !allowTF32 && - layout.dyn_cast_or_null(); - if (layout && (layout.isa() || layout.isa() || layout.isa())) { @@ -3835,37 +3815,39 @@ public: return LLVM::LLVMStructType::getLiteral(ctx, types); } else if (auto dotOpLayout = layout.dyn_cast_or_null()) { - if (isFMADot) { // for parent is blocked layout + if (dotOpLayout.getParent() + .isa()) { // for parent is blocked layout int numElemsPerThread = DotOpFMAConversionHelper::getNumElemsPerThread(shape, dotOpLayout); return LLVM::LLVMStructType::getLiteral( ctx, SmallVector(numElemsPerThread, type::f32Ty(ctx))); - } else { // for parent is MMA layout auto mmaLayout = dotOpLayout.getParent().cast(); auto wpt = mmaLayout.getWarpsPerCTA(); Type elemTy = convertType(type.getElementType()); - auto vecSize = 1; - if (elemTy.getIntOrFloatBitWidth() == 16) { - vecSize = 2; - } else if (elemTy.getIntOrFloatBitWidth() == 8) { - vecSize = 4; - } else { - assert(false && "Unsupported element type"); - } - Type vecTy = vec_ty(elemTy, vecSize); if (mmaLayout.getVersion() == 2) { + const llvm::DenseMap targetTyMap = { + {32, elemTy}, + {16, vec_ty(elemTy, 2)}, + {8, vec_ty(elemTy, 4)}, + }; + Type targetTy; + if (targetTyMap.count(elemTy.getIntOrFloatBitWidth())) { + targetTy = targetTyMap.lookup(elemTy.getIntOrFloatBitWidth()); + } else { + assert(false && "Unsupported element type"); + } if (dotOpLayout.getOpIdx() == 0) { // $a int elems = MMA16816ConversionHelper::getANumElemsPerThread(type, wpt[0]); return LLVM::LLVMStructType::getLiteral( - ctx, SmallVector(elems, vecTy)); + ctx, SmallVector(elems, targetTy)); } if (dotOpLayout.getOpIdx() == 1) { // $b int elems = MMA16816ConversionHelper::getBNumElemsPerThread(type, wpt[1]); - return struct_ty(SmallVector(elems, vecTy)); + return struct_ty(SmallVector(elems, targetTy)); } } @@ -3995,10 +3977,10 @@ struct InsertSliceAsyncOpConversion // %other SmallVector otherElems; if (llOther) { - // TODO(Keren): support "other" tensor. + // FIXME(Keren): always assume other is 0 for now // It's not necessary for now because the pipeline pass will skip // generating insert_slice_async if the load op has any "other" tensor. - assert(false && "insert_slice_async: Other value not supported yet"); + // assert(false && "insert_slice_async: Other value not supported yet"); otherElems = getLLVMElems(other, llOther, rewriter, loc); assert(srcElems.size() == otherElems.size()); } diff --git a/python/tests/test_gemm.py b/python/tests/test_gemm.py index e37c8490c..7c8c4226b 100644 --- a/python/tests/test_gemm.py +++ b/python/tests/test_gemm.py @@ -220,14 +220,17 @@ def test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLO assert_close(c, golden, rtol=max(1e-4, 1.5 * golden_rel_err), atol=max(1e-4, 1.5 * golden_abs_err), check_dtype=False) -@pytest.mark.parametrize('M,N,K,num_warps,block_M,block_N,block_K', [ - [32, 32, 16, 4, 32, 32, 16], - [32, 16, 16, 4, 32, 32, 16], - [128, 8, 8, 4, 32, 32, 16], - # TODO[Superjomn]: fix it later - # [127, 41, 43, 4, 32, 32, 16], +@pytest.mark.parametrize('M,N,K,num_warps,block_M,block_N,block_K,allow_tf32', [ + [32, 32, 16, 4, 32, 32, 16, False], + [32, 32, 16, 4, 32, 32, 16, True], + [32, 16, 16, 4, 32, 32, 16, False], + [32, 16, 16, 4, 32, 32, 16, True], + [127, 41, 43, 4, 32, 32, 16, False], + [127, 41, 43, 4, 32, 32, 16, True], + [128, 8, 8, 4, 32, 32, 16, False], + [128, 8, 8, 4, 32, 32, 16, True] ]) -def test_gemm_fmadot(M, N, K, num_warps, block_M, block_N, block_K): +def test_gemm_fp32(M, N, K, num_warps, block_M, block_N, block_K, allow_tf32): @triton.jit def matmul_kernel( a_ptr, b_ptr, c_ptr, @@ -236,6 +239,7 @@ def test_gemm_fmadot(M, N, K, num_warps, block_M, block_N, block_K): stride_bk, stride_bn, stride_cm, stride_cn, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + ALLOW_TF32: tl.constexpr ): pid = tl.program_id(axis=0) # num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) @@ -253,10 +257,9 @@ def test_gemm_fmadot(M, N, K, num_warps, block_M, block_N, block_K): for k in range(0, K, BLOCK_SIZE_K): a_mask = (offs_am[:, None] < M) & (offs_k[None, :] < K) b_mask = (offs_k[:, None] < K) & (offs_bn[None, :] < N) - a = tl.load(a_ptrs, a_mask) - b = tl.load(b_ptrs, b_mask) - # NOTE the allow_tf32 should be false to force the dot op to do fmadot lowering - accumulator += tl.dot(a, b, allow_tf32=False) + a = tl.load(a_ptrs, a_mask, other=0.0) + b = tl.load(b_ptrs, b_mask, other=0.0) + accumulator += tl.dot(a, b, allow_tf32=ALLOW_TF32) a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk offs_k += BLOCK_SIZE_K @@ -267,6 +270,9 @@ def test_gemm_fmadot(M, N, K, num_warps, block_M, block_N, block_K): c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) tl.store(c_ptrs, accumulator, c_mask) + # Configure the pytorch counterpart + torch.backends.cuda.matmul.allow_tf32 = allow_tf32 + a = torch.randn((M, K), device='cuda', dtype=torch.float32) b = torch.randn((K, N), device='cuda', dtype=torch.float32) c = torch.empty((M, N), device=a.device, dtype=torch.float32) @@ -277,8 +283,12 @@ def test_gemm_fmadot(M, N, K, num_warps, block_M, block_N, block_K): stride_am=a.stride(0), stride_ak=a.stride(1), stride_bk=b.stride(0), stride_bn=b.stride(1), stride_cm=c.stride(0), stride_cn=c.stride(1), - BLOCK_SIZE_M=block_M, BLOCK_SIZE_N=block_N, BLOCK_SIZE_K=block_K) + BLOCK_SIZE_M=block_M, BLOCK_SIZE_N=block_N, BLOCK_SIZE_K=block_K, ALLOW_TF32=allow_tf32) golden = torch.matmul(a, b) golden_abs_err, golden_rel_err = get_proper_err(a, b, golden) - torch.testing.assert_close(c, golden, rtol=max(1e-4, 1.5 * golden_rel_err), atol=max(1e-4, 1.5 * golden_abs_err)) + if allow_tf32: + # TF32 is not accurate enough + torch.testing.assert_close(c, golden, rtol=max(1e-2, 1.5 * golden_rel_err), atol=max(1e-2, 1.5 * golden_abs_err)) + else: + torch.testing.assert_close(c, golden, rtol=max(1e-4, 1.5 * golden_rel_err), atol=max(1e-4, 1.5 * golden_abs_err)) diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index 8ec45a385..e2bd31df0 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -923,6 +923,45 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { // ----- +#mma = #triton_gpu.mma<{version=2, warpsPerCTA=[2, 2]}> +#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}> +#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma}> +#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma}> +module attributes {"triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: matmul_tf32dot + func @matmul_tf32dot(%ptr:!tt.ptr {tt.divisibility = 16 : i32}, + %a:tensor<32x16xf32, #shared>, %b:tensor<16x32xf32, #shared>) { + %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> + // CHECK: llvm.inline_asm + // CHECK-SAME: ldmatrix.sync.aligned.m8n8.x4.shared.b16 + // CHECK-SAME: (f32, f32, f32, f32) + // CHECK: llvm.inline_asm + // CHECK-SAME: ldmatrix.sync.aligned.m8n8.x4.shared.b16 + // CHECK-SAME: (f32, f32, f32, f32) + %a_mat = triton_gpu.convert_layout %a : (tensor<32x16xf32, #shared>) -> tensor<32x16xf32, #dot_operand_a> + %b_mat = triton_gpu.convert_layout %b : (tensor<16x32xf32, #shared>) -> tensor<16x32xf32, #dot_operand_b> + + // CHECK: llvm.inline_asm + // CHECK-SAME: mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 + // CHECK: llvm.inline_asm + // CHECK-SAME: mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 + // CHECK: llvm.inline_asm + // CHECK-SAME: mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 + // CHECK: llvm.inline_asm + // CHECK-SAME: mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 + %28 = tt.dot %a_mat, %b_mat, %cst {allowTF32 = true, transA = false, transB = false} : tensor<32x16xf32, #dot_operand_a> * tensor<16x32xf32, #dot_operand_b> -> tensor<32x32xf32, #mma> + %38 = triton_gpu.convert_layout %28 : (tensor<32x32xf32, #mma>) -> tensor<32x32xf32, #blocked> + + %30 = tt.splat %ptr : (!tt.ptr) -> tensor<32x1x!tt.ptr, #blocked> + %36 = tt.broadcast %30 : (tensor<32x1x!tt.ptr, #blocked>) -> tensor<32x32x!tt.ptr, #blocked> + tt.store %36, %38 : tensor<32x32xf32, #blocked> + return + } +} + +// ----- + #blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> module attributes {"triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: atomic_add_f32