#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" #include using namespace mlir; static bool isSharedLayout(Value v) { if (auto tensorType = v.getType().dyn_cast()) { Attribute encoding = tensorType.getEncoding(); return encoding.isa(); } return false; } namespace { #include "TritonGPUCombine.inc" } #define GEN_PASS_CLASSES #include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" class TritonGPUCombineOpsPass : public TritonGPUCombineOpsBase { public: void runOnOperation() override { MLIRContext *context = &getContext(); ModuleOp m = getOperation(); mlir::RewritePatternSet patterns(context); patterns.add(context); patterns.add(context); patterns.add(context); if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) signalPassFailure(); } }; std::unique_ptr mlir::createTritonGPUCombineOpsPass() { return std::make_unique(); }