The pipeline pass is now functional
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
|
||||
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include <llvm-6.0/llvm/Support/raw_ostream.h>
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
@@ -35,8 +36,6 @@ class LoopPipeliner {
|
||||
PipelineInfo info;
|
||||
/// value (in loop) => value at stage N
|
||||
DenseMap<Value, SmallVector<Value>> valueMapping;
|
||||
/// stage => loop condition
|
||||
DenseMap<int, Value> loopConds;
|
||||
|
||||
DenseSet<BlockArgument> depArgs;
|
||||
DenseSet<Operation*> depOps;
|
||||
@@ -142,7 +141,7 @@ void LoopPipeliner::emitPrologue() {
|
||||
|
||||
// prologue from [0, numStage-1)
|
||||
auto yield = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
|
||||
Value iv = forOp.getInductionVar();
|
||||
Value iv = forOp.getLowerBound();
|
||||
for (int stage = 0; stage < numStages - 1; ++stage) {
|
||||
// special handling for induction variable as the increment is implicit
|
||||
if (stage != 0)
|
||||
@@ -152,7 +151,6 @@ void LoopPipeliner::emitPrologue() {
|
||||
// special handling for loop condition as there is no condition in ForOp
|
||||
Value loopCond = builder.create<arith::CmpIOp>(
|
||||
iv.getLoc(), arith::CmpIPredicate::slt, iv, forOp.getUpperBound());
|
||||
loopConds[stage] = loopCond;
|
||||
|
||||
// rematerialize peeled values
|
||||
SmallVector<Operation*> orderedDeps;
|
||||
@@ -192,8 +190,11 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
// (original args),
|
||||
// (a at stage[0, numStages-1)), (b at stage[0, numStages-1))
|
||||
// (depArgs at stage numStages-1)
|
||||
// (iv at stage numStages-1), (loopCond at stage numStages-1)
|
||||
// (iv at stage numStages-1)
|
||||
SmallVector<Value> newLoopArgs;
|
||||
// We need this to update operands for yield
|
||||
// original block arg => new arg's idx
|
||||
DenseMap<BlockArgument, size_t> depArgsIdx;
|
||||
for (auto v : forOp.getIterOperands())
|
||||
newLoopArgs.push_back(v);
|
||||
size_t aArgIdx = newLoopArgs.size();
|
||||
@@ -203,11 +204,15 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
for (int i = 0; i < numStages - 1; ++i)
|
||||
newLoopArgs.push_back(valueMapping[info.dotOp.b()][i]);
|
||||
size_t depArgsBeginIdx = newLoopArgs.size();
|
||||
for (BlockArgument depArg : depArgs)
|
||||
for (BlockArgument depArg : depArgs) {
|
||||
depArgsIdx[depArg] = newLoopArgs.size();
|
||||
newLoopArgs.push_back(valueMapping[depArg][numStages-1]);
|
||||
}
|
||||
size_t nextIVIdx = newLoopArgs.size();
|
||||
newLoopArgs.push_back(valueMapping[forOp.getInductionVar()][numStages-1]);
|
||||
newLoopArgs.push_back(loopConds[numStages-1]);
|
||||
newLoopArgs.push_back(valueMapping[forOp.getInductionVar()][numStages-2]);
|
||||
|
||||
for (size_t i = 0; i < newLoopArgs.size(); ++i)
|
||||
assert(newLoopArgs[i]);
|
||||
|
||||
// signature of the new ForOp
|
||||
auto newForOp = builder.create<scf::ForOp>(forOp.getLoc(),
|
||||
@@ -221,13 +226,18 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
BlockAndValueMapping mapping;
|
||||
for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs()))
|
||||
mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
|
||||
mapping.map(info.dotOp.a(), newForOp.getRegionIterArgs()[aArgIdx]);
|
||||
mapping.map(info.dotOp.b(), newForOp.getRegionIterArgs()[bArgIdx]);
|
||||
// mapping.map(info.dotOp.a(), newForOp.getRegionIterArgs()[aArgIdx]);
|
||||
// mapping.map(info.dotOp.b(), newForOp.getRegionIterArgs()[bArgIdx]);
|
||||
for (Operation &op : forOp.getBody()->without_terminator()) {
|
||||
Operation *newOp = builder.clone(op, mapping);
|
||||
// update mapping of results
|
||||
for (unsigned dstIdx : llvm::seq(unsigned(0), op.getNumResults()))
|
||||
mapping.map(op.getResult(dstIdx), newOp->getResult(dstIdx));
|
||||
// TODO: why doesn't mapping work?
|
||||
if (&op == info.dotOp.getOperation()) {
|
||||
newOp->setOperand(0, newForOp.getRegionIterArgs()[aArgIdx]);
|
||||
newOp->setOperand(1, newForOp.getRegionIterArgs()[bArgIdx]);
|
||||
}
|
||||
}
|
||||
// prefetch next iteration
|
||||
SmallVector<Operation*> orderedDeps;
|
||||
@@ -236,7 +246,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
orderedDeps.push_back(&op);
|
||||
assert(depOps.size() == orderedDeps.size() && "depOps contains invalid values");
|
||||
BlockAndValueMapping nextMapping;
|
||||
BlockAndValueMapping depArgsMapping;
|
||||
DenseMap<BlockArgument, Value> depArgsMapping;
|
||||
size_t argIdx = 0;
|
||||
for (BlockArgument arg : depArgs) {
|
||||
nextMapping.map(arg, newForOp.getRegionIterArgs()[argIdx + depArgsBeginIdx]);
|
||||
@@ -264,8 +274,19 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
}
|
||||
Operation *nextOp = builder.clone(*op, nextMapping);
|
||||
// update mapping of results
|
||||
for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults()))
|
||||
for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) {
|
||||
nextMapping.map(op->getResult(dstIdx), nextOp->getResult(dstIdx));
|
||||
// if this is a loop-carried value, update the mapping for yield
|
||||
auto originYield = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
|
||||
for (OpOperand &operand : originYield->getOpOperands()) {
|
||||
if (operand.get() == op->getResult(dstIdx)) {
|
||||
size_t originIdx = operand.getOperandNumber();
|
||||
size_t newArgIdx = depArgsIdx[forOp.getRegionIterArgs()[originIdx]];
|
||||
BlockArgument newArg = newForOp.getRegionIterArgs()[newArgIdx];
|
||||
depArgsMapping[newArg] = nextOp->getResult(dstIdx);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Finally, the YieldOp, need to sync with the order of newLoopArgs
|
||||
@@ -274,14 +295,16 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
yieldValues.push_back(mapping.lookup(v));
|
||||
for (int i = 1; i < numStages - 1; ++i)
|
||||
yieldValues.push_back(newForOp.getRegionIterArgs()[aArgIdx + i]);
|
||||
yieldValues.push_back(nextMapping.lookup(info.aLoadOp.getResult()));
|
||||
yieldValues.push_back(nextMapping.lookup(info.dotOp.a()));
|
||||
for (int i = 1; i < numStages - 1; ++i)
|
||||
yieldValues.push_back(newForOp.getRegionIterArgs()[bArgIdx + i]);
|
||||
yieldValues.push_back(nextMapping.lookup(info.bLoadOp.getResult()));
|
||||
// TODO: deps
|
||||
//
|
||||
yieldValues.push_back(nextMapping.lookup(info.dotOp.b()));
|
||||
for (size_t i = depArgsBeginIdx; i < nextIVIdx; ++i)
|
||||
yieldValues.push_back(depArgsMapping.lookup(newForOp.getRegionIterArgs()[i]));
|
||||
yieldValues.push_back(nextIV);
|
||||
yieldValues.push_back(nextLoopCond);
|
||||
builder.setInsertionPointToEnd(newForOp.getBody());
|
||||
builder.create<scf::YieldOp>(forOp.getBody()->getTerminator()->getLoc(),
|
||||
yieldValues);
|
||||
return newForOp;
|
||||
}
|
||||
|
||||
@@ -300,16 +323,14 @@ struct PipelinePass : public TritonGPUPipelineBase<PipelinePass> {
|
||||
if (pipeliner.initialize().failed())
|
||||
return;
|
||||
|
||||
llvm::errs() << "candidate for pipelining: " << pipeliner.info.dotOp
|
||||
<< "\n";
|
||||
|
||||
pipeliner.emitPrologue();
|
||||
|
||||
scf::ForOp newForOp = pipeliner.createNewForOp();
|
||||
|
||||
// // replace the original loop
|
||||
// forOp->replaceAllUsesWith(newForOp->getResults());
|
||||
// forOp->erase();
|
||||
// replace the original loop
|
||||
for (unsigned i = 0; i < forOp->getNumResults(); ++i)
|
||||
forOp->getResult(i).replaceAllUsesWith(newForOp->getResult(i));
|
||||
forOp->erase();
|
||||
});
|
||||
}
|
||||
};
|
||||
|
Reference in New Issue
Block a user