#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/Transforms/Passes.h" #include using namespace mlir; namespace { bool isZero(mlir::Value val) { if (mlir::matchPattern(val, mlir::m_Zero()) || mlir::matchPattern(val, mlir::m_AnyZeroFloat())) return true; // broadcast(constant_0) if (auto bc = val.getDefiningOp()) { if (mlir::matchPattern(bc.src(), mlir::m_Zero()) || mlir::matchPattern(bc.src(), mlir::m_AnyZeroFloat())) return true; } return false; } bool isBroadcastConstantCombinable(Attribute value) { if (auto denseValue = value.dyn_cast()) { return denseValue.isSplat(); } return value.isa(); } DenseElementsAttr getConstantValue(Builder &builder, Attribute value, Value bcast_res) { Type resType = bcast_res.getType(); DenseElementsAttr res; if (auto denseValue = value.dyn_cast()) { res = DenseElementsAttr::get(resType, denseValue.getSplatValue()); } else { res = DenseElementsAttr::get(resType, value); } return res; } #include "TritonCombine.inc" } // anonymous namespace #define GEN_PASS_CLASSES #include "triton/Dialect/Triton/Transforms/Passes.h.inc" class CombineOpsPass : public TritonCombineOpsBase { public: void runOnOperation() override { mlir::MLIRContext *context = &getContext(); mlir::RewritePatternSet patterns(context); mlir::ModuleOp m = getOperation(); // Dot Add %{ patterns.add(context); patterns.add(context); patterns.add(context); patterns.add(context); // %} patterns.add(context); patterns.add(context); patterns.add(context); if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) signalPassFailure(); } }; std::unique_ptr mlir::triton::createCombineOpsPass() { return std::make_unique(); }