This PR merges the `triton-mlir` branch, in which we have been quietly rewriting the Triton backend from scratch to increase maintainability, stability and ultimately performance. Changes to the runtime are minimal, and this new version aims to remain backward-compatible with the previous commit. The legacy backend is now officially deprecated, but can still be accessed via the `legacy-backend` tag. Co-authored-by: Keren Zhou <kerenzhou@openai.com> Co-authored-by: Yan Chunwei <yanchunwei@outlook.com> Co-authored-by: goostavz <109190422+goostavz@users.noreply.github.com> Co-authored-by: Shintaro Iwasaki <siwasaki@fb.com> Co-authored-by: Yan Da <dyanab@connect.ust.hk> Co-authored-by: Jun Yang <yangjunpro@gmail.com> Co-authored-by: Ian Bearman <ianb@microsoft.com> Co-authored-by: Jason Ansel <jansel@jansel.net> Co-authored-by: Qingyi Liu <qingyil@nvidia.com> Co-authored-by: ben-zhang-609 <110140741+ben-zhang-609@users.noreply.github.com> Co-authored-by: Chenggang Zhao <lyricz@yeah.net> Co-authored-by: ben-zhang-609 <benzh609@gmail.com> Co-authored-by: dongdongl <dongdongl@nvidia.com>
147 lines
7.6 KiB
MLIR
147 lines
7.6 KiB
MLIR
// RUN: triton-opt %s -split-input-file -canonicalize -triton-combine
|
|
// RUN: triton-opt %s -split-input-file -canonicalize -triton-combine | FileCheck %s
|
|
|
|
// CHECK-LABEL: @test_combine_dot_add_pattern
|
|
func @test_combine_dot_add_pattern() -> (tensor<128x128xf32>, tensor<128x128xf32>) {
|
|
// CHECK: %[[d:.*]] = arith.constant dense<3.000000e+00> : tensor<128x128xf32>
|
|
// CHECK: %[[b:.*]] = arith.constant dense<2.000000e+00> : tensor<128x128xf32>
|
|
// CHECK: %[[a:.*]] = arith.constant dense<1.000000e+00> : tensor<128x128xf32>
|
|
%a = arith.constant dense<1.0> : tensor<128x128xf32>
|
|
%b = arith.constant dense<2.0> : tensor<128x128xf32>
|
|
%zero = arith.constant dense<0.0> : tensor<128x128xf32>
|
|
%d = arith.constant dense<3.0> : tensor<128x128xf32>
|
|
|
|
%dot_out = tt.dot %a, %b, %zero {allowTF32 = true, transA = false, transB = false} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32>
|
|
|
|
// CHECK-NEXT: %[[res0:.*]] = tt.dot %[[a]], %[[b]], %[[d]] {allowTF32 = true} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32>
|
|
%res0 = arith.addf %dot_out, %d : tensor<128x128xf32>
|
|
|
|
// CHECK-NEXT: %[[res1:.*]] = tt.dot %[[a]], %[[b]], %[[d]] {allowTF32 = true} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32>
|
|
%res1 = arith.addf %d, %dot_out : tensor<128x128xf32>
|
|
|
|
return %res0, %res1 : tensor<128x128xf32>, tensor<128x128xf32>
|
|
}
|
|
|
|
|
|
// COM: CHECK-LABEL: @test_combine_addptr_pattern
|
|
func @test_combine_addptr_pattern(%base: !tt.ptr<f32>) -> tensor<8x!tt.ptr<f32>> {
|
|
%off0 = arith.constant 10 : i32
|
|
%off1 = arith.constant 15 : i32
|
|
|
|
// 10 + 15 = 25
|
|
// COM: CHECK-NEXT: %[[cst:.*]] = arith.constant dense<25> : tensor<8xi32>
|
|
|
|
%base_ = tt.broadcast %base : (!tt.ptr<f32>) -> tensor<8x!tt.ptr<f32>>
|
|
|
|
// COM: CHECK-NEXT: %[[tmp0:.*]] = tt.broadcast %{{.*}} : (!tt.ptr<f32>) -> tensor<8x!tt.ptr<f32>>
|
|
|
|
%idx0 = tt.broadcast %off0 : (i32) -> tensor<8xi32>
|
|
%idx1 = tt.broadcast %off1 : (i32) -> tensor<8xi32>
|
|
|
|
// COM: CHECK-NEXT: %1 = tt.addptr %[[tmp0]], %[[cst]] : tensor<8x!tt.ptr<f32>>, tensor<8xi32>
|
|
%ptr0 = tt.addptr %base_, %idx0 : tensor<8x!tt.ptr<f32>>, tensor<8xi32>
|
|
%ptr1 = tt.addptr %ptr0, %idx1 : tensor<8x!tt.ptr<f32>>, tensor<8xi32>
|
|
|
|
return %ptr1 : tensor<8x!tt.ptr<f32>>
|
|
}
|
|
|
|
|
|
// CHECK-LABEL: @test_combine_select_masked_load_pattern
|
|
func @test_combine_select_masked_load_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %cond: i1) -> (tensor<8xf32>, tensor<8xf32>) {
|
|
%mask = tt.broadcast %cond : (i1) -> tensor<8xi1>
|
|
%false_val = arith.constant dense<0.0> : tensor<8xf32>
|
|
|
|
// CHECK: %[[res1:.*]] = tt.load %{{.*}}, %{{.*}}, %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
|
|
%x = tt.load %ptr, %mask, %false_val {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
|
|
%0 = select %cond, %x, %false_val : tensor<8xf32>
|
|
|
|
// CHECK: %[[res2:.*]] = tt.load %{{.*}}, %{{.*}}, %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
|
|
%y = tt.load %ptr, %mask, %false_val {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
|
|
%1 = select %cond, %y, %false_val : tensor<8xf32>
|
|
|
|
// CHECK: return %[[res1]], %[[res2]] : tensor<8xf32>, tensor<8xf32>
|
|
return %0, %1 : tensor<8xf32>, tensor<8xf32>
|
|
}
|
|
|
|
// CHECK-LABEL: @test_combine_select_masked_load_fail_pattern
|
|
func @test_combine_select_masked_load_fail_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %dummy_load: tensor<8xf32>, %dummy_broadcast: tensor<8xi1>, %cond: i1) -> (tensor<8xf32>, tensor<8xf32>) {
|
|
%false_val = arith.constant dense<0.0> : tensor<8xf32>
|
|
|
|
// Case 1: value at the "load" position is not an "op". Select should not be canonicalized.
|
|
// CHECK: %{{.*}} = select %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32>
|
|
%0 = select %cond, %dummy_load, %false_val : tensor<8xf32>
|
|
|
|
// Case 2: value at the "broadcast" position is not an "op". Select should not be canonicalized.
|
|
%real_load = tt.load %ptr, %dummy_broadcast, %false_val {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
|
|
// CHECK: %{{.*}} = select %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32>
|
|
%1 = select %cond, %real_load, %false_val : tensor<8xf32>
|
|
|
|
return %0, %1 : tensor<8xf32>, tensor<8xf32>
|
|
}
|
|
|
|
// CHECK-LABEL: @test_combine_broadcast_constant_pattern
|
|
func @test_combine_broadcast_constant_pattern(%cst : f32) -> tensor<8x2xf32> {
|
|
// CHECK: %[[cst:.*]] = arith.constant dense<1.000000e+00> : tensor<8x2xf32>
|
|
%const = arith.constant dense<1.0> : tensor<8xf32>
|
|
%bst_out = tt.broadcast %const : (tensor<8xf32>) -> tensor<8x2xf32>
|
|
|
|
// CHECK-NEXT: return %[[cst]] : tensor<8x2xf32>
|
|
return %bst_out : tensor<8x2xf32>
|
|
}
|
|
|
|
// CHECK-LABEL: @test_canonicalize_masked_load_pattern
|
|
func @test_canonicalize_masked_load_pattern(%ptr: tensor<8x!tt.ptr<f32>>) -> (tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) {
|
|
%true_mask = arith.constant dense<true> : tensor<8xi1>
|
|
%false_mask = arith.constant dense<false> : tensor<8xi1>
|
|
%other_val = arith.constant dense<0.0> : tensor<8xf32>
|
|
|
|
// true_mask with other
|
|
// CHECK: %[[res1:.*]] = tt.load %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
|
|
%x = tt.load %ptr, %true_mask {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
|
|
|
|
// true_mask without other
|
|
// CHECK: %[[res2:.*]] = tt.load %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
|
|
%y = tt.load %ptr, %true_mask, %other_val {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
|
|
|
|
// false_mask with other. It should become "other" (i.e., %y)
|
|
%z = tt.load %ptr, %false_mask, %y {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
|
|
|
|
// CHECK: return %[[res1]], %[[res2]], %[[res2]] : tensor<8xf32>, tensor<8xf32>, tensor<8xf32>
|
|
return %x, %y, %z: tensor<8xf32>, tensor<8xf32>, tensor<8xf32>
|
|
}
|
|
|
|
// CHECK-LABEL: @test_canonicalize_masked_load_fail_pattern
|
|
func @test_canonicalize_masked_load_fail_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %mask: tensor<8xi1>) -> (tensor<8xf32>, tensor<8xf32>) {
|
|
%other_val = arith.constant dense<0.0> : tensor<8xf32>
|
|
|
|
// Case: value at the "mask" position is not an "op". Load should not be canonicalized.
|
|
// CHECK: %[[res1:.*]] = tt.load %{{.*}}, %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
|
|
%x = tt.load %ptr, %mask {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
|
|
// CHECK: %[[res1:.*]] = tt.load %{{.*}}, %{{.*}}, %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
|
|
%y = tt.load %ptr, %mask, %other_val {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
|
|
|
|
return %x, %y: tensor<8xf32>, tensor<8xf32>
|
|
}
|
|
|
|
// CHECK-LABEL: @test_canonicalize_masked_store_pattern
|
|
func @test_canonicalize_masked_store_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %val: tensor<8xf32>) {
|
|
%true_mask = arith.constant dense<true> : tensor<8xi1>
|
|
%false_mask = arith.constant dense<false> : tensor<8xi1>
|
|
|
|
// CHECK: tt.store %{{.*}}, %{{.*}} : tensor<8xf32>
|
|
tt.store %ptr, %val, %true_mask : tensor<8xf32>
|
|
|
|
// The following store should disappear.
|
|
// CHECK-NEXT: return
|
|
tt.store %ptr, %val, %false_mask : tensor<8xf32>
|
|
return
|
|
}
|
|
|
|
// CHECK-LABEL: @test_canonicalize_masked_store_fail_pattern
|
|
func @test_canonicalize_masked_store_fail_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %val: tensor<8xf32>, %mask: tensor<8xi1>) {
|
|
// Case: value at the "mask" position is not an "op". Store should not be canonicalized.
|
|
// CHECK: tt.store %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32>
|
|
tt.store %ptr, %val, %mask : tensor<8xf32>
|
|
return
|
|
}
|