diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp index 6fb17950e..3cf54d7eb 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp @@ -337,8 +337,9 @@ scf::ForOp LoopPipeliner::createNewForOp() { for (Operation &op : forOp.getBody()->without_terminator()) { if (!asyncWaitInserted && isDirectUserOfAsyncLoad(op)) { asyncWaitInserted = true; + assert(numStages >= 2); builder.create(op.getLoc(), - loads.size() * (numStages - 1)); + loads.size() * (numStages - 2)); } Operation *newOp = builder.clone(op, mapping); // update mapping of results diff --git a/test/TritonGPU/loop-pipeline.mlir b/test/TritonGPU/loop-pipeline.mlir index e71d20357..d3fb45f39 100644 --- a/test/TritonGPU/loop-pipeline.mlir +++ b/test/TritonGPU/loop-pipeline.mlir @@ -14,7 +14,7 @@ // CHECK: %[[A1:.*]] = triton_gpu.copy_async // CHECK: %[[B1:.*]] = triton_gpu.copy_async // CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_a1:.*]] = %[[A1]], %[[arg_b0:.*]] = %[[B0]], %[[arg_b1:.*]] = %[[B1]], {{.*}}) -// CHECK: triton_gpu.async_wait {num = 4 : i32} +// CHECK: triton_gpu.async_wait {num = 2 : i32} // CHECK: tt.dot %[[arg_a0]], %[[arg_b0]], {{.*}} // CHECK: %[[NEXT_A:.*]] = triton_gpu.copy_async // CHECK: %[[NEXT_B:.*]] = triton_gpu.copy_async @@ -55,7 +55,7 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B // CHECK: %[[A1:.*]] = triton_gpu.copy_async // CHECK: %[[B1:.*]] = triton_gpu.copy_async // CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_a1:.*]] = %[[A1]], %[[arg_b0:.*]] = %[[B0]], %[[arg_b1:.*]] = %[[B1]], {{.*}}) -// CHECK: triton_gpu.async_wait {num = 4 : i32} +// CHECK: triton_gpu.async_wait {num = 2 : i32} // CHECK: tt.dot %[[arg_a0]], %[[arg_b0]], {{.*}} // CHECK: %[[NEXT_A:.*]] = triton_gpu.copy_async // CHECK: %[[NEXT_B:.*]] = triton_gpu.copy_async @@ -95,7 +95,7 @@ func @matmul_loop_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr