From 36c45ec68725889a216dd6e08b19630eb507a0f4 Mon Sep 17 00:00:00 2001 From: Yan Da Date: Mon, 23 May 2022 12:47:55 +0800 Subject: [PATCH] make numStages an option in PipelinePass --- include/triton/Dialect/TritonGPU/Transforms/Passes.h | 2 +- .../triton/Dialect/TritonGPU/Transforms/Passes.td | 6 ++++++ lib/Dialect/TritonGPU/Transforms/Pipeline.cpp | 12 ++++++++---- python/src/triton.cc | 4 ++-- python/triton/code_gen.py | 5 ++++- rewrite-test/jit/matmul/matmul.py | 3 ++- 6 files changed, 23 insertions(+), 9 deletions(-) diff --git a/include/triton/Dialect/TritonGPU/Transforms/Passes.h b/include/triton/Dialect/TritonGPU/Transforms/Passes.h index 1fa150a60..e82a3fd67 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Passes.h +++ b/include/triton/Dialect/TritonGPU/Transforms/Passes.h @@ -4,7 +4,7 @@ #include "mlir/Pass/Pass.h" namespace mlir { -std::unique_ptr createTritonGPUPipelinePass(); +std::unique_ptr createTritonGPUPipelinePass(int numStages); namespace triton { namespace gpu { diff --git a/include/triton/Dialect/TritonGPU/Transforms/Passes.td b/include/triton/Dialect/TritonGPU/Transforms/Passes.td index 540bedd72..e038a4d89 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Passes.td +++ b/include/triton/Dialect/TritonGPU/Transforms/Passes.td @@ -24,6 +24,12 @@ def TritonGPUPipeline : Pass<"tritongpu-pipeline", "mlir::ModuleOp"> { let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", "mlir::scf::SCFDialect", "mlir::arith::ArithmeticDialect"]; + + let options = [ + Option<"numStages", "num-stages", + "int32_t", /*default*/"2", + "number of pipeline stages"> + ]; } def TritonGPUCombineOps : Pass<"tritongpu-combine", "mlir::ModuleOp"> { diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp index 32dc88e25..33217c200 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp @@ -311,9 +311,13 @@ scf::ForOp LoopPipeliner::createNewForOp() { // ref: mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp struct PipelinePass : public TritonGPUPipelineBase { + PipelinePass() = default; + PipelinePass(int numStages) { + this->numStages = numStages; + } + void runOnOperation() override { - // TODO: collect numStages from ModuleOp - int numStages = 2; + int numStages = this->numStages; if (numStages <= 1) return; @@ -337,6 +341,6 @@ struct PipelinePass : public TritonGPUPipelineBase { }; } // anonymous namespace -std::unique_ptr mlir::createTritonGPUPipelinePass() { - return std::make_unique(); +std::unique_ptr mlir::createTritonGPUPipelinePass(int numStages) { + return std::make_unique(numStages); } diff --git a/python/src/triton.cc b/python/src/triton.cc index 3b7c0d7ae..85ed9566a 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -1341,8 +1341,8 @@ void init_triton_ir(py::module &&m) { .def("add_convert_triton_to_tritongpu_pass", [](mlir::PassManager &self) { self.addPass(mlir::triton::createConvertTritonToTritonGPUPass()); }) - .def("add_tritongpu_pipeline_pass", [](mlir::PassManager &self) { - self.addPass(mlir::createTritonGPUPipelinePass()); + .def("add_tritongpu_pipeline_pass", [](mlir::PassManager &self, int numStages) { + self.addPass(mlir::createTritonGPUPipelinePass(numStages)); }) .def("add_triton_gpu_combine_pass", [](mlir::PassManager &self) { self.addPass(mlir::triton::gpu::createCombineOpsPass()); diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 67c9b8b0c..1516b841e 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -1305,16 +1305,19 @@ class JITFunction: if node is None or isinstance(e, (NotImplementedError, CompilationError)): raise e raise CompilationError(self.src, node) from e + # cache num_warps & num_stages + self.num_warps, self.num_stages = num_warps, num_stages # FIXME: now we need to return context, otherwise it will be deleted return generator.module, context def compile_ttir_to_llir(self, mod, ctx): + num_warps, num_stages = self.num_warps, self.num_stages pm = _triton.ir.pass_manager(ctx) pm.add_inliner_pass() pm.add_triton_combine_pass() pm.add_canonicalizer_pass() pm.add_convert_triton_to_tritongpu_pass() - pm.add_tritongpu_pipeline_pass() + pm.add_tritongpu_pipeline_pass(num_stages) pm.add_canonicalizer_pass() pm.add_triton_gpu_combine_pass() pm.run(mod) diff --git a/rewrite-test/jit/matmul/matmul.py b/rewrite-test/jit/matmul/matmul.py index 1e909dd50..9bd54fa81 100644 --- a/rewrite-test/jit/matmul/matmul.py +++ b/rewrite-test/jit/matmul/matmul.py @@ -91,7 +91,8 @@ mod, ctx = matmul_kernel.compile_to_ttir( b.stride(0), b.stride(1), c.stride(0), c.stride(1), 128, 128, 128, - 8, grid=(2,) + 8, grid=(2,), + num_stages=4 ) assert mod.verify()