From 2e08450c80d651e5d6f6820b30ec9fc942ad92fd Mon Sep 17 00:00:00 2001 From: Da Yan Date: Thu, 15 Sep 2022 14:26:40 +0800 Subject: [PATCH] [OPTIMIZER] Better pipeline tests (#660) --- lib/Dialect/TritonGPU/Transforms/Pipeline.cpp | 7 +- test/TritonGPU/loop-pipeline.mlir | 84 ++++++++++++------- 2 files changed, 56 insertions(+), 35 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp index 8021f672e..0b40c9df4 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp @@ -347,16 +347,13 @@ void LoopPipeliner::emitPrologue() { // async.wait & extract_slice Operation *asyncWait = builder.create( loads[0].getLoc(), loads.size() * (numStages - 2)); + loopIterIdx = builder.create(iv.getLoc(), 0, 32); for (Value loadOp : loads) { Value extractSlice = builder.create( loadOp.getLoc(), loadsMapping[loadOp].getType(), - loadStageBuffer[loadOp][numStages - 1], - builder.create(loadOp.getLoc(), 0, 32), - /*axis*/ 0); + loadStageBuffer[loadOp][numStages - 1], loopIterIdx, /*axis*/ 0); loadsExtract[loadOp] = extractSlice; } - - loopIterIdx = builder.create(iv.getLoc(), 0, 32); } scf::ForOp LoopPipeliner::createNewForOp() { diff --git a/test/TritonGPU/loop-pipeline.mlir b/test/TritonGPU/loop-pipeline.mlir index ae5f9191f..3e147f8ef 100644 --- a/test/TritonGPU/loop-pipeline.mlir +++ b/test/TritonGPU/loop-pipeline.mlir @@ -9,23 +9,31 @@ #C = #triton_gpu.mma<{version = 2, warpsPerCTA = [4, 1]}> // CHECK: func @matmul_loop +// CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32 +// CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32 +// CHECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32 +// CHECK-DAG: %[[CONSTANT_3:.*]] = arith.constant 3 : i32 // CHECK: %[[ABUFFER:.*]] = triton_gpu.alloc_tensor -// CHECK: %[[A0BUFFER:.*]] = triton_gpu.insert_slice_async +// CHECK: %[[A0BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_0]] // CHECK: %[[BBUFFER:.*]] = triton_gpu.alloc_tensor -// CHECK: %[[B0BUFFER:.*]] = triton_gpu.insert_slice_async -// CHECK: %[[A1BUFFER:.*]] = triton_gpu.insert_slice_async -// CHECK: %[[B1BUFFER:.*]] = triton_gpu.insert_slice_async +// CHECK: %[[B0BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_0]] +// CHECK: %[[A1BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_1]] +// CHECK: %[[B1BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_1]] // CHECK: triton_gpu.async_wait {num = 2 : i32} -// CHECK: %[[A0:.*]] = triton_gpu.extract_slice %[[A1BUFFER]] -// CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B1BUFFER]] -// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}) +// CHECK: %[[A0:.*]] = triton_gpu.extract_slice %[[A1BUFFER]], %[[CONSTANT_0]] +// CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B1BUFFER]], %[[CONSTANT_0]] +// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_2]], %[[LOOP_IDX:.*]] = %[[CONSTANT_0]] // CHECK: tt.dot %[[arg_a0]], %[[arg_b0]], {{.*}} -// CHECK: %[[NEXT_A_BUFFER:.*]] = triton_gpu.insert_slice_async -// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async +// CHECK-DAG: %[[INSERT_IDX:.*]] = arith.remsi %[[PIPELINE_IDX]], %[[CONSTANT_3]] +// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.remsi %[[LOOP_IDX]], %[[CONSTANT_3]] +// CHECK: %[[NEXT_A_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INSERT_IDX]] +// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INSERT_IDX]] // CHECK: triton_gpu.async_wait {num = 2 : i32} -// CHECK: %[[NEXT_A:.*]] = triton_gpu.extract_slice %[[NEXT_A_BUFFER]] -// CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %[[NEXT_B_BUFFER]] -// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[NEXT_A]], %[[NEXT_B]] +// CHECK: %[[NEXT_A:.*]] = triton_gpu.extract_slice %[[NEXT_A_BUFFER]], %[[EXTRACT_IDX]] +// CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %[[NEXT_B_BUFFER]], %[[EXTRACT_IDX]] +// CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.addi %[[PIPELINE_IDX]], %[[CONSTANT_1]] +// CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[LOOP_IDX]], %[[CONSTANT_1]] +// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[NEXT_A]], %[[NEXT_B]], {{.*}}, {{.*}}, {{.*}}, %[[NEXT_PIPELINE_IDX]], %[[NEXT_LOOP_IDX]] func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { %a_ptr_init = tt.broadcast %A : (!tt.ptr) -> tensor<128x32x!tt.ptr, #AL> %b_ptr_init = tt.broadcast %B : (!tt.ptr) -> tensor<32x128x!tt.ptr, #BL> @@ -56,24 +64,32 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B // CHECK: func @matmul_loop_nested +// CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32 +// CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32 +// CHECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32 +// CHECK-DAG: %[[CONSTANT_3:.*]] = arith.constant 3 : i32 // CHECK: scf.for // CHECK: %[[ABUFFER:.*]] = triton_gpu.alloc_tensor -// CHECK: %[[A0BUFFER:.*]] = triton_gpu.insert_slice_async +// CHECK: %[[A0BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_0]] // CHECK: %[[BBUFFER:.*]] = triton_gpu.alloc_tensor -// CHECK: %[[B0BUFFER:.*]] = triton_gpu.insert_slice_async -// CHECK: %[[A1BUFFER:.*]] = triton_gpu.insert_slice_async -// CHECK: %[[B1BUFFER:.*]] = triton_gpu.insert_slice_async +// CHECK: %[[B0BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_0]] +// CHECK: %[[A1BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_1]] +// CHECK: %[[B1BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_1]] // CHECK: triton_gpu.async_wait {num = 2 : i32} -// CHECK: %[[A0:.*]] = triton_gpu.extract_slice %[[A1BUFFER]] -// CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B1BUFFER]] -// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}) +// CHECK: %[[A0:.*]] = triton_gpu.extract_slice %[[A1BUFFER]], %[[CONSTANT_0]] +// CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B1BUFFER]], %[[CONSTANT_0]] +// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_2]], %[[LOOP_IDX:.*]] = %[[CONSTANT_0]] // CHECK: tt.dot %[[arg_a0]], %[[arg_b0]], {{.*}} -// CHECK: %[[NEXT_A_BUFFER:.*]] = triton_gpu.insert_slice_async -// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async +// CHECK-DAG: %[[INSERT_IDX:.*]] = arith.remsi %[[PIPELINE_IDX]], %[[CONSTANT_3]] +// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.remsi %[[LOOP_IDX]], %[[CONSTANT_3]] +// CHECK: %[[NEXT_A_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INSERT_IDX]] +// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INSERT_IDX]] // CHECK: triton_gpu.async_wait {num = 2 : i32} -// CHECK: %[[NEXT_A:.*]] = triton_gpu.extract_slice %[[NEXT_A_BUFFER]] -// CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %[[NEXT_B_BUFFER]] -// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[NEXT_A]], %[[NEXT_B]] +// CHECK: %[[NEXT_A:.*]] = triton_gpu.extract_slice %[[NEXT_A_BUFFER]], %[[EXTRACT_IDX]] +// CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %[[NEXT_B_BUFFER]], %[[EXTRACT_IDX]] +// CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.addi %[[PIPELINE_IDX]], %[[CONSTANT_1]] +// CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[LOOP_IDX]], %[[CONSTANT_1]] +// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[NEXT_A]], %[[NEXT_B]], {{.*}}, {{.*}}, {{.*}}, %[[NEXT_PIPELINE_IDX]], %[[NEXT_LOOP_IDX]] func @matmul_loop_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { scf.for %iv0 = %lb to %ub step %step { %a_ptr_init = tt.broadcast %A : (!tt.ptr) -> tensor<128x32x!tt.ptr, #AL> @@ -106,17 +122,25 @@ func @matmul_loop_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { %a_ptr_init = tt.broadcast %A : (!tt.ptr) -> tensor<128x32x!tt.ptr, #AL> %b_ptr_init = tt.broadcast %B : (!tt.ptr) -> tensor<32x128x!tt.ptr, #BL>