[Analysis/Allocation] Allocation passes now assumes that slices always alias (#108)

This code in this branch assumes the `src` operand in
`insert_slice_async` always aliases the result, which shouldn't hold for
generally cases but is just a workaround to make the pipeline pass work.

I'm also working on the complete analysis in another
[branch](https://github.com/openai/triton-mlir/tree/keren/analyze-slice).
This commit is contained in:
Keren Zhou
2022-09-09 12:03:41 -07:00
committed by GitHub
parent 9bd5a3dcd2
commit 16aed94ff5
14 changed files with 299 additions and 195 deletions

View File

@@ -37,7 +37,9 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B
func @alloc(%A : !tt.ptr<f16>) {
// CHECK: %cst -> %cst
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A>
%cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL>
%cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL>
// CHECK: %0 -> %0
%cst2 = triton_gpu.alloc_tensor : tensor<16x16xf16, #A>
return
}
@@ -49,25 +51,28 @@ func @convert(%A : !tt.ptr<f16>) {
return
}
// CHECK-LABEL: copy_async
func @copy_async(%A : !tt.ptr<f16>, %i1 : i1) {
// CHECK-LABEL: insert_slice_async
func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) {
%a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<16x16x!tt.ptr<f16>, #AL>
%mask = tt.splat %i1 : (i1) -> tensor<16x16xi1, #AL>
%other = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL>
// CHECK: %2 -> %2
%a = triton_gpu.copy_async %a_ptr, %mask, %other {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<16x16x!tt.ptr<f16>, #AL> -> tensor<16x16xf16, #A>
// CHECK: %cst_0 -> %cst_0
%tensor = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A>
%index = arith.constant 0 : i32
// CHECK: %2 -> %cst_0
%a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index, %mask, %other {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<16x16x!tt.ptr<f16>, #AL> -> tensor<1x16x16xf16, #A>
return
}
// COM: Enable the following test once we support view on shared memory tensors
// COM: // CHECK-LABEL: view
// COM: func @view(%A : !tt.ptr<f16>) {
// COM: // CHECK: res0:0 -> 0
// COM: %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A>
// COM: // CHECK-NEXT: res1:0 -> 0
// COM: %cst1 = tt.view %cst0 : (tensor<16x16xf16, #A>) -> tensor<32x8xf16, #A>
// COM: return
// COM: }
// CHECK-LABEL: extract_slice
func @extract_slice(%A : !tt.ptr<f16>) {
// CHECK: %cst -> %cst
%cst0 = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A>
%index = arith.constant 0 : i32
// CHECK-NEXT: %0 -> %cst
%cst1 = triton_gpu.extract_slice %cst0, %index { axis = 0 : i32 } : tensor<1x16x16xf16, #A> -> tensor<16x16xf16, #A>
return
}
// CHECK-LABEL: if_cat
func @if_cat(%i1 : i1) {
@@ -123,62 +128,31 @@ func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.p
return
}
// COM: // Enable the following test once we support view on shared memory tensors
// COM: // CHECK-LABEL: for_if
// COM: func @for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) {
// COM: // CHECK: res0:0 -> 0
// COM: %a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A>
// COM: // CHECK-NEXT: res1:0 -> 1
// COM: %b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A>
// COM: // CHECK-NEXT: res2:0 -> 2
// COM: %c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A>
// COM: // CHECK-NEXT: arg3:0 -> 0
// COM: // CHECK-NEXT: arg3:1 -> 1
// COM: // CHECK-NEXT: arg3:2 -> 2
// COM: // CHECK-NEXT: res3:0 -> 0,1
// COM: // CHECK-NEXT: res3:1 -> 0,1
// COM: // CHECK-NEXT: res3:2 -> 0,1
// COM: %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A>, tensor<128x32xf16, #A>, tensor<128x32xf16, #A>) {
// COM: scf.if %i1 {
// COM: // CHECK-NEXT: res5:0 -> 0,1
// COM: %cst0 = tt.view %a_shared : (tensor<128x32xf16, #A>) -> tensor<32x128xf16, #A>
// COM: scf.yield
// COM: }
// COM: scf.yield %b_shared, %a_shared, %a_shared : tensor<128x32xf16, #A>, tensor<128x32xf16, #A>, tensor<128x32xf16, #A>
// COM: }
// COM: return
// COM: }
// COM: // Enable the following test once we support view on shared memory tensors
// COM: // CHECK-LABEL: for_if_else
// COM: func @for_if_else(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) {
// COM: // CHECK: res0:0 -> 0
// COM: %a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A>
// COM: // CHECK-NEXT: res1:0 -> 1
// COM: %b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A>
// COM: // CHECK-NEXT: res2:0 -> 2
// COM: %c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A>
// COM: // CHECK-NEXT: arg3:0 -> 0
// COM: // CHECK-NEXT: arg3:1 -> 1
// COM: // CHECK-NEXT: arg3:2 -> 2
// COM: // CHECK-NEXT: res3:0 -> 0
// COM: // CHECK-NEXT: res3:1 -> 1
// COM: // CHECK-NEXT: res3:2 -> 0,7
// COM: %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A>, tensor<128x32xf16, #A>, tensor<128x32xf16, #A>) {
// COM: // CHECK-NEXT: res4:0 -> 0,7
// COM: %c_shared_next = scf.if %i1 -> tensor<128x32xf16, #A> {
// COM: // CHECK-NEXT: res5:0 -> 0
// COM: %cst0 = tt.view %a_shared : (tensor<128x32xf16, #A>) -> tensor<128x32xf16, #A>
// COM: scf.yield %cst0 : tensor<128x32xf16, #A>
// COM: } else {
// COM: // CHECK-NEXT: res7:0 -> 7
// COM: %cst0 = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A>
// COM: scf.yield %cst0 : tensor<128x32xf16, #A>
// COM: }
// COM: scf.yield %a_shared, %b_shared, %c_shared_next : tensor<128x32xf16, #A>, tensor<128x32xf16, #A>, tensor<128x32xf16, #A>
// COM: }
// COM: return
// COM: }
// CHECK-LABEL: for_if
func @for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) {
// CHECK: %cst -> %cst
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A>
// CHECK-NEXT: %cst_0 -> %cst_0
%b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A>
// CHECK-NEXT: %cst_1 -> %cst_1
%c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A>
// CHECK-NEXT: %arg7 -> %cst
// CHECK-NEXT: %arg8 -> %cst_0
// CHECK-NEXT: %arg9 -> %cst_1
// CHECK-NEXT: %0#0 -> %cst,%cst_0
// CHECK-NEXT: %0#1 -> %cst,%cst_0
// CHECK-NEXT: %0#2 -> %cst,%cst_0
%a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A>, tensor<128x32xf16, #A>, tensor<128x32xf16, #A>) {
scf.if %i1 {
%index = arith.constant 8 : i32
// CHECK-NEXT: %1 -> %cst,%cst_0
%cst0 = triton_gpu.extract_slice %a_shared, %index { axis = 0 : i32 } : tensor<128x32xf16, #A> -> tensor<32xf16, #A>
scf.yield
}
scf.yield %b_shared, %a_shared, %a_shared : tensor<128x32xf16, #A>, tensor<128x32xf16, #A>, tensor<128x32xf16, #A>
}
return
}
// CHECK-LABEL: for_if_for
func @for_if_for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) {

View File

@@ -149,6 +149,17 @@ func @longlive(%A : !tt.ptr<f16>) {
// CHECK-NEXT: size = 2560
}
// CHECK-LABEL: alloc
func @alloc(%A : !tt.ptr<f16>) {
// CHECK: offset = 0, size = 512
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A>
%cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL>
// CHECK-NEXT: offset = 0, size = 512
%cst2 = triton_gpu.alloc_tensor : tensor<16x16xf16, #A>
return
// CHECK-NEXT: size = 512
}
// CHECK-LABEL: scratch
func @scratch() {
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL>
@@ -158,6 +169,29 @@ func @scratch() {
// CHECK-NEXT: size = 512
}
// CHECK-LABEL: insert_slice_async
func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) {
%a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<16x16x!tt.ptr<f16>, #AL>
%mask = tt.splat %i1 : (i1) -> tensor<16x16xi1, #AL>
%other = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL>
// CHECK: offset = 0, size = 512
%tensor = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A>
%index = arith.constant 0 : i32
%a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index, %mask, %other {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<16x16x!tt.ptr<f16>, #AL> -> tensor<1x16x16xf16, #A>
return
// CHECK-NEXT: size = 512
}
// CHECK-LABEL: extract_slice
func @extract_slice(%A : !tt.ptr<f16>) {
// CHECK: offset = 0, size = 512
%cst0 = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A>
%index = arith.constant 0 : i32
%cst1 = triton_gpu.extract_slice %cst0, %index { axis = 0 : i32 } : tensor<1x16x16xf16, #A> -> tensor<16x16xf16, #A>
return
// CHECK-NEXT: size = 512
}
// B0 -> (B1) -> B0
// Memory used by B1 can be reused by B0.
// CHECK-LABEL: if
@@ -226,6 +260,26 @@ func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.p
// CHECK-NEXT: size = 24576
}
// CHECK-LABEL: for_if_slice
func @for_if_slice(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) {
// CHECK: offset = 0, size = 8192
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A>
// CHECK-NEXT: offset = 8192, size = 8192
%b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A>
// CHECK-NEXT: offset = 16384, size = 8192
%c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A>
%a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A>, tensor<128x32xf16, #A>, tensor<128x32xf16, #A>) {
scf.if %i1 {
%index = arith.constant 8 : i32
%cst0 = triton_gpu.extract_slice %a_shared, %index { axis = 0 : i32 } : tensor<128x32xf16, #A> -> tensor<32xf16, #A>
scf.yield
}
scf.yield %b_shared, %a_shared, %a_shared : tensor<128x32xf16, #A>, tensor<128x32xf16, #A>, tensor<128x32xf16, #A>
}
return
// CHECK-NEXT: size = 24576
}
// a_shared_init, b_shared_init, and c_shared_init's liveness ranges are span over the entire function before cst2.
// So they cannot be reused by cst0 and cst1, but can be reused by cst2.
// CHECK-LABEL: for_if_for

View File

@@ -56,6 +56,8 @@ func @war_single_block(%A : !tt.ptr<f16>) {
%a1 = triton_gpu.convert_layout %a1_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A>
// CHECK: Membar 5
%a2 = triton_gpu.convert_layout %a1 : (tensor<128x32xf16, #A>) -> tensor<128x32xf16, #AL>
// a2's liveness range ends here, and a3 and a2 have the same address range.
// So it makes sense to have a WAR dependency between a2 and a3.
// CHECK-NEXT: Membar 7
%a3 = triton_gpu.convert_layout %a1_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A>
return
@@ -82,6 +84,41 @@ func @async_wait() {
return
}
// CHECK-LABEL: alloc
func @alloc() {
%cst0 = triton_gpu.alloc_tensor : tensor<16x16xf16, #A>
%a = tt.cat %cst0, %cst0 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A>
// CHECK: Membar 2
%b = triton_gpu.convert_layout %a : (tensor<32x16xf16, #A>) -> tensor<32x16xf16, #AL>
return
}
// CHECK-LABEL: extract_slice
func @extract_slice() {
%cst0 = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A>
%index = arith.constant 0 : i32
%cst1 = triton_gpu.extract_slice %cst0, %index { axis = 0 : i32 } : tensor<1x16x16xf16, #A> -> tensor<16x16xf16, #A>
// CHECK: Membar 3
%cst2 = triton_gpu.convert_layout %cst1 : (tensor<16x16xf16, #A>) -> tensor<16x16xf16, #AL>
// CHECK-NEXT: Membar 5
%cst3 = triton_gpu.convert_layout %cst2 : (tensor<16x16xf16, #AL>) -> tensor<16x16xf16, #A>
return
}
// CHECK-LABEL: insert_slice_async
func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) {
%a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<16x16x!tt.ptr<f16>, #AL>
%mask = tt.splat %i1 : (i1) -> tensor<16x16xi1, #AL>
%other = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL>
%tensor = triton_gpu.alloc_tensor : tensor<1x16x16xf16, #A>
%index = arith.constant 0 : i32
%a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index, %mask, %other {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<16x16x!tt.ptr<f16>, #AL> -> tensor<1x16x16xf16, #A>
%b = tt.cat %a, %a {axis = 0} : (tensor<1x16x16xf16, #A>, tensor<1x16x16xf16, #A>) -> tensor<2x16x16xf16, #A>
// CHECK: Membar 7
%c = tt.cat %b, %b {axis = 0} : (tensor<2x16x16xf16, #A>, tensor<2x16x16xf16, #A>) -> tensor<4x16x16xf16, #A>
return
}
// If branch inserted a barrier for %cst0 and %cst1, but else didn't, then the barrier should be inserted in the parent region
// CHECK-LABEL: multi_blocks
func @multi_blocks(%i1 : i1) {

106
test/TritonGPU/matmul.mlir Normal file
View File

@@ -0,0 +1,106 @@
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu -tritongpu-pipeline=num-stages=3 -tritongpu-combine -test-print-allocation 2>&1 | FileCheck %s
// CHECK: offset = 0, size = 49152
// CHECK: offset = 49152, size = 49152
// CHECK: size = 98304
module {
func @matmul_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32_i32_i32_i32_i32_i32_i32__12c64_13c64_14c64_15c8(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32, %arg10: i32 {tt.divisibility = 16 : i32}, %arg11: i32) {
%cst = arith.constant dense<true> : tensor<64x64xi1>
%c64 = arith.constant 64 : index
%c0 = arith.constant 0 : index
%cst_0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32>
%c64_i32 = arith.constant 64 : i32
%c63_i32 = arith.constant 63 : i32
%c8_i32 = arith.constant 8 : i32
%0 = tt.get_program_id {axis = 0 : i32} : i32
%1 = arith.addi %arg3, %c63_i32 : i32
%2 = arith.divsi %1, %c64_i32 : i32
%3 = arith.addi %arg4, %c63_i32 : i32
%4 = arith.divsi %3, %c64_i32 : i32
%5 = arith.muli %4, %c8_i32 : i32
%6 = arith.divsi %0, %5 : i32
%7 = arith.muli %6, %c8_i32 : i32
%8 = arith.subi %2, %7 : i32
%9 = arith.cmpi slt, %8, %c8_i32 : i32
%10 = select %9, %8, %c8_i32 : i32
%11 = arith.remsi %0, %10 : i32
%12 = arith.addi %7, %11 : i32
%13 = arith.remsi %0, %5 : i32
%14 = arith.divsi %13, %10 : i32
%15 = arith.muli %12, %c64_i32 : i32
%16 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
%17 = tt.splat %15 : (i32) -> tensor<64xi32>
%18 = arith.addi %17, %16 : tensor<64xi32>
%19 = arith.muli %14, %c64_i32 : i32
%20 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
%21 = tt.splat %19 : (i32) -> tensor<64xi32>
%22 = arith.addi %21, %20 : tensor<64xi32>
%23 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
%24 = tt.expand_dims %18 {axis = 1 : i32} : (tensor<64xi32>) -> tensor<64x1xi32>
%25 = tt.splat %arg6 : (i32) -> tensor<64x1xi32>
%26 = arith.muli %24, %25 : tensor<64x1xi32>
%27 = tt.expand_dims %23 {axis = 0 : i32} : (tensor<64xi32>) -> tensor<1x64xi32>
%28 = tt.splat %arg7 : (i32) -> tensor<1x64xi32>
%29 = arith.muli %27, %28 : tensor<1x64xi32>
%30 = tt.broadcast %26 : (tensor<64x1xi32>) -> tensor<64x64xi32>
%31 = tt.broadcast %29 : (tensor<1x64xi32>) -> tensor<64x64xi32>
%32 = arith.addi %30, %31 : tensor<64x64xi32>
%33 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<64x64x!tt.ptr<f32>>
%34 = tt.getelementptr %33, %32 : tensor<64x64x!tt.ptr<f32>>
%35 = tt.expand_dims %23 {axis = 1 : i32} : (tensor<64xi32>) -> tensor<64x1xi32>
%36 = tt.splat %arg8 : (i32) -> tensor<64x1xi32>
%37 = arith.muli %35, %36 : tensor<64x1xi32>
%38 = tt.expand_dims %22 {axis = 0 : i32} : (tensor<64xi32>) -> tensor<1x64xi32>
%39 = tt.splat %arg9 : (i32) -> tensor<1x64xi32>
%40 = arith.muli %38, %39 : tensor<1x64xi32>
%41 = tt.broadcast %37 : (tensor<64x1xi32>) -> tensor<64x64xi32>
%42 = tt.broadcast %40 : (tensor<1x64xi32>) -> tensor<64x64xi32>
%43 = arith.addi %41, %42 : tensor<64x64xi32>
%44 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<64x64x!tt.ptr<f32>>
%45 = tt.getelementptr %44, %43 : tensor<64x64x!tt.ptr<f32>>
%46 = arith.index_cast %arg5 : i32 to index
%47:3 = scf.for %arg12 = %c0 to %46 step %c64 iter_args(%arg13 = %cst_0, %arg14 = %34, %arg15 = %45) -> (tensor<64x64xf32>, tensor<64x64x!tt.ptr<f32>>, tensor<64x64x!tt.ptr<f32>>) {
%76 = tt.load %arg14, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32>
%77 = tt.load %arg15, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32>
%78 = tt.dot %76, %77, %cst_0 {allowTF32 = true} : tensor<64x64xf32> * tensor<64x64xf32> -> tensor<64x64xf32>
%79 = arith.addf %arg13, %78 : tensor<64x64xf32>
%80 = arith.muli %arg7, %c64_i32 : i32
%81 = tt.splat %80 : (i32) -> tensor<64x64xi32>
%82 = tt.getelementptr %arg14, %81 : tensor<64x64x!tt.ptr<f32>>
%83 = arith.muli %arg8, %c64_i32 : i32
%84 = tt.splat %83 : (i32) -> tensor<64x64xi32>
%85 = tt.getelementptr %arg15, %84 : tensor<64x64x!tt.ptr<f32>>
scf.yield %79, %82, %85 : tensor<64x64xf32>, tensor<64x64x!tt.ptr<f32>>, tensor<64x64x!tt.ptr<f32>>
}
%48 = arith.muli %12, %c64_i32 : i32
%49 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
%50 = tt.splat %48 : (i32) -> tensor<64xi32>
%51 = arith.addi %50, %49 : tensor<64xi32>
%52 = arith.muli %14, %c64_i32 : i32
%53 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
%54 = tt.splat %52 : (i32) -> tensor<64xi32>
%55 = arith.addi %54, %53 : tensor<64xi32>
%56 = tt.expand_dims %51 {axis = 1 : i32} : (tensor<64xi32>) -> tensor<64x1xi32>
%57 = tt.splat %arg10 : (i32) -> tensor<64x1xi32>
%58 = arith.muli %57, %56 : tensor<64x1xi32>
%59 = tt.expand_dims %55 {axis = 0 : i32} : (tensor<64xi32>) -> tensor<1x64xi32>
%60 = tt.splat %arg11 : (i32) -> tensor<1x64xi32>
%61 = arith.muli %59, %60 : tensor<1x64xi32>
%62 = tt.broadcast %58 : (tensor<64x1xi32>) -> tensor<64x64xi32>
%63 = tt.broadcast %61 : (tensor<1x64xi32>) -> tensor<64x64xi32>
%64 = arith.addi %62, %63 : tensor<64x64xi32>
%65 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x64x!tt.ptr<f32>>
%66 = tt.getelementptr %65, %64 : tensor<64x64x!tt.ptr<f32>>
%67 = tt.expand_dims %51 {axis = 1 : i32} : (tensor<64xi32>) -> tensor<64x1xi32>
%68 = tt.splat %arg3 : (i32) -> tensor<64x1xi32>
%69 = arith.cmpi slt, %67, %68 : tensor<64x1xi32>
%70 = tt.expand_dims %55 {axis = 0 : i32} : (tensor<64xi32>) -> tensor<1x64xi32>
%71 = tt.splat %arg4 : (i32) -> tensor<1x64xi32>
%72 = arith.cmpi slt, %70, %71 : tensor<1x64xi32>
%73 = tt.broadcast %69 : (tensor<64x1xi1>) -> tensor<64x64xi1>
%74 = tt.broadcast %72 : (tensor<1x64xi1>) -> tensor<64x64xi1>
%75 = arith.andi %73, %74 : tensor<64x64xi1>
tt.store %66, %47#0, %75 : tensor<64x64xf32>
return
}
}

View File

@@ -29,7 +29,7 @@ struct TestMembarPass
MembarAnalysis analysis(&allocation);
size_t operationId = 0;
operation->walk<WalkOrder::PreOrder>([&](Operation *op) {
if (dyn_cast<gpu::BarrierOp>(op)) {
if (isa<gpu::BarrierOp>(op)) {
os << "Membar " << operationId << "\n";
}
if (op->getNumRegions() == 0) {