diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp index d3b2f899f..a95578793 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp @@ -1,6 +1,15 @@ #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" +//===----------------------------------------------------------------------===// +// +// This file implements loop software pipelining +// The implementation here is inspired by the pipeline pass in Triton 2.0 and +// SCF's LoopPipelining. +// +//===----------------------------------------------------------------------===// + + using namespace mlir; #define GEN_PASS_CLASSES @@ -14,6 +23,9 @@ class LoopPipeliner { triton::LoadOp bLoadOp; }; + /// comments on numStages: + /// [0, numStages-1) are in the prologue + /// numStages-1 is appended after the loop body int numStages; /// cache forOp we are working on scf::ForOp forOp; @@ -21,8 +33,16 @@ class LoopPipeliner { PipelineInfo info; /// value (in loop) => value at stage N DenseMap> valueMapping; + /// stage => loop condition + DenseMap loopConds; - void setStageValueMapping(Value origin, Value prefetched, int idx); + DenseSet depArgs; + DenseSet depOps; + + void setValueMapping(Value origin, Value newValue, int stage); + + /// collect values that v depends on and are defined inside the loop + void collectDeps(Value v); public: LoopPipeliner(scf::ForOp forOp, int numStages) : forOp(forOp), numStages(numStages) {} @@ -30,12 +50,32 @@ public: /// Collect loop info. Return success if we can pipeline this loop LogicalResult initialize(); - /// + /// void emitPrologue(); friend class PipelinePass; }; +// helpers +void LoopPipeliner::setValueMapping(Value origin, Value newValue, int stage) { + if (valueMapping.find(origin) == valueMapping.end()) + valueMapping[origin] = SmallVector(numStages); + valueMapping[origin][stage] = newValue; +} + +void LoopPipeliner::collectDeps(Value v) { + if (v.getParentRegion() != &forOp.getLoopBody()) + return; + if (auto arg = v.dyn_cast()) + depArgs.insert(arg); + else { // value + Operation *defOp = v.getDefiningOp(); + depOps.insert(defOp); + for (Value op : defOp->getOperands()) + collectDeps(op); + } +} + /// A load instruction can be pipelined if: /// - the pointer is a block argument (redefined inside the loop) /// - the load has only a single use in a dot instruction @@ -72,6 +112,8 @@ LogicalResult LoopPipeliner::initialize() { if (aLoad && bLoad) { if (aLoad.ptr().isa() && bLoad.ptr().isa()) { info.dotOp = dotOp; info.aLoadOp = aLoad; info.bLoadOp = bLoad; + collectDeps(dotOp.a()); + collectDeps(dotOp.b()); return success(); } } @@ -80,8 +122,50 @@ LogicalResult LoopPipeliner::initialize() { } void LoopPipeliner::emitPrologue() { + // TODO: should we use rewriter here? OpBuilder builder(forOp); - // + for (BlockArgument &arg : forOp.getRegionIterArgs()) { + OpOperand &operand = forOp.getOpOperandForRegionIterArg(arg); + setValueMapping(arg, operand.get(), 0); + } + + // pro + Value iv = forOp.getInductionVar(); + for (int stage = 0; stage < numStages - 1; ++stage) { + // special handling for induction variable as the increment is implicit + if (stage != 0) + iv = builder.create(iv.getLoc(), iv, forOp.getStep()); + setValueMapping(forOp.getInductionVar(), iv, stage); + + // special handling for loop condition as there is no condition in ForOp + Value loopCond = builder.create( + iv.getLoc(), arith::CmpIPredicate::slt, iv, forOp.getUpperBound()); + loopConds[stage] = loopCond; + + // rematerialize peeled values + SmallVector orderedDeps; + for (Operation &op : forOp.getLoopBody().front()) + if (depOps.contains(&op)) + orderedDeps.push_back(&op); + assert(depOps.size() == orderedDeps.size() && "depOps contains invalid values"); + for (Operation *op : orderedDeps) { + Operation *newOp = builder.clone(*op); + for (unsigned opIdx = 0; opIdx < op->getNumOperands(); ++opIdx) { + auto it = valueMapping.find(op->getOperand(opIdx)); + if (it != valueMapping.end()) { + Value v = it->second[stage]; + assert(v); + newOp->setOperand(opIdx, v); + } // else, op at opIdx is a loop-invariant value + } + + // update mapping of results + for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) { + setValueMapping(op->getResult(dstIdx), newOp->getResult(dstIdx), stage); + // TODO: update mapping for loop-carried values (args) + } + } + } } // ref: mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp @@ -101,8 +185,16 @@ struct PipelinePass : public TritonGPUPipelineBase { llvm::errs() << "candidate for pipelining: " << pipeliner.info.dotOp << "\n"; - + // pipeliner.emitPrologue(); + + // scf::ForOp newForOp = pipeliner.createNewForOp(); + + // // replace the original loop + // if (forOp->getNumResults() > 0) + // rewriter.replaceOp(forOp, newForOp->getResults()); + // else + // rewriter.eraseOp(forOp); }); } };