[OPTIMIZER] Added isRow attribute for dotOperandEncoding of MMAv1 layout

This commit is contained in:
Phil Tillet
2022-12-08 18:31:16 -08:00
parent 3ed36dcb4d
commit 0d27912554
4 changed files with 99 additions and 6 deletions

View File

@@ -416,15 +416,36 @@ In TritonGPU dialect, considering `d = tt.dot a, b, c`
tt.dot's operands a and b must be of DotOperandEncodingAttr layout. tt.dot's operands a and b must be of DotOperandEncodingAttr layout.
a's opIdx is 0, b's opIdx is 1. a's opIdx is 0, b's opIdx is 1.
The parend field in DotOperandEncodingAttr is the layout of d. The parend field in DotOperandEncodingAttr is the layout of d.
For MMA v1, an additional attribute `isMMAv1Row` determines whether e.g. the a operand is used
in the context of an mma.884.row.col or an mma.884.col.col operation. See the PTX ISA documentation
section 9.7.13.4.1 for more details.
}]; }];
let parameters = ( let parameters = (
ins ins
"unsigned":$opIdx, "unsigned":$opIdx,
"Attribute":$parent "Attribute":$parent,
"Attribute":$isMMAv1Row
); );
let builders = [
AttrBuilder<(ins "unsigned":$opIdx,
"Attribute":$parent), [{
if(parent.isa<MmaEncodingAttr>() &&
parent.cast<MmaEncodingAttr>().getVersion() == 1){
llvm::errs() << "DotOperand for MMAv1 must have isMMAv1Row field\n";
return {};
}
Attribute none;
return $_get(context, opIdx, parent, none);
}]>
];
let extraClassDeclaration = extraBaseClassDeclaration; let extraClassDeclaration = extraBaseClassDeclaration;
} }
#endif #endif

View File

@@ -3428,6 +3428,16 @@ Value ConvertLayoutOpConversion::lowerSharedToDotOperandMMA(
} else if (!isOuter && mmaLayout.getVersion() == 1 && } else if (!isOuter && mmaLayout.getVersion() == 1 &&
isHMMA) { // tensor core v1 isHMMA) { // tensor core v1
DotOpMmaV1ConversionHelper helper(mmaLayout); DotOpMmaV1ConversionHelper helper(mmaLayout);
bool isMMAv1Row = dotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
auto srcSharedLayout = src.getType().cast<RankedTensorType>().getEncoding().cast<SharedEncodingAttr>();
// Can only convert [1, 0] to row or [0, 1] to col for now
if((srcSharedLayout.getOrder()[0] == 1 && !isMMAv1Row) ||
(srcSharedLayout.getOrder()[0] == 0 && isMMAv1Row)){
llvm::errs() << "Unsupported Shared -> DotOperand[MMAv1] conversion\n";
return Value();
}
if (dotOperandLayout.getOpIdx() == 0) { // operand $a if (dotOperandLayout.getOpIdx() == 0) { // operand $a
// TODO[Superjomn]: transA is not available here. // TODO[Superjomn]: transA is not available here.
bool transA = false; bool transA = false;

View File

@@ -589,15 +589,22 @@ Attribute DotOperandEncodingAttr::parse(AsmParser &parser, Type type) {
return {}; return {};
unsigned opIdx = attrs.get("opIdx").cast<IntegerAttr>().getInt(); unsigned opIdx = attrs.get("opIdx").cast<IntegerAttr>().getInt();
Attribute parent = attrs.get("parent"); Attribute parent = attrs.get("parent");
Attribute isMMAv1Row;
if(parent.isa<MmaEncodingAttr>() &&
parent.cast<MmaEncodingAttr>().getVersion() == 1){
isMMAv1Row = attrs.get("isMMAv1Row");
}
return parser.getChecked<DotOperandEncodingAttr>(parser.getContext(), opIdx, return parser.getChecked<DotOperandEncodingAttr>(parser.getContext(), opIdx,
parent); parent, isMMAv1Row);
} }
void DotOperandEncodingAttr::print(mlir::AsmPrinter &printer) const { void DotOperandEncodingAttr::print(mlir::AsmPrinter &printer) const {
printer << "<{" printer << "<{"
<< "opIdx = " << getOpIdx() << ", " << "opIdx = " << getOpIdx() << ", "
<< "parent = " << getParent() << "}>"; << "parent = " << getParent();
if(getIsMMAv1Row())
printer << ", isMMAv1Row = " << getIsMMAv1Row();
printer << "}>";
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@@ -713,6 +713,46 @@ public:
} }
}; };
class OptimizeBlockedToDotOperand : public mlir::RewritePattern {
public:
OptimizeBlockedToDotOperand(mlir::MLIRContext *context)
: RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 1,
context) {}
mlir::LogicalResult
matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
auto cvt = cast<triton::gpu::ConvertLayoutOp>(op);
auto srcType = cvt.getOperand().getType().cast<RankedTensorType>();
auto dstType = cvt.getResult().getType().cast<RankedTensorType>();
auto srcBlockedLayout =
srcType.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>();
auto dstDotOperandLayout =
dstType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
if (!srcBlockedLayout || !dstDotOperandLayout)
return failure();
unsigned opIdx = dstDotOperandLayout.getOpIdx();
if(!dstDotOperandLayout.getIsMMAv1Row())
return failure();
bool isMMAv1Row = dstDotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
if((srcBlockedLayout.getOrder()[0] == 1 && isMMAv1Row) ||
(srcBlockedLayout.getOrder()[0] == 0 && !isMMAv1Row))
return failure();
auto newIsRow = BoolAttr::get(op->getContext(), !isMMAv1Row);
auto newDstEncoding = triton::gpu::DotOperandEncodingAttr::get(
op->getContext(), dstDotOperandLayout.getOpIdx(), dstDotOperandLayout.getParent(),
newIsRow);
auto newDstType = RankedTensorType::get(
dstType.getShape(),
dstType.getElementType(), newDstEncoding);
auto newCvt = rewriter.create<triton::gpu::ConvertLayoutOp>(
op->getLoc(), newDstType, cvt.getOperand());
rewriter.replaceOp(op, newCvt.getResult());
return success();
}
};
class BlockedToMMA : public mlir::RewritePattern { class BlockedToMMA : public mlir::RewritePattern {
int computeCapability; int computeCapability;
@@ -770,14 +810,28 @@ public:
Value b = dotOp.b(); Value b = dotOp.b();
auto oldAType = a.getType().cast<RankedTensorType>(); auto oldAType = a.getType().cast<RankedTensorType>();
auto oldBType = b.getType().cast<RankedTensorType>(); auto oldBType = b.getType().cast<RankedTensorType>();
auto oldAOrder = oldAType.getEncoding().cast<triton::gpu::DotOperandEncodingAttr>()
.getParent().cast<triton::gpu::BlockedEncodingAttr>().getOrder();
auto oldBOrder = oldBType.getEncoding().cast<triton::gpu::DotOperandEncodingAttr>()
.getParent().cast<triton::gpu::BlockedEncodingAttr>().getOrder();
Attribute isMMAv1RowA;
Attribute isMMAv1RowB;
if(version == 1){
isMMAv1RowA = BoolAttr::get(getContext(), oldAOrder[0] == 1);
isMMAv1RowB = BoolAttr::get(getContext(), oldBOrder[0] == 1);
}
auto newAType = RankedTensorType::get( auto newAType = RankedTensorType::get(
oldAType.getShape(), oldAType.getElementType(), oldAType.getShape(), oldAType.getElementType(),
triton::gpu::DotOperandEncodingAttr::get(oldAType.getContext(), 0, triton::gpu::DotOperandEncodingAttr::get(oldAType.getContext(), 0,
newRetType.getEncoding())); newRetType.getEncoding(),
isMMAv1RowA));
auto newBType = RankedTensorType::get( auto newBType = RankedTensorType::get(
oldBType.getShape(), oldBType.getElementType(), oldBType.getShape(), oldBType.getElementType(),
triton::gpu::DotOperandEncodingAttr::get(oldBType.getContext(), 1, triton::gpu::DotOperandEncodingAttr::get(oldBType.getContext(), 1,
newRetType.getEncoding())); newRetType.getEncoding(),
isMMAv1RowB));
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), newAType, a); a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), newAType, a);
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), newBType, b); b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), newBType, b);
auto newDot = rewriter.create<triton::DotOp>( auto newDot = rewriter.create<triton::DotOp>(
@@ -808,6 +862,7 @@ public:
mlir::RewritePatternSet patterns(context); mlir::RewritePatternSet patterns(context);
patterns.add<OptimizeBlockedToShared>(context); patterns.add<OptimizeBlockedToShared>(context);
patterns.add<OptimizeBlockedToDotOperand>(context);
patterns.add<SimplifyConversion>(context); patterns.add<SimplifyConversion>(context);
patterns.add<DecomposeDotOperand>(context); patterns.add<DecomposeDotOperand>(context);
patterns.add<RematerializeBackward>(context); patterns.add<RematerializeBackward>(context);