[codegen][grid] some cleaning

This commit is contained in:
Philippe Tillet
2019-09-14 13:05:53 -04:00
parent 8ae779206f
commit 66e32b3074
5 changed files with 81 additions and 72 deletions

View File

@@ -84,7 +84,6 @@ void grids::init_c_graph(ir::instruction *v) {
bool is_skewed = false;
for(unsigned i = 0; i < shapes.size(); i ++){
if(shapes[i] == 1){
static_params_.insert({{v, i}, 1});
add_constraint({v, i}, {v, i});
}
else if(!is_skewed &&
@@ -125,8 +124,6 @@ void grids::init_c_graph(ir::instruction *v) {
for(unsigned i = 0; i < shapes.size(); i++)
add_constraint({v, i}, {D, i});
for(unsigned i = 2; i < shapes.size(); i++){
if(shapes[i] == 1)
static_params_.insert({{v, i}, 1});
add_constraint({v, i}, {A, i});
add_constraint({v, i}, {B, i});
}
@@ -159,21 +156,15 @@ grids::fragment_t grids::get_fragmentation_type(node_t x, graph_t &graph){
return STRIDED_SCAN;
}
void grids::connected_components(node_t x, const std::vector<ir::metaparameter *> mps, const std::vector<std::string> prefixes, std::set<node_t> &nodes, graph_t &graph, unsigned group_id) {
void grids::connected_components(node_t x, const std::vector<param_ptr_t>& ptr_vec, const std::vector<param_map_t*>& maps, std::set<node_t> &nodes, graph_t &graph, unsigned group_id)
{
groups_[x.first].insert({x.second, group_id});
if(nodes.find(x) != nodes.end()){
nodes.erase(x);
std::string suffix = ".d" + std::to_string(x.second);
for(unsigned i = 0; i < mps.size(); i++)
params_[x.first].insert({prefixes[i] + suffix, mps[i]});
ir::type *ty = x.first->get_type();
if(static_params_.find(x) != static_params_.end()){
for(ir::metaparameter *mp: mps)
mp->set_value(static_params_.at(x));
}
for(const node_t &y: graph[x]){
connected_components(y, mps, prefixes, nodes, graph, group_id);
}
for(unsigned i = 0; i < ptr_vec.size(); i++)
(*maps[i])[x.first][x.second] = ptr_vec[i];
for(const node_t &y: graph[x])
connected_components(y, ptr_vec, maps, nodes, graph, group_id);
}
}
@@ -189,7 +180,10 @@ grids::fragment_t grids::get_fragment(ir::value *value, unsigned ax) {
//TODO: This shouldn't exist!
void grids::copy(ir::value *dst, ir::value *src) {
params_[dst] = params_[src];
mts_[dst] = mts_[src];
nts_[dst] = nts_[src];
fpw_[dst] = fpw_[src];
wpt_[dst] = wpt_[src];
groups_[dst] = groups_[src];
fragments_[{dst, 0}] = fragments_[{src, 0}];
}
@@ -217,17 +211,16 @@ void grids::run(ir::module &mod) {
for(auto x: nodes_)
fragments_[x] = get_fragmentation_type(x, dependencies_);
while(!nodes_.empty()) {
ir::type *ty = mod.get_builder().get_int32_ty();
node_t node = *nodes_.begin();
if(fragments_[node] == STRIDED_SCAN) {
ir::metaparameter *nts = ir::metaparameter::create(ctx, ty, 1, 1);
ir::metaparameter *mts = ir::metaparameter::create(ctx, ty, 1, 1);
connected_components(node, {nts, mts}, {"nts", "mts"}, nodes_, dependencies_, group_id++);
param_ptr_t nts(new int(-1));
param_ptr_t mts(new int(-1));
connected_components(node, {nts, mts}, {&nts_, &mts_}, nodes_, dependencies_, group_id++);
}
else {
ir::metaparameter *fpw = ir::metaparameter::create(ctx, ty, 1, 1);
ir::metaparameter *wpt = ir::metaparameter::create(ctx, ty, 1, 1);
connected_components(node, {fpw, wpt}, {"fpw", "wpt"}, nodes_, dependencies_, group_id++);
param_ptr_t fpw(new int(-1));
param_ptr_t wpt(new int(-1));
connected_components(node, {fpw, wpt}, {&fpw_, &wpt_}, nodes_, dependencies_, group_id++);
}
}
}
@@ -267,7 +260,7 @@ void grids::run(ir::module &mod) {
}while(fpw_nm1 != fpw);
// store parameters
for(unsigned d = 0; d < shapes.size(); d++)
params_.at(i).at("fpw.d" + std::to_string(d))->set_value(fpw[d]);
*fpw_[i][d] = fpw[d];
/* warps per tile */
// try to make things as square as possible to maximize data re-use
@@ -282,14 +275,12 @@ void grids::run(ir::module &mod) {
}while(wpt_nm1 != wpt);
// store parameters
for(unsigned d = 0; d < shapes.size(); d++)
params_.at(i).at("wpt.d" + std::to_string(d))->set_value(wpt[d]);
*wpt_[i][d] = wpt[d];
/* sanity check */
unsigned effective_num_warps = 1;
for(size_t d = 0; d < shapes.size(); d++){
std::string str_d = std::to_string(d);
effective_num_warps *= params_.at(i).at("wpt.d" + str_d)->get_value();
}
for(size_t d = 0; d < shapes.size(); d++)
effective_num_warps *= *wpt_[i][d];
if(num_warps_ != effective_num_warps)
throw std::runtime_error("cannot create a kernel with this amount of warps");
@@ -299,28 +290,20 @@ void grids::run(ir::module &mod) {
/* Scan-line */
else{
unsigned ld = order[0];
std::string s_ld = std::to_string(ld);
unsigned current = num_threads;
std::string nts = "nts.d" + s_ld;
std::string mts = "mts.d" + s_ld;
params_.at(i).at(nts)->set_value(clamp(size / num_threads, 1, 4));
params_.at(i).at(mts)->set_value(clamp(current, 1, shapes[ld] / params_.at(i).at(nts)->get_value()));
current = current / params_.at(i).at(mts)->get_value();
*nts_[i][ld] = clamp(size / num_threads, 1, 4);
*mts_[i][ld] = clamp(current, 1, shapes[ld] / *nts_[i][ld]);
current = current / *mts_[i][ld];
for(size_t d = 1; d < shapes.size(); d++){
ld = order[d];
s_ld = std::to_string(ld);
nts = "nts.d" + s_ld;
mts = "mts.d" + s_ld;
params_.at(i).at(nts)->set_value(1);
params_.at(i).at(mts)->set_value(clamp(current, 1, shapes[ld]));
current = current / params_.at(i).at(mts)->get_value();
*nts_[i][ld] = 1;
*mts_[i][ld] = clamp(current, 1, shapes[ld]);
current = current / *mts_[i][ld];
}
/* sanity check */
unsigned effective_num_threads = 1;
for(size_t d = 0; d < shapes.size(); d++){
std::string str_d = std::to_string(d);
effective_num_threads *= params_.at(i).at("mts.d" + str_d)->get_value();
}
for(size_t d = 0; d < shapes.size(); d++)
effective_num_threads *= *mts_[i][d];
if(num_threads != effective_num_threads)
throw std::runtime_error("cannot create a kernel with this amount of warps");
}
@@ -378,6 +361,21 @@ unsigned grids::get_num_threads() {
return num_warps_*32;
}
int grids::get_mts(ir::value *value, unsigned ax) {
return *mts_.at(value).at(ax);
}
int grids::get_nts(ir::value *value, unsigned ax) {
return *nts_.at(value).at(ax);
}
int grids::get_fpw(ir::value *value, unsigned ax) {
return *fpw_.at(value).at(ax);
}
int grids::get_wpt(ir::value *value, unsigned ax) {
return *wpt_.at(value).at(ax);
}
}
}