[BACKEND] Making the warp-level tile "more square" to increase data-reuse for tl.dot. (#442)
* Increase smem data-reuse for some layouts * tweak * Keep the original tiling logic for sm < 80 Co-authored-by: Philippe Tillet <phil@openai.com>
This commit is contained in:
@@ -217,16 +217,36 @@ mma_layout::mma_layout(size_t num_warps,
|
||||
order_ = {0, 1};
|
||||
|
||||
/* warps per tile */
|
||||
// try to make things as square as possible to maximize data re-use
|
||||
wpt_ = {1, 1, 1};
|
||||
std::vector<int> wpt_nm1;
|
||||
do{
|
||||
wpt_nm1 = wpt_;
|
||||
if(wpt_[0] * wpt_[1] * wpt_[2] < num_warps)
|
||||
wpt_[0] = clamp(wpt_[0]*2, 1, shape_[0] / spw_[0]);
|
||||
if(wpt_[0] * wpt_[1] * wpt_[2] < num_warps)
|
||||
wpt_[1] = clamp(wpt_[1]*2, 1, shape_[1] / spw_[1]);
|
||||
}while(wpt_nm1 != wpt_);
|
||||
// try to make warp-level tiles as square as possible to maximize data re-use
|
||||
if (tgt->as_nvidia()->sm() < 80) {
|
||||
std::vector<int> wpt_nm1;
|
||||
do{
|
||||
wpt_nm1 = wpt_;
|
||||
if(wpt_[0] * wpt_[1] * wpt_[2] < num_warps)
|
||||
wpt_[0] = clamp(wpt_[0]*2, 1, shape_[0] / spw_[0]);
|
||||
if(wpt_[0] * wpt_[1] * wpt_[2] < num_warps)
|
||||
wpt_[1] = clamp(wpt_[1]*2, 1, shape_[1] / spw_[1]);
|
||||
}while(wpt_nm1 != wpt_);
|
||||
} else {
|
||||
bool changed = false;
|
||||
do {
|
||||
changed = false;
|
||||
if (wpt_[0] * wpt_[1] * wpt_[2] >= num_warps)
|
||||
break;
|
||||
if (shape_[0] / spw_[0] / wpt_[0] >= shape_[1] / (spw_[1]*2) / wpt_[1]) {
|
||||
if (wpt_[0] < shape_[0] / spw_[0]) {
|
||||
wpt_[0] *= 2;
|
||||
changed = true;
|
||||
}
|
||||
} else {
|
||||
if (wpt_[1] < shape_[1] / (spw_[1]*2)) {
|
||||
wpt_[1] *= 2;
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
} while (changed);
|
||||
}
|
||||
|
||||
/* shape per block */
|
||||
shape_per_cta_ = {spw_[0]*wpt_[0], spw_[1]*wpt_[1], 1};
|
||||
|
Reference in New Issue
Block a user