[OPTIMIZER] Better pipeline pass (#100)
* Use `insert_slice_async` instead of `CopyAsync` * Move async.wait to loop header Co-authored-by: Jokeren <kerenzhou@openai.com>
This commit is contained in:
@@ -32,6 +32,8 @@ class LoopPipeliner {
|
||||
SetVector<Value> loads;
|
||||
/// the value that each load will be mapped to (after layout conversion)
|
||||
DenseMap<Value, Value> loadsMapping;
|
||||
/// load => buffer
|
||||
DenseMap<Value, Value> loadsBuffer;
|
||||
|
||||
/// value (in loop) => value at stage N
|
||||
DenseMap<Value, SmallVector<Value>> valueMapping;
|
||||
@@ -46,9 +48,15 @@ class LoopPipeliner {
|
||||
|
||||
void setValueMapping(Value origin, Value newValue, int stage);
|
||||
|
||||
Value lookupOrDefault(Value origin, int stage);
|
||||
|
||||
/// return true if this op uses any of `loads`
|
||||
bool isDirectUserOfAsyncLoad(Operation &op);
|
||||
|
||||
/// returns a empty buffer of size <numStages, ...>
|
||||
triton::gpu::AllocTensorOp allocateEmptyBuffer(Operation *op,
|
||||
OpBuilder &builder);
|
||||
|
||||
public:
|
||||
LoopPipeliner(scf::ForOp forOp, int numStages)
|
||||
: forOp(forOp), numStages(numStages) {
|
||||
@@ -75,6 +83,12 @@ void LoopPipeliner::setValueMapping(Value origin, Value newValue, int stage) {
|
||||
valueMapping[origin][stage] = newValue;
|
||||
}
|
||||
|
||||
Value LoopPipeliner::lookupOrDefault(Value origin, int stage) {
|
||||
if (valueMapping.find(origin) == valueMapping.end())
|
||||
return origin;
|
||||
return valueMapping[origin][stage];
|
||||
}
|
||||
|
||||
void LoopPipeliner::collectDeps(Value v, int stages, DenseSet<Value> &deps) {
|
||||
// Loop-invarant value. skip
|
||||
if (v.getParentRegion() != &forOp.getLoopBody())
|
||||
@@ -111,6 +125,25 @@ bool LoopPipeliner::isDirectUserOfAsyncLoad(Operation &op) {
|
||||
return false;
|
||||
}
|
||||
|
||||
triton::gpu::AllocTensorOp
|
||||
LoopPipeliner::allocateEmptyBuffer(Operation *op, OpBuilder &builder) {
|
||||
// allocate a buffer for each pipelined tensor
|
||||
// shape: e.g. (numStages==4), <32x64xbf16> -> <4x32x64xbf16>
|
||||
Value convertLayout = loadsMapping[op->getResult(0)];
|
||||
if (auto tensorType = convertLayout.getType().dyn_cast<RankedTensorType>()) {
|
||||
SmallVector<int64_t> shape(tensorType.getShape().begin(),
|
||||
tensorType.getShape().end());
|
||||
shape.insert(shape.begin(), numStages);
|
||||
Type elementType = tensorType.getElementType();
|
||||
// The encoding of the buffer is similar to the original tensor
|
||||
Attribute encoding = tensorType.getEncoding();
|
||||
auto bufferType = RankedTensorType::get(shape, elementType, encoding);
|
||||
return builder.create<triton::gpu::AllocTensorOp>(convertLayout.getLoc(),
|
||||
bufferType);
|
||||
}
|
||||
llvm_unreachable("Async copy's return should be of RankedTensorType");
|
||||
}
|
||||
|
||||
/// A load instruction can be pipelined if:
|
||||
/// - the load doesn't depend on any other loads (after loop peeling)
|
||||
/// - (?) this load is not a loop-invariant value (we should run LICM before
|
||||
@@ -137,13 +170,6 @@ LogicalResult LoopPipeliner::initialize() {
|
||||
loadDeps[loadOp] = deps;
|
||||
}
|
||||
|
||||
// for (triton::LoadOp loadOp : allLoads) {
|
||||
// llvm::errs() << loadOp << " depends on: #" << loadDeps[loadOp].size() <<
|
||||
// " values\n"; for (Value dep : loadDeps[loadOp])
|
||||
// llvm::errs() << dep << "\n";
|
||||
// llvm::errs() << "\n";
|
||||
// }
|
||||
|
||||
// Don't pipeline loads that depend on other loads
|
||||
// (Because if a load depends on another load, this load needs to wait on the
|
||||
// other load in the prologue, which is against the point of the pipeline
|
||||
@@ -234,29 +260,40 @@ void LoopPipeliner::emitPrologue() {
|
||||
for (Operation *op : orderedDeps) {
|
||||
Operation *newOp = nullptr;
|
||||
if (loads.contains(op->getResult(0))) {
|
||||
// Allocate empty buffer
|
||||
if (stage == 0)
|
||||
loadsBuffer[op->getResult(0)] = allocateEmptyBuffer(op, builder);
|
||||
// load => copy async
|
||||
// TODO: check if the hardware supports copyasync
|
||||
if (auto loadOp = llvm::dyn_cast<triton::LoadOp>(op)) {
|
||||
newOp = builder.create<triton::gpu::CopyAsyncOp>(
|
||||
op->getLoc(), loadsMapping[loadOp].getType(), loadOp.ptr(),
|
||||
loadOp.mask(), loadOp.other(), loadOp.cache(), loadOp.evict(),
|
||||
loadOp.isVolatile());
|
||||
Value sliceIndex = builder.create<arith::IndexCastOp>(
|
||||
iv.getLoc(), builder.getI32Type(), iv);
|
||||
Value insertAsyncOp = builder.create<triton::gpu::InsertSliceAsyncOp>(
|
||||
op->getLoc(), loadsBuffer[loadOp].getType(),
|
||||
lookupOrDefault(loadOp.ptr(), stage), loadsBuffer[loadOp],
|
||||
sliceIndex, lookupOrDefault(loadOp.mask(), stage),
|
||||
lookupOrDefault(loadOp.other(), stage), loadOp.cache(),
|
||||
loadOp.evict(), loadOp.isVolatile(), /*axis*/ 0);
|
||||
newOp = builder.create<triton::gpu::ExtractSliceOp>(
|
||||
op->getLoc(), loadsMapping[loadOp].getType(), insertAsyncOp,
|
||||
sliceIndex, /*axis*/ 0);
|
||||
} else
|
||||
llvm_unreachable("This should be LoadOp");
|
||||
} else
|
||||
} else {
|
||||
newOp = builder.clone(*op);
|
||||
|
||||
for (unsigned opIdx = 0; opIdx < op->getNumOperands(); ++opIdx) {
|
||||
auto it = valueMapping.find(op->getOperand(opIdx));
|
||||
if (it != valueMapping.end()) {
|
||||
Value v = it->second[stage];
|
||||
assert(v);
|
||||
newOp->setOperand(opIdx, v);
|
||||
} // else, op at opIdx is a loop-invariant value
|
||||
// Update loop-carried uses
|
||||
for (unsigned opIdx = 0; opIdx < op->getNumOperands(); ++opIdx) {
|
||||
auto it = valueMapping.find(op->getOperand(opIdx));
|
||||
if (it != valueMapping.end()) {
|
||||
Value v = it->second[stage];
|
||||
assert(v);
|
||||
newOp->setOperand(opIdx, v);
|
||||
} // else, op at opIdx is a loop-invariant value
|
||||
}
|
||||
}
|
||||
|
||||
// If this is a load/async_copy, we need to update the mask
|
||||
if (llvm::isa<triton::LoadOp, triton::gpu::CopyAsyncOp>(newOp)) {
|
||||
if (llvm::isa<triton::LoadOp, triton::gpu::InsertSliceAsyncOp>(newOp)) {
|
||||
Value mask = newOp->getOperand(1);
|
||||
// assert(I1 or TensorOf<[I1]>);
|
||||
OpBuilder::InsertionGuard g(builder);
|
||||
@@ -265,7 +302,11 @@ void LoopPipeliner::emitPrologue() {
|
||||
mask.getLoc(), mask.getType(), loopCond);
|
||||
Value newMask =
|
||||
builder.create<arith::AndIOp>(mask.getLoc(), mask, splatCond);
|
||||
newOp->setOperand(1, newMask);
|
||||
// TODO: better way to do this?
|
||||
if (llvm::isa<triton::LoadOp>(newOp))
|
||||
newOp->setOperand(1, newMask);
|
||||
else // InsertSliceAsyncOp
|
||||
newOp->setOperand(3, newMask);
|
||||
}
|
||||
|
||||
// update mapping of results
|
||||
@@ -285,6 +326,17 @@ void LoopPipeliner::emitPrologue() {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// async.wait & extract_slice
|
||||
Operation *asyncWait = builder.create<triton::gpu::AsyncWaitOp>(
|
||||
loads[0].getLoc(), loads.size() * (numStages - 2));
|
||||
for (int i = numStages - 2; i >= 0; --i) {
|
||||
for (auto it = loads.rbegin(); it != loads.rend(); ++it) {
|
||||
// move extract_slice after asyncWait
|
||||
Value load = *it;
|
||||
valueMapping[loadsMapping[load]][i].getDefiningOp()->moveAfter(asyncWait);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
@@ -333,14 +385,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
|
||||
// 2.1 clone the loop body, replace original args with args of the new ForOp
|
||||
// Insert async wait if necessary.
|
||||
bool asyncWaitInserted = false;
|
||||
for (Operation &op : forOp.getBody()->without_terminator()) {
|
||||
if (!asyncWaitInserted && isDirectUserOfAsyncLoad(op)) {
|
||||
asyncWaitInserted = true;
|
||||
assert(numStages >= 2);
|
||||
builder.create<triton::gpu::AsyncWaitOp>(op.getLoc(),
|
||||
loads.size() * (numStages - 2));
|
||||
}
|
||||
Operation *newOp = builder.clone(op, mapping);
|
||||
// update mapping of results
|
||||
for (unsigned dstIdx : llvm::seq(unsigned(0), op.getNumResults()))
|
||||
@@ -385,8 +430,17 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
Value nextLoopCond =
|
||||
builder.create<arith::CmpIOp>(nextIV.getLoc(), arith::CmpIPredicate::slt,
|
||||
nextIV, newForOp.getUpperBound());
|
||||
|
||||
// slice index
|
||||
SmallVector<Value> extractSlices;
|
||||
Value sliceIndex = builder.create<arith::IndexCastOp>(
|
||||
nextIV.getLoc(), builder.getI32Type(), nextIV);
|
||||
sliceIndex = builder.create<arith::RemSIOp>(
|
||||
nextIV.getLoc(), sliceIndex,
|
||||
builder.create<arith::ConstantIntOp>(nextIV.getLoc(), numStages, 32));
|
||||
for (Operation *op : orderedDeps) {
|
||||
Operation *nextOp = nullptr;
|
||||
// TODO(da): does this work if loadOp has no mask?
|
||||
// update loading mask
|
||||
if (loads.contains(op->getResult(0))) {
|
||||
auto loadOp = llvm::cast<triton::LoadOp>(op);
|
||||
@@ -401,16 +455,18 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
if (!(forOp.isDefinedOutsideOfLoop(mask) && nextMapping.contains(mask)))
|
||||
nextMapping.map(mask, newMask);
|
||||
}
|
||||
// TODO: more elegant way to do this?
|
||||
nextOp = builder.create<triton::gpu::CopyAsyncOp>(
|
||||
op->getLoc(), loadsMapping[op->getResult(0)].getType(),
|
||||
nextMapping.lookupOrDefault(loadOp.ptr()),
|
||||
nextMapping.lookupOrDefault(loadOp.mask()),
|
||||
Value insertAsyncOp = builder.create<triton::gpu::InsertSliceAsyncOp>(
|
||||
op->getLoc(), loadsBuffer[loadOp].getType(),
|
||||
nextMapping.lookupOrDefault(loadOp.ptr()), loadsBuffer[loadOp],
|
||||
sliceIndex, nextMapping.lookupOrDefault(loadOp.mask()),
|
||||
nextMapping.lookupOrDefault(loadOp.other()), loadOp.cache(),
|
||||
loadOp.evict(), loadOp.isVolatile());
|
||||
loadOp.evict(), loadOp.isVolatile(), /*axis*/ 0);
|
||||
nextOp = builder.create<triton::gpu::ExtractSliceOp>(
|
||||
op->getLoc(), loadsMapping[loadOp].getType(), insertAsyncOp,
|
||||
sliceIndex, /*axis*/ 0);
|
||||
extractSlices.push_back(nextOp->getResult(0));
|
||||
} else
|
||||
nextOp = builder.clone(*op, nextMapping);
|
||||
// llvm::errs() << "epilogue cloning...: " << *op << "\n";
|
||||
// update mapping of results
|
||||
for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) {
|
||||
nextMapping.map(op->getResult(dstIdx), nextOp->getResult(dstIdx));
|
||||
@@ -427,6 +483,14 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
}
|
||||
}
|
||||
|
||||
// async.wait & extract_slice
|
||||
Operation *asyncWait = builder.create<triton::gpu::AsyncWaitOp>(
|
||||
loads[0].getLoc(), loads.size() * (numStages - 2));
|
||||
for (auto it = extractSlices.rbegin(); it != extractSlices.rend(); ++it) {
|
||||
// move extract_slice after asyncWait
|
||||
it->getDefiningOp()->moveAfter(asyncWait);
|
||||
}
|
||||
|
||||
// Finally, the YieldOp, need to sync with the order of newLoopArgs
|
||||
SmallVector<Value> yieldValues;
|
||||
for (Value v : forOp.getBody()->getTerminator()->getOperands())
|
||||
|
@@ -9,15 +9,24 @@
|
||||
#C = #triton_gpu.mma<{version = 2, warpsPerCTA = [4, 1]}>
|
||||
|
||||
// CHECK: func @matmul_loop
|
||||
// CHECK: %[[A0:.*]] = triton_gpu.copy_async
|
||||
// CHECK: %[[B0:.*]] = triton_gpu.copy_async
|
||||
// CHECK: %[[A1:.*]] = triton_gpu.copy_async
|
||||
// CHECK: %[[B1:.*]] = triton_gpu.copy_async
|
||||
// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_a1:.*]] = %[[A1]], %[[arg_b0:.*]] = %[[B0]], %[[arg_b1:.*]] = %[[B1]], {{.*}})
|
||||
// CHECK: %[[ABUFFER:.*]] = triton_gpu.alloc_tensor
|
||||
// CHECK: %[[A0BUFFER:.*]] = triton_gpu.insert_slice_async
|
||||
// CHECK: %[[BBUFFER:.*]] = triton_gpu.alloc_tensor
|
||||
// CHECK: %[[B0BUFFER:.*]] = triton_gpu.insert_slice_async
|
||||
// CHECK: %[[A1BUFFER:.*]] = triton_gpu.insert_slice_async
|
||||
// CHECK: %[[B1BUFFER:.*]] = triton_gpu.insert_slice_async
|
||||
// CHECK: triton_gpu.async_wait {num = 2 : i32}
|
||||
// CHECK: %[[A0:.*]] = triton_gpu.extract_slice %[[A0BUFFER]]
|
||||
// CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B0BUFFER]]
|
||||
// CHECK: %[[A1:.*]] = triton_gpu.extract_slice %[[A1BUFFER]]
|
||||
// CHECK: %[[B1:.*]] = triton_gpu.extract_slice %[[B1BUFFER]]
|
||||
// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_a1:.*]] = %[[A1]], %[[arg_b0:.*]] = %[[B0]], %[[arg_b1:.*]] = %[[B1]], {{.*}})
|
||||
// CHECK: tt.dot %[[arg_a0]], %[[arg_b0]], {{.*}}
|
||||
// CHECK: %[[NEXT_A:.*]] = triton_gpu.copy_async
|
||||
// CHECK: %[[NEXT_B:.*]] = triton_gpu.copy_async
|
||||
// CHECK: %[[NEXT_A_BUFFER:.*]] = triton_gpu.insert_slice_async
|
||||
// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async
|
||||
// CHECK: triton_gpu.async_wait {num = 2 : i32}
|
||||
// CHECK: %[[NEXT_A:.*]] = triton_gpu.extract_slice %[[NEXT_A_BUFFER]]
|
||||
// CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %[[NEXT_B_BUFFER]]
|
||||
// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[arg_a1]], %[[NEXT_A]], %[[arg_b1]], %[[NEXT_B]]
|
||||
func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
||||
%a_ptr_init = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
|
||||
@@ -50,15 +59,24 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B
|
||||
|
||||
// CHECK: func @matmul_loop_nested
|
||||
// CHECK: scf.for
|
||||
// CHECK: %[[A0:.*]] = triton_gpu.copy_async
|
||||
// CHECK: %[[B0:.*]] = triton_gpu.copy_async
|
||||
// CHECK: %[[A1:.*]] = triton_gpu.copy_async
|
||||
// CHECK: %[[B1:.*]] = triton_gpu.copy_async
|
||||
// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_a1:.*]] = %[[A1]], %[[arg_b0:.*]] = %[[B0]], %[[arg_b1:.*]] = %[[B1]], {{.*}})
|
||||
// CHECK: %[[ABUFFER:.*]] = triton_gpu.alloc_tensor
|
||||
// CHECK: %[[A0BUFFER:.*]] = triton_gpu.insert_slice_async
|
||||
// CHECK: %[[BBUFFER:.*]] = triton_gpu.alloc_tensor
|
||||
// CHECK: %[[B0BUFFER:.*]] = triton_gpu.insert_slice_async
|
||||
// CHECK: %[[A1BUFFER:.*]] = triton_gpu.insert_slice_async
|
||||
// CHECK: %[[B1BUFFER:.*]] = triton_gpu.insert_slice_async
|
||||
// CHECK: triton_gpu.async_wait {num = 2 : i32}
|
||||
// CHECK: %[[A0:.*]] = triton_gpu.extract_slice %[[A0BUFFER]]
|
||||
// CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B0BUFFER]]
|
||||
// CHECK: %[[A1:.*]] = triton_gpu.extract_slice %[[A1BUFFER]]
|
||||
// CHECK: %[[B1:.*]] = triton_gpu.extract_slice %[[B1BUFFER]]
|
||||
// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_a1:.*]] = %[[A1]], %[[arg_b0:.*]] = %[[B0]], %[[arg_b1:.*]] = %[[B1]], {{.*}})
|
||||
// CHECK: tt.dot %[[arg_a0]], %[[arg_b0]], {{.*}}
|
||||
// CHECK: %[[NEXT_A:.*]] = triton_gpu.copy_async
|
||||
// CHECK: %[[NEXT_B:.*]] = triton_gpu.copy_async
|
||||
// CHECK: %[[NEXT_A_BUFFER:.*]] = triton_gpu.insert_slice_async
|
||||
// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async
|
||||
// CHECK: triton_gpu.async_wait {num = 2 : i32}
|
||||
// CHECK: %[[NEXT_A:.*]] = triton_gpu.extract_slice %[[NEXT_A_BUFFER]]
|
||||
// CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %[[NEXT_B_BUFFER]]
|
||||
// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[arg_a1]], %[[NEXT_A]], %[[arg_b1]], %[[NEXT_B]]
|
||||
func @matmul_loop_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
||||
scf.for %iv0 = %lb to %ub step %step {
|
||||
@@ -92,12 +110,17 @@ func @matmul_loop_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f
|
||||
|
||||
|
||||
// CHECK: func @matmul_loop_single_pipeline
|
||||
// CHECK: %[[B0:.*]] = triton_gpu.copy_async
|
||||
// CHECK: %[[B1:.*]] = triton_gpu.copy_async
|
||||
// CHECK: %[[BBUFFER:.*]] = triton_gpu.alloc_tensor
|
||||
// CHECK: %[[B0BUFFER:.*]] = triton_gpu.insert_slice_async
|
||||
// CHECK: %[[B1BUFFER:.*]] = triton_gpu.insert_slice_async
|
||||
// CHECK: triton_gpu.async_wait {num = 1 : i32}
|
||||
// CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B0BUFFER]]
|
||||
// CHECK: %[[B1:.*]] = triton_gpu.extract_slice %[[B1BUFFER]]
|
||||
// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, %[[arg_b0:.*]] = %[[B0]], %[[arg_b1:.*]] = %[[B1]], {{.*}})
|
||||
// CHECK: triton_gpu.async_wait {num = 1 : i32}
|
||||
// CHECK: tt.dot {{.*}}, %[[arg_b0]], {{.*}}
|
||||
// CHECK: %[[NEXT_B:.*]] = triton_gpu.copy_async
|
||||
// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async
|
||||
// CHECK: triton_gpu.async_wait {num = 1 : i32}
|
||||
// CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %[[NEXT_B_BUFFER]]
|
||||
// CHECK: scf.yield {{.*}}, {{.*}}, %[[arg_b1]], %[[NEXT_B]]
|
||||
func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
||||
%a_ptr_init = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
|
||||
|
Reference in New Issue
Block a user