[BACKEND][CODEGEN] vectorization bugfix (#502)
This commit is contained in:
@@ -129,6 +129,33 @@ std::vector<align::cst_info> align::populate_is_constant_broadcast(ir::broadcast
|
||||
return add_to_cache(x, result, is_constant_);
|
||||
}
|
||||
|
||||
std::vector<align::cst_info> align::populate_is_constant_cmp(ir::cmp_inst* x) {
|
||||
auto x_shapes = get_shapes(x);
|
||||
std::vector<cst_info> result;
|
||||
ir::value* lhs_op = x->get_operand(0);
|
||||
ir::value* rhs_op = x->get_operand(1);
|
||||
auto lhs = populate_is_constant(lhs_op);
|
||||
auto rhs = populate_is_constant(rhs_op);
|
||||
auto lhs_max_contiguous = populate_max_contiguous(lhs_op);
|
||||
auto rhs_max_contiguous = populate_max_contiguous(rhs_op);
|
||||
auto lhs_multiple_of = populate_starting_multiple(lhs_op);
|
||||
auto rhs_multiple_of = populate_starting_multiple(rhs_op);
|
||||
for(size_t d = 0; d < x_shapes.size(); d++) {
|
||||
cst_info ax = {1, 0};
|
||||
// if lhs (resp. rhs) is a range of M value starting at a multiple of N
|
||||
// and rhs (resp. lhs) is made of M constants that are multiples of N
|
||||
// then comparisons have M constants
|
||||
int min_multiple = std::min(lhs_multiple_of[d], rhs_multiple_of[d]);
|
||||
if(rhs[d].num_cst % lhs_max_contiguous[d] == 0)
|
||||
ax = {std::min<int>(min_multiple, lhs_max_contiguous[d]), 0};
|
||||
else if(lhs[d].num_cst % rhs_max_contiguous[d] == 0)
|
||||
ax = {std::min<int>(min_multiple, rhs_max_contiguous[d]), 0};
|
||||
result.push_back(ax);
|
||||
}
|
||||
return add_to_cache(x, result, is_constant_);
|
||||
}
|
||||
|
||||
|
||||
std::vector<align::cst_info> align::populate_is_constant_binop(ir::binary_operator* x) {
|
||||
auto x_shapes = get_shapes(x);
|
||||
std::vector<cst_info> result;
|
||||
@@ -136,12 +163,15 @@ std::vector<align::cst_info> align::populate_is_constant_binop(ir::binary_operat
|
||||
ir::value* rhs_op = x->get_operand(1);
|
||||
auto lhs = populate_is_constant(lhs_op);
|
||||
auto rhs = populate_is_constant(rhs_op);
|
||||
auto max_contiguous = populate_max_contiguous(lhs_op);
|
||||
auto lhs_max_contiguous = populate_max_contiguous(lhs_op);
|
||||
auto rhs_max_contiguous = populate_max_contiguous(rhs_op);
|
||||
auto lhs_multiple_of = populate_starting_multiple(lhs_op);
|
||||
auto rhs_multiple_of = populate_starting_multiple(rhs_op);
|
||||
for(size_t d = 0; d < x_shapes.size(); d++) {
|
||||
cst_info ax;
|
||||
if(lhs[d].num_cst==0 && rhs[d].value && x->is_int_div()){
|
||||
// todo might not be entirely true
|
||||
unsigned num_constants = gcd(max_contiguous[d], rhs[d].value);
|
||||
unsigned num_constants = gcd(lhs_max_contiguous[d], rhs[d].value);
|
||||
ax = {num_constants, 0};
|
||||
}
|
||||
else
|
||||
@@ -184,6 +214,8 @@ std::vector<align::cst_info> align::populate_is_constant(ir::value *v) {
|
||||
return populate_is_constant_broadcast(x);
|
||||
if(auto *x = dynamic_cast<ir::binary_operator*>(v))
|
||||
return populate_is_constant_binop(x);
|
||||
if(auto *x = dynamic_cast<ir::cmp_inst*>(v))
|
||||
return populate_is_constant_cmp(x);
|
||||
if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v))
|
||||
return populate_is_constant_gep(x);
|
||||
return populate_is_constant_default(v);
|
||||
@@ -511,12 +543,15 @@ std::vector<unsigned> align::contiguous(ir::value* v) const {
|
||||
return max_contiguous_.at(v);
|
||||
}
|
||||
|
||||
std::vector<align::cst_info> align::get_cst_info(ir::value* v) const {
|
||||
return is_constant_.at(v);
|
||||
}
|
||||
|
||||
|
||||
void align::populate(ir::value *v) {
|
||||
populate_is_constant(v);
|
||||
populate_starting_multiple(v);
|
||||
populate_max_contiguous(v);
|
||||
|
||||
}
|
||||
|
||||
void align::run(ir::module &mod) {
|
||||
|
@@ -744,6 +744,11 @@ void generator::visit_load_inst(ir::load_inst* x){
|
||||
if(op->get_type()->is_block_ty()){
|
||||
auto ord = ords_.at(op);
|
||||
size_t aln = alignment_->get(op, ord[0]);
|
||||
if(mx){
|
||||
size_t max_eq = alignment_->get_cst_info(mx->get_mask_operand())[ord[0]].num_cst;
|
||||
max_eq = std::max<size_t>(max_eq, 1);
|
||||
aln = std::min(aln, max_eq);
|
||||
}
|
||||
auto layout = layouts_->get(x)->to_scanline();
|
||||
if(layout){
|
||||
size_t nts = layout->nts(ord[0]);
|
||||
@@ -912,6 +917,11 @@ void generator::visit_store_inst(ir::store_inst * x){
|
||||
auto ord = ords_.at(x->get_pointer_operand());
|
||||
size_t aln = alignment_->get(ptr_op, ord[0]);
|
||||
size_t nts = axes_.at(a_axes_->get(x->get_pointer_operand(), ord[0])).contiguous;
|
||||
if(mx){
|
||||
size_t max_eq = alignment_->get_cst_info(mx->get_mask_operand())[ord[0]].num_cst;
|
||||
max_eq = std::max<size_t>(max_eq, 1);
|
||||
aln = std::min(aln, max_eq);
|
||||
}
|
||||
vec = std::min(nts, aln);
|
||||
}
|
||||
auto idxs = idxs_.at(val_op);
|
||||
|
Reference in New Issue
Block a user