Merge remote-tracking branch 'origin/master' into phil/fused-attention-perf-fixup
This commit is contained in:
@@ -23,6 +23,10 @@
|
||||
using namespace mlir;
|
||||
namespace {
|
||||
#include "TritonGPUCombine.inc"
|
||||
using triton::DotOp;
|
||||
using triton::gpu::ConvertLayoutOp;
|
||||
using triton::gpu::DotOperandEncodingAttr;
|
||||
using triton::gpu::MmaEncodingAttr;
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
//
|
||||
@@ -1020,6 +1024,7 @@ public:
|
||||
dstDotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
|
||||
if ((order[0] == 1 && isMMAv1Row) || (order[0] == 0 && !isMMAv1Row))
|
||||
return failure();
|
||||
|
||||
auto newIsRow = BoolAttr::get(op->getContext(), !isMMAv1Row);
|
||||
auto newDstEncoding = triton::gpu::DotOperandEncodingAttr::get(
|
||||
op->getContext(), dstDotOperandLayout.getOpIdx(),
|
||||
@@ -1061,7 +1066,8 @@ public:
|
||||
auto dotOp = cast<triton::DotOp>(op);
|
||||
// TODO: Check data-types and SM compatibility
|
||||
auto oldRetType = dotOp.getResult().getType().cast<RankedTensorType>();
|
||||
if (oldRetType.getEncoding().isa<triton::gpu::MmaEncodingAttr>())
|
||||
if (!oldRetType.getEncoding() ||
|
||||
oldRetType.getEncoding().isa<triton::gpu::MmaEncodingAttr>())
|
||||
return failure();
|
||||
|
||||
auto AType = dotOp.getOperand(0).getType().cast<RankedTensorType>();
|
||||
@@ -1171,7 +1177,8 @@ public:
|
||||
for (size_t i = 0; i < newInitArgs.size(); i++) {
|
||||
auto initArg = newInitArgs[i];
|
||||
auto regionArg = forOp.getRegionIterArgs()[i];
|
||||
if (newInitArgs[i].getType() != forOp.getRegionIterArgs()[i].getType()) {
|
||||
if (newInitArgs[i].getType() != forOp.getRegionIterArgs()[i].getType() ||
|
||||
newInitArgs[i].getType() != forOp.getResultTypes()[i]) {
|
||||
shouldRematerialize = true;
|
||||
break;
|
||||
}
|
||||
@@ -1187,15 +1194,207 @@ public:
|
||||
BlockAndValueMapping mapping;
|
||||
for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs()))
|
||||
mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
|
||||
mapping.map(forOp.getInductionVar(), newForOp.getInductionVar());
|
||||
|
||||
for (Operation &op : forOp.getBody()->getOperations()) {
|
||||
Operation *newOp = rewriter.clone(op, mapping);
|
||||
rewriter.clone(op, mapping);
|
||||
}
|
||||
rewriter.replaceOp(forOp, newForOp.getResults());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// This pattern collects the wrong Mma those need to update and create the right
|
||||
// ones for each.
|
||||
class CollectMmaToUpdateForVolta : public mlir::RewritePattern {
|
||||
DenseMap<MmaEncodingAttr, MmaEncodingAttr> &mmaToUpdate;
|
||||
|
||||
public:
|
||||
CollectMmaToUpdateForVolta(
|
||||
mlir::MLIRContext *ctx,
|
||||
DenseMap<MmaEncodingAttr, MmaEncodingAttr> &mmaToUpdate)
|
||||
: mlir::RewritePattern(triton::DotOp::getOperationName(), 1, ctx),
|
||||
mmaToUpdate(mmaToUpdate) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::Operation *op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
|
||||
auto dotOp = cast<triton::DotOp>(op);
|
||||
auto *ctx = dotOp->getContext();
|
||||
auto AT = dotOp.a().getType().cast<RankedTensorType>();
|
||||
auto BT = dotOp.b().getType().cast<RankedTensorType>();
|
||||
auto DT = dotOp.d().getType().cast<RankedTensorType>();
|
||||
if (!DT.getEncoding())
|
||||
return failure();
|
||||
auto mmaLayout = DT.getEncoding().dyn_cast<MmaEncodingAttr>();
|
||||
if (!(mmaLayout && mmaLayout.isVolta()))
|
||||
return failure();
|
||||
|
||||
// Has processed.
|
||||
if (mmaToUpdate.count(mmaLayout))
|
||||
return failure();
|
||||
|
||||
auto dotOperandA = AT.getEncoding().cast<DotOperandEncodingAttr>();
|
||||
auto dotOperandB = BT.getEncoding().cast<DotOperandEncodingAttr>();
|
||||
bool isARow = dotOperandA.getIsMMAv1Row().cast<BoolAttr>().getValue();
|
||||
bool isBRow = dotOperandB.getIsMMAv1Row().cast<BoolAttr>().getValue();
|
||||
auto [isARow_, isBRow_, isAVec4, isBVec4] =
|
||||
mmaLayout.decodeVoltaLayoutStates();
|
||||
if (isARow_ == isARow && isBRow_ == isBRow) {
|
||||
return failure(); // No need to update
|
||||
}
|
||||
|
||||
auto newMmaLayout = MmaEncodingAttr::get(
|
||||
ctx, mmaLayout.getVersionMajor(), mmaLayout.getWarpsPerCTA(),
|
||||
AT.getShape(), BT.getShape(), isARow, isBRow);
|
||||
|
||||
// Collect the wrong MMA Layouts, and mark need to update.
|
||||
mmaToUpdate.try_emplace(mmaLayout, newMmaLayout);
|
||||
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
|
||||
// Correct the versionMinor field in MmaEncodingAttr for Volta.
|
||||
class UpdateMMAVersionMinorForVolta : public mlir::RewritePattern {
|
||||
const DenseMap<MmaEncodingAttr, MmaEncodingAttr> &mmaToUpdate;
|
||||
enum class Kind {
|
||||
kUnk,
|
||||
kCvtToMma,
|
||||
kCvtToDotOp,
|
||||
kDot,
|
||||
kConstant,
|
||||
};
|
||||
mutable Kind rewriteKind{Kind::kUnk};
|
||||
|
||||
public:
|
||||
UpdateMMAVersionMinorForVolta(
|
||||
mlir::MLIRContext *ctx, llvm::StringRef opName,
|
||||
const DenseMap<MmaEncodingAttr, MmaEncodingAttr> &mmaToUpdate)
|
||||
: RewritePattern(opName, 1 /*benefit*/, ctx), mmaToUpdate(mmaToUpdate) {}
|
||||
|
||||
LogicalResult match(Operation *op) const override {
|
||||
MmaEncodingAttr mma;
|
||||
if (mmaToUpdate.empty())
|
||||
return failure();
|
||||
if (op->getNumResults() != 1)
|
||||
return failure();
|
||||
auto tensorTy = op->getResult(0).getType().dyn_cast<RankedTensorType>();
|
||||
if (!tensorTy)
|
||||
return failure();
|
||||
|
||||
// ConvertLayoutOp
|
||||
if (auto cvt = llvm::dyn_cast<ConvertLayoutOp>(op)) {
|
||||
// cvt X -> dot_operand
|
||||
if (auto dotOperand =
|
||||
tensorTy.getEncoding().dyn_cast<DotOperandEncodingAttr>()) {
|
||||
mma = dotOperand.getParent().dyn_cast<MmaEncodingAttr>();
|
||||
rewriteKind = Kind::kCvtToDotOp;
|
||||
if (mma && mmaToUpdate.count(mma))
|
||||
return success();
|
||||
}
|
||||
if ((mma = tensorTy.getEncoding().dyn_cast<MmaEncodingAttr>())) {
|
||||
// cvt X -> mma
|
||||
rewriteKind = Kind::kCvtToMma;
|
||||
if (mma && mmaToUpdate.count(mma))
|
||||
return success();
|
||||
}
|
||||
} else if (auto dot = llvm::dyn_cast<DotOp>(op)) {
|
||||
// DotOp
|
||||
mma = dot.d()
|
||||
.getType()
|
||||
.cast<RankedTensorType>()
|
||||
.getEncoding()
|
||||
.dyn_cast<MmaEncodingAttr>();
|
||||
rewriteKind = Kind::kDot;
|
||||
} else if (auto constant = llvm::dyn_cast<arith::ConstantOp>(op)) {
|
||||
// ConstantOp
|
||||
mma = tensorTy.getEncoding().dyn_cast<MmaEncodingAttr>();
|
||||
rewriteKind = Kind::kConstant;
|
||||
}
|
||||
|
||||
return success(mma && mmaToUpdate.count(mma));
|
||||
}
|
||||
|
||||
void rewrite(Operation *op, PatternRewriter &rewriter) const override {
|
||||
switch (rewriteKind) {
|
||||
case Kind::kDot:
|
||||
rewriteDot(op, rewriter);
|
||||
break;
|
||||
case Kind::kConstant:
|
||||
rewriteConstant(op, rewriter);
|
||||
break;
|
||||
case Kind::kCvtToDotOp:
|
||||
rewriteCvtDotOp(op, rewriter);
|
||||
break;
|
||||
case Kind::kCvtToMma:
|
||||
rewriteCvtToMma(op, rewriter);
|
||||
break;
|
||||
default:
|
||||
llvm::report_fatal_error("Not supported rewrite kind");
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
void rewriteCvtDotOp(Operation *op, PatternRewriter &rewriter) const {
|
||||
auto *ctx = op->getContext();
|
||||
auto cvt = llvm::cast<ConvertLayoutOp>(op);
|
||||
auto tensorTy = cvt.result().getType().cast<RankedTensorType>();
|
||||
auto dotOperand = tensorTy.getEncoding().cast<DotOperandEncodingAttr>();
|
||||
MmaEncodingAttr newMma =
|
||||
mmaToUpdate.lookup(dotOperand.getParent().cast<MmaEncodingAttr>());
|
||||
auto newDotOperand = DotOperandEncodingAttr::get(
|
||||
ctx, dotOperand.getOpIdx(), newMma, dotOperand.getIsMMAv1Row());
|
||||
auto newTensorTy = RankedTensorType::get(
|
||||
tensorTy.getShape(), tensorTy.getElementType(), newDotOperand);
|
||||
rewriter.replaceOpWithNewOp<ConvertLayoutOp>(op, newTensorTy,
|
||||
cvt.getOperand());
|
||||
}
|
||||
|
||||
void rewriteDot(Operation *op, PatternRewriter &rewriter) const {
|
||||
auto *ctx = op->getContext();
|
||||
auto dot = llvm::cast<DotOp>(op);
|
||||
auto tensorTy = dot.d().getType().cast<RankedTensorType>();
|
||||
auto mma = tensorTy.getEncoding().cast<MmaEncodingAttr>();
|
||||
auto newMma = mmaToUpdate.lookup(mma);
|
||||
auto newTensorTy = RankedTensorType::get(tensorTy.getShape(),
|
||||
tensorTy.getElementType(), newMma);
|
||||
rewriter.replaceOpWithNewOp<DotOp>(op, newTensorTy, dot.a(), dot.b(),
|
||||
dot.c(), dot.allowTF32());
|
||||
}
|
||||
|
||||
void rewriteCvtToMma(Operation *op, PatternRewriter &rewriter) const {
|
||||
auto *ctx = op->getContext();
|
||||
auto cvt = llvm::cast<ConvertLayoutOp>(op);
|
||||
auto tensorTy = cvt.result().getType().cast<RankedTensorType>();
|
||||
auto mma = tensorTy.getEncoding().cast<MmaEncodingAttr>();
|
||||
auto newMma = mmaToUpdate.lookup(mma);
|
||||
auto newTensorTy = RankedTensorType::get(tensorTy.getShape(),
|
||||
tensorTy.getElementType(), newMma);
|
||||
rewriter.replaceOpWithNewOp<ConvertLayoutOp>(op, newTensorTy,
|
||||
cvt.getOperand());
|
||||
}
|
||||
|
||||
void rewriteConstant(Operation *op, PatternRewriter &rewriter) const {
|
||||
auto *ctx = op->getContext();
|
||||
auto constant = llvm::cast<arith::ConstantOp>(op);
|
||||
auto tensorTy = constant.getResult().getType().dyn_cast<RankedTensorType>();
|
||||
auto mma = tensorTy.getEncoding().cast<MmaEncodingAttr>();
|
||||
auto newMma = mmaToUpdate.lookup(mma);
|
||||
auto newTensorTy = RankedTensorType::get(tensorTy.getShape(),
|
||||
tensorTy.getElementType(), newMma);
|
||||
if (auto attr = constant.getValue().dyn_cast<SplatElementsAttr>()) {
|
||||
auto newRet =
|
||||
SplatElementsAttr::get(newTensorTy, attr.getSplatValue<Attribute>());
|
||||
rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newTensorTy, newRet);
|
||||
return;
|
||||
}
|
||||
|
||||
assert(false && "Not supported ConstantOp value type");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
@@ -1230,6 +1429,28 @@ public:
|
||||
signalPassFailure();
|
||||
}
|
||||
|
||||
llvm::DenseMap<MmaEncodingAttr, MmaEncodingAttr> mmaToUpdate;
|
||||
{
|
||||
mlir::RewritePatternSet patterns(context);
|
||||
patterns.add<CollectMmaToUpdateForVolta>(context, mmaToUpdate);
|
||||
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed())
|
||||
signalPassFailure();
|
||||
}
|
||||
{
|
||||
mlir::RewritePatternSet patterns(context);
|
||||
patterns.add<UpdateMMAVersionMinorForVolta>(
|
||||
context, DotOp::getOperationName(), mmaToUpdate);
|
||||
patterns.add<UpdateMMAVersionMinorForVolta>(
|
||||
context, ConvertLayoutOp::getOperationName(), mmaToUpdate);
|
||||
patterns.add<UpdateMMAVersionMinorForVolta>(
|
||||
context, arith::ConstantOp::getOperationName(), mmaToUpdate);
|
||||
mlir::GreedyRewriteConfig config;
|
||||
config.useTopDownTraversal = true;
|
||||
|
||||
if (applyPatternsAndFoldGreedily(m, std::move(patterns), config).failed())
|
||||
signalPassFailure();
|
||||
}
|
||||
|
||||
mlir::RewritePatternSet loopFixup(context);
|
||||
loopFixup.add<FixupLoop>(context);
|
||||
if (applyPatternsAndFoldGreedily(m, std::move(loopFixup)).failed()) {
|
||||
|
Reference in New Issue
Block a user