[FRONTEND] Add scalar type support for some ops (#661)

This PR adds basic support for scalar-type inputs to some ops (cast and pointer arithmetics) for Triton-MLIR. Also renames getelementptr -> addptr
This commit is contained in:
Shintaro Iwasaki
2022-09-15 16:12:52 -07:00
committed by GitHub
parent 2e08450c80
commit 43be75ad42
27 changed files with 203 additions and 129 deletions

View File

@@ -0,0 +1,55 @@
// RUN: triton-opt %s | FileCheck %s
func @cast_ops(%scalar_ptr: !tt.ptr<f32>, %scalar_f32: f32, %scalar_i64: i64) {
// scalar -> scalar
// CHECK: i64 -> !tt.ptr<f32>
%0 = tt.int_to_ptr %scalar_i64 : i64 -> !tt.ptr<f32>
// CHECK: !tt.ptr<f32> -> i64
%1 = tt.ptr_to_int %scalar_ptr : !tt.ptr<f32> -> i64
// CHECK: f32 -> f16
%2 = tt.fp_to_fp %scalar_f32 : f32 -> f16
// 0D tensor -> 0D tensor
%tensor_ptr_0d = tt.splat %scalar_ptr : (!tt.ptr<f32>) -> tensor<!tt.ptr<f32>>
%tensor_f32_0d = tt.splat %scalar_f32 : (f32) -> tensor<f32>
%tensor_i64_0d = tt.splat %scalar_i64 : (i64) -> tensor<i64>
// CHECK: tensor<i64> -> tensor<!tt.ptr<f32>>
%3 = tt.int_to_ptr %tensor_i64_0d : tensor<i64> -> tensor<!tt.ptr<f32>>
// CHECK: tensor<!tt.ptr<f32>> -> tensor<i64>
%4 = tt.ptr_to_int %tensor_ptr_0d : tensor<!tt.ptr<f32>> -> tensor<i64>
// CHECK: tensor<f32> -> tensor<f16>
%5 = tt.fp_to_fp %tensor_f32_0d : tensor<f32> -> tensor<f16>
// 1D tensor -> 1D tensor
%tensor_ptr_1d = tt.splat %scalar_ptr : (!tt.ptr<f32>) -> tensor<16x!tt.ptr<f32>>
%tensor_f32_1d = tt.splat %scalar_f32 : (f32) -> tensor<16xf32>
%tensor_i64_1d = tt.splat %scalar_i64 : (i64) -> tensor<16xi64>
// CHECK: tensor<16xi64> -> tensor<16x!tt.ptr<f32>>
%6 = tt.int_to_ptr %tensor_i64_1d : tensor<16xi64> -> tensor<16x!tt.ptr<f32>>
// CHECK: tensor<16x!tt.ptr<f32>> -> tensor<16xi64>
%7 = tt.ptr_to_int %tensor_ptr_1d : tensor<16x!tt.ptr<f32>> -> tensor<16xi64>
// CHECK: tensor<16xf32> -> tensor<16xf16>
%8 = tt.fp_to_fp %tensor_f32_1d : tensor<16xf32> -> tensor<16xf16>
return
}
func @addptr_ops(%scalar_ptr: !tt.ptr<f32>, %scalar_i32: i32) {
// scalar -> scalar
// CHECK: !tt.ptr<f32>
%0 = tt.addptr %scalar_ptr, %scalar_i32 : !tt.ptr<f32>
// 0D tensor -> 0D tensor
%tensor_ptr_0d = tt.splat %scalar_ptr : (!tt.ptr<f32>) -> tensor<!tt.ptr<f32>>
%tensor_i32_0d = tt.splat %scalar_i32 : (i32) -> tensor<i32>
// CHECK: tensor<!tt.ptr<f32>>
%1 = tt.addptr %tensor_ptr_0d, %tensor_i32_0d : tensor<!tt.ptr<f32>>
// 1D tensor -> 1D tensor
%tensor_ptr_1d = tt.splat %scalar_ptr : (!tt.ptr<f32>) -> tensor<16x!tt.ptr<f32>>
%tensor_i32_1d = tt.splat %scalar_i32 : (i32) -> tensor<16xi32>
// CHECK: tensor<16x!tt.ptr<f32>>
%2 = tt.addptr %tensor_ptr_1d, %tensor_i32_1d : tensor<16x!tt.ptr<f32>>
return
}

View File

@@ -82,9 +82,9 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
%3 = tt.splat %1 : (i32) -> tensor<256xi32, #blocked0>
%4 = arith.addi %3, %2 : tensor<256xi32, #blocked0>
%5 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #blocked0>
%6 = tt.getelementptr %5, %4 : tensor<256x!tt.ptr<f32>, #blocked0>
%6 = tt.addptr %5, %4 : tensor<256x!tt.ptr<f32>, #blocked0>
%7 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #blocked0>
%8 = tt.getelementptr %7, %4 : tensor<256x!tt.ptr<f32>, #blocked0>
%8 = tt.addptr %7, %4 : tensor<256x!tt.ptr<f32>, #blocked0>
// CHECK: ld.global.v4.b32
%9 = tt.load %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
@@ -92,7 +92,7 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
%10 = tt.load %8 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
%11 = arith.addf %9, %10 : tensor<256xf32, #blocked0>
%12 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #blocked0>
%13 = tt.getelementptr %12, %4 : tensor<256x!tt.ptr<f32>, #blocked0>
%13 = tt.addptr %12, %4 : tensor<256x!tt.ptr<f32>, #blocked0>
// Store 4 elements to global
// CHECK: st.global.b32.v4
@@ -104,7 +104,7 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
// TODO: Add a testcase to verify the optimization when ptr of the LoadOp
// is from a GEP with const idx
// is from an addptr with const idx
// -----
@@ -187,11 +187,11 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: basic_gep
func @basic_gep(%arg0 : tensor<256x!tt.ptr<f32>,#blocked0>, %arg1 : tensor<256xi32,#blocked0>) {
// CHECK-LABEL: basic_addptr
func @basic_addptr(%arg0 : tensor<256x!tt.ptr<f32>,#blocked0>, %arg1 : tensor<256xi32,#blocked0>) {
// CHECK: llvm.getelementptr
// CHECK: llvm.getelementptr
%0 = tt.getelementptr %arg0, %arg1 : tensor<256x!tt.ptr<f32>, #blocked0>
%0 = tt.addptr %arg0, %arg1 : tensor<256x!tt.ptr<f32>, #blocked0>
return
}
}