Reparameterized in terms of micro- and nano- tiles
This commit is contained in:
@@ -379,9 +379,9 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id
|
||||
std::vector<unsigned> n_warps(dim);
|
||||
for(unsigned i = 0; i < shapes.size(); i++){
|
||||
std::string str_i = std::to_string(i);
|
||||
contiguous[i] = params_->get_param(v, "p0.d" + str_i)->get_value();
|
||||
warp_size[i] = params_->get_param(v, "p1.d" + str_i)->get_value();
|
||||
n_warps[i] = params_->get_param(v, "p2.d" + str_i)->get_value();
|
||||
contiguous[i] = params_->get_param(v, "nts.d" + str_i)->get_value();
|
||||
warp_size[i] = params_->get_param(v, "mts.d" + str_i)->get_value();
|
||||
n_warps[i] = shapes[i]->get_value() / (contiguous[i] * warp_size[i]);
|
||||
}
|
||||
std::vector<Value*> thread_id_in_warp = delinearize(u_thread_id, warp_size, builder);
|
||||
std::vector<Value*> warp_id = delinearize(u_warp_id, n_warps, builder);
|
||||
@@ -399,7 +399,7 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id
|
||||
unsigned offset = n / contiguous[k] * per_block + n % contiguous[k];
|
||||
idx_list[n] = builder.CreateAdd(thread_id, builder.getInt32(offset), "idx_" + str_k + "_" + std::to_string(n));
|
||||
}
|
||||
axes_[params_->get_param(v, "p0.d" + str_k)] = distributed_axis{contiguous[k], idx_list};
|
||||
axes_[params_->get_param(v, "nts.d" + str_k)] = distributed_axis{contiguous[k], idx_list};
|
||||
}
|
||||
}
|
||||
|
||||
@@ -432,7 +432,7 @@ void selection::create_grids(std::vector<ir::value*> &grids,
|
||||
for(size_t d = 0; d < shapes.size(); d++){
|
||||
if(shapes[d]->get_value() == 1)
|
||||
continue;
|
||||
ir::metaparameter *x = params_->get_param(v, "p0.d" + std::to_string(d));
|
||||
ir::metaparameter *x = params_->get_param(v, "nts.d" + std::to_string(d));
|
||||
ir::value *&r = references[x];
|
||||
if(!r || get_tile_gt1_dim(v) > get_tile_gt1_dim(r))
|
||||
r = v;
|
||||
@@ -517,7 +517,7 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder,
|
||||
std::vector<distributed_axis> axes(shapes.size());
|
||||
for(size_t d = 0; d < shapes.size(); d++){
|
||||
if(shapes[d]->get_value() > 1){
|
||||
ir::metaparameter *x = params_->get_param(v, "p0.d" + std::to_string(d));
|
||||
ir::metaparameter *x = params_->get_param(v, "nts.d" + std::to_string(d));
|
||||
axes[d] = axes_.at(x);
|
||||
}
|
||||
else{
|
||||
|
Reference in New Issue
Block a user