the pipeline pass now generates and accepts valid IR
This commit is contained in:
@@ -32,6 +32,8 @@ class LoopPipeliner {
|
||||
|
||||
/// loads to be pipelined
|
||||
SetVector<Value> loads;
|
||||
/// the value that each load will be mapped to (after layout conversion)
|
||||
DenseMap<Value, Value> loadsMapping;
|
||||
|
||||
/// value (in loop) => value at stage N
|
||||
DenseMap<Value, SmallVector<Value>> valueMapping;
|
||||
@@ -139,6 +141,23 @@ LogicalResult LoopPipeliner::initialize() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// For now, we only pipeline loads that have one covert_layout (to smem) use
|
||||
// TODO: lift this constraint in the future
|
||||
if (isCandiate && loadOp.getResult().hasOneUse()) {
|
||||
isCandiate = false;
|
||||
Operation *use = *loadOp.getResult().getUsers().begin();
|
||||
if (auto convertLayout = llvm::dyn_cast<triton::gpu::ConvertLayoutOp>(use)) {
|
||||
if (auto tensorType = convertLayout.getResult().getType().dyn_cast<RankedTensorType>()) {
|
||||
if (tensorType.getEncoding().isa<triton::gpu::TritonGPUSharedEncodingAttr>()) {
|
||||
isCandiate = true;
|
||||
loadsMapping[loadOp] = convertLayout;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else
|
||||
isCandiate = false;
|
||||
|
||||
if (isCandiate)
|
||||
loads.insert(loadOp);
|
||||
}
|
||||
@@ -202,7 +221,7 @@ void LoopPipeliner::emitPrologue() {
|
||||
// TODO: check if the hardware supports copyasync
|
||||
if (auto loadOp = llvm::dyn_cast<triton::LoadOp>(op)) {
|
||||
newOp = builder.create<triton::gpu::CopyAsyncOp>(
|
||||
op->getLoc(), op->getResult(0).getType(),
|
||||
op->getLoc(), loadsMapping[loadOp].getType(),
|
||||
loadOp.ptr(), loadOp.mask(), loadOp.other(),
|
||||
loadOp.cache(), loadOp.evict(), loadOp.isVolatile()
|
||||
);
|
||||
@@ -237,7 +256,11 @@ void LoopPipeliner::emitPrologue() {
|
||||
|
||||
// update mapping of results
|
||||
for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) {
|
||||
setValueMapping(op->getResult(dstIdx), newOp->getResult(dstIdx), stage);
|
||||
Value originalResult = op->getResult(dstIdx);
|
||||
// copy_async will update the value of its only use
|
||||
if (loads.contains(originalResult))
|
||||
originalResult = loadsMapping[originalResult];
|
||||
setValueMapping(originalResult, newOp->getResult(dstIdx), stage);
|
||||
// update mapping for loop-carried values (args)
|
||||
for (OpOperand &operand : yieldOp->getOpOperands()) {
|
||||
if (operand.get() == op->getResult(dstIdx))
|
||||
@@ -254,7 +277,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
|
||||
// order of new args:
|
||||
// (original args),
|
||||
// for each load result x:
|
||||
// for each load result (after layout conversion) x:
|
||||
// (x at stage[0, numStages-1))
|
||||
// (depArgs at stage numStages-1)
|
||||
// (iv at stage numStages-1)
|
||||
@@ -268,7 +291,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
size_t loadIdx = newLoopArgs.size();
|
||||
for (Value loadOp : loads)
|
||||
for (int i = 0; i < numStages - 1; ++i)
|
||||
newLoopArgs.push_back(valueMapping[loadOp][i]);
|
||||
newLoopArgs.push_back(valueMapping[loadsMapping[loadOp]][i]);
|
||||
|
||||
size_t depArgsBeginIdx = newLoopArgs.size();
|
||||
for (BlockArgument depArg : depArgs) {
|
||||
@@ -295,6 +318,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs()))
|
||||
mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
|
||||
|
||||
// 2.1 clone the loop body, replace original args with args of the new ForOp
|
||||
for (Operation &op : forOp.getBody()->without_terminator()) {
|
||||
Operation *newOp = builder.clone(op, mapping);
|
||||
// update mapping of results
|
||||
@@ -305,7 +329,9 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
// 3. replace loads with block args (from prologue)
|
||||
for (size_t idx = 0; idx < loads.size(); ++idx) {
|
||||
Value load = loads[idx];
|
||||
mapping.lookup(load).replaceAllUsesWith(
|
||||
assert(load.hasOneUse() && "we assume that this load has one use (ConvertLayout)");
|
||||
Value loadUse = load.getUsers().begin()->getResult(0);
|
||||
mapping.lookup(loadUse).replaceAllUsesWith(
|
||||
newForOp.getRegionIterArgs()[loadIdx + idx*(numStages-1)]);
|
||||
}
|
||||
|
||||
@@ -351,7 +377,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
nextMapping.map(mask, newMask);
|
||||
// TODO: more elegant way to do this?
|
||||
nextOp = builder.create<triton::gpu::CopyAsyncOp>(
|
||||
op->getLoc(), op->getResult(0).getType(),
|
||||
op->getLoc(), loadsMapping[op->getResult(0)].getType(),
|
||||
nextMapping.lookupOrDefault(loadOp.ptr()),
|
||||
nextMapping.lookupOrDefault(loadOp.mask()),
|
||||
nextMapping.lookupOrDefault(loadOp.other()),
|
||||
|
Reference in New Issue
Block a user