diff --git a/include/triton/codegen/analysis/grid.h b/include/triton/codegen/analysis/grid.h index 26331c786..84fd2168c 100644 --- a/include/triton/codegen/analysis/grid.h +++ b/include/triton/codegen/analysis/grid.h @@ -17,6 +17,11 @@ namespace ir{ } namespace codegen{ + +namespace transform{ +class reorder; +} + namespace analysis{ class grids { @@ -36,12 +41,12 @@ private: fragment_t get_fragmentation_type(node_t x, graph_t &graph); void connected_components(node_t x, const std::vector mps, const std::vector prefixes, std::set &nodes, graph_t &graph, unsigned group_id); void create_grids(std::vector &grids, - std::map &references, + std::map >, triton::ir::value *> &references, ir::function *fn); public: - grids(size_t num_warps); + grids(size_t num_warps, transform::reorder* reorder); ir::metaparameter* get_param(ir::value *value, const std::string &key) { return params_[value][key]; } unsigned get_param_group(ir::value *value, unsigned ax); fragment_t get_fragment(ir::value *value, unsigned ax) { return fragments_.at({value, ax}); } @@ -60,6 +65,8 @@ private: std::vector grids_; std::map> groups_; size_t num_warps_; + transform::reorder* reorder_; + }; diff --git a/lib/codegen/analysis/grid.cc b/lib/codegen/analysis/grid.cc index 29d5c3657..d7b773aaf 100644 --- a/lib/codegen/analysis/grid.cc +++ b/lib/codegen/analysis/grid.cc @@ -1,5 +1,6 @@ #include #include +#include "triton/codegen/transform/reorder.h" #include "triton/codegen/analysis/grid.h" #include "triton/ir/instructions.h" #include "triton/ir/type.h" @@ -15,7 +16,7 @@ namespace triton{ namespace codegen{ namespace analysis{ -grids::grids(size_t num_warps): num_warps_(num_warps) +grids::grids(size_t num_warps, transform::reorder *reorder): num_warps_(num_warps), reorder_(reorder) { } bool is_hmma(ir::value *v){ @@ -157,7 +158,6 @@ grids::fragment_t grids::get_fragmentation_type(node_t x, graph_t &graph){ } void grids::connected_components(node_t x, const std::vector mps, const std::vector prefixes, std::set &nodes, graph_t &graph, unsigned group_id) { - std::cout << "connected component: " << x.first->get_name() << " " << x.second << std::endl; groups_[x.first].insert({x.second, group_id}); if(nodes.find(x) != nodes.end()){ nodes.erase(x); @@ -225,7 +225,7 @@ void grids::run(ir::module &mod) { } for(ir::function *fn: mod.get_function_list()){ - std::map references; + std::map>, ir::value*> references; create_grids(grids_, references, fn); } @@ -317,7 +317,8 @@ void grids::run(ir::module &mod) { void grids::create_grids(std::vector &grids, - std::map &references, + std::map>, ir::value*> &references, ir::function *fn) { // get number of dimensions greater than 1 auto get_tile_gt1_dim = [&](ir::value *v){ @@ -331,6 +332,7 @@ void grids::create_grids(std::vector &grids, std::set seen; std::function bind_references = [&](ir::value *v) { + auto order = reorder_->get_order(v); // skip if(!v->get_type()->is_tile_ty() || !seen.insert(v).second) return; @@ -344,7 +346,7 @@ void grids::create_grids(std::vector &grids, if(shapes[d] == 1) continue; unsigned x = get_param_group(v, d); - ir::value *&r = references[x]; + ir::value *&r = references[{x, order}]; if(!r || get_tile_gt1_dim(v) > get_tile_gt1_dim(r)) r = v; } diff --git a/lib/runtime/function.cc b/lib/runtime/function.cc index 3ddcc856f..f2aa8e4db 100644 --- a/lib/runtime/function.cc +++ b/lib/runtime/function.cc @@ -194,12 +194,12 @@ std::unique_ptr function::make_bin(ir::module &module, driver::c std::unique_ptr target = context->device()->make_target(); // create passes - codegen::analysis::grids grids(opt.num_warps); codegen::analysis::meminfo shmem_info; codegen::analysis::liveness shmem_liveness(&shmem_info); - codegen::analysis::memalloc shmem_allocation(&shmem_liveness, &shmem_info, &grids); codegen::analysis::align alignment_info; codegen::transform::reorder reorder(&alignment_info, &shmem_info); + codegen::analysis::grids grids(opt.num_warps, &reorder); + codegen::analysis::memalloc shmem_allocation(&shmem_liveness, &shmem_info, &grids); codegen::transform::membar shmem_barriers(&shmem_allocation, &shmem_info); codegen::transform::vectorize vectorize(&grids); codegen::transform::dce dce;