// RUN: triton-opt %s | FileCheck %s func @cast_ops(%scalar_ptr: !tt.ptr, %scalar_f32: f32, %scalar_i64: i64) { // scalar -> scalar // CHECK: i64 -> !tt.ptr %0 = tt.int_to_ptr %scalar_i64 : i64 -> !tt.ptr // CHECK: !tt.ptr -> i64 %1 = tt.ptr_to_int %scalar_ptr : !tt.ptr -> i64 // CHECK: f32 to f16 %2 = arith.truncf %scalar_f32 : f32 to f16 // 0D tensor -> 0D tensor %tensor_ptr_0d = tt.splat %scalar_ptr : (!tt.ptr) -> tensor> %tensor_f32_0d = tt.splat %scalar_f32 : (f32) -> tensor %tensor_i64_0d = tt.splat %scalar_i64 : (i64) -> tensor // CHECK: tensor -> tensor> %3 = tt.int_to_ptr %tensor_i64_0d : tensor -> tensor> // CHECK: tensor> -> tensor %4 = tt.ptr_to_int %tensor_ptr_0d : tensor> -> tensor // CHECK: tensor to tensor %5 = arith.truncf %tensor_f32_0d : tensor to tensor // 1D tensor -> 1D tensor %tensor_ptr_1d = tt.splat %scalar_ptr : (!tt.ptr) -> tensor<16x!tt.ptr> %tensor_f32_1d = tt.splat %scalar_f32 : (f32) -> tensor<16xf32> %tensor_i64_1d = tt.splat %scalar_i64 : (i64) -> tensor<16xi64> // CHECK: tensor<16xi64> -> tensor<16x!tt.ptr> %6 = tt.int_to_ptr %tensor_i64_1d : tensor<16xi64> -> tensor<16x!tt.ptr> // CHECK: tensor<16x!tt.ptr> -> tensor<16xi64> %7 = tt.ptr_to_int %tensor_ptr_1d : tensor<16x!tt.ptr> -> tensor<16xi64> // CHECK: tensor<16xf32> to tensor<16xf16> %8 = arith.truncf %tensor_f32_1d : tensor<16xf32> to tensor<16xf16> return } func @addptr_ops(%scalar_ptr: !tt.ptr, %scalar_i32: i32) { // scalar -> scalar // CHECK: !tt.ptr %0 = tt.addptr %scalar_ptr, %scalar_i32 : !tt.ptr // 0D tensor -> 0D tensor %tensor_ptr_0d = tt.splat %scalar_ptr : (!tt.ptr) -> tensor> %tensor_i32_0d = tt.splat %scalar_i32 : (i32) -> tensor // CHECK: tensor> %1 = tt.addptr %tensor_ptr_0d, %tensor_i32_0d : tensor> // 1D tensor -> 1D tensor %tensor_ptr_1d = tt.splat %scalar_ptr : (!tt.ptr) -> tensor<16x!tt.ptr> %tensor_i32_1d = tt.splat %scalar_i32 : (i32) -> tensor<16xi32> // CHECK: tensor<16x!tt.ptr> %2 = tt.addptr %tensor_ptr_1d, %tensor_i32_1d : tensor<16x!tt.ptr> return } func @load_store_ops_scalar(%ptr: !tt.ptr {tt.divisibility = 16 : i32}, %mask : i1) { // Test if Load/Store ops can handle scalar values %other = arith.constant 0.0e+0 : f32 // load scalar // CHECK: %[[L0:.*]] = tt.load %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = true} : f32 %a = tt.load %ptr {cache = 1 : i32, evict = 1 : i32, isVolatile = true} : f32 // CHECK: %[[L1:.*]] = tt.load %{{.*}}, %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = true} : f32 %b = tt.load %ptr, %mask {cache = 1 : i32, evict = 1 : i32, isVolatile = true} : f32 // CHECK: %[[L2:.*]] = tt.load %{{.*}}, %{{.*}}, %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = true} : f32 %c = tt.load %ptr, %mask, %other {cache = 1 : i32, evict = 1 : i32, isVolatile = true} : f32 // store scalar // CHECK: tt.store %{{.*}}, %[[L0]] : f32 tt.store %ptr, %a : f32 // CHECK: tt.store %{{.*}}, %[[L1]], %{{.*}} : f32 tt.store %ptr, %b, %mask : f32 // CHECK: tt.store %{{.*}}, %[[L2]], %{{.*}} : f32 tt.store %ptr, %c, %mask : f32 return } func @reduce_ops_infer(%ptr: !tt.ptr, %v : tensor<1x2x4xf32>) { // Test if reduce ops infer types correctly // CHECK: %{{.*}} = tt.reduce %{{.*}} -> tensor<2x4xf32> %a = tt.reduce %v {redOp = 1 : i32, axis = 0 : i32} : tensor<1x2x4xf32> -> tensor<2x4xf32> // CHECK: %{{.*}} = tt.reduce %{{.*}} -> tensor<1x4xf32> %b = tt.reduce %v {redOp = 1 : i32, axis = 1 : i32} : tensor<1x2x4xf32> -> tensor<1x4xf32> // CHECK: %{{.*}} = tt.reduce %{{.*}} -> tensor<1x2xf32> %c = tt.reduce %v {redOp = 1 : i32, axis = 2 : i32} : tensor<1x2x4xf32> -> tensor<1x2xf32> // CHECK: %{{.*}} = tt.reduce %{{.*}} -> tensor<1xf32> %e = tt.reduce %b {redOp = 1 : i32, axis = 1 : i32} : tensor<1x4xf32> -> tensor<1xf32> // CHECK: %{{.*}} = tt.reduce %{{.*}} -> tensor<4xf32> %f = tt.reduce %a {redOp = 1 : i32, axis = 0 : i32} : tensor<2x4xf32> -> tensor<4xf32> // CHECK: %{{.*}} = tt.reduce %{{.*}} -> f32 %g = tt.reduce %f {redOp = 1 : i32, axis = 0 : i32} : tensor<4xf32> -> f32 // Avoid optimizations for c, e, and g %ptr1x2 = tt.splat %ptr : (!tt.ptr) -> tensor<1x2x!tt.ptr> %ptr1 = tt.splat %ptr : (!tt.ptr) -> tensor<1x!tt.ptr> tt.store %ptr1x2, %c : tensor<1x2xf32> tt.store %ptr1, %e : tensor<1xf32> tt.store %ptr, %g : f32 return } func @dot_ops_infer(%ptr: !tt.ptr, %v : f32) { // Test if reduce ops infer types correctly %v128x32 = tt.splat %v : (f32) -> tensor<128x32xf32> %v32x128 = tt.splat %v : (f32) -> tensor<32x128xf32> %v128x1 = tt.splat %v : (f32) -> tensor<128x1xf32> %v1x128 = tt.splat %v : (f32) -> tensor<1x128xf32> %zero128x128 = arith.constant dense<0.00e+00> : tensor<128x128xf32> %zero32x32 = arith.constant dense<0.00e+00> : tensor<32x32xf32> %zero1x1 = arith.constant dense<0.00e+00> : tensor<1x1xf32> // CHECK: %{{.*}} = tt.dot %{{.*}} -> tensor<128x128xf32> %r1 = tt.dot %v128x32, %v32x128, %zero128x128 {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf32> * tensor<32x128xf32> -> tensor<128x128xf32> // CHECK: %{{.*}} = tt.dot %{{.*}} -> tensor<32x32xf32> %r2 = tt.dot %v32x128, %v128x32, %zero32x32 {allowTF32 = true, transA = false, transB = false} : tensor<32x128xf32> * tensor<128x32xf32> -> tensor<32x32xf32> // CHECK: %{{.*}} = tt.dot %{{.*}} -> tensor<128x128xf32> %r3 = tt.dot %v128x1, %v1x128, %zero128x128 {allowTF32 = true, transA = false, transB = false} : tensor<128x1xf32> * tensor<1x128xf32> -> tensor<128x128xf32> // CHECK: %{{.*}} = tt.dot %{{.*}} -> tensor<1x1xf32> %r4 = tt.dot %v1x128, %v128x1, %zero1x1 {allowTF32 = true, transA = false, transB = false} : tensor<1x128xf32> * tensor<128x1xf32> -> tensor<1x1xf32> %ptr128x128 = tt.splat %ptr : (!tt.ptr) -> tensor<128x128x!tt.ptr> %ptr32x32 = tt.splat %ptr : (!tt.ptr) -> tensor<32x32x!tt.ptr> %ptr1x1 = tt.splat %ptr : (!tt.ptr) -> tensor<1x1x!tt.ptr> tt.store %ptr128x128, %r1 : tensor<128x128xf32> tt.store %ptr32x32, %r2 : tensor<32x32xf32> tt.store %ptr128x128, %r3 : tensor<128x128xf32> tt.store %ptr1x1, %r4 : tensor<1x1xf32> return }