make numStages an option in PipelinePass
This commit is contained in:
@@ -4,7 +4,7 @@
|
|||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
std::unique_ptr<Pass> createTritonGPUPipelinePass();
|
std::unique_ptr<Pass> createTritonGPUPipelinePass(int numStages);
|
||||||
|
|
||||||
namespace triton {
|
namespace triton {
|
||||||
namespace gpu {
|
namespace gpu {
|
||||||
|
@@ -24,6 +24,12 @@ def TritonGPUPipeline : Pass<"tritongpu-pipeline", "mlir::ModuleOp"> {
|
|||||||
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
|
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
|
||||||
"mlir::scf::SCFDialect",
|
"mlir::scf::SCFDialect",
|
||||||
"mlir::arith::ArithmeticDialect"];
|
"mlir::arith::ArithmeticDialect"];
|
||||||
|
|
||||||
|
let options = [
|
||||||
|
Option<"numStages", "num-stages",
|
||||||
|
"int32_t", /*default*/"2",
|
||||||
|
"number of pipeline stages">
|
||||||
|
];
|
||||||
}
|
}
|
||||||
|
|
||||||
def TritonGPUCombineOps : Pass<"tritongpu-combine", "mlir::ModuleOp"> {
|
def TritonGPUCombineOps : Pass<"tritongpu-combine", "mlir::ModuleOp"> {
|
||||||
|
@@ -311,9 +311,13 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
|||||||
|
|
||||||
// ref: mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
|
// ref: mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
|
||||||
struct PipelinePass : public TritonGPUPipelineBase<PipelinePass> {
|
struct PipelinePass : public TritonGPUPipelineBase<PipelinePass> {
|
||||||
|
PipelinePass() = default;
|
||||||
|
PipelinePass(int numStages) {
|
||||||
|
this->numStages = numStages;
|
||||||
|
}
|
||||||
|
|
||||||
void runOnOperation() override {
|
void runOnOperation() override {
|
||||||
// TODO: collect numStages from ModuleOp
|
int numStages = this->numStages;
|
||||||
int numStages = 2;
|
|
||||||
|
|
||||||
if (numStages <= 1)
|
if (numStages <= 1)
|
||||||
return;
|
return;
|
||||||
@@ -337,6 +341,6 @@ struct PipelinePass : public TritonGPUPipelineBase<PipelinePass> {
|
|||||||
};
|
};
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
|
|
||||||
std::unique_ptr<Pass> mlir::createTritonGPUPipelinePass() {
|
std::unique_ptr<Pass> mlir::createTritonGPUPipelinePass(int numStages) {
|
||||||
return std::make_unique<PipelinePass>();
|
return std::make_unique<PipelinePass>(numStages);
|
||||||
}
|
}
|
||||||
|
@@ -1341,8 +1341,8 @@ void init_triton_ir(py::module &&m) {
|
|||||||
.def("add_convert_triton_to_tritongpu_pass", [](mlir::PassManager &self) {
|
.def("add_convert_triton_to_tritongpu_pass", [](mlir::PassManager &self) {
|
||||||
self.addPass(mlir::triton::createConvertTritonToTritonGPUPass());
|
self.addPass(mlir::triton::createConvertTritonToTritonGPUPass());
|
||||||
})
|
})
|
||||||
.def("add_tritongpu_pipeline_pass", [](mlir::PassManager &self) {
|
.def("add_tritongpu_pipeline_pass", [](mlir::PassManager &self, int numStages) {
|
||||||
self.addPass(mlir::createTritonGPUPipelinePass());
|
self.addPass(mlir::createTritonGPUPipelinePass(numStages));
|
||||||
})
|
})
|
||||||
.def("add_triton_gpu_combine_pass", [](mlir::PassManager &self) {
|
.def("add_triton_gpu_combine_pass", [](mlir::PassManager &self) {
|
||||||
self.addPass(mlir::triton::gpu::createCombineOpsPass());
|
self.addPass(mlir::triton::gpu::createCombineOpsPass());
|
||||||
|
@@ -1305,16 +1305,19 @@ class JITFunction:
|
|||||||
if node is None or isinstance(e, (NotImplementedError, CompilationError)):
|
if node is None or isinstance(e, (NotImplementedError, CompilationError)):
|
||||||
raise e
|
raise e
|
||||||
raise CompilationError(self.src, node) from e
|
raise CompilationError(self.src, node) from e
|
||||||
|
# cache num_warps & num_stages
|
||||||
|
self.num_warps, self.num_stages = num_warps, num_stages
|
||||||
# FIXME: now we need to return context, otherwise it will be deleted
|
# FIXME: now we need to return context, otherwise it will be deleted
|
||||||
return generator.module, context
|
return generator.module, context
|
||||||
|
|
||||||
def compile_ttir_to_llir(self, mod, ctx):
|
def compile_ttir_to_llir(self, mod, ctx):
|
||||||
|
num_warps, num_stages = self.num_warps, self.num_stages
|
||||||
pm = _triton.ir.pass_manager(ctx)
|
pm = _triton.ir.pass_manager(ctx)
|
||||||
pm.add_inliner_pass()
|
pm.add_inliner_pass()
|
||||||
pm.add_triton_combine_pass()
|
pm.add_triton_combine_pass()
|
||||||
pm.add_canonicalizer_pass()
|
pm.add_canonicalizer_pass()
|
||||||
pm.add_convert_triton_to_tritongpu_pass()
|
pm.add_convert_triton_to_tritongpu_pass()
|
||||||
pm.add_tritongpu_pipeline_pass()
|
pm.add_tritongpu_pipeline_pass(num_stages)
|
||||||
pm.add_canonicalizer_pass()
|
pm.add_canonicalizer_pass()
|
||||||
pm.add_triton_gpu_combine_pass()
|
pm.add_triton_gpu_combine_pass()
|
||||||
pm.run(mod)
|
pm.run(mod)
|
||||||
|
@@ -91,7 +91,8 @@ mod, ctx = matmul_kernel.compile_to_ttir(
|
|||||||
b.stride(0), b.stride(1),
|
b.stride(0), b.stride(1),
|
||||||
c.stride(0), c.stride(1),
|
c.stride(0), c.stride(1),
|
||||||
128, 128, 128,
|
128, 128, 128,
|
||||||
8, grid=(2,)
|
8, grid=(2,),
|
||||||
|
num_stages=4
|
||||||
)
|
)
|
||||||
|
|
||||||
assert mod.verify()
|
assert mod.verify()
|
||||||
|
Reference in New Issue
Block a user