From 4dc2396ca070acd3702010e2e853a1a32a6d686d Mon Sep 17 00:00:00 2001 From: Yan Chunwei Date: Wed, 26 Oct 2022 10:33:04 +0800 Subject: [PATCH] [Triton-MLIR][BACKEND] Support $c from mma layout in dot (#798) This PR does 1. Support the case where $c holding a mma layout, this should be useful in forloop in k-axis in GEMM 2. Fix the `unrealized_conversion_cast` in ConvertLayout[shared->dot_op] Known issue 1. There is some IO conflict in GEMM with a k-forloop, it is temporarily solved by [adding a barrier](https://github.com/openai/triton/pull/798/files#diff-8a9a5a7f4a025fb1299af29d190d5626bd9000406d3ea47c49679272d3d6abe9R3028) in dot conversion, but we are still working on it, will get a more generic fix for it in the following PR. 2. The parallel pass will result in a buggy instruction result type ```mlir %1049 = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "cp.async.commit_group ;", "" : () -> !llvm.void %1050 = builtin.unrealized_conversion_cast %1049 : !llvm.void to !llvm.ptr ``` So we temporarily disable it. --- .../TritonGPUToLLVM/TritonGPUToLLVM.cpp | 226 ++++++++++++++---- python/tests/test_gemm.py | 60 +++-- python/triton/compiler.py | 4 +- 3 files changed, 226 insertions(+), 64 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 64f8b409e..74461ebe7 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -622,6 +622,13 @@ protected: Value smem; }; +Value convertSplatLikeOpWithMmaLayout(const MmaEncodingAttr &layout, + Type resType, Type elemType, + Value constVal, + TypeConverter *typeConverter, + ConversionPatternRewriter &rewriter, + Location loc); + // Convert SplatOp or arith::ConstantOp with SplatElementsAttr to a // LLVM::StructType value. // @@ -632,16 +639,26 @@ Value convertSplatLikeOp(Type elemType, Type resType, Value constVal, TypeConverter *typeConverter, ConversionPatternRewriter &rewriter, Location loc) { auto tensorTy = resType.cast(); - auto layout = tensorTy.getEncoding(); - auto srcType = typeConverter->convertType(elemType); - auto llSrc = bitcast(srcType, constVal); - size_t elemsPerThread = getElemsPerThread(layout, tensorTy.getShape()); - llvm::SmallVector elems(elemsPerThread, llSrc); - llvm::SmallVector elemTypes(elems.size(), srcType); - auto structTy = - LLVM::LLVMStructType::getLiteral(rewriter.getContext(), elemTypes); + if (tensorTy.getEncoding().isa()) { + auto tensorTy = resType.cast(); + auto layout = tensorTy.getEncoding(); + auto srcType = typeConverter->convertType(elemType); + auto llSrc = bitcast(srcType, constVal); + size_t elemsPerThread = getElemsPerThread(layout, tensorTy.getShape()); + llvm::SmallVector elems(elemsPerThread, llSrc); + llvm::SmallVector elemTypes(elems.size(), srcType); + auto structTy = + LLVM::LLVMStructType::getLiteral(rewriter.getContext(), elemTypes); - return getStructFromElements(loc, elems, rewriter, structTy); + return getStructFromElements(loc, elems, rewriter, structTy); + } else if (auto mmaLayout = + tensorTy.getEncoding().dyn_cast()) { + return convertSplatLikeOpWithMmaLayout( + mmaLayout, resType, elemType, constVal, typeConverter, rewriter, loc); + } else + assert(false && "Unsupported layout found in ConvertSplatLikeOp"); + + return Value{}; } struct SplatOpConversion @@ -2436,8 +2453,6 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern { MLIRContext *ctx = op->getContext(); bool allowTF32 = op.allowTF32(); - assert(isSplatLike(C) && "Currently only splat-like C is supported now"); - // Here we assume the DotOp's operands always comes from shared memory. auto AShape = A.getType().cast().getShape(); size_t reduceAxis = 1; @@ -2536,6 +2551,31 @@ struct DotOpConversionHelper { mmaType = getTensorCoreTypeFromOperand(operandTy); } + // Get the M and N of mat instruction shape. + static std::tuple getMatShapeMN() { + // According to DotOpConversionHelper::mmaMatShape, all the matrix shape's + // M,N are {8,8} + return {8, 8}; + } + + // Get the M and N of mma instruction shape. + static std::tuple getInstrShapeMN() { + // According to DotOpConversionHelper::mmaInstrShape, all the M,N are {16,8} + return {16, 8}; + } + + static std::tuple getRepMN(const RankedTensorType &tensorTy) { + auto mmaLayout = tensorTy.getEncoding().cast(); + auto wpt = mmaLayout.getWarpsPerCTA(); + + int M = tensorTy.getShape()[0]; + int N = tensorTy.getShape()[1]; + auto [instrM, instrN] = getInstrShapeMN(); + int repM = std::max(M / (wpt[0] * instrM), 1); + int repN = std::max(N / (wpt[1] * instrN), 1); + return {repM, repN}; + } + Type getShemPtrTy() const { switch (mmaType) { case TensorCoreType::FP32_FP16_FP16_FP32: @@ -2633,15 +2673,20 @@ struct DotOpConversionHelper { return mmaInstrShape.at(mmaType); } + static ArrayRef getMmaInstrShape(TensorCoreType tensorCoreType) { + assert(tensorCoreType != TensorCoreType::NOT_APPLICABLE && + "Unknown mma type found."); + return mmaInstrShape.at(tensorCoreType); + } + ArrayRef getMmaMatShape() const { assert(mmaType != TensorCoreType::NOT_APPLICABLE && "Unknown mma type found."); return mmaMatShape.at(mmaType); } - // Deduce the TensorCoreType from either $a or $b's type. This method is not - // safe, but we cannot get the DotOp in some getmaMatShape usage case. - TensorCoreType getTensorCoreTypeFromOperand(Type operandTy) const { + // Deduce the TensorCoreType from either $a or $b's type. + static TensorCoreType getTensorCoreTypeFromOperand(Type operandTy) { auto tensorTy = operandTy.cast(); auto elemTy = tensorTy.getElementType(); if (elemTy.isF16()) @@ -2814,22 +2859,58 @@ struct MMA16816ConversionHelper { // \param operand is either $a or $b's type. inline int getNumRepM(Type operand, int M) const { - auto [mmaInstrM, mmaInstrN, mmaInstrK] = getMmaInstrShape(operand); - return std::max(M / (wpt[0] * mmaInstrM), 1); + return getNumRepM(operand, M, wpt[0]); } // \param operand is either $a or $b's type. inline int getNumRepN(Type operand, int N) const { - auto [mmaInstrM, mmaInstrN, mmaInstrK] = getMmaInstrShape(operand); - return std::max(N / (wpt[1] * mmaInstrN), 1); + return getNumRepN(operand, N, wpt[1]); } // \param operand is either $a or $b's type. inline int getNumRepK(Type operand, int K) const { - auto [mmaInstrM, mmaInstrN, mmaInstrK] = getMmaInstrShape(operand); + return getNumRepK_(operand, K); + } + + static int getNumRepM(Type operand, int M, int wpt) { + auto tensorCoreType = + DotOpConversionHelper::getTensorCoreTypeFromOperand(operand); + int mmaInstrM = DotOpConversionHelper::getMmaInstrShape(tensorCoreType)[0]; + return std::max(M / (wpt * mmaInstrM), 1); + } + + static int getNumRepN(Type operand, int N, int wpt) { + auto tensorCoreType = + DotOpConversionHelper::getTensorCoreTypeFromOperand(operand); + int mmaInstrN = DotOpConversionHelper::getMmaInstrShape(tensorCoreType)[1]; + return std::max(N / (wpt * mmaInstrN), 1); + } + + static int getNumRepK_(Type operand, int K) { + auto tensorCoreType = + DotOpConversionHelper::getTensorCoreTypeFromOperand(operand); + int mmaInstrK = DotOpConversionHelper::getMmaInstrShape(tensorCoreType)[2]; return std::max(K / mmaInstrK, 1); } + // Get number of elements per thread for $a operand. + static size_t getANumElemsPerThread(RankedTensorType operand, + ArrayRef wpt) { + auto shape = operand.getShape(); + int repM = getNumRepM(operand, shape[0], wpt[0]); + int repK = getNumRepK_(operand, shape[1]); + return 4 * repM * repK; + } + + // Get number of elements per thread for $b operand. + static size_t getBNumElemsPerThread(RankedTensorType operand, + ArrayRef wpt) { + auto shape = operand.getShape(); + int repK = getNumRepK_(operand, shape[0]); + int repN = getNumRepN(operand, shape[1], wpt[1]); + return 4 * std::max(repN / 2, 1) * repK; + } + // Loading $a from smem to registers, returns a LLVM::Struct. Value loadA(Value tensor, Value llTensor) const { auto aTensorTy = tensor.getType().cast(); @@ -2863,9 +2944,6 @@ struct MMA16816ConversionHelper { // step2. Format the values to LLVM::Struct to passing to mma codegen. Value result = composeValuesToDotOperandLayoutStruct(ha, numRepM, numRepK); - - // TODO[Superjomn]: Replace the convert_layout op with the result once the - // DotOperandEncodingAttr is ready. return result; } @@ -2894,15 +2972,21 @@ struct MMA16816ConversionHelper { return result; } - // Loading $c from smem(?) to registers, returns a Value. - // NOTE Only SplatLike tensor is supported now. - Value loadC(Value tensor) const { - // Currently, we only support a SplatLike C. For the other cases, e.g., C in - // shared layout or blocked layout, we will support them by expanding - // convert_layout. - auto hc = helper.loadSplatLikeC(tensor, loc, rewriter); - assert(hc.size() == 4UL && "Only splat-like C is supported now"); - return hc[0]; + // Loading $c to registers, returns a Value. + Value loadC(Value tensor, Value llTensor) const { + auto tensorTy = tensor.getType().cast(); + auto [repM, repN] = DotOpConversionHelper::getRepMN(tensorTy); + size_t fcSize = 4 * repM * repN; + + assert(tensorTy.getEncoding().isa() && + "Currently, we only support $c with a mma layout."); + // Load a normal C tensor with mma layout, that should be a + // LLVM::struct with fcSize elements. + auto structTy = llTensor.getType().cast(); + assert(structTy.getBody().size() == fcSize && + "DotOp's $c operand should pass the same number of values as $d in " + "mma layout."); + return llTensor; } // Conduct the Dot conversion. @@ -2934,9 +3018,8 @@ struct MMA16816ConversionHelper { getValuesFromDotOperandLayoutStruct(loadedA, numRepM, numRepK); ValueTable hb = getValuesFromDotOperandLayoutStruct( loadedB, std::max(numRepN / 2, 1), numRepK); - - const int fcSize = 4 * numRepM * numRepN; - SmallVector fc(fcSize, loadedC); + auto fc = ConvertTritonGPUOpToLLVMPatternBase::getElementsFromStruct( + loc, loadedC, rewriter); auto callMma = [&](unsigned m, unsigned n, unsigned k) { unsigned colsPerThread = numRepN * 2; @@ -2974,6 +3057,11 @@ struct MMA16816ConversionHelper { for (unsigned n = 0; n < numRepN; ++n) callMma(2 * m, n, 2 * k); + // NOTE, the barrier here is a temporary trick making the gemm with a + // k-forloop pass the precision test, or it will fail. + // TODO[Superjomn]: Fix with a more general and performance-friendly way. + barrier; + // replace with new packed result Type structTy = LLVM::LLVMStructType::getLiteral( ctx, SmallVector(fc.size(), type::f32Ty(ctx))); @@ -3123,9 +3211,6 @@ LogicalResult ConvertLayoutOpConversion::lowerSharedToDotOperand( } else if (dotOperandLayout.getOpIdx() == 1) { // operand $b res = mmaHelper.loadB(src, adaptor.src()); - } else if (dotOperandLayout.getOpIdx() == 2) { - // operand $c - res = mmaHelper.loadC(src); } rewriter.replaceOp(op, res); @@ -3163,10 +3248,7 @@ DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adaptor, loadedB = mmaHelper.loadB(op.b(), adaptor.b()); } - // TODO[Superjomn]: Process C as a mma layout. - // Currently, C is simply treated as a Splat Op, and the data layout is not - // mattered. - loadedC = mmaHelper.loadC(op.c()); + loadedC = mmaHelper.loadC(op.c(), adaptor.c()); return mmaHelper.convertDot(A, B, C, op.d(), loadedA, loadedB, loadedC, op, adaptor); @@ -3174,6 +3256,26 @@ DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adaptor, /// ====================== mma codegen end ============================ +Value convertSplatLikeOpWithMmaLayout(const MmaEncodingAttr &layout, + Type resType, Type elemType, + Value constVal, + TypeConverter *typeConverter, + ConversionPatternRewriter &rewriter, + Location loc) { + if (layout.getVersion() == 2) { + auto tensorTy = resType.cast(); + auto [repM, repN] = DotOpConversionHelper::getRepMN(tensorTy); + size_t fcSize = 4 * repM * repN; + + auto structTy = LLVM::LLVMStructType::getLiteral( + rewriter.getContext(), SmallVector(fcSize, elemType)); + return getStructFromElements(loc, SmallVector(fcSize, constVal), + rewriter, structTy); + } + + assert(false && "Unsupported mma layout found"); +} + class TritonGPUToLLVMTypeConverter : public LLVMTypeConverter { public: using TypeConverter::convertType; @@ -3199,6 +3301,7 @@ public: } llvm::Optional convertTritonTensorType(RankedTensorType type) { + auto ctx = type.getContext(); Attribute layout = type.getEncoding(); if (layout && (layout.isa() || layout.isa() || @@ -3207,11 +3310,50 @@ public: getElemsPerThread(layout, type.getShape()); SmallVector types(numElementsPerThread, convertType(type.getElementType())); - return LLVM::LLVMStructType::getLiteral(&getContext(), types); + return LLVM::LLVMStructType::getLiteral(ctx, types); } else if (auto shared_layout = layout.dyn_cast_or_null()) { return LLVM::LLVMPointerType::get(convertType(type.getElementType()), 3); + } else if (auto mmaLayout = layout.dyn_cast_or_null()) { + if (mmaLayout.getVersion() == 2) { + auto [repM, repN] = DotOpConversionHelper::getRepMN(type); + size_t fcSize = 4 * repM * repN; + return LLVM::LLVMStructType::getLiteral( + ctx, SmallVector(fcSize, type.getElementType())); + } + + llvm::errs() + << "Unexpected mma layout detected in TritonToLLVMTypeConverter"; + return llvm::None; + + } else if (auto dot_op_layout = + layout.dyn_cast_or_null()) { + auto mmaLayout = dot_op_layout.getParent().cast(); + if (mmaLayout.getVersion() == 2) { + auto wpt = mmaLayout.getWarpsPerCTA(); + Type elemTy = type.getElementType(); + + if (dot_op_layout.getOpIdx() == 0) { // $a + int elems = + MMA16816ConversionHelper::getANumElemsPerThread(type, wpt); + Type x2Ty = vec_ty(elemTy, 2); + return LLVM::LLVMStructType::getLiteral( + ctx, SmallVector(elems, x2Ty)); + } + if (dot_op_layout.getOpIdx() == 1) { // $b + int elems = + MMA16816ConversionHelper::getBNumElemsPerThread(type, wpt); + Type x2Ty = vec_ty(elemTy, 2); + return LLVM::LLVMStructType::getLiteral( + ctx, SmallVector(elems, x2Ty)); + } + } + + llvm::errs() << "Unexpected dot operand layout detected in " + "TritonToLLVMTypeConverter"; + return llvm::None; } + return llvm::None; } }; diff --git a/python/tests/test_gemm.py b/python/tests/test_gemm.py index 4f1ff5fdc..12644cbd3 100644 --- a/python/tests/test_gemm.py +++ b/python/tests/test_gemm.py @@ -35,6 +35,9 @@ def matmul_no_scf_kernel( [256, 128, 16, 4], [128, 16, 32, 4], [32, 128, 64, 4], + [128, 128, 64, 4], + [64, 128, 128, 4], + [64, 128, 128, 2], ]) def test_gemm_no_scf(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS): a = torch.randn((SIZE_M, SIZE_K), device='cuda', dtype=torch.float16) @@ -78,24 +81,39 @@ def matmul_kernel( tl.store(c_ptrs, accumulator) # TODO: DotConversion in TritonGPUToLLVM cannot support non-splat C for the moment -# @pytest.mark.parametrize('SIZE_M,SIZE_N,SIZE_K,NUM_WARPS,BLOCK_SIZE_M,BLOCK_SIZE_N,BLOCK_SIZE_K', [ -# [128, 256, 128, 4, 128, 256, 32], -# # [256, 128, 64, 4, 256, 128, 16], -# # [128, 16, 128, 4, 128, 16, 32], -# # [32, 128, 256, 4, 32, 128, 64], -# ]) -# def test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K): -# a = torch.randn((SIZE_M, SIZE_K), device='cuda', dtype=torch.float16) -# b = torch.randn((SIZE_K, SIZE_N), device='cuda', dtype=torch.float16) -# c = torch.empty((SIZE_M, SIZE_N), device=a.device, dtype=torch.float32) -# grid = lambda META: (1, ) -# matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c, -# stride_am=a.stride(0), stride_ak=a.stride(1), -# stride_bk=b.stride(0), stride_bn=b.stride(1), -# stride_cm=c.stride(0), stride_cn=c.stride(1), -# M=a.shape[0], N=b.shape[1], K=a.shape[1], -# BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K, -# num_warps=NUM_WARPS) -# golden = torch.matmul(a, b) -# torch.set_printoptions(profile="full") -# assert_close(c, golden, rtol=1e-3, atol=1e-3, check_dtype=False) + + +@pytest.mark.parametrize('SIZE_M,SIZE_N,SIZE_K,NUM_WARPS,BLOCK_SIZE_M,BLOCK_SIZE_N,BLOCK_SIZE_K', [ + # Non-forloop + [64, 32, 64, 4, 64, 32, 64], + [128, 64, 128, 4, 128, 64, 128], + # K-Forloop + [64, 32, 128, 4, 64, 32, 64], + [128, 16, 128, 4, 128, 16, 32], + [32, 16, 128, 4, 32, 16, 32], + [32, 64, 128, 4, 32, 64, 32], + [32, 128, 256, 4, 32, 128, 64], + [64, 128, 64, 4, 64, 128, 32], + [128, 128, 64, 4, 128, 128, 32], + [64, 64, 128, 4, 64, 64, 32], + [128, 128, 128, 4, 128, 128, 32], + [128, 128, 256, 4, 128, 128, 64], + [128, 256, 128, 4, 128, 256, 32], + [256, 128, 64, 4, 256, 128, 16], + [128, 64, 128, 4, 128, 64, 32], +]) +def test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K): + a = torch.randn((SIZE_M, SIZE_K), device='cuda', dtype=torch.float16) + b = torch.randn((SIZE_K, SIZE_N), device='cuda', dtype=torch.float16) + c = torch.empty((SIZE_M, SIZE_N), device=a.device, dtype=torch.float32) + grid = lambda META: (1, ) + matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c, + stride_am=a.stride(0), stride_ak=a.stride(1), + stride_bk=b.stride(0), stride_bn=b.stride(1), + stride_cm=c.stride(0), stride_cn=c.stride(1), + M=a.shape[0], N=b.shape[1], K=a.shape[1], + BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K, + num_warps=NUM_WARPS) + golden = torch.matmul(a, b) + torch.set_printoptions(profile="full") + assert_close(c, golden, rtol=1e-3, atol=1e-3, check_dtype=False) diff --git a/python/triton/compiler.py b/python/triton/compiler.py index 7b9f99d8f..a48d1ec25 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -872,7 +872,9 @@ def make_tritongpu_ir(mod, num_warps): def optimize_tritongpu_ir(mod, num_stages): pm = _triton.ir.pass_manager(mod.context) pm.enable_debug() - pm.add_tritongpu_pipeline_pass(num_stages) + # Get error in backend due to wrong conversion in expanding async-related instruction. + # TODO[Superjomn]: Open it when fixed. + # pm.add_tritongpu_pipeline_pass(num_stages) pm.add_canonicalizer_pass() pm.add_cse_pass() pm.add_coalesce_pass()