[BACKEND] Fix dependency analysis in pipeline (#946)

This commit is contained in:
Keren Zhou
2022-12-06 09:08:55 -08:00
committed by GitHub
parent 9490252261
commit 16e973edf2
2 changed files with 20 additions and 12 deletions

View File

@@ -123,9 +123,13 @@ void LoopPipeliner::collectDeps(Value v, int stages, DenseSet<Value> &deps) {
return;
if (auto arg = v.dyn_cast<BlockArgument>()) {
deps.insert(v);
// Note: we have iv as the first arg, so the op idx is arg.getArgNumber()-1
collectDeps(yieldOp->getOperand(arg.getArgNumber() - 1), stages - 1, deps);
if (arg.getArgNumber() > 0) {
// Skip the first arg (loop induction variable)
// Otherwise the op idx is arg.getArgNumber()-1
deps.insert(v);
collectDeps(yieldOp->getOperand(arg.getArgNumber() - 1), stages - 1,
deps);
}
} else { // value
// v might be in deps, but we still need to visit v.
// This is because v might depend on value in previous iterations
@@ -376,11 +380,11 @@ scf::ForOp LoopPipeliner::createNewForOp() {
OpBuilder builder(forOp);
// Order of new args:
// (original args),
// (insertSliceAsync buffer at stage numStages - 1) for each load
// (extracted tensor) for each load
// (depArgs at stage numStages-1)
// (iv at stage numStages-1)
// (original args)
// (insertSliceAsync buffer at stage numStages - 1) for each load
// (extracted tensor) for each load
// (depArgs at stage numStages - 1)
// (iv at stage numStages - 2)
// (pipeline iteration index)
// (loop iteration index)
SmallVector<Value> newLoopArgs;
@@ -421,6 +425,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
BlockAndValueMapping mapping;
for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs()))
mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
mapping.map(forOp.getInductionVar(), newForOp.getInductionVar());
// 2.1 clone the loop body, replace original args with args of the new ForOp
// Insert async wait if necessary.
@@ -469,6 +474,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
Value nextLoopCond =
builder.create<arith::CmpIOp>(nextIV.getLoc(), arith::CmpIPredicate::slt,
nextIV, newForOp.getUpperBound());
nextMapping.map(forOp.getInductionVar(), nextIV);
// Slice index
SmallVector<Value> nextBuffers;
@@ -598,9 +604,11 @@ scf::ForOp LoopPipeliner::createNewForOp() {
for (Value nextSlice : extractSlices)
yieldValues.push_back(nextSlice);
for (size_t i = depArgsBeginIdx; i < nextIVIdx; ++i)
yieldValues.push_back(
depArgsMapping.lookup(newForOp.getRegionIterArgs()[i]));
for (size_t i = depArgsBeginIdx; i < nextIVIdx; ++i) {
auto arg = newForOp.getRegionIterArgs()[i];
assert(depArgsMapping.count(arg) && "Missing loop-carried value");
yieldValues.push_back(depArgsMapping[arg]);
}
yieldValues.push_back(nextIV);
yieldValues.push_back(pipelineIterIdx);
yieldValues.push_back(loopIterIdx);

View File

@@ -257,5 +257,5 @@ def bench_layer_norm(M, N, dtype, provider, mode='backward', eps=1e-5, device='c
grad_to_none=[x], rep=500)
return gbps(ms), gbps(max_ms), gbps(min_ms)
# test_layer_norm(1151, 8192, torch.float16)
bench_layer_norm.run(save_path='.', print_data=True)