[CODEGEN] Reverted some changes from previous PR; fixed vectorization characteristics of mma layout (#469)
This commit is contained in:
@@ -213,7 +213,7 @@ mma_layout::mma_layout(size_t num_warps,
|
||||
else{
|
||||
// fpw_ = {1, 1, 1};
|
||||
spw_ = mma_instr_shape_.at(tensor_core_type_); // e.g., {16, 8, 16} for f32.f16.f16.f32
|
||||
contig_per_thread_ = {1, 1};
|
||||
contig_per_thread_ = {1, 2};
|
||||
// rep_ = {2, 2, 1};
|
||||
}
|
||||
order_ = {0, 1};
|
||||
|
@@ -14,42 +14,41 @@ void swizzle::run(ir::module &) {
|
||||
max_phase_.clear();
|
||||
|
||||
for(auto &x: layouts_->get_all()){
|
||||
shared_layout* out_layout = dynamic_cast<shared_layout*>(x.second);
|
||||
if(!out_layout)
|
||||
shared_layout* layout = dynamic_cast<shared_layout*>(x.second);
|
||||
if(!layout)
|
||||
continue;
|
||||
scanline_layout* in_layout = dynamic_cast<scanline_layout*>(out_layout->get_arg_layout());
|
||||
if(!in_layout)
|
||||
continue;
|
||||
|
||||
ir::value* mma_dot_a = out_layout->hmma_dot_a();
|
||||
ir::value* mma_dot_b = out_layout->hmma_dot_b();
|
||||
ir::value* mma_dot_a = layout->hmma_dot_a();
|
||||
ir::value* mma_dot_b = layout->hmma_dot_b();
|
||||
|
||||
if(!mma_dot_a && !mma_dot_b){
|
||||
per_phase_[out_layout] = 1;
|
||||
max_phase_[out_layout] = 1;
|
||||
vec_[out_layout] = 1;
|
||||
per_phase_[layout] = 1;
|
||||
max_phase_[layout] = 1;
|
||||
vec_[layout] = 1;
|
||||
continue;
|
||||
}
|
||||
auto ord = out_layout->get_order();
|
||||
int dtsize = out_layout->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
|
||||
auto ord = layout->get_order();
|
||||
scanline_layout* in_layout = dynamic_cast<scanline_layout*>(layout->get_arg_layout());
|
||||
if(!in_layout)
|
||||
continue;
|
||||
int dtsize = layout->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
|
||||
if(tgt_->as_nvidia() && tgt_->as_nvidia()->sm() < 80){
|
||||
int inner = mma_dot_a ? 0 : 1;
|
||||
per_phase_[out_layout] = std::max<int>(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1);
|
||||
max_phase_[out_layout] = (ord[inner] == 1 ? 8 : 4) / per_phase_[out_layout];
|
||||
per_phase_[layout] = std::max<int>(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1);
|
||||
max_phase_[layout] = (ord[inner] == 1 ? 8 : 4) / per_phase_[layout];
|
||||
if(mma_dot_a)
|
||||
vec_[out_layout] = 2*layouts_->get(mma_dot_a)->to_mma()->rep(0);
|
||||
vec_[layout] = 2*layouts_->get(mma_dot_a)->to_mma()->rep(0);
|
||||
else
|
||||
vec_[out_layout] = 2*layouts_->get(mma_dot_b)->to_mma()->rep(1);
|
||||
vec_[layout] = 2*layouts_->get(mma_dot_b)->to_mma()->rep(1);
|
||||
}
|
||||
else {
|
||||
if (!out_layout->allow_swizzle()) {
|
||||
per_phase_[out_layout] = 1;
|
||||
max_phase_[out_layout] = 1;
|
||||
vec_[out_layout] = 1;
|
||||
if (!layout->allow_swizzle()) {
|
||||
per_phase_[layout] = 1;
|
||||
max_phase_[layout] = 1;
|
||||
vec_[layout] = 1;
|
||||
} else {
|
||||
per_phase_[out_layout] = std::max<int>(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1);
|
||||
max_phase_[out_layout] = out_layout->get_mma_strided() / per_phase_[out_layout];
|
||||
vec_[out_layout] = out_layout->get_mma_vec();
|
||||
per_phase_[layout] = std::max<int>(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1);
|
||||
max_phase_[layout] = layout->get_mma_strided() / per_phase_[layout];
|
||||
vec_[layout] = layout->get_mma_vec();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user