[CODEGEN] Now padding shared memory for layout conversion (#468)

This commit is contained in:
Philippe Tillet
2022-03-03 22:19:05 -08:00
committed by GitHub
parent d9dd97492f
commit bb5765df5c
4 changed files with 62 additions and 35 deletions

View File

@@ -208,10 +208,12 @@ mma_layout::mma_layout(size_t num_warps,
int pack_size_1 = (is_b_row && !is_b_vec4) ? 2 : 1;
rep_ = {2*pack_size_0, 2*pack_size_1, 1};
spw_ = {fpw_[0]*4*rep_[0], fpw_[1]*4*rep_[1], 1};
contig_per_thread_ = {1, 1};
}
else{
// fpw_ = {1, 1, 1};
spw_ = mma_instr_shape_.at(tensor_core_type_); // e.g., {16, 8, 16} for f32.f16.f16.f32
contig_per_thread_ = {1, 1};
// rep_ = {2, 2, 1};
}
order_ = {0, 1};
@@ -628,6 +630,12 @@ void layouts::run(ir::module &mod) {
shape[k] = std::max(in_layout->shape_per_cta(k),
out_layout->shape_per_cta(k));
}
auto in_ord = in_layout->get_order();
auto out_ord = out_layout->get_order();
int in_vec = in_layout->contig_per_thread(in_ord[0]);
int out_vec = out_layout->contig_per_thread(out_ord[0]);
int pad = std::max(in_vec, out_vec);
shape[out_ord[0]] += pad;
layouts_[id] = new shared_layout(out_layout, axes_->get(val), shape, {val}, val->get_type()->get_scalar_ty(), align_, tgt_);
tmp_[val] = id;
}

View File

@@ -14,41 +14,42 @@ void swizzle::run(ir::module &) {
max_phase_.clear();
for(auto &x: layouts_->get_all()){
shared_layout* layout = dynamic_cast<shared_layout*>(x.second);
if(!layout)
shared_layout* out_layout = dynamic_cast<shared_layout*>(x.second);
if(!out_layout)
continue;
ir::value* mma_dot_a = layout->hmma_dot_a();
ir::value* mma_dot_b = layout->hmma_dot_b();
if(!mma_dot_a && !mma_dot_b){
per_phase_[layout] = 1;
max_phase_[layout] = 1;
vec_[layout] = 1;
continue;
}
auto ord = layout->get_order();
scanline_layout* in_layout = dynamic_cast<scanline_layout*>(layout->get_arg_layout());
scanline_layout* in_layout = dynamic_cast<scanline_layout*>(out_layout->get_arg_layout());
if(!in_layout)
continue;
int dtsize = layout->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
ir::value* mma_dot_a = out_layout->hmma_dot_a();
ir::value* mma_dot_b = out_layout->hmma_dot_b();
if(!mma_dot_a && !mma_dot_b){
per_phase_[out_layout] = 1;
max_phase_[out_layout] = 1;
vec_[out_layout] = 1;
continue;
}
auto ord = out_layout->get_order();
int dtsize = out_layout->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
if(tgt_->as_nvidia() && tgt_->as_nvidia()->sm() < 80){
int inner = mma_dot_a ? 0 : 1;
per_phase_[layout] = std::max<int>(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1);
max_phase_[layout] = (ord[inner] == 1 ? 8 : 4) / per_phase_[layout];
per_phase_[out_layout] = std::max<int>(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1);
max_phase_[out_layout] = (ord[inner] == 1 ? 8 : 4) / per_phase_[out_layout];
if(mma_dot_a)
vec_[layout] = 2*layouts_->get(mma_dot_a)->to_mma()->rep(0);
vec_[out_layout] = 2*layouts_->get(mma_dot_a)->to_mma()->rep(0);
else
vec_[layout] = 2*layouts_->get(mma_dot_b)->to_mma()->rep(1);
vec_[out_layout] = 2*layouts_->get(mma_dot_b)->to_mma()->rep(1);
}
else {
if (!layout->allow_swizzle()) {
per_phase_[layout] = 1;
max_phase_[layout] = 1;
vec_[layout] = 1;
if (!out_layout->allow_swizzle()) {
per_phase_[out_layout] = 1;
max_phase_[out_layout] = 1;
vec_[out_layout] = 1;
} else {
per_phase_[layout] = std::max<int>(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1);
max_phase_[layout] = layout->get_mma_strided() / per_phase_[layout];
vec_[layout] = layout->get_mma_vec();
per_phase_[out_layout] = std::max<int>(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1);
max_phase_[out_layout] = out_layout->get_mma_strided() / per_phase_[out_layout];
vec_[out_layout] = out_layout->get_mma_vec();
}
}
}