diff --git a/lib/codegen/analysis/layout.cc b/lib/codegen/analysis/layout.cc index 2206f5b6a..587234863 100644 --- a/lib/codegen/analysis/layout.cc +++ b/lib/codegen/analysis/layout.cc @@ -217,16 +217,36 @@ mma_layout::mma_layout(size_t num_warps, order_ = {0, 1}; /* warps per tile */ - // try to make things as square as possible to maximize data re-use wpt_ = {1, 1, 1}; - std::vector wpt_nm1; - do{ - wpt_nm1 = wpt_; - if(wpt_[0] * wpt_[1] * wpt_[2] < num_warps) - wpt_[0] = clamp(wpt_[0]*2, 1, shape_[0] / spw_[0]); - if(wpt_[0] * wpt_[1] * wpt_[2] < num_warps) - wpt_[1] = clamp(wpt_[1]*2, 1, shape_[1] / spw_[1]); - }while(wpt_nm1 != wpt_); + // try to make warp-level tiles as square as possible to maximize data re-use + if (tgt->as_nvidia()->sm() < 80) { + std::vector wpt_nm1; + do{ + wpt_nm1 = wpt_; + if(wpt_[0] * wpt_[1] * wpt_[2] < num_warps) + wpt_[0] = clamp(wpt_[0]*2, 1, shape_[0] / spw_[0]); + if(wpt_[0] * wpt_[1] * wpt_[2] < num_warps) + wpt_[1] = clamp(wpt_[1]*2, 1, shape_[1] / spw_[1]); + }while(wpt_nm1 != wpt_); + } else { + bool changed = false; + do { + changed = false; + if (wpt_[0] * wpt_[1] * wpt_[2] >= num_warps) + break; + if (shape_[0] / spw_[0] / wpt_[0] >= shape_[1] / (spw_[1]*2) / wpt_[1]) { + if (wpt_[0] < shape_[0] / spw_[0]) { + wpt_[0] *= 2; + changed = true; + } + } else { + if (wpt_[1] < shape_[1] / (spw_[1]*2)) { + wpt_[1] *= 2; + changed = true; + } + } + } while (changed); + } /* shape per block */ shape_per_cta_ = {spw_[0]*wpt_[0], spw_[1]*wpt_[1], 1};