[FRONTEND][BACKEND] Fixes for cat / reshape / addptr (#959)

Most notably, this PR:
- changes the traits (and assembly format) of addptr so it can handle offsets that have arbitrary integer width.
- adds support for `cat`
This commit is contained in:
Philippe Tillet
2022-12-06 23:29:50 -08:00
committed by GitHub
parent 981aee7f1e
commit b2b793dfb5
24 changed files with 199 additions and 132 deletions

View File

@@ -46,7 +46,7 @@ func @matmul_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32_i32_i32_i32_i32_i32_i32__12c6
%31 = tt.broadcast %29 : (tensor<1x64xi32>) -> tensor<64x64xi32>
%32 = arith.addi %30, %31 : tensor<64x64xi32>
%33 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<64x64x!tt.ptr<f32>>
%34 = tt.addptr %33, %32 : tensor<64x64x!tt.ptr<f32>>
%34 = tt.addptr %33, %32 : tensor<64x64x!tt.ptr<f32>>, tensor<64x64xi32>
%35 = tt.expand_dims %23 {axis = 1 : i32} : (tensor<64xi32>) -> tensor<64x1xi32>
%36 = tt.splat %arg8 : (i32) -> tensor<64x1xi32>
%37 = arith.muli %35, %36 : tensor<64x1xi32>
@@ -57,7 +57,7 @@ func @matmul_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32_i32_i32_i32_i32_i32_i32__12c6
%42 = tt.broadcast %40 : (tensor<1x64xi32>) -> tensor<64x64xi32>
%43 = arith.addi %41, %42 : tensor<64x64xi32>
%44 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<64x64x!tt.ptr<f32>>
%45 = tt.addptr %44, %43 : tensor<64x64x!tt.ptr<f32>>
%45 = tt.addptr %44, %43 : tensor<64x64x!tt.ptr<f32>>, tensor<64x64xi32>
%46 = arith.index_cast %arg5 : i32 to index
%47:3 = scf.for %arg12 = %c0 to %46 step %c64 iter_args(%arg13 = %cst_0, %arg14 = %34, %arg15 = %45) -> (tensor<64x64xf32>, tensor<64x64x!tt.ptr<f32>>, tensor<64x64x!tt.ptr<f32>>) {
%76 = tt.load %arg14, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false, transA=false, transB=false} : tensor<64x64xf32>
@@ -66,10 +66,10 @@ func @matmul_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32_i32_i32_i32_i32_i32_i32__12c6
%79 = arith.addf %arg13, %78 : tensor<64x64xf32>
%80 = arith.muli %arg7, %c64_i32 : i32
%81 = tt.splat %80 : (i32) -> tensor<64x64xi32>
%82 = tt.addptr %arg14, %81 : tensor<64x64x!tt.ptr<f32>>
%82 = tt.addptr %arg14, %81 : tensor<64x64x!tt.ptr<f32>>, tensor<64x64xi32>
%83 = arith.muli %arg8, %c64_i32 : i32
%84 = tt.splat %83 : (i32) -> tensor<64x64xi32>
%85 = tt.addptr %arg15, %84 : tensor<64x64x!tt.ptr<f32>>
%85 = tt.addptr %arg15, %84 : tensor<64x64x!tt.ptr<f32>>, tensor<64x64xi32>
scf.yield %79, %82, %85 : tensor<64x64xf32>, tensor<64x64x!tt.ptr<f32>>, tensor<64x64x!tt.ptr<f32>>
}
%48 = arith.muli %12, %c64_i32 : i32
@@ -90,7 +90,7 @@ func @matmul_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32_i32_i32_i32_i32_i32_i32__12c6
%63 = tt.broadcast %61 : (tensor<1x64xi32>) -> tensor<64x64xi32>
%64 = arith.addi %62, %63 : tensor<64x64xi32>
%65 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x64x!tt.ptr<f32>>
%66 = tt.addptr %65, %64 : tensor<64x64x!tt.ptr<f32>>
%66 = tt.addptr %65, %64 : tensor<64x64x!tt.ptr<f32>>, tensor<64x64xi32>
%67 = tt.expand_dims %51 {axis = 1 : i32} : (tensor<64xi32>) -> tensor<64x1xi32>
%68 = tt.splat %arg3 : (i32) -> tensor<64x1xi32>
%69 = arith.cmpi slt, %67, %68 : tensor<64x1xi32>