From 0ee6e486f837437519c6946b67ba529e60558175 Mon Sep 17 00:00:00 2001 From: Yan Da Date: Fri, 10 Jun 2022 17:31:48 +0800 Subject: [PATCH] add cse pass to the pipeline & pass num-warps as an argument --- python/src/triton.cc | 7 +++++-- python/triton/code_gen.py | 4 +++- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/python/src/triton.cc b/python/src/triton.cc index 1029ea1fd..53bf96d06 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -1343,11 +1343,14 @@ void init_triton_ir(py::module &&m) { .def("add_canonicalizer_pass", [](mlir::PassManager &self) { self.addPass(mlir::createCanonicalizerPass()); }) + .def("add_cse_pass", [](mlir::PassManager &self) { + self.addPass(mlir::createCSEPass()); + }) .def("add_triton_combine_pass", [](mlir::PassManager &self) { self.addPass(mlir::triton::createCombineOpsPass()); }) - .def("add_convert_triton_to_tritongpu_pass", [](mlir::PassManager &self) { - self.addPass(mlir::triton::createConvertTritonToTritonGPUPass()); + .def("add_convert_triton_to_tritongpu_pass", [](mlir::PassManager &self, int numWarps) { + self.addPass(mlir::triton::createConvertTritonToTritonGPUPass(numWarps)); }) .def("add_tritongpu_pipeline_pass", [](mlir::PassManager &self, int numStages) { self.addPass(mlir::createTritonGPUPipelinePass(numStages)); diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 87df9f89e..7cac37cee 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -1316,9 +1316,11 @@ class JITFunction: pm.add_inliner_pass() pm.add_triton_combine_pass() pm.add_canonicalizer_pass() - pm.add_convert_triton_to_tritongpu_pass() + pm.add_cse_pass() + pm.add_convert_triton_to_tritongpu_pass(num_warps) pm.add_tritongpu_pipeline_pass(num_stages) pm.add_canonicalizer_pass() + pm.add_cse_pass() pm.add_triton_gpu_combine_pass() pm.add_triton_gpu_verifier_pass() return pm.run(mod)