[OPTIMIZER] Let the pipeline pass insert async wait. (#63)
This commit is contained in:
@@ -34,6 +34,8 @@ def TTG_AsyncWaitOp : TTG_Op<"async_wait"> {
|
|||||||
let summary = "async wait";
|
let summary = "async wait";
|
||||||
|
|
||||||
let arguments = (ins I32Attr:$num);
|
let arguments = (ins I32Attr:$num);
|
||||||
|
|
||||||
|
let assemblyFormat = "attr-dict";
|
||||||
}
|
}
|
||||||
|
|
||||||
def TTG_CopyAsyncOp : TTG_Op<"copy_async",
|
def TTG_CopyAsyncOp : TTG_Op<"copy_async",
|
||||||
|
@@ -47,6 +47,9 @@ class LoopPipeliner {
|
|||||||
|
|
||||||
void setValueMapping(Value origin, Value newValue, int stage);
|
void setValueMapping(Value origin, Value newValue, int stage);
|
||||||
|
|
||||||
|
/// return true if this op uses any of `loads`
|
||||||
|
bool isDirectUserOfAsyncLoad(Operation &op);
|
||||||
|
|
||||||
public:
|
public:
|
||||||
LoopPipeliner(scf::ForOp forOp, int numStages)
|
LoopPipeliner(scf::ForOp forOp, int numStages)
|
||||||
: forOp(forOp), numStages(numStages) {
|
: forOp(forOp), numStages(numStages) {
|
||||||
@@ -96,6 +99,19 @@ void LoopPipeliner::collectDeps(Value v, int stages, DenseSet<Value> &deps) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool LoopPipeliner::isDirectUserOfAsyncLoad(Operation &op) {
|
||||||
|
for (Value loadOp : loads) {
|
||||||
|
assert(loadOp.hasOneUse() &&
|
||||||
|
"load should only have one use (ConvertLayout)");
|
||||||
|
Value loadUseResult = loadOp.getUsers().begin()->getResult(0);
|
||||||
|
for (Value opOperand : op.getOperands()) {
|
||||||
|
if (opOperand == loadUseResult)
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
/// A load instruction can be pipelined if:
|
/// A load instruction can be pipelined if:
|
||||||
/// - the load doesn't depend on any other loads (after loop peeling)
|
/// - the load doesn't depend on any other loads (after loop peeling)
|
||||||
/// - (?) this load is not a loop-invariant value (we should run LICM before
|
/// - (?) this load is not a loop-invariant value (we should run LICM before
|
||||||
@@ -318,7 +334,14 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
|||||||
mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
|
mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
|
||||||
|
|
||||||
// 2.1 clone the loop body, replace original args with args of the new ForOp
|
// 2.1 clone the loop body, replace original args with args of the new ForOp
|
||||||
|
// Insert async wait if necessary.
|
||||||
|
bool asyncWaitInserted = false;
|
||||||
for (Operation &op : forOp.getBody()->without_terminator()) {
|
for (Operation &op : forOp.getBody()->without_terminator()) {
|
||||||
|
if (!asyncWaitInserted && isDirectUserOfAsyncLoad(op)) {
|
||||||
|
asyncWaitInserted = true;
|
||||||
|
builder.create<triton::gpu::AsyncWaitOp>(op.getLoc(),
|
||||||
|
loads.size() * (numStages - 1));
|
||||||
|
}
|
||||||
Operation *newOp = builder.clone(op, mapping);
|
Operation *newOp = builder.clone(op, mapping);
|
||||||
// update mapping of results
|
// update mapping of results
|
||||||
for (unsigned dstIdx : llvm::seq(unsigned(0), op.getNumResults()))
|
for (unsigned dstIdx : llvm::seq(unsigned(0), op.getNumResults()))
|
||||||
|
@@ -14,6 +14,7 @@
|
|||||||
// CHECK: %[[A1:.*]] = triton_gpu.copy_async
|
// CHECK: %[[A1:.*]] = triton_gpu.copy_async
|
||||||
// CHECK: %[[B1:.*]] = triton_gpu.copy_async
|
// CHECK: %[[B1:.*]] = triton_gpu.copy_async
|
||||||
// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_a1:.*]] = %[[A1]], %[[arg_b0:.*]] = %[[B0]], %[[arg_b1:.*]] = %[[B1]], {{.*}})
|
// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_a1:.*]] = %[[A1]], %[[arg_b0:.*]] = %[[B0]], %[[arg_b1:.*]] = %[[B1]], {{.*}})
|
||||||
|
// CHECK: triton_gpu.async_wait {num = 4 : i32}
|
||||||
// CHECK: tt.dot %[[arg_a0]], %[[arg_b0]], {{.*}}
|
// CHECK: tt.dot %[[arg_a0]], %[[arg_b0]], {{.*}}
|
||||||
// CHECK: %[[NEXT_A:.*]] = triton_gpu.copy_async
|
// CHECK: %[[NEXT_A:.*]] = triton_gpu.copy_async
|
||||||
// CHECK: %[[NEXT_B:.*]] = triton_gpu.copy_async
|
// CHECK: %[[NEXT_B:.*]] = triton_gpu.copy_async
|
||||||
@@ -54,6 +55,7 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B
|
|||||||
// CHECK: %[[A1:.*]] = triton_gpu.copy_async
|
// CHECK: %[[A1:.*]] = triton_gpu.copy_async
|
||||||
// CHECK: %[[B1:.*]] = triton_gpu.copy_async
|
// CHECK: %[[B1:.*]] = triton_gpu.copy_async
|
||||||
// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_a1:.*]] = %[[A1]], %[[arg_b0:.*]] = %[[B0]], %[[arg_b1:.*]] = %[[B1]], {{.*}})
|
// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_a1:.*]] = %[[A1]], %[[arg_b0:.*]] = %[[B0]], %[[arg_b1:.*]] = %[[B1]], {{.*}})
|
||||||
|
// CHECK: triton_gpu.async_wait {num = 4 : i32}
|
||||||
// CHECK: tt.dot %[[arg_a0]], %[[arg_b0]], {{.*}}
|
// CHECK: tt.dot %[[arg_a0]], %[[arg_b0]], {{.*}}
|
||||||
// CHECK: %[[NEXT_A:.*]] = triton_gpu.copy_async
|
// CHECK: %[[NEXT_A:.*]] = triton_gpu.copy_async
|
||||||
// CHECK: %[[NEXT_B:.*]] = triton_gpu.copy_async
|
// CHECK: %[[NEXT_B:.*]] = triton_gpu.copy_async
|
||||||
@@ -93,6 +95,7 @@ func @matmul_loop_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f
|
|||||||
// CHECK: %[[B0:.*]] = triton_gpu.copy_async
|
// CHECK: %[[B0:.*]] = triton_gpu.copy_async
|
||||||
// CHECK: %[[B1:.*]] = triton_gpu.copy_async
|
// CHECK: %[[B1:.*]] = triton_gpu.copy_async
|
||||||
// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, %[[arg_b0:.*]] = %[[B0]], %[[arg_b1:.*]] = %[[B1]], {{.*}})
|
// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, %[[arg_b0:.*]] = %[[B0]], %[[arg_b1:.*]] = %[[B1]], {{.*}})
|
||||||
|
// CHECK: triton_gpu.async_wait {num = 2 : i32}
|
||||||
// CHECK: tt.dot {{.*}}, %[[arg_b0]], {{.*}}
|
// CHECK: tt.dot {{.*}}, %[[arg_b0]], {{.*}}
|
||||||
// CHECK: %[[NEXT_B:.*]] = triton_gpu.copy_async
|
// CHECK: %[[NEXT_B:.*]] = triton_gpu.copy_async
|
||||||
// CHECK: scf.yield {{.*}}, {{.*}}, %[[arg_b1]], %[[NEXT_B]]
|
// CHECK: scf.yield {{.*}}, {{.*}}, %[[arg_b1]], %[[NEXT_B]]
|
||||||
|
Reference in New Issue
Block a user