[OPTIMIZER] Pipeline async buffer (#110)

This commit is contained in:
Philippe Tillet
2022-09-09 11:01:14 -07:00
committed by GitHub
parent 2a852044d9
commit 9bd5a3dcd2
2 changed files with 90 additions and 57 deletions

View File

@@ -34,6 +34,14 @@ class LoopPipeliner {
DenseMap<Value, Value> loadsMapping;
/// load => buffer
DenseMap<Value, Value> loadsBuffer;
/// load => buffer at stage N
DenseMap<Value, SmallVector<Value>> loadStageBuffer;
/// load => after extract
DenseMap<Value, Value> loadsExtract;
///
Value pipelineIterIdx;
///
Value loopIterIdx;
/// value (in loop) => value at stage N
DenseMap<Value, SmallVector<Value>> valueMapping;
@@ -237,6 +245,7 @@ void LoopPipeliner::emitPrologue() {
// prologue from [0, numStage-1)
Value iv = forOp.getLowerBound();
pipelineIterIdx = builder.create<arith::ConstantIntOp>(iv.getLoc(), 0, 32);
for (int stage = 0; stage < numStages - 1; ++stage) {
// special handling for induction variable as the increment is implicit
if (stage != 0)
@@ -261,22 +270,21 @@ void LoopPipeliner::emitPrologue() {
Operation *newOp = nullptr;
if (loads.contains(op->getResult(0))) {
// Allocate empty buffer
if (stage == 0)
if (stage == 0) {
loadsBuffer[op->getResult(0)] = allocateEmptyBuffer(op, builder);
loadStageBuffer[op->getResult(0)] = {loadsBuffer[op->getResult(0)]};
}
// load => copy async
// TODO: check if the hardware supports copyasync
if (auto loadOp = llvm::dyn_cast<triton::LoadOp>(op)) {
Value sliceIndex = builder.create<arith::IndexCastOp>(
iv.getLoc(), builder.getI32Type(), iv);
Value insertAsyncOp = builder.create<triton::gpu::InsertSliceAsyncOp>(
newOp = builder.create<triton::gpu::InsertSliceAsyncOp>(
op->getLoc(), loadsBuffer[loadOp].getType(),
lookupOrDefault(loadOp.ptr(), stage), loadsBuffer[loadOp],
sliceIndex, lookupOrDefault(loadOp.mask(), stage),
lookupOrDefault(loadOp.ptr(), stage),
loadStageBuffer[loadOp][stage], pipelineIterIdx,
lookupOrDefault(loadOp.mask(), stage),
lookupOrDefault(loadOp.other(), stage), loadOp.cache(),
loadOp.evict(), loadOp.isVolatile(), /*axis*/ 0);
newOp = builder.create<triton::gpu::ExtractSliceOp>(
op->getLoc(), loadsMapping[loadOp].getType(), insertAsyncOp,
sliceIndex, /*axis*/ 0);
loadStageBuffer[loadOp].push_back(newOp->getResult(0));
} else
llvm_unreachable("This should be LoadOp");
} else {
@@ -294,9 +302,11 @@ void LoopPipeliner::emitPrologue() {
// If this is a load/async_copy, we need to update the mask
if (llvm::isa<triton::LoadOp, triton::gpu::InsertSliceAsyncOp>(newOp)) {
Value mask = newOp->getOperand(1);
Value mask = llvm::isa<triton::LoadOp>(newOp) ? newOp->getOperand(1)
: newOp->getOperand(3);
// assert(I1 or TensorOf<[I1]>);
OpBuilder::InsertionGuard g(builder);
// TODO: move this out of the loop
builder.setInsertionPoint(newOp);
Value splatCond = builder.create<triton::SplatOp>(
mask.getLoc(), mask.getType(), loopCond);
@@ -313,8 +323,11 @@ void LoopPipeliner::emitPrologue() {
for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) {
Value originalResult = op->getResult(dstIdx);
// copy_async will update the value of its only use
if (loads.contains(originalResult))
originalResult = loadsMapping[originalResult];
// TODO: load should no be used in the preheader?
if (loads.contains(originalResult)) {
break;
// originalResult = loadsMapping[originalResult];
}
setValueMapping(originalResult, newOp->getResult(dstIdx), stage);
// update mapping for loop-carried values (args)
for (OpOperand &operand : yieldOp->getOpOperands()) {
@@ -325,18 +338,25 @@ void LoopPipeliner::emitPrologue() {
}
}
}
}
pipelineIterIdx = builder.create<arith::AddIOp>(
iv.getLoc(), pipelineIterIdx,
builder.create<arith::ConstantIntOp>(iv.getLoc(), 1, 32));
} // for (int stage = 0; stage < numStages - 1; ++stage)
// async.wait & extract_slice
Operation *asyncWait = builder.create<triton::gpu::AsyncWaitOp>(
loads[0].getLoc(), loads.size() * (numStages - 2));
for (int i = numStages - 2; i >= 0; --i) {
for (auto it = loads.rbegin(); it != loads.rend(); ++it) {
// move extract_slice after asyncWait
Value load = *it;
valueMapping[loadsMapping[load]][i].getDefiningOp()->moveAfter(asyncWait);
}
for (Value loadOp : loads) {
Value extractSlice = builder.create<triton::gpu::ExtractSliceOp>(
loadOp.getLoc(), loadsMapping[loadOp].getType(),
loadStageBuffer[loadOp][numStages - 1],
builder.create<arith::ConstantIntOp>(loadOp.getLoc(), 0, 32),
/*axis*/ 0);
loadsExtract[loadOp] = extractSlice;
}
loopIterIdx = builder.create<arith::ConstantIntOp>(iv.getLoc(), 0, 32);
}
scf::ForOp LoopPipeliner::createNewForOp() {
@@ -344,10 +364,12 @@ scf::ForOp LoopPipeliner::createNewForOp() {
// order of new args:
// (original args),
// for each load result (after layout conversion) x:
// (x at stage[0, numStages-1))
// (insertSliceAsync buffer at stage numStages - 1) for each load
// (extracted tensor) for each load
// (depArgs at stage numStages-1)
// (iv at stage numStages-1)
// (pipeline iteration index)
// (loop iteration index)
SmallVector<Value> newLoopArgs;
// We need this to update operands for yield
// original block arg => new arg's idx
@@ -355,10 +377,12 @@ scf::ForOp LoopPipeliner::createNewForOp() {
for (auto v : forOp.getIterOperands())
newLoopArgs.push_back(v);
size_t bufferIdx = newLoopArgs.size();
for (Value loadOp : loads)
newLoopArgs.push_back(loadStageBuffer[loadOp].back());
size_t loadIdx = newLoopArgs.size();
for (Value loadOp : loads)
for (int i = 0; i < numStages - 1; ++i)
newLoopArgs.push_back(valueMapping[loadsMapping[loadOp]][i]);
newLoopArgs.push_back(loadsExtract[loadOp]);
size_t depArgsBeginIdx = newLoopArgs.size();
for (BlockArgument depArg : depArgs) {
@@ -368,6 +392,8 @@ scf::ForOp LoopPipeliner::createNewForOp() {
size_t nextIVIdx = newLoopArgs.size();
newLoopArgs.push_back(valueMapping[forOp.getInductionVar()][numStages - 2]);
newLoopArgs.push_back(pipelineIterIdx);
newLoopArgs.push_back(loopIterIdx);
for (size_t i = 0; i < newLoopArgs.size(); ++i)
assert(newLoopArgs[i]);
@@ -399,7 +425,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
"we assume that this load has one use (ConvertLayout)");
Value loadUse = load.getUsers().begin()->getResult(0);
mapping.lookup(loadUse).replaceAllUsesWith(
newForOp.getRegionIterArgs()[loadIdx + idx * (numStages - 1)]);
newForOp.getRegionIterArgs()[loadIdx + idx]);
// delete old load and layout conversion
mapping.lookup(loadUse).getDefiningOp()->erase();
mapping.lookup(load).getDefiningOp()->erase();
@@ -432,12 +458,18 @@ scf::ForOp LoopPipeliner::createNewForOp() {
nextIV, newForOp.getUpperBound());
// slice index
SmallVector<Value> nextBuffers;
SmallVector<Value> extractSlices;
Value sliceIndex = builder.create<arith::IndexCastOp>(
nextIV.getLoc(), builder.getI32Type(), nextIV);
sliceIndex = builder.create<arith::RemSIOp>(
nextIV.getLoc(), sliceIndex,
pipelineIterIdx = newForOp.getRegionIterArgs()[nextIVIdx + 1];
Value insertSliceIndex = builder.create<arith::RemSIOp>(
nextIV.getLoc(), pipelineIterIdx,
builder.create<arith::ConstantIntOp>(nextIV.getLoc(), numStages, 32));
loopIterIdx = newForOp.getRegionIterArgs()[nextIVIdx + 2];
Value extractSliceIndex = builder.create<arith::RemSIOp>(
nextIV.getLoc(), loopIterIdx,
builder.create<arith::ConstantIntOp>(nextIV.getLoc(), numStages, 32));
for (Operation *op : orderedDeps) {
Operation *nextOp = nullptr;
// TODO(da): does this work if loadOp has no mask?
@@ -457,13 +489,15 @@ scf::ForOp LoopPipeliner::createNewForOp() {
}
Value insertAsyncOp = builder.create<triton::gpu::InsertSliceAsyncOp>(
op->getLoc(), loadsBuffer[loadOp].getType(),
nextMapping.lookupOrDefault(loadOp.ptr()), loadsBuffer[loadOp],
sliceIndex, nextMapping.lookupOrDefault(loadOp.mask()),
nextMapping.lookupOrDefault(loadOp.ptr()),
newForOp.getRegionIterArgs()[bufferIdx + nextBuffers.size()],
insertSliceIndex, nextMapping.lookupOrDefault(loadOp.mask()),
nextMapping.lookupOrDefault(loadOp.other()), loadOp.cache(),
loadOp.evict(), loadOp.isVolatile(), /*axis*/ 0);
nextBuffers.push_back(insertAsyncOp);
nextOp = builder.create<triton::gpu::ExtractSliceOp>(
op->getLoc(), loadsMapping[loadOp].getType(), insertAsyncOp,
sliceIndex, /*axis*/ 0);
extractSliceIndex, /*axis*/ 0);
extractSlices.push_back(nextOp->getResult(0));
} else
nextOp = builder.clone(*op, nextMapping);
@@ -491,25 +525,29 @@ scf::ForOp LoopPipeliner::createNewForOp() {
it->getDefiningOp()->moveAfter(asyncWait);
}
// bump iteration count
pipelineIterIdx = builder.create<arith::AddIOp>(
nextIV.getLoc(), pipelineIterIdx,
builder.create<arith::ConstantIntOp>(nextIV.getLoc(), 1, 32));
loopIterIdx = builder.create<arith::AddIOp>(
nextIV.getLoc(), loopIterIdx,
builder.create<arith::ConstantIntOp>(nextIV.getLoc(), 1, 32));
// Finally, the YieldOp, need to sync with the order of newLoopArgs
SmallVector<Value> yieldValues;
for (Value v : forOp.getBody()->getTerminator()->getOperands())
yieldValues.push_back(mapping.lookup(v));
// shift pipelined args by 1
for (size_t idx = 0; idx < loads.size(); ++idx) {
Value load = loads[idx];
for (int stage = 1; stage < numStages - 1; ++stage) {
yieldValues.push_back(
newForOp
.getRegionIterArgs()[loadIdx + idx * (numStages - 1) + stage]);
}
yieldValues.push_back(nextMapping.lookup(load));
}
for (Value nextBuffer : nextBuffers)
yieldValues.push_back(nextBuffer);
for (Value nextSlice : extractSlices)
yieldValues.push_back(nextSlice);
for (size_t i = depArgsBeginIdx; i < nextIVIdx; ++i)
yieldValues.push_back(
depArgsMapping.lookup(newForOp.getRegionIterArgs()[i]));
yieldValues.push_back(nextIV);
yieldValues.push_back(pipelineIterIdx);
yieldValues.push_back(loopIterIdx);
builder.setInsertionPointToEnd(newForOp.getBody());
auto test = builder.create<scf::YieldOp>(