|
|
|
@@ -571,145 +571,154 @@ inline void to_warps(const std::vector<unsigned> &bs, const std::vector<unsigned
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
|
|
|
|
|
void selection::init_strided_scan_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
|
|
|
|
|
auto order = reorder_->get_order(v);
|
|
|
|
|
const auto& shapes = v->get_type()->get_tile_shapes();
|
|
|
|
|
size_t dim = shapes.size();
|
|
|
|
|
if(params_->get_fragment(v, 0) == analysis::grids::STRIDED_SCAN){
|
|
|
|
|
std::vector<unsigned> contiguous(dim);
|
|
|
|
|
std::vector<unsigned> block_size(dim);
|
|
|
|
|
std::vector<unsigned> warp_size(dim);
|
|
|
|
|
std::vector<unsigned> n_warps(dim);
|
|
|
|
|
for(unsigned i = 0; i < shapes.size(); i++){
|
|
|
|
|
contiguous[i] = params_->get_nts(v, i);
|
|
|
|
|
block_size[i] = params_->get_mts(v, i);
|
|
|
|
|
}
|
|
|
|
|
to_warps(block_size, order, n_warps, warp_size);
|
|
|
|
|
std::vector<Value*> thread_id_in_warp = delinearize(u_thread_id, order, warp_size, builder);
|
|
|
|
|
std::vector<Value*> warp_id = delinearize(u_warp_id, order, n_warps, builder);
|
|
|
|
|
// Create axes
|
|
|
|
|
for(unsigned k = 0; k < dim; k++) {
|
|
|
|
|
std::string str_k = std::to_string(k);
|
|
|
|
|
Value *warp_size_k = builder.getInt32(warp_size[k]);
|
|
|
|
|
Value *contiguous_k = builder.getInt32(contiguous[k]);
|
|
|
|
|
Value *thread_id = builder.CreateAdd(thread_id_in_warp[k], builder.CreateMul(warp_id[k], warp_size_k));
|
|
|
|
|
Value *scaled_thread_id = builder.CreateMul(thread_id, contiguous_k);
|
|
|
|
|
unsigned per_block = contiguous[k] * warp_size[k] * n_warps[k];
|
|
|
|
|
unsigned per_thread = contiguous[k] * shapes[k] / per_block;
|
|
|
|
|
std::vector<Value*> idx_list(per_thread);
|
|
|
|
|
for(unsigned n = 0 ; n < per_thread; n++){
|
|
|
|
|
unsigned offset = n / contiguous[k] * per_block + n % contiguous[k];
|
|
|
|
|
idx_list[n] = builder.CreateAdd(scaled_thread_id, builder.getInt32(offset), "idx_" + str_k + "_" + std::to_string(n));
|
|
|
|
|
}
|
|
|
|
|
axes_[params_->get_param_group(v, k)] = distributed_axis{contiguous[k], idx_list, thread_id};
|
|
|
|
|
}
|
|
|
|
|
std::vector<unsigned> contiguous(dim);
|
|
|
|
|
std::vector<unsigned> block_size(dim);
|
|
|
|
|
std::vector<unsigned> warp_size(dim);
|
|
|
|
|
std::vector<unsigned> n_warps(dim);
|
|
|
|
|
for(unsigned i = 0; i < shapes.size(); i++){
|
|
|
|
|
contiguous[i] = params_->nts(v, i);
|
|
|
|
|
block_size[i] = params_->mts(v, i);
|
|
|
|
|
}
|
|
|
|
|
else {
|
|
|
|
|
if(shapes.size() > 3)
|
|
|
|
|
throw std::runtime_error("unsupported");
|
|
|
|
|
bool is_batched = shapes.size() >= 3;
|
|
|
|
|
|
|
|
|
|
Value *_1 = builder.getInt32(1);
|
|
|
|
|
Value *_2 = builder.getInt32(2);
|
|
|
|
|
Value *_3 = builder.getInt32(3);
|
|
|
|
|
Value *_4 = builder.getInt32(4);
|
|
|
|
|
Value *_16 = builder.getInt32(16);
|
|
|
|
|
|
|
|
|
|
// fragments per warp
|
|
|
|
|
unsigned fpw_0 = params_->get_fpw(v, 0);
|
|
|
|
|
unsigned fpw_1 = params_->get_fpw(v, 1);
|
|
|
|
|
unsigned fpw_2 = is_batched ? params_->get_fpw(v, 2) : 1;
|
|
|
|
|
// warps per tile
|
|
|
|
|
unsigned wpt_0 = params_->get_wpt(v, 0);
|
|
|
|
|
unsigned wpt_1 = params_->get_wpt(v, 1);
|
|
|
|
|
unsigned wpt_2 = is_batched ? params_->get_wpt(v, 2) : 1;
|
|
|
|
|
// hmma warp tile size
|
|
|
|
|
unsigned hmma_wts_0 = fpw_0 * 8;
|
|
|
|
|
unsigned hmma_wts_1 = fpw_1 * 8;
|
|
|
|
|
unsigned hmma_wts_2 = is_batched ? fpw_2 : 1;
|
|
|
|
|
// hmma block tile size
|
|
|
|
|
unsigned hmma_bts_0 = hmma_wts_0 * wpt_0;
|
|
|
|
|
unsigned hmma_bts_1 = hmma_wts_1 * wpt_1;
|
|
|
|
|
unsigned hmma_bts_2 = is_batched ? hmma_wts_2 * wpt_2 : 1;
|
|
|
|
|
// number of repetition
|
|
|
|
|
unsigned num_rep_0 = shapes[0] / hmma_bts_0;
|
|
|
|
|
unsigned num_rep_1 = shapes[1] / hmma_bts_1;
|
|
|
|
|
unsigned num_rep_2 = is_batched ? shapes[2] / hmma_bts_2 : 1;
|
|
|
|
|
// size of each pack (interleaving)
|
|
|
|
|
pack_size_0_ = std::min<unsigned>(num_rep_0, 1);
|
|
|
|
|
pack_size_1_ = std::min<unsigned>(num_rep_1, 1);
|
|
|
|
|
// number of packs (interleaving)
|
|
|
|
|
num_packs_0_ = num_rep_0 / pack_size_0_;
|
|
|
|
|
num_packs_1_ = num_rep_1 / pack_size_1_;
|
|
|
|
|
|
|
|
|
|
/* intra warp offset */
|
|
|
|
|
// offset of quad in pair
|
|
|
|
|
Value *in_pair_off_a = builder.CreateMul(builder.CreateUDiv(builder.CreateAnd(u_thread_id, _16), builder.getInt32(4)),
|
|
|
|
|
builder.getInt32(fpw_0 * pack_size_0_));
|
|
|
|
|
Value *in_pair_off_b = builder.CreateMul(builder.CreateUDiv(builder.CreateAnd(u_thread_id, _16), builder.getInt32(4)),
|
|
|
|
|
builder.getInt32(fpw_1 * pack_size_1_));
|
|
|
|
|
|
|
|
|
|
// Quad pair id
|
|
|
|
|
Value *pair_a_id = builder.CreateUDiv(builder.CreateURem(u_thread_id, _16), _4);
|
|
|
|
|
Value *pair_b_id = builder.CreateUDiv(builder.CreateURem(u_thread_id, _16), _4);
|
|
|
|
|
pair_a_id = builder.CreateURem(pair_a_id, builder.getInt32(fpw_0));
|
|
|
|
|
pair_b_id = builder.CreateUDiv(pair_b_id, builder.getInt32(fpw_0));
|
|
|
|
|
pair_b_id = builder.CreateURem(pair_b_id, builder.getInt32(fpw_1));
|
|
|
|
|
// Quad pair offset
|
|
|
|
|
Value *pair_a_off = builder.CreateMul(pair_a_id, builder.getInt32(4 * pack_size_0_));
|
|
|
|
|
Value *pair_b_off = builder.CreateMul(pair_b_id, builder.getInt32(4 * pack_size_1_));
|
|
|
|
|
|
|
|
|
|
/* inter warp offset */
|
|
|
|
|
Value *warp_id_0 = builder.CreateURem(u_warp_id, builder.getInt32(wpt_0));
|
|
|
|
|
Value *warp_id_12 = builder.CreateUDiv(u_warp_id, builder.getInt32(wpt_0));
|
|
|
|
|
Value *warp_id_1 = builder.CreateURem(warp_id_12, builder.getInt32(wpt_1));
|
|
|
|
|
Value *warp_id_2 = builder.CreateUDiv(warp_id_12, builder.getInt32(wpt_1));
|
|
|
|
|
Value *warp_offset_i = builder.CreateMul(warp_id_0, builder.getInt32(hmma_wts_0 * pack_size_0_));
|
|
|
|
|
Value *warp_offset_j = builder.CreateMul(warp_id_1, builder.getInt32(hmma_wts_1 * pack_size_1_));
|
|
|
|
|
|
|
|
|
|
/* offsets */
|
|
|
|
|
// a offset
|
|
|
|
|
offset_a_i_ = builder.CreateAdd(warp_offset_i, builder.CreateAdd(pair_a_off, in_pair_off_a));
|
|
|
|
|
offset_a_k_ = builder.CreateAnd(u_thread_id, _3);
|
|
|
|
|
// b offsets
|
|
|
|
|
offset_b_j_ = builder.CreateAdd(warp_offset_j, builder.CreateAdd(pair_b_off, in_pair_off_b));
|
|
|
|
|
offset_b_k_ = builder.CreateAnd(u_thread_id, _3);
|
|
|
|
|
|
|
|
|
|
// c offsets
|
|
|
|
|
Value *offset_c_i = builder.CreateAdd(builder.CreateAnd(u_thread_id, _1), offset_a_i_);
|
|
|
|
|
Value *offset_c_j = builder.CreateAdd(builder.CreateAnd(u_thread_id, _2),
|
|
|
|
|
builder.CreateAdd(warp_offset_j, pair_b_off));
|
|
|
|
|
|
|
|
|
|
/* indices */
|
|
|
|
|
// i indices
|
|
|
|
|
std::vector<Value*> idx_i;
|
|
|
|
|
for(unsigned pack = 0; pack < num_packs_0_; pack++)
|
|
|
|
|
for(unsigned ii = 0; ii < pack_size_0_; ii++)
|
|
|
|
|
for(unsigned i = 0; i < 2; i++){
|
|
|
|
|
idx_i.push_back(builder.CreateAdd(offset_c_i, builder.getInt32(pack*hmma_bts_0*pack_size_0_ + ii*4 + i*2)));
|
|
|
|
|
to_warps(block_size, order, n_warps, warp_size);
|
|
|
|
|
std::vector<Value*> thread_id_in_warp = delinearize(u_thread_id, order, warp_size, builder);
|
|
|
|
|
std::vector<Value*> warp_id = delinearize(u_warp_id, order, n_warps, builder);
|
|
|
|
|
// Create axes
|
|
|
|
|
for(unsigned k = 0; k < dim; k++) {
|
|
|
|
|
std::string str_k = std::to_string(k);
|
|
|
|
|
Value *warp_size_k = builder.getInt32(warp_size[k]);
|
|
|
|
|
Value *contiguous_k = builder.getInt32(contiguous[k]);
|
|
|
|
|
Value *thread_id = builder.CreateAdd(thread_id_in_warp[k], builder.CreateMul(warp_id[k], warp_size_k));
|
|
|
|
|
Value *scaled_thread_id = builder.CreateMul(thread_id, contiguous_k);
|
|
|
|
|
unsigned per_block = contiguous[k] * warp_size[k] * n_warps[k];
|
|
|
|
|
unsigned per_thread = contiguous[k] * shapes[k] / per_block;
|
|
|
|
|
std::vector<Value*> idx_list(per_thread);
|
|
|
|
|
for(unsigned n = 0 ; n < per_thread; n++){
|
|
|
|
|
unsigned offset = n / contiguous[k] * per_block + n % contiguous[k];
|
|
|
|
|
idx_list[n] = builder.CreateAdd(scaled_thread_id, builder.getInt32(offset), "idx_" + str_k + "_" + std::to_string(n));
|
|
|
|
|
}
|
|
|
|
|
// j indices
|
|
|
|
|
std::vector<Value*> idx_j;
|
|
|
|
|
for(unsigned pack = 0; pack < num_packs_1_; pack++)
|
|
|
|
|
for(unsigned jj = 0; jj < pack_size_1_; jj++)
|
|
|
|
|
for(unsigned j = 0; j < 2; j++){
|
|
|
|
|
idx_j.push_back(builder.CreateAdd(offset_c_j, builder.getInt32(pack*hmma_bts_1*pack_size_1_ + jj*4 + j*4*fpw_1*pack_size_1_)));
|
|
|
|
|
idx_j.push_back(builder.CreateAdd(offset_c_j, builder.getInt32(pack*hmma_bts_1*pack_size_1_ + jj*4 + j*4*fpw_1*pack_size_1_ + 1)));
|
|
|
|
|
}
|
|
|
|
|
// z indices
|
|
|
|
|
std::vector<Value*> idx_z;
|
|
|
|
|
for(unsigned pack = 0; pack < num_rep_2; pack++)
|
|
|
|
|
idx_z.push_back(builder.CreateAdd(warp_id_2, builder.getInt32(pack*hmma_bts_2)));
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
/* axes */
|
|
|
|
|
axes_[params_->get_param_group(v, 0)] = distributed_axis{1, idx_i, warp_id_0};
|
|
|
|
|
axes_[params_->get_param_group(v, 1)] = distributed_axis{1, idx_j, warp_id_1};
|
|
|
|
|
if(is_batched)
|
|
|
|
|
axes_[params_->get_param_group(v, 2)] = distributed_axis{1, idx_z, warp_id_2};
|
|
|
|
|
axes_[params_->get_param_group(v, k)] = distributed_axis{contiguous[k], idx_list, thread_id};
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void selection::init_hmma_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
|
|
|
|
|
// auto order = reorder_->get_order(v);
|
|
|
|
|
const auto& shapes = v->get_type()->get_tile_shapes();
|
|
|
|
|
if(shapes.size() > 3)
|
|
|
|
|
throw std::runtime_error("unsupported");
|
|
|
|
|
bool is_batched = shapes.size() >= 3;
|
|
|
|
|
|
|
|
|
|
Value *_1 = builder.getInt32(1);
|
|
|
|
|
Value *_2 = builder.getInt32(2);
|
|
|
|
|
Value *_3 = builder.getInt32(3);
|
|
|
|
|
Value *_4 = builder.getInt32(4);
|
|
|
|
|
Value *_16 = builder.getInt32(16);
|
|
|
|
|
|
|
|
|
|
// fragments per warp
|
|
|
|
|
unsigned fpw_0 = params_->fpw(v, 0);
|
|
|
|
|
unsigned fpw_1 = params_->fpw(v, 1);
|
|
|
|
|
unsigned fpw_2 = is_batched ? params_->fpw(v, 2) : 1;
|
|
|
|
|
// warps per tile
|
|
|
|
|
unsigned wpt_0 = params_->wpt(v, 0);
|
|
|
|
|
unsigned wpt_1 = params_->wpt(v, 1);
|
|
|
|
|
unsigned wpt_2 = is_batched ? params_->wpt(v, 2) : 1;
|
|
|
|
|
// hmma warp tile size
|
|
|
|
|
unsigned hmma_wts_0 = fpw_0 * 8;
|
|
|
|
|
unsigned hmma_wts_1 = fpw_1 * 8;
|
|
|
|
|
unsigned hmma_wts_2 = is_batched ? fpw_2 : 1;
|
|
|
|
|
// hmma block tile size
|
|
|
|
|
unsigned hmma_bts_0 = hmma_wts_0 * wpt_0;
|
|
|
|
|
unsigned hmma_bts_1 = hmma_wts_1 * wpt_1;
|
|
|
|
|
unsigned hmma_bts_2 = is_batched ? hmma_wts_2 * wpt_2 : 1;
|
|
|
|
|
// number of repetition
|
|
|
|
|
unsigned num_rep_0 = shapes[0] / hmma_bts_0;
|
|
|
|
|
unsigned num_rep_1 = shapes[1] / hmma_bts_1;
|
|
|
|
|
unsigned num_rep_2 = is_batched ? shapes[2] / hmma_bts_2 : 1;
|
|
|
|
|
// size of each pack (interleaving)
|
|
|
|
|
pack_size_0_ = std::min<unsigned>(num_rep_0, 1);
|
|
|
|
|
pack_size_1_ = std::min<unsigned>(num_rep_1, 1);
|
|
|
|
|
// number of packs (interleaving)
|
|
|
|
|
num_packs_0_ = num_rep_0 / pack_size_0_;
|
|
|
|
|
num_packs_1_ = num_rep_1 / pack_size_1_;
|
|
|
|
|
|
|
|
|
|
/* intra warp offset */
|
|
|
|
|
// offset of quad in pair
|
|
|
|
|
Value *in_pair_off_a = builder.CreateMul(builder.CreateUDiv(builder.CreateAnd(u_thread_id, _16), builder.getInt32(4)),
|
|
|
|
|
builder.getInt32(fpw_0 * pack_size_0_));
|
|
|
|
|
Value *in_pair_off_b = builder.CreateMul(builder.CreateUDiv(builder.CreateAnd(u_thread_id, _16), builder.getInt32(4)),
|
|
|
|
|
builder.getInt32(fpw_1 * pack_size_1_));
|
|
|
|
|
|
|
|
|
|
// Quad pair id
|
|
|
|
|
Value *pair_a_id = builder.CreateUDiv(builder.CreateURem(u_thread_id, _16), _4);
|
|
|
|
|
Value *pair_b_id = builder.CreateUDiv(builder.CreateURem(u_thread_id, _16), _4);
|
|
|
|
|
pair_a_id = builder.CreateURem(pair_a_id, builder.getInt32(fpw_0));
|
|
|
|
|
pair_b_id = builder.CreateUDiv(pair_b_id, builder.getInt32(fpw_0));
|
|
|
|
|
pair_b_id = builder.CreateURem(pair_b_id, builder.getInt32(fpw_1));
|
|
|
|
|
// Quad pair offset
|
|
|
|
|
Value *pair_a_off = builder.CreateMul(pair_a_id, builder.getInt32(4 * pack_size_0_));
|
|
|
|
|
Value *pair_b_off = builder.CreateMul(pair_b_id, builder.getInt32(4 * pack_size_1_));
|
|
|
|
|
|
|
|
|
|
/* inter warp offset */
|
|
|
|
|
Value *warp_id_0 = builder.CreateURem(u_warp_id, builder.getInt32(wpt_0));
|
|
|
|
|
Value *warp_id_12 = builder.CreateUDiv(u_warp_id, builder.getInt32(wpt_0));
|
|
|
|
|
Value *warp_id_1 = builder.CreateURem(warp_id_12, builder.getInt32(wpt_1));
|
|
|
|
|
Value *warp_id_2 = builder.CreateUDiv(warp_id_12, builder.getInt32(wpt_1));
|
|
|
|
|
Value *warp_offset_i = builder.CreateMul(warp_id_0, builder.getInt32(hmma_wts_0 * pack_size_0_));
|
|
|
|
|
Value *warp_offset_j = builder.CreateMul(warp_id_1, builder.getInt32(hmma_wts_1 * pack_size_1_));
|
|
|
|
|
|
|
|
|
|
/* offsets */
|
|
|
|
|
// a offset
|
|
|
|
|
offset_a_i_ = builder.CreateAdd(warp_offset_i, builder.CreateAdd(pair_a_off, in_pair_off_a));
|
|
|
|
|
offset_a_k_ = builder.CreateAnd(u_thread_id, _3);
|
|
|
|
|
// b offsets
|
|
|
|
|
offset_b_j_ = builder.CreateAdd(warp_offset_j, builder.CreateAdd(pair_b_off, in_pair_off_b));
|
|
|
|
|
offset_b_k_ = builder.CreateAnd(u_thread_id, _3);
|
|
|
|
|
|
|
|
|
|
// c offsets
|
|
|
|
|
Value *offset_c_i = builder.CreateAdd(builder.CreateAnd(u_thread_id, _1), offset_a_i_);
|
|
|
|
|
Value *offset_c_j = builder.CreateAdd(builder.CreateAnd(u_thread_id, _2),
|
|
|
|
|
builder.CreateAdd(warp_offset_j, pair_b_off));
|
|
|
|
|
|
|
|
|
|
/* indices */
|
|
|
|
|
// i indices
|
|
|
|
|
std::vector<Value*> idx_i;
|
|
|
|
|
for(unsigned pack = 0; pack < num_packs_0_; pack++)
|
|
|
|
|
for(unsigned ii = 0; ii < pack_size_0_; ii++)
|
|
|
|
|
for(unsigned i = 0; i < 2; i++){
|
|
|
|
|
idx_i.push_back(builder.CreateAdd(offset_c_i, builder.getInt32(pack*hmma_bts_0*pack_size_0_ + ii*4 + i*2)));
|
|
|
|
|
}
|
|
|
|
|
// j indices
|
|
|
|
|
std::vector<Value*> idx_j;
|
|
|
|
|
for(unsigned pack = 0; pack < num_packs_1_; pack++)
|
|
|
|
|
for(unsigned jj = 0; jj < pack_size_1_; jj++)
|
|
|
|
|
for(unsigned j = 0; j < 2; j++){
|
|
|
|
|
idx_j.push_back(builder.CreateAdd(offset_c_j, builder.getInt32(pack*hmma_bts_1*pack_size_1_ + jj*4 + j*4*fpw_1*pack_size_1_)));
|
|
|
|
|
idx_j.push_back(builder.CreateAdd(offset_c_j, builder.getInt32(pack*hmma_bts_1*pack_size_1_ + jj*4 + j*4*fpw_1*pack_size_1_ + 1)));
|
|
|
|
|
}
|
|
|
|
|
// z indices
|
|
|
|
|
std::vector<Value*> idx_z;
|
|
|
|
|
for(unsigned pack = 0; pack < num_rep_2; pack++)
|
|
|
|
|
idx_z.push_back(builder.CreateAdd(warp_id_2, builder.getInt32(pack*hmma_bts_2)));
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
/* axes */
|
|
|
|
|
axes_[params_->get_param_group(v, 0)] = distributed_axis{1, idx_i, warp_id_0};
|
|
|
|
|
axes_[params_->get_param_group(v, 1)] = distributed_axis{1, idx_j, warp_id_1};
|
|
|
|
|
if(is_batched)
|
|
|
|
|
axes_[params_->get_param_group(v, 2)] = distributed_axis{1, idx_z, warp_id_2};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
|
|
|
|
|
if(params_->get_fragment(v, 0) == analysis::grids::STRIDED_SCAN)
|
|
|
|
|
init_strided_scan_axes(v, builder, u_thread_id, u_warp_id);
|
|
|
|
|
else
|
|
|
|
|
init_hmma_axes(v, builder, u_thread_id, u_warp_id);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool static inline has_phi_user(ir::value *v) {
|
|
|
|
|
for(ir::user *usr: v->get_users()){
|
|
|
|
|
if(dynamic_cast<ir::phi_node*>(usr))
|
|
|
|
@@ -717,94 +726,97 @@ bool static inline has_phi_user(ir::value *v) {
|
|
|
|
|
}
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void selection::create_shared_tile(ir::value *v, IRBuilder<> &builder, Value *sh_mem_ptr) {
|
|
|
|
|
auto shapes = v->get_type()->get_tile_shapes();
|
|
|
|
|
unsigned pad = alloc_->is_ld_padded(v);
|
|
|
|
|
if(pad > 0)
|
|
|
|
|
shapes[0] += pad;
|
|
|
|
|
Type* ty = llvm_type(v->get_type()->get_scalar_ty(), builder.getContext());
|
|
|
|
|
// shared copy
|
|
|
|
|
PointerType *ptr_ty = ty->getPointerTo(sh_mem_ptr->getType()->getPointerAddressSpace());
|
|
|
|
|
// phi-node (double-buffering)
|
|
|
|
|
if(auto *phi = dynamic_cast<ir::phi_node*>(v)) {
|
|
|
|
|
BasicBlock *parent = (BasicBlock*)vmap_[phi->get_parent()];
|
|
|
|
|
unsigned id_pre = 0, id_loop = 1;
|
|
|
|
|
if(phi->get_incoming_block(0) == phi->get_parent())
|
|
|
|
|
std::swap(id_pre, id_loop);
|
|
|
|
|
if(parent->empty())
|
|
|
|
|
builder.SetInsertPoint(parent);
|
|
|
|
|
else
|
|
|
|
|
builder.SetInsertPoint(&*parent->getFirstInsertionPt());
|
|
|
|
|
PHINode *ptr = builder.CreatePHI(ptr_ty, 2);
|
|
|
|
|
PHINode *offset = builder.CreatePHI(builder.getInt32Ty(), 2);
|
|
|
|
|
// next pointer
|
|
|
|
|
Value *pre_ptr = builder.CreateGEP(sh_mem_ptr, builder.getInt32(alloc_->get_offset(phi)));
|
|
|
|
|
pre_ptr = builder.CreateBitCast(pre_ptr, ptr->getType());
|
|
|
|
|
Value *next_ptr = builder.CreateGEP(ptr, offset, "next_ptr");
|
|
|
|
|
tmap_.insert({phi, new shared_tile(ty, shapes, ptr, builder, offset)});
|
|
|
|
|
for(unsigned i = 0; i < phi->get_num_incoming(); i++) {
|
|
|
|
|
ir::basic_block* inc_block = phi->get_incoming_block(i);
|
|
|
|
|
ir::value* inc_value = phi->get_incoming_value(i);
|
|
|
|
|
ir::instruction* terminator = inc_block->get_inst_list().back();
|
|
|
|
|
bool is_loop_latch = buffer_info_->is_loop_latch(phi, terminator);
|
|
|
|
|
tmap_.insert({inc_value, new shared_tile(ty, shapes, is_loop_latch?next_ptr:pre_ptr, builder)});
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
else {
|
|
|
|
|
if(!has_phi_user(v)){
|
|
|
|
|
size_t offset = alloc_->get_offset(v);
|
|
|
|
|
Value *ptr = builder.CreateGEP(sh_mem_ptr, builder.getInt32(offset));
|
|
|
|
|
ptr = builder.CreateBitCast(ptr, ptr_ty);
|
|
|
|
|
tmap_.insert({v, new shared_tile(ty, shapes, ptr, builder)});
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void selection::create_distributed_tile(ir::value *v, IRBuilder<> &builder) {
|
|
|
|
|
Type* ty = llvm_type(v->get_type()->get_scalar_ty(), builder.getContext());
|
|
|
|
|
const auto &shapes = v->get_type()->get_tile_shapes();
|
|
|
|
|
std::vector<distributed_axis> axes(shapes.size());
|
|
|
|
|
for(size_t d = 0; d < shapes.size(); d++){
|
|
|
|
|
if(shapes[d] > 1){
|
|
|
|
|
unsigned x = params_->get_param_group(v, d);
|
|
|
|
|
axes[d] = axes_.at(x);
|
|
|
|
|
}
|
|
|
|
|
else{
|
|
|
|
|
axes[d].contiguous = 1;
|
|
|
|
|
axes[d].values = {builder.getInt32(0)};
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
bool vectorize = dynamic_cast<ir::vectorize_inst*>(v);
|
|
|
|
|
distributed_tile *T = new distributed_tile(ty, shapes, axes, builder, vectorize);
|
|
|
|
|
bool is_inserted = tmap_.insert({v, T}).second;
|
|
|
|
|
// constant range
|
|
|
|
|
if(is_inserted && dynamic_cast<ir::make_range*>(v)){
|
|
|
|
|
T->for_each([&](indices_t idx){
|
|
|
|
|
assert(idx.size() == 1);
|
|
|
|
|
T->set_value(idx, idx[0]);
|
|
|
|
|
});
|
|
|
|
|
}
|
|
|
|
|
if(is_inserted && dynamic_cast<ir::make_range_sta*>(v)){
|
|
|
|
|
T->for_each([&](indices_t idx){
|
|
|
|
|
assert(idx.size() == 1);
|
|
|
|
|
BinaryOperator *bin_add = dyn_cast<BinaryOperator>(idx[0]);
|
|
|
|
|
assert(bin_add);
|
|
|
|
|
Value *res = bin_add->getOperand(1);
|
|
|
|
|
assert(isa<Constant>(res));
|
|
|
|
|
T->set_value(idx, res);
|
|
|
|
|
});
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void selection::create_tile(ir::value *v, IRBuilder<> &builder,
|
|
|
|
|
std::set<ir::value*> &seen, Value *sh_mem_ptr) {
|
|
|
|
|
if(!v->get_type()->is_tile_ty() || !seen.insert(v).second)
|
|
|
|
|
return;
|
|
|
|
|
if(auto *user = dynamic_cast<ir::user*>(v))
|
|
|
|
|
for(ir::value *op: user->ops()){
|
|
|
|
|
for(ir::value *op: user->ops())
|
|
|
|
|
create_tile(op, builder, seen, sh_mem_ptr);
|
|
|
|
|
}
|
|
|
|
|
LLVMContext &ctx = builder.getContext();
|
|
|
|
|
auto shapes = v->get_type()->get_tile_shapes();
|
|
|
|
|
unsigned pad = alloc_->is_ld_padded(v);
|
|
|
|
|
if(pad > 0)
|
|
|
|
|
shapes[0] += pad;
|
|
|
|
|
Type* ty = llvm_type(v->get_type()->get_scalar_ty(), ctx);
|
|
|
|
|
// create shared tile
|
|
|
|
|
if(buffer_info_->is_shared(v) && !dynamic_cast<ir::reduce_inst*>(v)){
|
|
|
|
|
// shared copy
|
|
|
|
|
PointerType *ptr_ty = ty->getPointerTo(sh_mem_ptr->getType()->getPointerAddressSpace());
|
|
|
|
|
// phi-node (double-buffering)
|
|
|
|
|
if(auto *phi = dynamic_cast<ir::phi_node*>(v)) {
|
|
|
|
|
BasicBlock *parent = (BasicBlock*)vmap_[phi->get_parent()];
|
|
|
|
|
unsigned id_pre = 0, id_loop = 1;
|
|
|
|
|
if(phi->get_incoming_block(0) == phi->get_parent())
|
|
|
|
|
std::swap(id_pre, id_loop);
|
|
|
|
|
if(parent->empty())
|
|
|
|
|
builder.SetInsertPoint(parent);
|
|
|
|
|
else
|
|
|
|
|
builder.SetInsertPoint(&*parent->getFirstInsertionPt());
|
|
|
|
|
PHINode *ptr = builder.CreatePHI(ptr_ty, 2);
|
|
|
|
|
PHINode *offset = builder.CreatePHI(builder.getInt32Ty(), 2);
|
|
|
|
|
// next pointer
|
|
|
|
|
Value *pre_ptr = builder.CreateGEP(sh_mem_ptr, builder.getInt32(alloc_->get_offset(phi)));
|
|
|
|
|
pre_ptr = builder.CreateBitCast(pre_ptr, ptr->getType());
|
|
|
|
|
Value *next_ptr = builder.CreateGEP(ptr, offset, "next_ptr");
|
|
|
|
|
tmap_.insert({phi, new shared_tile(ty, shapes, ptr, builder, offset)});
|
|
|
|
|
for(unsigned i = 0; i < phi->get_num_incoming(); i++) {
|
|
|
|
|
ir::basic_block* inc_block = phi->get_incoming_block(i);
|
|
|
|
|
ir::value* inc_value = phi->get_incoming_value(i);
|
|
|
|
|
ir::instruction* terminator = inc_block->get_inst_list().back();
|
|
|
|
|
bool is_loop_latch = buffer_info_->is_loop_latch(phi, terminator);
|
|
|
|
|
tmap_.insert({inc_value, new shared_tile(ty, shapes, is_loop_latch?next_ptr:pre_ptr, builder)});
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
else {
|
|
|
|
|
if(!has_phi_user(v)){
|
|
|
|
|
size_t offset = alloc_->get_offset(v);
|
|
|
|
|
Value *ptr = builder.CreateGEP(sh_mem_ptr, builder.getInt32(offset));
|
|
|
|
|
ptr = builder.CreateBitCast(ptr, ptr_ty);
|
|
|
|
|
tmap_.insert({v, new shared_tile(ty, shapes, ptr, builder)});
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// create distributed tile
|
|
|
|
|
else {
|
|
|
|
|
const auto &shapes = v->get_type()->get_tile_shapes();
|
|
|
|
|
std::vector<distributed_axis> axes(shapes.size());
|
|
|
|
|
for(size_t d = 0; d < shapes.size(); d++){
|
|
|
|
|
if(shapes[d] > 1){
|
|
|
|
|
unsigned x = params_->get_param_group(v, d);
|
|
|
|
|
axes[d] = axes_.at(x);
|
|
|
|
|
}
|
|
|
|
|
else{
|
|
|
|
|
axes[d].contiguous = 1;
|
|
|
|
|
axes[d].values = {builder.getInt32(0)};
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
bool vectorize = dynamic_cast<ir::vectorize_inst*>(v);
|
|
|
|
|
distributed_tile *T = new distributed_tile(ty, shapes, axes, builder, vectorize);
|
|
|
|
|
bool is_inserted = tmap_.insert({v, T}).second;
|
|
|
|
|
// constant range
|
|
|
|
|
if(is_inserted && dynamic_cast<ir::make_range*>(v)){
|
|
|
|
|
T->for_each([&](indices_t idx){
|
|
|
|
|
assert(idx.size() == 1);
|
|
|
|
|
T->set_value(idx, idx[0]);
|
|
|
|
|
});
|
|
|
|
|
}
|
|
|
|
|
if(is_inserted && dynamic_cast<ir::make_range_sta*>(v)){
|
|
|
|
|
T->for_each([&](indices_t idx){
|
|
|
|
|
assert(idx.size() == 1);
|
|
|
|
|
BinaryOperator *bin_add = dyn_cast<BinaryOperator>(idx[0]);
|
|
|
|
|
assert(bin_add);
|
|
|
|
|
Value *res = bin_add->getOperand(1);
|
|
|
|
|
assert(isa<Constant>(res));
|
|
|
|
|
T->set_value(idx, res);
|
|
|
|
|
});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
}
|
|
|
|
|
if(buffer_info_->is_shared(v) && !dynamic_cast<ir::reduce_inst*>(v))
|
|
|
|
|
create_shared_tile(v, builder, sh_mem_ptr);
|
|
|
|
|
else
|
|
|
|
|
create_distributed_tile(v, builder);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void selection::init_grids(ir::function *fn, IRBuilder<> &builder, Value *sh_mem_ptr){
|
|
|
|
@@ -908,7 +920,7 @@ void selection::lower_reduce(ir::reduce_inst *x, LLVMContext &ctx, Function *fn,
|
|
|
|
|
tgt_->add_barrier(module, builder);
|
|
|
|
|
builder.CreateStore(result, write_ptr);
|
|
|
|
|
// build result
|
|
|
|
|
unsigned depth = params_->get_wpt(op, axis);
|
|
|
|
|
unsigned depth = params_->wpt(op, axis);
|
|
|
|
|
for(unsigned i = depth/2; i > 0; i >>= 1){
|
|
|
|
|
// current indices
|
|
|
|
|
indices_t current(write_idx.size(), builder.getInt32(0));
|
|
|
|
@@ -1075,12 +1087,12 @@ void selection::lower_hmma_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *fn
|
|
|
|
|
"{$10, $11}, "
|
|
|
|
|
"{$0, $1, $2, $3, $4, $5, $6, $7};", "=f,=f,=f,=f,=f,=f,=f,=f,r,r,r,r,0,1,2,3,4,5,6,7", false);
|
|
|
|
|
|
|
|
|
|
unsigned fpw_0 = params_->get_fpw(dot, 0);
|
|
|
|
|
unsigned fpw_1 = params_->get_fpw(dot, 1);
|
|
|
|
|
unsigned fpw_0 = params_->fpw(dot, 0);
|
|
|
|
|
unsigned fpw_1 = params_->fpw(dot, 1);
|
|
|
|
|
unsigned wts_0 = fpw_0 * 8;
|
|
|
|
|
unsigned wts_1 = fpw_1 * 8;
|
|
|
|
|
unsigned wpt_0 = params_->get_wpt(dot, 0);
|
|
|
|
|
unsigned wpt_1 = params_->get_wpt(dot, 1);
|
|
|
|
|
unsigned wpt_0 = params_->wpt(dot, 0);
|
|
|
|
|
unsigned wpt_1 = params_->wpt(dot, 1);
|
|
|
|
|
unsigned stride_rep_i = wpt_0 * wts_0;
|
|
|
|
|
unsigned stride_rep_j = wpt_1 * wts_1;
|
|
|
|
|
unsigned num_rep_i = shapes[0] / stride_rep_i;
|
|
|
|
@@ -1457,7 +1469,7 @@ void selection::run(ir::module &src, Module &dst) {
|
|
|
|
|
Metadata *md_args[] = {
|
|
|
|
|
ValueAsMetadata::get(dst_fn),
|
|
|
|
|
MDString::get(dst_ctx, "maxntidx"),
|
|
|
|
|
ValueAsMetadata::get(dst_builder.getInt32(params_->get_num_threads()))
|
|
|
|
|
ValueAsMetadata::get(dst_builder.getInt32(num_warps_*32))
|
|
|
|
|
};
|
|
|
|
|
dst.getOrInsertNamedMetadata("nvvm.annotations")->addOperand(MDNode::get(dst_ctx, md_args));
|
|
|
|
|
|
|
|
|
|