[Triton-MLIR][BACKEND] Refactor decompose insert_slice_async (#929)
1. Improve pipline's comment 2. Decompose insert_slice_async when load vector size is not supported 3. Add a test that could fail our gemm code Copy my comments here: There's a knob that may cause performance regression when decomposition has been performed. We should remove this knob once we have thorough analysis on async wait. Currently, we decompose `insert_slice_async` into `load` and `insert_slice` without knowing which `async_wait` is responsible for the `insert_slice_async`. To guarantee correctness, we blindly set the `async_wait` to wait for all async ops if any `insert_slice_async` has been decomposed. There are two options to improve this: 1. We can perform a dataflow analysis to find the `async_wait` that is responsible for the `insert_slice_async` in the backend. 4. We can modify the pipeline to perform the decomposition before the `async_wait` is inserted. However, it is also risky because we don't know the correct vectorized shape yet in the pipeline pass. Making the pipeline pass aware of the vectorization could introduce additional dependencies on the AxisInfoAnalysis and the Coalesce analysis.
This commit is contained in:
@@ -25,18 +25,20 @@ static Type getI1SameShape(Value v) {
|
||||
tensorType.getEncoding());
|
||||
}
|
||||
|
||||
#define int_attr(num) builder.getI64IntegerAttr(num)
|
||||
|
||||
namespace {
|
||||
|
||||
class LoopPipeliner {
|
||||
/// cache forOp we are working on
|
||||
/// Cache forOp we are working on
|
||||
scf::ForOp forOp;
|
||||
|
||||
/// cache YieldOp for this forOp
|
||||
/// Cache YieldOp for this forOp
|
||||
scf::YieldOp yieldOp;
|
||||
|
||||
/// loads to be pipelined
|
||||
/// Loads to be pipelined
|
||||
SetVector<Value> loads;
|
||||
/// the value that each load will be mapped to (after layout conversion)
|
||||
/// The value that each load will be mapped to (after layout conversion)
|
||||
DenseMap<Value, Value> loadsMapping;
|
||||
/// load => buffer
|
||||
DenseMap<Value, Value> loadsBuffer;
|
||||
@@ -51,7 +53,7 @@ class LoopPipeliner {
|
||||
///
|
||||
Value loopIterIdx;
|
||||
|
||||
/// comments on numStages:
|
||||
/// Comments on numStages:
|
||||
/// [0, numStages-1) are in the prologue
|
||||
/// numStages-1 is appended after the loop body
|
||||
int numStages;
|
||||
@@ -61,6 +63,7 @@ class LoopPipeliner {
|
||||
|
||||
/// Block arguments that loads depend on
|
||||
DenseSet<BlockArgument> depArgs;
|
||||
|
||||
/// Operations (inside the loop body) that loads depend on
|
||||
DenseSet<Operation *> depOps;
|
||||
|
||||
@@ -71,7 +74,7 @@ class LoopPipeliner {
|
||||
|
||||
Value lookupOrDefault(Value origin, int stage);
|
||||
|
||||
/// returns a empty buffer of size <numStages, ...>
|
||||
/// Returns a empty buffer of size <numStages, ...>
|
||||
ttg::AllocTensorOp allocateEmptyBuffer(Operation *op, OpBuilder &builder);
|
||||
|
||||
public:
|
||||
@@ -84,7 +87,7 @@ public:
|
||||
/// Collect loads to pipeline. Return success if we can pipeline this loop
|
||||
LogicalResult initialize();
|
||||
|
||||
/// emit pipelined loads (before loop body)
|
||||
/// Emit pipelined loads (before loop body)
|
||||
void emitPrologue();
|
||||
|
||||
/// emit pipelined loads (after loop body)
|
||||
@@ -134,7 +137,7 @@ void LoopPipeliner::collectDeps(Value v, int stages, DenseSet<Value> &deps) {
|
||||
|
||||
ttg::AllocTensorOp LoopPipeliner::allocateEmptyBuffer(Operation *op,
|
||||
OpBuilder &builder) {
|
||||
// allocate a buffer for each pipelined tensor
|
||||
// Allocate a buffer for each pipelined tensor
|
||||
// shape: e.g. (numStages==4), <32x64xbf16> -> <4x32x64xbf16>
|
||||
Value convertLayout = loadsMapping[op->getResult(0)];
|
||||
if (auto tensorType = convertLayout.getType().dyn_cast<RankedTensorType>()) {
|
||||
@@ -215,9 +218,9 @@ LogicalResult LoopPipeliner::initialize() {
|
||||
loads.insert(loadOp);
|
||||
}
|
||||
|
||||
// we have some loads to pipeline
|
||||
// We have some loads to pipeline
|
||||
if (!loads.empty()) {
|
||||
// update depArgs & depOps
|
||||
// Update depArgs & depOps
|
||||
for (Value loadOp : loads) {
|
||||
for (Value dep : loadDeps[loadOp]) {
|
||||
// TODO: we should record the stage that the value is depended on
|
||||
@@ -244,23 +247,20 @@ void LoopPipeliner::emitPrologue() {
|
||||
setValueMapping(arg, operand.get(), 0);
|
||||
}
|
||||
|
||||
// helper to construct int attribute
|
||||
auto intAttr = [&](int64_t val) { return builder.getI64IntegerAttr(val); };
|
||||
|
||||
// prologue from [0, numStage-1)
|
||||
Value iv = forOp.getLowerBound();
|
||||
pipelineIterIdx = builder.create<arith::ConstantIntOp>(iv.getLoc(), 0, 32);
|
||||
for (int stage = 0; stage < numStages - 1; ++stage) {
|
||||
// special handling for induction variable as the increment is implicit
|
||||
// Special handling for induction variable as the increment is implicit
|
||||
if (stage != 0)
|
||||
iv = builder.create<arith::AddIOp>(iv.getLoc(), iv, forOp.getStep());
|
||||
setValueMapping(forOp.getInductionVar(), iv, stage);
|
||||
|
||||
// special handling for loop condition as there is no condition in ForOp
|
||||
// Special handling for loop condition as there is no condition in ForOp
|
||||
Value loopCond = builder.create<arith::CmpIOp>(
|
||||
iv.getLoc(), arith::CmpIPredicate::slt, iv, forOp.getUpperBound());
|
||||
|
||||
// rematerialize peeled values
|
||||
// Rematerialize peeled values
|
||||
SmallVector<Operation *> orderedDeps;
|
||||
for (Operation &op : forOp.getLoopBody().front()) {
|
||||
if (depOps.contains(&op))
|
||||
@@ -314,7 +314,7 @@ void LoopPipeliner::emitPrologue() {
|
||||
}
|
||||
}
|
||||
|
||||
// update mapping of results
|
||||
// Update mapping of results
|
||||
for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) {
|
||||
Value originalResult = op->getResult(dstIdx);
|
||||
// copy_async will update the value of its only use
|
||||
@@ -350,13 +350,14 @@ void LoopPipeliner::emitPrologue() {
|
||||
loadsBufferType[loadOp].getEncoding());
|
||||
Value extractSlice = builder.create<tensor::ExtractSliceOp>(
|
||||
loadOp.getLoc(), sliceType, loadStageBuffer[loadOp][numStages - 1],
|
||||
SmallVector<OpFoldResult>{intAttr(0), intAttr(0), intAttr(0)},
|
||||
SmallVector<OpFoldResult>{intAttr(1), intAttr(sliceType.getShape()[0]),
|
||||
intAttr(sliceType.getShape()[1])},
|
||||
SmallVector<OpFoldResult>{intAttr(1), intAttr(1), intAttr(1)});
|
||||
SmallVector<OpFoldResult>{int_attr(0), int_attr(0), int_attr(0)},
|
||||
SmallVector<OpFoldResult>{int_attr(1),
|
||||
int_attr(sliceType.getShape()[0]),
|
||||
int_attr(sliceType.getShape()[1])},
|
||||
SmallVector<OpFoldResult>{int_attr(1), int_attr(1), int_attr(1)});
|
||||
loadsExtract[loadOp] = extractSlice;
|
||||
}
|
||||
// bump up loopIterIdx, this is used for getting the correct slice for the
|
||||
// Bump up loopIterIdx, this is used for getting the correct slice for the
|
||||
// *next* iteration
|
||||
loopIterIdx = builder.create<arith::AddIOp>(
|
||||
loopIterIdx.getLoc(), loopIterIdx,
|
||||
@@ -365,9 +366,6 @@ void LoopPipeliner::emitPrologue() {
|
||||
|
||||
void LoopPipeliner::emitEpilogue() {
|
||||
// If there's any outstanding async copies, we need to wait for them.
|
||||
// TODO(Keren): We may want to completely avoid the async copies in the last
|
||||
// few iterations by setting is_masked attribute to true. We don't want to use
|
||||
// the mask operand because it's a tensor but not a scalar.
|
||||
OpBuilder builder(forOp);
|
||||
OpBuilder::InsertionGuard g(builder);
|
||||
builder.setInsertionPointAfter(forOp);
|
||||
@@ -376,9 +374,8 @@ void LoopPipeliner::emitEpilogue() {
|
||||
|
||||
scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
OpBuilder builder(forOp);
|
||||
auto intAttr = [&](int64_t v) { return builder.getI64IntegerAttr(v); };
|
||||
|
||||
// order of new args:
|
||||
// Order of new args:
|
||||
// (original args),
|
||||
// (insertSliceAsync buffer at stage numStages - 1) for each load
|
||||
// (extracted tensor) for each load
|
||||
@@ -465,7 +462,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
newForOp.getRegionIterArgs()[argIdx + depArgsBeginIdx]);
|
||||
++argIdx;
|
||||
}
|
||||
// special handling for iv & loop condition
|
||||
// Special handling for iv & loop condition
|
||||
Value nextIV = builder.create<arith::AddIOp>(
|
||||
newForOp.getInductionVar().getLoc(),
|
||||
newForOp.getRegionIterArgs()[nextIVIdx], newForOp.getStep());
|
||||
@@ -473,7 +470,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
builder.create<arith::CmpIOp>(nextIV.getLoc(), arith::CmpIPredicate::slt,
|
||||
nextIV, newForOp.getUpperBound());
|
||||
|
||||
// slice index
|
||||
// Slice index
|
||||
SmallVector<Value> nextBuffers;
|
||||
SmallVector<Value> extractSlices;
|
||||
|
||||
@@ -490,7 +487,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
|
||||
for (Operation *op : orderedDeps) {
|
||||
Operation *nextOp = nullptr;
|
||||
// update loading mask
|
||||
// Update loading mask
|
||||
if (loads.contains(op->getResult(0))) {
|
||||
auto loadOp = llvm::cast<triton::LoadOp>(op);
|
||||
Value mask = loadOp.mask();
|
||||
@@ -500,7 +497,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
mask.getLoc(), mask.getType(), nextLoopCond);
|
||||
newMask = builder.create<arith::AndIOp>(
|
||||
mask.getLoc(), splatCond, nextMapping.lookupOrDefault(mask));
|
||||
// if mask is defined outside the loop, don't update the map more than
|
||||
// If mask is defined outside the loop, don't update the map more than
|
||||
// once
|
||||
if (!(forOp.isDefinedOutsideOfLoop(mask) && nextMapping.contains(mask)))
|
||||
nextMapping.map(mask, newMask);
|
||||
@@ -522,18 +519,19 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
loadsBufferType[loadOp].getEncoding());
|
||||
nextOp = builder.create<tensor::ExtractSliceOp>(
|
||||
op->getLoc(), sliceType, insertAsyncOp,
|
||||
SmallVector<OpFoldResult>{extractSliceIndex, intAttr(0), intAttr(0)},
|
||||
SmallVector<OpFoldResult>{intAttr(1),
|
||||
intAttr(sliceType.getShape()[0]),
|
||||
intAttr(sliceType.getShape()[1])},
|
||||
SmallVector<OpFoldResult>{intAttr(1), intAttr(1), intAttr(1)});
|
||||
SmallVector<OpFoldResult>{extractSliceIndex, int_attr(0),
|
||||
int_attr(0)},
|
||||
SmallVector<OpFoldResult>{int_attr(1),
|
||||
int_attr(sliceType.getShape()[0]),
|
||||
int_attr(sliceType.getShape()[1])},
|
||||
SmallVector<OpFoldResult>{int_attr(1), int_attr(1), int_attr(1)});
|
||||
extractSlices.push_back(nextOp->getResult(0));
|
||||
} else
|
||||
nextOp = builder.clone(*op, nextMapping);
|
||||
// update mapping of results
|
||||
// Update mapping of results
|
||||
for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) {
|
||||
nextMapping.map(op->getResult(dstIdx), nextOp->getResult(dstIdx));
|
||||
// if this is a loop-carried value, update the mapping for yield
|
||||
// If this is a loop-carried value, update the mapping for yield
|
||||
auto originYield = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
|
||||
for (OpOperand &operand : originYield->getOpOperands()) {
|
||||
if (operand.get() == op->getResult(dstIdx)) {
|
||||
@@ -583,7 +581,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
it->getDefiningOp()->moveAfter(asyncWait);
|
||||
}
|
||||
|
||||
// bump iteration count
|
||||
// Bump iteration count
|
||||
pipelineIterIdx = builder.create<arith::AddIOp>(
|
||||
nextIV.getLoc(), pipelineIterIdx,
|
||||
builder.create<arith::ConstantIntOp>(nextIV.getLoc(), 1, 32));
|
||||
|
Reference in New Issue
Block a user