#include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" #include "mlir/IR/BlockAndValueMapping.h" //===----------------------------------------------------------------------===// // // This file implements loop software pipelining // The implementation here is inspired by the pipeline pass in Triton (-v2.0) // and SCF's LoopPipelining. // //===----------------------------------------------------------------------===// using namespace mlir; #define GEN_PASS_CLASSES #include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" namespace { class LoopPipeliner { struct PipelineInfo { triton::DotOp dotOp; triton::LoadOp aLoadOp; triton::LoadOp bLoadOp; }; /// comments on numStages: /// [0, numStages-1) are in the prologue /// numStages-1 is appended after the loop body int numStages; /// cache forOp we are working on scf::ForOp forOp; /// dot & loads PipelineInfo info; /// value (in loop) => value at stage N DenseMap> valueMapping; DenseSet depArgs; DenseSet depOps; void setValueMapping(Value origin, Value newValue, int stage); /// collect values that v depends on and are defined inside the loop void collectDeps(Value v); public: LoopPipeliner(scf::ForOp forOp, int numStages) : forOp(forOp), numStages(numStages) {} /// Collect loop info. Return success if we can pipeline this loop LogicalResult initialize(); /// void emitPrologue(); /// create the new ForOp (add new args & insert prefetched ops) scf::ForOp createNewForOp(); friend class PipelinePass; }; // helpers void LoopPipeliner::setValueMapping(Value origin, Value newValue, int stage) { if (valueMapping.find(origin) == valueMapping.end()) valueMapping[origin] = SmallVector(numStages); valueMapping[origin][stage] = newValue; } void LoopPipeliner::collectDeps(Value v) { if (v.getParentRegion() != &forOp.getLoopBody()) return; if (auto arg = v.dyn_cast()) { if (depArgs.contains(arg)) return; depArgs.insert(arg); // we also need to rematerialize this arg auto yield = cast(forOp.getBody()->getTerminator()); // Note: we have iv as the first arg, so the op idx is arg.getArgNumber()-1 collectDeps(yield->getOperand(arg.getArgNumber() - 1)); } else { // value Operation *defOp = v.getDefiningOp(); if (depOps.contains(defOp)) return; depOps.insert(defOp); for (Value op : defOp->getOperands()) collectDeps(op); } } /// A load instruction can be pipelined if: /// - the pointer is a block argument (redefined inside the loop) /// - the load has only a single use in a dot instruction LogicalResult LoopPipeliner::initialize() { Block *loop = forOp.getBody(); // TODO: can we use forOp.walk(...) here? SmallVector dots; for (Operation &op : *loop) { if (auto dotOp = dyn_cast(&op)) { dots.push_back(dotOp); } } // Don't know what to do if we have more than 1 dots inside the loop if (dots.size() != 1) return failure(); triton::DotOp dotOp = dots[0]; // dot (cvt (load %ptr0)), (cvt (load %ptr1)) auto getDefinintLoad = [&](Value v) -> triton::LoadOp { auto cvt = v.getDefiningOp(); if (cvt) { return cvt.src().getDefiningOp(); } return nullptr; }; auto aLoad = getDefinintLoad(dotOp.a()); auto bLoad = getDefinintLoad(dotOp.b()); // ptrs must be block args (phi nodes) if (aLoad && bLoad) { if (aLoad.ptr().isa() && bLoad.ptr().isa()) { info.dotOp = dotOp; info.aLoadOp = aLoad; info.bLoadOp = bLoad; collectDeps(dotOp.a()); collectDeps(dotOp.b()); return success(); } } return failure(); } void LoopPipeliner::emitPrologue() { // TODO: should we use rewriter here? OpBuilder builder(forOp); for (BlockArgument &arg : forOp.getRegionIterArgs()) { OpOperand &operand = forOp.getOpOperandForRegionIterArg(arg); setValueMapping(arg, operand.get(), 0); } // prologue from [0, numStage-1) auto yield = cast(forOp.getBody()->getTerminator()); Value iv = forOp.getLowerBound(); for (int stage = 0; stage < numStages - 1; ++stage) { // special handling for induction variable as the increment is implicit if (stage != 0) iv = builder.create(iv.getLoc(), iv, forOp.getStep()); setValueMapping(forOp.getInductionVar(), iv, stage); // special handling for loop condition as there is no condition in ForOp Value loopCond = builder.create( iv.getLoc(), arith::CmpIPredicate::slt, iv, forOp.getUpperBound()); // rematerialize peeled values SmallVector orderedDeps; for (Operation &op : forOp.getLoopBody().front()) if (depOps.contains(&op)) orderedDeps.push_back(&op); assert(depOps.size() == orderedDeps.size() && "depOps contains invalid values"); for (Operation *op : orderedDeps) { Operation *newOp = builder.clone(*op); for (unsigned opIdx = 0; opIdx < op->getNumOperands(); ++opIdx) { auto it = valueMapping.find(op->getOperand(opIdx)); if (it != valueMapping.end()) { Value v = it->second[stage]; assert(v); newOp->setOperand(opIdx, v); } // else, op at opIdx is a loop-invariant value } // update mapping of results for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) { setValueMapping(op->getResult(dstIdx), newOp->getResult(dstIdx), stage); // update mapping for loop-carried values (args) for (OpOperand &operand : yield->getOpOperands()) { if (operand.get() == op->getResult(dstIdx)) setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()], newOp->getResult(dstIdx), stage + 1); } } } } } scf::ForOp LoopPipeliner::createNewForOp() { OpBuilder builder(forOp); // order of new args: // (original args), // (a at stage[0, numStages-1)), (b at stage[0, numStages-1)) // (depArgs at stage numStages-1) // (iv at stage numStages-1) SmallVector newLoopArgs; // We need this to update operands for yield // original block arg => new arg's idx DenseMap depArgsIdx; for (auto v : forOp.getIterOperands()) newLoopArgs.push_back(v); size_t aArgIdx = newLoopArgs.size(); for (int i = 0; i < numStages - 1; ++i) newLoopArgs.push_back(valueMapping[info.dotOp.a()][i]); size_t bArgIdx = newLoopArgs.size(); for (int i = 0; i < numStages - 1; ++i) newLoopArgs.push_back(valueMapping[info.dotOp.b()][i]); size_t depArgsBeginIdx = newLoopArgs.size(); for (BlockArgument depArg : depArgs) { depArgsIdx[depArg] = newLoopArgs.size(); newLoopArgs.push_back(valueMapping[depArg][numStages-1]); } size_t nextIVIdx = newLoopArgs.size(); newLoopArgs.push_back(valueMapping[forOp.getInductionVar()][numStages-2]); for (size_t i = 0; i < newLoopArgs.size(); ++i) assert(newLoopArgs[i]); // signature of the new ForOp auto newForOp = builder.create(forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep(), newLoopArgs); // body of the new ForOp builder.setInsertionPointToStart(newForOp.getBody()); BlockAndValueMapping mapping; for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs())) mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]); // mapping.map(info.dotOp.a(), newForOp.getRegionIterArgs()[aArgIdx]); // mapping.map(info.dotOp.b(), newForOp.getRegionIterArgs()[bArgIdx]); for (Operation &op : forOp.getBody()->without_terminator()) { Operation *newOp = builder.clone(op, mapping); // update mapping of results for (unsigned dstIdx : llvm::seq(unsigned(0), op.getNumResults())) mapping.map(op.getResult(dstIdx), newOp->getResult(dstIdx)); // TODO: why doesn't mapping work? if (&op == info.dotOp.getOperation()) { newOp->setOperand(0, newForOp.getRegionIterArgs()[aArgIdx]); newOp->setOperand(1, newForOp.getRegionIterArgs()[bArgIdx]); } } // prefetch next iteration SmallVector orderedDeps; for (Operation &op : forOp.getLoopBody().front()) if (depOps.contains(&op)) orderedDeps.push_back(&op); assert(depOps.size() == orderedDeps.size() && "depOps contains invalid values"); BlockAndValueMapping nextMapping; DenseMap depArgsMapping; size_t argIdx = 0; for (BlockArgument arg : depArgs) { nextMapping.map(arg, newForOp.getRegionIterArgs()[argIdx + depArgsBeginIdx]); ++argIdx; } // special handling for iv & loop condition Value nextIV = builder.create(newForOp.getInductionVar().getLoc(), newForOp.getRegionIterArgs()[nextIVIdx], newForOp.getStep()); Value nextLoopCond = builder.create( nextIV.getLoc(), arith::CmpIPredicate::slt, nextIV, newForOp.getUpperBound()); for (Operation *op : orderedDeps) { // update loading mask if (op == info.aLoadOp.getOperation() || op == info.bLoadOp.getOperation()) { auto loadOp = llvm::cast(op); Value mask = loadOp.mask(); Value splatCond = builder.create(mask.getLoc(), mask.getType(), nextLoopCond); Value newMask = builder.create(mask.getLoc(), splatCond, nextMapping.lookupOrDefault(mask)); // if mask is defined outside the loop, don't update the map more than once if (!(forOp.isDefinedOutsideOfLoop(mask) && nextMapping.contains(mask))) nextMapping.map(mask, newMask); } Operation *nextOp = builder.clone(*op, nextMapping); // update mapping of results for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) { nextMapping.map(op->getResult(dstIdx), nextOp->getResult(dstIdx)); // if this is a loop-carried value, update the mapping for yield auto originYield = cast(forOp.getBody()->getTerminator()); for (OpOperand &operand : originYield->getOpOperands()) { if (operand.get() == op->getResult(dstIdx)) { size_t originIdx = operand.getOperandNumber(); size_t newArgIdx = depArgsIdx[forOp.getRegionIterArgs()[originIdx]]; BlockArgument newArg = newForOp.getRegionIterArgs()[newArgIdx]; depArgsMapping[newArg] = nextOp->getResult(dstIdx); } } } } // Finally, the YieldOp, need to sync with the order of newLoopArgs SmallVector yieldValues; for (Value v : forOp.getBody()->getTerminator()->getOperands()) yieldValues.push_back(mapping.lookup(v)); for (int i = 1; i < numStages - 1; ++i) yieldValues.push_back(newForOp.getRegionIterArgs()[aArgIdx + i]); yieldValues.push_back(nextMapping.lookup(info.dotOp.a())); for (int i = 1; i < numStages - 1; ++i) yieldValues.push_back(newForOp.getRegionIterArgs()[bArgIdx + i]); yieldValues.push_back(nextMapping.lookup(info.dotOp.b())); for (size_t i = depArgsBeginIdx; i < nextIVIdx; ++i) yieldValues.push_back(depArgsMapping.lookup(newForOp.getRegionIterArgs()[i])); yieldValues.push_back(nextIV); builder.setInsertionPointToEnd(newForOp.getBody()); builder.create(forOp.getBody()->getTerminator()->getLoc(), yieldValues); return newForOp; } // ref: mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp struct PipelinePass : public TritonGPUPipelineBase { void runOnOperation() override { // TODO: collect numStages from ModuleOp int numStages = 2; if (numStages <= 1) return; getOperation()->walk([&](scf::ForOp forOp) -> void { LoopPipeliner pipeliner(forOp, numStages); if (pipeliner.initialize().failed()) return; pipeliner.emitPrologue(); scf::ForOp newForOp = pipeliner.createNewForOp(); // replace the original loop for (unsigned i = 0; i < forOp->getNumResults(); ++i) forOp->getResult(i).replaceAllUsesWith(newForOp->getResult(i)); forOp->erase(); }); } }; } // anonymous namespace std::unique_ptr mlir::createTritonGPUPipelinePass() { return std::make_unique(); }