[CODEGEN] Now padding shared memory for layout conversion (#468)
This commit is contained in:
@@ -2377,8 +2377,11 @@ void generator::visit_layout_convert(ir::value *out, ir::value *in){
|
||||
}
|
||||
in_ord = in_layout->to_mma() ? out_ord : in_ord;
|
||||
out_ord = out_layout->to_mma() ? in_ord : out_ord;
|
||||
Value *in_ld = i32(shape[in_ord[0]]);
|
||||
Value *out_ld = i32(shape[out_ord[0]]);
|
||||
int in_vec = out_ord[0] == 0 ? 1 : in_layout->contig_per_thread(in_ord[0]);
|
||||
int out_vec = out_ord[0] == 0 ? 1 : out_layout->contig_per_thread(out_ord[0]);
|
||||
int pad = std::max(in_vec, out_vec);
|
||||
Value *in_ld = i32(shape[in_ord[0]] + pad);
|
||||
Value *out_ld = i32(shape[out_ord[0]] + pad);
|
||||
for(int i = 0; i < n_reps[0]; i++)
|
||||
for(int j = 0; j < n_reps[1]; j++){
|
||||
int max_ii, max_jj;
|
||||
@@ -2386,29 +2389,39 @@ void generator::visit_layout_convert(ir::value *out, ir::value *in){
|
||||
max_ii = in_ax[0].size()/n_reps[0];
|
||||
max_jj = in_ax[1].size()/n_reps[1];
|
||||
for(int ii = 0; ii < max_ii; ii++)
|
||||
for(int jj = 0; jj < max_jj; jj++){
|
||||
for(int jj = 0; jj < max_jj; jj+=in_vec){
|
||||
// shared mem pointer
|
||||
indices_t offs = {in_ax[0][ii], in_ax[1][jj]};
|
||||
Value *off = add(offs[out_ord[0]], mul(out_ld, offs[out_ord[1]]));
|
||||
Value *ptr = gep(base, off);
|
||||
// stash value to shared mem
|
||||
indices_t idxs = {in_ax[0][i*max_ii + ii],
|
||||
in_ax[1][j*max_jj + jj]};
|
||||
store(bit_cast(vals_[in][idxs], ty), ptr);
|
||||
Value* vals = UndefValue::get(vec_ty(ty, in_vec));
|
||||
for(int jjj = 0; jjj < in_vec; jjj++){
|
||||
indices_t idxs = {in_ax[0][i*max_ii + ii],
|
||||
in_ax[1][j*max_jj + jj + jjj]};
|
||||
Value* val = bit_cast(vals_[in][idxs], ty);
|
||||
vals = insert_elt(vals, val, jjj);
|
||||
}
|
||||
ptr = bit_cast(ptr, ptr_ty(vals->getType(), ptr->getType()->getPointerAddressSpace()));
|
||||
store(vals, ptr);
|
||||
}
|
||||
add_barrier();
|
||||
max_ii = out_ax[0].size()/n_reps[0];
|
||||
max_jj = out_ax[1].size()/n_reps[1];
|
||||
for(int ii = 0; ii < max_ii; ii++)
|
||||
for(int jj = 0; jj < max_jj; jj++){
|
||||
for(int jj = 0; jj < max_jj; jj+=out_vec){
|
||||
// shared mem pointer
|
||||
indices_t offs = {out_ax[0][ii], out_ax[1][jj]};
|
||||
Value *off = add(offs[out_ord[0]], mul(out_ld, offs[out_ord[1]]));
|
||||
Value *ptr = gep(base, off);
|
||||
ptr = bit_cast(ptr, ptr_ty(vec_ty(ty, out_vec), ptr->getType()->getPointerAddressSpace()));
|
||||
// load value from shared rem
|
||||
indices_t idxs = {out_ax[0][i*max_ii + ii],
|
||||
out_ax[1][j*max_jj + jj]};
|
||||
vals_[out][idxs] = load(ptr);
|
||||
Value* vals = load(ptr);
|
||||
for(int jjj = 0; jjj < out_vec; jjj++){
|
||||
indices_t idxs = {out_ax[0][i*max_ii + ii],
|
||||
out_ax[1][j*max_jj + jj + jjj]};
|
||||
vals_[out][idxs] = extract_elt(vals, jjj);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
Reference in New Issue
Block a user