[OPTIMIZER] Fixed memory coalescing (#847)

This commit is contained in:
Philippe Tillet
2022-11-07 06:22:18 -08:00
committed by GitHub
parent b6f15e214b
commit 976cf12af1
7 changed files with 25 additions and 29 deletions

View File

@@ -163,18 +163,19 @@ for
"ArrayRef<unsigned>":$order,
"unsigned":$numWarps), [{
int rank = sizePerThread.size();
int remainingWarps = numWarps;
int remainingLanes = 32;
unsigned remainingLanes = 32;
unsigned remainingThreads = numWarps*32;
unsigned remainingWarps = numWarps;
SmallVector<unsigned, 4> threadsPerWarp(rank);
SmallVector<unsigned, 4> warpsPerCTA(rank);
for (int _dim = 0; _dim < rank; ++_dim) {
int dim = order[_dim];
int maxNumThreads = int(shape[dim]) / sizePerThread[dim];
warpsPerCTA[dim] = std::clamp(remainingWarps, 1, maxNumThreads);
maxNumThreads = maxNumThreads / warpsPerCTA[dim];
threadsPerWarp[dim] = std::clamp(remainingLanes, 1, maxNumThreads);
remainingWarps /= warpsPerCTA[dim];
remainingLanes /= threadsPerWarp[dim];
int i = order[_dim];
unsigned threadsPerCTA = std::clamp<unsigned>(remainingThreads, 1, shape[i] / sizePerThread[i]);
threadsPerWarp[i] = std::clamp<unsigned>(threadsPerCTA, 1, remainingLanes);
warpsPerCTA[i] = std::clamp<unsigned>(threadsPerCTA / threadsPerWarp[i], 1, remainingWarps);
remainingWarps /= warpsPerCTA[i];
remainingLanes /= threadsPerWarp[i];
remainingThreads /= threadsPerCTA;
}
return $_get(context, sizePerThread, threadsPerWarp, warpsPerCTA, order);