[Triton-MLIR] Fix threadsPerWarp derivation in BlockedEncodingAttr (#722)

Example:

```
    auto encoding = triton::gpu::BlockedEncodingAttr::get(
        &getContext(), {8, 32}, {2, 2}, {1, 0}, 2);
     //shape = [32 x 8], order = [1, 0], sizePerThread=[2, 2], numWarps=2
```

Expected output:

```
      //#triton_gpu.blocked_layout<{
      //  sizePerThread = {2, 2}
      //  threadsPerWarp = {8, 4}
      //  warpsPerCTA = {2, 1}
      //}>
```

Incorrect output by the current branch

```
      //#triton_gpu.blocked_layout<{
      //  sizePerThread = {2, 2}
      //  threadsPerWarp = {16, 2}
      //  warpsPerCTA = {2, 1}
      //}>
```
This commit is contained in:
Keren Zhou
2022-09-27 16:41:30 -07:00
committed by GitHub
parent 9ddf0921fb
commit baba98ad69
2 changed files with 3 additions and 2 deletions

View File

@@ -169,6 +169,7 @@ for
int dim = order[_dim];
int maxNumThreads = int(shape[dim]) / sizePerThread[dim];
warpsPerCTA[dim] = std::clamp(remainingWarps, 1, maxNumThreads);
maxNumThreads = maxNumThreads / warpsPerCTA[dim];
threadsPerWarp[dim] = std::clamp(remainingLanes, 1, maxNumThreads);
remainingWarps /= warpsPerCTA[dim];
remainingLanes /= threadsPerWarp[dim];

View File

@@ -7,8 +7,8 @@
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK: [[row_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
// CHECK: [[col_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1]}>
// CHECK: [[row_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [1, 0]}>
// CHECK: [[col_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [0, 1]}>
// CHECK: [[load_ptr:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64x!tt.ptr<f32>, [[row_layout]]>
// CHECK: [[load_mask:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64xi1, [[row_layout]]>
// CHECK: [[load_other:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64xf32, [[row_layout]]>