[FRONTEND] Add scalar type support for some ops (#661)
This PR adds basic support for scalar-type inputs to some ops (cast and pointer arithmetics) for Triton-MLIR. Also renames getelementptr -> addptr
This commit is contained in:
@@ -46,7 +46,7 @@ func @matmul_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32_i32_i32_i32_i32_i32_i32__12c6
|
||||
%31 = tt.broadcast %29 : (tensor<1x64xi32>) -> tensor<64x64xi32>
|
||||
%32 = arith.addi %30, %31 : tensor<64x64xi32>
|
||||
%33 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<64x64x!tt.ptr<f32>>
|
||||
%34 = tt.getelementptr %33, %32 : tensor<64x64x!tt.ptr<f32>>
|
||||
%34 = tt.addptr %33, %32 : tensor<64x64x!tt.ptr<f32>>
|
||||
%35 = tt.expand_dims %23 {axis = 1 : i32} : (tensor<64xi32>) -> tensor<64x1xi32>
|
||||
%36 = tt.splat %arg8 : (i32) -> tensor<64x1xi32>
|
||||
%37 = arith.muli %35, %36 : tensor<64x1xi32>
|
||||
@@ -57,7 +57,7 @@ func @matmul_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32_i32_i32_i32_i32_i32_i32__12c6
|
||||
%42 = tt.broadcast %40 : (tensor<1x64xi32>) -> tensor<64x64xi32>
|
||||
%43 = arith.addi %41, %42 : tensor<64x64xi32>
|
||||
%44 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<64x64x!tt.ptr<f32>>
|
||||
%45 = tt.getelementptr %44, %43 : tensor<64x64x!tt.ptr<f32>>
|
||||
%45 = tt.addptr %44, %43 : tensor<64x64x!tt.ptr<f32>>
|
||||
%46 = arith.index_cast %arg5 : i32 to index
|
||||
%47:3 = scf.for %arg12 = %c0 to %46 step %c64 iter_args(%arg13 = %cst_0, %arg14 = %34, %arg15 = %45) -> (tensor<64x64xf32>, tensor<64x64x!tt.ptr<f32>>, tensor<64x64x!tt.ptr<f32>>) {
|
||||
%76 = tt.load %arg14, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32>
|
||||
@@ -66,10 +66,10 @@ func @matmul_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32_i32_i32_i32_i32_i32_i32__12c6
|
||||
%79 = arith.addf %arg13, %78 : tensor<64x64xf32>
|
||||
%80 = arith.muli %arg7, %c64_i32 : i32
|
||||
%81 = tt.splat %80 : (i32) -> tensor<64x64xi32>
|
||||
%82 = tt.getelementptr %arg14, %81 : tensor<64x64x!tt.ptr<f32>>
|
||||
%82 = tt.addptr %arg14, %81 : tensor<64x64x!tt.ptr<f32>>
|
||||
%83 = arith.muli %arg8, %c64_i32 : i32
|
||||
%84 = tt.splat %83 : (i32) -> tensor<64x64xi32>
|
||||
%85 = tt.getelementptr %arg15, %84 : tensor<64x64x!tt.ptr<f32>>
|
||||
%85 = tt.addptr %arg15, %84 : tensor<64x64x!tt.ptr<f32>>
|
||||
scf.yield %79, %82, %85 : tensor<64x64xf32>, tensor<64x64x!tt.ptr<f32>>, tensor<64x64x!tt.ptr<f32>>
|
||||
}
|
||||
%48 = arith.muli %12, %c64_i32 : i32
|
||||
@@ -90,7 +90,7 @@ func @matmul_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32_i32_i32_i32_i32_i32_i32__12c6
|
||||
%63 = tt.broadcast %61 : (tensor<1x64xi32>) -> tensor<64x64xi32>
|
||||
%64 = arith.addi %62, %63 : tensor<64x64xi32>
|
||||
%65 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x64x!tt.ptr<f32>>
|
||||
%66 = tt.getelementptr %65, %64 : tensor<64x64x!tt.ptr<f32>>
|
||||
%66 = tt.addptr %65, %64 : tensor<64x64x!tt.ptr<f32>>
|
||||
%67 = tt.expand_dims %51 {axis = 1 : i32} : (tensor<64xi32>) -> tensor<64x1xi32>
|
||||
%68 = tt.splat %arg3 : (i32) -> tensor<64x1xi32>
|
||||
%69 = arith.cmpi slt, %67, %68 : tensor<64x1xi32>
|
||||
@@ -103,4 +103,4 @@ func @matmul_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32_i32_i32_i32_i32_i32_i32__12c6
|
||||
tt.store %66, %47#0, %75 : tensor<64x64xf32>
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user