[Triton-MLIR] Support FP8 (#864)

Co-authored-by: Superjomn <yanchunwei@outlook.com>
This commit is contained in:
Chenggang Zhao
2022-11-10 15:53:06 +08:00
committed by GitHub
parent 4946167241
commit 57fd1864a7
18 changed files with 571 additions and 160 deletions

View File

@@ -6,8 +6,8 @@ func @cast_ops(%scalar_ptr: !tt.ptr<f32>, %scalar_f32: f32, %scalar_i64: i64) {
%0 = tt.int_to_ptr %scalar_i64 : i64 -> !tt.ptr<f32>
// CHECK: !tt.ptr<f32> -> i64
%1 = tt.ptr_to_int %scalar_ptr : !tt.ptr<f32> -> i64
// CHECK: f32 -> f16
%2 = tt.fp_to_fp %scalar_f32 : f32 -> f16
// CHECK: f32 to f16
%2 = arith.truncf %scalar_f32 : f32 to f16
// 0D tensor -> 0D tensor
%tensor_ptr_0d = tt.splat %scalar_ptr : (!tt.ptr<f32>) -> tensor<!tt.ptr<f32>>
@@ -18,8 +18,8 @@ func @cast_ops(%scalar_ptr: !tt.ptr<f32>, %scalar_f32: f32, %scalar_i64: i64) {
%3 = tt.int_to_ptr %tensor_i64_0d : tensor<i64> -> tensor<!tt.ptr<f32>>
// CHECK: tensor<!tt.ptr<f32>> -> tensor<i64>
%4 = tt.ptr_to_int %tensor_ptr_0d : tensor<!tt.ptr<f32>> -> tensor<i64>
// CHECK: tensor<f32> -> tensor<f16>
%5 = tt.fp_to_fp %tensor_f32_0d : tensor<f32> -> tensor<f16>
// CHECK: tensor<f32> to tensor<f16>
%5 = arith.truncf %tensor_f32_0d : tensor<f32> to tensor<f16>
// 1D tensor -> 1D tensor
%tensor_ptr_1d = tt.splat %scalar_ptr : (!tt.ptr<f32>) -> tensor<16x!tt.ptr<f32>>
@@ -30,8 +30,8 @@ func @cast_ops(%scalar_ptr: !tt.ptr<f32>, %scalar_f32: f32, %scalar_i64: i64) {
%6 = tt.int_to_ptr %tensor_i64_1d : tensor<16xi64> -> tensor<16x!tt.ptr<f32>>
// CHECK: tensor<16x!tt.ptr<f32>> -> tensor<16xi64>
%7 = tt.ptr_to_int %tensor_ptr_1d : tensor<16x!tt.ptr<f32>> -> tensor<16xi64>
// CHECK: tensor<16xf32> -> tensor<16xf16>
%8 = tt.fp_to_fp %tensor_f32_1d : tensor<16xf32> -> tensor<16xf16>
// CHECK: tensor<16xf32> to tensor<16xf16>
%8 = arith.truncf %tensor_f32_1d : tensor<16xf32> to tensor<16xf16>
return
}