more progress on the pipeline pass
This commit is contained in:
@@ -1,11 +1,13 @@
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
|
||||
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file implements loop software pipelining
|
||||
// The implementation here is inspired by the pipeline pass in Triton 2.0 and
|
||||
// SCF's LoopPipelining.
|
||||
// The implementation here is inspired by the pipeline pass in Triton (-v2.0)
|
||||
// and SCF's LoopPipelining.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -53,6 +55,9 @@ public:
|
||||
///
|
||||
void emitPrologue();
|
||||
|
||||
/// create the new ForOp (add new args & insert prefetched ops)
|
||||
scf::ForOp createNewForOp();
|
||||
|
||||
friend class PipelinePass;
|
||||
};
|
||||
|
||||
@@ -66,10 +71,18 @@ void LoopPipeliner::setValueMapping(Value origin, Value newValue, int stage) {
|
||||
void LoopPipeliner::collectDeps(Value v) {
|
||||
if (v.getParentRegion() != &forOp.getLoopBody())
|
||||
return;
|
||||
if (auto arg = v.dyn_cast<BlockArgument>())
|
||||
if (auto arg = v.dyn_cast<BlockArgument>()) {
|
||||
if (depArgs.contains(arg))
|
||||
return;
|
||||
depArgs.insert(arg);
|
||||
else { // value
|
||||
// we also need to rematerialize this arg
|
||||
auto yield = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
|
||||
// Note: we have iv as the first arg, so the op idx is arg.getArgNumber()-1
|
||||
collectDeps(yield->getOperand(arg.getArgNumber() - 1));
|
||||
} else { // value
|
||||
Operation *defOp = v.getDefiningOp();
|
||||
if (depOps.contains(defOp))
|
||||
return;
|
||||
depOps.insert(defOp);
|
||||
for (Value op : defOp->getOperands())
|
||||
collectDeps(op);
|
||||
@@ -80,13 +93,11 @@ void LoopPipeliner::collectDeps(Value v) {
|
||||
/// - the pointer is a block argument (redefined inside the loop)
|
||||
/// - the load has only a single use in a dot instruction
|
||||
LogicalResult LoopPipeliner::initialize() {
|
||||
Region &bodyRegion = forOp.getLoopBody();
|
||||
assert(bodyRegion.hasOneBlock());
|
||||
Block &loop = bodyRegion.front();
|
||||
Block *loop = forOp.getBody();
|
||||
|
||||
// TODO: can we use forOp.walk(...) here?
|
||||
SmallVector<triton::DotOp, 2> dots;
|
||||
for (Operation &op : loop) {
|
||||
for (Operation &op : *loop) {
|
||||
if (auto dotOp = dyn_cast<triton::DotOp>(&op)) {
|
||||
dots.push_back(dotOp);
|
||||
}
|
||||
@@ -129,7 +140,8 @@ void LoopPipeliner::emitPrologue() {
|
||||
setValueMapping(arg, operand.get(), 0);
|
||||
}
|
||||
|
||||
// pro
|
||||
// prologue from [0, numStage-1)
|
||||
auto yield = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
|
||||
Value iv = forOp.getInductionVar();
|
||||
for (int stage = 0; stage < numStages - 1; ++stage) {
|
||||
// special handling for induction variable as the increment is implicit
|
||||
@@ -162,12 +174,117 @@ void LoopPipeliner::emitPrologue() {
|
||||
// update mapping of results
|
||||
for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) {
|
||||
setValueMapping(op->getResult(dstIdx), newOp->getResult(dstIdx), stage);
|
||||
// TODO: update mapping for loop-carried values (args)
|
||||
// update mapping for loop-carried values (args)
|
||||
for (OpOperand &operand : yield->getOpOperands()) {
|
||||
if (operand.get() == op->getResult(dstIdx))
|
||||
setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()],
|
||||
newOp->getResult(dstIdx), stage + 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
OpBuilder builder(forOp);
|
||||
|
||||
// order of new args:
|
||||
// (original args),
|
||||
// (a at stage[0, numStages-1)), (b at stage[0, numStages-1))
|
||||
// (depArgs at stage numStages-1)
|
||||
// (iv at stage numStages-1), (loopCond at stage numStages-1)
|
||||
SmallVector<Value> newLoopArgs;
|
||||
for (auto v : forOp.getIterOperands())
|
||||
newLoopArgs.push_back(v);
|
||||
size_t aArgIdx = newLoopArgs.size();
|
||||
for (int i = 0; i < numStages - 1; ++i)
|
||||
newLoopArgs.push_back(valueMapping[info.dotOp.a()][i]);
|
||||
size_t bArgIdx = newLoopArgs.size();
|
||||
for (int i = 0; i < numStages - 1; ++i)
|
||||
newLoopArgs.push_back(valueMapping[info.dotOp.b()][i]);
|
||||
size_t depArgsBeginIdx = newLoopArgs.size();
|
||||
for (BlockArgument depArg : depArgs)
|
||||
newLoopArgs.push_back(valueMapping[depArg][numStages-1]);
|
||||
size_t nextIVIdx = newLoopArgs.size();
|
||||
newLoopArgs.push_back(valueMapping[forOp.getInductionVar()][numStages-1]);
|
||||
newLoopArgs.push_back(loopConds[numStages-1]);
|
||||
|
||||
// signature of the new ForOp
|
||||
auto newForOp = builder.create<scf::ForOp>(forOp.getLoc(),
|
||||
forOp.getLowerBound(),
|
||||
forOp.getUpperBound(),
|
||||
forOp.getStep(),
|
||||
newLoopArgs);
|
||||
|
||||
// body of the new ForOp
|
||||
builder.setInsertionPointToStart(newForOp.getBody());
|
||||
BlockAndValueMapping mapping;
|
||||
for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs()))
|
||||
mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
|
||||
mapping.map(info.dotOp.a(), newForOp.getRegionIterArgs()[aArgIdx]);
|
||||
mapping.map(info.dotOp.b(), newForOp.getRegionIterArgs()[bArgIdx]);
|
||||
for (Operation &op : forOp.getBody()->without_terminator()) {
|
||||
Operation *newOp = builder.clone(op, mapping);
|
||||
// update mapping of results
|
||||
for (unsigned dstIdx : llvm::seq(unsigned(0), op.getNumResults()))
|
||||
mapping.map(op.getResult(dstIdx), newOp->getResult(dstIdx));
|
||||
}
|
||||
// prefetch next iteration
|
||||
SmallVector<Operation*> orderedDeps;
|
||||
for (Operation &op : forOp.getLoopBody().front())
|
||||
if (depOps.contains(&op))
|
||||
orderedDeps.push_back(&op);
|
||||
assert(depOps.size() == orderedDeps.size() && "depOps contains invalid values");
|
||||
BlockAndValueMapping nextMapping;
|
||||
BlockAndValueMapping depArgsMapping;
|
||||
size_t argIdx = 0;
|
||||
for (BlockArgument arg : depArgs) {
|
||||
nextMapping.map(arg, newForOp.getRegionIterArgs()[argIdx + depArgsBeginIdx]);
|
||||
++argIdx;
|
||||
}
|
||||
// special handling for iv & loop condition
|
||||
Value nextIV = builder.create<arith::AddIOp>(newForOp.getInductionVar().getLoc(),
|
||||
newForOp.getRegionIterArgs()[nextIVIdx],
|
||||
newForOp.getStep());
|
||||
Value nextLoopCond = builder.create<arith::CmpIOp>(
|
||||
nextIV.getLoc(), arith::CmpIPredicate::slt,
|
||||
nextIV, newForOp.getUpperBound());
|
||||
for (Operation *op : orderedDeps) {
|
||||
// update loading mask
|
||||
if (op == info.aLoadOp.getOperation() || op == info.bLoadOp.getOperation()) {
|
||||
auto loadOp = llvm::cast<triton::LoadOp>(op);
|
||||
Value mask = loadOp.mask();
|
||||
Value splatCond = builder.create<triton::BroadcastOp>(mask.getLoc(),
|
||||
mask.getType(),
|
||||
nextLoopCond);
|
||||
Value newMask = builder.create<arith::AndIOp>(mask.getLoc(),
|
||||
splatCond,
|
||||
nextMapping.lookupOrDefault(mask));
|
||||
nextMapping.map(mask, newMask);
|
||||
}
|
||||
Operation *nextOp = builder.clone(*op, nextMapping);
|
||||
// update mapping of results
|
||||
for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults()))
|
||||
nextMapping.map(op->getResult(dstIdx), nextOp->getResult(dstIdx));
|
||||
}
|
||||
|
||||
// Finally, the YieldOp, need to sync with the order of newLoopArgs
|
||||
SmallVector<Value> yieldValues;
|
||||
for (Value v : forOp.getBody()->getTerminator()->getOperands())
|
||||
yieldValues.push_back(mapping.lookup(v));
|
||||
for (int i = 1; i < numStages - 1; ++i)
|
||||
yieldValues.push_back(newForOp.getRegionIterArgs()[aArgIdx + i]);
|
||||
yieldValues.push_back(nextMapping.lookup(info.aLoadOp.getResult()));
|
||||
for (int i = 1; i < numStages - 1; ++i)
|
||||
yieldValues.push_back(newForOp.getRegionIterArgs()[bArgIdx + i]);
|
||||
yieldValues.push_back(nextMapping.lookup(info.bLoadOp.getResult()));
|
||||
// TODO: deps
|
||||
//
|
||||
yieldValues.push_back(nextIV);
|
||||
yieldValues.push_back(nextLoopCond);
|
||||
return newForOp;
|
||||
}
|
||||
|
||||
// ref: mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
|
||||
struct PipelinePass : public TritonGPUPipelineBase<PipelinePass> {
|
||||
void runOnOperation() override {
|
||||
@@ -186,15 +303,13 @@ struct PipelinePass : public TritonGPUPipelineBase<PipelinePass> {
|
||||
llvm::errs() << "candidate for pipelining: " << pipeliner.info.dotOp
|
||||
<< "\n";
|
||||
|
||||
// pipeliner.emitPrologue();
|
||||
pipeliner.emitPrologue();
|
||||
|
||||
// scf::ForOp newForOp = pipeliner.createNewForOp();
|
||||
scf::ForOp newForOp = pipeliner.createNewForOp();
|
||||
|
||||
// // replace the original loop
|
||||
// if (forOp->getNumResults() > 0)
|
||||
// rewriter.replaceOp(forOp, newForOp->getResults());
|
||||
// else
|
||||
// rewriter.eraseOp(forOp);
|
||||
// forOp->replaceAllUsesWith(newForOp->getResults());
|
||||
// forOp->erase();
|
||||
});
|
||||
}
|
||||
};
|
||||
|
Reference in New Issue
Block a user