[Triton-MLIR][BACKEND] Refactor decompose insert_slice_async (#929)

1. Improve pipline's comment
2. Decompose insert_slice_async when load vector size is not supported
3. Add a test that could fail our gemm code

Copy my comments here:

There's a knob that may cause performance regression when decomposition
has been performed. We should remove this knob once we have thorough
analysis on async wait. Currently, we decompose `insert_slice_async`
into `load` and `insert_slice` without knowing which `async_wait` is
responsible for the `insert_slice_async`. To guarantee correctness, we
blindly set the `async_wait` to wait for all async ops if any `insert_slice_async` has been decomposed.

There are two options to improve this:
1. We can perform a dataflow analysis to find the `async_wait` that is
responsible for the `insert_slice_async` in the backend.
4. We can modify the pipeline to perform the decomposition before the
`async_wait` is inserted. However, it is also risky because we don't
know the correct vectorized shape yet in the pipeline pass. Making the
pipeline pass aware of the vectorization could introduce additional
dependencies on the AxisInfoAnalysis and the Coalesce analysis.
This commit is contained in:
Keren Zhou
2022-11-30 10:07:34 -08:00
committed by GitHub
parent 6461254fb5
commit 7d90a07d0b
12 changed files with 219 additions and 107 deletions

View File

@@ -25,18 +25,20 @@ static Type getI1SameShape(Value v) {
tensorType.getEncoding());
}
#define int_attr(num) builder.getI64IntegerAttr(num)
namespace {
class LoopPipeliner {
/// cache forOp we are working on
/// Cache forOp we are working on
scf::ForOp forOp;
/// cache YieldOp for this forOp
/// Cache YieldOp for this forOp
scf::YieldOp yieldOp;
/// loads to be pipelined
/// Loads to be pipelined
SetVector<Value> loads;
/// the value that each load will be mapped to (after layout conversion)
/// The value that each load will be mapped to (after layout conversion)
DenseMap<Value, Value> loadsMapping;
/// load => buffer
DenseMap<Value, Value> loadsBuffer;
@@ -51,7 +53,7 @@ class LoopPipeliner {
///
Value loopIterIdx;
/// comments on numStages:
/// Comments on numStages:
/// [0, numStages-1) are in the prologue
/// numStages-1 is appended after the loop body
int numStages;
@@ -61,6 +63,7 @@ class LoopPipeliner {
/// Block arguments that loads depend on
DenseSet<BlockArgument> depArgs;
/// Operations (inside the loop body) that loads depend on
DenseSet<Operation *> depOps;
@@ -71,7 +74,7 @@ class LoopPipeliner {
Value lookupOrDefault(Value origin, int stage);
/// returns a empty buffer of size <numStages, ...>
/// Returns a empty buffer of size <numStages, ...>
ttg::AllocTensorOp allocateEmptyBuffer(Operation *op, OpBuilder &builder);
public:
@@ -84,7 +87,7 @@ public:
/// Collect loads to pipeline. Return success if we can pipeline this loop
LogicalResult initialize();
/// emit pipelined loads (before loop body)
/// Emit pipelined loads (before loop body)
void emitPrologue();
/// emit pipelined loads (after loop body)
@@ -134,7 +137,7 @@ void LoopPipeliner::collectDeps(Value v, int stages, DenseSet<Value> &deps) {
ttg::AllocTensorOp LoopPipeliner::allocateEmptyBuffer(Operation *op,
OpBuilder &builder) {
// allocate a buffer for each pipelined tensor
// Allocate a buffer for each pipelined tensor
// shape: e.g. (numStages==4), <32x64xbf16> -> <4x32x64xbf16>
Value convertLayout = loadsMapping[op->getResult(0)];
if (auto tensorType = convertLayout.getType().dyn_cast<RankedTensorType>()) {
@@ -215,9 +218,9 @@ LogicalResult LoopPipeliner::initialize() {
loads.insert(loadOp);
}
// we have some loads to pipeline
// We have some loads to pipeline
if (!loads.empty()) {
// update depArgs & depOps
// Update depArgs & depOps
for (Value loadOp : loads) {
for (Value dep : loadDeps[loadOp]) {
// TODO: we should record the stage that the value is depended on
@@ -244,23 +247,20 @@ void LoopPipeliner::emitPrologue() {
setValueMapping(arg, operand.get(), 0);
}
// helper to construct int attribute
auto intAttr = [&](int64_t val) { return builder.getI64IntegerAttr(val); };
// prologue from [0, numStage-1)
Value iv = forOp.getLowerBound();
pipelineIterIdx = builder.create<arith::ConstantIntOp>(iv.getLoc(), 0, 32);
for (int stage = 0; stage < numStages - 1; ++stage) {
// special handling for induction variable as the increment is implicit
// Special handling for induction variable as the increment is implicit
if (stage != 0)
iv = builder.create<arith::AddIOp>(iv.getLoc(), iv, forOp.getStep());
setValueMapping(forOp.getInductionVar(), iv, stage);
// special handling for loop condition as there is no condition in ForOp
// Special handling for loop condition as there is no condition in ForOp
Value loopCond = builder.create<arith::CmpIOp>(
iv.getLoc(), arith::CmpIPredicate::slt, iv, forOp.getUpperBound());
// rematerialize peeled values
// Rematerialize peeled values
SmallVector<Operation *> orderedDeps;
for (Operation &op : forOp.getLoopBody().front()) {
if (depOps.contains(&op))
@@ -314,7 +314,7 @@ void LoopPipeliner::emitPrologue() {
}
}
// update mapping of results
// Update mapping of results
for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) {
Value originalResult = op->getResult(dstIdx);
// copy_async will update the value of its only use
@@ -350,13 +350,14 @@ void LoopPipeliner::emitPrologue() {
loadsBufferType[loadOp].getEncoding());
Value extractSlice = builder.create<tensor::ExtractSliceOp>(
loadOp.getLoc(), sliceType, loadStageBuffer[loadOp][numStages - 1],
SmallVector<OpFoldResult>{intAttr(0), intAttr(0), intAttr(0)},
SmallVector<OpFoldResult>{intAttr(1), intAttr(sliceType.getShape()[0]),
intAttr(sliceType.getShape()[1])},
SmallVector<OpFoldResult>{intAttr(1), intAttr(1), intAttr(1)});
SmallVector<OpFoldResult>{int_attr(0), int_attr(0), int_attr(0)},
SmallVector<OpFoldResult>{int_attr(1),
int_attr(sliceType.getShape()[0]),
int_attr(sliceType.getShape()[1])},
SmallVector<OpFoldResult>{int_attr(1), int_attr(1), int_attr(1)});
loadsExtract[loadOp] = extractSlice;
}
// bump up loopIterIdx, this is used for getting the correct slice for the
// Bump up loopIterIdx, this is used for getting the correct slice for the
// *next* iteration
loopIterIdx = builder.create<arith::AddIOp>(
loopIterIdx.getLoc(), loopIterIdx,
@@ -365,9 +366,6 @@ void LoopPipeliner::emitPrologue() {
void LoopPipeliner::emitEpilogue() {
// If there's any outstanding async copies, we need to wait for them.
// TODO(Keren): We may want to completely avoid the async copies in the last
// few iterations by setting is_masked attribute to true. We don't want to use
// the mask operand because it's a tensor but not a scalar.
OpBuilder builder(forOp);
OpBuilder::InsertionGuard g(builder);
builder.setInsertionPointAfter(forOp);
@@ -376,9 +374,8 @@ void LoopPipeliner::emitEpilogue() {
scf::ForOp LoopPipeliner::createNewForOp() {
OpBuilder builder(forOp);
auto intAttr = [&](int64_t v) { return builder.getI64IntegerAttr(v); };
// order of new args:
// Order of new args:
// (original args),
// (insertSliceAsync buffer at stage numStages - 1) for each load
// (extracted tensor) for each load
@@ -465,7 +462,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
newForOp.getRegionIterArgs()[argIdx + depArgsBeginIdx]);
++argIdx;
}
// special handling for iv & loop condition
// Special handling for iv & loop condition
Value nextIV = builder.create<arith::AddIOp>(
newForOp.getInductionVar().getLoc(),
newForOp.getRegionIterArgs()[nextIVIdx], newForOp.getStep());
@@ -473,7 +470,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
builder.create<arith::CmpIOp>(nextIV.getLoc(), arith::CmpIPredicate::slt,
nextIV, newForOp.getUpperBound());
// slice index
// Slice index
SmallVector<Value> nextBuffers;
SmallVector<Value> extractSlices;
@@ -490,7 +487,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
for (Operation *op : orderedDeps) {
Operation *nextOp = nullptr;
// update loading mask
// Update loading mask
if (loads.contains(op->getResult(0))) {
auto loadOp = llvm::cast<triton::LoadOp>(op);
Value mask = loadOp.mask();
@@ -500,7 +497,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
mask.getLoc(), mask.getType(), nextLoopCond);
newMask = builder.create<arith::AndIOp>(
mask.getLoc(), splatCond, nextMapping.lookupOrDefault(mask));
// if mask is defined outside the loop, don't update the map more than
// If mask is defined outside the loop, don't update the map more than
// once
if (!(forOp.isDefinedOutsideOfLoop(mask) && nextMapping.contains(mask)))
nextMapping.map(mask, newMask);
@@ -522,18 +519,19 @@ scf::ForOp LoopPipeliner::createNewForOp() {
loadsBufferType[loadOp].getEncoding());
nextOp = builder.create<tensor::ExtractSliceOp>(
op->getLoc(), sliceType, insertAsyncOp,
SmallVector<OpFoldResult>{extractSliceIndex, intAttr(0), intAttr(0)},
SmallVector<OpFoldResult>{intAttr(1),
intAttr(sliceType.getShape()[0]),
intAttr(sliceType.getShape()[1])},
SmallVector<OpFoldResult>{intAttr(1), intAttr(1), intAttr(1)});
SmallVector<OpFoldResult>{extractSliceIndex, int_attr(0),
int_attr(0)},
SmallVector<OpFoldResult>{int_attr(1),
int_attr(sliceType.getShape()[0]),
int_attr(sliceType.getShape()[1])},
SmallVector<OpFoldResult>{int_attr(1), int_attr(1), int_attr(1)});
extractSlices.push_back(nextOp->getResult(0));
} else
nextOp = builder.clone(*op, nextMapping);
// update mapping of results
// Update mapping of results
for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) {
nextMapping.map(op->getResult(dstIdx), nextOp->getResult(dstIdx));
// if this is a loop-carried value, update the mapping for yield
// If this is a loop-carried value, update the mapping for yield
auto originYield = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
for (OpOperand &operand : originYield->getOpOperands()) {
if (operand.get() == op->getResult(dstIdx)) {
@@ -583,7 +581,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
it->getDefiningOp()->moveAfter(asyncWait);
}
// bump iteration count
// Bump iteration count
pipelineIterIdx = builder.create<arith::AddIOp>(
nextIV.getLoc(), pipelineIterIdx,
builder.create<arith::ConstantIntOp>(nextIV.getLoc(), 1, 32));