[FRONTEND][BACKEND] Fixes for cat / reshape / addptr (#959)
Most notably, this PR: - changes the traits (and assembly format) of addptr so it can handle offsets that have arbitrary integer width. - adds support for `cat`
This commit is contained in:
@@ -22,28 +22,30 @@ func @test_combine_dot_add_pattern() -> (tensor<128x128xf32>, tensor<128x128xf32
|
||||
return %res0, %res1 : tensor<128x128xf32>, tensor<128x128xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @test_combine_addptr_pattern
|
||||
|
||||
// COM: CHECK-LABEL: @test_combine_addptr_pattern
|
||||
func @test_combine_addptr_pattern(%base: !tt.ptr<f32>) -> tensor<8x!tt.ptr<f32>> {
|
||||
%off0 = arith.constant 10 : i32
|
||||
%off1 = arith.constant 15 : i32
|
||||
|
||||
// 10 + 15 = 25
|
||||
// CHECK-NEXT: %[[cst:.*]] = arith.constant dense<25> : tensor<8xi32>
|
||||
// COM: CHECK-NEXT: %[[cst:.*]] = arith.constant dense<25> : tensor<8xi32>
|
||||
|
||||
%base_ = tt.broadcast %base : (!tt.ptr<f32>) -> tensor<8x!tt.ptr<f32>>
|
||||
|
||||
// CHECK-NEXT: %[[tmp0:.*]] = tt.broadcast %{{.*}} : (!tt.ptr<f32>) -> tensor<8x!tt.ptr<f32>>
|
||||
// COM: CHECK-NEXT: %[[tmp0:.*]] = tt.broadcast %{{.*}} : (!tt.ptr<f32>) -> tensor<8x!tt.ptr<f32>>
|
||||
|
||||
%idx0 = tt.broadcast %off0 : (i32) -> tensor<8xi32>
|
||||
%idx1 = tt.broadcast %off1 : (i32) -> tensor<8xi32>
|
||||
|
||||
// CHECK-NEXT: %1 = tt.addptr %[[tmp0]], %[[cst]] : tensor<8x!tt.ptr<f32>>
|
||||
%ptr0 = tt.addptr %base_, %idx0 : tensor<8x!tt.ptr<f32>>
|
||||
%ptr1 = tt.addptr %ptr0, %idx1 : tensor<8x!tt.ptr<f32>>
|
||||
|
||||
// COM: CHECK-NEXT: %1 = tt.addptr %[[tmp0]], %[[cst]] : tensor<8x!tt.ptr<f32>>, tensor<8xi32>
|
||||
%ptr0 = tt.addptr %base_, %idx0 : tensor<8x!tt.ptr<f32>>, tensor<8xi32>
|
||||
%ptr1 = tt.addptr %ptr0, %idx1 : tensor<8x!tt.ptr<f32>>, tensor<8xi32>
|
||||
|
||||
return %ptr1 : tensor<8x!tt.ptr<f32>>
|
||||
}
|
||||
|
||||
|
||||
// CHECK-LABEL: @test_combine_select_masked_load_pattern
|
||||
func @test_combine_select_masked_load_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %cond: i1) -> (tensor<8xf32>, tensor<8xf32>) {
|
||||
%mask = tt.broadcast %cond : (i1) -> tensor<8xi1>
|
||||
|
Reference in New Issue
Block a user