add cse pass to the pipeline & pass num-warps as an argument

This commit is contained in:
Yan Da
2022-06-10 17:31:48 +08:00
parent 117a402c1b
commit 0ee6e486f8
2 changed files with 8 additions and 3 deletions

View File

@@ -1343,11 +1343,14 @@ void init_triton_ir(py::module &&m) {
.def("add_canonicalizer_pass", [](mlir::PassManager &self) { .def("add_canonicalizer_pass", [](mlir::PassManager &self) {
self.addPass(mlir::createCanonicalizerPass()); self.addPass(mlir::createCanonicalizerPass());
}) })
.def("add_cse_pass", [](mlir::PassManager &self) {
self.addPass(mlir::createCSEPass());
})
.def("add_triton_combine_pass", [](mlir::PassManager &self) { .def("add_triton_combine_pass", [](mlir::PassManager &self) {
self.addPass(mlir::triton::createCombineOpsPass()); self.addPass(mlir::triton::createCombineOpsPass());
}) })
.def("add_convert_triton_to_tritongpu_pass", [](mlir::PassManager &self) { .def("add_convert_triton_to_tritongpu_pass", [](mlir::PassManager &self, int numWarps) {
self.addPass(mlir::triton::createConvertTritonToTritonGPUPass()); self.addPass(mlir::triton::createConvertTritonToTritonGPUPass(numWarps));
}) })
.def("add_tritongpu_pipeline_pass", [](mlir::PassManager &self, int numStages) { .def("add_tritongpu_pipeline_pass", [](mlir::PassManager &self, int numStages) {
self.addPass(mlir::createTritonGPUPipelinePass(numStages)); self.addPass(mlir::createTritonGPUPipelinePass(numStages));

View File

@@ -1316,9 +1316,11 @@ class JITFunction:
pm.add_inliner_pass() pm.add_inliner_pass()
pm.add_triton_combine_pass() pm.add_triton_combine_pass()
pm.add_canonicalizer_pass() pm.add_canonicalizer_pass()
pm.add_convert_triton_to_tritongpu_pass() pm.add_cse_pass()
pm.add_convert_triton_to_tritongpu_pass(num_warps)
pm.add_tritongpu_pipeline_pass(num_stages) pm.add_tritongpu_pipeline_pass(num_stages)
pm.add_canonicalizer_pass() pm.add_canonicalizer_pass()
pm.add_cse_pass()
pm.add_triton_gpu_combine_pass() pm.add_triton_gpu_combine_pass()
pm.add_triton_gpu_verifier_pass() pm.add_triton_gpu_verifier_pass()
return pm.run(mod) return pm.run(mod)