From e843257295961634cc4adb7fe25fc820eb083d66 Mon Sep 17 00:00:00 2001 From: goostavz <109190422+goostavz@users.noreply.github.com> Date: Wed, 5 Oct 2022 12:29:59 +0800 Subject: [PATCH] [Backend] Fix a bug in emitIndicesForBlocked (#740) --- .../TritonGPUToLLVM/TritonGPUToLLVM.cpp | 11 ++-- test/Conversion/tritongpu_to_llvm.mlir | 50 +++++++++++++++++++ 2 files changed, 55 insertions(+), 6 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index d70763bbe..3501a4c09 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -548,6 +548,9 @@ public: auto warpsPerCTA = blockedLayout.getWarpsPerCTA(); unsigned rank = shape.size(); SmallVector shapePerCTA = getShapePerCTA(blockedLayout); + SmallVector tilesPerDim(rank); + for (unsigned k = 0; k < rank; ++k) + tilesPerDim[k] = ceil(shape[k], shapePerCTA[k]); // step 1, delinearize threadId to get the base index auto multiDimBase = @@ -558,8 +561,7 @@ public: SmallVector> offset(rank); for (unsigned k = 0; k < rank; ++k) { // 1 block in minimum if shape[k] is less than shapePerCTA[k] - for (unsigned blockOffset = 0; - blockOffset < ceil(shape[k], shapePerCTA[k]); + for (unsigned blockOffset = 0; blockOffset < tilesPerDim[k]; ++blockOffset) for (unsigned warpOffset = 0; warpOffset < warpsPerCTA[k]; ++warpOffset) for (unsigned threadOffset = 0; threadOffset < threadsPerWarp[k]; @@ -577,15 +579,12 @@ public: SmallVector> multiDimIdx(elemsPerThread, SmallVector(rank)); unsigned totalSizePerThread = product(sizePerThread); - SmallVector threadsPerDim(rank); - for (unsigned k = 0; k < rank; ++k) - threadsPerDim[k] = ceil(shape[k], sizePerThread[k]); for (unsigned n = 0; n < elemsPerThread; ++n) { unsigned linearNanoTileId = n / totalSizePerThread; unsigned linearNanoTileElemId = n % totalSizePerThread; SmallVector multiDimNanoTileId = - getMultiDimIndex(linearNanoTileId, threadsPerDim); + getMultiDimIndex(linearNanoTileId, tilesPerDim); SmallVector multiDimNanoTileElemId = getMultiDimIndex(linearNanoTileElemId, sizePerThread); for (unsigned k = 0; k < rank; ++k) { diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index f22371ee8..2c8583fed 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -426,6 +426,56 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { // ----- +#block0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [8], warpsPerCTA = [4], order = [0]}> +#block2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 1], warpsPerCTA = [4, 1], order = [1, 0]}> +#block3 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 8], warpsPerCTA = [1, 4], order = [1, 0]}> +#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#A = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 4, order = [1, 0]}> +module attributes {"triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: basic_insert_slice_async_v1_multictas + func @basic_insert_slice_async_v1_multictas(%arg0: !tt.ptr {tt.divisibility = 4 : i32}) { + %off0_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #block0> + %off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #block0> + %off0 = tt.expand_dims %off0_ {axis = 1 : i32} : (tensor<32xi32, #block0>) -> tensor<32x1xi32, #block2> + %off1 = tt.expand_dims %off1_ {axis = 0 : i32} : (tensor<32xi32, #block0>) -> tensor<1x32xi32, #block3> + %broadcast_off0_scalar = tt.broadcast %off0 : (tensor<32x1xi32, #block2>) -> tensor<32x32xi32, #block2> + %cst_scalar = arith.constant 32 : i32 + %cst = tt.splat %cst_scalar : (i32) -> tensor<32x32xi32, #block2> + %broadcast_off0_ = arith.muli %broadcast_off0_scalar, %cst : tensor<32x32xi32, #block2> + %broadcast_off1_ = tt.broadcast %off1 : (tensor<1x32xi32, #block3>) -> tensor<32x32xi32, #block3> + %broadcast_off0 = triton_gpu.convert_layout %broadcast_off0_ : (tensor<32x32xi32, #block2>) -> tensor<32x32xi32, #AL> + %broadcast_off1 = triton_gpu.convert_layout %broadcast_off1_ : (tensor<32x32xi32, #block3>) -> tensor<32x32xi32, #AL> + %off = arith.addi %broadcast_off0, %broadcast_off1 : tensor<32x32xi32, #AL> + %a_init = tt.splat %arg0 : (!tt.ptr) -> tensor<32x32x!tt.ptr, #AL> + %a_ptr = tt.addptr %a_init, %off : tensor<32x32x!tt.ptr, #AL> + %tensor = triton_gpu.alloc_tensor : tensor<2x32x32xf32, #A> + %index = arith.constant 1 : i32 + + // CHECK: llvm.inline_asm + // CHECK-SAME: cp.async.ca.shared.global.L2::evict_normal [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x20, 0x20 + // CHECK: llvm.inline_asm + // CHECK-SAME: cp.async.ca.shared.global.L2::evict_normal [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x20, 0x20 + // CHECK: llvm.inline_asm + // CHECK-SAME: cp.async.ca.shared.global.L2::evict_normal [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x20, 0x20 + // CHECK: llvm.inline_asm + // CHECK-SAME: cp.async.ca.shared.global.L2::evict_normal [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x20, 0x20 + // CHECK: llvm.inline_asm + // CHECK-SAME: cp.async.ca.shared.global.L2::evict_normal [ ${{.*}} + 512 ], [ ${{.*}} + 0 ], 0x20, 0x20 + // CHECK: llvm.inline_asm + // CHECK-SAME: cp.async.ca.shared.global.L2::evict_normal [ ${{.*}} + 512 ], [ ${{.*}} + 0 ], 0x20, 0x20 + // CHECK: llvm.inline_asm + // CHECK-SAME: cp.async.ca.shared.global.L2::evict_normal [ ${{.*}} + 512 ], [ ${{.*}} + 0 ], 0x20, 0x20 + // CHECK: llvm.inline_asm + // CHECK-SAME: cp.async.ca.shared.global.L2::evict_normal [ ${{.*}} + 512 ], [ ${{.*}} + 0 ], 0x20, 0x20 + // CHECK: llvm.inline_asm + // CHECK-SAME: cp.async.commit_group + %a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32x!tt.ptr, #AL> -> tensor<2x32x32xf32, #A> + return + } +} + +// ----- + #blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> module attributes {"triton_gpu.num-warps" = 4 : i32} { // CHECK: basic_splat