[CODEGEN] Now padding shared memory for layout conversion (#468)

This commit is contained in:
Philippe Tillet
2022-03-03 22:19:05 -08:00
committed by GitHub
parent d9dd97492f
commit bb5765df5c
4 changed files with 62 additions and 35 deletions

View File

@@ -2377,8 +2377,11 @@ void generator::visit_layout_convert(ir::value *out, ir::value *in){
}
in_ord = in_layout->to_mma() ? out_ord : in_ord;
out_ord = out_layout->to_mma() ? in_ord : out_ord;
Value *in_ld = i32(shape[in_ord[0]]);
Value *out_ld = i32(shape[out_ord[0]]);
int in_vec = out_ord[0] == 0 ? 1 : in_layout->contig_per_thread(in_ord[0]);
int out_vec = out_ord[0] == 0 ? 1 : out_layout->contig_per_thread(out_ord[0]);
int pad = std::max(in_vec, out_vec);
Value *in_ld = i32(shape[in_ord[0]] + pad);
Value *out_ld = i32(shape[out_ord[0]] + pad);
for(int i = 0; i < n_reps[0]; i++)
for(int j = 0; j < n_reps[1]; j++){
int max_ii, max_jj;
@@ -2386,29 +2389,39 @@ void generator::visit_layout_convert(ir::value *out, ir::value *in){
max_ii = in_ax[0].size()/n_reps[0];
max_jj = in_ax[1].size()/n_reps[1];
for(int ii = 0; ii < max_ii; ii++)
for(int jj = 0; jj < max_jj; jj++){
for(int jj = 0; jj < max_jj; jj+=in_vec){
// shared mem pointer
indices_t offs = {in_ax[0][ii], in_ax[1][jj]};
Value *off = add(offs[out_ord[0]], mul(out_ld, offs[out_ord[1]]));
Value *ptr = gep(base, off);
// stash value to shared mem
indices_t idxs = {in_ax[0][i*max_ii + ii],
in_ax[1][j*max_jj + jj]};
store(bit_cast(vals_[in][idxs], ty), ptr);
Value* vals = UndefValue::get(vec_ty(ty, in_vec));
for(int jjj = 0; jjj < in_vec; jjj++){
indices_t idxs = {in_ax[0][i*max_ii + ii],
in_ax[1][j*max_jj + jj + jjj]};
Value* val = bit_cast(vals_[in][idxs], ty);
vals = insert_elt(vals, val, jjj);
}
ptr = bit_cast(ptr, ptr_ty(vals->getType(), ptr->getType()->getPointerAddressSpace()));
store(vals, ptr);
}
add_barrier();
max_ii = out_ax[0].size()/n_reps[0];
max_jj = out_ax[1].size()/n_reps[1];
for(int ii = 0; ii < max_ii; ii++)
for(int jj = 0; jj < max_jj; jj++){
for(int jj = 0; jj < max_jj; jj+=out_vec){
// shared mem pointer
indices_t offs = {out_ax[0][ii], out_ax[1][jj]};
Value *off = add(offs[out_ord[0]], mul(out_ld, offs[out_ord[1]]));
Value *ptr = gep(base, off);
ptr = bit_cast(ptr, ptr_ty(vec_ty(ty, out_vec), ptr->getType()->getPointerAddressSpace()));
// load value from shared rem
indices_t idxs = {out_ax[0][i*max_ii + ii],
out_ax[1][j*max_jj + jj]};
vals_[out][idxs] = load(ptr);
Value* vals = load(ptr);
for(int jjj = 0; jjj < out_vec; jjj++){
indices_t idxs = {out_ax[0][i*max_ii + ii],
out_ax[1][j*max_jj + jj + jjj]};
vals_[out][idxs] = extract_elt(vals, jjj);
}
}
}