[general] cleaned tensorflow source code generation

This commit is contained in:
Philippe Tillet
2019-08-18 15:39:36 -07:00
parent 457c330f15
commit 0970fe12dd
12 changed files with 162 additions and 152 deletions

View File

@@ -15,7 +15,7 @@ namespace triton{
namespace codegen{
namespace analysis{
tune::tune(size_t num_warps): num_warps_(num_warps){
grids::grids(size_t num_warps): num_warps_(num_warps){
}
bool is_hmma(ir::value *v){
@@ -32,14 +32,14 @@ bool is_hmma(ir::value *v){
return result;
}
void tune::add_constraint(node_t x, node_t y) {
void grids::add_constraint(node_t x, node_t y) {
dependencies_[x].insert(y);
dependencies_[y].insert(x);
nodes_.insert(x);
nodes_.insert(y);
}
void tune::init_c_phi(ir::instruction *v) {
void grids::init_c_phi(ir::instruction *v) {
// Phi Nodes: all the incoming value share the result layout
if(auto *phi = dynamic_cast<ir::phi_node*>(v))
for(ir::value *op: phi->ops())
@@ -50,7 +50,7 @@ void tune::init_c_phi(ir::instruction *v) {
}
}
void tune::init_c_graph(ir::instruction *v) {
void grids::init_c_graph(ir::instruction *v) {
// Reference shape
ir::type::tile_shapes_t::value_type one = ir::tile_type::make_one(v->get_parent()->get_context());
ir::type::tile_shapes_t shapes;
@@ -142,7 +142,7 @@ void tune::init_c_graph(ir::instruction *v) {
}
}
tune::fragment_t tune::get_fragmentation_type(node_t x, graph_t &graph){
grids::fragment_t grids::get_fragmentation_type(node_t x, graph_t &graph){
std::list<node_t> work;
std::set<node_t> seen;
work.push_back(x);
@@ -160,7 +160,7 @@ tune::fragment_t tune::get_fragmentation_type(node_t x, graph_t &graph){
return STRIDED_SCAN;
}
void tune::connected_components(node_t x, const std::vector<ir::metaparameter *> mps, const std::vector<std::string> prefixes, std::set<node_t> &nodes, graph_t &graph, unsigned group_id) {
void grids::connected_components(node_t x, const std::vector<ir::metaparameter *> mps, const std::vector<std::string> prefixes, std::set<node_t> &nodes, graph_t &graph, unsigned group_id) {
// std::cout << "connected component: " << x.first->get_name() << " " << x.second << std::endl;
groups_[x.first].insert({x.second, group_id});
if(nodes.find(x) != nodes.end()){
@@ -183,20 +183,20 @@ void tune::connected_components(node_t x, const std::vector<ir::metaparameter *>
}
}
unsigned tune::get_param_group(ir::value *value, unsigned ax) {
unsigned grids::get_param_group(ir::value *value, unsigned ax) {
unsigned result = groups_.at(value).at(ax);
return result;
}
//TODO: This shouldn't exist!
void tune::copy(ir::value *dst, ir::value *src) {
void grids::copy(ir::value *dst, ir::value *src) {
params_[dst] = params_[src];
groups_[dst] = groups_[src];
fragments_[{dst, 0}] = fragments_[{src, 0}];
}
void tune::run(ir::module &mod) {
void grids::run(ir::module &mod) {
ir::context &ctx = mod.get_context();
// Create metaparameters
for(ir::function *fn: mod.get_function_list()){
@@ -318,7 +318,7 @@ void tune::run(ir::module &mod) {
}
void tune::create_grids(std::vector<ir::value*> &grids,
void grids::create_grids(std::vector<ir::value*> &grids,
std::map<unsigned, ir::value*> &references,
ir::function *fn) {
// get number of dimensions greater than 1
@@ -363,11 +363,7 @@ void tune::create_grids(std::vector<ir::value*> &grids,
}
bool tune::check_constraints(std::map<ir::value *, std::vector<std::string>> &errors) {
return errors.empty();
}
unsigned tune::get_num_threads() {
unsigned grids::get_num_threads() {
return num_warps_*32;
}