Reparameterized in terms of micro- and nano- tiles

This commit is contained in:
Philippe Tillet
2019-03-10 23:10:17 -04:00
parent c96a263896
commit 94e315ea8a
7 changed files with 62 additions and 53 deletions

View File

@@ -379,9 +379,9 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id
std::vector<unsigned> n_warps(dim);
for(unsigned i = 0; i < shapes.size(); i++){
std::string str_i = std::to_string(i);
contiguous[i] = params_->get_param(v, "p0.d" + str_i)->get_value();
warp_size[i] = params_->get_param(v, "p1.d" + str_i)->get_value();
n_warps[i] = params_->get_param(v, "p2.d" + str_i)->get_value();
contiguous[i] = params_->get_param(v, "nts.d" + str_i)->get_value();
warp_size[i] = params_->get_param(v, "mts.d" + str_i)->get_value();
n_warps[i] = shapes[i]->get_value() / (contiguous[i] * warp_size[i]);
}
std::vector<Value*> thread_id_in_warp = delinearize(u_thread_id, warp_size, builder);
std::vector<Value*> warp_id = delinearize(u_warp_id, n_warps, builder);
@@ -399,7 +399,7 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id
unsigned offset = n / contiguous[k] * per_block + n % contiguous[k];
idx_list[n] = builder.CreateAdd(thread_id, builder.getInt32(offset), "idx_" + str_k + "_" + std::to_string(n));
}
axes_[params_->get_param(v, "p0.d" + str_k)] = distributed_axis{contiguous[k], idx_list};
axes_[params_->get_param(v, "nts.d" + str_k)] = distributed_axis{contiguous[k], idx_list};
}
}
@@ -432,7 +432,7 @@ void selection::create_grids(std::vector<ir::value*> &grids,
for(size_t d = 0; d < shapes.size(); d++){
if(shapes[d]->get_value() == 1)
continue;
ir::metaparameter *x = params_->get_param(v, "p0.d" + std::to_string(d));
ir::metaparameter *x = params_->get_param(v, "nts.d" + std::to_string(d));
ir::value *&r = references[x];
if(!r || get_tile_gt1_dim(v) > get_tile_gt1_dim(r))
r = v;
@@ -517,7 +517,7 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder,
std::vector<distributed_axis> axes(shapes.size());
for(size_t d = 0; d < shapes.size(); d++){
if(shapes[d]->get_value() > 1){
ir::metaparameter *x = params_->get_param(v, "p0.d" + std::to_string(d));
ir::metaparameter *x = params_->get_param(v, "nts.d" + std::to_string(d));
axes[d] = axes_.at(x);
}
else{