[BACKEND] Making the warp-level tile "more square" to increase data-reuse for tl.dot. (#442)
* Increase smem data-reuse for some layouts * tweak * Keep the original tiling logic for sm < 80 Co-authored-by: Philippe Tillet <phil@openai.com>
This commit is contained in:
@@ -217,16 +217,36 @@ mma_layout::mma_layout(size_t num_warps,
|
|||||||
order_ = {0, 1};
|
order_ = {0, 1};
|
||||||
|
|
||||||
/* warps per tile */
|
/* warps per tile */
|
||||||
// try to make things as square as possible to maximize data re-use
|
|
||||||
wpt_ = {1, 1, 1};
|
wpt_ = {1, 1, 1};
|
||||||
std::vector<int> wpt_nm1;
|
// try to make warp-level tiles as square as possible to maximize data re-use
|
||||||
do{
|
if (tgt->as_nvidia()->sm() < 80) {
|
||||||
wpt_nm1 = wpt_;
|
std::vector<int> wpt_nm1;
|
||||||
if(wpt_[0] * wpt_[1] * wpt_[2] < num_warps)
|
do{
|
||||||
wpt_[0] = clamp(wpt_[0]*2, 1, shape_[0] / spw_[0]);
|
wpt_nm1 = wpt_;
|
||||||
if(wpt_[0] * wpt_[1] * wpt_[2] < num_warps)
|
if(wpt_[0] * wpt_[1] * wpt_[2] < num_warps)
|
||||||
wpt_[1] = clamp(wpt_[1]*2, 1, shape_[1] / spw_[1]);
|
wpt_[0] = clamp(wpt_[0]*2, 1, shape_[0] / spw_[0]);
|
||||||
}while(wpt_nm1 != wpt_);
|
if(wpt_[0] * wpt_[1] * wpt_[2] < num_warps)
|
||||||
|
wpt_[1] = clamp(wpt_[1]*2, 1, shape_[1] / spw_[1]);
|
||||||
|
}while(wpt_nm1 != wpt_);
|
||||||
|
} else {
|
||||||
|
bool changed = false;
|
||||||
|
do {
|
||||||
|
changed = false;
|
||||||
|
if (wpt_[0] * wpt_[1] * wpt_[2] >= num_warps)
|
||||||
|
break;
|
||||||
|
if (shape_[0] / spw_[0] / wpt_[0] >= shape_[1] / (spw_[1]*2) / wpt_[1]) {
|
||||||
|
if (wpt_[0] < shape_[0] / spw_[0]) {
|
||||||
|
wpt_[0] *= 2;
|
||||||
|
changed = true;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (wpt_[1] < shape_[1] / (spw_[1]*2)) {
|
||||||
|
wpt_[1] *= 2;
|
||||||
|
changed = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} while (changed);
|
||||||
|
}
|
||||||
|
|
||||||
/* shape per block */
|
/* shape per block */
|
||||||
shape_per_cta_ = {spw_[0]*wpt_[0], spw_[1]*wpt_[1], 1};
|
shape_per_cta_ = {spw_[0]*wpt_[0], spw_[1]*wpt_[1], 1};
|
||||||
|
Reference in New Issue
Block a user