[BACKEND] Support of ConvertLayoutOp from blocked to blocked and SliceLayout with blocked parent (#658)

This commit is contained in:
goostavz
2022-09-18 05:58:42 +08:00
committed by GitHub
parent 13669b46a6
commit 15bfd0cb79
17 changed files with 1025 additions and 191 deletions

View File

@@ -1,37 +0,0 @@
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu=num-warps=2 -convert-triton-gpu-to-llvm | FileCheck %s
func @test_splat(%ptr: !tt.ptr<f32>) {
// Here, 128 elements, 64(2*32) threads, so each need to process 2 elements
//
// CHECK: %0 = llvm.bitcast %arg0 : !llvm.ptr<f32, 1> to !llvm.ptr<f32, 1>
// CHECK: %1 = llvm.mlir.undef : !llvm.struct<(ptr<f32, 1>, ptr<f32, 1>)>
// CHECK: %2 = llvm.insertvalue %0, %1[0] : !llvm.struct<(ptr<f32, 1>, ptr<f32, 1>)>
// CHECK: %3 = llvm.insertvalue %0, %2[1] : !llvm.struct<(ptr<f32, 1>, ptr<f32, 1>)>
%ptrs = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<128x!tt.ptr<f32>>
%a = arith.constant 1.0 : f32
%true = arith.constant 1 : i1
%b = tt.splat %a : (f32) -> tensor<128xf32>
// Here, each thread process only 1 element
// CHECK: %{{.*}} = llvm.mlir.undef : !llvm.struct<(i1)>
%mask = tt.splat %true : (i1) -> tensor<64xi1>
return
}
// -----
func @test_store_splat(%ptr: !tt.ptr<f32>) {
%ptrs = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<128x!tt.ptr<f32>>
%a = arith.constant 1.0 : f32
%true = arith.constant 1 : i1
%vs = tt.splat %a : (f32) -> tensor<128xf32>
%mask = tt.splat %true : (i1) -> tensor<128xi1>
// CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@${{.*}} st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };",
// CHECK-SAME: "r,l,b" %{{.*}}, %{{.*}}, %{{.*}} : (i32, !llvm.ptr<f32, 1>, i1) -> !llvm.void
tt.store %ptrs, %vs, %mask : tensor<128xf32>
return
}

View File

@@ -1,16 +1,13 @@
// RUN: triton-opt %s -split-input-file --convert-triton-gpu-to-llvm | FileCheck %s
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK: llvm.func @test_empty_kernel(%arg0: i32, %arg1: !llvm.ptr<f16, 1>)
// Here the 128 comes from the 4 in module attribute multiples 32
// CHECK: attributes {nvvm.kernel = 1 : ui1, nvvm.maxntid = 128 : si32} {{.*}}
func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) {
// CHECK: llvm.return
return
}
// CHECK: llvm.func @test_empty_kernel(%arg0: i32, %arg1: !llvm.ptr<f16, 1>)
// Here the 128 comes from the 4 in module attribute multiples 32
// CHECK: attributes {nvvm.kernel = 1 : ui1, nvvm.maxntid = 128 : si32} {{.*}}
func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) {
// CHECK: llvm.return
return
}
} // end module
// -----
@@ -58,7 +55,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
// -----
// TODO: Pending on the support of isSplat constant
// TODO: masked load with vectorization is pending on TODO
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: masked_load_const_other
@@ -71,10 +68,23 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
// -----
// TODO: masked load with vectorization is pending on TODO
#blocked0 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: masked_load_const_other_vec
func @masked_load_const_other_vec(%a_ptr_init : tensor<256x!tt.ptr<f32>, #blocked0>, %cst : tensor<256xi1, #blocked0>) {
%cst_0 = arith.constant dense<0.000000e+00> : tensor<256xf32, #blocked0>
%1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
return
}
}
// -----
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>
module attributes {"triton_gpu.num-warps" = 2 : i32} {
// CHECK-LABEL: kernel__Pfp32_Pfp32_Pfp32_i32__3c256
func @kernel__Pfp32_Pfp32_Pfp32_i32__3c256(%arg0: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg3: i32) {
// CHECK-LABEL: global_load_store_no_vec
func @global_load_store_no_vec(%arg0: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg3: i32) {
%c256_i32 = arith.constant 256 : i32
%0 = tt.get_program_id {axis = 0 : i32} : i32
%1 = arith.muli %0, %c256_i32 : i32
@@ -86,22 +96,107 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
%7 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #blocked0>
%8 = tt.addptr %7, %4 : tensor<256x!tt.ptr<f32>, #blocked0>
// CHECK: ld.global.v4.b32
// Load 4 elements from vector0
// CHECK: "@${{.*}} ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ];
// CHECK: "@${{.*}} ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ];
// CHECK: "@${{.*}} ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ];
// CHECK: "@${{.*}} ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ];
// Load 4 elements from vector1
// CHECK: "@${{.*}} ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ];
// CHECK: "@${{.*}} ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ];
// CHECK: "@${{.*}} ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ];
// CHECK: "@${{.*}} ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ];
%9 = tt.load %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
// CHECK: ld.global.v4.b32
%10 = tt.load %8 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
%11 = arith.addf %9, %10 : tensor<256xf32, #blocked0>
%12 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #blocked0>
%13 = tt.addptr %12, %4 : tensor<256x!tt.ptr<f32>, #blocked0>
// Store 4 elements to global
// CHECK: st.global.b32.v4
// CHECK: @${{.*}} st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };
// CHECK: @${{.*}} st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };
// CHECK: @${{.*}} st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };
// CHECK: @${{.*}} st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };
tt.store %13, %11 : tensor<256xf32, #blocked0>
return
}
}
// -----
#blocked0 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>
module attributes {"triton_gpu.num-warps" = 2 : i32} {
// CHECK-LABEL: global_load_store_vec4
func @global_load_store_vec4(%arg0: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg3: i32) {
%c256_i32 = arith.constant 256 : i32
%0 = tt.get_program_id {axis = 0 : i32} : i32
%1 = arith.muli %0, %c256_i32 : i32
%2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0>
%3 = tt.splat %1 : (i32) -> tensor<256xi32, #blocked0>
%4 = arith.addi %3, %2 : tensor<256xi32, #blocked0>
%5 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #blocked0>
%6 = tt.addptr %5, %4 : tensor<256x!tt.ptr<f32>, #blocked0>
%7 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #blocked0>
%8 = tt.addptr %7, %4 : tensor<256x!tt.ptr<f32>, #blocked0>
// Load 4 elements from A with single one vectorized load instruction
// CHECK: @${{.*}} ld.global.v4.b32 { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ];
// Load 4 elements from B with single one vectorized load instruction
// CHECK: @${{.*}} ld.global.v4.b32 { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ];
%9 = tt.load %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
%10 = tt.load %8 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
%11 = arith.addf %9, %10 : tensor<256xf32, #blocked0>
%12 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #blocked0>
%13 = tt.addptr %12, %4 : tensor<256x!tt.ptr<f32>, #blocked0>
// Store 4 elements to global with single one vectorized store instruction
// CHECK: @$5 st.global.b32.v4 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} };
tt.store %13, %11 : tensor<256xf32, #blocked0>
return
}
}
// -----
#blocked0 = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: global_load_store_vec8
func @global_load_store_vec8(%arg0: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg3: i32) {
%c256_i32 = arith.constant 256 : i32
%0 = tt.get_program_id {axis = 0 : i32} : i32
%1 = arith.muli %0, %c256_i32 : i32
%2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0>
%3 = tt.splat %1 : (i32) -> tensor<256xi32, #blocked0>
%4 = arith.addi %3, %2 : tensor<256xi32, #blocked0>
%5 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #blocked0>
%6 = tt.addptr %5, %4 : tensor<256x!tt.ptr<f32>, #blocked0>
%7 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #blocked0>
%8 = tt.addptr %7, %4 : tensor<256x!tt.ptr<f32>, #blocked0>
// Load 8 elements from A with two vectorized load instruction
// CHECK: @${{.*}} ld.global.v4.b32 { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ];
// CHECK: @${{.*}} ld.global.v4.b32 { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ];
// Load 8 elements from B with two vectorized load instruction
// CHECK: @${{.*}} ld.global.v4.b32 { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ];
// CHECK: @${{.*}} ld.global.v4.b32 { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ];
%9 = tt.load %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
%10 = tt.load %8 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
%11 = arith.addf %9, %10 : tensor<256xf32, #blocked0>
%12 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #blocked0>
%13 = tt.addptr %12, %4 : tensor<256x!tt.ptr<f32>, #blocked0>
// Store 8 elements to global with two vectorized store instruction
// CHECK: @$5 st.global.b32.v4 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} };
// CHECK: @$5 st.global.b32.v4 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} };
tt.store %13, %11 : tensor<256xf32, #blocked0>
return
}
}
// TODO: Add a testcase to verify the optimization when ptr of the LoadOp
// is from an addptr with const idx
@@ -217,10 +312,121 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: basic_store
func @basic_store(%ptrs: tensor<256x!tt.ptr<f32>, #blocked0>, %vals: tensor<256xf32, #blocked0>, %mask: tensor<256xi1, #blocked0>) {
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
// CHECK-SAME: @${{.*}} st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };", "r,l,b" %{{.*}}, %{{.*}}, %{{.*}} : (i32, !llvm.ptr<f32, 1>, i1) -> !llvm.void
// CHECK-SAME: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
// CHECK-SAME: @${{.*}} st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };", "r,l,b" %{{.*}}, %{{.*}}, %{{.*}} : (i32, !llvm.ptr<f32, 1>, i1) -> !llvm.void
// CHECK-SAME: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };
tt.store %ptrs, %vals, %mask : tensor<256xf32, #blocked0>
return
}
}
// -----
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [0, 1]}>
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK: llvm.mlir.global internal @global_smem() {addr_space = 3 : i32} : !llvm.array<1088 x i8>
// CHECK-LABEL: convert_layout_blocked_blocked
func @convert_layout_blocked_blocked(%arg0: tensor<16x16xf32, #blocked0>) {
// CHECK: llvm.mlir.addressof @global_smem
// CHECK: nvvm.barrier0
// CHECK: llvm.store
// CHECK-SAME: !llvm.ptr<vector<1xf32>, 3>
// CHECK: llvm.store
// CHECK-SAME: !llvm.ptr<vector<1xf32>, 3>
// CHECK: llvm.store
// CHECK-SAME: !llvm.ptr<vector<1xf32>, 3>
// CHECK: llvm.store
// CHECK-SAME: !llvm.ptr<vector<1xf32>, 3>
// CHECK: llvm.store
// CHECK-SAME: !llvm.ptr<vector<1xf32>, 3>
// CHECK: llvm.store
// CHECK-SAME: !llvm.ptr<vector<1xf32>, 3>
// CHECK: llvm.store
// CHECK-SAME: !llvm.ptr<vector<1xf32>, 3>
// CHECK: llvm.store
// CHECK-SAME: !llvm.ptr<vector<1xf32>, 3>
// CHECK: nvvm.barrier0
// CHECK: llvm.load
// CHECK-SAME: !llvm.ptr<vector<1xf32>, 3>
// CHECK: llvm.load
// CHECK-SAME: !llvm.ptr<vector<1xf32>, 3>
// CHECK: llvm.load
// CHECK-SAME: !llvm.ptr<vector<1xf32>, 3>
// CHECK: llvm.load
// CHECK-SAME: !llvm.ptr<vector<1xf32>, 3>
// CHECK: llvm.load
// CHECK-SAME: !llvm.ptr<vector<1xf32>, 3>
// CHECK: llvm.load
// CHECK-SAME: !llvm.ptr<vector<1xf32>, 3>
// CHECK: llvm.load
// CHECK-SAME: !llvm.ptr<vector<1xf32>, 3>
// CHECK: llvm.load
// CHECK-SAME: !llvm.ptr<vector<1xf32>, 3>
%0 = triton_gpu.convert_layout %arg0 : (tensor<16x16xf32, #blocked0>) -> tensor<16x16xf32, #blocked1>
return
}
}
// -----
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [16, 2], warpsPerCTA = [1, 1], order = [1, 0]}>
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK: llvm.mlir.global internal @global_smem() {addr_space = 3 : i32} : !llvm.array<1280 x i8>
// CHECK-LABEL: convert_layout_blocked_blocked_vec
func @convert_layout_blocked_blocked_vec(%arg0: tensor<16x16xf32, #blocked0>) {
// CHECK: llvm.mlir.addressof @global_smem
// CHECK: nvvm.barrier0
// CHECK: llvm.store
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
// CHECK: llvm.store
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
// CHECK: nvvm.barrier0
// CHECK: llvm.load
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
// CHECK: llvm.load
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
%0 = triton_gpu.convert_layout %arg0 : (tensor<16x16xf32, #blocked0>) -> tensor<16x16xf32, #blocked1>
return
}
}
// -----
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [1, 0]}>
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK: llvm.mlir.global internal @global_smem() {addr_space = 3 : i32} : !llvm.array<640 x i8>
// CHECK-LABEL: convert_layout_blocked_blocked_multi_rep
func @convert_layout_blocked_blocked_multi_rep(%arg0: tensor<16x16xf32, #blocked0>) {
// CHECK: llvm.mlir.addressof @global_smem
// CHECK: nvvm.barrier0
// CHECK: llvm.store
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
// CHECK: nvvm.barrier0
// CHECK: llvm.load
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
// CHECK: llvm.load
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
// CHECK: nvvm.barrier0
// CHECK: llvm.store
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
// CHECK: nvvm.barrier0
// CHECK: llvm.load
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
// CHECK: llvm.load
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
%0 = triton_gpu.convert_layout %arg0 : (tensor<16x16xf32, #blocked0>) -> tensor<16x16xf32, #blocked1>
return
}
}
// TODO: problems in MLIR's parser on slice layout
// #blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0]}>
// module attributes {"triton_gpu.num-warps" = 1 : i32} {
// func @make_range_sliced_layout() {
// %0 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>>
// return
// }
// }