[CODEGEN] Fixed various issues in alignment inference pass
This commit is contained in:
committed by
Philippe Tillet
parent
da6008128e
commit
e18f169a39
@@ -11,6 +11,7 @@ namespace ir {
|
|||||||
class module;
|
class module;
|
||||||
class phi_node;
|
class phi_node;
|
||||||
class splat_inst;
|
class splat_inst;
|
||||||
|
class cast_inst;
|
||||||
class reshape_inst;
|
class reshape_inst;
|
||||||
class broadcast_inst;
|
class broadcast_inst;
|
||||||
class binary_operator;
|
class binary_operator;
|
||||||
@@ -44,6 +45,7 @@ private:
|
|||||||
std::vector<unsigned> populate_max_contiguous_broadcast(ir::broadcast_inst* x);
|
std::vector<unsigned> populate_max_contiguous_broadcast(ir::broadcast_inst* x);
|
||||||
std::vector<unsigned> populate_max_contiguous_binop(ir::binary_operator* x);
|
std::vector<unsigned> populate_max_contiguous_binop(ir::binary_operator* x);
|
||||||
std::vector<unsigned> populate_max_contiguous_gep(ir::getelementptr_inst* x);
|
std::vector<unsigned> populate_max_contiguous_gep(ir::getelementptr_inst* x);
|
||||||
|
std::vector<unsigned> populate_max_contiguous_cast(ir::cast_inst* x);
|
||||||
std::vector<unsigned> populate_max_contiguous_default(ir::value* v);
|
std::vector<unsigned> populate_max_contiguous_default(ir::value* v);
|
||||||
std::vector<unsigned> populate_max_contiguous(ir::value *v);
|
std::vector<unsigned> populate_max_contiguous(ir::value *v);
|
||||||
// populate starting_multiple
|
// populate starting_multiple
|
||||||
@@ -53,6 +55,7 @@ private:
|
|||||||
std::vector<unsigned> populate_starting_multiple_broadcast(ir::broadcast_inst* x);
|
std::vector<unsigned> populate_starting_multiple_broadcast(ir::broadcast_inst* x);
|
||||||
std::vector<unsigned> populate_starting_multiple_binop(ir::binary_operator* x);
|
std::vector<unsigned> populate_starting_multiple_binop(ir::binary_operator* x);
|
||||||
std::vector<unsigned> populate_starting_multiple_gep(ir::getelementptr_inst* x);
|
std::vector<unsigned> populate_starting_multiple_gep(ir::getelementptr_inst* x);
|
||||||
|
std::vector<unsigned> populate_starting_multiple_cast(ir::cast_inst* x);
|
||||||
std::vector<unsigned> populate_starting_multiple_default(ir::value* v);
|
std::vector<unsigned> populate_starting_multiple_default(ir::value* v);
|
||||||
std::vector<unsigned> populate_starting_multiple(ir::value *v);
|
std::vector<unsigned> populate_starting_multiple(ir::value *v);
|
||||||
// populate all maps
|
// populate all maps
|
||||||
|
@@ -327,9 +327,16 @@ std::vector<unsigned> align::populate_max_contiguous_default(ir::value* v) {
|
|||||||
return add_to_cache(v, std::vector<unsigned>(shapes.size(), 1), max_contiguous_);
|
return add_to_cache(v, std::vector<unsigned>(shapes.size(), 1), max_contiguous_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<unsigned> align::populate_max_contiguous_cast(ir::cast_inst* v){
|
||||||
|
auto result = populate_max_contiguous(v->get_operand(0));
|
||||||
|
return add_to_cache(v, result, max_contiguous_);
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<unsigned> align::populate_max_contiguous(ir::value *v){
|
std::vector<unsigned> align::populate_max_contiguous(ir::value *v){
|
||||||
if(max_contiguous_.find(v) != max_contiguous_.end())
|
if(max_contiguous_.find(v) != max_contiguous_.end())
|
||||||
return max_contiguous_.at(v);
|
return max_contiguous_.at(v);
|
||||||
|
if(auto *x = dynamic_cast<ir::cast_inst*>(v))
|
||||||
|
return populate_max_contiguous_cast(x);
|
||||||
if(auto *x = dynamic_cast<ir::splat_inst*>(v))
|
if(auto *x = dynamic_cast<ir::splat_inst*>(v))
|
||||||
return populate_max_contiguous_splat(x);
|
return populate_max_contiguous_splat(x);
|
||||||
if(auto *x = dynamic_cast<ir::reshape_inst*>(v))
|
if(auto *x = dynamic_cast<ir::reshape_inst*>(v))
|
||||||
@@ -434,16 +441,16 @@ std::vector<unsigned> align::populate_starting_multiple_phi(ir::phi_node* x){
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
std::vector<unsigned> align::populate_starting_multiple_cast(ir::cast_inst* x){
|
||||||
|
auto result = populate_starting_multiple(x->get_operand(0));
|
||||||
|
return add_to_cache(x, result, starting_multiple_);
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<unsigned> align::populate_starting_multiple_default(ir::value* v) {
|
std::vector<unsigned> align::populate_starting_multiple_default(ir::value* v) {
|
||||||
ir::type* ty = v->get_type();
|
ir::type* ty = v->get_type();
|
||||||
if(ty->is_tile_ty()) {
|
if(ty->is_tile_ty()) {
|
||||||
return add_to_cache(v, ty->get_tile_shapes(), starting_multiple_);
|
return add_to_cache(v, ty->get_tile_shapes(), starting_multiple_);
|
||||||
}
|
}
|
||||||
if(auto *x = dynamic_cast<ir::instruction*>(v)){
|
|
||||||
unsigned multiple_of = x->get_metadata(ir::metadata::multiple_of);
|
|
||||||
if(multiple_of > 0)
|
|
||||||
return add_to_cache(x, {multiple_of}, starting_multiple_);
|
|
||||||
}
|
|
||||||
if(auto *x = dynamic_cast<ir::argument*>(v)){
|
if(auto *x = dynamic_cast<ir::argument*>(v)){
|
||||||
std::set<ir::attribute> attributes = x->get_parent()->get_attributes(x);
|
std::set<ir::attribute> attributes = x->get_parent()->get_attributes(x);
|
||||||
for(auto attr: attributes){
|
for(auto attr: attributes){
|
||||||
@@ -464,6 +471,13 @@ std::vector<unsigned> align::populate_starting_multiple_default(ir::value* v) {
|
|||||||
std::vector<unsigned> align::populate_starting_multiple(ir::value *v){
|
std::vector<unsigned> align::populate_starting_multiple(ir::value *v){
|
||||||
if(starting_multiple_.find(v) != starting_multiple_.end())
|
if(starting_multiple_.find(v) != starting_multiple_.end())
|
||||||
return starting_multiple_.at(v);
|
return starting_multiple_.at(v);
|
||||||
|
if(auto *x = dynamic_cast<ir::instruction*>(v)){
|
||||||
|
unsigned multiple_of = x->get_metadata(ir::metadata::multiple_of);
|
||||||
|
if(multiple_of > 0)
|
||||||
|
return add_to_cache(x, {multiple_of}, starting_multiple_);
|
||||||
|
}
|
||||||
|
if(auto *x = dynamic_cast<ir::cast_inst*>(v))
|
||||||
|
return populate_starting_multiple_cast(x);
|
||||||
if(auto *x = dynamic_cast<ir::binary_operator*>(v))
|
if(auto *x = dynamic_cast<ir::binary_operator*>(v))
|
||||||
return populate_starting_multiple_binop(x);
|
return populate_starting_multiple_binop(x);
|
||||||
if(auto *x = dynamic_cast<ir::constant_int*>(v))
|
if(auto *x = dynamic_cast<ir::constant_int*>(v))
|
||||||
@@ -508,6 +522,11 @@ void align::populate(ir::value *v) {
|
|||||||
|
|
||||||
void align::run(ir::module &mod) {
|
void align::run(ir::module &mod) {
|
||||||
ir::for_each_value(mod, [this](ir::value* v) { populate(v); } );
|
ir::for_each_value(mod, [this](ir::value* v) { populate(v); } );
|
||||||
|
// ir::for_each_value(mod, [this](ir::value* v) {
|
||||||
|
// if(dynamic_cast<ir::cast_inst*>(v) || dynamic_cast<ir::getelementptr_inst*>(v))
|
||||||
|
// std::cout << "ALIGN: " << v->get_name() << " " << starting_multiple_.at(v)[0] << " " << max_contiguous_.at(v)[0]
|
||||||
|
// << " " << starting_multiple_.at(v)[1] << " " << max_contiguous_.at(v)[1] << std::endl;
|
||||||
|
// });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user