op combine in Triton Dialect: broadcast(cst) -> cst

This commit is contained in:
Yan Da
2022-06-17 16:19:47 +08:00
parent 35736aa44e
commit 9feb256b71
4 changed files with 76 additions and 93 deletions

View File

@@ -10,7 +10,7 @@
#include <memory>
// using namespace mlir;
using namespace mlir;
namespace {
// dot(a, b, 0) + c => dot(a, b, c)
@@ -114,6 +114,39 @@ public:
return mlir::failure();
}
};
// broadcast(cst) => cst
// TODO: move this to .td file
class CombineBroadcastConstantOp : public mlir::RewritePattern {
public:
CombineBroadcastConstantOp(mlir::MLIRContext *context)
: mlir::RewritePattern(mlir::RewritePattern::MatchAnyOpTypeTag(), 1,
context) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
if (auto broadcast = llvm::dyn_cast<triton::BroadcastOp>(op)) {
if (auto cst = broadcast.src().getDefiningOp<arith::ConstantOp>()) {
Attribute value = cst.getValue();
Type resType = broadcast.getResult().getType();
if (auto denseValue = value.dyn_cast<DenseElementsAttr>()) {
if (!denseValue.isSplat())
return failure();
value = DenseElementsAttr::get(resType, denseValue.getSplatValue<Attribute>());
} else {
if (!value.isa<FloatAttr, IntegerAttr>())
return failure();
value = DenseElementsAttr::get(resType, value);
}
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
op, value, resType
);
return success();
}
}
return failure();
}
};
} // anonymous namespace
#define GEN_PASS_CLASSES
@@ -129,6 +162,7 @@ public:
patterns.add<CombineDotOp>(context);
patterns.add<CombineSelectMaskedLoadOp>(context);
patterns.add<CombineGEPOp>(context);
patterns.add<CombineBroadcastConstantOp>(context);
// patterns.add<CombineReduceOp>(context);
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed())