add cse pass to the pipeline & pass num-warps as an argument
This commit is contained in:
@@ -1343,11 +1343,14 @@ void init_triton_ir(py::module &&m) {
|
||||
.def("add_canonicalizer_pass", [](mlir::PassManager &self) {
|
||||
self.addPass(mlir::createCanonicalizerPass());
|
||||
})
|
||||
.def("add_cse_pass", [](mlir::PassManager &self) {
|
||||
self.addPass(mlir::createCSEPass());
|
||||
})
|
||||
.def("add_triton_combine_pass", [](mlir::PassManager &self) {
|
||||
self.addPass(mlir::triton::createCombineOpsPass());
|
||||
})
|
||||
.def("add_convert_triton_to_tritongpu_pass", [](mlir::PassManager &self) {
|
||||
self.addPass(mlir::triton::createConvertTritonToTritonGPUPass());
|
||||
.def("add_convert_triton_to_tritongpu_pass", [](mlir::PassManager &self, int numWarps) {
|
||||
self.addPass(mlir::triton::createConvertTritonToTritonGPUPass(numWarps));
|
||||
})
|
||||
.def("add_tritongpu_pipeline_pass", [](mlir::PassManager &self, int numStages) {
|
||||
self.addPass(mlir::createTritonGPUPipelinePass(numStages));
|
||||
|
Reference in New Issue
Block a user