[Triton-MLIR][Backend] Fix many problems to get the pipeline working (#809)

1. Rewrite code generation of insert_slice_async.
2. Correct the wrong index passed to extract_slice in pipeline.
3. Add a prologue in pipeline to wait for dangling cp.asyncs.  
4. Move scf to cf conversion inside TritonGPUToLLVM because we need to
perform membar before scf to cf. It shouldn't be a technical limitation
and could be improved by a more general membar analysis.
5. Use an attribute to memoize the shared memory size and support
dynamic shared memory.
6. Prevent the combine pass to reorder insert_slice and extract_slice
across async_wait

Co-authored-by: Superjomn <yanchunwei@outlook.com>
This commit is contained in:
Keren Zhou
2022-10-27 22:09:06 -07:00
committed by GitHub
parent 42db3538e4
commit 3b80801dff
10 changed files with 122 additions and 70 deletions

View File

@@ -326,7 +326,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
#shared0 = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK: llvm.mlir.global internal @global_smem
// CHECK: llvm.mlir.global external @global_smem
// CHECK-LABEL: basic_alloc_tensor
func @basic_alloc_tensor() {
// CHECK: llvm.mlir.addressof @global_smem
@@ -343,7 +343,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
#shared0 = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK: llvm.mlir.global internal @global_smem
// CHECK: llvm.mlir.global external @global_smem
// CHECK-LABEL: basic_extract_slice
func @basic_extract_slice() {
// CHECK: %[[BASE0:.*]] = llvm.mlir.addressof @global_smem
@@ -382,10 +382,10 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
#slice2d1 = #triton_gpu.slice<{dim = 1, parent=#block2}>
#slice3d0 = #triton_gpu.slice<{dim = 0, parent=#block3}>
#AL = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
#A = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 4, order = [1, 0]}>
#A = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: basic_insert_slice_async_v4
func @basic_insert_slice_async_v4(%arg0: !tt.ptr<f32> {tt.divisibility = 4 : i32}) {
func @basic_insert_slice_async_v4(%arg0: !tt.ptr<f32> {tt.divisibility = 8 : i32}) {
%off0_ = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #slice2d1>
%off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<64xi32, #slice3d0>
%off0 = tt.expand_dims %off0_ {axis = 1 : i32} : (tensor<16xi32, #slice2d1>) -> tensor<16x1xi32, #block2>
@@ -404,9 +404,9 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
%index = arith.constant 1 : i32
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
// CHECK-SAME: cp.async.cg.shared.global.L2::evict_normal [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x80, 0x80
// CHECK-SAME: cp.async.cg.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x10, 0x10
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
// CHECK-SAME: cp.async.cg.shared.global.L2::evict_normal [ ${{.*}} + 8 ], [ ${{.*}} + 0 ], 0x80, 0x80
// CHECK-SAME: cp.async.cg.shared.global [ ${{.*}} + 16 ], [ ${{.*}} + 0 ], 0x10, 0x10
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
// CHECK-SAME: cp.async.commit_group
%a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x64x!tt.ptr<f32>, #AL> -> tensor<2x16x64xf32, #A>
@@ -445,13 +445,13 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
%index = arith.constant 1 : i32
// CHECK: llvm.inline_asm
// CHECK-SAME: cp.async.ca.shared.global.L2::evict_normal [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x20, 0x20
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4
// CHECK: llvm.inline_asm
// CHECK-SAME: cp.async.ca.shared.global.L2::evict_normal [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x20, 0x20
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4
// CHECK: llvm.inline_asm
// CHECK-SAME: cp.async.ca.shared.global.L2::evict_normal [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x20, 0x20
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4
// CHECK: llvm.inline_asm
// CHECK-SAME: cp.async.ca.shared.global.L2::evict_normal [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x20, 0x20
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4
// CHECK: llvm.inline_asm
// CHECK-SAME: cp.async.commit_group
%a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x32x!tt.ptr<f32>, #AL> -> tensor<2x16x32xf32, #A>
@@ -489,21 +489,21 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
%index = arith.constant 1 : i32
// CHECK: llvm.inline_asm
// CHECK-SAME: cp.async.ca.shared.global.L2::evict_normal [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x20, 0x20
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4
// CHECK: llvm.inline_asm
// CHECK-SAME: cp.async.ca.shared.global.L2::evict_normal [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x20, 0x20
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4
// CHECK: llvm.inline_asm
// CHECK-SAME: cp.async.ca.shared.global.L2::evict_normal [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x20, 0x20
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4
// CHECK: llvm.inline_asm
// CHECK-SAME: cp.async.ca.shared.global.L2::evict_normal [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x20, 0x20
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4
// CHECK: llvm.inline_asm
// CHECK-SAME: cp.async.ca.shared.global.L2::evict_normal [ ${{.*}} + 512 ], [ ${{.*}} + 0 ], 0x20, 0x20
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 2048 ], [ ${{.*}} + 0 ], 0x4, 0x4
// CHECK: llvm.inline_asm
// CHECK-SAME: cp.async.ca.shared.global.L2::evict_normal [ ${{.*}} + 512 ], [ ${{.*}} + 0 ], 0x20, 0x20
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 2048 ], [ ${{.*}} + 0 ], 0x4, 0x4
// CHECK: llvm.inline_asm
// CHECK-SAME: cp.async.ca.shared.global.L2::evict_normal [ ${{.*}} + 512 ], [ ${{.*}} + 0 ], 0x20, 0x20
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 2048 ], [ ${{.*}} + 0 ], 0x4, 0x4
// CHECK: llvm.inline_asm
// CHECK-SAME: cp.async.ca.shared.global.L2::evict_normal [ ${{.*}} + 512 ], [ ${{.*}} + 0 ], 0x20, 0x20
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 2048 ], [ ${{.*}} + 0 ], 0x4, 0x4
// CHECK: llvm.inline_asm
// CHECK-SAME: cp.async.commit_group
%a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32x!tt.ptr<f32>, #AL> -> tensor<2x32x32xf32, #A>
@@ -545,7 +545,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [0, 1]}>
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK: llvm.mlir.global internal @global_smem() {addr_space = 3 : i32} : !llvm.array<1088 x i8>
// CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8>
// CHECK-LABEL: convert_layout_blocked_blocked
func @convert_layout_blocked_blocked(%arg0: tensor<16x16xf32, #blocked0>) {
// CHECK: llvm.mlir.addressof @global_smem
@@ -593,7 +593,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [16, 2], warpsPerCTA = [1, 1], order = [1, 0]}>
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK: llvm.mlir.global internal @global_smem() {addr_space = 3 : i32} : !llvm.array<1280 x i8>
// CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8>
// CHECK-LABEL: convert_layout_blocked_blocked_vec
func @convert_layout_blocked_blocked_vec(%arg0: tensor<16x16xf32, #blocked0>) {
// CHECK: llvm.mlir.addressof @global_smem
@@ -617,7 +617,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [1, 0]}>
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK: llvm.mlir.global internal @global_smem() {addr_space = 3 : i32} : !llvm.array<640 x i8>
// CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8>
// CHECK-LABEL: convert_layout_blocked_blocked_multi_rep
func @convert_layout_blocked_blocked_multi_rep(%arg0: tensor<16x16xf32, #blocked0>) {
// CHECK: llvm.mlir.addressof @global_smem
@@ -682,7 +682,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}>
#mma = #triton_gpu.mma<{version = 2, warpsPerCTA = [2, 2]}>
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK: llvm.mlir.global internal @global_smem() {addr_space = 3 : i32} : !llvm.array<2560 x i8>
// CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8>
// CHECK-LABEL: convert_layout_mma_block
func @convert_layout_mma_blocked(%arg0: tensor<32x16xf32, #mma>) {
// CHECK: nvvm.barrier0
@@ -703,7 +703,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0]}>
#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0]}>
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK: llvm.mlir.global internal @global_smem() {addr_space = 3 : i32} : !llvm.array<16384 x i8>
// CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8>
// CHECK-LABEL: convert_layout_blocked_shared
func @convert_layout_blocked_shared(%arg0: tensor<128x32xf32, #blocked0>) {
// CHECK: llvm.store