[BACKEND] llvm::dyn_cast -> llvm::dyn_cast_or_null (#689)
This commit is contained in:
@@ -61,6 +61,22 @@ func @test_combine_select_masked_load_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %con
|
||||
return %0, %1 : tensor<8xf32>, tensor<8xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @test_combine_select_masked_load_fail_pattern
|
||||
func @test_combine_select_masked_load_fail_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %dummy_load: tensor<8xf32>, %dummy_broadcast: tensor<8xi1>, %cond: i1) -> (tensor<8xf32>, tensor<8xf32>) {
|
||||
%false_val = arith.constant dense<0.0> : tensor<8xf32>
|
||||
|
||||
// Case 1: value at the "load" position is not an "op". Select should not be canonicalized.
|
||||
// CHECK: %{{.*}} = select %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32>
|
||||
%0 = select %cond, %dummy_load, %false_val : tensor<8xf32>
|
||||
|
||||
// Case 2: value at the "broadcast" position is not an "op". Select should not be canonicalized.
|
||||
%real_load = tt.load %ptr, %dummy_broadcast, %false_val {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
|
||||
// CHECK: %{{.*}} = select %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32>
|
||||
%1 = select %cond, %real_load, %false_val : tensor<8xf32>
|
||||
|
||||
return %0, %1 : tensor<8xf32>, tensor<8xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @test_combine_broadcast_constant_pattern
|
||||
func @test_combine_broadcast_constant_pattern(%cst : f32) -> tensor<8x2xf32> {
|
||||
// CHECK: %[[cst:.*]] = arith.constant dense<1.000000e+00> : tensor<8x2xf32>
|
||||
@@ -92,6 +108,19 @@ func @test_canonicalize_masked_load_pattern(%ptr: tensor<8x!tt.ptr<f32>>) -> (te
|
||||
return %x, %y, %z: tensor<8xf32>, tensor<8xf32>, tensor<8xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @test_canonicalize_masked_load_fail_pattern
|
||||
func @test_canonicalize_masked_load_fail_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %mask: tensor<8xi1>) -> (tensor<8xf32>, tensor<8xf32>) {
|
||||
%other_val = arith.constant dense<0.0> : tensor<8xf32>
|
||||
|
||||
// Case: value at the "mask" position is not an "op". Load should not be canonicalized.
|
||||
// CHECK: %[[res1:.*]] = tt.load %{{.*}}, %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
|
||||
%x = tt.load %ptr, %mask {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
|
||||
// CHECK: %[[res1:.*]] = tt.load %{{.*}}, %{{.*}}, %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
|
||||
%y = tt.load %ptr, %mask, %other_val {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
|
||||
|
||||
return %x, %y: tensor<8xf32>, tensor<8xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @test_canonicalize_masked_store_pattern
|
||||
func @test_canonicalize_masked_store_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %val: tensor<8xf32>) {
|
||||
%true_mask = arith.constant dense<true> : tensor<8xi1>
|
||||
@@ -105,3 +134,11 @@ func @test_canonicalize_masked_store_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %val:
|
||||
tt.store %ptr, %val, %false_mask : tensor<8xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @test_canonicalize_masked_store_fail_pattern
|
||||
func @test_canonicalize_masked_store_fail_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %val: tensor<8xf32>, %mask: tensor<8xi1>) {
|
||||
// Case: value at the "mask" position is not an "op". Store should not be canonicalized.
|
||||
// CHECK: tt.store %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32>
|
||||
tt.store %ptr, %val, %mask : tensor<8xf32>
|
||||
return
|
||||
}
|
||||
|
Reference in New Issue
Block a user