Files
triton/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp

465 lines
17 KiB
C++
Raw Normal View History

2022-05-11 16:13:53 +08:00
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
2022-05-14 22:04:36 +08:00
#include "mlir/IR/BlockAndValueMapping.h"
2022-05-13 21:32:35 +08:00
//===----------------------------------------------------------------------===//
//
// This file implements loop software pipelining
2022-05-14 22:04:36 +08:00
// The implementation here is inspired by the pipeline pass in Triton (-v2.0)
// and SCF's LoopPipelining.
2022-05-13 21:32:35 +08:00
//
//===----------------------------------------------------------------------===//
2022-05-11 16:13:53 +08:00
using namespace mlir;
#define GEN_PASS_CLASSES
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
namespace {
2022-05-11 20:31:08 +08:00
class LoopPipeliner {
2022-05-13 21:32:35 +08:00
/// comments on numStages:
/// [0, numStages-1) are in the prologue
/// numStages-1 is appended after the loop body
2022-05-11 20:31:08 +08:00
int numStages;
2022-05-25 21:52:51 +08:00
2022-05-11 20:31:08 +08:00
/// cache forOp we are working on
scf::ForOp forOp;
2022-05-25 21:52:51 +08:00
/// cahce YieldOp for this forOp
scf::YieldOp yieldOp;
/// loads to be pipelined
SetVector<Value> loads;
/// the value that each load will be mapped to (after layout conversion)
DenseMap<Value, Value> loadsMapping;
2022-05-25 21:52:51 +08:00
2022-05-11 20:31:08 +08:00
/// value (in loop) => value at stage N
DenseMap<Value, SmallVector<Value>> valueMapping;
2022-05-13 21:32:35 +08:00
2022-05-25 21:52:51 +08:00
/// Block arguments that loads depend on
2022-05-13 21:32:35 +08:00
DenseSet<BlockArgument> depArgs;
2022-05-25 21:52:51 +08:00
/// Operations (inside the loop body) that loads depend on
2022-05-13 21:32:35 +08:00
DenseSet<Operation*> depOps;
/// collect values that v depends on and are defined inside the loop
2022-05-25 21:52:51 +08:00
void collectDeps(Value v, int stages, DenseSet<Value> &deps);
void setValueMapping(Value origin, Value newValue, int stage);
2022-05-11 20:31:08 +08:00
public:
LoopPipeliner(scf::ForOp forOp, int numStages)
2022-05-25 21:52:51 +08:00
: forOp(forOp), numStages(numStages) {
// cache yieldOp
yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
}
2022-05-11 20:31:08 +08:00
2022-05-25 21:52:51 +08:00
/// Collect loads to pipeline. Return success if we can pipeline this loop
2022-05-11 20:31:08 +08:00
LogicalResult initialize();
2022-05-25 21:52:51 +08:00
/// emit pipelined loads (before loop body)
2022-05-11 20:31:08 +08:00
void emitPrologue();
2022-05-14 22:04:36 +08:00
/// create the new ForOp (add new args & insert prefetched ops)
scf::ForOp createNewForOp();
2022-05-11 20:31:08 +08:00
friend class PipelinePass;
};
2022-05-13 21:32:35 +08:00
// helpers
void LoopPipeliner::setValueMapping(Value origin, Value newValue, int stage) {
if (valueMapping.find(origin) == valueMapping.end())
valueMapping[origin] = SmallVector<Value>(numStages);
valueMapping[origin][stage] = newValue;
}
2022-05-25 21:52:51 +08:00
void LoopPipeliner::collectDeps(Value v, int stages, DenseSet<Value> &deps) {
// Loop-invarant value. skip
2022-05-13 21:32:35 +08:00
if (v.getParentRegion() != &forOp.getLoopBody())
return;
2022-05-25 21:52:51 +08:00
// Since we only need to peel the loop numStages-1 times, don't worry about
// depends that are too far away
if (stages < 0)
return;
2022-05-14 22:04:36 +08:00
if (auto arg = v.dyn_cast<BlockArgument>()) {
2022-05-25 21:52:51 +08:00
deps.insert(v);
2022-05-14 22:04:36 +08:00
// Note: we have iv as the first arg, so the op idx is arg.getArgNumber()-1
2022-05-25 21:52:51 +08:00
collectDeps(yieldOp->getOperand(arg.getArgNumber() - 1), stages-1, deps);
2022-05-14 22:04:36 +08:00
} else { // value
2022-05-25 21:52:51 +08:00
// v might be in deps, but we still need to visit v.
// This is because v might depends on value in previous iterations
deps.insert(v);
for (Value op : v.getDefiningOp()->getOperands())
collectDeps(op, stages, deps);
2022-05-13 21:32:35 +08:00
}
}
2022-05-11 20:31:08 +08:00
/// A load instruction can be pipelined if:
2022-05-25 21:52:51 +08:00
/// - the load doesn't depend on any other loads (after loop peeling)
/// - (?) this load is not a loop-invariant value (we should run LICM before
/// this pass?)
2022-05-11 20:31:08 +08:00
LogicalResult LoopPipeliner::initialize() {
2022-05-14 22:04:36 +08:00
Block *loop = forOp.getBody();
2022-05-11 20:31:08 +08:00
2022-05-25 21:52:51 +08:00
// can we use forOp.walk(...) here?
SmallVector<triton::LoadOp, 2> allLoads;
for (Operation &op : *loop)
if (auto loadOp = dyn_cast<triton::LoadOp>(&op))
allLoads.push_back(loadOp);
2022-05-11 20:31:08 +08:00
2022-05-25 21:52:51 +08:00
// Early stop: no need to continue if there is no load in the loop.
if (allLoads.empty())
2022-05-11 20:31:08 +08:00
return failure();
2022-05-25 21:52:51 +08:00
// load => values that it depends on
DenseMap<Value, DenseSet<Value>> loadDeps;
for (triton::LoadOp loadOp : allLoads) {
DenseSet<Value> deps;
for (Value op : loadOp->getOperands())
collectDeps(op, numStages - 1, deps);
loadDeps[loadOp] = deps;
}
// for (triton::LoadOp loadOp : allLoads) {
// llvm::errs() << loadOp << " depends on: #" << loadDeps[loadOp].size() << " values\n";
// for (Value dep : loadDeps[loadOp])
// llvm::errs() << dep << "\n";
// llvm::errs() << "\n";
// }
// Don't pipeline loads that depend on other loads
// (Because if a load depends on another load, this load needs to wait on the
// other load in the prologue, which is against the point of the pipeline
// pass)
for (triton::LoadOp loadOp : allLoads) {
bool isCandiate = true;
for (triton::LoadOp other : allLoads) {
if (loadDeps[loadOp].contains(other)) {
isCandiate = false;
break;
}
2022-05-11 20:31:08 +08:00
}
// For now, we only pipeline loads that have one covert_layout (to smem) use
// TODO: lift this constraint in the future
if (isCandiate && loadOp.getResult().hasOneUse()) {
isCandiate = false;
Operation *use = *loadOp.getResult().getUsers().begin();
if (auto convertLayout = llvm::dyn_cast<triton::gpu::ConvertLayoutOp>(use)) {
if (auto tensorType = convertLayout.getResult().getType().dyn_cast<RankedTensorType>()) {
if (tensorType.getEncoding().isa<triton::gpu::TritonGPUSharedEncodingAttr>()) {
isCandiate = true;
loadsMapping[loadOp] = convertLayout;
}
}
}
} else
isCandiate = false;
2022-05-25 21:52:51 +08:00
if (isCandiate)
loads.insert(loadOp);
}
// we have some loads to pipeline
if (!loads.empty()) {
// update depArgs & depOps
for (Value loadOp : loads) {
for (Value dep : loadDeps[loadOp]) {
// TODO: we should record the stage that the value is depended on
if (auto arg = dep.dyn_cast<BlockArgument>())
depArgs.insert(arg);
else
depOps.insert(dep.getDefiningOp());
}
2022-05-11 20:31:08 +08:00
}
2022-05-25 21:52:51 +08:00
return success();
2022-05-11 20:31:08 +08:00
}
return failure();
}
void LoopPipeliner::emitPrologue() {
2022-05-26 13:57:01 +08:00
// llvm::errs() << "loads to pipeline...:\n";
// for (Value load : loads)
// llvm::errs() << load << "\n";
2022-05-11 20:31:08 +08:00
OpBuilder builder(forOp);
2022-05-13 21:32:35 +08:00
for (BlockArgument &arg : forOp.getRegionIterArgs()) {
OpOperand &operand = forOp.getOpOperandForRegionIterArg(arg);
setValueMapping(arg, operand.get(), 0);
}
2022-05-14 22:04:36 +08:00
// prologue from [0, numStage-1)
2022-05-15 22:29:27 +08:00
Value iv = forOp.getLowerBound();
2022-05-13 21:32:35 +08:00
for (int stage = 0; stage < numStages - 1; ++stage) {
// special handling for induction variable as the increment is implicit
if (stage != 0)
iv = builder.create<arith::AddIOp>(iv.getLoc(), iv, forOp.getStep());
setValueMapping(forOp.getInductionVar(), iv, stage);
// special handling for loop condition as there is no condition in ForOp
Value loopCond = builder.create<arith::CmpIOp>(
iv.getLoc(), arith::CmpIPredicate::slt, iv, forOp.getUpperBound());
// rematerialize peeled values
SmallVector<Operation*> orderedDeps;
2022-05-25 21:52:51 +08:00
for (Operation &op : forOp.getLoopBody().front()) {
2022-05-13 21:32:35 +08:00
if (depOps.contains(&op))
orderedDeps.push_back(&op);
2022-05-25 21:52:51 +08:00
else if (loads.contains(op.getResult(0)))
orderedDeps.push_back(&op);
}
assert(depOps.size() + loads.size() == orderedDeps.size() &&
"depOps contains invalid values");
2022-05-13 21:32:35 +08:00
for (Operation *op : orderedDeps) {
2022-05-25 21:52:51 +08:00
Operation *newOp = nullptr;
if (loads.contains(op->getResult(0))) {
// load => copy async
// TODO: check if the hardware supports copyasync
if (auto loadOp = llvm::dyn_cast<triton::LoadOp>(op)) {
newOp = builder.create<triton::gpu::CopyAsyncOp>(
op->getLoc(), loadsMapping[loadOp].getType(),
2022-05-25 21:52:51 +08:00
loadOp.ptr(), loadOp.mask(), loadOp.other(),
loadOp.cache(), loadOp.evict(), loadOp.isVolatile()
);
} else
llvm_unreachable("This should be LoadOp");
} else
newOp = builder.clone(*op);
2022-05-26 13:14:41 +08:00
2022-05-13 21:32:35 +08:00
for (unsigned opIdx = 0; opIdx < op->getNumOperands(); ++opIdx) {
auto it = valueMapping.find(op->getOperand(opIdx));
if (it != valueMapping.end()) {
Value v = it->second[stage];
assert(v);
newOp->setOperand(opIdx, v);
} // else, op at opIdx is a loop-invariant value
}
2022-05-26 13:14:41 +08:00
// If this is a load/async_copy, we need to update the mask
if (llvm::isa<triton::LoadOp, triton::gpu::CopyAsyncOp>(newOp)) {
Value mask = newOp->getOperand(1);
// assert(I1 or TensorOf<[I1]>);
OpBuilder::InsertionGuard g(builder);
builder.setInsertionPoint(newOp);
Value splatCond = builder.create<triton::BroadcastOp>(mask.getLoc(),
mask.getType(),
loopCond);
Value newMask = builder.create<arith::AndIOp>(mask.getLoc(),
mask,
splatCond);
newOp->setOperand(1, newMask);
}
2022-05-25 21:52:51 +08:00
2022-05-13 21:32:35 +08:00
// update mapping of results
for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) {
Value originalResult = op->getResult(dstIdx);
// copy_async will update the value of its only use
if (loads.contains(originalResult))
originalResult = loadsMapping[originalResult];
setValueMapping(originalResult, newOp->getResult(dstIdx), stage);
2022-05-14 22:04:36 +08:00
// update mapping for loop-carried values (args)
2022-05-25 21:52:51 +08:00
for (OpOperand &operand : yieldOp->getOpOperands()) {
2022-05-14 22:04:36 +08:00
if (operand.get() == op->getResult(dstIdx))
setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()],
newOp->getResult(dstIdx), stage + 1);
}
2022-05-13 21:32:35 +08:00
}
}
}
2022-05-11 20:31:08 +08:00
}
2022-05-14 22:04:36 +08:00
scf::ForOp LoopPipeliner::createNewForOp() {
OpBuilder builder(forOp);
// order of new args:
// (original args),
// for each load result (after layout conversion) x:
2022-05-25 21:52:51 +08:00
// (x at stage[0, numStages-1))
2022-05-14 22:04:36 +08:00
// (depArgs at stage numStages-1)
2022-05-15 22:29:27 +08:00
// (iv at stage numStages-1)
2022-05-14 22:04:36 +08:00
SmallVector<Value> newLoopArgs;
2022-05-15 22:29:27 +08:00
// We need this to update operands for yield
// original block arg => new arg's idx
DenseMap<BlockArgument, size_t> depArgsIdx;
2022-05-14 22:04:36 +08:00
for (auto v : forOp.getIterOperands())
newLoopArgs.push_back(v);
2022-05-25 21:52:51 +08:00
size_t loadIdx = newLoopArgs.size();
for (Value loadOp : loads)
for (int i = 0; i < numStages - 1; ++i)
newLoopArgs.push_back(valueMapping[loadsMapping[loadOp]][i]);
2022-05-25 21:52:51 +08:00
2022-05-14 22:04:36 +08:00
size_t depArgsBeginIdx = newLoopArgs.size();
2022-05-15 22:29:27 +08:00
for (BlockArgument depArg : depArgs) {
depArgsIdx[depArg] = newLoopArgs.size();
2022-05-14 22:04:36 +08:00
newLoopArgs.push_back(valueMapping[depArg][numStages-1]);
2022-05-15 22:29:27 +08:00
}
2022-05-25 21:52:51 +08:00
2022-05-14 22:04:36 +08:00
size_t nextIVIdx = newLoopArgs.size();
2022-05-15 22:29:27 +08:00
newLoopArgs.push_back(valueMapping[forOp.getInductionVar()][numStages-2]);
for (size_t i = 0; i < newLoopArgs.size(); ++i)
assert(newLoopArgs[i]);
2022-05-14 22:04:36 +08:00
2022-05-25 21:52:51 +08:00
// 1. signature of the new ForOp
2022-05-14 22:04:36 +08:00
auto newForOp = builder.create<scf::ForOp>(forOp.getLoc(),
forOp.getLowerBound(),
forOp.getUpperBound(),
forOp.getStep(),
newLoopArgs);
2022-05-25 21:52:51 +08:00
// 2. body of the new ForOp
2022-05-14 22:04:36 +08:00
builder.setInsertionPointToStart(newForOp.getBody());
BlockAndValueMapping mapping;
for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs()))
mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
2022-05-25 21:52:51 +08:00
// 2.1 clone the loop body, replace original args with args of the new ForOp
2022-05-14 22:04:36 +08:00
for (Operation &op : forOp.getBody()->without_terminator()) {
Operation *newOp = builder.clone(op, mapping);
// update mapping of results
for (unsigned dstIdx : llvm::seq(unsigned(0), op.getNumResults()))
mapping.map(op.getResult(dstIdx), newOp->getResult(dstIdx));
}
2022-05-25 21:52:51 +08:00
2022-05-26 13:14:41 +08:00
// 3. replace loads with block args (from prologue)
2022-05-25 21:52:51 +08:00
for (size_t idx = 0; idx < loads.size(); ++idx) {
Value load = loads[idx];
assert(load.hasOneUse() && "we assume that this load has one use (ConvertLayout)");
Value loadUse = load.getUsers().begin()->getResult(0);
mapping.lookup(loadUse).replaceAllUsesWith(
2022-05-26 13:57:01 +08:00
newForOp.getRegionIterArgs()[loadIdx + idx*(numStages-1)]);
2022-05-25 21:52:51 +08:00
}
// 4. prefetch the next iteration
2022-05-14 22:04:36 +08:00
SmallVector<Operation*> orderedDeps;
2022-05-25 21:52:51 +08:00
for (Operation &op : forOp.getLoopBody().front()) {
2022-05-14 22:04:36 +08:00
if (depOps.contains(&op))
orderedDeps.push_back(&op);
2022-05-25 21:52:51 +08:00
else if (loads.contains(op.getResult(0)))
orderedDeps.push_back(&op);
}
assert(depOps.size() + loads.size() == orderedDeps.size() &&
"depOps contains invalid values");
2022-05-14 22:04:36 +08:00
BlockAndValueMapping nextMapping;
2022-05-15 22:29:27 +08:00
DenseMap<BlockArgument, Value> depArgsMapping;
2022-05-14 22:04:36 +08:00
size_t argIdx = 0;
for (BlockArgument arg : depArgs) {
nextMapping.map(arg, newForOp.getRegionIterArgs()[argIdx + depArgsBeginIdx]);
++argIdx;
}
// special handling for iv & loop condition
Value nextIV = builder.create<arith::AddIOp>(newForOp.getInductionVar().getLoc(),
newForOp.getRegionIterArgs()[nextIVIdx],
newForOp.getStep());
Value nextLoopCond = builder.create<arith::CmpIOp>(
nextIV.getLoc(), arith::CmpIPredicate::slt,
nextIV, newForOp.getUpperBound());
for (Operation *op : orderedDeps) {
2022-05-25 21:52:51 +08:00
Operation *nextOp = nullptr;
2022-05-14 22:04:36 +08:00
// update loading mask
2022-05-25 21:52:51 +08:00
if (loads.contains(op->getResult(0))) {
2022-05-14 22:04:36 +08:00
auto loadOp = llvm::cast<triton::LoadOp>(op);
Value mask = loadOp.mask();
Value splatCond = builder.create<triton::BroadcastOp>(mask.getLoc(),
mask.getType(),
nextLoopCond);
Value newMask = builder.create<arith::AndIOp>(mask.getLoc(),
splatCond,
nextMapping.lookupOrDefault(mask));
2022-05-16 19:38:40 +08:00
// if mask is defined outside the loop, don't update the map more than once
if (!(forOp.isDefinedOutsideOfLoop(mask) && nextMapping.contains(mask)))
nextMapping.map(mask, newMask);
2022-05-25 21:52:51 +08:00
// TODO: more elegant way to do this?
nextOp = builder.create<triton::gpu::CopyAsyncOp>(
op->getLoc(), loadsMapping[op->getResult(0)].getType(),
2022-05-25 21:52:51 +08:00
nextMapping.lookupOrDefault(loadOp.ptr()),
nextMapping.lookupOrDefault(loadOp.mask()),
nextMapping.lookupOrDefault(loadOp.other()),
loadOp.cache(), loadOp.evict(), loadOp.isVolatile()
);
2022-05-14 22:04:36 +08:00
}
2022-05-25 21:52:51 +08:00
else
nextOp = builder.clone(*op, nextMapping);
// llvm::errs() << "epilogue cloning...: " << *op << "\n";
2022-05-14 22:04:36 +08:00
// update mapping of results
2022-05-15 22:29:27 +08:00
for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) {
2022-05-14 22:04:36 +08:00
nextMapping.map(op->getResult(dstIdx), nextOp->getResult(dstIdx));
2022-05-15 22:29:27 +08:00
// if this is a loop-carried value, update the mapping for yield
auto originYield = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
for (OpOperand &operand : originYield->getOpOperands()) {
if (operand.get() == op->getResult(dstIdx)) {
size_t originIdx = operand.getOperandNumber();
size_t newArgIdx = depArgsIdx[forOp.getRegionIterArgs()[originIdx]];
BlockArgument newArg = newForOp.getRegionIterArgs()[newArgIdx];
depArgsMapping[newArg] = nextOp->getResult(dstIdx);
}
}
}
2022-05-14 22:04:36 +08:00
}
// Finally, the YieldOp, need to sync with the order of newLoopArgs
SmallVector<Value> yieldValues;
for (Value v : forOp.getBody()->getTerminator()->getOperands())
yieldValues.push_back(mapping.lookup(v));
2022-05-26 13:57:01 +08:00
// shift pipelined args by 1
2022-05-25 21:52:51 +08:00
for (size_t idx = 0; idx < loads.size(); ++idx) {
Value load = loads[idx];
for (int stage = 1; stage < numStages - 1; ++stage) {
yieldValues.push_back(newForOp.getRegionIterArgs()[
2022-05-26 13:57:01 +08:00
loadIdx + idx*(numStages-1) + stage
2022-05-25 21:52:51 +08:00
]);
}
yieldValues.push_back(nextMapping.lookup(load));
}
2022-05-15 22:29:27 +08:00
for (size_t i = depArgsBeginIdx; i < nextIVIdx; ++i)
yieldValues.push_back(depArgsMapping.lookup(newForOp.getRegionIterArgs()[i]));
2022-05-14 22:04:36 +08:00
yieldValues.push_back(nextIV);
2022-05-15 22:29:27 +08:00
builder.setInsertionPointToEnd(newForOp.getBody());
builder.create<scf::YieldOp>(forOp.getBody()->getTerminator()->getLoc(),
yieldValues);
2022-05-14 22:04:36 +08:00
return newForOp;
}
2022-05-11 20:31:08 +08:00
// ref: mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
2022-05-11 16:13:53 +08:00
struct PipelinePass : public TritonGPUPipelineBase<PipelinePass> {
PipelinePass() = default;
PipelinePass(int numStages) {
this->numStages = numStages;
}
2022-05-11 16:13:53 +08:00
void runOnOperation() override {
int numStages = this->numStages;
2022-05-11 20:31:08 +08:00
if (numStages <= 1)
return;
getOperation()->walk([&](scf::ForOp forOp) -> void {
LoopPipeliner pipeliner(forOp, numStages);
if (pipeliner.initialize().failed())
return;
2022-05-11 16:13:53 +08:00
2022-05-14 22:04:36 +08:00
pipeliner.emitPrologue();
2022-05-13 21:32:35 +08:00
2022-05-14 22:04:36 +08:00
scf::ForOp newForOp = pipeliner.createNewForOp();
2022-05-13 21:32:35 +08:00
2022-05-15 22:29:27 +08:00
// replace the original loop
for (unsigned i = 0; i < forOp->getNumResults(); ++i)
forOp->getResult(i).replaceAllUsesWith(newForOp->getResult(i));
forOp->erase();
2022-05-11 16:13:53 +08:00
});
}
};
} // anonymous namespace
std::unique_ptr<Pass> mlir::createTritonGPUPipelinePass(int numStages) {
return std::make_unique<PipelinePass>(numStages);
2022-05-11 16:13:53 +08:00
}