[BACKEND] Fix layout convert for non-contiguous input (#564)

This commit is contained in:
Keren Zhou
2022-06-25 23:12:03 -07:00
committed by GitHub
parent d345ddf837
commit 87413bc925
2 changed files with 20 additions and 7 deletions

View File

@@ -2638,8 +2638,6 @@ void generator::visit_layout_convert(ir::value *out, ir::value *in){
// Orders
analysis::distributed_layout* in_layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(in));
analysis::distributed_layout* out_layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(out));
auto in_ord = in_layout->get_order();
auto out_ord = out_layout->get_order();
Value *base;
base = gep(shmem_, i32(alloc_->offset(layouts_->get(layouts_->tmp(out)))));
base = bit_cast(base, ptr_ty(ty, 3));
@@ -2656,9 +2654,16 @@ void generator::visit_layout_convert(ir::value *out, ir::value *in){
in_ax.push_back(axes_.at(a_axes_->get(in, d)).values);
out_ax.push_back(axes_.at(a_axes_->get(out, d)).values);
}
in_ord = in_layout->to_mma() ? out_ord : in_ord;
out_ord = out_layout->to_mma() ? in_ord : out_ord;
int in_vec = out_ord[0] == 0 ? 1 : in_layout->contig_per_thread(in_ord[0]);
auto in_ord =
in_layout->to_mma() ? out_layout->get_order() : in_layout->get_order();
auto out_ord =
out_layout->to_mma() ? in_layout->get_order() : out_layout->get_order();
// out_ord[0] == 0 or in_order[0] == 0 means the first dimension is
// non-contiguous. in_vec can be greater than 0 only if both out_ord[0] and
// and in_ord[0] are contiguous.
int in_vec = out_ord[0] == 0 ? 1
: in_ord[0] == 0 ? 1
: in_layout->contig_per_thread(in_ord[0]);
int out_vec = out_ord[0] == 0 ? 1 : out_layout->contig_per_thread(out_ord[0]);
int pad = std::max(in_vec, out_vec);
Value *in_ld = i32(shape[in_ord[0]] + pad);