make numStages an option in PipelinePass

This commit is contained in:
Yan Da
2022-05-23 12:47:55 +08:00
parent 39b1235082
commit 36c45ec687
6 changed files with 23 additions and 9 deletions

View File

@@ -1341,8 +1341,8 @@ void init_triton_ir(py::module &&m) {
.def("add_convert_triton_to_tritongpu_pass", [](mlir::PassManager &self) {
self.addPass(mlir::triton::createConvertTritonToTritonGPUPass());
})
.def("add_tritongpu_pipeline_pass", [](mlir::PassManager &self) {
self.addPass(mlir::createTritonGPUPipelinePass());
.def("add_tritongpu_pipeline_pass", [](mlir::PassManager &self, int numStages) {
self.addPass(mlir::createTritonGPUPipelinePass(numStages));
})
.def("add_triton_gpu_combine_pass", [](mlir::PassManager &self) {
self.addPass(mlir::triton::gpu::createCombineOpsPass());

View File

@@ -1305,16 +1305,19 @@ class JITFunction:
if node is None or isinstance(e, (NotImplementedError, CompilationError)):
raise e
raise CompilationError(self.src, node) from e
# cache num_warps & num_stages
self.num_warps, self.num_stages = num_warps, num_stages
# FIXME: now we need to return context, otherwise it will be deleted
return generator.module, context
def compile_ttir_to_llir(self, mod, ctx):
num_warps, num_stages = self.num_warps, self.num_stages
pm = _triton.ir.pass_manager(ctx)
pm.add_inliner_pass()
pm.add_triton_combine_pass()
pm.add_canonicalizer_pass()
pm.add_convert_triton_to_tritongpu_pass()
pm.add_tritongpu_pipeline_pass()
pm.add_tritongpu_pipeline_pass(num_stages)
pm.add_canonicalizer_pass()
pm.add_triton_gpu_combine_pass()
pm.run(mod)