diff --git a/lib/Dialect/TritonGPU/Transforms/Combine.cpp b/lib/Dialect/TritonGPU/Transforms/Combine.cpp index 456ce1200..0135a1dca 100644 --- a/lib/Dialect/TritonGPU/Transforms/Combine.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Combine.cpp @@ -200,7 +200,7 @@ inline bool expensive_to_remat(Operation *op) { return true; if (isa(op)) + triton::DotOp, triton::ExtElemwiseOp>(op)) return true; if (isa(op)) return true; @@ -623,4 +623,4 @@ public: std::unique_ptr mlir::createTritonGPUCombineOpsPass() { return std::make_unique(); -} \ No newline at end of file +}