[CODEGEN] Now padding shared memory for layout conversion (#468)
This commit is contained in:
@@ -103,6 +103,7 @@ public:
|
||||
|
||||
int shape_per_cta(size_t k) { return shape_per_cta_.at(k); }
|
||||
int rep_per_cta(size_t k) { return shape_[k] / shape_per_cta_[k]; }
|
||||
virtual int contig_per_thread(size_t k) = 0;
|
||||
|
||||
protected:
|
||||
std::vector<int> shape_per_cta_;
|
||||
@@ -181,6 +182,7 @@ public:
|
||||
int wpt(size_t k) { return wpt_.at(k); }
|
||||
int spw(size_t k) { return spw_.at(k); }
|
||||
int rep(size_t k) { return rep_.at(k); }
|
||||
int contig_per_thread(size_t k) { return contig_per_thread_.at(k); }
|
||||
|
||||
// helpers for generator.cc
|
||||
std::string get_ptx_instr() const { return mma_instr_ptx_.at(tensor_core_type_); }
|
||||
@@ -203,6 +205,8 @@ private:
|
||||
std::vector<int> spt_;
|
||||
// repetitions
|
||||
std::vector<int> rep_;
|
||||
// contiguous per thread
|
||||
std::vector<int> contig_per_thread_;
|
||||
|
||||
TensorCoreType tensor_core_type_ = FP32_FP16_FP16_FP32;
|
||||
};
|
||||
@@ -218,6 +222,7 @@ struct scanline_layout: public distributed_layout {
|
||||
// accessor
|
||||
int mts(size_t k) { return mts_.at(k); }
|
||||
int nts(size_t k) { return nts_.at(k); }
|
||||
int contig_per_thread(size_t k) { return nts_.at(k); }
|
||||
|
||||
public:
|
||||
// micro tile size. The size of a tile held by a thread block.
|
||||
|
@@ -208,10 +208,12 @@ mma_layout::mma_layout(size_t num_warps,
|
||||
int pack_size_1 = (is_b_row && !is_b_vec4) ? 2 : 1;
|
||||
rep_ = {2*pack_size_0, 2*pack_size_1, 1};
|
||||
spw_ = {fpw_[0]*4*rep_[0], fpw_[1]*4*rep_[1], 1};
|
||||
contig_per_thread_ = {1, 1};
|
||||
}
|
||||
else{
|
||||
// fpw_ = {1, 1, 1};
|
||||
spw_ = mma_instr_shape_.at(tensor_core_type_); // e.g., {16, 8, 16} for f32.f16.f16.f32
|
||||
contig_per_thread_ = {1, 1};
|
||||
// rep_ = {2, 2, 1};
|
||||
}
|
||||
order_ = {0, 1};
|
||||
@@ -628,6 +630,12 @@ void layouts::run(ir::module &mod) {
|
||||
shape[k] = std::max(in_layout->shape_per_cta(k),
|
||||
out_layout->shape_per_cta(k));
|
||||
}
|
||||
auto in_ord = in_layout->get_order();
|
||||
auto out_ord = out_layout->get_order();
|
||||
int in_vec = in_layout->contig_per_thread(in_ord[0]);
|
||||
int out_vec = out_layout->contig_per_thread(out_ord[0]);
|
||||
int pad = std::max(in_vec, out_vec);
|
||||
shape[out_ord[0]] += pad;
|
||||
layouts_[id] = new shared_layout(out_layout, axes_->get(val), shape, {val}, val->get_type()->get_scalar_ty(), align_, tgt_);
|
||||
tmp_[val] = id;
|
||||
}
|
||||
|
@@ -14,41 +14,42 @@ void swizzle::run(ir::module &) {
|
||||
max_phase_.clear();
|
||||
|
||||
for(auto &x: layouts_->get_all()){
|
||||
shared_layout* layout = dynamic_cast<shared_layout*>(x.second);
|
||||
if(!layout)
|
||||
shared_layout* out_layout = dynamic_cast<shared_layout*>(x.second);
|
||||
if(!out_layout)
|
||||
continue;
|
||||
ir::value* mma_dot_a = layout->hmma_dot_a();
|
||||
ir::value* mma_dot_b = layout->hmma_dot_b();
|
||||
|
||||
if(!mma_dot_a && !mma_dot_b){
|
||||
per_phase_[layout] = 1;
|
||||
max_phase_[layout] = 1;
|
||||
vec_[layout] = 1;
|
||||
continue;
|
||||
}
|
||||
auto ord = layout->get_order();
|
||||
scanline_layout* in_layout = dynamic_cast<scanline_layout*>(layout->get_arg_layout());
|
||||
scanline_layout* in_layout = dynamic_cast<scanline_layout*>(out_layout->get_arg_layout());
|
||||
if(!in_layout)
|
||||
continue;
|
||||
int dtsize = layout->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
|
||||
|
||||
ir::value* mma_dot_a = out_layout->hmma_dot_a();
|
||||
ir::value* mma_dot_b = out_layout->hmma_dot_b();
|
||||
|
||||
if(!mma_dot_a && !mma_dot_b){
|
||||
per_phase_[out_layout] = 1;
|
||||
max_phase_[out_layout] = 1;
|
||||
vec_[out_layout] = 1;
|
||||
continue;
|
||||
}
|
||||
auto ord = out_layout->get_order();
|
||||
int dtsize = out_layout->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
|
||||
if(tgt_->as_nvidia() && tgt_->as_nvidia()->sm() < 80){
|
||||
int inner = mma_dot_a ? 0 : 1;
|
||||
per_phase_[layout] = std::max<int>(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1);
|
||||
max_phase_[layout] = (ord[inner] == 1 ? 8 : 4) / per_phase_[layout];
|
||||
per_phase_[out_layout] = std::max<int>(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1);
|
||||
max_phase_[out_layout] = (ord[inner] == 1 ? 8 : 4) / per_phase_[out_layout];
|
||||
if(mma_dot_a)
|
||||
vec_[layout] = 2*layouts_->get(mma_dot_a)->to_mma()->rep(0);
|
||||
vec_[out_layout] = 2*layouts_->get(mma_dot_a)->to_mma()->rep(0);
|
||||
else
|
||||
vec_[layout] = 2*layouts_->get(mma_dot_b)->to_mma()->rep(1);
|
||||
vec_[out_layout] = 2*layouts_->get(mma_dot_b)->to_mma()->rep(1);
|
||||
}
|
||||
else {
|
||||
if (!layout->allow_swizzle()) {
|
||||
per_phase_[layout] = 1;
|
||||
max_phase_[layout] = 1;
|
||||
vec_[layout] = 1;
|
||||
if (!out_layout->allow_swizzle()) {
|
||||
per_phase_[out_layout] = 1;
|
||||
max_phase_[out_layout] = 1;
|
||||
vec_[out_layout] = 1;
|
||||
} else {
|
||||
per_phase_[layout] = std::max<int>(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1);
|
||||
max_phase_[layout] = layout->get_mma_strided() / per_phase_[layout];
|
||||
vec_[layout] = layout->get_mma_vec();
|
||||
per_phase_[out_layout] = std::max<int>(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1);
|
||||
max_phase_[out_layout] = out_layout->get_mma_strided() / per_phase_[out_layout];
|
||||
vec_[out_layout] = out_layout->get_mma_vec();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -2377,8 +2377,11 @@ void generator::visit_layout_convert(ir::value *out, ir::value *in){
|
||||
}
|
||||
in_ord = in_layout->to_mma() ? out_ord : in_ord;
|
||||
out_ord = out_layout->to_mma() ? in_ord : out_ord;
|
||||
Value *in_ld = i32(shape[in_ord[0]]);
|
||||
Value *out_ld = i32(shape[out_ord[0]]);
|
||||
int in_vec = out_ord[0] == 0 ? 1 : in_layout->contig_per_thread(in_ord[0]);
|
||||
int out_vec = out_ord[0] == 0 ? 1 : out_layout->contig_per_thread(out_ord[0]);
|
||||
int pad = std::max(in_vec, out_vec);
|
||||
Value *in_ld = i32(shape[in_ord[0]] + pad);
|
||||
Value *out_ld = i32(shape[out_ord[0]] + pad);
|
||||
for(int i = 0; i < n_reps[0]; i++)
|
||||
for(int j = 0; j < n_reps[1]; j++){
|
||||
int max_ii, max_jj;
|
||||
@@ -2386,29 +2389,39 @@ void generator::visit_layout_convert(ir::value *out, ir::value *in){
|
||||
max_ii = in_ax[0].size()/n_reps[0];
|
||||
max_jj = in_ax[1].size()/n_reps[1];
|
||||
for(int ii = 0; ii < max_ii; ii++)
|
||||
for(int jj = 0; jj < max_jj; jj++){
|
||||
for(int jj = 0; jj < max_jj; jj+=in_vec){
|
||||
// shared mem pointer
|
||||
indices_t offs = {in_ax[0][ii], in_ax[1][jj]};
|
||||
Value *off = add(offs[out_ord[0]], mul(out_ld, offs[out_ord[1]]));
|
||||
Value *ptr = gep(base, off);
|
||||
// stash value to shared mem
|
||||
indices_t idxs = {in_ax[0][i*max_ii + ii],
|
||||
in_ax[1][j*max_jj + jj]};
|
||||
store(bit_cast(vals_[in][idxs], ty), ptr);
|
||||
Value* vals = UndefValue::get(vec_ty(ty, in_vec));
|
||||
for(int jjj = 0; jjj < in_vec; jjj++){
|
||||
indices_t idxs = {in_ax[0][i*max_ii + ii],
|
||||
in_ax[1][j*max_jj + jj + jjj]};
|
||||
Value* val = bit_cast(vals_[in][idxs], ty);
|
||||
vals = insert_elt(vals, val, jjj);
|
||||
}
|
||||
ptr = bit_cast(ptr, ptr_ty(vals->getType(), ptr->getType()->getPointerAddressSpace()));
|
||||
store(vals, ptr);
|
||||
}
|
||||
add_barrier();
|
||||
max_ii = out_ax[0].size()/n_reps[0];
|
||||
max_jj = out_ax[1].size()/n_reps[1];
|
||||
for(int ii = 0; ii < max_ii; ii++)
|
||||
for(int jj = 0; jj < max_jj; jj++){
|
||||
for(int jj = 0; jj < max_jj; jj+=out_vec){
|
||||
// shared mem pointer
|
||||
indices_t offs = {out_ax[0][ii], out_ax[1][jj]};
|
||||
Value *off = add(offs[out_ord[0]], mul(out_ld, offs[out_ord[1]]));
|
||||
Value *ptr = gep(base, off);
|
||||
ptr = bit_cast(ptr, ptr_ty(vec_ty(ty, out_vec), ptr->getType()->getPointerAddressSpace()));
|
||||
// load value from shared rem
|
||||
indices_t idxs = {out_ax[0][i*max_ii + ii],
|
||||
out_ax[1][j*max_jj + jj]};
|
||||
vals_[out][idxs] = load(ptr);
|
||||
Value* vals = load(ptr);
|
||||
for(int jjj = 0; jjj < out_vec; jjj++){
|
||||
indices_t idxs = {out_ax[0][i*max_ii + ii],
|
||||
out_ax[1][j*max_jj + jj + jjj]};
|
||||
vals_[out][idxs] = extract_elt(vals, jjj);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
Reference in New Issue
Block a user