[codegen] added missing file
This commit is contained in:
67
include/triton/tools/graph.h
Normal file
67
include/triton/tools/graph.h
Normal file
@@ -0,0 +1,67 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _TRITON_TOOLS_THREAD_GRAPH_H_
|
||||
#define _TRITON_TOOLS_THREAD_GRAPH_H_
|
||||
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
|
||||
namespace triton {
|
||||
namespace tools{
|
||||
|
||||
template<class node_t>
|
||||
class graph {
|
||||
typedef std::map<node_t, std::set<node_t>> edges_t;
|
||||
|
||||
public:
|
||||
typedef std::map<size_t, std::vector<node_t>> cmap_t;
|
||||
typedef std::map<node_t, size_t> nmap_t;
|
||||
|
||||
private:
|
||||
void connected_components_impl(node_t x, std::set<node_t> &nodes,
|
||||
nmap_t* nmap, cmap_t* cmap, int id) const {
|
||||
if(nmap)
|
||||
(*nmap)[x] = id;
|
||||
if(cmap)
|
||||
(*cmap)[id].push_back(x);
|
||||
if(nodes.find(x) != nodes.end()) {
|
||||
nodes.erase(x);
|
||||
for(const node_t &y: edges_.at(x))
|
||||
connected_components_impl(y, nodes, nmap, cmap, id);
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
void connected_components(cmap_t *cmap, nmap_t *nmap) const {
|
||||
if(cmap)
|
||||
cmap->clear();
|
||||
if(nmap)
|
||||
nmap->clear();
|
||||
std::set<node_t> nodes = nodes_;
|
||||
unsigned id = 0;
|
||||
while(!nodes.empty())
|
||||
connected_components_impl(*nodes.begin(), nodes, nmap, cmap, id++);
|
||||
}
|
||||
|
||||
void add_edge(node_t x, node_t y) {
|
||||
nodes_.insert(x);
|
||||
nodes_.insert(y);
|
||||
edges_[x].insert(y);
|
||||
edges_[y].insert(x);
|
||||
}
|
||||
|
||||
void clear() {
|
||||
nodes_.clear();
|
||||
edges_.clear();
|
||||
}
|
||||
|
||||
private:
|
||||
std::set<node_t> nodes_;
|
||||
edges_t edges_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -206,21 +206,20 @@ void tiles::run(ir::module &) {
|
||||
size_t num_groups = layout_->num_layouts();
|
||||
// helpers
|
||||
auto rank = [](ir::value* v) {
|
||||
ir::type *ty = v->get_type();
|
||||
size_t ret = 0;
|
||||
if(ty->is_tile_ty())
|
||||
for(int s: ty->get_tile_shapes())
|
||||
ret += s > 1;
|
||||
int ret = 0;
|
||||
for(int s: v->get_type()->get_tile_shapes())
|
||||
ret += s > 1;
|
||||
return ret;
|
||||
};
|
||||
|
||||
// find out which groups require hmma layout
|
||||
for(size_t i = 0; i < num_groups; i++) {
|
||||
const auto& values = layout_->values_of(i);
|
||||
bool hmma_c = std::any_of(values.begin(), values.end(), &is_hmma_c);
|
||||
if(hmma_c) hmma_[i] = HMMA_C;
|
||||
else hmma_[i] = SCANLINE;
|
||||
|
||||
}
|
||||
|
||||
// find out which value is the largest in each group
|
||||
for(size_t i = 0; i < num_groups; i++) {
|
||||
const auto& values = layout_->values_of(i);
|
||||
|
Reference in New Issue
Block a user