2019-09-15 21:14:14 -04:00
|
|
|
#include <algorithm>
|
|
|
|
#include <iostream>
|
|
|
|
#include "triton/codegen/analysis/axes.h"
|
|
|
|
#include "triton/codegen/analysis/layout.h"
|
|
|
|
#include "triton/ir/function.h"
|
|
|
|
#include "triton/ir/module.h"
|
2019-09-19 16:25:36 -04:00
|
|
|
#include "triton/ir/utils.h"
|
2019-09-15 21:14:14 -04:00
|
|
|
|
|
|
|
namespace triton{
|
|
|
|
namespace codegen{
|
|
|
|
namespace analysis{
|
|
|
|
|
|
|
|
|
|
|
|
// constructor
|
|
|
|
layout::layout(analysis::axes *axes)
|
|
|
|
: axes_(axes) { }
|
|
|
|
|
|
|
|
// get group id
|
2019-10-07 18:06:54 -04:00
|
|
|
unsigned layout::layout_of(ir::value *value) const
|
2019-09-15 21:14:14 -04:00
|
|
|
{ return groups_.at(value); }
|
|
|
|
|
|
|
|
// get values
|
2019-10-07 18:06:54 -04:00
|
|
|
const std::vector<ir::value*>& layout::values_of(unsigned id) const
|
2019-09-15 21:14:14 -04:00
|
|
|
{ return values_.at(id); }
|
|
|
|
|
|
|
|
// get number of groups
|
2019-10-07 18:06:54 -04:00
|
|
|
size_t layout::num_layouts() const
|
2019-09-15 21:14:14 -04:00
|
|
|
{ return values_.size(); }
|
|
|
|
|
2019-09-20 16:01:12 -04:00
|
|
|
// connect two values
|
2019-09-17 17:40:03 -04:00
|
|
|
void layout::connect(ir::value *x, ir::value *y) {
|
|
|
|
if(x == y)
|
|
|
|
return;
|
|
|
|
if(!x->get_type()->is_tile_ty())
|
|
|
|
return;
|
|
|
|
if(!y->get_type()->is_tile_ty())
|
|
|
|
return;
|
2019-10-08 11:26:22 -04:00
|
|
|
std::vector<int> x_axes = axes_->get(x);
|
|
|
|
std::vector<int> y_axes = axes_->get(y);
|
|
|
|
std::set<int> sx_axes(x_axes.begin(), x_axes.end());
|
|
|
|
std::set<int> sy_axes(y_axes.begin(), y_axes.end());
|
2019-09-17 17:40:03 -04:00
|
|
|
std::set<int> common;
|
2019-10-08 11:26:22 -04:00
|
|
|
std::set_intersection(sx_axes.begin(), sx_axes.end(),
|
|
|
|
sy_axes.begin(), sy_axes.end(),
|
2019-09-17 17:40:03 -04:00
|
|
|
std::inserter(common, common.begin()));
|
2019-10-07 18:06:54 -04:00
|
|
|
if(!common.empty())
|
|
|
|
graph_.add_edge(x, y);
|
2019-09-17 17:40:03 -04:00
|
|
|
}
|
|
|
|
|
2019-09-20 16:01:12 -04:00
|
|
|
// make graph
|
2019-09-19 16:25:36 -04:00
|
|
|
void layout::make_graph(ir::instruction *i) {
|
|
|
|
for(ir::value* opx: i->ops())
|
|
|
|
for(ir::value* opy: i->ops()){
|
|
|
|
connect(i, opx);
|
|
|
|
connect(opx, opy);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2019-09-15 21:14:14 -04:00
|
|
|
void layout::run(ir::module &mod) {
|
2019-09-19 16:25:36 -04:00
|
|
|
// make graph
|
2019-10-07 18:06:54 -04:00
|
|
|
graph_.clear();
|
|
|
|
ir::for_each_instruction(mod, [this](ir::instruction* i) {
|
|
|
|
make_graph(i);
|
|
|
|
});
|
2019-09-19 16:25:36 -04:00
|
|
|
// connected components
|
2019-10-07 18:06:54 -04:00
|
|
|
graph_.connected_components(&values_, &groups_);
|
2019-09-15 21:14:14 -04:00
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|