[CODEGEN] Fixed various issues in alignment inference pass
This commit is contained in:
committed by
Philippe Tillet
parent
da6008128e
commit
e18f169a39
@@ -11,6 +11,7 @@ namespace ir {
|
||||
class module;
|
||||
class phi_node;
|
||||
class splat_inst;
|
||||
class cast_inst;
|
||||
class reshape_inst;
|
||||
class broadcast_inst;
|
||||
class binary_operator;
|
||||
@@ -44,6 +45,7 @@ private:
|
||||
std::vector<unsigned> populate_max_contiguous_broadcast(ir::broadcast_inst* x);
|
||||
std::vector<unsigned> populate_max_contiguous_binop(ir::binary_operator* x);
|
||||
std::vector<unsigned> populate_max_contiguous_gep(ir::getelementptr_inst* x);
|
||||
std::vector<unsigned> populate_max_contiguous_cast(ir::cast_inst* x);
|
||||
std::vector<unsigned> populate_max_contiguous_default(ir::value* v);
|
||||
std::vector<unsigned> populate_max_contiguous(ir::value *v);
|
||||
// populate starting_multiple
|
||||
@@ -53,6 +55,7 @@ private:
|
||||
std::vector<unsigned> populate_starting_multiple_broadcast(ir::broadcast_inst* x);
|
||||
std::vector<unsigned> populate_starting_multiple_binop(ir::binary_operator* x);
|
||||
std::vector<unsigned> populate_starting_multiple_gep(ir::getelementptr_inst* x);
|
||||
std::vector<unsigned> populate_starting_multiple_cast(ir::cast_inst* x);
|
||||
std::vector<unsigned> populate_starting_multiple_default(ir::value* v);
|
||||
std::vector<unsigned> populate_starting_multiple(ir::value *v);
|
||||
// populate all maps
|
||||
|
@@ -327,9 +327,16 @@ std::vector<unsigned> align::populate_max_contiguous_default(ir::value* v) {
|
||||
return add_to_cache(v, std::vector<unsigned>(shapes.size(), 1), max_contiguous_);
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::populate_max_contiguous_cast(ir::cast_inst* v){
|
||||
auto result = populate_max_contiguous(v->get_operand(0));
|
||||
return add_to_cache(v, result, max_contiguous_);
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::populate_max_contiguous(ir::value *v){
|
||||
if(max_contiguous_.find(v) != max_contiguous_.end())
|
||||
return max_contiguous_.at(v);
|
||||
if(auto *x = dynamic_cast<ir::cast_inst*>(v))
|
||||
return populate_max_contiguous_cast(x);
|
||||
if(auto *x = dynamic_cast<ir::splat_inst*>(v))
|
||||
return populate_max_contiguous_splat(x);
|
||||
if(auto *x = dynamic_cast<ir::reshape_inst*>(v))
|
||||
@@ -434,16 +441,16 @@ std::vector<unsigned> align::populate_starting_multiple_phi(ir::phi_node* x){
|
||||
}
|
||||
|
||||
|
||||
std::vector<unsigned> align::populate_starting_multiple_cast(ir::cast_inst* x){
|
||||
auto result = populate_starting_multiple(x->get_operand(0));
|
||||
return add_to_cache(x, result, starting_multiple_);
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::populate_starting_multiple_default(ir::value* v) {
|
||||
ir::type* ty = v->get_type();
|
||||
if(ty->is_tile_ty()) {
|
||||
return add_to_cache(v, ty->get_tile_shapes(), starting_multiple_);
|
||||
}
|
||||
if(auto *x = dynamic_cast<ir::instruction*>(v)){
|
||||
unsigned multiple_of = x->get_metadata(ir::metadata::multiple_of);
|
||||
if(multiple_of > 0)
|
||||
return add_to_cache(x, {multiple_of}, starting_multiple_);
|
||||
}
|
||||
if(auto *x = dynamic_cast<ir::argument*>(v)){
|
||||
std::set<ir::attribute> attributes = x->get_parent()->get_attributes(x);
|
||||
for(auto attr: attributes){
|
||||
@@ -464,6 +471,13 @@ std::vector<unsigned> align::populate_starting_multiple_default(ir::value* v) {
|
||||
std::vector<unsigned> align::populate_starting_multiple(ir::value *v){
|
||||
if(starting_multiple_.find(v) != starting_multiple_.end())
|
||||
return starting_multiple_.at(v);
|
||||
if(auto *x = dynamic_cast<ir::instruction*>(v)){
|
||||
unsigned multiple_of = x->get_metadata(ir::metadata::multiple_of);
|
||||
if(multiple_of > 0)
|
||||
return add_to_cache(x, {multiple_of}, starting_multiple_);
|
||||
}
|
||||
if(auto *x = dynamic_cast<ir::cast_inst*>(v))
|
||||
return populate_starting_multiple_cast(x);
|
||||
if(auto *x = dynamic_cast<ir::binary_operator*>(v))
|
||||
return populate_starting_multiple_binop(x);
|
||||
if(auto *x = dynamic_cast<ir::constant_int*>(v))
|
||||
@@ -508,6 +522,11 @@ void align::populate(ir::value *v) {
|
||||
|
||||
void align::run(ir::module &mod) {
|
||||
ir::for_each_value(mod, [this](ir::value* v) { populate(v); } );
|
||||
// ir::for_each_value(mod, [this](ir::value* v) {
|
||||
// if(dynamic_cast<ir::cast_inst*>(v) || dynamic_cast<ir::getelementptr_inst*>(v))
|
||||
// std::cout << "ALIGN: " << v->get_name() << " " << starting_multiple_.at(v)[0] << " " << max_contiguous_.at(v)[0]
|
||||
// << " " << starting_multiple_.at(v)[1] << " " << max_contiguous_.at(v)[1] << std::endl;
|
||||
// });
|
||||
}
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user