diff --git a/include/triton/Dialect/TritonGPU/Transforms/Passes.td b/include/triton/Dialect/TritonGPU/Transforms/Passes.td index 490df2730..4e50bcbd4 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Passes.td +++ b/include/triton/Dialect/TritonGPU/Transforms/Passes.td @@ -3,7 +3,7 @@ include "mlir/Pass/PassBase.td" -def TritonGPUPipeline : Pass<"tritongpu-pipeline", "mlir::FuncOp"> { +def TritonGPUPipeline : Pass<"tritongpu-pipeline", "mlir::ModuleOp"> { let summary = "pipeline"; let description = [{ diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp index ecb9f8b9a..d3b2f899f 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp @@ -7,10 +7,102 @@ using namespace mlir; #include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" namespace { +class LoopPipeliner { + struct PipelineInfo { + triton::DotOp dotOp; + triton::LoadOp aLoadOp; + triton::LoadOp bLoadOp; + }; + + int numStages; + /// cache forOp we are working on + scf::ForOp forOp; + /// dot & loads + PipelineInfo info; + /// value (in loop) => value at stage N + DenseMap> valueMapping; + + void setStageValueMapping(Value origin, Value prefetched, int idx); +public: + LoopPipeliner(scf::ForOp forOp, int numStages) + : forOp(forOp), numStages(numStages) {} + + /// Collect loop info. Return success if we can pipeline this loop + LogicalResult initialize(); + + /// + void emitPrologue(); + + friend class PipelinePass; +}; + +/// A load instruction can be pipelined if: +/// - the pointer is a block argument (redefined inside the loop) +/// - the load has only a single use in a dot instruction +LogicalResult LoopPipeliner::initialize() { + Region &bodyRegion = forOp.getLoopBody(); + assert(bodyRegion.hasOneBlock()); + Block &loop = bodyRegion.front(); + + // TODO: can we use forOp.walk(...) here? + SmallVector dots; + for (Operation &op : loop) { + if (auto dotOp = dyn_cast(&op)) { + dots.push_back(dotOp); + } + } + + // Don't know what to do if we have more than 1 dots inside the loop + if (dots.size() != 1) + return failure(); + + triton::DotOp dotOp = dots[0]; + // dot (cvt (load %ptr0)), (cvt (load %ptr1)) + auto getDefinintLoad = [&](Value v) -> triton::LoadOp { + auto cvt = v.getDefiningOp(); + if (cvt) { + return cvt.src().getDefiningOp(); + } + return nullptr; + }; + auto aLoad = getDefinintLoad(dotOp.a()); + auto bLoad = getDefinintLoad(dotOp.b()); + + // ptrs must be block args (phi nodes) + if (aLoad && bLoad) { + if (aLoad.ptr().isa() && bLoad.ptr().isa()) { + info.dotOp = dotOp; info.aLoadOp = aLoad; info.bLoadOp = bLoad; + return success(); + } + } + + return failure(); +} + +void LoopPipeliner::emitPrologue() { + OpBuilder builder(forOp); + // +} + +// ref: mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp struct PipelinePass : public TritonGPUPipelineBase { void runOnOperation() override { - getOperation()->walk([&](scf::ForOp forOp) { + // TODO: collect numStages from ModuleOp + int numStages = 2; + if (numStages <= 1) + return; + + getOperation()->walk([&](scf::ForOp forOp) -> void { + LoopPipeliner pipeliner(forOp, numStages); + + if (pipeliner.initialize().failed()) + return; + + llvm::errs() << "candidate for pipelining: " << pipeliner.info.dotOp + << "\n"; + + // pipeliner.emitPrologue(); }); } }; diff --git a/python/src/triton.cc b/python/src/triton.cc index 82142dfb5..4c16bf6d7 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -16,6 +16,7 @@ #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Types.h" #include "triton/Dialect/Triton/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" #include "llvm/IR/Module.h" #include "llvm/IR/LegacyPassManager.h" @@ -1340,6 +1341,9 @@ void init_triton_ir(py::module &&m) { .def("add_convert_triton_to_tritongpu_pass", [](mlir::PassManager &self) { self.addPass(mlir::triton::createConvertTritonToTritonGPUPass()); }) + .def("add_tritongpu_pipeline_pass", [](mlir::PassManager &self) { + self.addPass(mlir::createTritonGPUPipelinePass()); + }) ; } diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 45399c5ba..b5999648b 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -1314,6 +1314,7 @@ class JITFunction: pm.add_triton_combine_pass() pm.add_canonicalizer_pass() pm.add_convert_triton_to_tritongpu_pass() + pm.add_tritongpu_pipeline_pass() pm.run(mod) return mod