op combine in Triton Dialect: broadcast(cst) -> cst
This commit is contained in:
@@ -1 +1,2 @@
|
||||
# add_subdirectory(TritonGPUToLLVM)
|
||||
add_subdirectory(TritonToTritonGPU)
|
||||
|
@@ -10,7 +10,7 @@
|
||||
|
||||
#include <memory>
|
||||
|
||||
// using namespace mlir;
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
// dot(a, b, 0) + c => dot(a, b, c)
|
||||
@@ -114,6 +114,39 @@ public:
|
||||
return mlir::failure();
|
||||
}
|
||||
};
|
||||
|
||||
// broadcast(cst) => cst
|
||||
// TODO: move this to .td file
|
||||
class CombineBroadcastConstantOp : public mlir::RewritePattern {
|
||||
public:
|
||||
CombineBroadcastConstantOp(mlir::MLIRContext *context)
|
||||
: mlir::RewritePattern(mlir::RewritePattern::MatchAnyOpTypeTag(), 1,
|
||||
context) {}
|
||||
|
||||
LogicalResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (auto broadcast = llvm::dyn_cast<triton::BroadcastOp>(op)) {
|
||||
if (auto cst = broadcast.src().getDefiningOp<arith::ConstantOp>()) {
|
||||
Attribute value = cst.getValue();
|
||||
Type resType = broadcast.getResult().getType();
|
||||
if (auto denseValue = value.dyn_cast<DenseElementsAttr>()) {
|
||||
if (!denseValue.isSplat())
|
||||
return failure();
|
||||
value = DenseElementsAttr::get(resType, denseValue.getSplatValue<Attribute>());
|
||||
} else {
|
||||
if (!value.isa<FloatAttr, IntegerAttr>())
|
||||
return failure();
|
||||
value = DenseElementsAttr::get(resType, value);
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
|
||||
op, value, resType
|
||||
);
|
||||
return success();
|
||||
}
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
} // anonymous namespace
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
@@ -129,6 +162,7 @@ public:
|
||||
patterns.add<CombineDotOp>(context);
|
||||
patterns.add<CombineSelectMaskedLoadOp>(context);
|
||||
patterns.add<CombineGEPOp>(context);
|
||||
patterns.add<CombineBroadcastConstantOp>(context);
|
||||
// patterns.add<CombineReduceOp>(context);
|
||||
|
||||
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed())
|
||||
|
Reference in New Issue
Block a user