2022-11-10 13:57:27 +08:00
|
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
2022-08-18 12:49:37 -07:00
|
|
|
#include "mlir/IR/BlockAndValueMapping.h"
|
2022-05-11 16:13:53 +08:00
|
|
|
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
|
|
|
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
|
|
|
|
|
2022-05-13 21:32:35 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
//
|
|
|
|
// This file implements loop software pipelining
|
2022-07-26 17:25:03 -07:00
|
|
|
// The implementation here is inspired by the pipeline pass in Triton (-v2.0)
|
2022-05-14 22:04:36 +08:00
|
|
|
// and SCF's LoopPipelining.
|
2022-05-13 21:32:35 +08:00
|
|
|
//
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-05-11 16:13:53 +08:00
|
|
|
using namespace mlir;
|
2022-11-10 13:57:27 +08:00
|
|
|
namespace ttg = triton::gpu;
|
2022-05-11 16:13:53 +08:00
|
|
|
|
|
|
|
#define GEN_PASS_CLASSES
|
|
|
|
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
|
|
|
|
|
2022-11-09 01:29:53 +08:00
|
|
|
static Type getI1SameShape(Value v) {
|
|
|
|
Type vType = v.getType();
|
|
|
|
auto i1Type = IntegerType::get(vType.getContext(), 1);
|
|
|
|
auto tensorType = vType.cast<RankedTensorType>();
|
|
|
|
return RankedTensorType::get(tensorType.getShape(), i1Type,
|
|
|
|
tensorType.getEncoding());
|
|
|
|
}
|
|
|
|
|
2022-11-30 10:07:34 -08:00
|
|
|
#define int_attr(num) builder.getI64IntegerAttr(num)
|
|
|
|
|
2022-05-11 16:13:53 +08:00
|
|
|
namespace {
|
2022-11-10 13:57:27 +08:00
|
|
|
|
2022-05-11 20:31:08 +08:00
|
|
|
class LoopPipeliner {
|
2022-11-30 10:07:34 -08:00
|
|
|
/// Cache forOp we are working on
|
2022-05-11 20:31:08 +08:00
|
|
|
scf::ForOp forOp;
|
2022-05-25 21:52:51 +08:00
|
|
|
|
2022-11-30 10:07:34 -08:00
|
|
|
/// Cache YieldOp for this forOp
|
2022-05-25 21:52:51 +08:00
|
|
|
scf::YieldOp yieldOp;
|
|
|
|
|
2022-11-30 10:07:34 -08:00
|
|
|
/// Loads to be pipelined
|
2022-05-25 21:52:51 +08:00
|
|
|
SetVector<Value> loads;
|
2022-11-30 10:07:34 -08:00
|
|
|
/// The value that each load will be mapped to (after layout conversion)
|
2022-06-07 19:34:59 +08:00
|
|
|
DenseMap<Value, Value> loadsMapping;
|
2022-09-06 23:31:13 +08:00
|
|
|
/// load => buffer
|
|
|
|
DenseMap<Value, Value> loadsBuffer;
|
2022-11-10 13:57:27 +08:00
|
|
|
/// load => buffer type (with shared layout after swizzling)
|
|
|
|
DenseMap<Value, RankedTensorType> loadsBufferType;
|
2022-09-09 11:01:14 -07:00
|
|
|
/// load => buffer at stage N
|
|
|
|
DenseMap<Value, SmallVector<Value>> loadStageBuffer;
|
|
|
|
/// load => after extract
|
|
|
|
DenseMap<Value, Value> loadsExtract;
|
|
|
|
///
|
|
|
|
Value pipelineIterIdx;
|
|
|
|
///
|
|
|
|
Value loopIterIdx;
|
2022-05-25 21:52:51 +08:00
|
|
|
|
2022-11-30 10:07:34 -08:00
|
|
|
/// Comments on numStages:
|
2022-10-28 12:36:09 -07:00
|
|
|
/// [0, numStages-1) are in the prologue
|
|
|
|
/// numStages-1 is appended after the loop body
|
|
|
|
int numStages;
|
|
|
|
|
2022-05-11 20:31:08 +08:00
|
|
|
/// value (in loop) => value at stage N
|
|
|
|
DenseMap<Value, SmallVector<Value>> valueMapping;
|
2022-05-13 21:32:35 +08:00
|
|
|
|
2022-05-25 21:52:51 +08:00
|
|
|
/// Block arguments that loads depend on
|
2022-05-13 21:32:35 +08:00
|
|
|
DenseSet<BlockArgument> depArgs;
|
2022-11-30 10:07:34 -08:00
|
|
|
|
2022-05-25 21:52:51 +08:00
|
|
|
/// Operations (inside the loop body) that loads depend on
|
2022-07-26 17:25:03 -07:00
|
|
|
DenseSet<Operation *> depOps;
|
2022-05-13 21:32:35 +08:00
|
|
|
|
|
|
|
/// collect values that v depends on and are defined inside the loop
|
2022-05-25 21:52:51 +08:00
|
|
|
void collectDeps(Value v, int stages, DenseSet<Value> &deps);
|
|
|
|
|
|
|
|
void setValueMapping(Value origin, Value newValue, int stage);
|
2022-07-26 17:25:03 -07:00
|
|
|
|
2022-09-06 23:31:13 +08:00
|
|
|
Value lookupOrDefault(Value origin, int stage);
|
|
|
|
|
2022-11-30 10:07:34 -08:00
|
|
|
/// Returns a empty buffer of size <numStages, ...>
|
2022-11-10 13:57:27 +08:00
|
|
|
ttg::AllocTensorOp allocateEmptyBuffer(Operation *op, OpBuilder &builder);
|
|
|
|
|
2022-05-11 20:31:08 +08:00
|
|
|
public:
|
2022-07-26 17:25:03 -07:00
|
|
|
LoopPipeliner(scf::ForOp forOp, int numStages)
|
2022-05-25 21:52:51 +08:00
|
|
|
: forOp(forOp), numStages(numStages) {
|
|
|
|
// cache yieldOp
|
|
|
|
yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
|
|
|
|
}
|
2022-05-11 20:31:08 +08:00
|
|
|
|
2022-05-25 21:52:51 +08:00
|
|
|
/// Collect loads to pipeline. Return success if we can pipeline this loop
|
2022-05-11 20:31:08 +08:00
|
|
|
LogicalResult initialize();
|
|
|
|
|
2022-11-30 10:07:34 -08:00
|
|
|
/// Emit pipelined loads (before loop body)
|
2022-05-11 20:31:08 +08:00
|
|
|
void emitPrologue();
|
|
|
|
|
2022-10-27 22:09:06 -07:00
|
|
|
/// emit pipelined loads (after loop body)
|
|
|
|
void emitEpilogue();
|
|
|
|
|
2022-05-14 22:04:36 +08:00
|
|
|
/// create the new ForOp (add new args & insert prefetched ops)
|
|
|
|
scf::ForOp createNewForOp();
|
|
|
|
|
2022-10-28 12:36:09 -07:00
|
|
|
friend struct PipelinePass;
|
2022-05-11 20:31:08 +08:00
|
|
|
};
|
|
|
|
|
2022-05-13 21:32:35 +08:00
|
|
|
// helpers
|
|
|
|
void LoopPipeliner::setValueMapping(Value origin, Value newValue, int stage) {
|
|
|
|
if (valueMapping.find(origin) == valueMapping.end())
|
|
|
|
valueMapping[origin] = SmallVector<Value>(numStages);
|
|
|
|
valueMapping[origin][stage] = newValue;
|
|
|
|
}
|
|
|
|
|
2022-09-06 23:31:13 +08:00
|
|
|
Value LoopPipeliner::lookupOrDefault(Value origin, int stage) {
|
|
|
|
if (valueMapping.find(origin) == valueMapping.end())
|
|
|
|
return origin;
|
|
|
|
return valueMapping[origin][stage];
|
|
|
|
}
|
|
|
|
|
2022-05-25 21:52:51 +08:00
|
|
|
void LoopPipeliner::collectDeps(Value v, int stages, DenseSet<Value> &deps) {
|
2022-11-14 10:15:53 +08:00
|
|
|
// Loop-invariant value, skip
|
2022-05-13 21:32:35 +08:00
|
|
|
if (v.getParentRegion() != &forOp.getLoopBody())
|
|
|
|
return;
|
2022-05-25 21:52:51 +08:00
|
|
|
|
|
|
|
// Since we only need to peel the loop numStages-1 times, don't worry about
|
|
|
|
// depends that are too far away
|
|
|
|
if (stages < 0)
|
|
|
|
return;
|
|
|
|
|
2022-05-14 22:04:36 +08:00
|
|
|
if (auto arg = v.dyn_cast<BlockArgument>()) {
|
2022-12-06 09:08:55 -08:00
|
|
|
if (arg.getArgNumber() > 0) {
|
|
|
|
// Skip the first arg (loop induction variable)
|
|
|
|
// Otherwise the op idx is arg.getArgNumber()-1
|
|
|
|
deps.insert(v);
|
|
|
|
collectDeps(yieldOp->getOperand(arg.getArgNumber() - 1), stages - 1,
|
|
|
|
deps);
|
|
|
|
}
|
2022-05-14 22:04:36 +08:00
|
|
|
} else { // value
|
2022-05-25 21:52:51 +08:00
|
|
|
// v might be in deps, but we still need to visit v.
|
2022-11-14 10:15:53 +08:00
|
|
|
// This is because v might depend on value in previous iterations
|
2022-05-25 21:52:51 +08:00
|
|
|
deps.insert(v);
|
|
|
|
for (Value op : v.getDefiningOp()->getOperands())
|
|
|
|
collectDeps(op, stages, deps);
|
2022-05-13 21:32:35 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2022-11-10 13:57:27 +08:00
|
|
|
ttg::AllocTensorOp LoopPipeliner::allocateEmptyBuffer(Operation *op,
|
|
|
|
OpBuilder &builder) {
|
2022-11-30 10:07:34 -08:00
|
|
|
// Allocate a buffer for each pipelined tensor
|
2022-09-06 23:31:13 +08:00
|
|
|
// shape: e.g. (numStages==4), <32x64xbf16> -> <4x32x64xbf16>
|
|
|
|
Value convertLayout = loadsMapping[op->getResult(0)];
|
|
|
|
if (auto tensorType = convertLayout.getType().dyn_cast<RankedTensorType>()) {
|
2022-11-10 13:57:27 +08:00
|
|
|
return builder.create<ttg::AllocTensorOp>(
|
|
|
|
convertLayout.getLoc(), loadsBufferType[op->getResult(0)]);
|
2022-09-06 23:31:13 +08:00
|
|
|
}
|
|
|
|
llvm_unreachable("Async copy's return should be of RankedTensorType");
|
|
|
|
}
|
|
|
|
|
2022-05-11 20:31:08 +08:00
|
|
|
/// A load instruction can be pipelined if:
|
2022-05-25 21:52:51 +08:00
|
|
|
/// - the load doesn't depend on any other loads (after loop peeling)
|
|
|
|
/// - (?) this load is not a loop-invariant value (we should run LICM before
|
|
|
|
/// this pass?)
|
2022-05-11 20:31:08 +08:00
|
|
|
LogicalResult LoopPipeliner::initialize() {
|
2022-05-14 22:04:36 +08:00
|
|
|
Block *loop = forOp.getBody();
|
2022-05-11 20:31:08 +08:00
|
|
|
|
2022-05-25 21:52:51 +08:00
|
|
|
// can we use forOp.walk(...) here?
|
|
|
|
SmallVector<triton::LoadOp, 2> allLoads;
|
|
|
|
for (Operation &op : *loop)
|
|
|
|
if (auto loadOp = dyn_cast<triton::LoadOp>(&op))
|
|
|
|
allLoads.push_back(loadOp);
|
2022-05-11 20:31:08 +08:00
|
|
|
|
2022-05-25 21:52:51 +08:00
|
|
|
// Early stop: no need to continue if there is no load in the loop.
|
|
|
|
if (allLoads.empty())
|
2022-05-11 20:31:08 +08:00
|
|
|
return failure();
|
|
|
|
|
2022-05-25 21:52:51 +08:00
|
|
|
// load => values that it depends on
|
|
|
|
DenseMap<Value, DenseSet<Value>> loadDeps;
|
|
|
|
for (triton::LoadOp loadOp : allLoads) {
|
|
|
|
DenseSet<Value> deps;
|
|
|
|
for (Value op : loadOp->getOperands())
|
|
|
|
collectDeps(op, numStages - 1, deps);
|
|
|
|
loadDeps[loadOp] = deps;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Don't pipeline loads that depend on other loads
|
|
|
|
// (Because if a load depends on another load, this load needs to wait on the
|
|
|
|
// other load in the prologue, which is against the point of the pipeline
|
|
|
|
// pass)
|
|
|
|
for (triton::LoadOp loadOp : allLoads) {
|
2022-11-14 10:15:53 +08:00
|
|
|
bool isCandidate = true;
|
2022-05-25 21:52:51 +08:00
|
|
|
for (triton::LoadOp other : allLoads) {
|
|
|
|
if (loadDeps[loadOp].contains(other)) {
|
2022-11-14 10:15:53 +08:00
|
|
|
isCandidate = false;
|
2022-05-25 21:52:51 +08:00
|
|
|
break;
|
|
|
|
}
|
2022-05-11 20:31:08 +08:00
|
|
|
}
|
2022-06-07 19:34:59 +08:00
|
|
|
|
2022-11-10 13:57:27 +08:00
|
|
|
// We only pipeline loads that have one covert_layout (to dot_op) use
|
2022-06-07 19:34:59 +08:00
|
|
|
// TODO: lift this constraint in the future
|
2022-11-14 10:15:53 +08:00
|
|
|
if (isCandidate && loadOp.getResult().hasOneUse()) {
|
|
|
|
isCandidate = false;
|
2022-06-07 19:34:59 +08:00
|
|
|
Operation *use = *loadOp.getResult().getUsers().begin();
|
2022-11-10 13:57:27 +08:00
|
|
|
if (auto convertLayout = llvm::dyn_cast<ttg::ConvertLayoutOp>(use)) {
|
2022-07-26 17:25:03 -07:00
|
|
|
if (auto tensorType = convertLayout.getResult()
|
|
|
|
.getType()
|
|
|
|
.dyn_cast<RankedTensorType>()) {
|
2022-11-10 13:57:27 +08:00
|
|
|
if (auto dotOpEnc = tensorType.getEncoding()
|
|
|
|
.dyn_cast<ttg::DotOperandEncodingAttr>()) {
|
2022-11-14 10:15:53 +08:00
|
|
|
isCandidate = true;
|
2022-06-07 19:34:59 +08:00
|
|
|
loadsMapping[loadOp] = convertLayout;
|
2022-11-10 12:05:46 -08:00
|
|
|
auto ty = loadOp.getType().cast<RankedTensorType>();
|
|
|
|
SmallVector<int64_t> bufferShape(ty.getShape().begin(),
|
|
|
|
ty.getShape().end());
|
|
|
|
bufferShape.insert(bufferShape.begin(), numStages);
|
|
|
|
auto sharedEnc = ttg::SharedEncodingAttr::get(
|
2022-11-21 06:25:02 +01:00
|
|
|
ty.getContext(), dotOpEnc, ty.getShape(),
|
2022-11-24 14:05:54 -08:00
|
|
|
triton::gpu::getOrder(ty.getEncoding()), ty.getElementType());
|
2022-11-10 12:05:46 -08:00
|
|
|
loadsBufferType[loadOp] = RankedTensorType::get(
|
|
|
|
bufferShape, ty.getElementType(), sharedEnc);
|
2022-06-07 19:34:59 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
} else
|
2022-11-14 10:15:53 +08:00
|
|
|
isCandidate = false;
|
2022-06-07 19:34:59 +08:00
|
|
|
|
2022-11-14 10:15:53 +08:00
|
|
|
if (isCandidate)
|
2022-05-25 21:52:51 +08:00
|
|
|
loads.insert(loadOp);
|
|
|
|
}
|
|
|
|
|
2022-11-30 10:07:34 -08:00
|
|
|
// We have some loads to pipeline
|
2022-05-25 21:52:51 +08:00
|
|
|
if (!loads.empty()) {
|
2022-11-30 10:07:34 -08:00
|
|
|
// Update depArgs & depOps
|
2022-05-25 21:52:51 +08:00
|
|
|
for (Value loadOp : loads) {
|
|
|
|
for (Value dep : loadDeps[loadOp]) {
|
|
|
|
// TODO: we should record the stage that the value is depended on
|
|
|
|
if (auto arg = dep.dyn_cast<BlockArgument>())
|
|
|
|
depArgs.insert(arg);
|
|
|
|
else
|
|
|
|
depOps.insert(dep.getDefiningOp());
|
|
|
|
}
|
2022-05-11 20:31:08 +08:00
|
|
|
}
|
2022-05-25 21:52:51 +08:00
|
|
|
return success();
|
2022-05-11 20:31:08 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
|
|
|
|
void LoopPipeliner::emitPrologue() {
|
2022-05-26 13:57:01 +08:00
|
|
|
// llvm::errs() << "loads to pipeline...:\n";
|
|
|
|
// for (Value load : loads)
|
|
|
|
// llvm::errs() << load << "\n";
|
|
|
|
|
2022-05-11 20:31:08 +08:00
|
|
|
OpBuilder builder(forOp);
|
2022-05-13 21:32:35 +08:00
|
|
|
for (BlockArgument &arg : forOp.getRegionIterArgs()) {
|
|
|
|
OpOperand &operand = forOp.getOpOperandForRegionIterArg(arg);
|
|
|
|
setValueMapping(arg, operand.get(), 0);
|
|
|
|
}
|
|
|
|
|
2022-05-14 22:04:36 +08:00
|
|
|
// prologue from [0, numStage-1)
|
2022-05-15 22:29:27 +08:00
|
|
|
Value iv = forOp.getLowerBound();
|
2022-09-09 11:01:14 -07:00
|
|
|
pipelineIterIdx = builder.create<arith::ConstantIntOp>(iv.getLoc(), 0, 32);
|
2022-05-13 21:32:35 +08:00
|
|
|
for (int stage = 0; stage < numStages - 1; ++stage) {
|
2022-11-30 10:07:34 -08:00
|
|
|
// Special handling for induction variable as the increment is implicit
|
2022-05-13 21:32:35 +08:00
|
|
|
if (stage != 0)
|
|
|
|
iv = builder.create<arith::AddIOp>(iv.getLoc(), iv, forOp.getStep());
|
|
|
|
setValueMapping(forOp.getInductionVar(), iv, stage);
|
|
|
|
|
2022-11-30 10:07:34 -08:00
|
|
|
// Special handling for loop condition as there is no condition in ForOp
|
2022-05-13 21:32:35 +08:00
|
|
|
Value loopCond = builder.create<arith::CmpIOp>(
|
2022-07-26 17:25:03 -07:00
|
|
|
iv.getLoc(), arith::CmpIPredicate::slt, iv, forOp.getUpperBound());
|
2022-05-13 21:32:35 +08:00
|
|
|
|
2022-11-30 10:07:34 -08:00
|
|
|
// Rematerialize peeled values
|
2022-07-26 17:25:03 -07:00
|
|
|
SmallVector<Operation *> orderedDeps;
|
2022-05-25 21:52:51 +08:00
|
|
|
for (Operation &op : forOp.getLoopBody().front()) {
|
2022-05-13 21:32:35 +08:00
|
|
|
if (depOps.contains(&op))
|
|
|
|
orderedDeps.push_back(&op);
|
2022-05-25 21:52:51 +08:00
|
|
|
else if (loads.contains(op.getResult(0)))
|
|
|
|
orderedDeps.push_back(&op);
|
|
|
|
}
|
|
|
|
assert(depOps.size() + loads.size() == orderedDeps.size() &&
|
|
|
|
"depOps contains invalid values");
|
2022-05-13 21:32:35 +08:00
|
|
|
for (Operation *op : orderedDeps) {
|
2022-05-25 21:52:51 +08:00
|
|
|
Operation *newOp = nullptr;
|
|
|
|
if (loads.contains(op->getResult(0))) {
|
2022-09-06 23:31:13 +08:00
|
|
|
// Allocate empty buffer
|
2022-09-09 11:01:14 -07:00
|
|
|
if (stage == 0) {
|
2022-09-06 23:31:13 +08:00
|
|
|
loadsBuffer[op->getResult(0)] = allocateEmptyBuffer(op, builder);
|
2022-09-09 11:01:14 -07:00
|
|
|
loadStageBuffer[op->getResult(0)] = {loadsBuffer[op->getResult(0)]};
|
|
|
|
}
|
2022-05-25 21:52:51 +08:00
|
|
|
// load => copy async
|
|
|
|
if (auto loadOp = llvm::dyn_cast<triton::LoadOp>(op)) {
|
2022-11-09 01:29:53 +08:00
|
|
|
Value mask = lookupOrDefault(loadOp.mask(), stage);
|
|
|
|
Value newMask;
|
|
|
|
if (mask) {
|
|
|
|
Value splatCond = builder.create<triton::SplatOp>(
|
|
|
|
mask.getLoc(), mask.getType(), loopCond);
|
|
|
|
newMask =
|
|
|
|
builder.create<arith::AndIOp>(mask.getLoc(), mask, splatCond);
|
|
|
|
} else {
|
|
|
|
newMask = builder.create<triton::SplatOp>(
|
|
|
|
loopCond.getLoc(), getI1SameShape(loadOp), loopCond);
|
|
|
|
}
|
|
|
|
// TODO: check if the hardware supports async copy
|
2022-09-09 11:01:14 -07:00
|
|
|
newOp = builder.create<triton::gpu::InsertSliceAsyncOp>(
|
2022-09-06 23:31:13 +08:00
|
|
|
op->getLoc(), loadsBuffer[loadOp].getType(),
|
2022-09-09 11:01:14 -07:00
|
|
|
lookupOrDefault(loadOp.ptr(), stage),
|
2022-11-09 01:29:53 +08:00
|
|
|
loadStageBuffer[loadOp][stage], pipelineIterIdx, newMask,
|
2022-09-06 23:31:13 +08:00
|
|
|
lookupOrDefault(loadOp.other(), stage), loadOp.cache(),
|
|
|
|
loadOp.evict(), loadOp.isVolatile(), /*axis*/ 0);
|
2022-09-09 11:01:14 -07:00
|
|
|
loadStageBuffer[loadOp].push_back(newOp->getResult(0));
|
2022-05-25 21:52:51 +08:00
|
|
|
} else
|
|
|
|
llvm_unreachable("This should be LoadOp");
|
2022-09-06 23:31:13 +08:00
|
|
|
} else {
|
2022-05-25 21:52:51 +08:00
|
|
|
newOp = builder.clone(*op);
|
2022-09-06 23:31:13 +08:00
|
|
|
// Update loop-carried uses
|
|
|
|
for (unsigned opIdx = 0; opIdx < op->getNumOperands(); ++opIdx) {
|
|
|
|
auto it = valueMapping.find(op->getOperand(opIdx));
|
|
|
|
if (it != valueMapping.end()) {
|
|
|
|
Value v = it->second[stage];
|
|
|
|
assert(v);
|
|
|
|
newOp->setOperand(opIdx, v);
|
|
|
|
} // else, op at opIdx is a loop-invariant value
|
|
|
|
}
|
2022-05-13 21:32:35 +08:00
|
|
|
}
|
|
|
|
|
2022-11-30 10:07:34 -08:00
|
|
|
// Update mapping of results
|
2022-05-13 21:32:35 +08:00
|
|
|
for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) {
|
2022-06-07 19:34:59 +08:00
|
|
|
Value originalResult = op->getResult(dstIdx);
|
|
|
|
// copy_async will update the value of its only use
|
2022-11-14 10:15:53 +08:00
|
|
|
// TODO: load should not be used in the preheader?
|
2022-09-09 11:01:14 -07:00
|
|
|
if (loads.contains(originalResult)) {
|
|
|
|
break;
|
|
|
|
// originalResult = loadsMapping[originalResult];
|
|
|
|
}
|
2022-06-07 19:34:59 +08:00
|
|
|
setValueMapping(originalResult, newOp->getResult(dstIdx), stage);
|
2022-05-14 22:04:36 +08:00
|
|
|
// update mapping for loop-carried values (args)
|
2022-05-25 21:52:51 +08:00
|
|
|
for (OpOperand &operand : yieldOp->getOpOperands()) {
|
2022-05-14 22:04:36 +08:00
|
|
|
if (operand.get() == op->getResult(dstIdx))
|
2022-07-26 17:25:03 -07:00
|
|
|
setValueMapping(
|
|
|
|
forOp.getRegionIterArgs()[operand.getOperandNumber()],
|
|
|
|
newOp->getResult(dstIdx), stage + 1);
|
2022-05-14 22:04:36 +08:00
|
|
|
}
|
2022-05-13 21:32:35 +08:00
|
|
|
}
|
2022-11-09 01:29:53 +08:00
|
|
|
} // for (Operation *op : orderedDeps)
|
2022-09-09 11:01:14 -07:00
|
|
|
|
|
|
|
pipelineIterIdx = builder.create<arith::AddIOp>(
|
|
|
|
iv.getLoc(), pipelineIterIdx,
|
|
|
|
builder.create<arith::ConstantIntOp>(iv.getLoc(), 1, 32));
|
|
|
|
} // for (int stage = 0; stage < numStages - 1; ++stage)
|
2022-09-06 23:31:13 +08:00
|
|
|
|
|
|
|
// async.wait & extract_slice
|
2022-11-10 13:57:27 +08:00
|
|
|
builder.create<ttg::AsyncWaitOp>(loads[0].getLoc(),
|
|
|
|
loads.size() * (numStages - 2));
|
2022-09-15 14:26:40 +08:00
|
|
|
loopIterIdx = builder.create<arith::ConstantIntOp>(iv.getLoc(), 0, 32);
|
2022-09-09 11:01:14 -07:00
|
|
|
for (Value loadOp : loads) {
|
2022-11-06 22:59:03 -08:00
|
|
|
auto sliceType = loadsMapping[loadOp].getType().cast<RankedTensorType>();
|
2022-11-10 13:57:27 +08:00
|
|
|
sliceType =
|
|
|
|
RankedTensorType::get(sliceType.getShape(), sliceType.getElementType(),
|
|
|
|
loadsBufferType[loadOp].getEncoding());
|
2022-11-06 22:59:03 -08:00
|
|
|
Value extractSlice = builder.create<tensor::ExtractSliceOp>(
|
|
|
|
loadOp.getLoc(), sliceType, loadStageBuffer[loadOp][numStages - 1],
|
2022-11-30 10:07:34 -08:00
|
|
|
SmallVector<OpFoldResult>{int_attr(0), int_attr(0), int_attr(0)},
|
|
|
|
SmallVector<OpFoldResult>{int_attr(1),
|
|
|
|
int_attr(sliceType.getShape()[0]),
|
|
|
|
int_attr(sliceType.getShape()[1])},
|
|
|
|
SmallVector<OpFoldResult>{int_attr(1), int_attr(1), int_attr(1)});
|
2022-09-09 11:01:14 -07:00
|
|
|
loadsExtract[loadOp] = extractSlice;
|
2022-09-06 23:31:13 +08:00
|
|
|
}
|
2022-11-30 10:07:34 -08:00
|
|
|
// Bump up loopIterIdx, this is used for getting the correct slice for the
|
2022-10-27 22:09:06 -07:00
|
|
|
// *next* iteration
|
|
|
|
loopIterIdx = builder.create<arith::AddIOp>(
|
|
|
|
loopIterIdx.getLoc(), loopIterIdx,
|
|
|
|
builder.create<arith::ConstantIntOp>(loopIterIdx.getLoc(), 1, 32));
|
|
|
|
}
|
|
|
|
|
|
|
|
void LoopPipeliner::emitEpilogue() {
|
|
|
|
// If there's any outstanding async copies, we need to wait for them.
|
|
|
|
OpBuilder builder(forOp);
|
|
|
|
OpBuilder::InsertionGuard g(builder);
|
|
|
|
builder.setInsertionPointAfter(forOp);
|
2022-10-28 12:36:09 -07:00
|
|
|
builder.create<triton::gpu::AsyncWaitOp>(forOp.getLoc(), 0);
|
2022-05-11 20:31:08 +08:00
|
|
|
}
|
|
|
|
|
2022-05-14 22:04:36 +08:00
|
|
|
scf::ForOp LoopPipeliner::createNewForOp() {
|
|
|
|
OpBuilder builder(forOp);
|
|
|
|
|
2022-11-30 10:07:34 -08:00
|
|
|
// Order of new args:
|
2022-12-06 09:08:55 -08:00
|
|
|
// (original args)
|
|
|
|
// (insertSliceAsync buffer at stage numStages - 1) for each load
|
|
|
|
// (extracted tensor) for each load
|
|
|
|
// (depArgs at stage numStages - 1)
|
|
|
|
// (iv at stage numStages - 2)
|
2022-09-09 11:01:14 -07:00
|
|
|
// (pipeline iteration index)
|
|
|
|
// (loop iteration index)
|
2022-05-14 22:04:36 +08:00
|
|
|
SmallVector<Value> newLoopArgs;
|
2022-05-15 22:29:27 +08:00
|
|
|
// We need this to update operands for yield
|
|
|
|
// original block arg => new arg's idx
|
|
|
|
DenseMap<BlockArgument, size_t> depArgsIdx;
|
2022-05-14 22:04:36 +08:00
|
|
|
for (auto v : forOp.getIterOperands())
|
|
|
|
newLoopArgs.push_back(v);
|
2022-05-25 21:52:51 +08:00
|
|
|
|
2022-09-09 11:01:14 -07:00
|
|
|
size_t bufferIdx = newLoopArgs.size();
|
|
|
|
for (Value loadOp : loads)
|
|
|
|
newLoopArgs.push_back(loadStageBuffer[loadOp].back());
|
2022-05-25 21:52:51 +08:00
|
|
|
size_t loadIdx = newLoopArgs.size();
|
|
|
|
for (Value loadOp : loads)
|
2022-09-09 11:01:14 -07:00
|
|
|
newLoopArgs.push_back(loadsExtract[loadOp]);
|
2022-05-25 21:52:51 +08:00
|
|
|
|
2022-05-14 22:04:36 +08:00
|
|
|
size_t depArgsBeginIdx = newLoopArgs.size();
|
2022-05-15 22:29:27 +08:00
|
|
|
for (BlockArgument depArg : depArgs) {
|
|
|
|
depArgsIdx[depArg] = newLoopArgs.size();
|
2022-07-26 17:25:03 -07:00
|
|
|
newLoopArgs.push_back(valueMapping[depArg][numStages - 1]);
|
2022-05-15 22:29:27 +08:00
|
|
|
}
|
2022-05-25 21:52:51 +08:00
|
|
|
|
2022-05-14 22:04:36 +08:00
|
|
|
size_t nextIVIdx = newLoopArgs.size();
|
2022-07-26 17:25:03 -07:00
|
|
|
newLoopArgs.push_back(valueMapping[forOp.getInductionVar()][numStages - 2]);
|
2022-09-09 11:01:14 -07:00
|
|
|
newLoopArgs.push_back(pipelineIterIdx);
|
|
|
|
newLoopArgs.push_back(loopIterIdx);
|
2022-05-15 22:29:27 +08:00
|
|
|
|
|
|
|
for (size_t i = 0; i < newLoopArgs.size(); ++i)
|
|
|
|
assert(newLoopArgs[i]);
|
2022-05-14 22:04:36 +08:00
|
|
|
|
2022-05-25 21:52:51 +08:00
|
|
|
// 1. signature of the new ForOp
|
2022-07-26 17:25:03 -07:00
|
|
|
auto newForOp = builder.create<scf::ForOp>(
|
|
|
|
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
|
|
|
|
forOp.getStep(), newLoopArgs);
|
2022-05-14 22:04:36 +08:00
|
|
|
|
2022-05-25 21:52:51 +08:00
|
|
|
// 2. body of the new ForOp
|
2022-05-14 22:04:36 +08:00
|
|
|
builder.setInsertionPointToStart(newForOp.getBody());
|
|
|
|
BlockAndValueMapping mapping;
|
|
|
|
for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs()))
|
|
|
|
mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
|
2022-12-06 09:08:55 -08:00
|
|
|
mapping.map(forOp.getInductionVar(), newForOp.getInductionVar());
|
2022-05-25 21:52:51 +08:00
|
|
|
|
2022-06-07 19:34:59 +08:00
|
|
|
// 2.1 clone the loop body, replace original args with args of the new ForOp
|
2022-08-19 01:31:57 +08:00
|
|
|
// Insert async wait if necessary.
|
2022-05-14 22:04:36 +08:00
|
|
|
for (Operation &op : forOp.getBody()->without_terminator()) {
|
|
|
|
Operation *newOp = builder.clone(op, mapping);
|
|
|
|
// update mapping of results
|
|
|
|
for (unsigned dstIdx : llvm::seq(unsigned(0), op.getNumResults()))
|
|
|
|
mapping.map(op.getResult(dstIdx), newOp->getResult(dstIdx));
|
|
|
|
}
|
2022-05-25 21:52:51 +08:00
|
|
|
|
2022-05-26 13:14:41 +08:00
|
|
|
// 3. replace loads with block args (from prologue)
|
2022-05-25 21:52:51 +08:00
|
|
|
for (size_t idx = 0; idx < loads.size(); ++idx) {
|
|
|
|
Value load = loads[idx];
|
2022-07-26 17:25:03 -07:00
|
|
|
assert(load.hasOneUse() &&
|
|
|
|
"we assume that this load has one use (ConvertLayout)");
|
2022-06-07 19:34:59 +08:00
|
|
|
Value loadUse = load.getUsers().begin()->getResult(0);
|
|
|
|
mapping.lookup(loadUse).replaceAllUsesWith(
|
2022-09-09 11:01:14 -07:00
|
|
|
newForOp.getRegionIterArgs()[loadIdx + idx]);
|
2022-08-18 12:49:37 -07:00
|
|
|
// delete old load and layout conversion
|
|
|
|
mapping.lookup(loadUse).getDefiningOp()->erase();
|
|
|
|
mapping.lookup(load).getDefiningOp()->erase();
|
2022-05-25 21:52:51 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
// 4. prefetch the next iteration
|
2022-07-26 17:25:03 -07:00
|
|
|
SmallVector<Operation *> orderedDeps;
|
2022-05-25 21:52:51 +08:00
|
|
|
for (Operation &op : forOp.getLoopBody().front()) {
|
2022-05-14 22:04:36 +08:00
|
|
|
if (depOps.contains(&op))
|
|
|
|
orderedDeps.push_back(&op);
|
2022-05-25 21:52:51 +08:00
|
|
|
else if (loads.contains(op.getResult(0)))
|
|
|
|
orderedDeps.push_back(&op);
|
|
|
|
}
|
|
|
|
assert(depOps.size() + loads.size() == orderedDeps.size() &&
|
|
|
|
"depOps contains invalid values");
|
2022-05-14 22:04:36 +08:00
|
|
|
BlockAndValueMapping nextMapping;
|
2022-05-15 22:29:27 +08:00
|
|
|
DenseMap<BlockArgument, Value> depArgsMapping;
|
2022-05-14 22:04:36 +08:00
|
|
|
size_t argIdx = 0;
|
|
|
|
for (BlockArgument arg : depArgs) {
|
2022-07-26 17:25:03 -07:00
|
|
|
nextMapping.map(arg,
|
|
|
|
newForOp.getRegionIterArgs()[argIdx + depArgsBeginIdx]);
|
2022-05-14 22:04:36 +08:00
|
|
|
++argIdx;
|
|
|
|
}
|
2022-11-30 10:07:34 -08:00
|
|
|
// Special handling for iv & loop condition
|
2022-07-26 17:25:03 -07:00
|
|
|
Value nextIV = builder.create<arith::AddIOp>(
|
|
|
|
newForOp.getInductionVar().getLoc(),
|
|
|
|
newForOp.getRegionIterArgs()[nextIVIdx], newForOp.getStep());
|
|
|
|
Value nextLoopCond =
|
|
|
|
builder.create<arith::CmpIOp>(nextIV.getLoc(), arith::CmpIPredicate::slt,
|
|
|
|
nextIV, newForOp.getUpperBound());
|
2022-12-06 09:08:55 -08:00
|
|
|
nextMapping.map(forOp.getInductionVar(), nextIV);
|
2022-09-06 23:31:13 +08:00
|
|
|
|
2022-11-30 10:07:34 -08:00
|
|
|
// Slice index
|
2022-09-09 11:01:14 -07:00
|
|
|
SmallVector<Value> nextBuffers;
|
2022-09-06 23:31:13 +08:00
|
|
|
SmallVector<Value> extractSlices;
|
2022-09-09 11:01:14 -07:00
|
|
|
|
|
|
|
pipelineIterIdx = newForOp.getRegionIterArgs()[nextIVIdx + 1];
|
|
|
|
Value insertSliceIndex = builder.create<arith::RemSIOp>(
|
|
|
|
nextIV.getLoc(), pipelineIterIdx,
|
|
|
|
builder.create<arith::ConstantIntOp>(nextIV.getLoc(), numStages, 32));
|
|
|
|
loopIterIdx = newForOp.getRegionIterArgs()[nextIVIdx + 2];
|
|
|
|
Value extractSliceIndex = builder.create<arith::RemSIOp>(
|
|
|
|
nextIV.getLoc(), loopIterIdx,
|
2022-09-06 23:31:13 +08:00
|
|
|
builder.create<arith::ConstantIntOp>(nextIV.getLoc(), numStages, 32));
|
2022-11-06 22:59:03 -08:00
|
|
|
extractSliceIndex = builder.create<arith::IndexCastOp>(
|
|
|
|
extractSliceIndex.getLoc(), builder.getIndexType(), extractSliceIndex);
|
|
|
|
|
2022-05-14 22:04:36 +08:00
|
|
|
for (Operation *op : orderedDeps) {
|
2022-05-25 21:52:51 +08:00
|
|
|
Operation *nextOp = nullptr;
|
2022-11-30 10:07:34 -08:00
|
|
|
// Update loading mask
|
2022-05-25 21:52:51 +08:00
|
|
|
if (loads.contains(op->getResult(0))) {
|
2022-05-14 22:04:36 +08:00
|
|
|
auto loadOp = llvm::cast<triton::LoadOp>(op);
|
|
|
|
Value mask = loadOp.mask();
|
2022-11-09 01:29:53 +08:00
|
|
|
Value newMask;
|
2022-08-22 22:00:17 -07:00
|
|
|
if (mask) {
|
|
|
|
Value splatCond = builder.create<triton::SplatOp>(
|
|
|
|
mask.getLoc(), mask.getType(), nextLoopCond);
|
2022-11-09 01:29:53 +08:00
|
|
|
newMask = builder.create<arith::AndIOp>(
|
2022-08-22 22:00:17 -07:00
|
|
|
mask.getLoc(), splatCond, nextMapping.lookupOrDefault(mask));
|
2022-11-30 10:07:34 -08:00
|
|
|
// If mask is defined outside the loop, don't update the map more than
|
2022-08-22 22:00:17 -07:00
|
|
|
// once
|
|
|
|
if (!(forOp.isDefinedOutsideOfLoop(mask) && nextMapping.contains(mask)))
|
|
|
|
nextMapping.map(mask, newMask);
|
2022-11-09 01:29:53 +08:00
|
|
|
newMask = nextMapping.lookupOrDefault(loadOp.mask());
|
|
|
|
} else
|
|
|
|
newMask = builder.create<triton::SplatOp>(
|
|
|
|
loadOp.getLoc(), getI1SameShape(loadOp), nextLoopCond);
|
2022-09-06 23:31:13 +08:00
|
|
|
Value insertAsyncOp = builder.create<triton::gpu::InsertSliceAsyncOp>(
|
|
|
|
op->getLoc(), loadsBuffer[loadOp].getType(),
|
2022-09-09 11:01:14 -07:00
|
|
|
nextMapping.lookupOrDefault(loadOp.ptr()),
|
|
|
|
newForOp.getRegionIterArgs()[bufferIdx + nextBuffers.size()],
|
2022-11-09 01:29:53 +08:00
|
|
|
insertSliceIndex, newMask,
|
2022-07-26 17:25:03 -07:00
|
|
|
nextMapping.lookupOrDefault(loadOp.other()), loadOp.cache(),
|
2022-09-06 23:31:13 +08:00
|
|
|
loadOp.evict(), loadOp.isVolatile(), /*axis*/ 0);
|
2022-09-09 11:01:14 -07:00
|
|
|
nextBuffers.push_back(insertAsyncOp);
|
2022-11-06 22:59:03 -08:00
|
|
|
auto sliceType = loadsMapping[loadOp].getType().cast<RankedTensorType>();
|
2022-11-10 13:57:27 +08:00
|
|
|
sliceType = RankedTensorType::get(sliceType.getShape(),
|
|
|
|
sliceType.getElementType(),
|
|
|
|
loadsBufferType[loadOp].getEncoding());
|
2022-11-06 22:59:03 -08:00
|
|
|
nextOp = builder.create<tensor::ExtractSliceOp>(
|
|
|
|
op->getLoc(), sliceType, insertAsyncOp,
|
2022-11-30 10:07:34 -08:00
|
|
|
SmallVector<OpFoldResult>{extractSliceIndex, int_attr(0),
|
|
|
|
int_attr(0)},
|
|
|
|
SmallVector<OpFoldResult>{int_attr(1),
|
|
|
|
int_attr(sliceType.getShape()[0]),
|
|
|
|
int_attr(sliceType.getShape()[1])},
|
|
|
|
SmallVector<OpFoldResult>{int_attr(1), int_attr(1), int_attr(1)});
|
2022-09-06 23:31:13 +08:00
|
|
|
extractSlices.push_back(nextOp->getResult(0));
|
2022-07-26 17:25:03 -07:00
|
|
|
} else
|
2022-05-25 21:52:51 +08:00
|
|
|
nextOp = builder.clone(*op, nextMapping);
|
2022-11-30 10:07:34 -08:00
|
|
|
// Update mapping of results
|
2022-05-15 22:29:27 +08:00
|
|
|
for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) {
|
2022-05-14 22:04:36 +08:00
|
|
|
nextMapping.map(op->getResult(dstIdx), nextOp->getResult(dstIdx));
|
2022-11-30 10:07:34 -08:00
|
|
|
// If this is a loop-carried value, update the mapping for yield
|
2022-05-15 22:29:27 +08:00
|
|
|
auto originYield = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
|
|
|
|
for (OpOperand &operand : originYield->getOpOperands()) {
|
|
|
|
if (operand.get() == op->getResult(dstIdx)) {
|
|
|
|
size_t originIdx = operand.getOperandNumber();
|
|
|
|
size_t newArgIdx = depArgsIdx[forOp.getRegionIterArgs()[originIdx]];
|
|
|
|
BlockArgument newArg = newForOp.getRegionIterArgs()[newArgIdx];
|
|
|
|
depArgsMapping[newArg] = nextOp->getResult(dstIdx);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
2022-05-14 22:04:36 +08:00
|
|
|
}
|
|
|
|
|
2022-11-10 13:57:27 +08:00
|
|
|
{
|
|
|
|
OpBuilder::InsertionGuard guard(builder);
|
|
|
|
for (Operation &op : *newForOp.getBody()) {
|
|
|
|
if (auto dotOp = llvm::dyn_cast<triton::DotOp>(&op)) {
|
|
|
|
builder.setInsertionPoint(&op);
|
|
|
|
auto dotType = dotOp.getType().cast<RankedTensorType>();
|
|
|
|
Value a = dotOp.a();
|
|
|
|
Value b = dotOp.b();
|
|
|
|
auto layoutCast = [&](Value dotOperand, int opIdx) -> Value {
|
|
|
|
auto tensorType = dotOperand.getType().cast<RankedTensorType>();
|
|
|
|
if (!tensorType.getEncoding().isa<ttg::DotOperandEncodingAttr>()) {
|
|
|
|
auto newEncoding = ttg::DotOperandEncodingAttr::get(
|
|
|
|
tensorType.getContext(), opIdx, dotType.getEncoding());
|
|
|
|
auto newType =
|
|
|
|
RankedTensorType::get(tensorType.getShape(),
|
|
|
|
tensorType.getElementType(), newEncoding);
|
|
|
|
return builder.create<ttg::ConvertLayoutOp>(dotOperand.getLoc(),
|
|
|
|
newType, dotOperand);
|
|
|
|
}
|
|
|
|
return dotOperand;
|
|
|
|
};
|
|
|
|
a = layoutCast(a, 0);
|
|
|
|
b = layoutCast(b, 1);
|
|
|
|
dotOp->setOperand(0, a);
|
|
|
|
dotOp->setOperand(1, b);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2022-09-06 23:31:13 +08:00
|
|
|
// async.wait & extract_slice
|
2022-11-10 13:57:27 +08:00
|
|
|
Operation *asyncWait = builder.create<ttg::AsyncWaitOp>(
|
2022-09-06 23:31:13 +08:00
|
|
|
loads[0].getLoc(), loads.size() * (numStages - 2));
|
|
|
|
for (auto it = extractSlices.rbegin(); it != extractSlices.rend(); ++it) {
|
|
|
|
// move extract_slice after asyncWait
|
|
|
|
it->getDefiningOp()->moveAfter(asyncWait);
|
|
|
|
}
|
|
|
|
|
2022-11-30 10:07:34 -08:00
|
|
|
// Bump iteration count
|
2022-09-09 11:01:14 -07:00
|
|
|
pipelineIterIdx = builder.create<arith::AddIOp>(
|
|
|
|
nextIV.getLoc(), pipelineIterIdx,
|
|
|
|
builder.create<arith::ConstantIntOp>(nextIV.getLoc(), 1, 32));
|
|
|
|
loopIterIdx = builder.create<arith::AddIOp>(
|
|
|
|
nextIV.getLoc(), loopIterIdx,
|
|
|
|
builder.create<arith::ConstantIntOp>(nextIV.getLoc(), 1, 32));
|
|
|
|
|
2022-05-14 22:04:36 +08:00
|
|
|
// Finally, the YieldOp, need to sync with the order of newLoopArgs
|
|
|
|
SmallVector<Value> yieldValues;
|
|
|
|
for (Value v : forOp.getBody()->getTerminator()->getOperands())
|
|
|
|
yieldValues.push_back(mapping.lookup(v));
|
2022-09-09 11:01:14 -07:00
|
|
|
for (Value nextBuffer : nextBuffers)
|
|
|
|
yieldValues.push_back(nextBuffer);
|
|
|
|
for (Value nextSlice : extractSlices)
|
|
|
|
yieldValues.push_back(nextSlice);
|
2022-05-25 21:52:51 +08:00
|
|
|
|
2022-12-06 09:08:55 -08:00
|
|
|
for (size_t i = depArgsBeginIdx; i < nextIVIdx; ++i) {
|
|
|
|
auto arg = newForOp.getRegionIterArgs()[i];
|
|
|
|
assert(depArgsMapping.count(arg) && "Missing loop-carried value");
|
|
|
|
yieldValues.push_back(depArgsMapping[arg]);
|
|
|
|
}
|
2022-05-14 22:04:36 +08:00
|
|
|
yieldValues.push_back(nextIV);
|
2022-09-09 11:01:14 -07:00
|
|
|
yieldValues.push_back(pipelineIterIdx);
|
|
|
|
yieldValues.push_back(loopIterIdx);
|
2022-08-18 12:49:37 -07:00
|
|
|
|
2022-05-15 22:29:27 +08:00
|
|
|
builder.setInsertionPointToEnd(newForOp.getBody());
|
2022-10-28 12:36:09 -07:00
|
|
|
builder.create<scf::YieldOp>(forOp.getBody()->getTerminator()->getLoc(),
|
|
|
|
yieldValues);
|
2022-05-14 22:04:36 +08:00
|
|
|
return newForOp;
|
|
|
|
}
|
|
|
|
|
2022-05-11 20:31:08 +08:00
|
|
|
// ref: mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
|
2022-05-11 16:13:53 +08:00
|
|
|
struct PipelinePass : public TritonGPUPipelineBase<PipelinePass> {
|
2022-05-23 12:47:55 +08:00
|
|
|
PipelinePass() = default;
|
2022-07-26 17:25:03 -07:00
|
|
|
PipelinePass(int numStages) { this->numStages = numStages; }
|
2022-05-23 12:47:55 +08:00
|
|
|
|
2022-05-11 16:13:53 +08:00
|
|
|
void runOnOperation() override {
|
2022-05-23 12:47:55 +08:00
|
|
|
int numStages = this->numStages;
|
2022-05-11 20:31:08 +08:00
|
|
|
|
|
|
|
if (numStages <= 1)
|
|
|
|
return;
|
|
|
|
|
|
|
|
getOperation()->walk([&](scf::ForOp forOp) -> void {
|
|
|
|
LoopPipeliner pipeliner(forOp, numStages);
|
|
|
|
|
|
|
|
if (pipeliner.initialize().failed())
|
|
|
|
return;
|
2022-05-11 16:13:53 +08:00
|
|
|
|
2022-05-14 22:04:36 +08:00
|
|
|
pipeliner.emitPrologue();
|
2022-05-13 21:32:35 +08:00
|
|
|
|
2022-05-14 22:04:36 +08:00
|
|
|
scf::ForOp newForOp = pipeliner.createNewForOp();
|
2022-05-13 21:32:35 +08:00
|
|
|
|
2022-10-27 22:09:06 -07:00
|
|
|
pipeliner.emitEpilogue();
|
|
|
|
|
2022-05-15 22:29:27 +08:00
|
|
|
// replace the original loop
|
|
|
|
for (unsigned i = 0; i < forOp->getNumResults(); ++i)
|
|
|
|
forOp->getResult(i).replaceAllUsesWith(newForOp->getResult(i));
|
|
|
|
forOp->erase();
|
2022-05-11 16:13:53 +08:00
|
|
|
});
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // anonymous namespace
|
|
|
|
|
2022-05-23 12:47:55 +08:00
|
|
|
std::unique_ptr<Pass> mlir::createTritonGPUPipelinePass(int numStages) {
|
|
|
|
return std::make_unique<PipelinePass>(numStages);
|
2022-05-11 16:13:53 +08:00
|
|
|
}
|