more progress on the pipeline pass
This commit is contained in:
@@ -1,6 +1,15 @@
|
|||||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
|
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// This file implements loop software pipelining
|
||||||
|
// The implementation here is inspired by the pipeline pass in Triton 2.0 and
|
||||||
|
// SCF's LoopPipelining.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
#define GEN_PASS_CLASSES
|
#define GEN_PASS_CLASSES
|
||||||
@@ -14,6 +23,9 @@ class LoopPipeliner {
|
|||||||
triton::LoadOp bLoadOp;
|
triton::LoadOp bLoadOp;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// comments on numStages:
|
||||||
|
/// [0, numStages-1) are in the prologue
|
||||||
|
/// numStages-1 is appended after the loop body
|
||||||
int numStages;
|
int numStages;
|
||||||
/// cache forOp we are working on
|
/// cache forOp we are working on
|
||||||
scf::ForOp forOp;
|
scf::ForOp forOp;
|
||||||
@@ -21,8 +33,16 @@ class LoopPipeliner {
|
|||||||
PipelineInfo info;
|
PipelineInfo info;
|
||||||
/// value (in loop) => value at stage N
|
/// value (in loop) => value at stage N
|
||||||
DenseMap<Value, SmallVector<Value>> valueMapping;
|
DenseMap<Value, SmallVector<Value>> valueMapping;
|
||||||
|
/// stage => loop condition
|
||||||
|
DenseMap<int, Value> loopConds;
|
||||||
|
|
||||||
void setStageValueMapping(Value origin, Value prefetched, int idx);
|
DenseSet<BlockArgument> depArgs;
|
||||||
|
DenseSet<Operation*> depOps;
|
||||||
|
|
||||||
|
void setValueMapping(Value origin, Value newValue, int stage);
|
||||||
|
|
||||||
|
/// collect values that v depends on and are defined inside the loop
|
||||||
|
void collectDeps(Value v);
|
||||||
public:
|
public:
|
||||||
LoopPipeliner(scf::ForOp forOp, int numStages)
|
LoopPipeliner(scf::ForOp forOp, int numStages)
|
||||||
: forOp(forOp), numStages(numStages) {}
|
: forOp(forOp), numStages(numStages) {}
|
||||||
@@ -30,12 +50,32 @@ public:
|
|||||||
/// Collect loop info. Return success if we can pipeline this loop
|
/// Collect loop info. Return success if we can pipeline this loop
|
||||||
LogicalResult initialize();
|
LogicalResult initialize();
|
||||||
|
|
||||||
///
|
///
|
||||||
void emitPrologue();
|
void emitPrologue();
|
||||||
|
|
||||||
friend class PipelinePass;
|
friend class PipelinePass;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// helpers
|
||||||
|
void LoopPipeliner::setValueMapping(Value origin, Value newValue, int stage) {
|
||||||
|
if (valueMapping.find(origin) == valueMapping.end())
|
||||||
|
valueMapping[origin] = SmallVector<Value>(numStages);
|
||||||
|
valueMapping[origin][stage] = newValue;
|
||||||
|
}
|
||||||
|
|
||||||
|
void LoopPipeliner::collectDeps(Value v) {
|
||||||
|
if (v.getParentRegion() != &forOp.getLoopBody())
|
||||||
|
return;
|
||||||
|
if (auto arg = v.dyn_cast<BlockArgument>())
|
||||||
|
depArgs.insert(arg);
|
||||||
|
else { // value
|
||||||
|
Operation *defOp = v.getDefiningOp();
|
||||||
|
depOps.insert(defOp);
|
||||||
|
for (Value op : defOp->getOperands())
|
||||||
|
collectDeps(op);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// A load instruction can be pipelined if:
|
/// A load instruction can be pipelined if:
|
||||||
/// - the pointer is a block argument (redefined inside the loop)
|
/// - the pointer is a block argument (redefined inside the loop)
|
||||||
/// - the load has only a single use in a dot instruction
|
/// - the load has only a single use in a dot instruction
|
||||||
@@ -72,6 +112,8 @@ LogicalResult LoopPipeliner::initialize() {
|
|||||||
if (aLoad && bLoad) {
|
if (aLoad && bLoad) {
|
||||||
if (aLoad.ptr().isa<BlockArgument>() && bLoad.ptr().isa<BlockArgument>()) {
|
if (aLoad.ptr().isa<BlockArgument>() && bLoad.ptr().isa<BlockArgument>()) {
|
||||||
info.dotOp = dotOp; info.aLoadOp = aLoad; info.bLoadOp = bLoad;
|
info.dotOp = dotOp; info.aLoadOp = aLoad; info.bLoadOp = bLoad;
|
||||||
|
collectDeps(dotOp.a());
|
||||||
|
collectDeps(dotOp.b());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -80,8 +122,50 @@ LogicalResult LoopPipeliner::initialize() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void LoopPipeliner::emitPrologue() {
|
void LoopPipeliner::emitPrologue() {
|
||||||
|
// TODO: should we use rewriter here?
|
||||||
OpBuilder builder(forOp);
|
OpBuilder builder(forOp);
|
||||||
//
|
for (BlockArgument &arg : forOp.getRegionIterArgs()) {
|
||||||
|
OpOperand &operand = forOp.getOpOperandForRegionIterArg(arg);
|
||||||
|
setValueMapping(arg, operand.get(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// pro
|
||||||
|
Value iv = forOp.getInductionVar();
|
||||||
|
for (int stage = 0; stage < numStages - 1; ++stage) {
|
||||||
|
// special handling for induction variable as the increment is implicit
|
||||||
|
if (stage != 0)
|
||||||
|
iv = builder.create<arith::AddIOp>(iv.getLoc(), iv, forOp.getStep());
|
||||||
|
setValueMapping(forOp.getInductionVar(), iv, stage);
|
||||||
|
|
||||||
|
// special handling for loop condition as there is no condition in ForOp
|
||||||
|
Value loopCond = builder.create<arith::CmpIOp>(
|
||||||
|
iv.getLoc(), arith::CmpIPredicate::slt, iv, forOp.getUpperBound());
|
||||||
|
loopConds[stage] = loopCond;
|
||||||
|
|
||||||
|
// rematerialize peeled values
|
||||||
|
SmallVector<Operation*> orderedDeps;
|
||||||
|
for (Operation &op : forOp.getLoopBody().front())
|
||||||
|
if (depOps.contains(&op))
|
||||||
|
orderedDeps.push_back(&op);
|
||||||
|
assert(depOps.size() == orderedDeps.size() && "depOps contains invalid values");
|
||||||
|
for (Operation *op : orderedDeps) {
|
||||||
|
Operation *newOp = builder.clone(*op);
|
||||||
|
for (unsigned opIdx = 0; opIdx < op->getNumOperands(); ++opIdx) {
|
||||||
|
auto it = valueMapping.find(op->getOperand(opIdx));
|
||||||
|
if (it != valueMapping.end()) {
|
||||||
|
Value v = it->second[stage];
|
||||||
|
assert(v);
|
||||||
|
newOp->setOperand(opIdx, v);
|
||||||
|
} // else, op at opIdx is a loop-invariant value
|
||||||
|
}
|
||||||
|
|
||||||
|
// update mapping of results
|
||||||
|
for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) {
|
||||||
|
setValueMapping(op->getResult(dstIdx), newOp->getResult(dstIdx), stage);
|
||||||
|
// TODO: update mapping for loop-carried values (args)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ref: mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
|
// ref: mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
|
||||||
@@ -101,8 +185,16 @@ struct PipelinePass : public TritonGPUPipelineBase<PipelinePass> {
|
|||||||
|
|
||||||
llvm::errs() << "candidate for pipelining: " << pipeliner.info.dotOp
|
llvm::errs() << "candidate for pipelining: " << pipeliner.info.dotOp
|
||||||
<< "\n";
|
<< "\n";
|
||||||
|
|
||||||
// pipeliner.emitPrologue();
|
// pipeliner.emitPrologue();
|
||||||
|
|
||||||
|
// scf::ForOp newForOp = pipeliner.createNewForOp();
|
||||||
|
|
||||||
|
// // replace the original loop
|
||||||
|
// if (forOp->getNumResults() > 0)
|
||||||
|
// rewriter.replaceOp(forOp, newForOp->getResults());
|
||||||
|
// else
|
||||||
|
// rewriter.eraseOp(forOp);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
Reference in New Issue
Block a user