diff --git a/include/triton/codegen/analysis/align.h b/include/triton/codegen/analysis/align.h index 647db3984..2393603cb 100644 --- a/include/triton/codegen/analysis/align.h +++ b/include/triton/codegen/analysis/align.h @@ -11,6 +11,7 @@ namespace ir { class module; class phi_node; class splat_inst; + class cast_inst; class reshape_inst; class broadcast_inst; class binary_operator; @@ -44,6 +45,7 @@ private: std::vector populate_max_contiguous_broadcast(ir::broadcast_inst* x); std::vector populate_max_contiguous_binop(ir::binary_operator* x); std::vector populate_max_contiguous_gep(ir::getelementptr_inst* x); + std::vector populate_max_contiguous_cast(ir::cast_inst* x); std::vector populate_max_contiguous_default(ir::value* v); std::vector populate_max_contiguous(ir::value *v); // populate starting_multiple @@ -53,6 +55,7 @@ private: std::vector populate_starting_multiple_broadcast(ir::broadcast_inst* x); std::vector populate_starting_multiple_binop(ir::binary_operator* x); std::vector populate_starting_multiple_gep(ir::getelementptr_inst* x); + std::vector populate_starting_multiple_cast(ir::cast_inst* x); std::vector populate_starting_multiple_default(ir::value* v); std::vector populate_starting_multiple(ir::value *v); // populate all maps diff --git a/lib/codegen/analysis/align.cc b/lib/codegen/analysis/align.cc index 95ccb8dc3..8eae89f86 100644 --- a/lib/codegen/analysis/align.cc +++ b/lib/codegen/analysis/align.cc @@ -327,9 +327,16 @@ std::vector align::populate_max_contiguous_default(ir::value* v) { return add_to_cache(v, std::vector(shapes.size(), 1), max_contiguous_); } +std::vector align::populate_max_contiguous_cast(ir::cast_inst* v){ + auto result = populate_max_contiguous(v->get_operand(0)); + return add_to_cache(v, result, max_contiguous_); +} + std::vector align::populate_max_contiguous(ir::value *v){ if(max_contiguous_.find(v) != max_contiguous_.end()) return max_contiguous_.at(v); + if(auto *x = dynamic_cast(v)) + return populate_max_contiguous_cast(x); if(auto *x = dynamic_cast(v)) return populate_max_contiguous_splat(x); if(auto *x = dynamic_cast(v)) @@ -434,16 +441,16 @@ std::vector align::populate_starting_multiple_phi(ir::phi_node* x){ } +std::vector align::populate_starting_multiple_cast(ir::cast_inst* x){ + auto result = populate_starting_multiple(x->get_operand(0)); + return add_to_cache(x, result, starting_multiple_); +} + std::vector align::populate_starting_multiple_default(ir::value* v) { ir::type* ty = v->get_type(); if(ty->is_tile_ty()) { return add_to_cache(v, ty->get_tile_shapes(), starting_multiple_); } - if(auto *x = dynamic_cast(v)){ - unsigned multiple_of = x->get_metadata(ir::metadata::multiple_of); - if(multiple_of > 0) - return add_to_cache(x, {multiple_of}, starting_multiple_); - } if(auto *x = dynamic_cast(v)){ std::set attributes = x->get_parent()->get_attributes(x); for(auto attr: attributes){ @@ -464,6 +471,13 @@ std::vector align::populate_starting_multiple_default(ir::value* v) { std::vector align::populate_starting_multiple(ir::value *v){ if(starting_multiple_.find(v) != starting_multiple_.end()) return starting_multiple_.at(v); + if(auto *x = dynamic_cast(v)){ + unsigned multiple_of = x->get_metadata(ir::metadata::multiple_of); + if(multiple_of > 0) + return add_to_cache(x, {multiple_of}, starting_multiple_); + } + if(auto *x = dynamic_cast(v)) + return populate_starting_multiple_cast(x); if(auto *x = dynamic_cast(v)) return populate_starting_multiple_binop(x); if(auto *x = dynamic_cast(v)) @@ -508,6 +522,11 @@ void align::populate(ir::value *v) { void align::run(ir::module &mod) { ir::for_each_value(mod, [this](ir::value* v) { populate(v); } ); +// ir::for_each_value(mod, [this](ir::value* v) { +// if(dynamic_cast(v) || dynamic_cast(v)) +// std::cout << "ALIGN: " << v->get_name() << " " << starting_multiple_.at(v)[0] << " " << max_contiguous_.at(v)[0] +// << " " << starting_multiple_.at(v)[1] << " " << max_contiguous_.at(v)[1] << std::endl; +// }); }