diff --git a/python/src/triton.cc b/python/src/triton.cc index 1029ea1fd..53bf96d06 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -1343,11 +1343,14 @@ void init_triton_ir(py::module &&m) { .def("add_canonicalizer_pass", [](mlir::PassManager &self) { self.addPass(mlir::createCanonicalizerPass()); }) + .def("add_cse_pass", [](mlir::PassManager &self) { + self.addPass(mlir::createCSEPass()); + }) .def("add_triton_combine_pass", [](mlir::PassManager &self) { self.addPass(mlir::triton::createCombineOpsPass()); }) - .def("add_convert_triton_to_tritongpu_pass", [](mlir::PassManager &self) { - self.addPass(mlir::triton::createConvertTritonToTritonGPUPass()); + .def("add_convert_triton_to_tritongpu_pass", [](mlir::PassManager &self, int numWarps) { + self.addPass(mlir::triton::createConvertTritonToTritonGPUPass(numWarps)); }) .def("add_tritongpu_pipeline_pass", [](mlir::PassManager &self, int numStages) { self.addPass(mlir::createTritonGPUPipelinePass(numStages)); diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 87df9f89e..7cac37cee 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -1316,9 +1316,11 @@ class JITFunction: pm.add_inliner_pass() pm.add_triton_combine_pass() pm.add_canonicalizer_pass() - pm.add_convert_triton_to_tritongpu_pass() + pm.add_cse_pass() + pm.add_convert_triton_to_tritongpu_pass(num_warps) pm.add_tritongpu_pipeline_pass(num_stages) pm.add_canonicalizer_pass() + pm.add_cse_pass() pm.add_triton_gpu_combine_pass() pm.add_triton_gpu_verifier_pass() return pm.run(mod)