diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 9f7a1cd62..2d79611f6 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -3201,10 +3201,7 @@ private: ConversionPatternRewriter &rewriter) const; LogicalResult convertFMADot(triton::DotOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - assert(false && "Not implemented yet."); - return failure(); - } + ConversionPatternRewriter &rewriter) const; }; // Helper for conversion of DotOp with mma, that is sm<80 @@ -4497,6 +4494,155 @@ DotOpMmaV1ConversionHelper::extractLoadedOperand( return rcds; } +LogicalResult +DotOpConversion::convertFMADot(triton::DotOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto *ctx = rewriter.getContext(); + auto loc = op.getLoc(); + auto threadId = getThreadId(rewriter, loc); + + using ValueTable = std::map, Value>; + + auto A = op.a(); + auto B = op.b(); + auto C = op.c(); + auto D = op.getResult(); + + auto aTensorTy = A.getType().cast(); + auto bTensorTy = B.getType().cast(); + auto cTensorTy = C.getType().cast(); + auto dTensorTy = D.getType().cast(); + + auto aShape = aTensorTy.getShape(); + auto bShape = bTensorTy.getShape(); + auto cShape = cTensorTy.getShape(); + + auto aLayout = aTensorTy.getEncoding().cast(); + auto bLayout = bTensorTy.getEncoding().cast(); + auto cLayout = cTensorTy.getEncoding().cast(); + auto dLayout = dTensorTy.getEncoding().cast(); + + auto aOrder = aLayout.getOrder(); + auto bOrder = bLayout.getOrder(); + + auto order = dLayout.getOrder(); + + bool isARow = aOrder[0] == 1; + bool isBRow = bOrder[0] == 1; + + int strideAM = isARow ? aShape[1] : 1; + int strideAK = isARow ? 1 : aShape[0]; + int strideBN = isBRow ? 1 : bShape[0]; + int strideBK = isBRow ? bShape[1] : 1; + int strideA0 = isARow ? strideAK : strideAM; + int strideA1 = isARow ? strideAM : strideAK; + int strideB0 = isBRow ? strideBN : strideBK; + int strideB1 = isBRow ? strideBK : strideBN; + int lda = isARow ? strideAM : strideAK; + int ldb = isBRow ? strideBK : strideBN; + int aPerPhase = aLayout.getPerPhase(); + int aMaxPhase = aLayout.getMaxPhase(); + int bPerPhase = bLayout.getPerPhase(); + int bMaxPhase = bLayout.getMaxPhase(); + int aNumPtr = 8; + int bNumPtr = 8; + int NK = aShape[1]; + + auto shapePerCTA = getShapePerCTA(dLayout); + + auto sizePerThread = getSizePerThread(dLayout); + + Value _0 = i32_val(0); + + Value mContig = i32_val(sizePerThread[order[1]]); + Value nContig = i32_val(sizePerThread[order[0]]); + + // threadId in blocked layout + SmallVector threadIds; + { + int dim = cShape.size(); + threadIds.resize(dim); + for (unsigned k = 0; k < dim - 1; k++) { + Value dimK = i32_val(shapePerCTA[order[k]]); + Value rem = urem(threadId, dimK); + threadId = udiv(threadId, dimK); + threadIds[order[k]] = rem; + } + Value dimK = i32_val(shapePerCTA[order[dim - 1]]); + threadIds[order[dim - 1]] = urem(threadId, dimK); + } + + Value threadIdM = threadIds[0]; + Value threadIdN = threadIds[1]; + + Value offA0 = isARow ? _0 : mul(threadIdM, mContig); + Value offA1 = isARow ? mul(threadIdM, mContig) : _0; + SmallVector aOff(aNumPtr); + for (int i = 0; i < aNumPtr; ++i) { + aOff[i] = add(mul(offA0, i32_val(strideA0)), mul(offA1, i32_val(strideA1))); + } + + Value offB0 = isBRow ? mul(threadIdN, nContig) : _0; + Value offB1 = isBRow ? _0 : mul(threadIdN, nContig); + SmallVector bOff(bNumPtr); + for (int i = 0; i < bNumPtr; ++i) { + bOff[i] = add(mul(offB0, i32_val(strideB0)), mul(offB1, i32_val(strideB1))); + } + + auto aSmem = getSharedMemoryObjectFromStruct(loc, adaptor.a(), rewriter); + auto bSmem = getSharedMemoryObjectFromStruct(loc, adaptor.b(), rewriter); + + Type f32PtrTy = ptr_ty(f32_ty); + SmallVector aPtrs(aNumPtr); + for (int i = 0; i < aNumPtr; ++i) + aPtrs[i] = gep(f32PtrTy, aSmem.base, aOff[i]); + + SmallVector bPtrs(bNumPtr); + for (int i = 0; i < bNumPtr; ++i) + bPtrs[i] = gep(f32PtrTy, bSmem.base, bOff[i]); + + ValueTable has, hbs; + auto cc = getElementsFromStruct(loc, adaptor.c(), rewriter); + SmallVector ret = cc; + // is this compatible with blocked layout? + + for (unsigned k = 0; k < NK; k++) { + int z = 0; + for (unsigned i = 0; i < cShape[order[1]]; i += shapePerCTA[order[1]]) + for (unsigned j = 0; j < cShape[order[0]]; j += shapePerCTA[order[0]]) + for (unsigned ii = 0; ii < sizePerThread[order[1]]; ++ii) + for (unsigned jj = 0; jj < sizePerThread[order[0]]; ++jj) { + unsigned m = order[0] == 1 ? i : j; + unsigned n = order[0] == 1 ? j : i; + unsigned mm = order[0] == 1 ? ii : jj; + unsigned nn = order[0] == 1 ? jj : ii; + if (!has.count({m + mm, k})) { + Value pa = gep(f32PtrTy, aPtrs[0], + i32_val((m + mm) * strideAM + k * strideAK)); + Value va = load(pa); + has[{m + mm, k}] = va; + } + if (!hbs.count({n + nn, k})) { + Value pb = gep(f32PtrTy, bPtrs[0], + i32_val((n + nn) * strideBN + k * strideBK)); + Value vb = load(pb); + hbs[{n + nn, k}] = vb; + } + + ret[z] = rewriter.create(loc, has[{m + mm, k}], + hbs[{n + nn, k}], ret[z]); + ++z; + } + } + + auto res = getStructFromElements( + loc, ret, rewriter, + struct_ty(SmallVector(ret.size(), ret[0].getType()))); + rewriter.replaceOp(op, res); + + return success(); +} + /// ====================== mma codegen end ============================ Value convertSplatLikeOpWithMmaLayout(const MmaEncodingAttr &layout, diff --git a/lib/Dialect/TritonGPU/Transforms/Combine.cpp b/lib/Dialect/TritonGPU/Transforms/Combine.cpp index 416c82671..d7c622e50 100644 --- a/lib/Dialect/TritonGPU/Transforms/Combine.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Combine.cpp @@ -576,6 +576,14 @@ public: auto oldRetType = dotOp.getResult().getType().cast(); if (oldRetType.getEncoding().isa()) return failure(); + + auto A = dotOp.getOperand(0).getType().cast(); + auto B = dotOp.getOperand(1).getType().cast(); + // for FMA, should retain the blocked layout. + if (A.getElementType().isF32() && B.getElementType().isF32() && + !dotOp.allowTF32()) + return failure(); + // get MMA encoding for the given number of warps auto retShape = oldRetType.getShape(); auto mod = op->getParentOfType(); @@ -629,4 +637,4 @@ public: std::unique_ptr mlir::createTritonGPUCombineOpsPass() { return std::make_unique(); -} \ No newline at end of file +} diff --git a/python/tests/test_gemm.py b/python/tests/test_gemm.py index e8326c078..e264fe086 100644 --- a/python/tests/test_gemm.py +++ b/python/tests/test_gemm.py @@ -169,3 +169,65 @@ def test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLO torch.set_printoptions(profile="full") assert_close(c, golden, rtol=max(1e-4, 1.5 * golden_rel_err), atol=max(1e-4, 1.5 * golden_abs_err), check_dtype=False) + + +@pytest.mark.parametrize('M,N,K,num_warps,block_M,block_N,block_K', [ + [32, 32, 16, 4, 32, 32, 16], + [32, 16, 16, 4, 32, 32, 16], + [128, 8, 8, 4, 32, 32, 16], + [127, 41, 43, 4, 32, 32, 16], +]) +def test_gemm_fmadot(M, N, K, num_warps, block_M, block_N, block_K): + @triton.jit + def matmul_kernel( + a_ptr, b_ptr, c_ptr, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + ): + pid = tl.program_id(axis=0) + # num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, K, BLOCK_SIZE_K): + a_mask = (offs_am[:, None] < M) & (offs_k[None, :] < K) + b_mask = (offs_k[:, None] < K) & (offs_bn[None, :] < N) + a = tl.load(a_ptrs, a_mask) + b = tl.load(b_ptrs, b_mask) + # NOTE the allow_tf32 should be false to force the dot op to do fmadot lowering + accumulator += tl.dot(a, b, allow_tf32=False) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + offs_k += BLOCK_SIZE_K + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, c_mask) + + a = torch.randn((M, K), device='cuda', dtype=torch.float32) + b = torch.randn((K, N), device='cuda', dtype=torch.float32) + c = torch.empty((M, N), device=a.device, dtype=torch.float32) + + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),) + matmul_kernel[grid](a, b, c, + M, N, K, + stride_am=a.stride(0), stride_ak=a.stride(1), + stride_bk=b.stride(0), stride_bn=b.stride(1), + stride_cm=c.stride(0), stride_cn=c.stride(1), + BLOCK_SIZE_M=block_M, BLOCK_SIZE_N=block_N, BLOCK_SIZE_K=block_K) + + golden = torch.matmul(a, b) + torch.testing.assert_close(c, golden) diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index 0acd64e29..30da6ad52 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -811,3 +811,22 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { return } } + +// ----- +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}> +#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> +#mma = #triton_gpu.mma<{version = 2, warpsPerCTA = [2, 2]}> +#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma}> +#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma}> +module attributes {"triton_gpu.num-warps" = 4 : i32} { + func @matmul_fmadot(%ptr:!tt.ptr {tt.divisibility = 16 : i32}, + %a:tensor<32x16xf32, #shared>, %b:tensor<16x32xf32, #shared>) { + %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #blocked> + // CHECK: llvm.intr.fmuladd + %28 = tt.dot %a, %b, %cst {allowTF32 = false, transA = false, transB = false} : tensor<32x16xf32, #shared> * tensor<16x32xf32, #shared> -> tensor<32x32xf32, #blocked> + %30 = tt.splat %ptr : (!tt.ptr) -> tensor<32x1x!tt.ptr, #blocked> + %36 = tt.broadcast %30 : (tensor<32x1x!tt.ptr, #blocked>) -> tensor<32x32x!tt.ptr, #blocked> + tt.store %36, %28 : tensor<32x32xf32, #blocked> + return + } +}