[FRONTEND] Add scalar type support for some ops (#661)
This PR adds basic support for scalar-type inputs to some ops (cast and pointer arithmetics) for Triton-MLIR. Also renames getelementptr -> addptr
This commit is contained in:
55
test/Conversion/triton_ops.mlir
Normal file
55
test/Conversion/triton_ops.mlir
Normal file
@@ -0,0 +1,55 @@
|
||||
// RUN: triton-opt %s | FileCheck %s
|
||||
|
||||
func @cast_ops(%scalar_ptr: !tt.ptr<f32>, %scalar_f32: f32, %scalar_i64: i64) {
|
||||
// scalar -> scalar
|
||||
// CHECK: i64 -> !tt.ptr<f32>
|
||||
%0 = tt.int_to_ptr %scalar_i64 : i64 -> !tt.ptr<f32>
|
||||
// CHECK: !tt.ptr<f32> -> i64
|
||||
%1 = tt.ptr_to_int %scalar_ptr : !tt.ptr<f32> -> i64
|
||||
// CHECK: f32 -> f16
|
||||
%2 = tt.fp_to_fp %scalar_f32 : f32 -> f16
|
||||
|
||||
// 0D tensor -> 0D tensor
|
||||
%tensor_ptr_0d = tt.splat %scalar_ptr : (!tt.ptr<f32>) -> tensor<!tt.ptr<f32>>
|
||||
%tensor_f32_0d = tt.splat %scalar_f32 : (f32) -> tensor<f32>
|
||||
%tensor_i64_0d = tt.splat %scalar_i64 : (i64) -> tensor<i64>
|
||||
|
||||
// CHECK: tensor<i64> -> tensor<!tt.ptr<f32>>
|
||||
%3 = tt.int_to_ptr %tensor_i64_0d : tensor<i64> -> tensor<!tt.ptr<f32>>
|
||||
// CHECK: tensor<!tt.ptr<f32>> -> tensor<i64>
|
||||
%4 = tt.ptr_to_int %tensor_ptr_0d : tensor<!tt.ptr<f32>> -> tensor<i64>
|
||||
// CHECK: tensor<f32> -> tensor<f16>
|
||||
%5 = tt.fp_to_fp %tensor_f32_0d : tensor<f32> -> tensor<f16>
|
||||
|
||||
// 1D tensor -> 1D tensor
|
||||
%tensor_ptr_1d = tt.splat %scalar_ptr : (!tt.ptr<f32>) -> tensor<16x!tt.ptr<f32>>
|
||||
%tensor_f32_1d = tt.splat %scalar_f32 : (f32) -> tensor<16xf32>
|
||||
%tensor_i64_1d = tt.splat %scalar_i64 : (i64) -> tensor<16xi64>
|
||||
|
||||
// CHECK: tensor<16xi64> -> tensor<16x!tt.ptr<f32>>
|
||||
%6 = tt.int_to_ptr %tensor_i64_1d : tensor<16xi64> -> tensor<16x!tt.ptr<f32>>
|
||||
// CHECK: tensor<16x!tt.ptr<f32>> -> tensor<16xi64>
|
||||
%7 = tt.ptr_to_int %tensor_ptr_1d : tensor<16x!tt.ptr<f32>> -> tensor<16xi64>
|
||||
// CHECK: tensor<16xf32> -> tensor<16xf16>
|
||||
%8 = tt.fp_to_fp %tensor_f32_1d : tensor<16xf32> -> tensor<16xf16>
|
||||
return
|
||||
}
|
||||
|
||||
func @addptr_ops(%scalar_ptr: !tt.ptr<f32>, %scalar_i32: i32) {
|
||||
// scalar -> scalar
|
||||
// CHECK: !tt.ptr<f32>
|
||||
%0 = tt.addptr %scalar_ptr, %scalar_i32 : !tt.ptr<f32>
|
||||
|
||||
// 0D tensor -> 0D tensor
|
||||
%tensor_ptr_0d = tt.splat %scalar_ptr : (!tt.ptr<f32>) -> tensor<!tt.ptr<f32>>
|
||||
%tensor_i32_0d = tt.splat %scalar_i32 : (i32) -> tensor<i32>
|
||||
// CHECK: tensor<!tt.ptr<f32>>
|
||||
%1 = tt.addptr %tensor_ptr_0d, %tensor_i32_0d : tensor<!tt.ptr<f32>>
|
||||
|
||||
// 1D tensor -> 1D tensor
|
||||
%tensor_ptr_1d = tt.splat %scalar_ptr : (!tt.ptr<f32>) -> tensor<16x!tt.ptr<f32>>
|
||||
%tensor_i32_1d = tt.splat %scalar_i32 : (i32) -> tensor<16xi32>
|
||||
// CHECK: tensor<16x!tt.ptr<f32>>
|
||||
%2 = tt.addptr %tensor_ptr_1d, %tensor_i32_1d : tensor<16x!tt.ptr<f32>>
|
||||
return
|
||||
}
|
@@ -82,9 +82,9 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
|
||||
%3 = tt.splat %1 : (i32) -> tensor<256xi32, #blocked0>
|
||||
%4 = arith.addi %3, %2 : tensor<256xi32, #blocked0>
|
||||
%5 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #blocked0>
|
||||
%6 = tt.getelementptr %5, %4 : tensor<256x!tt.ptr<f32>, #blocked0>
|
||||
%6 = tt.addptr %5, %4 : tensor<256x!tt.ptr<f32>, #blocked0>
|
||||
%7 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #blocked0>
|
||||
%8 = tt.getelementptr %7, %4 : tensor<256x!tt.ptr<f32>, #blocked0>
|
||||
%8 = tt.addptr %7, %4 : tensor<256x!tt.ptr<f32>, #blocked0>
|
||||
|
||||
// CHECK: ld.global.v4.b32
|
||||
%9 = tt.load %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
|
||||
@@ -92,7 +92,7 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
|
||||
%10 = tt.load %8 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
|
||||
%11 = arith.addf %9, %10 : tensor<256xf32, #blocked0>
|
||||
%12 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #blocked0>
|
||||
%13 = tt.getelementptr %12, %4 : tensor<256x!tt.ptr<f32>, #blocked0>
|
||||
%13 = tt.addptr %12, %4 : tensor<256x!tt.ptr<f32>, #blocked0>
|
||||
|
||||
// Store 4 elements to global
|
||||
// CHECK: st.global.b32.v4
|
||||
@@ -104,7 +104,7 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
|
||||
|
||||
|
||||
// TODO: Add a testcase to verify the optimization when ptr of the LoadOp
|
||||
// is from a GEP with const idx
|
||||
// is from an addptr with const idx
|
||||
|
||||
// -----
|
||||
|
||||
@@ -187,11 +187,11 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: basic_gep
|
||||
func @basic_gep(%arg0 : tensor<256x!tt.ptr<f32>,#blocked0>, %arg1 : tensor<256xi32,#blocked0>) {
|
||||
// CHECK-LABEL: basic_addptr
|
||||
func @basic_addptr(%arg0 : tensor<256x!tt.ptr<f32>,#blocked0>, %arg1 : tensor<256xi32,#blocked0>) {
|
||||
// CHECK: llvm.getelementptr
|
||||
// CHECK: llvm.getelementptr
|
||||
%0 = tt.getelementptr %arg0, %arg1 : tensor<256x!tt.ptr<f32>, #blocked0>
|
||||
%0 = tt.addptr %arg0, %arg1 : tensor<256x!tt.ptr<f32>, #blocked0>
|
||||
return
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user