// RUN: triton-opt %s -split-input-file -canonicalize -triton-combine // RUN: triton-opt %s -split-input-file -canonicalize -triton-combine | FileCheck %s // CHECK-LABEL: @test_combine_dot_add_pattern func @test_combine_dot_add_pattern() -> (tensor<128x128xf32>, tensor<128x128xf32>) { // CHECK: %[[d:.*]] = arith.constant dense<3.000000e+00> : tensor<128x128xf32> // CHECK: %[[b:.*]] = arith.constant dense<2.000000e+00> : tensor<128x128xf32> // CHECK: %[[a:.*]] = arith.constant dense<1.000000e+00> : tensor<128x128xf32> %a = arith.constant dense<1.0> : tensor<128x128xf32> %b = arith.constant dense<2.0> : tensor<128x128xf32> %zero = arith.constant dense<0.0> : tensor<128x128xf32> %d = arith.constant dense<3.0> : tensor<128x128xf32> %dot_out = tt.dot %a, %b, %zero {allowTF32 = true} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32> // CHECK-NEXT: %[[res0:.*]] = tt.dot %[[a]], %[[b]], %[[d]] {allowTF32 = true} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32> %res0 = arith.addf %dot_out, %d : tensor<128x128xf32> // CHECK-NEXT: %[[res1:.*]] = tt.dot %[[a]], %[[b]], %[[d]] {allowTF32 = true} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32> %res1 = arith.addf %d, %dot_out : tensor<128x128xf32> return %res0, %res1 : tensor<128x128xf32>, tensor<128x128xf32> } // CHECK-LABEL: @test_combine_gep_pattern func @test_combine_gep_pattern(%base: !tt.ptr) -> tensor<8x!tt.ptr> { %off0 = arith.constant 10 : i32 %off1 = arith.constant 15 : i32 // 10 + 15 = 25 // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<25> : tensor<8xi32> %base_ = tt.broadcast %base : (!tt.ptr) -> tensor<8x!tt.ptr> // CHECK-NEXT: %[[tmp0:.*]] = tt.broadcast %{{.*}} : (!tt.ptr) -> tensor<8x!tt.ptr> %idx0 = tt.broadcast %off0 : (i32) -> tensor<8xi32> %idx1 = tt.broadcast %off1 : (i32) -> tensor<8xi32> // CHECK-NEXT: %1 = tt.getelementptr %[[tmp0]], %[[cst]] : tensor<8x!tt.ptr> %ptr0 = tt.getelementptr %base_, %idx0 : tensor<8x!tt.ptr> %ptr1 = tt.getelementptr %ptr0, %idx1 : tensor<8x!tt.ptr> return %ptr1 : tensor<8x!tt.ptr> } // CHECK-LABEL: @test_combine_select_masked_load_pattern func @test_combine_select_masked_load_pattern(%ptr: tensor<8x!tt.ptr>, %cond: i1) -> tensor<8xf32> { %mask = tt.broadcast %cond : (i1) -> tensor<8xi1> %false_val = arith.constant dense<0.0> : tensor<8xf32> // CHECK: %[[res:.*]] = tt.load %{{.*}}, %{{.*}}, %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32> %x = tt.load %ptr, %mask, %false_val {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32> %0 = select %cond, %x, %false_val : tensor<8xf32> // CHECK: return %[[res]] : tensor<8xf32> return %0 : tensor<8xf32> } // CHECK-LABEL: @test_combine_broadcast_constant_pattern func @test_combine_broadcast_constant_pattern(%cst : f32) -> tensor<8x2xf32> { // CHECK: %[[cst:.*]] = arith.constant dense<1.000000e+00> : tensor<8x2xf32> %const = arith.constant dense<1.0> : tensor<8xf32> %bst_out = tt.broadcast %const : (tensor<8xf32>) -> tensor<8x2xf32> // CHECK-NEXT: return %[[cst]] : tensor<8x2xf32> return %bst_out : tensor<8x2xf32> }