diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp index a99ae7e4b..ceb7bf4d6 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp @@ -163,6 +163,10 @@ LogicalResult LoopPipeliner::initialize() { } void LoopPipeliner::emitPrologue() { + // llvm::errs() << "loads to pipeline...:\n"; + // for (Value load : loads) + // llvm::errs() << load << "\n"; + OpBuilder builder(forOp); for (BlockArgument &arg : forOp.getRegionIterArgs()) { OpOperand &operand = forOp.getOpOperandForRegionIterArg(arg); @@ -302,7 +306,7 @@ scf::ForOp LoopPipeliner::createNewForOp() { for (size_t idx = 0; idx < loads.size(); ++idx) { Value load = loads[idx]; mapping.lookup(load).replaceAllUsesWith( - newForOp.getRegionIterArgs()[loadIdx+idx]); + newForOp.getRegionIterArgs()[loadIdx + idx*(numStages-1)]); } @@ -377,17 +381,12 @@ scf::ForOp LoopPipeliner::createNewForOp() { SmallVector yieldValues; for (Value v : forOp.getBody()->getTerminator()->getOperands()) yieldValues.push_back(mapping.lookup(v)); - // for (int i = 1; i < numStages - 1; ++i) - // yieldValues.push_back(newForOp.getRegionIterArgs()[aArgIdx + i]); - // yieldValues.push_back(nextMapping.lookup(info.dotOp.a())); - // for (int i = 1; i < numStages - 1; ++i) - // yieldValues.push_back(newForOp.getRegionIterArgs()[bArgIdx + i]); - // yieldValues.push_back(nextMapping.lookup(info.dotOp.b())); + // shift pipelined args by 1 for (size_t idx = 0; idx < loads.size(); ++idx) { Value load = loads[idx]; for (int stage = 1; stage < numStages - 1; ++stage) { yieldValues.push_back(newForOp.getRegionIterArgs()[ - loadIdx + idx*(numStages-1) + stage-1 + loadIdx + idx*(numStages-1) + stage ]); } yieldValues.push_back(nextMapping.lookup(load));