[Backend] Fix a bug in emitIndicesForBlocked (#740)
This commit is contained in:
@@ -548,6 +548,9 @@ public:
|
|||||||
auto warpsPerCTA = blockedLayout.getWarpsPerCTA();
|
auto warpsPerCTA = blockedLayout.getWarpsPerCTA();
|
||||||
unsigned rank = shape.size();
|
unsigned rank = shape.size();
|
||||||
SmallVector<unsigned> shapePerCTA = getShapePerCTA(blockedLayout);
|
SmallVector<unsigned> shapePerCTA = getShapePerCTA(blockedLayout);
|
||||||
|
SmallVector<unsigned> tilesPerDim(rank);
|
||||||
|
for (unsigned k = 0; k < rank; ++k)
|
||||||
|
tilesPerDim[k] = ceil<unsigned>(shape[k], shapePerCTA[k]);
|
||||||
|
|
||||||
// step 1, delinearize threadId to get the base index
|
// step 1, delinearize threadId to get the base index
|
||||||
auto multiDimBase =
|
auto multiDimBase =
|
||||||
@@ -558,8 +561,7 @@ public:
|
|||||||
SmallVector<SmallVector<unsigned>> offset(rank);
|
SmallVector<SmallVector<unsigned>> offset(rank);
|
||||||
for (unsigned k = 0; k < rank; ++k) {
|
for (unsigned k = 0; k < rank; ++k) {
|
||||||
// 1 block in minimum if shape[k] is less than shapePerCTA[k]
|
// 1 block in minimum if shape[k] is less than shapePerCTA[k]
|
||||||
for (unsigned blockOffset = 0;
|
for (unsigned blockOffset = 0; blockOffset < tilesPerDim[k];
|
||||||
blockOffset < ceil<unsigned>(shape[k], shapePerCTA[k]);
|
|
||||||
++blockOffset)
|
++blockOffset)
|
||||||
for (unsigned warpOffset = 0; warpOffset < warpsPerCTA[k]; ++warpOffset)
|
for (unsigned warpOffset = 0; warpOffset < warpsPerCTA[k]; ++warpOffset)
|
||||||
for (unsigned threadOffset = 0; threadOffset < threadsPerWarp[k];
|
for (unsigned threadOffset = 0; threadOffset < threadsPerWarp[k];
|
||||||
@@ -577,15 +579,12 @@ public:
|
|||||||
SmallVector<SmallVector<Value>> multiDimIdx(elemsPerThread,
|
SmallVector<SmallVector<Value>> multiDimIdx(elemsPerThread,
|
||||||
SmallVector<Value>(rank));
|
SmallVector<Value>(rank));
|
||||||
unsigned totalSizePerThread = product<unsigned>(sizePerThread);
|
unsigned totalSizePerThread = product<unsigned>(sizePerThread);
|
||||||
SmallVector<unsigned> threadsPerDim(rank);
|
|
||||||
for (unsigned k = 0; k < rank; ++k)
|
|
||||||
threadsPerDim[k] = ceil<unsigned>(shape[k], sizePerThread[k]);
|
|
||||||
|
|
||||||
for (unsigned n = 0; n < elemsPerThread; ++n) {
|
for (unsigned n = 0; n < elemsPerThread; ++n) {
|
||||||
unsigned linearNanoTileId = n / totalSizePerThread;
|
unsigned linearNanoTileId = n / totalSizePerThread;
|
||||||
unsigned linearNanoTileElemId = n % totalSizePerThread;
|
unsigned linearNanoTileElemId = n % totalSizePerThread;
|
||||||
SmallVector<unsigned> multiDimNanoTileId =
|
SmallVector<unsigned> multiDimNanoTileId =
|
||||||
getMultiDimIndex<unsigned>(linearNanoTileId, threadsPerDim);
|
getMultiDimIndex<unsigned>(linearNanoTileId, tilesPerDim);
|
||||||
SmallVector<unsigned> multiDimNanoTileElemId =
|
SmallVector<unsigned> multiDimNanoTileElemId =
|
||||||
getMultiDimIndex<unsigned>(linearNanoTileElemId, sizePerThread);
|
getMultiDimIndex<unsigned>(linearNanoTileElemId, sizePerThread);
|
||||||
for (unsigned k = 0; k < rank; ++k) {
|
for (unsigned k = 0; k < rank; ++k) {
|
||||||
|
@@ -426,6 +426,56 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
#block0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [8], warpsPerCTA = [4], order = [0]}>
|
||||||
|
#block2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
|
||||||
|
#block3 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 8], warpsPerCTA = [1, 4], order = [1, 0]}>
|
||||||
|
#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
|
||||||
|
#A = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 4, order = [1, 0]}>
|
||||||
|
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||||
|
// CHECK-LABEL: basic_insert_slice_async_v1_multictas
|
||||||
|
func @basic_insert_slice_async_v1_multictas(%arg0: !tt.ptr<f32> {tt.divisibility = 4 : i32}) {
|
||||||
|
%off0_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #block0>
|
||||||
|
%off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #block0>
|
||||||
|
%off0 = tt.expand_dims %off0_ {axis = 1 : i32} : (tensor<32xi32, #block0>) -> tensor<32x1xi32, #block2>
|
||||||
|
%off1 = tt.expand_dims %off1_ {axis = 0 : i32} : (tensor<32xi32, #block0>) -> tensor<1x32xi32, #block3>
|
||||||
|
%broadcast_off0_scalar = tt.broadcast %off0 : (tensor<32x1xi32, #block2>) -> tensor<32x32xi32, #block2>
|
||||||
|
%cst_scalar = arith.constant 32 : i32
|
||||||
|
%cst = tt.splat %cst_scalar : (i32) -> tensor<32x32xi32, #block2>
|
||||||
|
%broadcast_off0_ = arith.muli %broadcast_off0_scalar, %cst : tensor<32x32xi32, #block2>
|
||||||
|
%broadcast_off1_ = tt.broadcast %off1 : (tensor<1x32xi32, #block3>) -> tensor<32x32xi32, #block3>
|
||||||
|
%broadcast_off0 = triton_gpu.convert_layout %broadcast_off0_ : (tensor<32x32xi32, #block2>) -> tensor<32x32xi32, #AL>
|
||||||
|
%broadcast_off1 = triton_gpu.convert_layout %broadcast_off1_ : (tensor<32x32xi32, #block3>) -> tensor<32x32xi32, #AL>
|
||||||
|
%off = arith.addi %broadcast_off0, %broadcast_off1 : tensor<32x32xi32, #AL>
|
||||||
|
%a_init = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<32x32x!tt.ptr<f32>, #AL>
|
||||||
|
%a_ptr = tt.addptr %a_init, %off : tensor<32x32x!tt.ptr<f32>, #AL>
|
||||||
|
%tensor = triton_gpu.alloc_tensor : tensor<2x32x32xf32, #A>
|
||||||
|
%index = arith.constant 1 : i32
|
||||||
|
|
||||||
|
// CHECK: llvm.inline_asm
|
||||||
|
// CHECK-SAME: cp.async.ca.shared.global.L2::evict_normal [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x20, 0x20
|
||||||
|
// CHECK: llvm.inline_asm
|
||||||
|
// CHECK-SAME: cp.async.ca.shared.global.L2::evict_normal [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x20, 0x20
|
||||||
|
// CHECK: llvm.inline_asm
|
||||||
|
// CHECK-SAME: cp.async.ca.shared.global.L2::evict_normal [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x20, 0x20
|
||||||
|
// CHECK: llvm.inline_asm
|
||||||
|
// CHECK-SAME: cp.async.ca.shared.global.L2::evict_normal [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x20, 0x20
|
||||||
|
// CHECK: llvm.inline_asm
|
||||||
|
// CHECK-SAME: cp.async.ca.shared.global.L2::evict_normal [ ${{.*}} + 512 ], [ ${{.*}} + 0 ], 0x20, 0x20
|
||||||
|
// CHECK: llvm.inline_asm
|
||||||
|
// CHECK-SAME: cp.async.ca.shared.global.L2::evict_normal [ ${{.*}} + 512 ], [ ${{.*}} + 0 ], 0x20, 0x20
|
||||||
|
// CHECK: llvm.inline_asm
|
||||||
|
// CHECK-SAME: cp.async.ca.shared.global.L2::evict_normal [ ${{.*}} + 512 ], [ ${{.*}} + 0 ], 0x20, 0x20
|
||||||
|
// CHECK: llvm.inline_asm
|
||||||
|
// CHECK-SAME: cp.async.ca.shared.global.L2::evict_normal [ ${{.*}} + 512 ], [ ${{.*}} + 0 ], 0x20, 0x20
|
||||||
|
// CHECK: llvm.inline_asm
|
||||||
|
// CHECK-SAME: cp.async.commit_group
|
||||||
|
%a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32x!tt.ptr<f32>, #AL> -> tensor<2x32x32xf32, #A>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||||
// CHECK: basic_splat
|
// CHECK: basic_splat
|
||||||
|
Reference in New Issue
Block a user