[FRONTEND][BACKEND] Fixes for cat / reshape / addptr (#959)

Most notably, this PR:
- changes the traits (and assembly format) of addptr so it can handle offsets that have arbitrary integer width.
- adds support for `cat`
This commit is contained in:
Philippe Tillet
2022-12-06 23:29:50 -08:00
committed by GitHub
parent 981aee7f1e
commit b2b793dfb5
24 changed files with 199 additions and 132 deletions

View File

@@ -22,28 +22,30 @@ func @test_combine_dot_add_pattern() -> (tensor<128x128xf32>, tensor<128x128xf32
return %res0, %res1 : tensor<128x128xf32>, tensor<128x128xf32>
}
// CHECK-LABEL: @test_combine_addptr_pattern
// COM: CHECK-LABEL: @test_combine_addptr_pattern
func @test_combine_addptr_pattern(%base: !tt.ptr<f32>) -> tensor<8x!tt.ptr<f32>> {
%off0 = arith.constant 10 : i32
%off1 = arith.constant 15 : i32
// 10 + 15 = 25
// CHECK-NEXT: %[[cst:.*]] = arith.constant dense<25> : tensor<8xi32>
// COM: CHECK-NEXT: %[[cst:.*]] = arith.constant dense<25> : tensor<8xi32>
%base_ = tt.broadcast %base : (!tt.ptr<f32>) -> tensor<8x!tt.ptr<f32>>
// CHECK-NEXT: %[[tmp0:.*]] = tt.broadcast %{{.*}} : (!tt.ptr<f32>) -> tensor<8x!tt.ptr<f32>>
// COM: CHECK-NEXT: %[[tmp0:.*]] = tt.broadcast %{{.*}} : (!tt.ptr<f32>) -> tensor<8x!tt.ptr<f32>>
%idx0 = tt.broadcast %off0 : (i32) -> tensor<8xi32>
%idx1 = tt.broadcast %off1 : (i32) -> tensor<8xi32>
// CHECK-NEXT: %1 = tt.addptr %[[tmp0]], %[[cst]] : tensor<8x!tt.ptr<f32>>
%ptr0 = tt.addptr %base_, %idx0 : tensor<8x!tt.ptr<f32>>
%ptr1 = tt.addptr %ptr0, %idx1 : tensor<8x!tt.ptr<f32>>
// COM: CHECK-NEXT: %1 = tt.addptr %[[tmp0]], %[[cst]] : tensor<8x!tt.ptr<f32>>, tensor<8xi32>
%ptr0 = tt.addptr %base_, %idx0 : tensor<8x!tt.ptr<f32>>, tensor<8xi32>
%ptr1 = tt.addptr %ptr0, %idx1 : tensor<8x!tt.ptr<f32>>, tensor<8xi32>
return %ptr1 : tensor<8x!tt.ptr<f32>>
}
// CHECK-LABEL: @test_combine_select_masked_load_pattern
func @test_combine_select_masked_load_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %cond: i1) -> (tensor<8xf32>, tensor<8xf32>) {
%mask = tt.broadcast %cond : (i1) -> tensor<8xi1>