diff --git a/lib/codegen/analysis/layout.cc b/lib/codegen/analysis/layout.cc index fd0a7879b..5d30a2f45 100644 --- a/lib/codegen/analysis/layout.cc +++ b/lib/codegen/analysis/layout.cc @@ -213,7 +213,7 @@ mma_layout::mma_layout(size_t num_warps, else{ // fpw_ = {1, 1, 1}; spw_ = mma_instr_shape_.at(tensor_core_type_); // e.g., {16, 8, 16} for f32.f16.f16.f32 - contig_per_thread_ = {1, 1}; + contig_per_thread_ = {1, 2}; // rep_ = {2, 2, 1}; } order_ = {0, 1}; diff --git a/lib/codegen/analysis/swizzle.cc b/lib/codegen/analysis/swizzle.cc index 414b0e0e5..5737f80a0 100644 --- a/lib/codegen/analysis/swizzle.cc +++ b/lib/codegen/analysis/swizzle.cc @@ -14,42 +14,41 @@ void swizzle::run(ir::module &) { max_phase_.clear(); for(auto &x: layouts_->get_all()){ - shared_layout* out_layout = dynamic_cast(x.second); - if(!out_layout) + shared_layout* layout = dynamic_cast(x.second); + if(!layout) continue; - scanline_layout* in_layout = dynamic_cast(out_layout->get_arg_layout()); - if(!in_layout) - continue; - - ir::value* mma_dot_a = out_layout->hmma_dot_a(); - ir::value* mma_dot_b = out_layout->hmma_dot_b(); + ir::value* mma_dot_a = layout->hmma_dot_a(); + ir::value* mma_dot_b = layout->hmma_dot_b(); if(!mma_dot_a && !mma_dot_b){ - per_phase_[out_layout] = 1; - max_phase_[out_layout] = 1; - vec_[out_layout] = 1; + per_phase_[layout] = 1; + max_phase_[layout] = 1; + vec_[layout] = 1; continue; } - auto ord = out_layout->get_order(); - int dtsize = out_layout->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8; + auto ord = layout->get_order(); + scanline_layout* in_layout = dynamic_cast(layout->get_arg_layout()); + if(!in_layout) + continue; + int dtsize = layout->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8; if(tgt_->as_nvidia() && tgt_->as_nvidia()->sm() < 80){ int inner = mma_dot_a ? 0 : 1; - per_phase_[out_layout] = std::max(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1); - max_phase_[out_layout] = (ord[inner] == 1 ? 8 : 4) / per_phase_[out_layout]; + per_phase_[layout] = std::max(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1); + max_phase_[layout] = (ord[inner] == 1 ? 8 : 4) / per_phase_[layout]; if(mma_dot_a) - vec_[out_layout] = 2*layouts_->get(mma_dot_a)->to_mma()->rep(0); + vec_[layout] = 2*layouts_->get(mma_dot_a)->to_mma()->rep(0); else - vec_[out_layout] = 2*layouts_->get(mma_dot_b)->to_mma()->rep(1); + vec_[layout] = 2*layouts_->get(mma_dot_b)->to_mma()->rep(1); } else { - if (!out_layout->allow_swizzle()) { - per_phase_[out_layout] = 1; - max_phase_[out_layout] = 1; - vec_[out_layout] = 1; + if (!layout->allow_swizzle()) { + per_phase_[layout] = 1; + max_phase_[layout] = 1; + vec_[layout] = 1; } else { - per_phase_[out_layout] = std::max(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1); - max_phase_[out_layout] = out_layout->get_mma_strided() / per_phase_[out_layout]; - vec_[out_layout] = out_layout->get_mma_vec(); + per_phase_[layout] = std::max(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1); + max_phase_[layout] = layout->get_mma_strided() / per_phase_[layout]; + vec_[layout] = layout->get_mma_vec(); } } }