Merge triton-mlir
branch - Complete rewrite of the backend from scratch (#1004)
This PR merges the `triton-mlir` branch, in which we have been quietly rewriting the Triton backend from scratch to increase maintainability, stability and ultimately performance. Changes to the runtime are minimal, and this new version aims to remain backward-compatible with the previous commit. The legacy backend is now officially deprecated, but can still be accessed via the `legacy-backend` tag. Co-authored-by: Keren Zhou <kerenzhou@openai.com> Co-authored-by: Yan Chunwei <yanchunwei@outlook.com> Co-authored-by: goostavz <109190422+goostavz@users.noreply.github.com> Co-authored-by: Shintaro Iwasaki <siwasaki@fb.com> Co-authored-by: Yan Da <dyanab@connect.ust.hk> Co-authored-by: Jun Yang <yangjunpro@gmail.com> Co-authored-by: Ian Bearman <ianb@microsoft.com> Co-authored-by: Jason Ansel <jansel@jansel.net> Co-authored-by: Qingyi Liu <qingyil@nvidia.com> Co-authored-by: ben-zhang-609 <110140741+ben-zhang-609@users.noreply.github.com> Co-authored-by: Chenggang Zhao <lyricz@yeah.net> Co-authored-by: ben-zhang-609 <benzh609@gmail.com> Co-authored-by: dongdongl <dongdongl@nvidia.com>
This commit is contained in:
132
test/Conversion/triton_ops.mlir
Normal file
132
test/Conversion/triton_ops.mlir
Normal file
@@ -0,0 +1,132 @@
|
||||
// RUN: triton-opt %s | FileCheck %s
|
||||
|
||||
func @cast_ops(%scalar_ptr: !tt.ptr<f32>, %scalar_f32: f32, %scalar_i64: i64) {
|
||||
// scalar -> scalar
|
||||
// CHECK: i64 -> !tt.ptr<f32>
|
||||
%0 = tt.int_to_ptr %scalar_i64 : i64 -> !tt.ptr<f32>
|
||||
// CHECK: !tt.ptr<f32> -> i64
|
||||
%1 = tt.ptr_to_int %scalar_ptr : !tt.ptr<f32> -> i64
|
||||
// CHECK: f32 to f16
|
||||
%2 = arith.truncf %scalar_f32 : f32 to f16
|
||||
|
||||
// 0D tensor -> 0D tensor
|
||||
%tensor_ptr_0d = tt.splat %scalar_ptr : (!tt.ptr<f32>) -> tensor<!tt.ptr<f32>>
|
||||
%tensor_f32_0d = tt.splat %scalar_f32 : (f32) -> tensor<f32>
|
||||
%tensor_i64_0d = tt.splat %scalar_i64 : (i64) -> tensor<i64>
|
||||
|
||||
// CHECK: tensor<i64> -> tensor<!tt.ptr<f32>>
|
||||
%3 = tt.int_to_ptr %tensor_i64_0d : tensor<i64> -> tensor<!tt.ptr<f32>>
|
||||
// CHECK: tensor<!tt.ptr<f32>> -> tensor<i64>
|
||||
%4 = tt.ptr_to_int %tensor_ptr_0d : tensor<!tt.ptr<f32>> -> tensor<i64>
|
||||
// CHECK: tensor<f32> to tensor<f16>
|
||||
%5 = arith.truncf %tensor_f32_0d : tensor<f32> to tensor<f16>
|
||||
|
||||
// 1D tensor -> 1D tensor
|
||||
%tensor_ptr_1d = tt.splat %scalar_ptr : (!tt.ptr<f32>) -> tensor<16x!tt.ptr<f32>>
|
||||
%tensor_f32_1d = tt.splat %scalar_f32 : (f32) -> tensor<16xf32>
|
||||
%tensor_i64_1d = tt.splat %scalar_i64 : (i64) -> tensor<16xi64>
|
||||
|
||||
// CHECK: tensor<16xi64> -> tensor<16x!tt.ptr<f32>>
|
||||
%6 = tt.int_to_ptr %tensor_i64_1d : tensor<16xi64> -> tensor<16x!tt.ptr<f32>>
|
||||
// CHECK: tensor<16x!tt.ptr<f32>> -> tensor<16xi64>
|
||||
%7 = tt.ptr_to_int %tensor_ptr_1d : tensor<16x!tt.ptr<f32>> -> tensor<16xi64>
|
||||
// CHECK: tensor<16xf32> to tensor<16xf16>
|
||||
%8 = arith.truncf %tensor_f32_1d : tensor<16xf32> to tensor<16xf16>
|
||||
return
|
||||
}
|
||||
|
||||
func @addptr_ops(%scalar_ptr: !tt.ptr<f32>, %scalar_i32: i32) {
|
||||
// scalar -> scalar
|
||||
// CHECK: !tt.ptr<f32>
|
||||
%0 = tt.addptr %scalar_ptr, %scalar_i32 : !tt.ptr<f32>, i32
|
||||
|
||||
// 0D tensor -> 0D tensor
|
||||
%tensor_ptr_0d = tt.splat %scalar_ptr : (!tt.ptr<f32>) -> tensor<!tt.ptr<f32>>
|
||||
%tensor_i32_0d = tt.splat %scalar_i32 : (i32) -> tensor<i32>
|
||||
// CHECK: tensor<!tt.ptr<f32>>
|
||||
%1 = tt.addptr %tensor_ptr_0d, %tensor_i32_0d : tensor<!tt.ptr<f32>>, tensor<i32>
|
||||
|
||||
// 1D tensor -> 1D tensor
|
||||
%tensor_ptr_1d = tt.splat %scalar_ptr : (!tt.ptr<f32>) -> tensor<16x!tt.ptr<f32>>
|
||||
%tensor_i32_1d = tt.splat %scalar_i32 : (i32) -> tensor<16xi32>
|
||||
// CHECK: tensor<16x!tt.ptr<f32>>
|
||||
%2 = tt.addptr %tensor_ptr_1d, %tensor_i32_1d : tensor<16x!tt.ptr<f32>>, tensor<16xi32>
|
||||
return
|
||||
}
|
||||
|
||||
func @load_store_ops_scalar(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %mask : i1) {
|
||||
// Test if Load/Store ops can handle scalar values
|
||||
%other = arith.constant 0.0e+0 : f32
|
||||
|
||||
// load scalar
|
||||
// CHECK: %[[L0:.*]] = tt.load %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = true} : f32
|
||||
%a = tt.load %ptr {cache = 1 : i32, evict = 1 : i32, isVolatile = true} : f32
|
||||
// CHECK: %[[L1:.*]] = tt.load %{{.*}}, %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = true} : f32
|
||||
%b = tt.load %ptr, %mask {cache = 1 : i32, evict = 1 : i32, isVolatile = true} : f32
|
||||
// CHECK: %[[L2:.*]] = tt.load %{{.*}}, %{{.*}}, %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = true} : f32
|
||||
%c = tt.load %ptr, %mask, %other {cache = 1 : i32, evict = 1 : i32, isVolatile = true} : f32
|
||||
|
||||
// store scalar
|
||||
// CHECK: tt.store %{{.*}}, %[[L0]] : f32
|
||||
tt.store %ptr, %a : f32
|
||||
// CHECK: tt.store %{{.*}}, %[[L1]], %{{.*}} : f32
|
||||
tt.store %ptr, %b, %mask : f32
|
||||
// CHECK: tt.store %{{.*}}, %[[L2]], %{{.*}} : f32
|
||||
tt.store %ptr, %c, %mask : f32
|
||||
return
|
||||
}
|
||||
|
||||
func @reduce_ops_infer(%ptr: !tt.ptr<f32>, %v : tensor<1x2x4xf32>) {
|
||||
// Test if reduce ops infer types correctly
|
||||
|
||||
// CHECK: %{{.*}} = tt.reduce %{{.*}} -> tensor<2x4xf32>
|
||||
%a = tt.reduce %v {redOp = 1 : i32, axis = 0 : i32} : tensor<1x2x4xf32> -> tensor<2x4xf32>
|
||||
// CHECK: %{{.*}} = tt.reduce %{{.*}} -> tensor<1x4xf32>
|
||||
%b = tt.reduce %v {redOp = 1 : i32, axis = 1 : i32} : tensor<1x2x4xf32> -> tensor<1x4xf32>
|
||||
// CHECK: %{{.*}} = tt.reduce %{{.*}} -> tensor<1x2xf32>
|
||||
%c = tt.reduce %v {redOp = 1 : i32, axis = 2 : i32} : tensor<1x2x4xf32> -> tensor<1x2xf32>
|
||||
// CHECK: %{{.*}} = tt.reduce %{{.*}} -> tensor<1xf32>
|
||||
%e = tt.reduce %b {redOp = 1 : i32, axis = 1 : i32} : tensor<1x4xf32> -> tensor<1xf32>
|
||||
// CHECK: %{{.*}} = tt.reduce %{{.*}} -> tensor<4xf32>
|
||||
%f = tt.reduce %a {redOp = 1 : i32, axis = 0 : i32} : tensor<2x4xf32> -> tensor<4xf32>
|
||||
// CHECK: %{{.*}} = tt.reduce %{{.*}} -> f32
|
||||
%g = tt.reduce %f {redOp = 1 : i32, axis = 0 : i32} : tensor<4xf32> -> f32
|
||||
|
||||
// Avoid optimizations for c, e, and g
|
||||
%ptr1x2 = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<1x2x!tt.ptr<f32>>
|
||||
%ptr1 = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<1x!tt.ptr<f32>>
|
||||
tt.store %ptr1x2, %c : tensor<1x2xf32>
|
||||
tt.store %ptr1, %e : tensor<1xf32>
|
||||
tt.store %ptr, %g : f32
|
||||
return
|
||||
}
|
||||
|
||||
func @dot_ops_infer(%ptr: !tt.ptr<f32>, %v : f32) {
|
||||
// Test if reduce ops infer types correctly
|
||||
%v128x32 = tt.splat %v : (f32) -> tensor<128x32xf32>
|
||||
%v32x128 = tt.splat %v : (f32) -> tensor<32x128xf32>
|
||||
%v128x1 = tt.splat %v : (f32) -> tensor<128x1xf32>
|
||||
%v1x128 = tt.splat %v : (f32) -> tensor<1x128xf32>
|
||||
|
||||
%zero128x128 = arith.constant dense<0.00e+00> : tensor<128x128xf32>
|
||||
%zero32x32 = arith.constant dense<0.00e+00> : tensor<32x32xf32>
|
||||
%zero1x1 = arith.constant dense<0.00e+00> : tensor<1x1xf32>
|
||||
|
||||
// CHECK: %{{.*}} = tt.dot %{{.*}} -> tensor<128x128xf32>
|
||||
%r1 = tt.dot %v128x32, %v32x128, %zero128x128 {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf32> * tensor<32x128xf32> -> tensor<128x128xf32>
|
||||
// CHECK: %{{.*}} = tt.dot %{{.*}} -> tensor<32x32xf32>
|
||||
%r2 = tt.dot %v32x128, %v128x32, %zero32x32 {allowTF32 = true, transA = false, transB = false} : tensor<32x128xf32> * tensor<128x32xf32> -> tensor<32x32xf32>
|
||||
// CHECK: %{{.*}} = tt.dot %{{.*}} -> tensor<128x128xf32>
|
||||
%r3 = tt.dot %v128x1, %v1x128, %zero128x128 {allowTF32 = true, transA = false, transB = false} : tensor<128x1xf32> * tensor<1x128xf32> -> tensor<128x128xf32>
|
||||
// CHECK: %{{.*}} = tt.dot %{{.*}} -> tensor<1x1xf32>
|
||||
%r4 = tt.dot %v1x128, %v128x1, %zero1x1 {allowTF32 = true, transA = false, transB = false} : tensor<1x128xf32> * tensor<128x1xf32> -> tensor<1x1xf32>
|
||||
|
||||
%ptr128x128 = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<128x128x!tt.ptr<f32>>
|
||||
%ptr32x32 = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<32x32x!tt.ptr<f32>>
|
||||
%ptr1x1 = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<1x1x!tt.ptr<f32>>
|
||||
tt.store %ptr128x128, %r1 : tensor<128x128xf32>
|
||||
tt.store %ptr32x32, %r2 : tensor<32x32xf32>
|
||||
tt.store %ptr128x128, %r3 : tensor<128x128xf32>
|
||||
tt.store %ptr1x1, %r4 : tensor<1x1xf32>
|
||||
return
|
||||
}
|
53
test/Conversion/triton_to_tritongpu.mlir
Normal file
53
test/Conversion/triton_to_tritongpu.mlir
Normal file
@@ -0,0 +1,53 @@
|
||||
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu=num-warps=2 | FileCheck %s
|
||||
|
||||
func @ops() {
|
||||
// CHECK: module attributes {"triton_gpu.num-warps" = 2 : i32} {{.*}}
|
||||
%a = arith.constant dense<1.00e+00> : tensor<128x32xf16>
|
||||
%b = arith.constant dense<2.00e+00> : tensor<32x128xf16>
|
||||
%c = arith.constant dense<3.00e+00> : tensor<128x128xf32>
|
||||
%0 = tt.dot %a, %b, %c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16> * tensor<32x128xf16> -> tensor<128x128xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @load_ops(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
|
||||
// Test if LoadOp is lowered properly (see #771)
|
||||
%ptrs = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<128x!tt.ptr<f32>>
|
||||
%mask = arith.constant dense<true> : tensor<128xi1>
|
||||
%other = arith.constant dense<0.0e+0> : tensor<128xf32>
|
||||
// CHECK: %{{.*}} = tt.load %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = true} : {{.*}}
|
||||
%a = tt.load %ptrs {cache = 1 : i32, evict = 1 : i32, isVolatile = true} : tensor<128xf32>
|
||||
// CHECK: %{{.*}} = tt.load %{{.*}}, %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = true} : {{.*}}
|
||||
%b = tt.load %ptrs, %mask {cache = 1 : i32, evict = 1 : i32, isVolatile = true} : tensor<128xf32>
|
||||
// CHECK: %{{.*}} = tt.load %{{.*}}, %{{.*}}, %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = true} : {{.*}}
|
||||
%c = tt.load %ptrs, %mask, %other {cache = 1 : i32, evict = 1 : i32, isVolatile = true} : tensor<128xf32>
|
||||
tt.store %ptrs, %a : tensor<128xf32>
|
||||
tt.store %ptrs, %b : tensor<128xf32>
|
||||
tt.store %ptrs, %c : tensor<128xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @reduce_ops(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
|
||||
// Test if the total number of threadsPerWarp is 32
|
||||
// Test if the total number of warps is 2
|
||||
// CHECK: #blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 2], order = [0, 1]}>
|
||||
// CHECK: #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 2], order = [0, 1]}>
|
||||
// CHECK: #blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 2], order = [0, 1]}>
|
||||
// CHECK: module attributes {"triton_gpu.num-warps" = 2 : i32} {{.*}}
|
||||
%c0 = arith.constant dense<1.00e+00> : tensor<4x4xf32>
|
||||
%c1 = arith.constant dense<2.00e+00> : tensor<8x2xf32>
|
||||
%c2 = arith.constant dense<3.00e+00> : tensor<16x16xf32>
|
||||
// CHECK: tensor<4x4xf32, #blocked0> -> tensor<4xf32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>>
|
||||
%c0_ = tt.reduce %c0 {redOp = 1 : i32, axis = 0 : i32} : tensor<4x4xf32> -> tensor<4xf32>
|
||||
// CHECK: tensor<8x2xf32, #blocked1> -> tensor<2xf32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>
|
||||
%c1_ = tt.reduce %c1 {redOp = 1 : i32, axis = 0 : i32} : tensor<8x2xf32> -> tensor<2xf32>
|
||||
// CHECK: tensor<8x2xf32, #blocked1> -> tensor<8xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
||||
%c2_ = tt.reduce %c1 {redOp = 1 : i32, axis = 1 : i32} : tensor<8x2xf32> -> tensor<8xf32>
|
||||
// CHECK: tensor<16x16xf32, #blocked2> -> tensor<16xf32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>
|
||||
%c3_ = tt.reduce %c2 {redOp = 1 : i32, axis = 0 : i32} : tensor<16x16xf32> -> tensor<16xf32>
|
||||
|
||||
return
|
||||
}
|
1016
test/Conversion/tritongpu_to_llvm.mlir
Normal file
1016
test/Conversion/tritongpu_to_llvm.mlir
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user