the pipeline pass now generates and accepts valid IR

This commit is contained in:
Yan Da
2022-06-07 19:34:59 +08:00
parent 560e29229b
commit 7b09b5f9e9
4 changed files with 82 additions and 37 deletions

View File

@@ -32,6 +32,8 @@ class LoopPipeliner {
/// loads to be pipelined
SetVector<Value> loads;
/// the value that each load will be mapped to (after layout conversion)
DenseMap<Value, Value> loadsMapping;
/// value (in loop) => value at stage N
DenseMap<Value, SmallVector<Value>> valueMapping;
@@ -139,6 +141,23 @@ LogicalResult LoopPipeliner::initialize() {
break;
}
}
// For now, we only pipeline loads that have one covert_layout (to smem) use
// TODO: lift this constraint in the future
if (isCandiate && loadOp.getResult().hasOneUse()) {
isCandiate = false;
Operation *use = *loadOp.getResult().getUsers().begin();
if (auto convertLayout = llvm::dyn_cast<triton::gpu::ConvertLayoutOp>(use)) {
if (auto tensorType = convertLayout.getResult().getType().dyn_cast<RankedTensorType>()) {
if (tensorType.getEncoding().isa<triton::gpu::TritonGPUSharedEncodingAttr>()) {
isCandiate = true;
loadsMapping[loadOp] = convertLayout;
}
}
}
} else
isCandiate = false;
if (isCandiate)
loads.insert(loadOp);
}
@@ -202,7 +221,7 @@ void LoopPipeliner::emitPrologue() {
// TODO: check if the hardware supports copyasync
if (auto loadOp = llvm::dyn_cast<triton::LoadOp>(op)) {
newOp = builder.create<triton::gpu::CopyAsyncOp>(
op->getLoc(), op->getResult(0).getType(),
op->getLoc(), loadsMapping[loadOp].getType(),
loadOp.ptr(), loadOp.mask(), loadOp.other(),
loadOp.cache(), loadOp.evict(), loadOp.isVolatile()
);
@@ -237,7 +256,11 @@ void LoopPipeliner::emitPrologue() {
// update mapping of results
for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) {
setValueMapping(op->getResult(dstIdx), newOp->getResult(dstIdx), stage);
Value originalResult = op->getResult(dstIdx);
// copy_async will update the value of its only use
if (loads.contains(originalResult))
originalResult = loadsMapping[originalResult];
setValueMapping(originalResult, newOp->getResult(dstIdx), stage);
// update mapping for loop-carried values (args)
for (OpOperand &operand : yieldOp->getOpOperands()) {
if (operand.get() == op->getResult(dstIdx))
@@ -254,7 +277,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
// order of new args:
// (original args),
// for each load result x:
// for each load result (after layout conversion) x:
// (x at stage[0, numStages-1))
// (depArgs at stage numStages-1)
// (iv at stage numStages-1)
@@ -268,7 +291,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
size_t loadIdx = newLoopArgs.size();
for (Value loadOp : loads)
for (int i = 0; i < numStages - 1; ++i)
newLoopArgs.push_back(valueMapping[loadOp][i]);
newLoopArgs.push_back(valueMapping[loadsMapping[loadOp]][i]);
size_t depArgsBeginIdx = newLoopArgs.size();
for (BlockArgument depArg : depArgs) {
@@ -295,6 +318,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs()))
mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
// 2.1 clone the loop body, replace original args with args of the new ForOp
for (Operation &op : forOp.getBody()->without_terminator()) {
Operation *newOp = builder.clone(op, mapping);
// update mapping of results
@@ -305,7 +329,9 @@ scf::ForOp LoopPipeliner::createNewForOp() {
// 3. replace loads with block args (from prologue)
for (size_t idx = 0; idx < loads.size(); ++idx) {
Value load = loads[idx];
mapping.lookup(load).replaceAllUsesWith(
assert(load.hasOneUse() && "we assume that this load has one use (ConvertLayout)");
Value loadUse = load.getUsers().begin()->getResult(0);
mapping.lookup(loadUse).replaceAllUsesWith(
newForOp.getRegionIterArgs()[loadIdx + idx*(numStages-1)]);
}
@@ -351,7 +377,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
nextMapping.map(mask, newMask);
// TODO: more elegant way to do this?
nextOp = builder.create<triton::gpu::CopyAsyncOp>(
op->getLoc(), op->getResult(0).getType(),
op->getLoc(), loadsMapping[op->getResult(0)].getType(),
nextMapping.lookupOrDefault(loadOp.ptr()),
nextMapping.lookupOrDefault(loadOp.mask()),
nextMapping.lookupOrDefault(loadOp.other()),