[OPTIMIZER] Let the pipeline pass insert async wait. (#63)

This commit is contained in:
Da Yan
2022-08-19 01:31:57 +08:00
committed by GitHub
parent d69ce77b19
commit 8776ad1a0e
3 changed files with 28 additions and 0 deletions

View File

@@ -34,6 +34,8 @@ def TTG_AsyncWaitOp : TTG_Op<"async_wait"> {
let summary = "async wait";
let arguments = (ins I32Attr:$num);
let assemblyFormat = "attr-dict";
}
def TTG_CopyAsyncOp : TTG_Op<"copy_async",

View File

@@ -47,6 +47,9 @@ class LoopPipeliner {
void setValueMapping(Value origin, Value newValue, int stage);
/// return true if this op uses any of `loads`
bool isDirectUserOfAsyncLoad(Operation &op);
public:
LoopPipeliner(scf::ForOp forOp, int numStages)
: forOp(forOp), numStages(numStages) {
@@ -96,6 +99,19 @@ void LoopPipeliner::collectDeps(Value v, int stages, DenseSet<Value> &deps) {
}
}
bool LoopPipeliner::isDirectUserOfAsyncLoad(Operation &op) {
for (Value loadOp : loads) {
assert(loadOp.hasOneUse() &&
"load should only have one use (ConvertLayout)");
Value loadUseResult = loadOp.getUsers().begin()->getResult(0);
for (Value opOperand : op.getOperands()) {
if (opOperand == loadUseResult)
return true;
}
}
return false;
}
/// A load instruction can be pipelined if:
/// - the load doesn't depend on any other loads (after loop peeling)
/// - (?) this load is not a loop-invariant value (we should run LICM before
@@ -318,7 +334,14 @@ scf::ForOp LoopPipeliner::createNewForOp() {
mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
// 2.1 clone the loop body, replace original args with args of the new ForOp
// Insert async wait if necessary.
bool asyncWaitInserted = false;
for (Operation &op : forOp.getBody()->without_terminator()) {
if (!asyncWaitInserted && isDirectUserOfAsyncLoad(op)) {
asyncWaitInserted = true;
builder.create<triton::gpu::AsyncWaitOp>(op.getLoc(),
loads.size() * (numStages - 1));
}
Operation *newOp = builder.clone(op, mapping);
// update mapping of results
for (unsigned dstIdx : llvm::seq(unsigned(0), op.getNumResults()))

View File

@@ -14,6 +14,7 @@
// CHECK: %[[A1:.*]] = triton_gpu.copy_async
// CHECK: %[[B1:.*]] = triton_gpu.copy_async
// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_a1:.*]] = %[[A1]], %[[arg_b0:.*]] = %[[B0]], %[[arg_b1:.*]] = %[[B1]], {{.*}})
// CHECK: triton_gpu.async_wait {num = 4 : i32}
// CHECK: tt.dot %[[arg_a0]], %[[arg_b0]], {{.*}}
// CHECK: %[[NEXT_A:.*]] = triton_gpu.copy_async
// CHECK: %[[NEXT_B:.*]] = triton_gpu.copy_async
@@ -54,6 +55,7 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B
// CHECK: %[[A1:.*]] = triton_gpu.copy_async
// CHECK: %[[B1:.*]] = triton_gpu.copy_async
// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_a1:.*]] = %[[A1]], %[[arg_b0:.*]] = %[[B0]], %[[arg_b1:.*]] = %[[B1]], {{.*}})
// CHECK: triton_gpu.async_wait {num = 4 : i32}
// CHECK: tt.dot %[[arg_a0]], %[[arg_b0]], {{.*}}
// CHECK: %[[NEXT_A:.*]] = triton_gpu.copy_async
// CHECK: %[[NEXT_B:.*]] = triton_gpu.copy_async
@@ -93,6 +95,7 @@ func @matmul_loop_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f
// CHECK: %[[B0:.*]] = triton_gpu.copy_async
// CHECK: %[[B1:.*]] = triton_gpu.copy_async
// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, %[[arg_b0:.*]] = %[[B0]], %[[arg_b1:.*]] = %[[B1]], {{.*}})
// CHECK: triton_gpu.async_wait {num = 2 : i32}
// CHECK: tt.dot {{.*}}, %[[arg_b0]], {{.*}}
// CHECK: %[[NEXT_B:.*]] = triton_gpu.copy_async
// CHECK: scf.yield {{.*}}, {{.*}}, %[[arg_b1]], %[[NEXT_B]]