From 7e0e7ec365d69cf0b6cbc2848829d1f078bf530d Mon Sep 17 00:00:00 2001 From: Yan Da Date: Sat, 14 May 2022 22:04:36 +0800 Subject: [PATCH] more progress on the pipeline pass --- lib/Dialect/TritonGPU/Transforms/Pipeline.cpp | 147 ++++++++++++++++-- 1 file changed, 131 insertions(+), 16 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp index a95578793..ec621d99f 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp @@ -1,11 +1,13 @@ #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "mlir/IR/BlockAndValueMapping.h" + //===----------------------------------------------------------------------===// // // This file implements loop software pipelining -// The implementation here is inspired by the pipeline pass in Triton 2.0 and -// SCF's LoopPipelining. +// The implementation here is inspired by the pipeline pass in Triton (-v2.0) +// and SCF's LoopPipelining. // //===----------------------------------------------------------------------===// @@ -53,6 +55,9 @@ public: /// void emitPrologue(); + /// create the new ForOp (add new args & insert prefetched ops) + scf::ForOp createNewForOp(); + friend class PipelinePass; }; @@ -66,10 +71,18 @@ void LoopPipeliner::setValueMapping(Value origin, Value newValue, int stage) { void LoopPipeliner::collectDeps(Value v) { if (v.getParentRegion() != &forOp.getLoopBody()) return; - if (auto arg = v.dyn_cast()) + if (auto arg = v.dyn_cast()) { + if (depArgs.contains(arg)) + return; depArgs.insert(arg); - else { // value + // we also need to rematerialize this arg + auto yield = cast(forOp.getBody()->getTerminator()); + // Note: we have iv as the first arg, so the op idx is arg.getArgNumber()-1 + collectDeps(yield->getOperand(arg.getArgNumber() - 1)); + } else { // value Operation *defOp = v.getDefiningOp(); + if (depOps.contains(defOp)) + return; depOps.insert(defOp); for (Value op : defOp->getOperands()) collectDeps(op); @@ -80,13 +93,11 @@ void LoopPipeliner::collectDeps(Value v) { /// - the pointer is a block argument (redefined inside the loop) /// - the load has only a single use in a dot instruction LogicalResult LoopPipeliner::initialize() { - Region &bodyRegion = forOp.getLoopBody(); - assert(bodyRegion.hasOneBlock()); - Block &loop = bodyRegion.front(); + Block *loop = forOp.getBody(); // TODO: can we use forOp.walk(...) here? SmallVector dots; - for (Operation &op : loop) { + for (Operation &op : *loop) { if (auto dotOp = dyn_cast(&op)) { dots.push_back(dotOp); } @@ -129,7 +140,8 @@ void LoopPipeliner::emitPrologue() { setValueMapping(arg, operand.get(), 0); } - // pro + // prologue from [0, numStage-1) + auto yield = cast(forOp.getBody()->getTerminator()); Value iv = forOp.getInductionVar(); for (int stage = 0; stage < numStages - 1; ++stage) { // special handling for induction variable as the increment is implicit @@ -162,12 +174,117 @@ void LoopPipeliner::emitPrologue() { // update mapping of results for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) { setValueMapping(op->getResult(dstIdx), newOp->getResult(dstIdx), stage); - // TODO: update mapping for loop-carried values (args) + // update mapping for loop-carried values (args) + for (OpOperand &operand : yield->getOpOperands()) { + if (operand.get() == op->getResult(dstIdx)) + setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()], + newOp->getResult(dstIdx), stage + 1); + } } } } } +scf::ForOp LoopPipeliner::createNewForOp() { + OpBuilder builder(forOp); + + // order of new args: + // (original args), + // (a at stage[0, numStages-1)), (b at stage[0, numStages-1)) + // (depArgs at stage numStages-1) + // (iv at stage numStages-1), (loopCond at stage numStages-1) + SmallVector newLoopArgs; + for (auto v : forOp.getIterOperands()) + newLoopArgs.push_back(v); + size_t aArgIdx = newLoopArgs.size(); + for (int i = 0; i < numStages - 1; ++i) + newLoopArgs.push_back(valueMapping[info.dotOp.a()][i]); + size_t bArgIdx = newLoopArgs.size(); + for (int i = 0; i < numStages - 1; ++i) + newLoopArgs.push_back(valueMapping[info.dotOp.b()][i]); + size_t depArgsBeginIdx = newLoopArgs.size(); + for (BlockArgument depArg : depArgs) + newLoopArgs.push_back(valueMapping[depArg][numStages-1]); + size_t nextIVIdx = newLoopArgs.size(); + newLoopArgs.push_back(valueMapping[forOp.getInductionVar()][numStages-1]); + newLoopArgs.push_back(loopConds[numStages-1]); + + // signature of the new ForOp + auto newForOp = builder.create(forOp.getLoc(), + forOp.getLowerBound(), + forOp.getUpperBound(), + forOp.getStep(), + newLoopArgs); + + // body of the new ForOp + builder.setInsertionPointToStart(newForOp.getBody()); + BlockAndValueMapping mapping; + for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs())) + mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]); + mapping.map(info.dotOp.a(), newForOp.getRegionIterArgs()[aArgIdx]); + mapping.map(info.dotOp.b(), newForOp.getRegionIterArgs()[bArgIdx]); + for (Operation &op : forOp.getBody()->without_terminator()) { + Operation *newOp = builder.clone(op, mapping); + // update mapping of results + for (unsigned dstIdx : llvm::seq(unsigned(0), op.getNumResults())) + mapping.map(op.getResult(dstIdx), newOp->getResult(dstIdx)); + } + // prefetch next iteration + SmallVector orderedDeps; + for (Operation &op : forOp.getLoopBody().front()) + if (depOps.contains(&op)) + orderedDeps.push_back(&op); + assert(depOps.size() == orderedDeps.size() && "depOps contains invalid values"); + BlockAndValueMapping nextMapping; + BlockAndValueMapping depArgsMapping; + size_t argIdx = 0; + for (BlockArgument arg : depArgs) { + nextMapping.map(arg, newForOp.getRegionIterArgs()[argIdx + depArgsBeginIdx]); + ++argIdx; + } + // special handling for iv & loop condition + Value nextIV = builder.create(newForOp.getInductionVar().getLoc(), + newForOp.getRegionIterArgs()[nextIVIdx], + newForOp.getStep()); + Value nextLoopCond = builder.create( + nextIV.getLoc(), arith::CmpIPredicate::slt, + nextIV, newForOp.getUpperBound()); + for (Operation *op : orderedDeps) { + // update loading mask + if (op == info.aLoadOp.getOperation() || op == info.bLoadOp.getOperation()) { + auto loadOp = llvm::cast(op); + Value mask = loadOp.mask(); + Value splatCond = builder.create(mask.getLoc(), + mask.getType(), + nextLoopCond); + Value newMask = builder.create(mask.getLoc(), + splatCond, + nextMapping.lookupOrDefault(mask)); + nextMapping.map(mask, newMask); + } + Operation *nextOp = builder.clone(*op, nextMapping); + // update mapping of results + for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) + nextMapping.map(op->getResult(dstIdx), nextOp->getResult(dstIdx)); + } + + // Finally, the YieldOp, need to sync with the order of newLoopArgs + SmallVector yieldValues; + for (Value v : forOp.getBody()->getTerminator()->getOperands()) + yieldValues.push_back(mapping.lookup(v)); + for (int i = 1; i < numStages - 1; ++i) + yieldValues.push_back(newForOp.getRegionIterArgs()[aArgIdx + i]); + yieldValues.push_back(nextMapping.lookup(info.aLoadOp.getResult())); + for (int i = 1; i < numStages - 1; ++i) + yieldValues.push_back(newForOp.getRegionIterArgs()[bArgIdx + i]); + yieldValues.push_back(nextMapping.lookup(info.bLoadOp.getResult())); + // TODO: deps + // + yieldValues.push_back(nextIV); + yieldValues.push_back(nextLoopCond); + return newForOp; +} + // ref: mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp struct PipelinePass : public TritonGPUPipelineBase { void runOnOperation() override { @@ -186,15 +303,13 @@ struct PipelinePass : public TritonGPUPipelineBase { llvm::errs() << "candidate for pipelining: " << pipeliner.info.dotOp << "\n"; - // pipeliner.emitPrologue(); + pipeliner.emitPrologue(); - // scf::ForOp newForOp = pipeliner.createNewForOp(); + scf::ForOp newForOp = pipeliner.createNewForOp(); // // replace the original loop - // if (forOp->getNumResults() > 0) - // rewriter.replaceOp(forOp, newForOp->getResults()); - // else - // rewriter.eraseOp(forOp); + // forOp->replaceAllUsesWith(newForOp->getResults()); + // forOp->erase(); }); } };