[Triton-MLIR] Fix threadsPerWarp derivation in BlockedEncodingAttr (#722)
Example: ``` auto encoding = triton::gpu::BlockedEncodingAttr::get( &getContext(), {8, 32}, {2, 2}, {1, 0}, 2); //shape = [32 x 8], order = [1, 0], sizePerThread=[2, 2], numWarps=2 ``` Expected output: ``` //#triton_gpu.blocked_layout<{ // sizePerThread = {2, 2} // threadsPerWarp = {8, 4} // warpsPerCTA = {2, 1} //}> ``` Incorrect output by the current branch ``` //#triton_gpu.blocked_layout<{ // sizePerThread = {2, 2} // threadsPerWarp = {16, 2} // warpsPerCTA = {2, 1} //}> ```
This commit is contained in:
@@ -169,6 +169,7 @@ for
|
||||
int dim = order[_dim];
|
||||
int maxNumThreads = int(shape[dim]) / sizePerThread[dim];
|
||||
warpsPerCTA[dim] = std::clamp(remainingWarps, 1, maxNumThreads);
|
||||
maxNumThreads = maxNumThreads / warpsPerCTA[dim];
|
||||
threadsPerWarp[dim] = std::clamp(remainingLanes, 1, maxNumThreads);
|
||||
remainingWarps /= warpsPerCTA[dim];
|
||||
remainingLanes /= threadsPerWarp[dim];
|
||||
|
@@ -7,8 +7,8 @@
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
|
||||
|
||||
// CHECK: [[row_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
|
||||
// CHECK: [[col_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1]}>
|
||||
// CHECK: [[row_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [1, 0]}>
|
||||
// CHECK: [[col_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [0, 1]}>
|
||||
// CHECK: [[load_ptr:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64x!tt.ptr<f32>, [[row_layout]]>
|
||||
// CHECK: [[load_mask:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64xi1, [[row_layout]]>
|
||||
// CHECK: [[load_other:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64xf32, [[row_layout]]>
|
||||
|
Reference in New Issue
Block a user