From 16e973edf2a92f792ff84c6fdf5a50ceeab3bf9f Mon Sep 17 00:00:00 2001 From: Keren Zhou Date: Tue, 6 Dec 2022 09:08:55 -0800 Subject: [PATCH] [BACKEND] Fix dependency analysis in pipeline (#946) --- lib/Dialect/TritonGPU/Transforms/Pipeline.cpp | 30 ++++++++++++------- python/tutorials/05-layer-norm.py | 2 +- 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp index b2292d901..dd5aa9d2a 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp @@ -123,9 +123,13 @@ void LoopPipeliner::collectDeps(Value v, int stages, DenseSet &deps) { return; if (auto arg = v.dyn_cast()) { - deps.insert(v); - // Note: we have iv as the first arg, so the op idx is arg.getArgNumber()-1 - collectDeps(yieldOp->getOperand(arg.getArgNumber() - 1), stages - 1, deps); + if (arg.getArgNumber() > 0) { + // Skip the first arg (loop induction variable) + // Otherwise the op idx is arg.getArgNumber()-1 + deps.insert(v); + collectDeps(yieldOp->getOperand(arg.getArgNumber() - 1), stages - 1, + deps); + } } else { // value // v might be in deps, but we still need to visit v. // This is because v might depend on value in previous iterations @@ -376,11 +380,11 @@ scf::ForOp LoopPipeliner::createNewForOp() { OpBuilder builder(forOp); // Order of new args: - // (original args), - // (insertSliceAsync buffer at stage numStages - 1) for each load - // (extracted tensor) for each load - // (depArgs at stage numStages-1) - // (iv at stage numStages-1) + // (original args) + // (insertSliceAsync buffer at stage numStages - 1) for each load + // (extracted tensor) for each load + // (depArgs at stage numStages - 1) + // (iv at stage numStages - 2) // (pipeline iteration index) // (loop iteration index) SmallVector newLoopArgs; @@ -421,6 +425,7 @@ scf::ForOp LoopPipeliner::createNewForOp() { BlockAndValueMapping mapping; for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs())) mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]); + mapping.map(forOp.getInductionVar(), newForOp.getInductionVar()); // 2.1 clone the loop body, replace original args with args of the new ForOp // Insert async wait if necessary. @@ -469,6 +474,7 @@ scf::ForOp LoopPipeliner::createNewForOp() { Value nextLoopCond = builder.create(nextIV.getLoc(), arith::CmpIPredicate::slt, nextIV, newForOp.getUpperBound()); + nextMapping.map(forOp.getInductionVar(), nextIV); // Slice index SmallVector nextBuffers; @@ -598,9 +604,11 @@ scf::ForOp LoopPipeliner::createNewForOp() { for (Value nextSlice : extractSlices) yieldValues.push_back(nextSlice); - for (size_t i = depArgsBeginIdx; i < nextIVIdx; ++i) - yieldValues.push_back( - depArgsMapping.lookup(newForOp.getRegionIterArgs()[i])); + for (size_t i = depArgsBeginIdx; i < nextIVIdx; ++i) { + auto arg = newForOp.getRegionIterArgs()[i]; + assert(depArgsMapping.count(arg) && "Missing loop-carried value"); + yieldValues.push_back(depArgsMapping[arg]); + } yieldValues.push_back(nextIV); yieldValues.push_back(pipelineIterIdx); yieldValues.push_back(loopIterIdx); diff --git a/python/tutorials/05-layer-norm.py b/python/tutorials/05-layer-norm.py index 1cefc60b9..110351af5 100644 --- a/python/tutorials/05-layer-norm.py +++ b/python/tutorials/05-layer-norm.py @@ -257,5 +257,5 @@ def bench_layer_norm(M, N, dtype, provider, mode='backward', eps=1e-5, device='c grad_to_none=[x], rep=500) return gbps(ms), gbps(max_ms), gbps(min_ms) - +# test_layer_norm(1151, 8192, torch.float16) bench_layer_norm.run(save_path='.', print_data=True)