Merge remote-tracking branch 'origin/master' into phil/fused-attention-perf-fixup

This commit is contained in:
Phil Tillet
2022-12-30 11:53:49 -08:00
21 changed files with 765 additions and 337 deletions

View File

@@ -23,6 +23,10 @@
using namespace mlir;
namespace {
#include "TritonGPUCombine.inc"
using triton::DotOp;
using triton::gpu::ConvertLayoutOp;
using triton::gpu::DotOperandEncodingAttr;
using triton::gpu::MmaEncodingAttr;
// -----------------------------------------------------------------------------
//
@@ -1020,6 +1024,7 @@ public:
dstDotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
if ((order[0] == 1 && isMMAv1Row) || (order[0] == 0 && !isMMAv1Row))
return failure();
auto newIsRow = BoolAttr::get(op->getContext(), !isMMAv1Row);
auto newDstEncoding = triton::gpu::DotOperandEncodingAttr::get(
op->getContext(), dstDotOperandLayout.getOpIdx(),
@@ -1061,7 +1066,8 @@ public:
auto dotOp = cast<triton::DotOp>(op);
// TODO: Check data-types and SM compatibility
auto oldRetType = dotOp.getResult().getType().cast<RankedTensorType>();
if (oldRetType.getEncoding().isa<triton::gpu::MmaEncodingAttr>())
if (!oldRetType.getEncoding() ||
oldRetType.getEncoding().isa<triton::gpu::MmaEncodingAttr>())
return failure();
auto AType = dotOp.getOperand(0).getType().cast<RankedTensorType>();
@@ -1171,7 +1177,8 @@ public:
for (size_t i = 0; i < newInitArgs.size(); i++) {
auto initArg = newInitArgs[i];
auto regionArg = forOp.getRegionIterArgs()[i];
if (newInitArgs[i].getType() != forOp.getRegionIterArgs()[i].getType()) {
if (newInitArgs[i].getType() != forOp.getRegionIterArgs()[i].getType() ||
newInitArgs[i].getType() != forOp.getResultTypes()[i]) {
shouldRematerialize = true;
break;
}
@@ -1187,15 +1194,207 @@ public:
BlockAndValueMapping mapping;
for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs()))
mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
mapping.map(forOp.getInductionVar(), newForOp.getInductionVar());
for (Operation &op : forOp.getBody()->getOperations()) {
Operation *newOp = rewriter.clone(op, mapping);
rewriter.clone(op, mapping);
}
rewriter.replaceOp(forOp, newForOp.getResults());
return success();
}
};
// This pattern collects the wrong Mma those need to update and create the right
// ones for each.
class CollectMmaToUpdateForVolta : public mlir::RewritePattern {
DenseMap<MmaEncodingAttr, MmaEncodingAttr> &mmaToUpdate;
public:
CollectMmaToUpdateForVolta(
mlir::MLIRContext *ctx,
DenseMap<MmaEncodingAttr, MmaEncodingAttr> &mmaToUpdate)
: mlir::RewritePattern(triton::DotOp::getOperationName(), 1, ctx),
mmaToUpdate(mmaToUpdate) {}
mlir::LogicalResult
matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
auto dotOp = cast<triton::DotOp>(op);
auto *ctx = dotOp->getContext();
auto AT = dotOp.a().getType().cast<RankedTensorType>();
auto BT = dotOp.b().getType().cast<RankedTensorType>();
auto DT = dotOp.d().getType().cast<RankedTensorType>();
if (!DT.getEncoding())
return failure();
auto mmaLayout = DT.getEncoding().dyn_cast<MmaEncodingAttr>();
if (!(mmaLayout && mmaLayout.isVolta()))
return failure();
// Has processed.
if (mmaToUpdate.count(mmaLayout))
return failure();
auto dotOperandA = AT.getEncoding().cast<DotOperandEncodingAttr>();
auto dotOperandB = BT.getEncoding().cast<DotOperandEncodingAttr>();
bool isARow = dotOperandA.getIsMMAv1Row().cast<BoolAttr>().getValue();
bool isBRow = dotOperandB.getIsMMAv1Row().cast<BoolAttr>().getValue();
auto [isARow_, isBRow_, isAVec4, isBVec4] =
mmaLayout.decodeVoltaLayoutStates();
if (isARow_ == isARow && isBRow_ == isBRow) {
return failure(); // No need to update
}
auto newMmaLayout = MmaEncodingAttr::get(
ctx, mmaLayout.getVersionMajor(), mmaLayout.getWarpsPerCTA(),
AT.getShape(), BT.getShape(), isARow, isBRow);
// Collect the wrong MMA Layouts, and mark need to update.
mmaToUpdate.try_emplace(mmaLayout, newMmaLayout);
return failure();
}
};
// Correct the versionMinor field in MmaEncodingAttr for Volta.
class UpdateMMAVersionMinorForVolta : public mlir::RewritePattern {
const DenseMap<MmaEncodingAttr, MmaEncodingAttr> &mmaToUpdate;
enum class Kind {
kUnk,
kCvtToMma,
kCvtToDotOp,
kDot,
kConstant,
};
mutable Kind rewriteKind{Kind::kUnk};
public:
UpdateMMAVersionMinorForVolta(
mlir::MLIRContext *ctx, llvm::StringRef opName,
const DenseMap<MmaEncodingAttr, MmaEncodingAttr> &mmaToUpdate)
: RewritePattern(opName, 1 /*benefit*/, ctx), mmaToUpdate(mmaToUpdate) {}
LogicalResult match(Operation *op) const override {
MmaEncodingAttr mma;
if (mmaToUpdate.empty())
return failure();
if (op->getNumResults() != 1)
return failure();
auto tensorTy = op->getResult(0).getType().dyn_cast<RankedTensorType>();
if (!tensorTy)
return failure();
// ConvertLayoutOp
if (auto cvt = llvm::dyn_cast<ConvertLayoutOp>(op)) {
// cvt X -> dot_operand
if (auto dotOperand =
tensorTy.getEncoding().dyn_cast<DotOperandEncodingAttr>()) {
mma = dotOperand.getParent().dyn_cast<MmaEncodingAttr>();
rewriteKind = Kind::kCvtToDotOp;
if (mma && mmaToUpdate.count(mma))
return success();
}
if ((mma = tensorTy.getEncoding().dyn_cast<MmaEncodingAttr>())) {
// cvt X -> mma
rewriteKind = Kind::kCvtToMma;
if (mma && mmaToUpdate.count(mma))
return success();
}
} else if (auto dot = llvm::dyn_cast<DotOp>(op)) {
// DotOp
mma = dot.d()
.getType()
.cast<RankedTensorType>()
.getEncoding()
.dyn_cast<MmaEncodingAttr>();
rewriteKind = Kind::kDot;
} else if (auto constant = llvm::dyn_cast<arith::ConstantOp>(op)) {
// ConstantOp
mma = tensorTy.getEncoding().dyn_cast<MmaEncodingAttr>();
rewriteKind = Kind::kConstant;
}
return success(mma && mmaToUpdate.count(mma));
}
void rewrite(Operation *op, PatternRewriter &rewriter) const override {
switch (rewriteKind) {
case Kind::kDot:
rewriteDot(op, rewriter);
break;
case Kind::kConstant:
rewriteConstant(op, rewriter);
break;
case Kind::kCvtToDotOp:
rewriteCvtDotOp(op, rewriter);
break;
case Kind::kCvtToMma:
rewriteCvtToMma(op, rewriter);
break;
default:
llvm::report_fatal_error("Not supported rewrite kind");
}
}
private:
void rewriteCvtDotOp(Operation *op, PatternRewriter &rewriter) const {
auto *ctx = op->getContext();
auto cvt = llvm::cast<ConvertLayoutOp>(op);
auto tensorTy = cvt.result().getType().cast<RankedTensorType>();
auto dotOperand = tensorTy.getEncoding().cast<DotOperandEncodingAttr>();
MmaEncodingAttr newMma =
mmaToUpdate.lookup(dotOperand.getParent().cast<MmaEncodingAttr>());
auto newDotOperand = DotOperandEncodingAttr::get(
ctx, dotOperand.getOpIdx(), newMma, dotOperand.getIsMMAv1Row());
auto newTensorTy = RankedTensorType::get(
tensorTy.getShape(), tensorTy.getElementType(), newDotOperand);
rewriter.replaceOpWithNewOp<ConvertLayoutOp>(op, newTensorTy,
cvt.getOperand());
}
void rewriteDot(Operation *op, PatternRewriter &rewriter) const {
auto *ctx = op->getContext();
auto dot = llvm::cast<DotOp>(op);
auto tensorTy = dot.d().getType().cast<RankedTensorType>();
auto mma = tensorTy.getEncoding().cast<MmaEncodingAttr>();
auto newMma = mmaToUpdate.lookup(mma);
auto newTensorTy = RankedTensorType::get(tensorTy.getShape(),
tensorTy.getElementType(), newMma);
rewriter.replaceOpWithNewOp<DotOp>(op, newTensorTy, dot.a(), dot.b(),
dot.c(), dot.allowTF32());
}
void rewriteCvtToMma(Operation *op, PatternRewriter &rewriter) const {
auto *ctx = op->getContext();
auto cvt = llvm::cast<ConvertLayoutOp>(op);
auto tensorTy = cvt.result().getType().cast<RankedTensorType>();
auto mma = tensorTy.getEncoding().cast<MmaEncodingAttr>();
auto newMma = mmaToUpdate.lookup(mma);
auto newTensorTy = RankedTensorType::get(tensorTy.getShape(),
tensorTy.getElementType(), newMma);
rewriter.replaceOpWithNewOp<ConvertLayoutOp>(op, newTensorTy,
cvt.getOperand());
}
void rewriteConstant(Operation *op, PatternRewriter &rewriter) const {
auto *ctx = op->getContext();
auto constant = llvm::cast<arith::ConstantOp>(op);
auto tensorTy = constant.getResult().getType().dyn_cast<RankedTensorType>();
auto mma = tensorTy.getEncoding().cast<MmaEncodingAttr>();
auto newMma = mmaToUpdate.lookup(mma);
auto newTensorTy = RankedTensorType::get(tensorTy.getShape(),
tensorTy.getElementType(), newMma);
if (auto attr = constant.getValue().dyn_cast<SplatElementsAttr>()) {
auto newRet =
SplatElementsAttr::get(newTensorTy, attr.getSplatValue<Attribute>());
rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newTensorTy, newRet);
return;
}
assert(false && "Not supported ConstantOp value type");
}
};
} // namespace
#define GEN_PASS_CLASSES
@@ -1230,6 +1429,28 @@ public:
signalPassFailure();
}
llvm::DenseMap<MmaEncodingAttr, MmaEncodingAttr> mmaToUpdate;
{
mlir::RewritePatternSet patterns(context);
patterns.add<CollectMmaToUpdateForVolta>(context, mmaToUpdate);
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed())
signalPassFailure();
}
{
mlir::RewritePatternSet patterns(context);
patterns.add<UpdateMMAVersionMinorForVolta>(
context, DotOp::getOperationName(), mmaToUpdate);
patterns.add<UpdateMMAVersionMinorForVolta>(
context, ConvertLayoutOp::getOperationName(), mmaToUpdate);
patterns.add<UpdateMMAVersionMinorForVolta>(
context, arith::ConstantOp::getOperationName(), mmaToUpdate);
mlir::GreedyRewriteConfig config;
config.useTopDownTraversal = true;
if (applyPatternsAndFoldGreedily(m, std::move(patterns), config).failed())
signalPassFailure();
}
mlir::RewritePatternSet loopFixup(context);
loopFixup.add<FixupLoop>(context);
if (applyPatternsAndFoldGreedily(m, std::move(loopFixup)).failed()) {