[OPTIMIZER] Added isRow attribute for dotOperandEncoding of MMAv1 layout
This commit is contained in:
@@ -416,15 +416,36 @@ In TritonGPU dialect, considering `d = tt.dot a, b, c`
|
|||||||
tt.dot's operands a and b must be of DotOperandEncodingAttr layout.
|
tt.dot's operands a and b must be of DotOperandEncodingAttr layout.
|
||||||
a's opIdx is 0, b's opIdx is 1.
|
a's opIdx is 0, b's opIdx is 1.
|
||||||
The parend field in DotOperandEncodingAttr is the layout of d.
|
The parend field in DotOperandEncodingAttr is the layout of d.
|
||||||
|
|
||||||
|
For MMA v1, an additional attribute `isMMAv1Row` determines whether e.g. the a operand is used
|
||||||
|
in the context of an mma.884.row.col or an mma.884.col.col operation. See the PTX ISA documentation
|
||||||
|
section 9.7.13.4.1 for more details.
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let parameters = (
|
let parameters = (
|
||||||
ins
|
ins
|
||||||
"unsigned":$opIdx,
|
"unsigned":$opIdx,
|
||||||
"Attribute":$parent
|
"Attribute":$parent,
|
||||||
|
"Attribute":$isMMAv1Row
|
||||||
);
|
);
|
||||||
|
|
||||||
|
let builders = [
|
||||||
|
AttrBuilder<(ins "unsigned":$opIdx,
|
||||||
|
"Attribute":$parent), [{
|
||||||
|
if(parent.isa<MmaEncodingAttr>() &&
|
||||||
|
parent.cast<MmaEncodingAttr>().getVersion() == 1){
|
||||||
|
llvm::errs() << "DotOperand for MMAv1 must have isMMAv1Row field\n";
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
Attribute none;
|
||||||
|
return $_get(context, opIdx, parent, none);
|
||||||
|
}]>
|
||||||
|
|
||||||
|
];
|
||||||
|
|
||||||
let extraClassDeclaration = extraBaseClassDeclaration;
|
let extraClassDeclaration = extraBaseClassDeclaration;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
@@ -3428,6 +3428,16 @@ Value ConvertLayoutOpConversion::lowerSharedToDotOperandMMA(
|
|||||||
} else if (!isOuter && mmaLayout.getVersion() == 1 &&
|
} else if (!isOuter && mmaLayout.getVersion() == 1 &&
|
||||||
isHMMA) { // tensor core v1
|
isHMMA) { // tensor core v1
|
||||||
DotOpMmaV1ConversionHelper helper(mmaLayout);
|
DotOpMmaV1ConversionHelper helper(mmaLayout);
|
||||||
|
bool isMMAv1Row = dotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
|
||||||
|
auto srcSharedLayout = src.getType().cast<RankedTensorType>().getEncoding().cast<SharedEncodingAttr>();
|
||||||
|
|
||||||
|
// Can only convert [1, 0] to row or [0, 1] to col for now
|
||||||
|
if((srcSharedLayout.getOrder()[0] == 1 && !isMMAv1Row) ||
|
||||||
|
(srcSharedLayout.getOrder()[0] == 0 && isMMAv1Row)){
|
||||||
|
llvm::errs() << "Unsupported Shared -> DotOperand[MMAv1] conversion\n";
|
||||||
|
return Value();
|
||||||
|
}
|
||||||
|
|
||||||
if (dotOperandLayout.getOpIdx() == 0) { // operand $a
|
if (dotOperandLayout.getOpIdx() == 0) { // operand $a
|
||||||
// TODO[Superjomn]: transA is not available here.
|
// TODO[Superjomn]: transA is not available here.
|
||||||
bool transA = false;
|
bool transA = false;
|
||||||
|
@@ -589,15 +589,22 @@ Attribute DotOperandEncodingAttr::parse(AsmParser &parser, Type type) {
|
|||||||
return {};
|
return {};
|
||||||
unsigned opIdx = attrs.get("opIdx").cast<IntegerAttr>().getInt();
|
unsigned opIdx = attrs.get("opIdx").cast<IntegerAttr>().getInt();
|
||||||
Attribute parent = attrs.get("parent");
|
Attribute parent = attrs.get("parent");
|
||||||
|
Attribute isMMAv1Row;
|
||||||
|
if(parent.isa<MmaEncodingAttr>() &&
|
||||||
|
parent.cast<MmaEncodingAttr>().getVersion() == 1){
|
||||||
|
isMMAv1Row = attrs.get("isMMAv1Row");
|
||||||
|
}
|
||||||
return parser.getChecked<DotOperandEncodingAttr>(parser.getContext(), opIdx,
|
return parser.getChecked<DotOperandEncodingAttr>(parser.getContext(), opIdx,
|
||||||
parent);
|
parent, isMMAv1Row);
|
||||||
}
|
}
|
||||||
|
|
||||||
void DotOperandEncodingAttr::print(mlir::AsmPrinter &printer) const {
|
void DotOperandEncodingAttr::print(mlir::AsmPrinter &printer) const {
|
||||||
printer << "<{"
|
printer << "<{"
|
||||||
<< "opIdx = " << getOpIdx() << ", "
|
<< "opIdx = " << getOpIdx() << ", "
|
||||||
<< "parent = " << getParent() << "}>";
|
<< "parent = " << getParent();
|
||||||
|
if(getIsMMAv1Row())
|
||||||
|
printer << ", isMMAv1Row = " << getIsMMAv1Row();
|
||||||
|
printer << "}>";
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@@ -713,6 +713,46 @@ public:
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class OptimizeBlockedToDotOperand : public mlir::RewritePattern {
|
||||||
|
public:
|
||||||
|
OptimizeBlockedToDotOperand(mlir::MLIRContext *context)
|
||||||
|
: RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 1,
|
||||||
|
context) {}
|
||||||
|
|
||||||
|
mlir::LogicalResult
|
||||||
|
matchAndRewrite(mlir::Operation *op,
|
||||||
|
mlir::PatternRewriter &rewriter) const override {
|
||||||
|
auto cvt = cast<triton::gpu::ConvertLayoutOp>(op);
|
||||||
|
auto srcType = cvt.getOperand().getType().cast<RankedTensorType>();
|
||||||
|
auto dstType = cvt.getResult().getType().cast<RankedTensorType>();
|
||||||
|
auto srcBlockedLayout =
|
||||||
|
srcType.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>();
|
||||||
|
auto dstDotOperandLayout =
|
||||||
|
dstType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
|
||||||
|
if (!srcBlockedLayout || !dstDotOperandLayout)
|
||||||
|
return failure();
|
||||||
|
unsigned opIdx = dstDotOperandLayout.getOpIdx();
|
||||||
|
if(!dstDotOperandLayout.getIsMMAv1Row())
|
||||||
|
return failure();
|
||||||
|
bool isMMAv1Row = dstDotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
|
||||||
|
if((srcBlockedLayout.getOrder()[0] == 1 && isMMAv1Row) ||
|
||||||
|
(srcBlockedLayout.getOrder()[0] == 0 && !isMMAv1Row))
|
||||||
|
return failure();
|
||||||
|
auto newIsRow = BoolAttr::get(op->getContext(), !isMMAv1Row);
|
||||||
|
auto newDstEncoding = triton::gpu::DotOperandEncodingAttr::get(
|
||||||
|
op->getContext(), dstDotOperandLayout.getOpIdx(), dstDotOperandLayout.getParent(),
|
||||||
|
newIsRow);
|
||||||
|
auto newDstType = RankedTensorType::get(
|
||||||
|
dstType.getShape(),
|
||||||
|
dstType.getElementType(), newDstEncoding);
|
||||||
|
auto newCvt = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||||
|
op->getLoc(), newDstType, cvt.getOperand());
|
||||||
|
rewriter.replaceOp(op, newCvt.getResult());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
class BlockedToMMA : public mlir::RewritePattern {
|
class BlockedToMMA : public mlir::RewritePattern {
|
||||||
int computeCapability;
|
int computeCapability;
|
||||||
|
|
||||||
@@ -770,14 +810,28 @@ public:
|
|||||||
Value b = dotOp.b();
|
Value b = dotOp.b();
|
||||||
auto oldAType = a.getType().cast<RankedTensorType>();
|
auto oldAType = a.getType().cast<RankedTensorType>();
|
||||||
auto oldBType = b.getType().cast<RankedTensorType>();
|
auto oldBType = b.getType().cast<RankedTensorType>();
|
||||||
|
auto oldAOrder = oldAType.getEncoding().cast<triton::gpu::DotOperandEncodingAttr>()
|
||||||
|
.getParent().cast<triton::gpu::BlockedEncodingAttr>().getOrder();
|
||||||
|
auto oldBOrder = oldBType.getEncoding().cast<triton::gpu::DotOperandEncodingAttr>()
|
||||||
|
.getParent().cast<triton::gpu::BlockedEncodingAttr>().getOrder();
|
||||||
|
Attribute isMMAv1RowA;
|
||||||
|
Attribute isMMAv1RowB;
|
||||||
|
if(version == 1){
|
||||||
|
isMMAv1RowA = BoolAttr::get(getContext(), oldAOrder[0] == 1);
|
||||||
|
isMMAv1RowB = BoolAttr::get(getContext(), oldBOrder[0] == 1);
|
||||||
|
}
|
||||||
|
|
||||||
auto newAType = RankedTensorType::get(
|
auto newAType = RankedTensorType::get(
|
||||||
oldAType.getShape(), oldAType.getElementType(),
|
oldAType.getShape(), oldAType.getElementType(),
|
||||||
triton::gpu::DotOperandEncodingAttr::get(oldAType.getContext(), 0,
|
triton::gpu::DotOperandEncodingAttr::get(oldAType.getContext(), 0,
|
||||||
newRetType.getEncoding()));
|
newRetType.getEncoding(),
|
||||||
|
isMMAv1RowA));
|
||||||
auto newBType = RankedTensorType::get(
|
auto newBType = RankedTensorType::get(
|
||||||
oldBType.getShape(), oldBType.getElementType(),
|
oldBType.getShape(), oldBType.getElementType(),
|
||||||
triton::gpu::DotOperandEncodingAttr::get(oldBType.getContext(), 1,
|
triton::gpu::DotOperandEncodingAttr::get(oldBType.getContext(), 1,
|
||||||
newRetType.getEncoding()));
|
newRetType.getEncoding(),
|
||||||
|
isMMAv1RowB));
|
||||||
|
|
||||||
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), newAType, a);
|
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), newAType, a);
|
||||||
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), newBType, b);
|
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), newBType, b);
|
||||||
auto newDot = rewriter.create<triton::DotOp>(
|
auto newDot = rewriter.create<triton::DotOp>(
|
||||||
@@ -808,6 +862,7 @@ public:
|
|||||||
mlir::RewritePatternSet patterns(context);
|
mlir::RewritePatternSet patterns(context);
|
||||||
|
|
||||||
patterns.add<OptimizeBlockedToShared>(context);
|
patterns.add<OptimizeBlockedToShared>(context);
|
||||||
|
patterns.add<OptimizeBlockedToDotOperand>(context);
|
||||||
patterns.add<SimplifyConversion>(context);
|
patterns.add<SimplifyConversion>(context);
|
||||||
patterns.add<DecomposeDotOperand>(context);
|
patterns.add<DecomposeDotOperand>(context);
|
||||||
patterns.add<RematerializeBackward>(context);
|
patterns.add<RematerializeBackward>(context);
|
||||||
|
Reference in New Issue
Block a user