The pipeline pass is now functional

This commit is contained in:
Yan Da
2022-05-15 22:29:27 +08:00
parent 7e0e7ec365
commit 7027af9666
4 changed files with 119 additions and 98 deletions

View File

@@ -2,6 +2,7 @@
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include <llvm-6.0/llvm/Support/raw_ostream.h>
//===----------------------------------------------------------------------===//
//
@@ -35,8 +36,6 @@ class LoopPipeliner {
PipelineInfo info;
/// value (in loop) => value at stage N
DenseMap<Value, SmallVector<Value>> valueMapping;
/// stage => loop condition
DenseMap<int, Value> loopConds;
DenseSet<BlockArgument> depArgs;
DenseSet<Operation*> depOps;
@@ -142,7 +141,7 @@ void LoopPipeliner::emitPrologue() {
// prologue from [0, numStage-1)
auto yield = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
Value iv = forOp.getInductionVar();
Value iv = forOp.getLowerBound();
for (int stage = 0; stage < numStages - 1; ++stage) {
// special handling for induction variable as the increment is implicit
if (stage != 0)
@@ -152,7 +151,6 @@ void LoopPipeliner::emitPrologue() {
// special handling for loop condition as there is no condition in ForOp
Value loopCond = builder.create<arith::CmpIOp>(
iv.getLoc(), arith::CmpIPredicate::slt, iv, forOp.getUpperBound());
loopConds[stage] = loopCond;
// rematerialize peeled values
SmallVector<Operation*> orderedDeps;
@@ -192,8 +190,11 @@ scf::ForOp LoopPipeliner::createNewForOp() {
// (original args),
// (a at stage[0, numStages-1)), (b at stage[0, numStages-1))
// (depArgs at stage numStages-1)
// (iv at stage numStages-1), (loopCond at stage numStages-1)
// (iv at stage numStages-1)
SmallVector<Value> newLoopArgs;
// We need this to update operands for yield
// original block arg => new arg's idx
DenseMap<BlockArgument, size_t> depArgsIdx;
for (auto v : forOp.getIterOperands())
newLoopArgs.push_back(v);
size_t aArgIdx = newLoopArgs.size();
@@ -203,11 +204,15 @@ scf::ForOp LoopPipeliner::createNewForOp() {
for (int i = 0; i < numStages - 1; ++i)
newLoopArgs.push_back(valueMapping[info.dotOp.b()][i]);
size_t depArgsBeginIdx = newLoopArgs.size();
for (BlockArgument depArg : depArgs)
for (BlockArgument depArg : depArgs) {
depArgsIdx[depArg] = newLoopArgs.size();
newLoopArgs.push_back(valueMapping[depArg][numStages-1]);
}
size_t nextIVIdx = newLoopArgs.size();
newLoopArgs.push_back(valueMapping[forOp.getInductionVar()][numStages-1]);
newLoopArgs.push_back(loopConds[numStages-1]);
newLoopArgs.push_back(valueMapping[forOp.getInductionVar()][numStages-2]);
for (size_t i = 0; i < newLoopArgs.size(); ++i)
assert(newLoopArgs[i]);
// signature of the new ForOp
auto newForOp = builder.create<scf::ForOp>(forOp.getLoc(),
@@ -221,13 +226,18 @@ scf::ForOp LoopPipeliner::createNewForOp() {
BlockAndValueMapping mapping;
for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs()))
mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
mapping.map(info.dotOp.a(), newForOp.getRegionIterArgs()[aArgIdx]);
mapping.map(info.dotOp.b(), newForOp.getRegionIterArgs()[bArgIdx]);
// mapping.map(info.dotOp.a(), newForOp.getRegionIterArgs()[aArgIdx]);
// mapping.map(info.dotOp.b(), newForOp.getRegionIterArgs()[bArgIdx]);
for (Operation &op : forOp.getBody()->without_terminator()) {
Operation *newOp = builder.clone(op, mapping);
// update mapping of results
for (unsigned dstIdx : llvm::seq(unsigned(0), op.getNumResults()))
mapping.map(op.getResult(dstIdx), newOp->getResult(dstIdx));
// TODO: why doesn't mapping work?
if (&op == info.dotOp.getOperation()) {
newOp->setOperand(0, newForOp.getRegionIterArgs()[aArgIdx]);
newOp->setOperand(1, newForOp.getRegionIterArgs()[bArgIdx]);
}
}
// prefetch next iteration
SmallVector<Operation*> orderedDeps;
@@ -236,7 +246,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
orderedDeps.push_back(&op);
assert(depOps.size() == orderedDeps.size() && "depOps contains invalid values");
BlockAndValueMapping nextMapping;
BlockAndValueMapping depArgsMapping;
DenseMap<BlockArgument, Value> depArgsMapping;
size_t argIdx = 0;
for (BlockArgument arg : depArgs) {
nextMapping.map(arg, newForOp.getRegionIterArgs()[argIdx + depArgsBeginIdx]);
@@ -264,8 +274,19 @@ scf::ForOp LoopPipeliner::createNewForOp() {
}
Operation *nextOp = builder.clone(*op, nextMapping);
// update mapping of results
for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults()))
for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) {
nextMapping.map(op->getResult(dstIdx), nextOp->getResult(dstIdx));
// if this is a loop-carried value, update the mapping for yield
auto originYield = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
for (OpOperand &operand : originYield->getOpOperands()) {
if (operand.get() == op->getResult(dstIdx)) {
size_t originIdx = operand.getOperandNumber();
size_t newArgIdx = depArgsIdx[forOp.getRegionIterArgs()[originIdx]];
BlockArgument newArg = newForOp.getRegionIterArgs()[newArgIdx];
depArgsMapping[newArg] = nextOp->getResult(dstIdx);
}
}
}
}
// Finally, the YieldOp, need to sync with the order of newLoopArgs
@@ -274,14 +295,16 @@ scf::ForOp LoopPipeliner::createNewForOp() {
yieldValues.push_back(mapping.lookup(v));
for (int i = 1; i < numStages - 1; ++i)
yieldValues.push_back(newForOp.getRegionIterArgs()[aArgIdx + i]);
yieldValues.push_back(nextMapping.lookup(info.aLoadOp.getResult()));
yieldValues.push_back(nextMapping.lookup(info.dotOp.a()));
for (int i = 1; i < numStages - 1; ++i)
yieldValues.push_back(newForOp.getRegionIterArgs()[bArgIdx + i]);
yieldValues.push_back(nextMapping.lookup(info.bLoadOp.getResult()));
// TODO: deps
//
yieldValues.push_back(nextMapping.lookup(info.dotOp.b()));
for (size_t i = depArgsBeginIdx; i < nextIVIdx; ++i)
yieldValues.push_back(depArgsMapping.lookup(newForOp.getRegionIterArgs()[i]));
yieldValues.push_back(nextIV);
yieldValues.push_back(nextLoopCond);
builder.setInsertionPointToEnd(newForOp.getBody());
builder.create<scf::YieldOp>(forOp.getBody()->getTerminator()->getLoc(),
yieldValues);
return newForOp;
}
@@ -300,16 +323,14 @@ struct PipelinePass : public TritonGPUPipelineBase<PipelinePass> {
if (pipeliner.initialize().failed())
return;
llvm::errs() << "candidate for pipelining: " << pipeliner.info.dotOp
<< "\n";
pipeliner.emitPrologue();
scf::ForOp newForOp = pipeliner.createNewForOp();
// // replace the original loop
// forOp->replaceAllUsesWith(newForOp->getResults());
// forOp->erase();
// replace the original loop
for (unsigned i = 0; i < forOp->getNumResults(); ++i)
forOp->getResult(i).replaceAllUsesWith(newForOp->getResult(i));
forOp->erase();
});
}
};