[OPTIMIZER] Added isRow attribute for dotOperandEncoding of MMAv1 layout

This commit is contained in:
Phil Tillet
2022-12-08 18:31:16 -08:00
parent 3ed36dcb4d
commit 0d27912554
4 changed files with 99 additions and 6 deletions

View File

@@ -416,15 +416,36 @@ In TritonGPU dialect, considering `d = tt.dot a, b, c`
tt.dot's operands a and b must be of DotOperandEncodingAttr layout.
a's opIdx is 0, b's opIdx is 1.
The parend field in DotOperandEncodingAttr is the layout of d.
For MMA v1, an additional attribute `isMMAv1Row` determines whether e.g. the a operand is used
in the context of an mma.884.row.col or an mma.884.col.col operation. See the PTX ISA documentation
section 9.7.13.4.1 for more details.
}];
let parameters = (
ins
"unsigned":$opIdx,
"Attribute":$parent
"Attribute":$parent,
"Attribute":$isMMAv1Row
);
let builders = [
AttrBuilder<(ins "unsigned":$opIdx,
"Attribute":$parent), [{
if(parent.isa<MmaEncodingAttr>() &&
parent.cast<MmaEncodingAttr>().getVersion() == 1){
llvm::errs() << "DotOperand for MMAv1 must have isMMAv1Row field\n";
return {};
}
Attribute none;
return $_get(context, opIdx, parent, none);
}]>
];
let extraClassDeclaration = extraBaseClassDeclaration;
}
#endif

View File

@@ -3428,6 +3428,16 @@ Value ConvertLayoutOpConversion::lowerSharedToDotOperandMMA(
} else if (!isOuter && mmaLayout.getVersion() == 1 &&
isHMMA) { // tensor core v1
DotOpMmaV1ConversionHelper helper(mmaLayout);
bool isMMAv1Row = dotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
auto srcSharedLayout = src.getType().cast<RankedTensorType>().getEncoding().cast<SharedEncodingAttr>();
// Can only convert [1, 0] to row or [0, 1] to col for now
if((srcSharedLayout.getOrder()[0] == 1 && !isMMAv1Row) ||
(srcSharedLayout.getOrder()[0] == 0 && isMMAv1Row)){
llvm::errs() << "Unsupported Shared -> DotOperand[MMAv1] conversion\n";
return Value();
}
if (dotOperandLayout.getOpIdx() == 0) { // operand $a
// TODO[Superjomn]: transA is not available here.
bool transA = false;

View File

@@ -589,15 +589,22 @@ Attribute DotOperandEncodingAttr::parse(AsmParser &parser, Type type) {
return {};
unsigned opIdx = attrs.get("opIdx").cast<IntegerAttr>().getInt();
Attribute parent = attrs.get("parent");
Attribute isMMAv1Row;
if(parent.isa<MmaEncodingAttr>() &&
parent.cast<MmaEncodingAttr>().getVersion() == 1){
isMMAv1Row = attrs.get("isMMAv1Row");
}
return parser.getChecked<DotOperandEncodingAttr>(parser.getContext(), opIdx,
parent);
parent, isMMAv1Row);
}
void DotOperandEncodingAttr::print(mlir::AsmPrinter &printer) const {
printer << "<{"
<< "opIdx = " << getOpIdx() << ", "
<< "parent = " << getParent() << "}>";
<< "parent = " << getParent();
if(getIsMMAv1Row())
printer << ", isMMAv1Row = " << getIsMMAv1Row();
printer << "}>";
}
//===----------------------------------------------------------------------===//

View File

@@ -713,6 +713,46 @@ public:
}
};
class OptimizeBlockedToDotOperand : public mlir::RewritePattern {
public:
OptimizeBlockedToDotOperand(mlir::MLIRContext *context)
: RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 1,
context) {}
mlir::LogicalResult
matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
auto cvt = cast<triton::gpu::ConvertLayoutOp>(op);
auto srcType = cvt.getOperand().getType().cast<RankedTensorType>();
auto dstType = cvt.getResult().getType().cast<RankedTensorType>();
auto srcBlockedLayout =
srcType.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>();
auto dstDotOperandLayout =
dstType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
if (!srcBlockedLayout || !dstDotOperandLayout)
return failure();
unsigned opIdx = dstDotOperandLayout.getOpIdx();
if(!dstDotOperandLayout.getIsMMAv1Row())
return failure();
bool isMMAv1Row = dstDotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
if((srcBlockedLayout.getOrder()[0] == 1 && isMMAv1Row) ||
(srcBlockedLayout.getOrder()[0] == 0 && !isMMAv1Row))
return failure();
auto newIsRow = BoolAttr::get(op->getContext(), !isMMAv1Row);
auto newDstEncoding = triton::gpu::DotOperandEncodingAttr::get(
op->getContext(), dstDotOperandLayout.getOpIdx(), dstDotOperandLayout.getParent(),
newIsRow);
auto newDstType = RankedTensorType::get(
dstType.getShape(),
dstType.getElementType(), newDstEncoding);
auto newCvt = rewriter.create<triton::gpu::ConvertLayoutOp>(
op->getLoc(), newDstType, cvt.getOperand());
rewriter.replaceOp(op, newCvt.getResult());
return success();
}
};
class BlockedToMMA : public mlir::RewritePattern {
int computeCapability;
@@ -770,14 +810,28 @@ public:
Value b = dotOp.b();
auto oldAType = a.getType().cast<RankedTensorType>();
auto oldBType = b.getType().cast<RankedTensorType>();
auto oldAOrder = oldAType.getEncoding().cast<triton::gpu::DotOperandEncodingAttr>()
.getParent().cast<triton::gpu::BlockedEncodingAttr>().getOrder();
auto oldBOrder = oldBType.getEncoding().cast<triton::gpu::DotOperandEncodingAttr>()
.getParent().cast<triton::gpu::BlockedEncodingAttr>().getOrder();
Attribute isMMAv1RowA;
Attribute isMMAv1RowB;
if(version == 1){
isMMAv1RowA = BoolAttr::get(getContext(), oldAOrder[0] == 1);
isMMAv1RowB = BoolAttr::get(getContext(), oldBOrder[0] == 1);
}
auto newAType = RankedTensorType::get(
oldAType.getShape(), oldAType.getElementType(),
triton::gpu::DotOperandEncodingAttr::get(oldAType.getContext(), 0,
newRetType.getEncoding()));
newRetType.getEncoding(),
isMMAv1RowA));
auto newBType = RankedTensorType::get(
oldBType.getShape(), oldBType.getElementType(),
triton::gpu::DotOperandEncodingAttr::get(oldBType.getContext(), 1,
newRetType.getEncoding()));
newRetType.getEncoding(),
isMMAv1RowB));
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), newAType, a);
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), newBType, b);
auto newDot = rewriter.create<triton::DotOp>(
@@ -808,6 +862,7 @@ public:
mlir::RewritePatternSet patterns(context);
patterns.add<OptimizeBlockedToShared>(context);
patterns.add<OptimizeBlockedToDotOperand>(context);
patterns.add<SimplifyConversion>(context);
patterns.add<DecomposeDotOperand>(context);
patterns.add<RematerializeBackward>(context);