diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 928536f6e..61740ae22 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -2603,7 +2603,7 @@ public: Value ptr = getPtr(ptrIdx); - if (canUseLdmatrix) { + if (canUseLdmatrix) { // work with fp16 int sOffset = matIdx[order[1]] * sMatStride * sMatShape * sTileStride * elemBytes; PTXBuilder builder; @@ -2626,12 +2626,13 @@ public: return ArrayAttr::get(ctx, {IntegerAttr::get(i32_ty, v)}); }; - Type fp16x2Ty = vec_ty(type::f16Ty(ctx), 2); + // The struct should have exactly the same element types. + Type elemType = resV4.getType().cast().getBody()[0]; - return {extract_val(fp16x2Ty, resV4, getIntAttr(0)), - extract_val(fp16x2Ty, resV4, getIntAttr(1)), - extract_val(fp16x2Ty, resV4, getIntAttr(2)), - extract_val(fp16x2Ty, resV4, getIntAttr(3))}; + return {extract_val(elemType, resV4, getIntAttr(0)), + extract_val(elemType, resV4, getIntAttr(1)), + extract_val(elemType, resV4, getIntAttr(2)), + extract_val(elemType, resV4, getIntAttr(3))}; } else if (elemBytes == 4 && needTrans) { // Use lds.32 to load tf32 matrices Value ptr2 = getPtr(ptrIdx + 1); @@ -2658,9 +2659,9 @@ public: elems[3] = load(gep(elemPtrTy, ptr2, i32_val(sOffsetElem + sOffsetArrElem))); } - return {elems[0], elems[1], elems[2], elems[3]}; - } else if (elemBytes == 1 && needTrans) { + + } else if (elemBytes == 1 && needTrans) { // work with int8 std::array, 2> ptrs; ptrs[0] = { getPtr(ptrIdx), @@ -2688,17 +2689,18 @@ public: Value i8Elems[4][4]; Type elemTy = type::i8Ty(ctx); + Type elemPtrTy = ptr_ty(elemTy); if (kOrder == 1) { Value offset = i32_val(sOffsetElem); for (int i = 0; i < 2; ++i) for (int j = 0; j < 4; ++j) - i8Elems[i][j] = load(gep(elemTy, ptrs[i][j], offset)); + i8Elems[i][j] = load(gep(elemPtrTy, ptrs[i][j], offset)); offset = i32_val(sOffsetElem + sOffsetArrElem); for (int i = 2; i < 4; ++i) for (int j = 0; j < 4; ++j) - i8Elems[i][j] = load(gep(elemTy, ptrs[i - 2][j], offset)); + i8Elems[i][j] = load(gep(elemPtrTy, ptrs[i - 2][j], offset)); for (int m = 0; m < 4; ++m) { for (int e = 0; e < 4; ++e) @@ -2709,14 +2711,14 @@ public: } else { // k first Value offset = i32_val(sOffsetElem); for (int j = 0; j < 4; ++j) - i8Elems[0][j] = load(gep(elemTy, ptrs[0][j], offset)); + i8Elems[0][j] = load(gep(elemPtrTy, ptrs[0][j], offset)); for (int j = 0; j < 4; ++j) - i8Elems[2][j] = load(gep(elemTy, ptrs[1][j], offset)); + i8Elems[2][j] = load(gep(elemPtrTy, ptrs[1][j], offset)); offset = i32_val(sOffsetElem + sOffsetArrElem); for (int j = 0; j < 4; ++j) - i8Elems[1][j] = load(gep(elemTy, ptrs[0][j], offset)); + i8Elems[1][j] = load(gep(elemPtrTy, ptrs[0][j], offset)); for (int j = 0; j < 4; ++j) - i8Elems[3][j] = load(gep(elemTy, ptrs[1][j], offset)); + i8Elems[3][j] = load(gep(elemPtrTy, ptrs[1][j], offset)); for (int m = 0; m < 4; ++m) { for (int e = 0; e < 4; ++e) @@ -3501,9 +3503,10 @@ struct MMA16816ConversionHelper { return ArrayAttr::get(ctx, {IntegerAttr::get(i32_ty, v)}); }; + Type elemTy = mmaOut.getType().cast().getBody()[0]; for (int i = 0; i < 4; ++i) fc[m * colsPerThread + 4 * n + i] = - extract_val(type::f32Ty(ctx), mmaOut, getIntAttr(i)); + extract_val(elemTy, mmaOut, getIntAttr(i)); }; for (int k = 0; k < numRepK; ++k) @@ -3511,9 +3514,14 @@ struct MMA16816ConversionHelper { for (int n = 0; n < numRepN; ++n) callMma(2 * m, n, 2 * k); + // bitcast to fp32 in bulk + for (auto &elem : fc) { + elem = bitcast(elem, type::i32Ty(ctx)); + } + // replace with new packed result Type structTy = LLVM::LLVMStructType::getLiteral( - ctx, SmallVector(fc.size(), type::f32Ty(ctx))); + ctx, SmallVector(fc.size(), type::i32Ty(ctx))); Value res = getStructFromElements(loc, fc, rewriter, structTy); rewriter.replaceOp(op, res); @@ -3607,10 +3615,9 @@ private: assert(!elems.empty()); - Type fp16Ty = type::f16Ty(ctx); - Type fp16x2Ty = vec_ty(fp16Ty, 2); + Type elemTy = elems[0].getType(); Type structTy = LLVM::LLVMStructType::getLiteral( - ctx, SmallVector(elems.size(), fp16x2Ty)); + ctx, SmallVector(elems.size(), elemTy)); auto result = getStructFromElements(loc, elems, rewriter, structTy); return result; } @@ -3634,161 +3641,6 @@ private: } }; -// Helper for FMADot conversion. -class DotOpFMAConversionHelper { -public: - MmaEncodingAttr mmaLayout; - ArrayRef wpt; - - using ValueTable = std::map, Value>; - - explicit DotOpFMAConversionHelper(MmaEncodingAttr mmaLayout) - : mmaLayout(mmaLayout), wpt(mmaLayout.getWarpsPerCTA()) {} - - // Currently, we can tell whether to use FMAdot only from the operand type, - // while in the original code, FMADot requires that both the operand and - // result of dot should be fp32. - // This method should be safe to use in the cases where tensor core is not - // appliable. - static bool useFMA(TensorType operand) { - return operand.getElementType().isF32(); - } - - Value loadA(Value tensor, Value llTensor, Value threadId, Location loc, - Value smem, ConversionPatternRewriter &rewriter) const { - - auto *ctx = rewriter.getContext(); - auto tensorTy = tensor.getType().cast(); - auto aShape = tensorTy.getShape(); - auto aLayout = tensorTy.getEncoding().cast(); - auto aOrder = aLayout.getOrder(); - - bool isARow = aOrder[0] == 1; - - int strideAM = isARow ? aShape[1] : 1; - int strideAK = isARow ? 1 : aShape[0]; - int strideA0 = isARow ? strideAK : strideAM; - int strideA1 = isARow ? strideAM : strideAK; - int lda = isARow ? strideAM : strideAK; - int aPerPhase = aLayout.getPerPhase(); - int aMaxPhase = aLayout.getMaxPhase(); - int aNumPtr = 8; - int bNumPtr = 8; - int aVec = 2; - - Value _0 = i32_val(0); - Value _1 = i32_val(1); - - Value mContig = _1; - Value nContig = _1; - - Value offA0 = isARow ? _0 : mul(threadId, mContig); - Value offA1 = isARow ? mul(threadId, mContig) : _0; - SmallVector aOff(aNumPtr); - for (int i = 0; i < aNumPtr; ++i) { - aOff[i] = - add(mul(offA0, i32_val(strideA0)), mul(offA1, i32_val(strideA1))); - } - - Type f32PtrTy = ptr_ty(f32_ty); - SmallVector aPtrs(aNumPtr); - for (int i = 0; i < aNumPtr; ++i) - aPtrs[i] = gep(f32PtrTy, llTensor, aOff[i]); - - ValueTable has; - - auto aShapePerCTA = getShapePerCTA(aLayout); - auto sizePerThread = getSizePerThread(aLayout); - int M = isARow ? aShape[0] : aShape[1]; - int K = isARow ? aShape[1] : aShape[0]; - - for (unsigned k = 0; k < K; k++) - for (unsigned m = 0; m < M; m += aShapePerCTA[aOrder[1]]) - for (unsigned mm = 0; mm < sizePerThread[aOrder[1]]; ++mm) { - Value pa = gep(f32PtrTy, aPtrs[0], - i32_val((m + mm) * strideAM + k * strideAK)); - Value va = load(pa); - has[{m + mm, k}] = va; - } - - SmallVector values; - for (auto &item : has) - values.push_back(item.second); - Type structTy = - struct_ty(SmallVector(values.size(), values[0].getType())); - - return getStructFromElements(loc, values, rewriter, structTy); - } - - Value loadB(Value tensor, Value llTensor, Value threadId, Location loc, - Value smem, ConversionPatternRewriter &rewriter) const { - - auto *ctx = rewriter.getContext(); - auto tensorTy = tensor.getType().cast(); - auto bShape = tensorTy.getShape(); - auto bLayout = tensorTy.getEncoding().cast(); - auto bOrder = bLayout.getOrder(); - - bool isBRow = bOrder[0] == 1; - - int strideBN = isBRow ? 1 : bShape[0]; - int strideBK = isBRow ? bShape[1] : 1; - int strideB0 = isBRow ? strideBN : strideBK; - int strideB1 = isBRow ? strideBK : strideBN; - int ldb = isBRow ? strideBK : strideBN; - int bPerPhase = bLayout.getPerPhase(); - int bMaxPhase = bLayout.getMaxPhase(); - int bNumPtr = 8; - int bVec = 4; - - auto bShapePerCTA = getShapePerCTA(bLayout); - auto sizePerThread = getSizePerThread(bLayout); - - Value _0 = i32_val(0); - Value _1 = i32_val(1); - - Value mContig = _1; - Value nContig = _1; - - Value offB0 = isBRow ? mul(threadId, nContig) : _0; - Value offB1 = isBRow ? _0 : mul(threadId, nContig); - SmallVector bOff(bNumPtr); - for (int i = 0; i < bNumPtr; ++i) { - bOff[i] = - add(mul(offB0, i32_val(strideB0)), mul(offB1, i32_val(strideB1))); - } - - Type f32PtrTy = ptr_ty(f32_ty); - SmallVector bPtrs(bNumPtr); - for (int i = 0; i < bNumPtr; ++i) - bPtrs[i] = gep(f32PtrTy, llTensor, bOff[i]); - - ValueTable hbs; - - int K = isBRow ? bShape[0] : bShape[1]; - int N = isBRow ? bShape[1] : bShape[0]; - - for (int k = 0; k < K; ++k) - for (unsigned n = 0; n < N; n += bShapePerCTA[bOrder[0]]) - for (unsigned nn = 0; nn < sizePerThread[bOrder[0]]; ++nn) { - Value pb = gep(f32PtrTy, bPtrs[0], - i32_val((n + nn) * strideBN + k * strideBK)); - Value vb = load(pb); - hbs[{n + nn, k}] = vb; - } - - SmallVector values; - for (auto &item : hbs) - values.push_back(item.second); - Type structTy = - struct_ty(SmallVector(values.size(), values[0].getType())); - - return getStructFromElements(loc, values, rewriter, structTy); - } - - ValueTable extractLoadedOperand(Value llTensor) const { return ValueTable{}; } -}; - LogicalResult ConvertLayoutOpConversion::lowerSharedToDotOperand( triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { @@ -3842,15 +3694,6 @@ LogicalResult ConvertLayoutOpConversion::lowerSharedToDotOperand( res = helper.loadB(src, adaptor.src(), getThreadId(rewriter, loc), adaptor.src(), loc, rewriter); } - } else if (DotOpFMAConversionHelper::useFMA(dstTensorTy)) { // fmadot - DotOpMmaV1ConversionHelper helper(mmaLayout); - if (dotOperandLayout.getOpIdx() == 0) { // operand $a - res = helper.loadA(src, adaptor.src(), getThreadId(rewriter, loc), - adaptor.src(), loc, rewriter); - } else if (dotOperandLayout.getOpIdx() == 1) { // operand $b - res = helper.loadB(src, adaptor.src(), getThreadId(rewriter, loc), - adaptor.src(), loc, rewriter); - } } else { assert(false && "Unsupported mma layout found"); } @@ -4321,6 +4164,8 @@ DotOpConversion::convertFMADot(triton::DotOp op, OpAdaptor adaptor, auto loc = op.getLoc(); auto threadId = getThreadId(rewriter, loc); + using ValueTable = std::map, Value>; + auto A = op.a(); auto B = op.b(); auto C = op.c(); @@ -4400,8 +4245,7 @@ DotOpConversion::convertFMADot(triton::DotOp op, OpAdaptor adaptor, for (int i = 0; i < bNumPtr; ++i) bPtrs[i] = gep(f32PtrTy, adaptor.b(), bOff[i]); - // TODO initialize ret with $c. - DotOpFMAConversionHelper::ValueTable has, hbs; + ValueTable has, hbs; auto cc = getElementsFromStruct(loc, adaptor.c(), rewriter); SmallVector ret = cc; diff --git a/python/tests/test_gemm.py b/python/tests/test_gemm.py index e1a27ec74..df337988c 100644 --- a/python/tests/test_gemm.py +++ b/python/tests/test_gemm.py @@ -144,3 +144,50 @@ def test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLO torch.set_printoptions(profile="full") assert_close(c, golden, rtol=max(1e-4, 1.5 * golden_rel_err), atol=max(1e-4, 1.5 * golden_abs_err), check_dtype=False) + + +# Precession regression for FMADot is not done yet due to some issue on the optimizer failed to give a blocked layout to dot op. +# TODO[Superjomn]: Uncomment this test and continue to finish precession regression latter. +# @pytest.mark.parametrize('M,N,K,num_warps,block_M,block_N,block_K', [ +# [128, 256, 128, 4, 128, 256, 32], +# [256, 128, 64, 4, 256, 128, 16], +# [128, 64, 128, 4, 128, 64, 32], +# ]) +# def test_gemm_fmadot(M, N, K, num_warps, block_M, block_N, block_K): +# @triton.jit +# def matmul_kernel( +# a_ptr, b_ptr, c_ptr, +# stride_am, stride_ak, +# stride_bk, stride_bn, +# stride_cm, stride_cn, +# K: tl.constexpr, +# BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, +# ): +# offs_m = tl.arange(0, BLOCK_SIZE_M) +# offs_n = tl.arange(0, BLOCK_SIZE_N) +# offs_k = tl.arange(0, BLOCK_SIZE_K) +# a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak +# b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn +# accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) +# for k in range(0, K, BLOCK_SIZE_K): +# a = tl.load(a_ptrs) +# b = tl.load(b_ptrs) +# accumulator += tl.dot(a, b, allow_tf32=True) +# a_ptrs += BLOCK_SIZE_K * stride_ak +# b_ptrs += BLOCK_SIZE_K * stride_bk + +# c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn +# tl.store(c_ptrs, accumulator) + +# a = torch.randn((M, K), device='cuda', dtype=torch.float32) +# b = torch.randn((K, N), device='cuda', dtype=torch.float) +# c = torch.empty((M, N), device=a.device, dtype=torch.float32) +# grid = lambda META: (1, ) +# matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c, +# stride_am=a.stride(0), stride_ak=a.stride(1), +# stride_bk=b.stride(0), stride_bn=b.stride(1), +# stride_cm=c.stride(0), stride_cn=c.stride(1), +# K=a.shape[1], BLOCK_SIZE_M=block_M, BLOCK_SIZE_N=block_N, +# BLOCK_SIZE_K=block_K, num_warps=num_warps) +# golden = torch.matmul(a, b) +# torch.testing.assert_close(c, golden)