make numStages an option in PipelinePass
This commit is contained in:
@@ -311,9 +311,13 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
|
||||
// ref: mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
|
||||
struct PipelinePass : public TritonGPUPipelineBase<PipelinePass> {
|
||||
PipelinePass() = default;
|
||||
PipelinePass(int numStages) {
|
||||
this->numStages = numStages;
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
// TODO: collect numStages from ModuleOp
|
||||
int numStages = 2;
|
||||
int numStages = this->numStages;
|
||||
|
||||
if (numStages <= 1)
|
||||
return;
|
||||
@@ -337,6 +341,6 @@ struct PipelinePass : public TritonGPUPipelineBase<PipelinePass> {
|
||||
};
|
||||
} // anonymous namespace
|
||||
|
||||
std::unique_ptr<Pass> mlir::createTritonGPUPipelinePass() {
|
||||
return std::make_unique<PipelinePass>();
|
||||
std::unique_ptr<Pass> mlir::createTritonGPUPipelinePass(int numStages) {
|
||||
return std::make_unique<PipelinePass>(numStages);
|
||||
}
|
||||
|
Reference in New Issue
Block a user