[CI] run clang-format (#24)

This commit is contained in:
Philippe Tillet
2022-07-26 17:25:03 -07:00
committed by GitHub
parent 25357083e6
commit 6d62d88d4f
62 changed files with 13673 additions and 11367 deletions

View File

@@ -6,12 +6,11 @@
//===----------------------------------------------------------------------===//
//
// This file implements loop software pipelining
// The implementation here is inspired by the pipeline pass in Triton (-v2.0)
// The implementation here is inspired by the pipeline pass in Triton (-v2.0)
// and SCF's LoopPipelining.
//
//===----------------------------------------------------------------------===//
using namespace mlir;
#define GEN_PASS_CLASSES
@@ -41,14 +40,15 @@ class LoopPipeliner {
/// Block arguments that loads depend on
DenseSet<BlockArgument> depArgs;
/// Operations (inside the loop body) that loads depend on
DenseSet<Operation*> depOps;
DenseSet<Operation *> depOps;
/// collect values that v depends on and are defined inside the loop
void collectDeps(Value v, int stages, DenseSet<Value> &deps);
void setValueMapping(Value origin, Value newValue, int stage);
public:
LoopPipeliner(scf::ForOp forOp, int numStages)
LoopPipeliner(scf::ForOp forOp, int numStages)
: forOp(forOp), numStages(numStages) {
// cache yieldOp
yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
@@ -86,7 +86,7 @@ void LoopPipeliner::collectDeps(Value v, int stages, DenseSet<Value> &deps) {
if (auto arg = v.dyn_cast<BlockArgument>()) {
deps.insert(v);
// Note: we have iv as the first arg, so the op idx is arg.getArgNumber()-1
collectDeps(yieldOp->getOperand(arg.getArgNumber() - 1), stages-1, deps);
collectDeps(yieldOp->getOperand(arg.getArgNumber() - 1), stages - 1, deps);
} else { // value
// v might be in deps, but we still need to visit v.
// This is because v might depends on value in previous iterations
@@ -123,8 +123,8 @@ LogicalResult LoopPipeliner::initialize() {
}
// for (triton::LoadOp loadOp : allLoads) {
// llvm::errs() << loadOp << " depends on: #" << loadDeps[loadOp].size() << " values\n";
// for (Value dep : loadDeps[loadOp])
// llvm::errs() << loadOp << " depends on: #" << loadDeps[loadOp].size() <<
// " values\n"; for (Value dep : loadDeps[loadOp])
// llvm::errs() << dep << "\n";
// llvm::errs() << "\n";
// }
@@ -147,9 +147,13 @@ LogicalResult LoopPipeliner::initialize() {
if (isCandiate && loadOp.getResult().hasOneUse()) {
isCandiate = false;
Operation *use = *loadOp.getResult().getUsers().begin();
if (auto convertLayout = llvm::dyn_cast<triton::gpu::ConvertLayoutOp>(use)) {
if (auto tensorType = convertLayout.getResult().getType().dyn_cast<RankedTensorType>()) {
if (tensorType.getEncoding().isa<triton::gpu::TritonGPUSharedEncodingAttr>()) {
if (auto convertLayout =
llvm::dyn_cast<triton::gpu::ConvertLayoutOp>(use)) {
if (auto tensorType = convertLayout.getResult()
.getType()
.dyn_cast<RankedTensorType>()) {
if (tensorType.getEncoding()
.isa<triton::gpu::TritonGPUSharedEncodingAttr>()) {
isCandiate = true;
loadsMapping[loadOp] = convertLayout;
}
@@ -162,7 +166,6 @@ LogicalResult LoopPipeliner::initialize() {
loads.insert(loadOp);
}
// we have some loads to pipeline
if (!loads.empty()) {
// update depArgs & depOps
@@ -202,10 +205,10 @@ void LoopPipeliner::emitPrologue() {
// special handling for loop condition as there is no condition in ForOp
Value loopCond = builder.create<arith::CmpIOp>(
iv.getLoc(), arith::CmpIPredicate::slt, iv, forOp.getUpperBound());
iv.getLoc(), arith::CmpIPredicate::slt, iv, forOp.getUpperBound());
// rematerialize peeled values
SmallVector<Operation*> orderedDeps;
SmallVector<Operation *> orderedDeps;
for (Operation &op : forOp.getLoopBody().front()) {
if (depOps.contains(&op))
orderedDeps.push_back(&op);
@@ -221,10 +224,9 @@ void LoopPipeliner::emitPrologue() {
// TODO: check if the hardware supports copyasync
if (auto loadOp = llvm::dyn_cast<triton::LoadOp>(op)) {
newOp = builder.create<triton::gpu::CopyAsyncOp>(
op->getLoc(), loadsMapping[loadOp].getType(),
loadOp.ptr(), loadOp.mask(), loadOp.other(),
loadOp.cache(), loadOp.evict(), loadOp.isVolatile()
);
op->getLoc(), loadsMapping[loadOp].getType(), loadOp.ptr(),
loadOp.mask(), loadOp.other(), loadOp.cache(), loadOp.evict(),
loadOp.isVolatile());
} else
llvm_unreachable("This should be LoadOp");
} else
@@ -245,12 +247,10 @@ void LoopPipeliner::emitPrologue() {
// assert(I1 or TensorOf<[I1]>);
OpBuilder::InsertionGuard g(builder);
builder.setInsertionPoint(newOp);
Value splatCond = builder.create<triton::BroadcastOp>(mask.getLoc(),
mask.getType(),
loopCond);
Value newMask = builder.create<arith::AndIOp>(mask.getLoc(),
mask,
splatCond);
Value splatCond = builder.create<triton::BroadcastOp>(
mask.getLoc(), mask.getType(), loopCond);
Value newMask =
builder.create<arith::AndIOp>(mask.getLoc(), mask, splatCond);
newOp->setOperand(1, newMask);
}
@@ -264,8 +264,9 @@ void LoopPipeliner::emitPrologue() {
// update mapping for loop-carried values (args)
for (OpOperand &operand : yieldOp->getOpOperands()) {
if (operand.get() == op->getResult(dstIdx))
setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()],
newOp->getResult(dstIdx), stage + 1);
setValueMapping(
forOp.getRegionIterArgs()[operand.getOperandNumber()],
newOp->getResult(dstIdx), stage + 1);
}
}
}
@@ -296,21 +297,19 @@ scf::ForOp LoopPipeliner::createNewForOp() {
size_t depArgsBeginIdx = newLoopArgs.size();
for (BlockArgument depArg : depArgs) {
depArgsIdx[depArg] = newLoopArgs.size();
newLoopArgs.push_back(valueMapping[depArg][numStages-1]);
newLoopArgs.push_back(valueMapping[depArg][numStages - 1]);
}
size_t nextIVIdx = newLoopArgs.size();
newLoopArgs.push_back(valueMapping[forOp.getInductionVar()][numStages-2]);
newLoopArgs.push_back(valueMapping[forOp.getInductionVar()][numStages - 2]);
for (size_t i = 0; i < newLoopArgs.size(); ++i)
assert(newLoopArgs[i]);
// 1. signature of the new ForOp
auto newForOp = builder.create<scf::ForOp>(forOp.getLoc(),
forOp.getLowerBound(),
forOp.getUpperBound(),
forOp.getStep(),
newLoopArgs);
auto newForOp = builder.create<scf::ForOp>(
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
forOp.getStep(), newLoopArgs);
// 2. body of the new ForOp
builder.setInsertionPointToStart(newForOp.getBody());
@@ -329,15 +328,15 @@ scf::ForOp LoopPipeliner::createNewForOp() {
// 3. replace loads with block args (from prologue)
for (size_t idx = 0; idx < loads.size(); ++idx) {
Value load = loads[idx];
assert(load.hasOneUse() && "we assume that this load has one use (ConvertLayout)");
assert(load.hasOneUse() &&
"we assume that this load has one use (ConvertLayout)");
Value loadUse = load.getUsers().begin()->getResult(0);
mapping.lookup(loadUse).replaceAllUsesWith(
newForOp.getRegionIterArgs()[loadIdx + idx*(numStages-1)]);
newForOp.getRegionIterArgs()[loadIdx + idx * (numStages - 1)]);
}
// 4. prefetch the next iteration
SmallVector<Operation*> orderedDeps;
SmallVector<Operation *> orderedDeps;
for (Operation &op : forOp.getLoopBody().front()) {
if (depOps.contains(&op))
orderedDeps.push_back(&op);
@@ -350,41 +349,39 @@ scf::ForOp LoopPipeliner::createNewForOp() {
DenseMap<BlockArgument, Value> depArgsMapping;
size_t argIdx = 0;
for (BlockArgument arg : depArgs) {
nextMapping.map(arg, newForOp.getRegionIterArgs()[argIdx + depArgsBeginIdx]);
nextMapping.map(arg,
newForOp.getRegionIterArgs()[argIdx + depArgsBeginIdx]);
++argIdx;
}
// special handling for iv & loop condition
Value nextIV = builder.create<arith::AddIOp>(newForOp.getInductionVar().getLoc(),
newForOp.getRegionIterArgs()[nextIVIdx],
newForOp.getStep());
Value nextLoopCond = builder.create<arith::CmpIOp>(
nextIV.getLoc(), arith::CmpIPredicate::slt,
nextIV, newForOp.getUpperBound());
Value nextIV = builder.create<arith::AddIOp>(
newForOp.getInductionVar().getLoc(),
newForOp.getRegionIterArgs()[nextIVIdx], newForOp.getStep());
Value nextLoopCond =
builder.create<arith::CmpIOp>(nextIV.getLoc(), arith::CmpIPredicate::slt,
nextIV, newForOp.getUpperBound());
for (Operation *op : orderedDeps) {
Operation *nextOp = nullptr;
// update loading mask
if (loads.contains(op->getResult(0))) {
auto loadOp = llvm::cast<triton::LoadOp>(op);
Value mask = loadOp.mask();
Value splatCond = builder.create<triton::BroadcastOp>(mask.getLoc(),
mask.getType(),
nextLoopCond);
Value newMask = builder.create<arith::AndIOp>(mask.getLoc(),
splatCond,
nextMapping.lookupOrDefault(mask));
// if mask is defined outside the loop, don't update the map more than once
Value splatCond = builder.create<triton::BroadcastOp>(
mask.getLoc(), mask.getType(), nextLoopCond);
Value newMask = builder.create<arith::AndIOp>(
mask.getLoc(), splatCond, nextMapping.lookupOrDefault(mask));
// if mask is defined outside the loop, don't update the map more than
// once
if (!(forOp.isDefinedOutsideOfLoop(mask) && nextMapping.contains(mask)))
nextMapping.map(mask, newMask);
// TODO: more elegant way to do this?
nextOp = builder.create<triton::gpu::CopyAsyncOp>(
op->getLoc(), loadsMapping[op->getResult(0)].getType(),
nextMapping.lookupOrDefault(loadOp.ptr()),
nextMapping.lookupOrDefault(loadOp.mask()),
nextMapping.lookupOrDefault(loadOp.other()),
loadOp.cache(), loadOp.evict(), loadOp.isVolatile()
);
}
else
op->getLoc(), loadsMapping[op->getResult(0)].getType(),
nextMapping.lookupOrDefault(loadOp.ptr()),
nextMapping.lookupOrDefault(loadOp.mask()),
nextMapping.lookupOrDefault(loadOp.other()), loadOp.cache(),
loadOp.evict(), loadOp.isVolatile());
} else
nextOp = builder.clone(*op, nextMapping);
// llvm::errs() << "epilogue cloning...: " << *op << "\n";
// update mapping of results
@@ -411,15 +408,16 @@ scf::ForOp LoopPipeliner::createNewForOp() {
for (size_t idx = 0; idx < loads.size(); ++idx) {
Value load = loads[idx];
for (int stage = 1; stage < numStages - 1; ++stage) {
yieldValues.push_back(newForOp.getRegionIterArgs()[
loadIdx + idx*(numStages-1) + stage
]);
yieldValues.push_back(
newForOp
.getRegionIterArgs()[loadIdx + idx * (numStages - 1) + stage]);
}
yieldValues.push_back(nextMapping.lookup(load));
}
for (size_t i = depArgsBeginIdx; i < nextIVIdx; ++i)
yieldValues.push_back(depArgsMapping.lookup(newForOp.getRegionIterArgs()[i]));
yieldValues.push_back(
depArgsMapping.lookup(newForOp.getRegionIterArgs()[i]));
yieldValues.push_back(nextIV);
builder.setInsertionPointToEnd(newForOp.getBody());
builder.create<scf::YieldOp>(forOp.getBody()->getTerminator()->getLoc(),
@@ -430,9 +428,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
// ref: mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
struct PipelinePass : public TritonGPUPipelineBase<PipelinePass> {
PipelinePass() = default;
PipelinePass(int numStages) {
this->numStages = numStages;
}
PipelinePass(int numStages) { this->numStages = numStages; }
void runOnOperation() override {
int numStages = this->numStages;