[FRONTEND] Add scalar type support for some ops (#661)
This PR adds basic support for scalar-type inputs to some ops (cast and pointer arithmetics) for Triton-MLIR. Also renames getelementptr -> addptr
This commit is contained in:
@@ -22,8 +22,8 @@ func @test_combine_dot_add_pattern() -> (tensor<128x128xf32>, tensor<128x128xf32
|
||||
return %res0, %res1 : tensor<128x128xf32>, tensor<128x128xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @test_combine_gep_pattern
|
||||
func @test_combine_gep_pattern(%base: !tt.ptr<f32>) -> tensor<8x!tt.ptr<f32>> {
|
||||
// CHECK-LABEL: @test_combine_addptr_pattern
|
||||
func @test_combine_addptr_pattern(%base: !tt.ptr<f32>) -> tensor<8x!tt.ptr<f32>> {
|
||||
%off0 = arith.constant 10 : i32
|
||||
%off1 = arith.constant 15 : i32
|
||||
|
||||
@@ -37,9 +37,9 @@ func @test_combine_gep_pattern(%base: !tt.ptr<f32>) -> tensor<8x!tt.ptr<f32>> {
|
||||
%idx0 = tt.broadcast %off0 : (i32) -> tensor<8xi32>
|
||||
%idx1 = tt.broadcast %off1 : (i32) -> tensor<8xi32>
|
||||
|
||||
// CHECK-NEXT: %1 = tt.getelementptr %[[tmp0]], %[[cst]] : tensor<8x!tt.ptr<f32>>
|
||||
%ptr0 = tt.getelementptr %base_, %idx0 : tensor<8x!tt.ptr<f32>>
|
||||
%ptr1 = tt.getelementptr %ptr0, %idx1 : tensor<8x!tt.ptr<f32>>
|
||||
// CHECK-NEXT: %1 = tt.addptr %[[tmp0]], %[[cst]] : tensor<8x!tt.ptr<f32>>
|
||||
%ptr0 = tt.addptr %base_, %idx0 : tensor<8x!tt.ptr<f32>>
|
||||
%ptr1 = tt.addptr %ptr0, %idx1 : tensor<8x!tt.ptr<f32>>
|
||||
|
||||
return %ptr1 : tensor<8x!tt.ptr<f32>>
|
||||
}
|
||||
|
Reference in New Issue
Block a user