[FRONTEND] Add scalar type support for some ops (#661)

This PR adds basic support for scalar-type inputs to some ops (cast and pointer arithmetics) for Triton-MLIR. Also renames getelementptr -> addptr
This commit is contained in:
Shintaro Iwasaki
2022-09-15 16:12:52 -07:00
committed by GitHub
parent 2e08450c80
commit 43be75ad42
27 changed files with 203 additions and 129 deletions

View File

@@ -22,8 +22,8 @@ func @test_combine_dot_add_pattern() -> (tensor<128x128xf32>, tensor<128x128xf32
return %res0, %res1 : tensor<128x128xf32>, tensor<128x128xf32>
}
// CHECK-LABEL: @test_combine_gep_pattern
func @test_combine_gep_pattern(%base: !tt.ptr<f32>) -> tensor<8x!tt.ptr<f32>> {
// CHECK-LABEL: @test_combine_addptr_pattern
func @test_combine_addptr_pattern(%base: !tt.ptr<f32>) -> tensor<8x!tt.ptr<f32>> {
%off0 = arith.constant 10 : i32
%off1 = arith.constant 15 : i32
@@ -37,9 +37,9 @@ func @test_combine_gep_pattern(%base: !tt.ptr<f32>) -> tensor<8x!tt.ptr<f32>> {
%idx0 = tt.broadcast %off0 : (i32) -> tensor<8xi32>
%idx1 = tt.broadcast %off1 : (i32) -> tensor<8xi32>
// CHECK-NEXT: %1 = tt.getelementptr %[[tmp0]], %[[cst]] : tensor<8x!tt.ptr<f32>>
%ptr0 = tt.getelementptr %base_, %idx0 : tensor<8x!tt.ptr<f32>>
%ptr1 = tt.getelementptr %ptr0, %idx1 : tensor<8x!tt.ptr<f32>>
// CHECK-NEXT: %1 = tt.addptr %[[tmp0]], %[[cst]] : tensor<8x!tt.ptr<f32>>
%ptr0 = tt.addptr %base_, %idx0 : tensor<8x!tt.ptr<f32>>
%ptr1 = tt.addptr %ptr0, %idx1 : tensor<8x!tt.ptr<f32>>
return %ptr1 : tensor<8x!tt.ptr<f32>>
}