[BACKEND] Fix dependency analysis in pipeline (#946)
This commit is contained in:
@@ -123,9 +123,13 @@ void LoopPipeliner::collectDeps(Value v, int stages, DenseSet<Value> &deps) {
|
||||
return;
|
||||
|
||||
if (auto arg = v.dyn_cast<BlockArgument>()) {
|
||||
deps.insert(v);
|
||||
// Note: we have iv as the first arg, so the op idx is arg.getArgNumber()-1
|
||||
collectDeps(yieldOp->getOperand(arg.getArgNumber() - 1), stages - 1, deps);
|
||||
if (arg.getArgNumber() > 0) {
|
||||
// Skip the first arg (loop induction variable)
|
||||
// Otherwise the op idx is arg.getArgNumber()-1
|
||||
deps.insert(v);
|
||||
collectDeps(yieldOp->getOperand(arg.getArgNumber() - 1), stages - 1,
|
||||
deps);
|
||||
}
|
||||
} else { // value
|
||||
// v might be in deps, but we still need to visit v.
|
||||
// This is because v might depend on value in previous iterations
|
||||
@@ -376,11 +380,11 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
OpBuilder builder(forOp);
|
||||
|
||||
// Order of new args:
|
||||
// (original args),
|
||||
// (insertSliceAsync buffer at stage numStages - 1) for each load
|
||||
// (extracted tensor) for each load
|
||||
// (depArgs at stage numStages-1)
|
||||
// (iv at stage numStages-1)
|
||||
// (original args)
|
||||
// (insertSliceAsync buffer at stage numStages - 1) for each load
|
||||
// (extracted tensor) for each load
|
||||
// (depArgs at stage numStages - 1)
|
||||
// (iv at stage numStages - 2)
|
||||
// (pipeline iteration index)
|
||||
// (loop iteration index)
|
||||
SmallVector<Value> newLoopArgs;
|
||||
@@ -421,6 +425,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
BlockAndValueMapping mapping;
|
||||
for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs()))
|
||||
mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
|
||||
mapping.map(forOp.getInductionVar(), newForOp.getInductionVar());
|
||||
|
||||
// 2.1 clone the loop body, replace original args with args of the new ForOp
|
||||
// Insert async wait if necessary.
|
||||
@@ -469,6 +474,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
Value nextLoopCond =
|
||||
builder.create<arith::CmpIOp>(nextIV.getLoc(), arith::CmpIPredicate::slt,
|
||||
nextIV, newForOp.getUpperBound());
|
||||
nextMapping.map(forOp.getInductionVar(), nextIV);
|
||||
|
||||
// Slice index
|
||||
SmallVector<Value> nextBuffers;
|
||||
@@ -598,9 +604,11 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
for (Value nextSlice : extractSlices)
|
||||
yieldValues.push_back(nextSlice);
|
||||
|
||||
for (size_t i = depArgsBeginIdx; i < nextIVIdx; ++i)
|
||||
yieldValues.push_back(
|
||||
depArgsMapping.lookup(newForOp.getRegionIterArgs()[i]));
|
||||
for (size_t i = depArgsBeginIdx; i < nextIVIdx; ++i) {
|
||||
auto arg = newForOp.getRegionIterArgs()[i];
|
||||
assert(depArgsMapping.count(arg) && "Missing loop-carried value");
|
||||
yieldValues.push_back(depArgsMapping[arg]);
|
||||
}
|
||||
yieldValues.push_back(nextIV);
|
||||
yieldValues.push_back(pipelineIterIdx);
|
||||
yieldValues.push_back(loopIterIdx);
|
||||
|
@@ -257,5 +257,5 @@ def bench_layer_norm(M, N, dtype, provider, mode='backward', eps=1e-5, device='c
|
||||
grad_to_none=[x], rep=500)
|
||||
return gbps(ms), gbps(max_ms), gbps(min_ms)
|
||||
|
||||
|
||||
# test_layer_norm(1151, 8192, torch.float16)
|
||||
bench_layer_norm.run(save_path='.', print_data=True)
|
||||
|
Reference in New Issue
Block a user