diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index af231c1ef..5edd5a51e 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -169,6 +169,7 @@ for int dim = order[_dim]; int maxNumThreads = int(shape[dim]) / sizePerThread[dim]; warpsPerCTA[dim] = std::clamp(remainingWarps, 1, maxNumThreads); + maxNumThreads = maxNumThreads / warpsPerCTA[dim]; threadsPerWarp[dim] = std::clamp(remainingLanes, 1, maxNumThreads); remainingWarps /= warpsPerCTA[dim]; remainingLanes /= threadsPerWarp[dim]; diff --git a/test/TritonGPU/coalesce.mlir b/test/TritonGPU/coalesce.mlir index 1dc46a1a6..f34f10003 100644 --- a/test/TritonGPU/coalesce.mlir +++ b/test/TritonGPU/coalesce.mlir @@ -7,8 +7,8 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { -// CHECK: [[row_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}> -// CHECK: [[col_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1]}> +// CHECK: [[row_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [1, 0]}> +// CHECK: [[col_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [0, 1]}> // CHECK: [[load_ptr:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64x!tt.ptr, [[row_layout]]> // CHECK: [[load_mask:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64xi1, [[row_layout]]> // CHECK: [[load_other:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64xf32, [[row_layout]]>