[FRONTEND] add an attr for masked load without explicit other (#55)

This commit is contained in:
Shintaro Iwasaki
2022-08-18 09:51:37 -07:00
committed by GitHub
parent fc58250a06
commit d69ce77b19
16 changed files with 71 additions and 54 deletions

View File

@@ -45,16 +45,20 @@ func @test_combine_gep_pattern(%base: !tt.ptr<f32>) -> tensor<8x!tt.ptr<f32>> {
}
// CHECK-LABEL: @test_combine_select_masked_load_pattern
func @test_combine_select_masked_load_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %cond: i1) -> tensor<8xf32> {
func @test_combine_select_masked_load_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %cond: i1) -> (tensor<8xf32>, tensor<8xf32>) {
%mask = tt.broadcast %cond : (i1) -> tensor<8xi1>
%false_val = arith.constant dense<0.0> : tensor<8xf32>
// CHECK: %[[res:.*]] = tt.load %{{.*}}, %{{.*}}, %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
%x = tt.load %ptr, %mask, %false_val {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
// CHECK: %[[res1:.*]] = tt.load %{{.*}}, %{{.*}}, %{{.*}} {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<8xf32>
%x = tt.load %ptr, %mask, %false_val {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<8xf32>
%0 = select %cond, %x, %false_val : tensor<8xf32>
// CHECK: return %[[res]] : tensor<8xf32>
return %0 : tensor<8xf32>
// CHECK: %[[res2:.*]] = tt.load %{{.*}}, %{{.*}}, %{{.*}} {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = true, isVolatile = false} : tensor<8xf32>
%y = tt.load %ptr, %mask, %false_val {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = true, isVolatile = false} : tensor<8xf32>
%1 = select %cond, %y, %false_val : tensor<8xf32>
// CHECK: return %[[res1]], %[[res2]] : tensor<8xf32>, tensor<8xf32>
return %0, %1 : tensor<8xf32>, tensor<8xf32>
}
// CHECK-LABEL: @test_combine_broadcast_constant_pattern

View File

@@ -24,10 +24,10 @@ module {
%15:3 = scf.for %arg6 = %12 to %13 step %14 iter_args(%arg7 = %11, %arg8 = %8, %arg9 = %10) -> (tensor<256xf32>, tensor<256x!tt.ptr<f32>>, tensor<256x!tt.ptr<f32>>) {
%cst_0 = arith.constant 0.000000e+00 : f32
%18 = tt.broadcast %cst_0 : (f32) -> tensor<256xf32>
%19 = tt.load %arg8, %6, %18 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32>
%19 = tt.load %arg8, %6, %18 {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<256xf32>
%cst_1 = arith.constant 0.000000e+00 : f32
%20 = tt.broadcast %cst_1 : (f32) -> tensor<256xf32>
%21 = tt.load %arg9, %6, %20 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32>
%21 = tt.load %arg9, %6, %20 {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<256xf32>
%22 = arith.addf %19, %21 : tensor<256xf32>
%23 = arith.addf %arg7, %22 : tensor<256xf32>
%24 = tt.broadcast %arg5 : (i32) -> tensor<256xi32>