diff --git a/include/triton/codegen/analysis/layout.h b/include/triton/codegen/analysis/layout.h index b6376d7cc..56fb1e4b9 100644 --- a/include/triton/codegen/analysis/layout.h +++ b/include/triton/codegen/analysis/layout.h @@ -103,6 +103,7 @@ public: int shape_per_cta(size_t k) { return shape_per_cta_.at(k); } int rep_per_cta(size_t k) { return shape_[k] / shape_per_cta_[k]; } + virtual int contig_per_thread(size_t k) = 0; protected: std::vector shape_per_cta_; @@ -181,6 +182,7 @@ public: int wpt(size_t k) { return wpt_.at(k); } int spw(size_t k) { return spw_.at(k); } int rep(size_t k) { return rep_.at(k); } + int contig_per_thread(size_t k) { return contig_per_thread_.at(k); } // helpers for generator.cc std::string get_ptx_instr() const { return mma_instr_ptx_.at(tensor_core_type_); } @@ -203,6 +205,8 @@ private: std::vector spt_; // repetitions std::vector rep_; + // contiguous per thread + std::vector contig_per_thread_; TensorCoreType tensor_core_type_ = FP32_FP16_FP16_FP32; }; @@ -218,6 +222,7 @@ struct scanline_layout: public distributed_layout { // accessor int mts(size_t k) { return mts_.at(k); } int nts(size_t k) { return nts_.at(k); } + int contig_per_thread(size_t k) { return nts_.at(k); } public: // micro tile size. The size of a tile held by a thread block. diff --git a/lib/codegen/analysis/layout.cc b/lib/codegen/analysis/layout.cc index 587234863..fd0a7879b 100644 --- a/lib/codegen/analysis/layout.cc +++ b/lib/codegen/analysis/layout.cc @@ -208,10 +208,12 @@ mma_layout::mma_layout(size_t num_warps, int pack_size_1 = (is_b_row && !is_b_vec4) ? 2 : 1; rep_ = {2*pack_size_0, 2*pack_size_1, 1}; spw_ = {fpw_[0]*4*rep_[0], fpw_[1]*4*rep_[1], 1}; + contig_per_thread_ = {1, 1}; } else{ // fpw_ = {1, 1, 1}; spw_ = mma_instr_shape_.at(tensor_core_type_); // e.g., {16, 8, 16} for f32.f16.f16.f32 + contig_per_thread_ = {1, 1}; // rep_ = {2, 2, 1}; } order_ = {0, 1}; @@ -628,6 +630,12 @@ void layouts::run(ir::module &mod) { shape[k] = std::max(in_layout->shape_per_cta(k), out_layout->shape_per_cta(k)); } + auto in_ord = in_layout->get_order(); + auto out_ord = out_layout->get_order(); + int in_vec = in_layout->contig_per_thread(in_ord[0]); + int out_vec = out_layout->contig_per_thread(out_ord[0]); + int pad = std::max(in_vec, out_vec); + shape[out_ord[0]] += pad; layouts_[id] = new shared_layout(out_layout, axes_->get(val), shape, {val}, val->get_type()->get_scalar_ty(), align_, tgt_); tmp_[val] = id; } diff --git a/lib/codegen/analysis/swizzle.cc b/lib/codegen/analysis/swizzle.cc index 5737f80a0..414b0e0e5 100644 --- a/lib/codegen/analysis/swizzle.cc +++ b/lib/codegen/analysis/swizzle.cc @@ -14,41 +14,42 @@ void swizzle::run(ir::module &) { max_phase_.clear(); for(auto &x: layouts_->get_all()){ - shared_layout* layout = dynamic_cast(x.second); - if(!layout) + shared_layout* out_layout = dynamic_cast(x.second); + if(!out_layout) continue; - ir::value* mma_dot_a = layout->hmma_dot_a(); - ir::value* mma_dot_b = layout->hmma_dot_b(); - - if(!mma_dot_a && !mma_dot_b){ - per_phase_[layout] = 1; - max_phase_[layout] = 1; - vec_[layout] = 1; - continue; - } - auto ord = layout->get_order(); - scanline_layout* in_layout = dynamic_cast(layout->get_arg_layout()); + scanline_layout* in_layout = dynamic_cast(out_layout->get_arg_layout()); if(!in_layout) continue; - int dtsize = layout->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8; + + ir::value* mma_dot_a = out_layout->hmma_dot_a(); + ir::value* mma_dot_b = out_layout->hmma_dot_b(); + + if(!mma_dot_a && !mma_dot_b){ + per_phase_[out_layout] = 1; + max_phase_[out_layout] = 1; + vec_[out_layout] = 1; + continue; + } + auto ord = out_layout->get_order(); + int dtsize = out_layout->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8; if(tgt_->as_nvidia() && tgt_->as_nvidia()->sm() < 80){ int inner = mma_dot_a ? 0 : 1; - per_phase_[layout] = std::max(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1); - max_phase_[layout] = (ord[inner] == 1 ? 8 : 4) / per_phase_[layout]; + per_phase_[out_layout] = std::max(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1); + max_phase_[out_layout] = (ord[inner] == 1 ? 8 : 4) / per_phase_[out_layout]; if(mma_dot_a) - vec_[layout] = 2*layouts_->get(mma_dot_a)->to_mma()->rep(0); + vec_[out_layout] = 2*layouts_->get(mma_dot_a)->to_mma()->rep(0); else - vec_[layout] = 2*layouts_->get(mma_dot_b)->to_mma()->rep(1); + vec_[out_layout] = 2*layouts_->get(mma_dot_b)->to_mma()->rep(1); } else { - if (!layout->allow_swizzle()) { - per_phase_[layout] = 1; - max_phase_[layout] = 1; - vec_[layout] = 1; + if (!out_layout->allow_swizzle()) { + per_phase_[out_layout] = 1; + max_phase_[out_layout] = 1; + vec_[out_layout] = 1; } else { - per_phase_[layout] = std::max(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1); - max_phase_[layout] = layout->get_mma_strided() / per_phase_[layout]; - vec_[layout] = layout->get_mma_vec(); + per_phase_[out_layout] = std::max(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1); + max_phase_[out_layout] = out_layout->get_mma_strided() / per_phase_[out_layout]; + vec_[out_layout] = out_layout->get_mma_vec(); } } } diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index b4f1dd41e..f8cf08cba 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -2377,8 +2377,11 @@ void generator::visit_layout_convert(ir::value *out, ir::value *in){ } in_ord = in_layout->to_mma() ? out_ord : in_ord; out_ord = out_layout->to_mma() ? in_ord : out_ord; - Value *in_ld = i32(shape[in_ord[0]]); - Value *out_ld = i32(shape[out_ord[0]]); + int in_vec = out_ord[0] == 0 ? 1 : in_layout->contig_per_thread(in_ord[0]); + int out_vec = out_ord[0] == 0 ? 1 : out_layout->contig_per_thread(out_ord[0]); + int pad = std::max(in_vec, out_vec); + Value *in_ld = i32(shape[in_ord[0]] + pad); + Value *out_ld = i32(shape[out_ord[0]] + pad); for(int i = 0; i < n_reps[0]; i++) for(int j = 0; j < n_reps[1]; j++){ int max_ii, max_jj; @@ -2386,29 +2389,39 @@ void generator::visit_layout_convert(ir::value *out, ir::value *in){ max_ii = in_ax[0].size()/n_reps[0]; max_jj = in_ax[1].size()/n_reps[1]; for(int ii = 0; ii < max_ii; ii++) - for(int jj = 0; jj < max_jj; jj++){ + for(int jj = 0; jj < max_jj; jj+=in_vec){ // shared mem pointer indices_t offs = {in_ax[0][ii], in_ax[1][jj]}; Value *off = add(offs[out_ord[0]], mul(out_ld, offs[out_ord[1]])); Value *ptr = gep(base, off); // stash value to shared mem - indices_t idxs = {in_ax[0][i*max_ii + ii], - in_ax[1][j*max_jj + jj]}; - store(bit_cast(vals_[in][idxs], ty), ptr); + Value* vals = UndefValue::get(vec_ty(ty, in_vec)); + for(int jjj = 0; jjj < in_vec; jjj++){ + indices_t idxs = {in_ax[0][i*max_ii + ii], + in_ax[1][j*max_jj + jj + jjj]}; + Value* val = bit_cast(vals_[in][idxs], ty); + vals = insert_elt(vals, val, jjj); + } + ptr = bit_cast(ptr, ptr_ty(vals->getType(), ptr->getType()->getPointerAddressSpace())); + store(vals, ptr); } add_barrier(); max_ii = out_ax[0].size()/n_reps[0]; max_jj = out_ax[1].size()/n_reps[1]; for(int ii = 0; ii < max_ii; ii++) - for(int jj = 0; jj < max_jj; jj++){ + for(int jj = 0; jj < max_jj; jj+=out_vec){ // shared mem pointer indices_t offs = {out_ax[0][ii], out_ax[1][jj]}; Value *off = add(offs[out_ord[0]], mul(out_ld, offs[out_ord[1]])); Value *ptr = gep(base, off); + ptr = bit_cast(ptr, ptr_ty(vec_ty(ty, out_vec), ptr->getType()->getPointerAddressSpace())); // load value from shared rem - indices_t idxs = {out_ax[0][i*max_ii + ii], - out_ax[1][j*max_jj + jj]}; - vals_[out][idxs] = load(ptr); + Value* vals = load(ptr); + for(int jjj = 0; jjj < out_vec; jjj++){ + indices_t idxs = {out_ax[0][i*max_ii + ii], + out_ax[1][j*max_jj + jj + jjj]}; + vals_[out][idxs] = extract_elt(vals, jjj); + } } }