[codegen] added missing file
This commit is contained in:
67
include/triton/tools/graph.h
Normal file
67
include/triton/tools/graph.h
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#ifndef _TRITON_TOOLS_THREAD_GRAPH_H_
|
||||||
|
#define _TRITON_TOOLS_THREAD_GRAPH_H_
|
||||||
|
|
||||||
|
#include <map>
|
||||||
|
#include <set>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace triton {
|
||||||
|
namespace tools{
|
||||||
|
|
||||||
|
template<class node_t>
|
||||||
|
class graph {
|
||||||
|
typedef std::map<node_t, std::set<node_t>> edges_t;
|
||||||
|
|
||||||
|
public:
|
||||||
|
typedef std::map<size_t, std::vector<node_t>> cmap_t;
|
||||||
|
typedef std::map<node_t, size_t> nmap_t;
|
||||||
|
|
||||||
|
private:
|
||||||
|
void connected_components_impl(node_t x, std::set<node_t> &nodes,
|
||||||
|
nmap_t* nmap, cmap_t* cmap, int id) const {
|
||||||
|
if(nmap)
|
||||||
|
(*nmap)[x] = id;
|
||||||
|
if(cmap)
|
||||||
|
(*cmap)[id].push_back(x);
|
||||||
|
if(nodes.find(x) != nodes.end()) {
|
||||||
|
nodes.erase(x);
|
||||||
|
for(const node_t &y: edges_.at(x))
|
||||||
|
connected_components_impl(y, nodes, nmap, cmap, id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public:
|
||||||
|
void connected_components(cmap_t *cmap, nmap_t *nmap) const {
|
||||||
|
if(cmap)
|
||||||
|
cmap->clear();
|
||||||
|
if(nmap)
|
||||||
|
nmap->clear();
|
||||||
|
std::set<node_t> nodes = nodes_;
|
||||||
|
unsigned id = 0;
|
||||||
|
while(!nodes.empty())
|
||||||
|
connected_components_impl(*nodes.begin(), nodes, nmap, cmap, id++);
|
||||||
|
}
|
||||||
|
|
||||||
|
void add_edge(node_t x, node_t y) {
|
||||||
|
nodes_.insert(x);
|
||||||
|
nodes_.insert(y);
|
||||||
|
edges_[x].insert(y);
|
||||||
|
edges_[y].insert(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
void clear() {
|
||||||
|
nodes_.clear();
|
||||||
|
edges_.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::set<node_t> nodes_;
|
||||||
|
edges_t edges_;
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
@@ -206,21 +206,20 @@ void tiles::run(ir::module &) {
|
|||||||
size_t num_groups = layout_->num_layouts();
|
size_t num_groups = layout_->num_layouts();
|
||||||
// helpers
|
// helpers
|
||||||
auto rank = [](ir::value* v) {
|
auto rank = [](ir::value* v) {
|
||||||
ir::type *ty = v->get_type();
|
int ret = 0;
|
||||||
size_t ret = 0;
|
for(int s: v->get_type()->get_tile_shapes())
|
||||||
if(ty->is_tile_ty())
|
|
||||||
for(int s: ty->get_tile_shapes())
|
|
||||||
ret += s > 1;
|
ret += s > 1;
|
||||||
return ret;
|
return ret;
|
||||||
};
|
};
|
||||||
|
|
||||||
// find out which groups require hmma layout
|
// find out which groups require hmma layout
|
||||||
for(size_t i = 0; i < num_groups; i++) {
|
for(size_t i = 0; i < num_groups; i++) {
|
||||||
const auto& values = layout_->values_of(i);
|
const auto& values = layout_->values_of(i);
|
||||||
bool hmma_c = std::any_of(values.begin(), values.end(), &is_hmma_c);
|
bool hmma_c = std::any_of(values.begin(), values.end(), &is_hmma_c);
|
||||||
if(hmma_c) hmma_[i] = HMMA_C;
|
if(hmma_c) hmma_[i] = HMMA_C;
|
||||||
else hmma_[i] = SCANLINE;
|
else hmma_[i] = SCANLINE;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// find out which value is the largest in each group
|
// find out which value is the largest in each group
|
||||||
for(size_t i = 0; i < num_groups; i++) {
|
for(size_t i = 0; i < num_groups; i++) {
|
||||||
const auto& values = layout_->values_of(i);
|
const auto& values = layout_->values_of(i);
|
||||||
|
Reference in New Issue
Block a user