[CODEGEN] Fixed various issues in alignment inference pass

This commit is contained in:
Philippe Tillet
2020-06-06 11:28:43 -04:00
committed by Philippe Tillet
parent da6008128e
commit e18f169a39
2 changed files with 27 additions and 5 deletions

View File

@@ -11,6 +11,7 @@ namespace ir {
class module; class module;
class phi_node; class phi_node;
class splat_inst; class splat_inst;
class cast_inst;
class reshape_inst; class reshape_inst;
class broadcast_inst; class broadcast_inst;
class binary_operator; class binary_operator;
@@ -44,6 +45,7 @@ private:
std::vector<unsigned> populate_max_contiguous_broadcast(ir::broadcast_inst* x); std::vector<unsigned> populate_max_contiguous_broadcast(ir::broadcast_inst* x);
std::vector<unsigned> populate_max_contiguous_binop(ir::binary_operator* x); std::vector<unsigned> populate_max_contiguous_binop(ir::binary_operator* x);
std::vector<unsigned> populate_max_contiguous_gep(ir::getelementptr_inst* x); std::vector<unsigned> populate_max_contiguous_gep(ir::getelementptr_inst* x);
std::vector<unsigned> populate_max_contiguous_cast(ir::cast_inst* x);
std::vector<unsigned> populate_max_contiguous_default(ir::value* v); std::vector<unsigned> populate_max_contiguous_default(ir::value* v);
std::vector<unsigned> populate_max_contiguous(ir::value *v); std::vector<unsigned> populate_max_contiguous(ir::value *v);
// populate starting_multiple // populate starting_multiple
@@ -53,6 +55,7 @@ private:
std::vector<unsigned> populate_starting_multiple_broadcast(ir::broadcast_inst* x); std::vector<unsigned> populate_starting_multiple_broadcast(ir::broadcast_inst* x);
std::vector<unsigned> populate_starting_multiple_binop(ir::binary_operator* x); std::vector<unsigned> populate_starting_multiple_binop(ir::binary_operator* x);
std::vector<unsigned> populate_starting_multiple_gep(ir::getelementptr_inst* x); std::vector<unsigned> populate_starting_multiple_gep(ir::getelementptr_inst* x);
std::vector<unsigned> populate_starting_multiple_cast(ir::cast_inst* x);
std::vector<unsigned> populate_starting_multiple_default(ir::value* v); std::vector<unsigned> populate_starting_multiple_default(ir::value* v);
std::vector<unsigned> populate_starting_multiple(ir::value *v); std::vector<unsigned> populate_starting_multiple(ir::value *v);
// populate all maps // populate all maps

View File

@@ -327,9 +327,16 @@ std::vector<unsigned> align::populate_max_contiguous_default(ir::value* v) {
return add_to_cache(v, std::vector<unsigned>(shapes.size(), 1), max_contiguous_); return add_to_cache(v, std::vector<unsigned>(shapes.size(), 1), max_contiguous_);
} }
std::vector<unsigned> align::populate_max_contiguous_cast(ir::cast_inst* v){
auto result = populate_max_contiguous(v->get_operand(0));
return add_to_cache(v, result, max_contiguous_);
}
std::vector<unsigned> align::populate_max_contiguous(ir::value *v){ std::vector<unsigned> align::populate_max_contiguous(ir::value *v){
if(max_contiguous_.find(v) != max_contiguous_.end()) if(max_contiguous_.find(v) != max_contiguous_.end())
return max_contiguous_.at(v); return max_contiguous_.at(v);
if(auto *x = dynamic_cast<ir::cast_inst*>(v))
return populate_max_contiguous_cast(x);
if(auto *x = dynamic_cast<ir::splat_inst*>(v)) if(auto *x = dynamic_cast<ir::splat_inst*>(v))
return populate_max_contiguous_splat(x); return populate_max_contiguous_splat(x);
if(auto *x = dynamic_cast<ir::reshape_inst*>(v)) if(auto *x = dynamic_cast<ir::reshape_inst*>(v))
@@ -434,16 +441,16 @@ std::vector<unsigned> align::populate_starting_multiple_phi(ir::phi_node* x){
} }
std::vector<unsigned> align::populate_starting_multiple_cast(ir::cast_inst* x){
auto result = populate_starting_multiple(x->get_operand(0));
return add_to_cache(x, result, starting_multiple_);
}
std::vector<unsigned> align::populate_starting_multiple_default(ir::value* v) { std::vector<unsigned> align::populate_starting_multiple_default(ir::value* v) {
ir::type* ty = v->get_type(); ir::type* ty = v->get_type();
if(ty->is_tile_ty()) { if(ty->is_tile_ty()) {
return add_to_cache(v, ty->get_tile_shapes(), starting_multiple_); return add_to_cache(v, ty->get_tile_shapes(), starting_multiple_);
} }
if(auto *x = dynamic_cast<ir::instruction*>(v)){
unsigned multiple_of = x->get_metadata(ir::metadata::multiple_of);
if(multiple_of > 0)
return add_to_cache(x, {multiple_of}, starting_multiple_);
}
if(auto *x = dynamic_cast<ir::argument*>(v)){ if(auto *x = dynamic_cast<ir::argument*>(v)){
std::set<ir::attribute> attributes = x->get_parent()->get_attributes(x); std::set<ir::attribute> attributes = x->get_parent()->get_attributes(x);
for(auto attr: attributes){ for(auto attr: attributes){
@@ -464,6 +471,13 @@ std::vector<unsigned> align::populate_starting_multiple_default(ir::value* v) {
std::vector<unsigned> align::populate_starting_multiple(ir::value *v){ std::vector<unsigned> align::populate_starting_multiple(ir::value *v){
if(starting_multiple_.find(v) != starting_multiple_.end()) if(starting_multiple_.find(v) != starting_multiple_.end())
return starting_multiple_.at(v); return starting_multiple_.at(v);
if(auto *x = dynamic_cast<ir::instruction*>(v)){
unsigned multiple_of = x->get_metadata(ir::metadata::multiple_of);
if(multiple_of > 0)
return add_to_cache(x, {multiple_of}, starting_multiple_);
}
if(auto *x = dynamic_cast<ir::cast_inst*>(v))
return populate_starting_multiple_cast(x);
if(auto *x = dynamic_cast<ir::binary_operator*>(v)) if(auto *x = dynamic_cast<ir::binary_operator*>(v))
return populate_starting_multiple_binop(x); return populate_starting_multiple_binop(x);
if(auto *x = dynamic_cast<ir::constant_int*>(v)) if(auto *x = dynamic_cast<ir::constant_int*>(v))
@@ -508,6 +522,11 @@ void align::populate(ir::value *v) {
void align::run(ir::module &mod) { void align::run(ir::module &mod) {
ir::for_each_value(mod, [this](ir::value* v) { populate(v); } ); ir::for_each_value(mod, [this](ir::value* v) { populate(v); } );
// ir::for_each_value(mod, [this](ir::value* v) {
// if(dynamic_cast<ir::cast_inst*>(v) || dynamic_cast<ir::getelementptr_inst*>(v))
// std::cout << "ALIGN: " << v->get_name() << " " << starting_multiple_.at(v)[0] << " " << max_contiguous_.at(v)[0]
// << " " << starting_multiple_.at(v)[1] << " " << max_contiguous_.at(v)[1] << std::endl;
// });
} }