more progress
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
#include <algorithm>
|
||||
#include <cstdlib>
|
||||
#include "triton/codegen/transform/reorder.h"
|
||||
#include "triton/codegen/analysis/grid.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/type.h"
|
||||
@@ -15,7 +16,7 @@ namespace triton{
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
|
||||
grids::grids(size_t num_warps): num_warps_(num_warps)
|
||||
grids::grids(size_t num_warps, transform::reorder *reorder): num_warps_(num_warps), reorder_(reorder)
|
||||
{ }
|
||||
|
||||
bool is_hmma(ir::value *v){
|
||||
@@ -157,7 +158,6 @@ grids::fragment_t grids::get_fragmentation_type(node_t x, graph_t &graph){
|
||||
}
|
||||
|
||||
void grids::connected_components(node_t x, const std::vector<ir::metaparameter *> mps, const std::vector<std::string> prefixes, std::set<node_t> &nodes, graph_t &graph, unsigned group_id) {
|
||||
std::cout << "connected component: " << x.first->get_name() << " " << x.second << std::endl;
|
||||
groups_[x.first].insert({x.second, group_id});
|
||||
if(nodes.find(x) != nodes.end()){
|
||||
nodes.erase(x);
|
||||
@@ -225,7 +225,7 @@ void grids::run(ir::module &mod) {
|
||||
}
|
||||
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
std::map<unsigned, ir::value*> references;
|
||||
std::map<std::pair<unsigned, std::vector<unsigned>>, ir::value*> references;
|
||||
create_grids(grids_, references, fn);
|
||||
}
|
||||
|
||||
@@ -317,7 +317,8 @@ void grids::run(ir::module &mod) {
|
||||
|
||||
|
||||
void grids::create_grids(std::vector<ir::value*> &grids,
|
||||
std::map<unsigned, ir::value*> &references,
|
||||
std::map<std::pair<unsigned,
|
||||
std::vector<unsigned>>, ir::value*> &references,
|
||||
ir::function *fn) {
|
||||
// get number of dimensions greater than 1
|
||||
auto get_tile_gt1_dim = [&](ir::value *v){
|
||||
@@ -331,6 +332,7 @@ void grids::create_grids(std::vector<ir::value*> &grids,
|
||||
std::set<ir::value*> seen;
|
||||
std::function<void(ir::value*)> bind_references = [&](ir::value *v)
|
||||
{
|
||||
auto order = reorder_->get_order(v);
|
||||
// skip
|
||||
if(!v->get_type()->is_tile_ty() || !seen.insert(v).second)
|
||||
return;
|
||||
@@ -344,7 +346,7 @@ void grids::create_grids(std::vector<ir::value*> &grids,
|
||||
if(shapes[d] == 1)
|
||||
continue;
|
||||
unsigned x = get_param_group(v, d);
|
||||
ir::value *&r = references[x];
|
||||
ir::value *&r = references[{x, order}];
|
||||
if(!r || get_tile_gt1_dim(v) > get_tile_gt1_dim(r))
|
||||
r = v;
|
||||
}
|
||||
|
Reference in New Issue
Block a user