#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BlockAndValueMapping.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" //===----------------------------------------------------------------------===// // // This file implements loop software pipelining // The implementation here is inspired by the pipeline pass in Triton (-v2.0) // and SCF's LoopPipelining. // //===----------------------------------------------------------------------===// using namespace mlir; namespace ttg = triton::gpu; #define GEN_PASS_CLASSES #include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" static Type getI1SameShape(Value v) { Type vType = v.getType(); auto i1Type = IntegerType::get(vType.getContext(), 1); auto tensorType = vType.cast(); return RankedTensorType::get(tensorType.getShape(), i1Type, tensorType.getEncoding()); } namespace { class LoopPipeliner { /// cache forOp we are working on scf::ForOp forOp; /// cache YieldOp for this forOp scf::YieldOp yieldOp; /// loads to be pipelined SetVector loads; /// the value that each load will be mapped to (after layout conversion) DenseMap loadsMapping; /// load => buffer DenseMap loadsBuffer; /// load => buffer type (with shared layout after swizzling) DenseMap loadsBufferType; /// load => buffer at stage N DenseMap> loadStageBuffer; /// load => after extract DenseMap loadsExtract; /// Value pipelineIterIdx; /// Value loopIterIdx; /// comments on numStages: /// [0, numStages-1) are in the prologue /// numStages-1 is appended after the loop body int numStages; /// value (in loop) => value at stage N DenseMap> valueMapping; /// Block arguments that loads depend on DenseSet depArgs; /// Operations (inside the loop body) that loads depend on DenseSet depOps; /// collect values that v depends on and are defined inside the loop void collectDeps(Value v, int stages, DenseSet &deps); void setValueMapping(Value origin, Value newValue, int stage); Value lookupOrDefault(Value origin, int stage); /// returns a empty buffer of size ttg::AllocTensorOp allocateEmptyBuffer(Operation *op, OpBuilder &builder); /// compute type of shared buffers (with swizzled shared layouts) RankedTensorType getSwizzleType(ttg::DotOperandEncodingAttr dotOpEnc, RankedTensorType tensorType); public: LoopPipeliner(scf::ForOp forOp, int numStages) : forOp(forOp), numStages(numStages) { // cache yieldOp yieldOp = cast(forOp.getBody()->getTerminator()); } /// Collect loads to pipeline. Return success if we can pipeline this loop LogicalResult initialize(); /// emit pipelined loads (before loop body) void emitPrologue(); /// emit pipelined loads (after loop body) void emitEpilogue(); /// create the new ForOp (add new args & insert prefetched ops) scf::ForOp createNewForOp(); friend struct PipelinePass; }; // helpers void LoopPipeliner::setValueMapping(Value origin, Value newValue, int stage) { if (valueMapping.find(origin) == valueMapping.end()) valueMapping[origin] = SmallVector(numStages); valueMapping[origin][stage] = newValue; } Value LoopPipeliner::lookupOrDefault(Value origin, int stage) { if (valueMapping.find(origin) == valueMapping.end()) return origin; return valueMapping[origin][stage]; } void LoopPipeliner::collectDeps(Value v, int stages, DenseSet &deps) { // Loop-invarant value. skip if (v.getParentRegion() != &forOp.getLoopBody()) return; // Since we only need to peel the loop numStages-1 times, don't worry about // depends that are too far away if (stages < 0) return; if (auto arg = v.dyn_cast()) { deps.insert(v); // Note: we have iv as the first arg, so the op idx is arg.getArgNumber()-1 collectDeps(yieldOp->getOperand(arg.getArgNumber() - 1), stages - 1, deps); } else { // value // v might be in deps, but we still need to visit v. // This is because v might depends on value in previous iterations deps.insert(v); for (Value op : v.getDefiningOp()->getOperands()) collectDeps(op, stages, deps); } } ttg::AllocTensorOp LoopPipeliner::allocateEmptyBuffer(Operation *op, OpBuilder &builder) { // allocate a buffer for each pipelined tensor // shape: e.g. (numStages==4), <32x64xbf16> -> <4x32x64xbf16> Value convertLayout = loadsMapping[op->getResult(0)]; if (auto tensorType = convertLayout.getType().dyn_cast()) { return builder.create( convertLayout.getLoc(), loadsBufferType[op->getResult(0)]); } llvm_unreachable("Async copy's return should be of RankedTensorType"); } // TODO: I copied the code from Swizzle.cpp. Should find a way to unify the // code path. // Swizzle has to be performed before pipeline for now. If we do swizzle // after pipeline, we need to propagate the swizzled layout to all // operands that is an alias of the swizzled tensor. The alias analysis // component maybe helpful for this purpose. RankedTensorType LoopPipeliner::getSwizzleType(ttg::DotOperandEncodingAttr dotOpEnc, RankedTensorType ty) { int opIdx = dotOpEnc.getOpIdx(); int vec = 1; int maxPhase = 1; int perPhase = 1; llvm::SmallVector order; if (auto mmaEnc = dotOpEnc.getParent().dyn_cast()) { // Only support row major for now // TODO(Keren): check why column major code crashes order = {1, 0}; int version = mmaEnc.getVersion(); auto tyEncoding = ty.getEncoding().cast(); // number of rows per phase perPhase = 128 / (ty.getShape()[order[0]] * (ty.getElementType().getIntOrFloatBitWidth() / 8)); perPhase = std::max(perPhase, 1); // index of the inner dimension in `order` unsigned inner = (opIdx == 0) ? 0 : 1; if (version == 1) { maxPhase = (order[inner] == 1 ? 8 : 4) / perPhase; // TODO: handle rep (see // https://github.com/openai/triton/blob/master/lib/codegen/analysis/layout.cc#L209) } else if (version == 2) { auto eltTy = ty.getElementType(); std::vector matShape = {8, 8, 2 * 64 / eltTy.getIntOrFloatBitWidth()}; // for now, disable swizzle when using transposed int8 tensor cores if (ty.getElementType().isInteger(8) && order[0] == inner) perPhase = 1; else { if (opIdx == 0) { // compute swizzling for A operand vec = order[0] == 1 ? matShape[2] : matShape[0]; // k : m int mmaStride = order[0] == 1 ? matShape[0] : matShape[2]; maxPhase = mmaStride / perPhase; } else if (opIdx == 1) { // compute swizzling for B operand vec = order[0] == 1 ? matShape[1] : matShape[2]; // n : k int mmaStride = order[0] == 1 ? matShape[2] : matShape[1]; maxPhase = mmaStride / perPhase; } else llvm_unreachable("invalid operand index"); } } else // version not in [1, 2] llvm_unreachable("unsupported swizzling for provided MMA version"); } else { // If the layout of dot is not mma, we don't need to swizzle auto blockedEnc = dotOpEnc.getParent().cast(); order = llvm::SmallVector(blockedEnc.getOrder().begin(), blockedEnc.getOrder().end()); } auto newEncoding = ttg::SharedEncodingAttr::get(ty.getContext(), vec, perPhase, maxPhase, order); SmallVector bufferShape(ty.getShape().begin(), ty.getShape().end()); bufferShape.insert(bufferShape.begin(), numStages); return RankedTensorType::get(bufferShape, ty.getElementType(), newEncoding); } /// A load instruction can be pipelined if: /// - the load doesn't depend on any other loads (after loop peeling) /// - (?) this load is not a loop-invariant value (we should run LICM before /// this pass?) LogicalResult LoopPipeliner::initialize() { Block *loop = forOp.getBody(); // can we use forOp.walk(...) here? SmallVector allLoads; for (Operation &op : *loop) if (auto loadOp = dyn_cast(&op)) allLoads.push_back(loadOp); // Early stop: no need to continue if there is no load in the loop. if (allLoads.empty()) return failure(); // load => values that it depends on DenseMap> loadDeps; for (triton::LoadOp loadOp : allLoads) { DenseSet deps; for (Value op : loadOp->getOperands()) collectDeps(op, numStages - 1, deps); loadDeps[loadOp] = deps; } // Don't pipeline loads that depend on other loads // (Because if a load depends on another load, this load needs to wait on the // other load in the prologue, which is against the point of the pipeline // pass) for (triton::LoadOp loadOp : allLoads) { bool isCandiate = true; for (triton::LoadOp other : allLoads) { if (loadDeps[loadOp].contains(other)) { isCandiate = false; break; } } // We only pipeline loads that have one covert_layout (to dot_op) use // TODO: lift this constraint in the future if (isCandiate && loadOp.getResult().hasOneUse()) { isCandiate = false; Operation *use = *loadOp.getResult().getUsers().begin(); if (auto convertLayout = llvm::dyn_cast(use)) { if (auto tensorType = convertLayout.getResult() .getType() .dyn_cast()) { if (auto dotOpEnc = tensorType.getEncoding() .dyn_cast()) { isCandiate = true; loadsMapping[loadOp] = convertLayout; loadsBufferType[loadOp] = getSwizzleType( dotOpEnc, loadOp.getType().cast()); } } } } else isCandiate = false; if (isCandiate) loads.insert(loadOp); } // we have some loads to pipeline if (!loads.empty()) { // update depArgs & depOps for (Value loadOp : loads) { for (Value dep : loadDeps[loadOp]) { // TODO: we should record the stage that the value is depended on if (auto arg = dep.dyn_cast()) depArgs.insert(arg); else depOps.insert(dep.getDefiningOp()); } } return success(); } return failure(); } void LoopPipeliner::emitPrologue() { // llvm::errs() << "loads to pipeline...:\n"; // for (Value load : loads) // llvm::errs() << load << "\n"; OpBuilder builder(forOp); for (BlockArgument &arg : forOp.getRegionIterArgs()) { OpOperand &operand = forOp.getOpOperandForRegionIterArg(arg); setValueMapping(arg, operand.get(), 0); } // helper to construct int attribute auto intAttr = [&](int64_t val) { return builder.getI64IntegerAttr(val); }; // prologue from [0, numStage-1) Value iv = forOp.getLowerBound(); pipelineIterIdx = builder.create(iv.getLoc(), 0, 32); for (int stage = 0; stage < numStages - 1; ++stage) { // special handling for induction variable as the increment is implicit if (stage != 0) iv = builder.create(iv.getLoc(), iv, forOp.getStep()); setValueMapping(forOp.getInductionVar(), iv, stage); // special handling for loop condition as there is no condition in ForOp Value loopCond = builder.create( iv.getLoc(), arith::CmpIPredicate::slt, iv, forOp.getUpperBound()); // rematerialize peeled values SmallVector orderedDeps; for (Operation &op : forOp.getLoopBody().front()) { if (depOps.contains(&op)) orderedDeps.push_back(&op); else if (loads.contains(op.getResult(0))) orderedDeps.push_back(&op); } assert(depOps.size() + loads.size() == orderedDeps.size() && "depOps contains invalid values"); for (Operation *op : orderedDeps) { Operation *newOp = nullptr; if (loads.contains(op->getResult(0))) { // Allocate empty buffer if (stage == 0) { loadsBuffer[op->getResult(0)] = allocateEmptyBuffer(op, builder); loadStageBuffer[op->getResult(0)] = {loadsBuffer[op->getResult(0)]}; } // load => copy async if (auto loadOp = llvm::dyn_cast(op)) { Value mask = lookupOrDefault(loadOp.mask(), stage); Value newMask; if (mask) { Value splatCond = builder.create( mask.getLoc(), mask.getType(), loopCond); newMask = builder.create(mask.getLoc(), mask, splatCond); } else { newMask = builder.create( loopCond.getLoc(), getI1SameShape(loadOp), loopCond); } // TODO: check if the hardware supports async copy newOp = builder.create( op->getLoc(), loadsBuffer[loadOp].getType(), lookupOrDefault(loadOp.ptr(), stage), loadStageBuffer[loadOp][stage], pipelineIterIdx, newMask, lookupOrDefault(loadOp.other(), stage), loadOp.cache(), loadOp.evict(), loadOp.isVolatile(), /*axis*/ 0); loadStageBuffer[loadOp].push_back(newOp->getResult(0)); } else llvm_unreachable("This should be LoadOp"); } else { newOp = builder.clone(*op); // Update loop-carried uses for (unsigned opIdx = 0; opIdx < op->getNumOperands(); ++opIdx) { auto it = valueMapping.find(op->getOperand(opIdx)); if (it != valueMapping.end()) { Value v = it->second[stage]; assert(v); newOp->setOperand(opIdx, v); } // else, op at opIdx is a loop-invariant value } } // update mapping of results for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) { Value originalResult = op->getResult(dstIdx); // copy_async will update the value of its only use // TODO: load should no be used in the preheader? if (loads.contains(originalResult)) { break; // originalResult = loadsMapping[originalResult]; } setValueMapping(originalResult, newOp->getResult(dstIdx), stage); // update mapping for loop-carried values (args) for (OpOperand &operand : yieldOp->getOpOperands()) { if (operand.get() == op->getResult(dstIdx)) setValueMapping( forOp.getRegionIterArgs()[operand.getOperandNumber()], newOp->getResult(dstIdx), stage + 1); } } } // for (Operation *op : orderedDeps) pipelineIterIdx = builder.create( iv.getLoc(), pipelineIterIdx, builder.create(iv.getLoc(), 1, 32)); } // for (int stage = 0; stage < numStages - 1; ++stage) // async.wait & extract_slice builder.create(loads[0].getLoc(), loads.size() * (numStages - 2)); loopIterIdx = builder.create(iv.getLoc(), 0, 32); for (Value loadOp : loads) { auto sliceType = loadsMapping[loadOp].getType().cast(); sliceType = RankedTensorType::get(sliceType.getShape(), sliceType.getElementType(), loadsBufferType[loadOp].getEncoding()); Value extractSlice = builder.create( loadOp.getLoc(), sliceType, loadStageBuffer[loadOp][numStages - 1], SmallVector{intAttr(0), intAttr(0), intAttr(0)}, SmallVector{intAttr(1), intAttr(sliceType.getShape()[0]), intAttr(sliceType.getShape()[1])}, SmallVector{intAttr(1), intAttr(1), intAttr(1)}); loadsExtract[loadOp] = extractSlice; } // bump up loopIterIdx, this is used for getting the correct slice for the // *next* iteration loopIterIdx = builder.create( loopIterIdx.getLoc(), loopIterIdx, builder.create(loopIterIdx.getLoc(), 1, 32)); } void LoopPipeliner::emitEpilogue() { // If there's any outstanding async copies, we need to wait for them. // TODO(Keren): We may want to completely avoid the async copies in the last // few iterations by setting is_masked attribute to true. We don't want to use // the mask operand because it's a tensor but not a scalar. OpBuilder builder(forOp); OpBuilder::InsertionGuard g(builder); builder.setInsertionPointAfter(forOp); builder.create(forOp.getLoc(), 0); } scf::ForOp LoopPipeliner::createNewForOp() { OpBuilder builder(forOp); auto intAttr = [&](int64_t v) { return builder.getI64IntegerAttr(v); }; // order of new args: // (original args), // (insertSliceAsync buffer at stage numStages - 1) for each load // (extracted tensor) for each load // (depArgs at stage numStages-1) // (iv at stage numStages-1) // (pipeline iteration index) // (loop iteration index) SmallVector newLoopArgs; // We need this to update operands for yield // original block arg => new arg's idx DenseMap depArgsIdx; for (auto v : forOp.getIterOperands()) newLoopArgs.push_back(v); size_t bufferIdx = newLoopArgs.size(); for (Value loadOp : loads) newLoopArgs.push_back(loadStageBuffer[loadOp].back()); size_t loadIdx = newLoopArgs.size(); for (Value loadOp : loads) newLoopArgs.push_back(loadsExtract[loadOp]); size_t depArgsBeginIdx = newLoopArgs.size(); for (BlockArgument depArg : depArgs) { depArgsIdx[depArg] = newLoopArgs.size(); newLoopArgs.push_back(valueMapping[depArg][numStages - 1]); } size_t nextIVIdx = newLoopArgs.size(); newLoopArgs.push_back(valueMapping[forOp.getInductionVar()][numStages - 2]); newLoopArgs.push_back(pipelineIterIdx); newLoopArgs.push_back(loopIterIdx); for (size_t i = 0; i < newLoopArgs.size(); ++i) assert(newLoopArgs[i]); // 1. signature of the new ForOp auto newForOp = builder.create( forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep(), newLoopArgs); // 2. body of the new ForOp builder.setInsertionPointToStart(newForOp.getBody()); BlockAndValueMapping mapping; for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs())) mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]); // 2.1 clone the loop body, replace original args with args of the new ForOp // Insert async wait if necessary. for (Operation &op : forOp.getBody()->without_terminator()) { Operation *newOp = builder.clone(op, mapping); // update mapping of results for (unsigned dstIdx : llvm::seq(unsigned(0), op.getNumResults())) mapping.map(op.getResult(dstIdx), newOp->getResult(dstIdx)); } // 3. replace loads with block args (from prologue) for (size_t idx = 0; idx < loads.size(); ++idx) { Value load = loads[idx]; assert(load.hasOneUse() && "we assume that this load has one use (ConvertLayout)"); Value loadUse = load.getUsers().begin()->getResult(0); mapping.lookup(loadUse).replaceAllUsesWith( newForOp.getRegionIterArgs()[loadIdx + idx]); // delete old load and layout conversion mapping.lookup(loadUse).getDefiningOp()->erase(); mapping.lookup(load).getDefiningOp()->erase(); } // 4. prefetch the next iteration SmallVector orderedDeps; for (Operation &op : forOp.getLoopBody().front()) { if (depOps.contains(&op)) orderedDeps.push_back(&op); else if (loads.contains(op.getResult(0))) orderedDeps.push_back(&op); } assert(depOps.size() + loads.size() == orderedDeps.size() && "depOps contains invalid values"); BlockAndValueMapping nextMapping; DenseMap depArgsMapping; size_t argIdx = 0; for (BlockArgument arg : depArgs) { nextMapping.map(arg, newForOp.getRegionIterArgs()[argIdx + depArgsBeginIdx]); ++argIdx; } // special handling for iv & loop condition Value nextIV = builder.create( newForOp.getInductionVar().getLoc(), newForOp.getRegionIterArgs()[nextIVIdx], newForOp.getStep()); Value nextLoopCond = builder.create(nextIV.getLoc(), arith::CmpIPredicate::slt, nextIV, newForOp.getUpperBound()); // slice index SmallVector nextBuffers; SmallVector extractSlices; pipelineIterIdx = newForOp.getRegionIterArgs()[nextIVIdx + 1]; Value insertSliceIndex = builder.create( nextIV.getLoc(), pipelineIterIdx, builder.create(nextIV.getLoc(), numStages, 32)); loopIterIdx = newForOp.getRegionIterArgs()[nextIVIdx + 2]; Value extractSliceIndex = builder.create( nextIV.getLoc(), loopIterIdx, builder.create(nextIV.getLoc(), numStages, 32)); extractSliceIndex = builder.create( extractSliceIndex.getLoc(), builder.getIndexType(), extractSliceIndex); for (Operation *op : orderedDeps) { Operation *nextOp = nullptr; // update loading mask if (loads.contains(op->getResult(0))) { auto loadOp = llvm::cast(op); Value mask = loadOp.mask(); Value newMask; if (mask) { Value splatCond = builder.create( mask.getLoc(), mask.getType(), nextLoopCond); newMask = builder.create( mask.getLoc(), splatCond, nextMapping.lookupOrDefault(mask)); // if mask is defined outside the loop, don't update the map more than // once if (!(forOp.isDefinedOutsideOfLoop(mask) && nextMapping.contains(mask))) nextMapping.map(mask, newMask); newMask = nextMapping.lookupOrDefault(loadOp.mask()); } else newMask = builder.create( loadOp.getLoc(), getI1SameShape(loadOp), nextLoopCond); Value insertAsyncOp = builder.create( op->getLoc(), loadsBuffer[loadOp].getType(), nextMapping.lookupOrDefault(loadOp.ptr()), newForOp.getRegionIterArgs()[bufferIdx + nextBuffers.size()], insertSliceIndex, newMask, nextMapping.lookupOrDefault(loadOp.other()), loadOp.cache(), loadOp.evict(), loadOp.isVolatile(), /*axis*/ 0); nextBuffers.push_back(insertAsyncOp); auto sliceType = loadsMapping[loadOp].getType().cast(); sliceType = RankedTensorType::get(sliceType.getShape(), sliceType.getElementType(), loadsBufferType[loadOp].getEncoding()); nextOp = builder.create( op->getLoc(), sliceType, insertAsyncOp, SmallVector{extractSliceIndex, intAttr(0), intAttr(0)}, SmallVector{intAttr(1), intAttr(sliceType.getShape()[0]), intAttr(sliceType.getShape()[1])}, SmallVector{intAttr(1), intAttr(1), intAttr(1)}); extractSlices.push_back(nextOp->getResult(0)); } else nextOp = builder.clone(*op, nextMapping); // update mapping of results for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) { nextMapping.map(op->getResult(dstIdx), nextOp->getResult(dstIdx)); // if this is a loop-carried value, update the mapping for yield auto originYield = cast(forOp.getBody()->getTerminator()); for (OpOperand &operand : originYield->getOpOperands()) { if (operand.get() == op->getResult(dstIdx)) { size_t originIdx = operand.getOperandNumber(); size_t newArgIdx = depArgsIdx[forOp.getRegionIterArgs()[originIdx]]; BlockArgument newArg = newForOp.getRegionIterArgs()[newArgIdx]; depArgsMapping[newArg] = nextOp->getResult(dstIdx); } } } } { OpBuilder::InsertionGuard guard(builder); for (Operation &op : *newForOp.getBody()) { if (auto dotOp = llvm::dyn_cast(&op)) { builder.setInsertionPoint(&op); auto dotType = dotOp.getType().cast(); Value a = dotOp.a(); Value b = dotOp.b(); auto layoutCast = [&](Value dotOperand, int opIdx) -> Value { auto tensorType = dotOperand.getType().cast(); if (!tensorType.getEncoding().isa()) { auto newEncoding = ttg::DotOperandEncodingAttr::get( tensorType.getContext(), opIdx, dotType.getEncoding()); auto newType = RankedTensorType::get(tensorType.getShape(), tensorType.getElementType(), newEncoding); return builder.create(dotOperand.getLoc(), newType, dotOperand); } return dotOperand; }; a = layoutCast(a, 0); b = layoutCast(b, 1); dotOp->setOperand(0, a); dotOp->setOperand(1, b); } } } // async.wait & extract_slice Operation *asyncWait = builder.create( loads[0].getLoc(), loads.size() * (numStages - 2)); for (auto it = extractSlices.rbegin(); it != extractSlices.rend(); ++it) { // move extract_slice after asyncWait it->getDefiningOp()->moveAfter(asyncWait); } // bump iteration count pipelineIterIdx = builder.create( nextIV.getLoc(), pipelineIterIdx, builder.create(nextIV.getLoc(), 1, 32)); loopIterIdx = builder.create( nextIV.getLoc(), loopIterIdx, builder.create(nextIV.getLoc(), 1, 32)); // Finally, the YieldOp, need to sync with the order of newLoopArgs SmallVector yieldValues; for (Value v : forOp.getBody()->getTerminator()->getOperands()) yieldValues.push_back(mapping.lookup(v)); for (Value nextBuffer : nextBuffers) yieldValues.push_back(nextBuffer); for (Value nextSlice : extractSlices) yieldValues.push_back(nextSlice); for (size_t i = depArgsBeginIdx; i < nextIVIdx; ++i) yieldValues.push_back( depArgsMapping.lookup(newForOp.getRegionIterArgs()[i])); yieldValues.push_back(nextIV); yieldValues.push_back(pipelineIterIdx); yieldValues.push_back(loopIterIdx); builder.setInsertionPointToEnd(newForOp.getBody()); builder.create(forOp.getBody()->getTerminator()->getLoc(), yieldValues); return newForOp; } // ref: mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp struct PipelinePass : public TritonGPUPipelineBase { PipelinePass() = default; PipelinePass(int numStages) { this->numStages = numStages; } void runOnOperation() override { int numStages = this->numStages; if (numStages <= 1) return; getOperation()->walk([&](scf::ForOp forOp) -> void { LoopPipeliner pipeliner(forOp, numStages); if (pipeliner.initialize().failed()) return; pipeliner.emitPrologue(); scf::ForOp newForOp = pipeliner.createNewForOp(); pipeliner.emitEpilogue(); // replace the original loop for (unsigned i = 0; i < forOp->getNumResults(); ++i) forOp->getResult(i).replaceAllUsesWith(newForOp->getResult(i)); forOp->erase(); }); } }; } // anonymous namespace std::unique_ptr mlir::createTritonGPUPipelinePass(int numStages) { return std::make_unique(numStages); }