From de5b84c476a755e99d14568d2ccc3180b6fa8f0b Mon Sep 17 00:00:00 2001 From: Yan Chunwei Date: Wed, 9 Nov 2022 12:23:43 +0800 Subject: [PATCH] [Triton-MLIR][Backend] Fix mma int8 precision error (#850) Fix mma.16816 s8 precision error Co-authored-by: ben-zhang-609 --- .../TritonGPUToLLVM/TritonGPUToLLVM.cpp | 146 +++++++++++++----- lib/Dialect/TritonGPU/IR/Dialect.cpp | 11 +- python/tests/test_gemm.py | 29 +++- test/Conversion/tritongpu_to_llvm.mlir | 2 +- 4 files changed, 146 insertions(+), 42 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index ec16e9aef..9f7a1cd62 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -87,7 +87,6 @@ static Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, void llPrintf(StringRef msg, ValueRange args, ConversionPatternRewriter &rewriter); -// Shortcuts for some commonly used LLVM ops to keep code simple and intuitive// // Shortcuts for some commonly used LLVM ops to keep code simple and intuitive #define zext(...) rewriter.create(loc, __VA_ARGS__) #define udiv(...) rewriter.create(loc, __VA_ARGS__) @@ -2923,8 +2922,8 @@ public: Value sOffset = mul(i32_val(matIdx[order[1]] * sMatStride * sMatShape), sTileStride); Value sOffsetPtr = gep(shemPtrTy, ptr, sOffset); - PTXBuilder builder; + PTXBuilder builder; // ldmatrix.m8n8.x4 returns 4x2xfp16(that is 4xb32) elements for a // thread. auto resArgs = builder.newListOperand(4, "=r"); @@ -2943,12 +2942,13 @@ public: return ArrayAttr::get(ctx, {IntegerAttr::get(i32_ty, v)}); }; - Type fp16x2Ty = vec_ty(type::f16Ty(ctx), 2); + // The struct should have exactly the same element types. + Type elemType = resV4.getType().cast().getBody()[0]; - return {extract_val(fp16x2Ty, resV4, getIntAttr(0)), - extract_val(fp16x2Ty, resV4, getIntAttr(1)), - extract_val(fp16x2Ty, resV4, getIntAttr(2)), - extract_val(fp16x2Ty, resV4, getIntAttr(3))}; + return {extract_val(elemType, resV4, getIntAttr(0)), + extract_val(elemType, resV4, getIntAttr(1)), + extract_val(elemType, resV4, getIntAttr(2)), + extract_val(elemType, resV4, getIntAttr(3))}; } else if (elemBytes == 4 && needTrans) { // Use lds.32 to load tf32 matrices Value ptr2 = getPtr(ptrIdx + 1); @@ -2961,20 +2961,25 @@ public: Value elems[4]; Type elemTy = type::f32Ty(ctx); + Type elemPtrTy = ptr_ty(elemTy); if (kOrder == 1) { - elems[0] = load(gep(elemTy, ptr, sOffsetElemVal)); - elems[1] = load(gep(elemTy, ptr2, sOffsetElemVal)); - elems[2] = load(gep(elemTy, ptr, sOffsetArrElemVal)); - elems[3] = load(gep(elemTy, ptr2, sOffsetArrElemVal)); + elems[0] = load(gep(elemPtrTy, ptr, i32_val(sOffsetElem))); + elems[1] = load(gep(elemPtrTy, ptr2, i32_val(sOffsetElem))); + elems[2] = + load(gep(elemPtrTy, ptr, i32_val(sOffsetElem + sOffsetArrElem))); + elems[3] = + load(gep(elemPtrTy, ptr2, i32_val(sOffsetElem + sOffsetArrElem))); } else { - elems[0] = load(gep(elemTy, ptr, sOffsetElemVal)); - elems[2] = load(gep(elemTy, ptr2, sOffsetElemVal)); - elems[1] = load(gep(elemTy, ptr, sOffsetArrElemVal)); - elems[3] = load(gep(elemTy, ptr2, sOffsetArrElemVal)); + elems[0] = load(gep(elemPtrTy, ptr, i32_val(sOffsetElem))); + elems[2] = load(gep(elemPtrTy, ptr2, i32_val(sOffsetElem))); + elems[1] = + load(gep(elemPtrTy, ptr, i32_val(sOffsetElem + sOffsetArrElem))); + elems[3] = + load(gep(elemPtrTy, ptr2, i32_val(sOffsetElem + sOffsetArrElem))); } - return {elems[0], elems[1], elems[2], elems[3]}; - } else if (elemBytes == 1 && needTrans) { + + } else if (elemBytes == 1 && needTrans) { // work with int8 std::array, 2> ptrs; ptrs[0] = { getPtr(ptrIdx), @@ -3004,15 +3009,16 @@ public: Value i8Elems[4][4]; Type elemTy = type::i8Ty(ctx); + Type elemPtrTy = ptr_ty(elemTy); if (kOrder == 1) { for (int i = 0; i < 2; ++i) for (int j = 0; j < 4; ++j) - i8Elems[i][j] = load(gep(elemTy, ptrs[i][j], sOffsetElemVal)); + i8Elems[i][j] = load(gep(elemPtrTy, ptrs[i][j], sOffsetElemVal)); for (int i = 2; i < 4; ++i) for (int j = 0; j < 4; ++j) i8Elems[i][j] = - load(gep(elemTy, ptrs[i - 2][j], sOffsetArrElemVal)); + load(gep(elemPtrTy, ptrs[i - 2][j], sOffsetArrElemVal)); for (int m = 0; m < 4; ++m) { for (int e = 0; e < 4; ++e) @@ -3022,13 +3028,13 @@ public: } } else { // k first for (int j = 0; j < 4; ++j) - i8Elems[0][j] = load(gep(elemTy, ptrs[0][j], sOffsetElemVal)); + i8Elems[0][j] = load(gep(elemPtrTy, ptrs[0][j], sOffsetElemVal)); for (int j = 0; j < 4; ++j) - i8Elems[2][j] = load(gep(elemTy, ptrs[1][j], sOffsetElemVal)); + i8Elems[2][j] = load(gep(elemPtrTy, ptrs[1][j], sOffsetElemVal)); for (int j = 0; j < 4; ++j) - i8Elems[1][j] = load(gep(elemTy, ptrs[0][j], sOffsetArrElemVal)); + i8Elems[1][j] = load(gep(elemPtrTy, ptrs[0][j], sOffsetArrElemVal)); for (int j = 0; j < 4; ++j) - i8Elems[3][j] = load(gep(elemTy, ptrs[1][j], sOffsetArrElemVal)); + i8Elems[3][j] = load(gep(elemPtrTy, ptrs[1][j], sOffsetArrElemVal)); for (int m = 0; m < 4; ++m) { for (int e = 0; e < 4; ++e) @@ -3112,6 +3118,7 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern { size_t reduceAxis = 1; unsigned K = AShape[reduceAxis]; bool isOuter = K == 1; + bool isMMA = D.getType() .cast() .getEncoding() @@ -3123,11 +3130,13 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern { .getEncoding() .cast(); - if (!isOuter && isMMA) { + bool isHMMA = isDotHMMA(op); + if (!isOuter && isMMA && isHMMA) { if (mmaLayout.getVersion() == 1) return convertMMA884(op, adaptor, rewriter); if (mmaLayout.getVersion() == 2) return convertMMA16816(op, adaptor, rewriter); + llvm::report_fatal_error( "Unsupported MMA kind found when converting DotOp to LLVM."); } @@ -3140,6 +3149,49 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern { "Unsupported DotOp found when converting TritonGPU to LLVM."); } + // Tell whether a DotOp support HMMA. + // This is port from the master branch, the original logic is retained. + static bool isDotHMMA(DotOp op) { + auto a = op.a(); + auto b = op.b(); + auto c = op.c(); + auto d = op.getResult(); + auto aTensorTy = a.getType().cast(); + auto bTensorTy = b.getType().cast(); + auto cTensorTy = c.getType().cast(); + auto dTensorTy = d.getType().cast(); + + if (!dTensorTy.getEncoding().isa()) + return false; + + auto mmaLayout = dTensorTy.getEncoding().cast(); + auto aElemTy = aTensorTy.getElementType(); + auto bElemTy = bTensorTy.getElementType(); + + assert((mmaLayout.getVersion() == 1 || mmaLayout.getVersion() == 2) && + "Unexpected MMA layout version found"); + // Refer to mma section for the data type supported by Volta and Hopper + // Tensor Core in + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-884-f16 + return (aElemTy.isF16() && bElemTy.isF16()) || + (aElemTy.isBF16() && bElemTy.isBF16()) || + (aElemTy.isF32() && bElemTy.isF32() && op.allowTF32() && + mmaLayout.getVersion() >= 2) || + (aElemTy.isInteger(8) && bElemTy.isInteger(8) && + mmaLayout.getVersion() >= 2); + } + + // Tell whether a DotOp support HMMA by the operand type(either $a or $b). + // We cannot get both the operand types(in TypeConverter), here we assume the + // types of both the operands are identical here. + // TODO[Superjomn]: Find a better way to implement it. + static bool isDotHMMA(TensorType operand, bool allowTF32, int mmaVersion) { + auto elemTy = operand.getElementType(); + return elemTy.isF16() || elemTy.isBF16() || + (elemTy.isF32() && allowTF32 && mmaVersion >= 2) || + (elemTy.isInteger(8) && mmaVersion >= 2); + } + private: // Convert to mma.m16n8k16 LogicalResult convertMMA16816(triton::DotOp a, OpAdaptor adaptor, @@ -3651,6 +3703,7 @@ struct MMA16816ConversionHelper { std::function loadFn; auto [matShapeM, matShapeN, matShapeK] = getMmaMatShape(aTensorTy); auto [mmaInstrM, mmaInstrN, mmaInstrK] = getMmaInstrShape(aTensorTy); + int numRepM = getNumRepM(aTensorTy, shape[0]); int numRepK = getNumRepK(aTensorTy, shape[1]); @@ -3766,6 +3819,7 @@ struct MMA16816ConversionHelper { std::to_string(i))); // reuse the output registers } + mma(retArgs, aArgs, bArgs, cArgs); Value mmaOut = builder.launch(rewriter, loc, helper.getMmaRetType()); @@ -3773,9 +3827,10 @@ struct MMA16816ConversionHelper { return ArrayAttr::get(ctx, {IntegerAttr::get(i32_ty, v)}); }; + Type elemTy = mmaOut.getType().cast().getBody()[0]; for (int i = 0; i < 4; ++i) fc[m * colsPerThread + 4 * n + i] = - extract_val(type::f32Ty(ctx), mmaOut, getIntAttr(i)); + extract_val(elemTy, mmaOut, getIntAttr(i)); }; for (int k = 0; k < numRepK; ++k) @@ -3783,9 +3838,15 @@ struct MMA16816ConversionHelper { for (int n = 0; n < numRepN; ++n) callMma(2 * m, n, 2 * k); + Type resElemTy = dTensorTy.getElementType(); + + for (auto &elem : fc) { + elem = bitcast(elem, resElemTy); + } + // replace with new packed result Type structTy = LLVM::LLVMStructType::getLiteral( - ctx, SmallVector(fc.size(), type::f32Ty(ctx))); + ctx, SmallVector(fc.size(), resElemTy)); Value res = getStructFromElements(loc, fc, rewriter, structTy); rewriter.replaceOp(op, res); @@ -3821,9 +3882,7 @@ private: tensorTy.getShape() /*tileShape*/, instrShape, matShape, perPhase, maxPhase, elemBytes, rewriter, typeConverter, loc); SmallVector offs = loader.computeOffsets(warpId, lane); - const int numPtrs = loader.getNumPtr(); - SmallVector ptrs(numPtrs); Type smemPtrTy = helper.getShemPtrTy(); @@ -3835,6 +3894,7 @@ private: auto [ha0, ha1, ha2, ha3] = loader.loadX4( (kOrder == 1) ? a : b /*mat0*/, (kOrder == 1) ? b : a /*mat1*/, offs, ptrs, helper.getMatType(), helper.getShemPtrTy()); + if (!needTrans) { ld2(vals, a, b, ha0); ld2(vals, a + 1, b, ha1); @@ -3879,10 +3939,9 @@ private: assert(!elems.empty()); - Type fp16Ty = type::f16Ty(ctx); - Type fp16x2Ty = vec_ty(fp16Ty, 2); + Type elemTy = elems[0].getType(); Type structTy = LLVM::LLVMStructType::getLiteral( - ctx, SmallVector(elems.size(), fp16x2Ty)); + ctx, SmallVector(elems.size(), elemTy)); auto result = getStructFromElements(loc, elems, rewriter, structTy); return result; } @@ -3921,9 +3980,25 @@ LogicalResult ConvertLayoutOpConversion::lowerSharedToDotOperand( dotOperandLayout.getParent().dyn_cast_or_null(); assert(mmaLayout); + bool isOuter{}; + { + int K{}; + if (dotOperandLayout.getOpIdx() == 0) // $a + K = dstTensorTy.getShape()[1]; + else // $b + K = dstTensorTy.getShape()[0]; + isOuter = K == 1; + } + + // TODO[Superjomn]: the allowTF32 is not available in ConvertLayoutOp for it + // is an attribute of DotOp. + bool allowTF32 = false; + bool isHMMA = DotOpConversion::isDotHMMA(dstTensorTy, allowTF32, + mmaLayout.getVersion()); + auto smemObj = getSharedMemoryObjectFromStruct(loc, adaptor.src(), rewriter); Value res; - if (mmaLayout.getVersion() == 2) { + if (!isOuter && mmaLayout.getVersion() == 2 && isHMMA) { // tensor core v2 MMA16816ConversionHelper mmaHelper(mmaLayout, getThreadId(rewriter, loc), rewriter, getTypeConverter(), op.getLoc()); @@ -3935,7 +4010,8 @@ LogicalResult ConvertLayoutOpConversion::lowerSharedToDotOperand( // operand $b res = mmaHelper.loadB(src, smemObj); } - } else if (mmaLayout.getVersion() == 1) { + } else if (!isOuter && mmaLayout.getVersion() == 1 && + isHMMA) { // tensor core v1 DotOpMmaV1ConversionHelper helper(mmaLayout); if (dotOperandLayout.getOpIdx() == 0) { // operand $a @@ -5076,8 +5152,8 @@ void ConvertTritonGPUToLLVM::initSharedMemory( OpBuilder b(mod.getBodyRegion()); auto loc = mod.getLoc(); auto elemTy = typeConverter.convertType(b.getIntegerType(8)); - // Set array size 0 and external linkage indicates that we use dynamic shared - // allocation to allow a larger shared memory size for each kernel. + // Set array size 0 and external linkage indicates that we use dynamic + // shared allocation to allow a larger shared memory size for each kernel. auto arrayTy = LLVM::LLVMArrayType::get(elemTy, 0); auto global = b.create( loc, arrayTy, /*isConstant=*/false, LLVM::Linkage::External, diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index c7968c76c..6ffed8df9 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -117,10 +117,13 @@ SmallVector getShapePerCTA(const Attribute &layout) { "BlockedEncodingAttr not implemented"); } } else if (auto mmaLayout = layout.dyn_cast()) { - assert(mmaLayout.getVersion() == 2 && - "mmaLayout version = 1 is not implemented yet"); - return {16 * mmaLayout.getWarpsPerCTA()[0], - 8 * mmaLayout.getWarpsPerCTA()[1]}; + if (mmaLayout.getVersion() == 2) + return {16 * mmaLayout.getWarpsPerCTA()[0], + 8 * mmaLayout.getWarpsPerCTA()[1]}; + if (mmaLayout.getVersion() == 1) + return {16 * mmaLayout.getWarpsPerCTA()[0], + 16 * mmaLayout.getWarpsPerCTA()[1]}; + assert(0 && "Unexpected MMA layout version found"); } else { assert(0 && "Unimplemented usage of getShapePerCTA"); } diff --git a/python/tests/test_gemm.py b/python/tests/test_gemm.py index e1a27ec74..e8326c078 100644 --- a/python/tests/test_gemm.py +++ b/python/tests/test_gemm.py @@ -55,6 +55,33 @@ def test_gemm_no_scf(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS): assert_close(c, golden, rtol=1e-3, atol=1e-3, check_dtype=False) +@pytest.mark.parametrize('SIZE_M,SIZE_N,SIZE_K,NUM_WARPS', [ + [64, 128, 128, 1], + [128, 128, 128, 4], + [16, 8, 32, 1], + [32, 16, 64, 2], + [32, 16, 64, 4], +]) +def test_gemm_no_scf_int8(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS): + a = torch.randint(-5, 5, (SIZE_M, SIZE_K), device='cuda', dtype=torch.int8) + b = torch.randint(-5, 5, (SIZE_K, SIZE_N), device='cuda', dtype=torch.int8) + c = torch.empty((SIZE_M, SIZE_N), device=a.device, dtype=torch.int32) + + grid = lambda META: (1, ) + matmul_no_scf_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c, + stride_am=a.stride(0), stride_ak=a.stride(1), + stride_bk=b.stride(0), stride_bn=b.stride(1), + stride_cm=c.stride(0), stride_cn=c.stride(1), + M=SIZE_M, N=SIZE_N, K=SIZE_K, + num_warps=NUM_WARPS) + + aa = a.cpu() + bb = b.cpu() + golden = torch.matmul(aa.float(), bb.float()).int() + torch.set_printoptions(profile="full") + torch.testing.assert_close(c.cpu(), golden, check_dtype=False) + + @triton.jit def matmul_kernel( a_ptr, b_ptr, c_ptr, @@ -80,8 +107,6 @@ def matmul_kernel( c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn tl.store(c_ptrs, accumulator) -# TODO: DotConversion in TritonGPUToLLVM cannot support non-splat C for the moment - def get_variant_golden(a, b): SIZE_M = a.shape[0] diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index 59ec5a927..0acd64e29 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -347,7 +347,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: basic_extract_slice func @basic_extract_slice() { // CHECK: llvm.mlir.addressof @global_smem - // CHECK: llvm.extractvalue + // CHECK: llvm.extractvalue // CHECK-NEXT: llvm.extractvalue // CHECK-NEXT: llvm.extractvalue // CHECK-NEXT: llvm.extractvalue