From 52accd4c2b6bf36d664db218e11d72151af1f2e6 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Sun, 11 Dec 2022 19:01:57 -0800 Subject: [PATCH] [BACKEND] Add isRow attribute for DotOp tensors whose parent is mmav1 (#970) Co-authored-by: Yan Chunwei --- .../Dialect/TritonGPU/IR/TritonGPUAttrDefs.td | 22 +++- .../TritonGPUToLLVM/TritonGPUToLLVM.cpp | 30 ++++- lib/Dialect/TritonGPU/IR/Dialect.cpp | 15 ++- lib/Dialect/TritonGPU/Transforms/Combine.cpp | 120 +++++++++++++++++- python/tests/test_gemm.py | 16 +-- python/triton/compiler.py | 6 +- test/Conversion/tritongpu_to_llvm.mlir | 4 +- 7 files changed, 186 insertions(+), 27 deletions(-) diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index d4ff8021d..c658667ce 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -416,15 +416,35 @@ In TritonGPU dialect, considering `d = tt.dot a, b, c` tt.dot's operands a and b must be of DotOperandEncodingAttr layout. a's opIdx is 0, b's opIdx is 1. The parend field in DotOperandEncodingAttr is the layout of d. + +For MMA v1, an additional attribute `isMMAv1Row` determines whether e.g. the a operand is used +in the context of an mma.884.row.col or an mma.884.col.col operation. See the PTX ISA documentation +section 9.7.13.4.1 for more details. }]; let parameters = ( ins "unsigned":$opIdx, - "Attribute":$parent + "Attribute":$parent, + "Attribute":$isMMAv1Row ); + let builders = [ + AttrBuilder<(ins "unsigned":$opIdx, + "Attribute":$parent), [{ + Attribute isMMAv1Row; + if(parent.isa() && + parent.cast().getVersion() == 1){ + isMMAv1Row = BoolAttr::get(context, true); + } + return $_get(context, opIdx, parent, isMMAv1Row); + }]> + + ]; + let extraClassDeclaration = extraBaseClassDeclaration; } + + #endif diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 839d42170..9526e1a5a 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -3432,6 +3432,20 @@ Value ConvertLayoutOpConversion::lowerSharedToDotOperandMMA( } else if (!isOuter && mmaLayout.getVersion() == 1 && isHMMA) { // tensor core v1 DotOpMmaV1ConversionHelper helper(mmaLayout); + bool isMMAv1Row = + dotOperandLayout.getIsMMAv1Row().cast().getValue(); + auto srcSharedLayout = src.getType() + .cast() + .getEncoding() + .cast(); + + // Can only convert [1, 0] to row or [0, 1] to col for now + if ((srcSharedLayout.getOrder()[0] == 1 && !isMMAv1Row) || + (srcSharedLayout.getOrder()[0] == 0 && isMMAv1Row)) { + llvm::errs() << "Unsupported Shared -> DotOperand[MMAv1] conversion\n"; + return Value(); + } + if (dotOperandLayout.getOpIdx() == 0) { // operand $a // TODO[Superjomn]: transA is not available here. bool transA = false; @@ -3544,6 +3558,14 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor, .cast() .getEncoding() .cast(); + auto ALayout = A.getType() + .cast() + .getEncoding() + .cast(); + auto BLayout = B.getType() + .cast() + .getEncoding() + .cast(); auto ATensorTy = A.getType().cast(); auto BTensorTy = B.getType().cast(); @@ -3555,12 +3577,8 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor, auto DShape = DTensorTy.getShape(); auto wpt = mmaLayout.getWarpsPerCTA(); - // TODO[Superjomn]: order cannot accessed in DotOp. - SmallVector AOrder({1, 0}); - SmallVector BOrder({1, 0}); - - bool isARow = AOrder[0] != 0; - bool isBRow = BOrder[0] != 0; + bool isARow = ALayout.getIsMMAv1Row().cast().getValue(); + bool isBRow = BLayout.getIsMMAv1Row().cast().getValue(); bool isAVec4 = !isARow && AShape[isARow] <= 16; // fp16*4 = 16bytes bool isBVec4 = isBRow && BShape[isBRow] <= 16; // TODO[Superjomn]: ld.v4 is not supported. diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 2649be1f0..248aaced4 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -589,15 +589,24 @@ Attribute DotOperandEncodingAttr::parse(AsmParser &parser, Type type) { return {}; unsigned opIdx = attrs.get("opIdx").cast().getInt(); Attribute parent = attrs.get("parent"); - + Attribute isMMAv1Row; + if(parent.isa() && + parent.cast().getVersion() == 1){ + isMMAv1Row = attrs.get("isMMAv1Row"); + if(!isMMAv1Row) + llvm::report_fatal_error("isMMAv1Row attribute is missing"); + } return parser.getChecked(parser.getContext(), opIdx, - parent); + parent, isMMAv1Row); } void DotOperandEncodingAttr::print(mlir::AsmPrinter &printer) const { printer << "<{" << "opIdx = " << getOpIdx() << ", " - << "parent = " << getParent() << "}>"; + << "parent = " << getParent(); + if(getIsMMAv1Row()) + printer << ", isMMAv1Row = " << getIsMMAv1Row(); + printer << "}>"; } //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/TritonGPU/Transforms/Combine.cpp b/lib/Dialect/TritonGPU/Transforms/Combine.cpp index 3417e36ac..53bab8a7e 100644 --- a/lib/Dialect/TritonGPU/Transforms/Combine.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Combine.cpp @@ -715,6 +715,55 @@ public: } }; +class OptimizeConvertToDotOperand : public mlir::RewritePattern { +public: + OptimizeConvertToDotOperand(mlir::MLIRContext *context) + : RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 1, + context) {} + + mlir::LogicalResult + matchAndRewrite(mlir::Operation *op, + mlir::PatternRewriter &rewriter) const override { + auto cvt = cast(op); + auto srcType = cvt.getOperand().getType().cast(); + auto dstType = cvt.getResult().getType().cast(); + // order + ArrayRef order; + if(auto srcBlockedLayout = + srcType.getEncoding().dyn_cast()) + order = srcBlockedLayout.getOrder(); + else if(auto srcSharedLayout = + srcType.getEncoding().dyn_cast()) + order = srcSharedLayout.getOrder(); + else + return failure(); + // dot operand output + auto dstDotOperandLayout = + dstType.getEncoding().dyn_cast(); + if (!dstDotOperandLayout) + return failure(); + unsigned opIdx = dstDotOperandLayout.getOpIdx(); + if(!dstDotOperandLayout.getIsMMAv1Row()) + return failure(); + bool isMMAv1Row = dstDotOperandLayout.getIsMMAv1Row().cast().getValue(); + if((order[0] == 1 && isMMAv1Row) || + (order[0] == 0 && !isMMAv1Row)) + return failure(); + auto newIsRow = BoolAttr::get(op->getContext(), !isMMAv1Row); + auto newDstEncoding = triton::gpu::DotOperandEncodingAttr::get( + op->getContext(), dstDotOperandLayout.getOpIdx(), dstDotOperandLayout.getParent(), + newIsRow); + auto newDstType = RankedTensorType::get( + dstType.getShape(), + dstType.getElementType(), newDstEncoding); + auto newCvt = rewriter.create( + op->getLoc(), newDstType, cvt.getOperand()); + rewriter.replaceOp(op, newCvt.getResult()); + return success(); + } +}; + + class BlockedToMMA : public mlir::RewritePattern { int computeCapability; @@ -772,14 +821,28 @@ public: Value b = dotOp.b(); auto oldAType = a.getType().cast(); auto oldBType = b.getType().cast(); + auto oldAOrder = oldAType.getEncoding().cast() + .getParent().cast().getOrder(); + auto oldBOrder = oldBType.getEncoding().cast() + .getParent().cast().getOrder(); + Attribute isMMAv1RowA; + Attribute isMMAv1RowB; + if(version == 1){ + isMMAv1RowA = BoolAttr::get(getContext(), oldAOrder[0] == 1); + isMMAv1RowB = BoolAttr::get(getContext(), oldBOrder[0] == 1); + } + auto newAType = RankedTensorType::get( oldAType.getShape(), oldAType.getElementType(), triton::gpu::DotOperandEncodingAttr::get(oldAType.getContext(), 0, - newRetType.getEncoding())); + newRetType.getEncoding(), + isMMAv1RowA)); auto newBType = RankedTensorType::get( oldBType.getShape(), oldBType.getElementType(), triton::gpu::DotOperandEncodingAttr::get(oldBType.getContext(), 1, - newRetType.getEncoding())); + newRetType.getEncoding(), + isMMAv1RowB)); + a = rewriter.create(a.getLoc(), newAType, a); b = rewriter.create(b.getLoc(), newBType, b); auto newDot = rewriter.create(dotOp.getLoc(), newRetType, a, @@ -791,6 +854,51 @@ public: } }; +class FixupLoop : public mlir::RewritePattern { + +public: + FixupLoop(mlir::MLIRContext *context) + : mlir::RewritePattern(scf::ForOp::getOperationName(), 2, + context) {} + + mlir::LogicalResult + matchAndRewrite(mlir::Operation *op, + mlir::PatternRewriter &rewriter) const override { + auto forOp = cast(op); + + // Rewrite init argument + SmallVector newInitArgs = forOp.getInitArgs(); + bool shouldRematerialize = false; + for(size_t i = 0; i < newInitArgs.size(); i++){ + auto initArg = newInitArgs[i]; + auto regionArg = forOp.getRegionIterArgs()[i]; + if(newInitArgs[i].getType() != forOp.getRegionIterArgs()[i].getType()){ + shouldRematerialize = true; + break; + } + } + if(!shouldRematerialize) + return failure(); + + scf::ForOp newForOp = rewriter.create( + forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), + forOp.getStep(), newInitArgs); + newForOp->moveBefore(forOp); + rewriter.setInsertionPointToStart(newForOp.getBody()); + BlockAndValueMapping mapping; + for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs())) + mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]); + + for (Operation &op : forOp.getBody()->getOperations()) { + Operation *newOp = rewriter.clone(op, mapping); + } + rewriter.replaceOp(forOp, newForOp.getResults()); + return success(); + + + } +}; + } // namespace #define GEN_PASS_CLASSES @@ -810,6 +918,7 @@ public: mlir::RewritePatternSet patterns(context); patterns.add(context); + patterns.add(context); patterns.add(context); patterns.add(context); patterns.add(context); @@ -820,6 +929,13 @@ public: if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) { signalPassFailure(); } + + // llvm::outs() << m << "\n"; + mlir::RewritePatternSet loopFixup(context); + loopFixup.add(context); + if (applyPatternsAndFoldGreedily(m, std::move(loopFixup)).failed()) { + signalPassFailure(); + } } }; diff --git a/python/tests/test_gemm.py b/python/tests/test_gemm.py index 5e6cd421c..7d502328e 100644 --- a/python/tests/test_gemm.py +++ b/python/tests/test_gemm.py @@ -32,7 +32,7 @@ def matmul_no_scf_kernel( (shape, num_warps, trans_a, trans_b) for shape in [ [128, 256, 32], - [256, 128, 16], + # [256, 128, 16], [128, 16, 32], [32, 128, 64], [128, 128, 64], @@ -43,8 +43,6 @@ def matmul_no_scf_kernel( for trans_b in [False, True] ]) def test_gemm_no_scf(SHAPE, NUM_WARPS, TRANS_A, TRANS_B): - guard_for_volta(NUM_WARPS, TRANS_A, TRANS_B) - SIZE_M, SIZE_N, SIZE_K = SHAPE if (TRANS_A): a = torch.randn((SIZE_K, SIZE_M), device='cuda', dtype=torch.float16).T @@ -83,7 +81,7 @@ def test_gemm_no_scf(SHAPE, NUM_WARPS, TRANS_A, TRANS_B): for trans_b in [False, True] ]) def test_gemm_no_scf_int8(SHAPE, NUM_WARPS, TRANS_A, TRANS_B): - guard_for_volta(NUM_WARPS, TRANS_A, TRANS_B, is_int8=True) + guard_for_volta(is_int8=True) SIZE_M, SIZE_N, SIZE_K = SHAPE @@ -199,7 +197,6 @@ def get_proper_err(a, b, golden): [128, 64, 128, 4, 128, 64, 32, False, True], ]) def test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, TRANS_A, TRANS_B): - guard_for_volta(NUM_WARPS, TRANS_A, TRANS_B) if (TRANS_A): a = torch.randn((SIZE_K, SIZE_M), device='cuda', dtype=torch.float16).T @@ -276,7 +273,7 @@ def test_gemm_fp32(M, N, K, num_warps, block_M, block_N, block_K, allow_tf32): c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) tl.store(c_ptrs, accumulator, c_mask) - guard_for_volta(num_warps, trans_a=False, trans_b=False, is_tf32=allow_tf32) + guard_for_volta(is_tf32=allow_tf32) # Configure the pytorch counterpart torch.backends.cuda.matmul.allow_tf32 = allow_tf32 @@ -302,7 +299,7 @@ def test_gemm_fp32(M, N, K, num_warps, block_M, block_N, block_K, allow_tf32): torch.testing.assert_close(c, golden, rtol=max(1e-4, 1.5 * golden_rel_err), atol=max(1e-4, 1.5 * golden_abs_err)) -def guard_for_volta(num_warps, trans_a, trans_b, is_int8=False, is_tf32=False): +def guard_for_volta(is_int8=False, is_tf32=False): ''' Tell whether the test case is valid on Volta GPU. Some features are WIP, so the corresponding support are missing. @@ -311,8 +308,7 @@ def guard_for_volta(num_warps, trans_a, trans_b, is_int8=False, is_tf32=False): is_on_Volta = capability[0] < 8 # TODO[Superjomn]: Remove the constraints below when features are ready is_feature_supported = not (is_int8 or is_tf32) - is_feature_ready = not (trans_a or trans_b) if is_on_Volta: - if (not is_feature_supported) or (not is_feature_ready): - pytest.skip("Not valid on Volta") + if (not is_feature_supported): + pytest.skip("Not valid on Volta") \ No newline at end of file diff --git a/python/triton/compiler.py b/python/triton/compiler.py index 2e760dabb..91ff732a7 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -1385,6 +1385,8 @@ arg_type_pattern = { # def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: int = 4, num_stages: int = 3, extern_libs=None, configs=None): def compile(fn, **kwargs): + capability = torch.cuda.get_device_capability() + capability = capability[0] * 10 + capability[1] # we get the kernel, i.e. the first function generated in the module # if fn is not a JITFunction, then it # has to be a path to a file @@ -1392,11 +1394,9 @@ def compile(fn, **kwargs): asm = dict() constants = kwargs.get("constants", dict()) num_warps = kwargs.get("num_warps", 4) - num_stages = kwargs.get("num_stages", 3) + num_stages = kwargs.get("num_stages", 3 if capability >= 75 else 2) extern_libs = kwargs.get("extern_libs", dict()) device = kwargs.get("device", torch.cuda.current_device()) - capability = torch.cuda.get_device_capability() - capability = capability[0] * 10 + capability[1] # build compilation stages stages = { "ast": (lambda path: fn, None), diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index e56b20e31..e15d4e6a7 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -879,8 +879,8 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { #blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}> #shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> #mma = #triton_gpu.mma<{version = 1, warpsPerCTA = [2, 2]}> -#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma}> -#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma}> +#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma, isMMAv1Row=true}> +#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma, isMMAv1Row=true}> module attributes {"triton_gpu.num-warps" = 4 : i32} { func @matmul884_kernel_dot_operand_layout(%ptr:!tt.ptr {tt.divisibility = 16 : i32}, %a:tensor<128x32xf16, #shared>, %b:tensor<32x256xf16, #shared>) {