[TritonIR] simplify Load/StoreOps when mask is true/false (#79)
* [TritonIR] fix Load/Store/CopyAsyncOp's parsers * [TritonIR] simplify Load/StoreOps when mask is true/false * [TEST] adds tests to check load/store simplification
This commit is contained in:
@@ -70,3 +70,38 @@ func @test_combine_broadcast_constant_pattern(%cst : f32) -> tensor<8x2xf32> {
|
||||
// CHECK-NEXT: return %[[cst]] : tensor<8x2xf32>
|
||||
return %bst_out : tensor<8x2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @test_canonicalize_masked_load_pattern
|
||||
func @test_canonicalize_masked_load_pattern(%ptr: tensor<8x!tt.ptr<f32>>) -> (tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) {
|
||||
%true_mask = arith.constant dense<true> : tensor<8xi1>
|
||||
%false_mask = arith.constant dense<false> : tensor<8xi1>
|
||||
%other_val = arith.constant dense<0.0> : tensor<8xf32>
|
||||
|
||||
// true_mask with other
|
||||
// CHECK: %[[res1:.*]] = tt.load %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
|
||||
%x = tt.load %ptr, %true_mask {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
|
||||
|
||||
// true_mask without other
|
||||
// CHECK: %[[res2:.*]] = tt.load %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
|
||||
%y = tt.load %ptr, %true_mask, %other_val {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
|
||||
|
||||
// false_mask with other. It should become "other" (i.e., %y)
|
||||
%z = tt.load %ptr, %false_mask, %y {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
|
||||
|
||||
// CHECK: return %[[res1]], %[[res2]], %[[res2]] : tensor<8xf32>, tensor<8xf32>, tensor<8xf32>
|
||||
return %x, %y, %z: tensor<8xf32>, tensor<8xf32>, tensor<8xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @test_canonicalize_masked_store_pattern
|
||||
func @test_canonicalize_masked_store_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %val: tensor<8xf32>) {
|
||||
%true_mask = arith.constant dense<true> : tensor<8xi1>
|
||||
%false_mask = arith.constant dense<false> : tensor<8xi1>
|
||||
|
||||
// CHECK: tt.store %{{.*}}, %{{.*}} : tensor<8xf32>
|
||||
tt.store %ptr, %val, %true_mask : tensor<8xf32>
|
||||
|
||||
// The following store should disappear.
|
||||
// CHECK-NEXT: return
|
||||
tt.store %ptr, %val, %false_mask : tensor<8xf32>
|
||||
return
|
||||
}
|
||||
|
Reference in New Issue
Block a user