more progress

This commit is contained in:
Philippe Tillet
2019-09-08 21:36:54 -04:00
parent 3d78810d5e
commit 3daef1726d
3 changed files with 18 additions and 9 deletions

View File

@@ -1,5 +1,6 @@
#include <algorithm>
#include <cstdlib>
#include "triton/codegen/transform/reorder.h"
#include "triton/codegen/analysis/grid.h"
#include "triton/ir/instructions.h"
#include "triton/ir/type.h"
@@ -15,7 +16,7 @@ namespace triton{
namespace codegen{
namespace analysis{
grids::grids(size_t num_warps): num_warps_(num_warps)
grids::grids(size_t num_warps, transform::reorder *reorder): num_warps_(num_warps), reorder_(reorder)
{ }
bool is_hmma(ir::value *v){
@@ -157,7 +158,6 @@ grids::fragment_t grids::get_fragmentation_type(node_t x, graph_t &graph){
}
void grids::connected_components(node_t x, const std::vector<ir::metaparameter *> mps, const std::vector<std::string> prefixes, std::set<node_t> &nodes, graph_t &graph, unsigned group_id) {
std::cout << "connected component: " << x.first->get_name() << " " << x.second << std::endl;
groups_[x.first].insert({x.second, group_id});
if(nodes.find(x) != nodes.end()){
nodes.erase(x);
@@ -225,7 +225,7 @@ void grids::run(ir::module &mod) {
}
for(ir::function *fn: mod.get_function_list()){
std::map<unsigned, ir::value*> references;
std::map<std::pair<unsigned, std::vector<unsigned>>, ir::value*> references;
create_grids(grids_, references, fn);
}
@@ -317,7 +317,8 @@ void grids::run(ir::module &mod) {
void grids::create_grids(std::vector<ir::value*> &grids,
std::map<unsigned, ir::value*> &references,
std::map<std::pair<unsigned,
std::vector<unsigned>>, ir::value*> &references,
ir::function *fn) {
// get number of dimensions greater than 1
auto get_tile_gt1_dim = [&](ir::value *v){
@@ -331,6 +332,7 @@ void grids::create_grids(std::vector<ir::value*> &grids,
std::set<ir::value*> seen;
std::function<void(ir::value*)> bind_references = [&](ir::value *v)
{
auto order = reorder_->get_order(v);
// skip
if(!v->get_type()->is_tile_ty() || !seen.insert(v).second)
return;
@@ -344,7 +346,7 @@ void grids::create_grids(std::vector<ir::value*> &grids,
if(shapes[d] == 1)
continue;
unsigned x = get_param_group(v, d);
ir::value *&r = references[x];
ir::value *&r = references[{x, order}];
if(!r || get_tile_gt1_dim(v) > get_tile_gt1_dim(r))
r = v;
}