[BACKEND] Fix dependency analysis in pipeline (#946)

This commit is contained in:
Keren Zhou
2022-12-06 09:08:55 -08:00
committed by GitHub
parent 9490252261
commit 16e973edf2
2 changed files with 20 additions and 12 deletions

View File

@@ -123,9 +123,13 @@ void LoopPipeliner::collectDeps(Value v, int stages, DenseSet<Value> &deps) {
return; return;
if (auto arg = v.dyn_cast<BlockArgument>()) { if (auto arg = v.dyn_cast<BlockArgument>()) {
deps.insert(v); if (arg.getArgNumber() > 0) {
// Note: we have iv as the first arg, so the op idx is arg.getArgNumber()-1 // Skip the first arg (loop induction variable)
collectDeps(yieldOp->getOperand(arg.getArgNumber() - 1), stages - 1, deps); // Otherwise the op idx is arg.getArgNumber()-1
deps.insert(v);
collectDeps(yieldOp->getOperand(arg.getArgNumber() - 1), stages - 1,
deps);
}
} else { // value } else { // value
// v might be in deps, but we still need to visit v. // v might be in deps, but we still need to visit v.
// This is because v might depend on value in previous iterations // This is because v might depend on value in previous iterations
@@ -376,11 +380,11 @@ scf::ForOp LoopPipeliner::createNewForOp() {
OpBuilder builder(forOp); OpBuilder builder(forOp);
// Order of new args: // Order of new args:
// (original args), // (original args)
// (insertSliceAsync buffer at stage numStages - 1) for each load // (insertSliceAsync buffer at stage numStages - 1) for each load
// (extracted tensor) for each load // (extracted tensor) for each load
// (depArgs at stage numStages-1) // (depArgs at stage numStages - 1)
// (iv at stage numStages-1) // (iv at stage numStages - 2)
// (pipeline iteration index) // (pipeline iteration index)
// (loop iteration index) // (loop iteration index)
SmallVector<Value> newLoopArgs; SmallVector<Value> newLoopArgs;
@@ -421,6 +425,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
BlockAndValueMapping mapping; BlockAndValueMapping mapping;
for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs())) for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs()))
mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]); mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
mapping.map(forOp.getInductionVar(), newForOp.getInductionVar());
// 2.1 clone the loop body, replace original args with args of the new ForOp // 2.1 clone the loop body, replace original args with args of the new ForOp
// Insert async wait if necessary. // Insert async wait if necessary.
@@ -469,6 +474,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
Value nextLoopCond = Value nextLoopCond =
builder.create<arith::CmpIOp>(nextIV.getLoc(), arith::CmpIPredicate::slt, builder.create<arith::CmpIOp>(nextIV.getLoc(), arith::CmpIPredicate::slt,
nextIV, newForOp.getUpperBound()); nextIV, newForOp.getUpperBound());
nextMapping.map(forOp.getInductionVar(), nextIV);
// Slice index // Slice index
SmallVector<Value> nextBuffers; SmallVector<Value> nextBuffers;
@@ -598,9 +604,11 @@ scf::ForOp LoopPipeliner::createNewForOp() {
for (Value nextSlice : extractSlices) for (Value nextSlice : extractSlices)
yieldValues.push_back(nextSlice); yieldValues.push_back(nextSlice);
for (size_t i = depArgsBeginIdx; i < nextIVIdx; ++i) for (size_t i = depArgsBeginIdx; i < nextIVIdx; ++i) {
yieldValues.push_back( auto arg = newForOp.getRegionIterArgs()[i];
depArgsMapping.lookup(newForOp.getRegionIterArgs()[i])); assert(depArgsMapping.count(arg) && "Missing loop-carried value");
yieldValues.push_back(depArgsMapping[arg]);
}
yieldValues.push_back(nextIV); yieldValues.push_back(nextIV);
yieldValues.push_back(pipelineIterIdx); yieldValues.push_back(pipelineIterIdx);
yieldValues.push_back(loopIterIdx); yieldValues.push_back(loopIterIdx);

View File

@@ -257,5 +257,5 @@ def bench_layer_norm(M, N, dtype, provider, mode='backward', eps=1e-5, device='c
grad_to_none=[x], rep=500) grad_to_none=[x], rep=500)
return gbps(ms), gbps(max_ms), gbps(min_ms) return gbps(ms), gbps(max_ms), gbps(min_ms)
# test_layer_norm(1151, 8192, torch.float16)
bench_layer_norm.run(save_path='.', print_data=True) bench_layer_norm.run(save_path='.', print_data=True)