[Backend] Fix a bug in emitIndicesForBlocked (#740)
This commit is contained in:
@@ -548,6 +548,9 @@ public:
|
||||
auto warpsPerCTA = blockedLayout.getWarpsPerCTA();
|
||||
unsigned rank = shape.size();
|
||||
SmallVector<unsigned> shapePerCTA = getShapePerCTA(blockedLayout);
|
||||
SmallVector<unsigned> tilesPerDim(rank);
|
||||
for (unsigned k = 0; k < rank; ++k)
|
||||
tilesPerDim[k] = ceil<unsigned>(shape[k], shapePerCTA[k]);
|
||||
|
||||
// step 1, delinearize threadId to get the base index
|
||||
auto multiDimBase =
|
||||
@@ -558,8 +561,7 @@ public:
|
||||
SmallVector<SmallVector<unsigned>> offset(rank);
|
||||
for (unsigned k = 0; k < rank; ++k) {
|
||||
// 1 block in minimum if shape[k] is less than shapePerCTA[k]
|
||||
for (unsigned blockOffset = 0;
|
||||
blockOffset < ceil<unsigned>(shape[k], shapePerCTA[k]);
|
||||
for (unsigned blockOffset = 0; blockOffset < tilesPerDim[k];
|
||||
++blockOffset)
|
||||
for (unsigned warpOffset = 0; warpOffset < warpsPerCTA[k]; ++warpOffset)
|
||||
for (unsigned threadOffset = 0; threadOffset < threadsPerWarp[k];
|
||||
@@ -577,15 +579,12 @@ public:
|
||||
SmallVector<SmallVector<Value>> multiDimIdx(elemsPerThread,
|
||||
SmallVector<Value>(rank));
|
||||
unsigned totalSizePerThread = product<unsigned>(sizePerThread);
|
||||
SmallVector<unsigned> threadsPerDim(rank);
|
||||
for (unsigned k = 0; k < rank; ++k)
|
||||
threadsPerDim[k] = ceil<unsigned>(shape[k], sizePerThread[k]);
|
||||
|
||||
for (unsigned n = 0; n < elemsPerThread; ++n) {
|
||||
unsigned linearNanoTileId = n / totalSizePerThread;
|
||||
unsigned linearNanoTileElemId = n % totalSizePerThread;
|
||||
SmallVector<unsigned> multiDimNanoTileId =
|
||||
getMultiDimIndex<unsigned>(linearNanoTileId, threadsPerDim);
|
||||
getMultiDimIndex<unsigned>(linearNanoTileId, tilesPerDim);
|
||||
SmallVector<unsigned> multiDimNanoTileElemId =
|
||||
getMultiDimIndex<unsigned>(linearNanoTileElemId, sizePerThread);
|
||||
for (unsigned k = 0; k < rank; ++k) {
|
||||
|
Reference in New Issue
Block a user