From 0d279125546725e335f23b9d691b09c6cc171d51 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Thu, 8 Dec 2022 18:31:16 -0800 Subject: [PATCH] [OPTIMIZER] Added isRow attribute for dotOperandEncoding of MMAv1 layout --- .../Dialect/TritonGPU/IR/TritonGPUAttrDefs.td | 23 +++++++- .../TritonGPUToLLVM/TritonGPUToLLVM.cpp | 10 ++++ lib/Dialect/TritonGPU/IR/Dialect.cpp | 13 +++- lib/Dialect/TritonGPU/Transforms/Combine.cpp | 59 ++++++++++++++++++- 4 files changed, 99 insertions(+), 6 deletions(-) diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index d4ff8021d..f00c387ce 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -416,15 +416,36 @@ In TritonGPU dialect, considering `d = tt.dot a, b, c` tt.dot's operands a and b must be of DotOperandEncodingAttr layout. a's opIdx is 0, b's opIdx is 1. The parend field in DotOperandEncodingAttr is the layout of d. + +For MMA v1, an additional attribute `isMMAv1Row` determines whether e.g. the a operand is used +in the context of an mma.884.row.col or an mma.884.col.col operation. See the PTX ISA documentation +section 9.7.13.4.1 for more details. }]; let parameters = ( ins "unsigned":$opIdx, - "Attribute":$parent + "Attribute":$parent, + "Attribute":$isMMAv1Row ); + let builders = [ + AttrBuilder<(ins "unsigned":$opIdx, + "Attribute":$parent), [{ + if(parent.isa() && + parent.cast().getVersion() == 1){ + llvm::errs() << "DotOperand for MMAv1 must have isMMAv1Row field\n"; + return {}; + } + Attribute none; + return $_get(context, opIdx, parent, none); + }]> + + ]; + let extraClassDeclaration = extraBaseClassDeclaration; } + + #endif diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index b3d9e172f..fc8e2f508 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -3428,6 +3428,16 @@ Value ConvertLayoutOpConversion::lowerSharedToDotOperandMMA( } else if (!isOuter && mmaLayout.getVersion() == 1 && isHMMA) { // tensor core v1 DotOpMmaV1ConversionHelper helper(mmaLayout); + bool isMMAv1Row = dotOperandLayout.getIsMMAv1Row().cast().getValue(); + auto srcSharedLayout = src.getType().cast().getEncoding().cast(); + + // Can only convert [1, 0] to row or [0, 1] to col for now + if((srcSharedLayout.getOrder()[0] == 1 && !isMMAv1Row) || + (srcSharedLayout.getOrder()[0] == 0 && isMMAv1Row)){ + llvm::errs() << "Unsupported Shared -> DotOperand[MMAv1] conversion\n"; + return Value(); + } + if (dotOperandLayout.getOpIdx() == 0) { // operand $a // TODO[Superjomn]: transA is not available here. bool transA = false; diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 2649be1f0..707b036d4 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -589,15 +589,22 @@ Attribute DotOperandEncodingAttr::parse(AsmParser &parser, Type type) { return {}; unsigned opIdx = attrs.get("opIdx").cast().getInt(); Attribute parent = attrs.get("parent"); - + Attribute isMMAv1Row; + if(parent.isa() && + parent.cast().getVersion() == 1){ + isMMAv1Row = attrs.get("isMMAv1Row"); + } return parser.getChecked(parser.getContext(), opIdx, - parent); + parent, isMMAv1Row); } void DotOperandEncodingAttr::print(mlir::AsmPrinter &printer) const { printer << "<{" << "opIdx = " << getOpIdx() << ", " - << "parent = " << getParent() << "}>"; + << "parent = " << getParent(); + if(getIsMMAv1Row()) + printer << ", isMMAv1Row = " << getIsMMAv1Row(); + printer << "}>"; } //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/TritonGPU/Transforms/Combine.cpp b/lib/Dialect/TritonGPU/Transforms/Combine.cpp index e14bae003..23d9c1b80 100644 --- a/lib/Dialect/TritonGPU/Transforms/Combine.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Combine.cpp @@ -713,6 +713,46 @@ public: } }; +class OptimizeBlockedToDotOperand : public mlir::RewritePattern { +public: + OptimizeBlockedToDotOperand(mlir::MLIRContext *context) + : RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 1, + context) {} + + mlir::LogicalResult + matchAndRewrite(mlir::Operation *op, + mlir::PatternRewriter &rewriter) const override { + auto cvt = cast(op); + auto srcType = cvt.getOperand().getType().cast(); + auto dstType = cvt.getResult().getType().cast(); + auto srcBlockedLayout = + srcType.getEncoding().dyn_cast(); + auto dstDotOperandLayout = + dstType.getEncoding().dyn_cast(); + if (!srcBlockedLayout || !dstDotOperandLayout) + return failure(); + unsigned opIdx = dstDotOperandLayout.getOpIdx(); + if(!dstDotOperandLayout.getIsMMAv1Row()) + return failure(); + bool isMMAv1Row = dstDotOperandLayout.getIsMMAv1Row().cast().getValue(); + if((srcBlockedLayout.getOrder()[0] == 1 && isMMAv1Row) || + (srcBlockedLayout.getOrder()[0] == 0 && !isMMAv1Row)) + return failure(); + auto newIsRow = BoolAttr::get(op->getContext(), !isMMAv1Row); + auto newDstEncoding = triton::gpu::DotOperandEncodingAttr::get( + op->getContext(), dstDotOperandLayout.getOpIdx(), dstDotOperandLayout.getParent(), + newIsRow); + auto newDstType = RankedTensorType::get( + dstType.getShape(), + dstType.getElementType(), newDstEncoding); + auto newCvt = rewriter.create( + op->getLoc(), newDstType, cvt.getOperand()); + rewriter.replaceOp(op, newCvt.getResult()); + return success(); + } +}; + + class BlockedToMMA : public mlir::RewritePattern { int computeCapability; @@ -770,14 +810,28 @@ public: Value b = dotOp.b(); auto oldAType = a.getType().cast(); auto oldBType = b.getType().cast(); + auto oldAOrder = oldAType.getEncoding().cast() + .getParent().cast().getOrder(); + auto oldBOrder = oldBType.getEncoding().cast() + .getParent().cast().getOrder(); + Attribute isMMAv1RowA; + Attribute isMMAv1RowB; + if(version == 1){ + isMMAv1RowA = BoolAttr::get(getContext(), oldAOrder[0] == 1); + isMMAv1RowB = BoolAttr::get(getContext(), oldBOrder[0] == 1); + } + auto newAType = RankedTensorType::get( oldAType.getShape(), oldAType.getElementType(), triton::gpu::DotOperandEncodingAttr::get(oldAType.getContext(), 0, - newRetType.getEncoding())); + newRetType.getEncoding(), + isMMAv1RowA)); auto newBType = RankedTensorType::get( oldBType.getShape(), oldBType.getElementType(), triton::gpu::DotOperandEncodingAttr::get(oldBType.getContext(), 1, - newRetType.getEncoding())); + newRetType.getEncoding(), + isMMAv1RowB)); + a = rewriter.create(a.getLoc(), newAType, a); b = rewriter.create(b.getLoc(), newBType, b); auto newDot = rewriter.create( @@ -808,6 +862,7 @@ public: mlir::RewritePatternSet patterns(context); patterns.add(context); + patterns.add(context); patterns.add(context); patterns.add(context); patterns.add(context);