[OPTIMIZER] Added isRow attribute for dotOperandEncoding of MMAv1 layout

This commit is contained in:
Phil Tillet
2022-12-08 18:31:16 -08:00
parent 3ed36dcb4d
commit 0d27912554
4 changed files with 99 additions and 6 deletions

View File

@@ -589,15 +589,22 @@ Attribute DotOperandEncodingAttr::parse(AsmParser &parser, Type type) {
return {};
unsigned opIdx = attrs.get("opIdx").cast<IntegerAttr>().getInt();
Attribute parent = attrs.get("parent");
Attribute isMMAv1Row;
if(parent.isa<MmaEncodingAttr>() &&
parent.cast<MmaEncodingAttr>().getVersion() == 1){
isMMAv1Row = attrs.get("isMMAv1Row");
}
return parser.getChecked<DotOperandEncodingAttr>(parser.getContext(), opIdx,
parent);
parent, isMMAv1Row);
}
void DotOperandEncodingAttr::print(mlir::AsmPrinter &printer) const {
printer << "<{"
<< "opIdx = " << getOpIdx() << ", "
<< "parent = " << getParent() << "}>";
<< "parent = " << getParent();
if(getIsMMAv1Row())
printer << ", isMMAv1Row = " << getIsMMAv1Row();
printer << "}>";
}
//===----------------------------------------------------------------------===//

View File

@@ -713,6 +713,46 @@ public:
}
};
class OptimizeBlockedToDotOperand : public mlir::RewritePattern {
public:
OptimizeBlockedToDotOperand(mlir::MLIRContext *context)
: RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 1,
context) {}
mlir::LogicalResult
matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
auto cvt = cast<triton::gpu::ConvertLayoutOp>(op);
auto srcType = cvt.getOperand().getType().cast<RankedTensorType>();
auto dstType = cvt.getResult().getType().cast<RankedTensorType>();
auto srcBlockedLayout =
srcType.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>();
auto dstDotOperandLayout =
dstType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
if (!srcBlockedLayout || !dstDotOperandLayout)
return failure();
unsigned opIdx = dstDotOperandLayout.getOpIdx();
if(!dstDotOperandLayout.getIsMMAv1Row())
return failure();
bool isMMAv1Row = dstDotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
if((srcBlockedLayout.getOrder()[0] == 1 && isMMAv1Row) ||
(srcBlockedLayout.getOrder()[0] == 0 && !isMMAv1Row))
return failure();
auto newIsRow = BoolAttr::get(op->getContext(), !isMMAv1Row);
auto newDstEncoding = triton::gpu::DotOperandEncodingAttr::get(
op->getContext(), dstDotOperandLayout.getOpIdx(), dstDotOperandLayout.getParent(),
newIsRow);
auto newDstType = RankedTensorType::get(
dstType.getShape(),
dstType.getElementType(), newDstEncoding);
auto newCvt = rewriter.create<triton::gpu::ConvertLayoutOp>(
op->getLoc(), newDstType, cvt.getOperand());
rewriter.replaceOp(op, newCvt.getResult());
return success();
}
};
class BlockedToMMA : public mlir::RewritePattern {
int computeCapability;
@@ -770,14 +810,28 @@ public:
Value b = dotOp.b();
auto oldAType = a.getType().cast<RankedTensorType>();
auto oldBType = b.getType().cast<RankedTensorType>();
auto oldAOrder = oldAType.getEncoding().cast<triton::gpu::DotOperandEncodingAttr>()
.getParent().cast<triton::gpu::BlockedEncodingAttr>().getOrder();
auto oldBOrder = oldBType.getEncoding().cast<triton::gpu::DotOperandEncodingAttr>()
.getParent().cast<triton::gpu::BlockedEncodingAttr>().getOrder();
Attribute isMMAv1RowA;
Attribute isMMAv1RowB;
if(version == 1){
isMMAv1RowA = BoolAttr::get(getContext(), oldAOrder[0] == 1);
isMMAv1RowB = BoolAttr::get(getContext(), oldBOrder[0] == 1);
}
auto newAType = RankedTensorType::get(
oldAType.getShape(), oldAType.getElementType(),
triton::gpu::DotOperandEncodingAttr::get(oldAType.getContext(), 0,
newRetType.getEncoding()));
newRetType.getEncoding(),
isMMAv1RowA));
auto newBType = RankedTensorType::get(
oldBType.getShape(), oldBType.getElementType(),
triton::gpu::DotOperandEncodingAttr::get(oldBType.getContext(), 1,
newRetType.getEncoding()));
newRetType.getEncoding(),
isMMAv1RowB));
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), newAType, a);
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), newBType, b);
auto newDot = rewriter.create<triton::DotOp>(
@@ -808,6 +862,7 @@ public:
mlir::RewritePatternSet patterns(context);
patterns.add<OptimizeBlockedToShared>(context);
patterns.add<OptimizeBlockedToDotOperand>(context);
patterns.add<SimplifyConversion>(context);
patterns.add<DecomposeDotOperand>(context);
patterns.add<RematerializeBackward>(context);