diff --git a/lib/Conversion/TritonGPUToLLVM/DotOpHelpers.h b/lib/Conversion/TritonGPUToLLVM/DotOpHelpers.h index fa94ba3f3..2a360f4b1 100644 --- a/lib/Conversion/TritonGPUToLLVM/DotOpHelpers.h +++ b/lib/Conversion/TritonGPUToLLVM/DotOpHelpers.h @@ -204,7 +204,12 @@ struct DotOpMmaV1ConversionHelper { offA[i] = add(mul(offA0I, strideA0), mul(offA1, strideA1)); } - Type f16x2Ty = vec_ty(f16_ty, 2); + Type elemX2Ty = vec_ty(f16_ty, 2); + Type elemPtrTy = ptr_ty(f16_ty); + if (tensorTy.getElementType().isBF16()) { + elemX2Ty = vec_ty(i16_ty, 2); + elemPtrTy = ptr_ty(i16_ty); + } // prepare arguments SmallVector ptrA(numPtrA); @@ -213,30 +218,28 @@ struct DotOpMmaV1ConversionHelper { for (int i = 0; i < numPtrA; i++) ptrA[i] = gep(ptr_ty(f16_ty), smemBase, offA[i]); - Type f16PtrTy = ptr_ty(f16_ty); - auto ld = [&](decltype(has) &vals, int m, int k, Value val0, Value val1) { vals[{m, k}] = {val0, val1}; }; auto loadA = [&](int m, int k) { int offidx = (isARow ? k / 4 : m) % numPtrA; - Value thePtrA = gep(f16PtrTy, smemBase, offA[offidx]); + Value thePtrA = gep(elemPtrTy, smemBase, offA[offidx]); int stepAM = isARow ? m : m / numPtrA * numPtrA; int stepAK = isARow ? k / (numPtrA * vecA) * (numPtrA * vecA) : k; Value offset = add(mul(i32_val(stepAM * strideRepM), strideAM), mul(i32_val(stepAK), strideAK)); - Value pa = gep(f16PtrTy, thePtrA, offset); + Value pa = gep(elemPtrTy, thePtrA, offset); Type aPtrTy = ptr_ty(vec_ty(i32_ty, std::max(vecA / 2, 1)), 3); Value ha = load(bitcast(pa, aPtrTy)); // record lds that needs to be moved - Value ha00 = bitcast(extract_element(ha, i32_val(0)), f16x2Ty); - Value ha01 = bitcast(extract_element(ha, i32_val(1)), f16x2Ty); + Value ha00 = bitcast(extract_element(ha, i32_val(0)), elemX2Ty); + Value ha01 = bitcast(extract_element(ha, i32_val(1)), elemX2Ty); ld(has, m, k, ha00, ha01); if (vecA > 4) { - Value ha10 = bitcast(extract_element(ha, i32_val(2)), f16x2Ty); - Value ha11 = bitcast(extract_element(ha, i32_val(3)), f16x2Ty); + Value ha10 = bitcast(extract_element(ha, i32_val(2)), elemX2Ty); + Value ha11 = bitcast(extract_element(ha, i32_val(3)), elemX2Ty); if (isARow) ld(has, m, k + 4, ha10, ha11); else @@ -256,7 +259,7 @@ struct DotOpMmaV1ConversionHelper { elems.push_back(item.second.second); } - Type resTy = struct_ty(SmallVector(elems.size(), f16x2Ty)); + Type resTy = struct_ty(SmallVector(elems.size(), elemX2Ty)); Value res = getStructFromElements(loc, elems, rewriter, resTy); return res; } @@ -319,8 +322,12 @@ struct DotOpMmaV1ConversionHelper { offB[i] = add(mul(offB0I, strideB0), mul(offB1, strideB1)); } - Type f16PtrTy = ptr_ty(f16_ty); - Type f16x2Ty = vec_ty(f16_ty, 2); + Type elemPtrTy = ptr_ty(f16_ty); + Type elemX2Ty = vec_ty(f16_ty, 2); + if (tensorTy.getElementType().isBF16()) { + elemPtrTy = ptr_ty(i16_ty); + elemX2Ty = vec_ty(i16_ty, 2); + } SmallVector ptrB(numPtrB); ValueTable hbs; @@ -339,17 +346,17 @@ struct DotOpMmaV1ConversionHelper { int stepBK = isBRow ? K : K / (numPtrB * vecB) * (numPtrB * vecB); Value offset = add(mul(i32_val(stepBN * strideRepN), strideBN), mul(i32_val(stepBK), strideBK)); - Value pb = gep(f16PtrTy, thePtrB, offset); + Value pb = gep(elemPtrTy, thePtrB, offset); Value hb = load(bitcast(pb, ptr_ty(vec_ty(i32_ty, std::max(vecB / 2, 1)), 3))); // record lds that needs to be moved - Value hb00 = bitcast(extract_element(hb, i32_val(0)), f16x2Ty); - Value hb01 = bitcast(extract_element(hb, i32_val(1)), f16x2Ty); + Value hb00 = bitcast(extract_element(hb, i32_val(0)), elemX2Ty); + Value hb01 = bitcast(extract_element(hb, i32_val(1)), elemX2Ty); ld(hbs, n, K, hb00, hb01); if (vecB > 4) { - Value hb10 = bitcast(extract_element(hb, i32_val(2)), f16x2Ty); - Value hb11 = bitcast(extract_element(hb, i32_val(3)), f16x2Ty); + Value hb10 = bitcast(extract_element(hb, i32_val(2)), elemX2Ty); + Value hb11 = bitcast(extract_element(hb, i32_val(3)), elemX2Ty); if (isBRow) ld(hbs, n + 1, K, hb10, hb11); else @@ -369,8 +376,7 @@ struct DotOpMmaV1ConversionHelper { elems.push_back(item.second.first); elems.push_back(item.second.second); } - Type fp16x2Ty = vec_ty(type::f16Ty(ctx), 2); - Type resTy = struct_ty(SmallVector(elems.size(), fp16x2Ty)); + Type resTy = struct_ty(SmallVector(elems.size(), elemX2Ty)); Value res = getStructFromElements(loc, elems, rewriter, resTy); return res; } diff --git a/lib/Dialect/TritonGPU/Transforms/Combine.cpp b/lib/Dialect/TritonGPU/Transforms/Combine.cpp index 7dcdc0162..5e391e3c1 100644 --- a/lib/Dialect/TritonGPU/Transforms/Combine.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Combine.cpp @@ -22,6 +22,10 @@ using namespace mlir; namespace { #include "TritonGPUCombine.inc" +using triton::DotOp; +using triton::gpu::ConvertLayoutOp; +using triton::gpu::DotOperandEncodingAttr; +using triton::gpu::MmaEncodingAttr; // ----------------------------------------------------------------------------- // @@ -1019,6 +1023,7 @@ public: dstDotOperandLayout.getIsMMAv1Row().cast().getValue(); if ((order[0] == 1 && isMMAv1Row) || (order[0] == 0 && !isMMAv1Row)) return failure(); + auto newIsRow = BoolAttr::get(op->getContext(), !isMMAv1Row); auto newDstEncoding = triton::gpu::DotOperandEncodingAttr::get( op->getContext(), dstDotOperandLayout.getOpIdx(), @@ -1060,7 +1065,8 @@ public: auto dotOp = cast(op); // TODO: Check data-types and SM compatibility auto oldRetType = dotOp.getResult().getType().cast(); - if (oldRetType.getEncoding().isa()) + if (!oldRetType.getEncoding() || + oldRetType.getEncoding().isa()) return failure(); auto AType = dotOp.getOperand(0).getType().cast(); @@ -1170,7 +1176,8 @@ public: for (size_t i = 0; i < newInitArgs.size(); i++) { auto initArg = newInitArgs[i]; auto regionArg = forOp.getRegionIterArgs()[i]; - if (newInitArgs[i].getType() != forOp.getRegionIterArgs()[i].getType()) { + if (newInitArgs[i].getType() != forOp.getRegionIterArgs()[i].getType() || + newInitArgs[i].getType() != forOp.getResultTypes()[i]) { shouldRematerialize = true; break; } @@ -1186,15 +1193,207 @@ public: BlockAndValueMapping mapping; for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs())) mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]); + mapping.map(forOp.getInductionVar(), newForOp.getInductionVar()); for (Operation &op : forOp.getBody()->getOperations()) { - Operation *newOp = rewriter.clone(op, mapping); + rewriter.clone(op, mapping); } rewriter.replaceOp(forOp, newForOp.getResults()); return success(); } }; +// This pattern collects the wrong Mma those need to update and create the right +// ones for each. +class CollectMmaToUpdateForVolta : public mlir::RewritePattern { + DenseMap &mmaToUpdate; + +public: + CollectMmaToUpdateForVolta( + mlir::MLIRContext *ctx, + DenseMap &mmaToUpdate) + : mlir::RewritePattern(triton::DotOp::getOperationName(), 1, ctx), + mmaToUpdate(mmaToUpdate) {} + + mlir::LogicalResult + matchAndRewrite(mlir::Operation *op, + mlir::PatternRewriter &rewriter) const override { + + auto dotOp = cast(op); + auto *ctx = dotOp->getContext(); + auto AT = dotOp.a().getType().cast(); + auto BT = dotOp.b().getType().cast(); + auto DT = dotOp.d().getType().cast(); + if (!DT.getEncoding()) + return failure(); + auto mmaLayout = DT.getEncoding().dyn_cast(); + if (!(mmaLayout && mmaLayout.isVolta())) + return failure(); + + // Has processed. + if (mmaToUpdate.count(mmaLayout)) + return failure(); + + auto dotOperandA = AT.getEncoding().cast(); + auto dotOperandB = BT.getEncoding().cast(); + bool isARow = dotOperandA.getIsMMAv1Row().cast().getValue(); + bool isBRow = dotOperandB.getIsMMAv1Row().cast().getValue(); + auto [isARow_, isBRow_, isAVec4, isBVec4] = + mmaLayout.decodeVoltaLayoutStates(); + if (isARow_ == isARow && isBRow_ == isBRow) { + return failure(); // No need to update + } + + auto newMmaLayout = MmaEncodingAttr::get( + ctx, mmaLayout.getVersionMajor(), mmaLayout.getWarpsPerCTA(), + AT.getShape(), BT.getShape(), isARow, isBRow); + + // Collect the wrong MMA Layouts, and mark need to update. + mmaToUpdate.try_emplace(mmaLayout, newMmaLayout); + + return failure(); + } +}; + +// Correct the versionMinor field in MmaEncodingAttr for Volta. +class UpdateMMAVersionMinorForVolta : public mlir::RewritePattern { + const DenseMap &mmaToUpdate; + enum class Kind { + kUnk, + kCvtToMma, + kCvtToDotOp, + kDot, + kConstant, + }; + mutable Kind rewriteKind{Kind::kUnk}; + +public: + UpdateMMAVersionMinorForVolta( + mlir::MLIRContext *ctx, llvm::StringRef opName, + const DenseMap &mmaToUpdate) + : RewritePattern(opName, 1 /*benefit*/, ctx), mmaToUpdate(mmaToUpdate) {} + + LogicalResult match(Operation *op) const override { + MmaEncodingAttr mma; + if (mmaToUpdate.empty()) + return failure(); + if (op->getNumResults() != 1) + return failure(); + auto tensorTy = op->getResult(0).getType().dyn_cast(); + if (!tensorTy) + return failure(); + + // ConvertLayoutOp + if (auto cvt = llvm::dyn_cast(op)) { + // cvt X -> dot_operand + if (auto dotOperand = + tensorTy.getEncoding().dyn_cast()) { + mma = dotOperand.getParent().dyn_cast(); + rewriteKind = Kind::kCvtToDotOp; + if (mma && mmaToUpdate.count(mma)) + return success(); + } + if ((mma = tensorTy.getEncoding().dyn_cast())) { + // cvt X -> mma + rewriteKind = Kind::kCvtToMma; + if (mma && mmaToUpdate.count(mma)) + return success(); + } + } else if (auto dot = llvm::dyn_cast(op)) { + // DotOp + mma = dot.d() + .getType() + .cast() + .getEncoding() + .dyn_cast(); + rewriteKind = Kind::kDot; + } else if (auto constant = llvm::dyn_cast(op)) { + // ConstantOp + mma = tensorTy.getEncoding().dyn_cast(); + rewriteKind = Kind::kConstant; + } + + return success(mma && mmaToUpdate.count(mma)); + } + + void rewrite(Operation *op, PatternRewriter &rewriter) const override { + switch (rewriteKind) { + case Kind::kDot: + rewriteDot(op, rewriter); + break; + case Kind::kConstant: + rewriteConstant(op, rewriter); + break; + case Kind::kCvtToDotOp: + rewriteCvtDotOp(op, rewriter); + break; + case Kind::kCvtToMma: + rewriteCvtToMma(op, rewriter); + break; + default: + llvm::report_fatal_error("Not supported rewrite kind"); + } + } + +private: + void rewriteCvtDotOp(Operation *op, PatternRewriter &rewriter) const { + auto *ctx = op->getContext(); + auto cvt = llvm::cast(op); + auto tensorTy = cvt.result().getType().cast(); + auto dotOperand = tensorTy.getEncoding().cast(); + MmaEncodingAttr newMma = + mmaToUpdate.lookup(dotOperand.getParent().cast()); + auto newDotOperand = DotOperandEncodingAttr::get( + ctx, dotOperand.getOpIdx(), newMma, dotOperand.getIsMMAv1Row()); + auto newTensorTy = RankedTensorType::get( + tensorTy.getShape(), tensorTy.getElementType(), newDotOperand); + rewriter.replaceOpWithNewOp(op, newTensorTy, + cvt.getOperand()); + } + + void rewriteDot(Operation *op, PatternRewriter &rewriter) const { + auto *ctx = op->getContext(); + auto dot = llvm::cast(op); + auto tensorTy = dot.d().getType().cast(); + auto mma = tensorTy.getEncoding().cast(); + auto newMma = mmaToUpdate.lookup(mma); + auto newTensorTy = RankedTensorType::get(tensorTy.getShape(), + tensorTy.getElementType(), newMma); + rewriter.replaceOpWithNewOp(op, newTensorTy, dot.a(), dot.b(), + dot.c(), dot.allowTF32()); + } + + void rewriteCvtToMma(Operation *op, PatternRewriter &rewriter) const { + auto *ctx = op->getContext(); + auto cvt = llvm::cast(op); + auto tensorTy = cvt.result().getType().cast(); + auto mma = tensorTy.getEncoding().cast(); + auto newMma = mmaToUpdate.lookup(mma); + auto newTensorTy = RankedTensorType::get(tensorTy.getShape(), + tensorTy.getElementType(), newMma); + rewriter.replaceOpWithNewOp(op, newTensorTy, + cvt.getOperand()); + } + + void rewriteConstant(Operation *op, PatternRewriter &rewriter) const { + auto *ctx = op->getContext(); + auto constant = llvm::cast(op); + auto tensorTy = constant.getResult().getType().dyn_cast(); + auto mma = tensorTy.getEncoding().cast(); + auto newMma = mmaToUpdate.lookup(mma); + auto newTensorTy = RankedTensorType::get(tensorTy.getShape(), + tensorTy.getElementType(), newMma); + if (auto attr = constant.getValue().dyn_cast()) { + auto newRet = + SplatElementsAttr::get(newTensorTy, attr.getSplatValue()); + rewriter.replaceOpWithNewOp(op, newTensorTy, newRet); + return; + } + + assert(false && "Not supported ConstantOp value type"); + } +}; + } // namespace #define GEN_PASS_CLASSES @@ -1229,6 +1428,28 @@ public: signalPassFailure(); } + llvm::DenseMap mmaToUpdate; + { + mlir::RewritePatternSet patterns(context); + patterns.add(context, mmaToUpdate); + if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) + signalPassFailure(); + } + { + mlir::RewritePatternSet patterns(context); + patterns.add( + context, DotOp::getOperationName(), mmaToUpdate); + patterns.add( + context, ConvertLayoutOp::getOperationName(), mmaToUpdate); + patterns.add( + context, arith::ConstantOp::getOperationName(), mmaToUpdate); + mlir::GreedyRewriteConfig config; + config.useTopDownTraversal = true; + + if (applyPatternsAndFoldGreedily(m, std::move(patterns), config).failed()) + signalPassFailure(); + } + mlir::RewritePatternSet loopFixup(context); loopFixup.add(context); if (applyPatternsAndFoldGreedily(m, std::move(loopFixup)).failed()) { diff --git a/test/TritonGPU/combine.mlir b/test/TritonGPU/combine.mlir index b4d2da376..a040873d8 100644 --- a/test/TritonGPU/combine.mlir +++ b/test/TritonGPU/combine.mlir @@ -1,4 +1,4 @@ -// RUN: triton-opt %s -tritongpu-combine 2>&1 | FileCheck %s +// RUN: triton-opt %s -split-input-file -tritongpu-combine 2>&1 | FileCheck %s #layout0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> #layout1 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> @@ -7,7 +7,6 @@ // CHECK: [[row_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}> // CHECK: [[col_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1]}> // CHECK: [[col_layout_novec:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> - func @cst() -> tensor<1024xi32, #layout1> { %cst = arith.constant dense<0> : tensor<1024xi32, #layout0> %1 = triton_gpu.convert_layout %cst : (tensor<1024xi32, #layout0>) -> tensor<1024xi32, #layout1> @@ -62,9 +61,9 @@ func @remat(%arg0: i32) -> tensor<1024xi32, #layout1> { // CHECK-LABEL: transpose func @transpose(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) { // CHECK-NOT: triton_gpu.convert_layout - // CHECK: [[loaded_val:%.*]] = tt.load {{.*}}, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32, [[row_layout]]> + // CHECK: [[loaded_val:%.*]] = tt.load {{.*}}, {{%cst.*}}, {{%cst.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32, [[row_layout]]> // CHECK: [[cvt_val:%.*]] = triton_gpu.convert_layout [[loaded_val]] : (tensor<64x64xf32, [[row_layout]]>) -> tensor<64x64xf32, [[col_layout]]> - // CHECK: tt.store {{.*}}, [[cvt_val]], %cst_1 : tensor<64x64xf32, [[col_layout]]> + // CHECK: tt.store {{.*}}, [[cvt_val]], {{%cst.*}} : tensor<64x64xf32, [[col_layout]]> // CHECK: return %cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked1> %cst_0 = arith.constant dense : tensor<64x64xi1, #blocked1> @@ -184,3 +183,32 @@ func @vecadd(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr return } + + +// ----- + +// check the UpdateMMAVersionMinorForVolta pattern +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0]}> +#shared0 = #triton_gpu.shared<{vec = 1, perPhase=2, maxPhase=8 ,order = [1, 0]}> +#mma0 = #triton_gpu.mma<{versionMajor=1, versionMinor=0, warpsPerCTA=[1,1]}> +// Here, the isMMAv1Row of a and b's dot_operands mismatch #mma0's versionMinor, +// and the pattern should update the versionMinor. +#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma0, isMMAv1Row=true}> +#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma0, isMMAv1Row=false}> +// It creates a new MMA layout to fit with $a and $b's dot_operand +// CHECK: [[new_mma:#mma.*]] = #triton_gpu.mma<{versionMajor = 1, versionMinor = 11, warpsPerCTA = [1, 1]}> +module attributes {"triton_gpu.num-warps" = 1 : i32} { + // CHECK-LABEL: dot_mmav1 + func @dot_mmav1(%A: tensor<16x16xf16, #blocked0>, %B: tensor<16x16xf16, #blocked0>) -> tensor<16x16xf32, #blocked0> { + %C = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #blocked0> + %AA = triton_gpu.convert_layout %A : (tensor<16x16xf16, #blocked0>) -> tensor<16x16xf16, #dot_operand_a> + %BB = triton_gpu.convert_layout %B : (tensor<16x16xf16, #blocked0>) -> tensor<16x16xf16, #dot_operand_b> + %CC = triton_gpu.convert_layout %C : (tensor<16x16xf32, #blocked0>) -> tensor<16x16xf32, #mma0> + + // CHECK: {{.*}} = tt.dot {{.*}}, {{.*}}, %cst {allowTF32 = true} : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = [[new_mma]], isMMAv1Row = true}>> * tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = [[new_mma]], isMMAv1Row = true}>> -> tensor<16x16xf32, [[new_mma]]> + %D = tt.dot %AA, %BB, %CC {allowTF32 = true} : tensor<16x16xf16, #dot_operand_a> * tensor<16x16xf16, #dot_operand_b> -> tensor<16x16xf32, #mma0> + %res = triton_gpu.convert_layout %D : (tensor<16x16xf32, #mma0>) -> tensor<16x16xf32, #blocked0> + + return %res : tensor<16x16xf32, #blocked0> + } +}