[BACKEND] Add isRow attribute for DotOp tensors whose parent is mmav1 (#970)
Co-authored-by: Yan Chunwei <yanchunwei@outlook.com>
This commit is contained in:
@@ -589,15 +589,24 @@ Attribute DotOperandEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
return {};
|
||||
unsigned opIdx = attrs.get("opIdx").cast<IntegerAttr>().getInt();
|
||||
Attribute parent = attrs.get("parent");
|
||||
|
||||
Attribute isMMAv1Row;
|
||||
if(parent.isa<MmaEncodingAttr>() &&
|
||||
parent.cast<MmaEncodingAttr>().getVersion() == 1){
|
||||
isMMAv1Row = attrs.get("isMMAv1Row");
|
||||
if(!isMMAv1Row)
|
||||
llvm::report_fatal_error("isMMAv1Row attribute is missing");
|
||||
}
|
||||
return parser.getChecked<DotOperandEncodingAttr>(parser.getContext(), opIdx,
|
||||
parent);
|
||||
parent, isMMAv1Row);
|
||||
}
|
||||
|
||||
void DotOperandEncodingAttr::print(mlir::AsmPrinter &printer) const {
|
||||
printer << "<{"
|
||||
<< "opIdx = " << getOpIdx() << ", "
|
||||
<< "parent = " << getParent() << "}>";
|
||||
<< "parent = " << getParent();
|
||||
if(getIsMMAv1Row())
|
||||
printer << ", isMMAv1Row = " << getIsMMAv1Row();
|
||||
printer << "}>";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@@ -715,6 +715,55 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
class OptimizeConvertToDotOperand : public mlir::RewritePattern {
|
||||
public:
|
||||
OptimizeConvertToDotOperand(mlir::MLIRContext *context)
|
||||
: RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 1,
|
||||
context) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::Operation *op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto cvt = cast<triton::gpu::ConvertLayoutOp>(op);
|
||||
auto srcType = cvt.getOperand().getType().cast<RankedTensorType>();
|
||||
auto dstType = cvt.getResult().getType().cast<RankedTensorType>();
|
||||
// order
|
||||
ArrayRef<unsigned> order;
|
||||
if(auto srcBlockedLayout =
|
||||
srcType.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>())
|
||||
order = srcBlockedLayout.getOrder();
|
||||
else if(auto srcSharedLayout =
|
||||
srcType.getEncoding().dyn_cast<triton::gpu::SharedEncodingAttr>())
|
||||
order = srcSharedLayout.getOrder();
|
||||
else
|
||||
return failure();
|
||||
// dot operand output
|
||||
auto dstDotOperandLayout =
|
||||
dstType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
|
||||
if (!dstDotOperandLayout)
|
||||
return failure();
|
||||
unsigned opIdx = dstDotOperandLayout.getOpIdx();
|
||||
if(!dstDotOperandLayout.getIsMMAv1Row())
|
||||
return failure();
|
||||
bool isMMAv1Row = dstDotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
|
||||
if((order[0] == 1 && isMMAv1Row) ||
|
||||
(order[0] == 0 && !isMMAv1Row))
|
||||
return failure();
|
||||
auto newIsRow = BoolAttr::get(op->getContext(), !isMMAv1Row);
|
||||
auto newDstEncoding = triton::gpu::DotOperandEncodingAttr::get(
|
||||
op->getContext(), dstDotOperandLayout.getOpIdx(), dstDotOperandLayout.getParent(),
|
||||
newIsRow);
|
||||
auto newDstType = RankedTensorType::get(
|
||||
dstType.getShape(),
|
||||
dstType.getElementType(), newDstEncoding);
|
||||
auto newCvt = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
op->getLoc(), newDstType, cvt.getOperand());
|
||||
rewriter.replaceOp(op, newCvt.getResult());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
class BlockedToMMA : public mlir::RewritePattern {
|
||||
int computeCapability;
|
||||
|
||||
@@ -772,14 +821,28 @@ public:
|
||||
Value b = dotOp.b();
|
||||
auto oldAType = a.getType().cast<RankedTensorType>();
|
||||
auto oldBType = b.getType().cast<RankedTensorType>();
|
||||
auto oldAOrder = oldAType.getEncoding().cast<triton::gpu::DotOperandEncodingAttr>()
|
||||
.getParent().cast<triton::gpu::BlockedEncodingAttr>().getOrder();
|
||||
auto oldBOrder = oldBType.getEncoding().cast<triton::gpu::DotOperandEncodingAttr>()
|
||||
.getParent().cast<triton::gpu::BlockedEncodingAttr>().getOrder();
|
||||
Attribute isMMAv1RowA;
|
||||
Attribute isMMAv1RowB;
|
||||
if(version == 1){
|
||||
isMMAv1RowA = BoolAttr::get(getContext(), oldAOrder[0] == 1);
|
||||
isMMAv1RowB = BoolAttr::get(getContext(), oldBOrder[0] == 1);
|
||||
}
|
||||
|
||||
auto newAType = RankedTensorType::get(
|
||||
oldAType.getShape(), oldAType.getElementType(),
|
||||
triton::gpu::DotOperandEncodingAttr::get(oldAType.getContext(), 0,
|
||||
newRetType.getEncoding()));
|
||||
newRetType.getEncoding(),
|
||||
isMMAv1RowA));
|
||||
auto newBType = RankedTensorType::get(
|
||||
oldBType.getShape(), oldBType.getElementType(),
|
||||
triton::gpu::DotOperandEncodingAttr::get(oldBType.getContext(), 1,
|
||||
newRetType.getEncoding()));
|
||||
newRetType.getEncoding(),
|
||||
isMMAv1RowB));
|
||||
|
||||
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), newAType, a);
|
||||
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), newBType, b);
|
||||
auto newDot = rewriter.create<triton::DotOp>(dotOp.getLoc(), newRetType, a,
|
||||
@@ -791,6 +854,51 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
class FixupLoop : public mlir::RewritePattern {
|
||||
|
||||
public:
|
||||
FixupLoop(mlir::MLIRContext *context)
|
||||
: mlir::RewritePattern(scf::ForOp::getOperationName(), 2,
|
||||
context) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::Operation *op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto forOp = cast<scf::ForOp>(op);
|
||||
|
||||
// Rewrite init argument
|
||||
SmallVector<Value, 4> newInitArgs = forOp.getInitArgs();
|
||||
bool shouldRematerialize = false;
|
||||
for(size_t i = 0; i < newInitArgs.size(); i++){
|
||||
auto initArg = newInitArgs[i];
|
||||
auto regionArg = forOp.getRegionIterArgs()[i];
|
||||
if(newInitArgs[i].getType() != forOp.getRegionIterArgs()[i].getType()){
|
||||
shouldRematerialize = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if(!shouldRematerialize)
|
||||
return failure();
|
||||
|
||||
scf::ForOp newForOp = rewriter.create<scf::ForOp>(
|
||||
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
|
||||
forOp.getStep(), newInitArgs);
|
||||
newForOp->moveBefore(forOp);
|
||||
rewriter.setInsertionPointToStart(newForOp.getBody());
|
||||
BlockAndValueMapping mapping;
|
||||
for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs()))
|
||||
mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
|
||||
|
||||
for (Operation &op : forOp.getBody()->getOperations()) {
|
||||
Operation *newOp = rewriter.clone(op, mapping);
|
||||
}
|
||||
rewriter.replaceOp(forOp, newForOp.getResults());
|
||||
return success();
|
||||
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
@@ -810,6 +918,7 @@ public:
|
||||
mlir::RewritePatternSet patterns(context);
|
||||
|
||||
patterns.add<OptimizeBlockedToShared>(context);
|
||||
patterns.add<OptimizeConvertToDotOperand>(context);
|
||||
patterns.add<SimplifyConversion>(context);
|
||||
patterns.add<DecomposeDotOperand>(context);
|
||||
patterns.add<RematerializeBackward>(context);
|
||||
@@ -820,6 +929,13 @@ public:
|
||||
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) {
|
||||
signalPassFailure();
|
||||
}
|
||||
|
||||
// llvm::outs() << m << "\n";
|
||||
mlir::RewritePatternSet loopFixup(context);
|
||||
loopFixup.add<FixupLoop>(context);
|
||||
if (applyPatternsAndFoldGreedily(m, std::move(loopFixup)).failed()) {
|
||||
signalPassFailure();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
Reference in New Issue
Block a user