[codegen] more cleaning

This commit is contained in:
Philippe Tillet
2019-10-09 15:05:44 -04:00
parent 10ab94d1c5
commit 9bc6df4fd1
10 changed files with 226 additions and 252 deletions

View File

@@ -577,37 +577,36 @@ inline int32_t ceil(int32_t num, int32_t div){
return (num + div - 1)/div;
}
void selection::init_strided_scan_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
auto order = tiles_->order(v);
const auto& shapes = v->get_type()->get_tile_shapes();
void selection::init_strided_scan_axes(const analysis::layout_t& layout, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
auto order = layout.order;
const auto& shapes = layout.shapes;
size_t dim = shapes.size();
std::vector<unsigned> contiguous(dim);
std::vector<unsigned> block_size(dim);
std::vector<unsigned> nts(dim);
std::vector<unsigned> mts(dim);
for(unsigned i = 0; i < shapes.size(); i++){
contiguous[i] = tiles_->nts(v, i);
block_size[i] = tiles_->mts(v, i);
nts[i] = tiles_->nts(layout.i, i);
mts[i] = tiles_->mts(layout.i, i);
}
Value* full_thread_id = builder.CreateAdd(builder.CreateMul(u_warp_id, builder.getInt32(32)), u_thread_id);
std::vector<Value*> thread_id = delinearize(full_thread_id, order, block_size, builder);
std::vector<Value*> thread_id = delinearize(full_thread_id, order, mts, builder);
// Create axes
for(unsigned k = 0; k < dim; k++) {
std::string str_k = std::to_string(k);
Value *contiguous_k = builder.getInt32(contiguous[k]);
Value *contiguous_k = builder.getInt32(nts[k]);
Value *scaled_thread_id = builder.CreateMul(thread_id[k], contiguous_k);
unsigned per_block = contiguous[k] * block_size[k];
unsigned per_thread = contiguous[k] * shapes[k] / per_block;
unsigned per_block = nts[k] * mts[k];
unsigned per_thread = nts[k] * shapes[k] / per_block;
std::vector<Value*> idx_list(per_thread);
for(unsigned n = 0 ; n < per_thread; n++){
unsigned offset = n / contiguous[k] * per_block + n % contiguous[k];
unsigned offset = n / nts[k] * per_block + n % nts[k];
idx_list[n] = builder.CreateAdd(scaled_thread_id, builder.getInt32(offset), "idx_" + str_k + "_" + std::to_string(n));
}
axes_[a_axes_->get(v, k)] = distributed_axis{contiguous[k], idx_list, thread_id[k]};
axes_[layout.axes[k]] = distributed_axis{nts[k], idx_list, thread_id[k]};
}
}
void selection::init_hmma_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
// auto order = reorder_->get_order(v);
const auto& shapes = v->get_type()->get_tile_shapes();
void selection::init_hmma_axes(const analysis::layout_t& layout, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
const auto& shapes = layout.shapes;
if(shapes.size() > 3)
throw std::runtime_error("unsupported");
bool is_batched = shapes.size() >= 3;
@@ -619,13 +618,13 @@ void selection::init_hmma_axes(ir::value *v, IRBuilder<> &builder, Value *u_thre
Value *_16 = builder.getInt32(16);
// fragments per warp
unsigned fpw_0 = tiles_->fpw(v, 0);
unsigned fpw_1 = tiles_->fpw(v, 1);
unsigned fpw_2 = is_batched ? tiles_->fpw(v, 2) : 1;
unsigned fpw_0 = tiles_->fpw(layout.i, 0);
unsigned fpw_1 = tiles_->fpw(layout.i, 1);
unsigned fpw_2 = is_batched ? tiles_->fpw(layout.i, 2) : 1;
// warps per tile
unsigned wpt_0 = tiles_->wpt(v, 0);
unsigned wpt_1 = tiles_->wpt(v, 1);
unsigned wpt_2 = is_batched ? tiles_->wpt(v, 2) : 1;
unsigned wpt_0 = tiles_->wpt(layout.i, 0);
unsigned wpt_1 = tiles_->wpt(layout.i, 1);
unsigned wpt_2 = is_batched ? tiles_->wpt(layout.i, 2) : 1;
// hmma warp tile size
unsigned hmma_wts_0 = fpw_0 * 8;
unsigned hmma_wts_1 = fpw_1 * 8;
@@ -706,18 +705,18 @@ void selection::init_hmma_axes(ir::value *v, IRBuilder<> &builder, Value *u_thre
/* axes */
axes_[a_axes_->get(v, 0)] = distributed_axis{1, idx_i, warp_id_0};
axes_[a_axes_->get(v, 1)] = distributed_axis{1, idx_j, warp_id_1};
axes_[layout.axes[0]] = distributed_axis{1, idx_i, warp_id_0};
axes_[layout.axes[1]] = distributed_axis{1, idx_j, warp_id_1};
if(is_batched)
axes_[a_axes_->get(v, 2)] = distributed_axis{1, idx_z, warp_id_2};
axes_[layout.axes[2]] = distributed_axis{1, idx_z, warp_id_2};
}
void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
if(tiles_->hmma(v) == analysis::HMMA_C)
init_hmma_axes(v, builder, u_thread_id, u_warp_id);
void selection::init_axes(const analysis::layout_t& layout, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
if(layout.type == analysis::HMMA_884)
init_hmma_axes(layout, builder, u_thread_id, u_warp_id);
else
init_strided_scan_axes(v, builder, u_thread_id, u_warp_id);
init_strided_scan_axes(layout, builder, u_thread_id, u_warp_id);
}
/* -------------------
@@ -727,7 +726,7 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id
void selection::create_shared_tile(ir::value *v, IRBuilder<> &builder, Value *sh_mem_ptr) {
if(tmap_.find(v) != tmap_.end())
return;
auto order = tiles_->order(v);
auto order = layouts_->get(v).order;
auto shapes = v->get_type()->get_tile_shapes();
unsigned pad = liveness_->get_pad(v);
if(pad > 0)
@@ -777,7 +776,7 @@ void selection::create_distributed_tile(ir::value *v, IRBuilder<> &builder) {
axes[d].values = {builder.getInt32(0)};
}
}
distributed_tile *T = new distributed_tile(ty, shapes, tiles_->order(v), axes, builder, false);
distributed_tile *T = new distributed_tile(ty, shapes, layouts_->get(v).order, axes, builder, false);
bool is_inserted = tmap_.insert({v, T}).second;
// constant range
if(is_inserted && dynamic_cast<ir::make_range*>(v)){
@@ -820,7 +819,7 @@ void selection::init_layouts(ir::function *fn, IRBuilder<> &builder, Value *sh_m
Value *u_thread_warp_id = builder.CreateURem(u_thread_id, warp_size);
Value *u_warp_id = builder.CreateUDiv(u_thread_id, warp_size);
// create grid
for(auto x: tiles_->largest())
for(auto x: layouts_->get_all())
init_axes(x.second, builder, u_thread_warp_id, u_warp_id);
// create tile
std::set<ir::value*> seen;
@@ -868,7 +867,7 @@ void selection::lower_masked_store(ir::masked_store_inst *x, LLVMContext &ctx, F
void selection::lower_store(ir::store_inst *x, LLVMContext &ctx, Function *fn, IRBuilder<> &builder) {
distributed_tile* ptrs = (distributed_tile*)tmap_.at(x->get_pointer_operand());
tile *scalars = tmap_.at(x->get_value_operand());
// size_t ld = tiles_->order(x->get_pointer_operand())[0];
// size_t ld = layouts_->order(x->get_pointer_operand())[0];
// unsigned vector_size = 2;
// // vectorize pointers
// std::map<unsigned, Value*> ptr_packets;
@@ -1015,9 +1014,9 @@ void selection::lower_broadcast(ir::broadcast_inst *x, LLVMContext &ctx, Functio
void selection::lower_copy_to_shared(ir::copy_to_shared_inst *x, LLVMContext &ctx, Function *fn, IRBuilder<> &builder) {
unsigned vector_size = 1;
auto x_order = tiles_->order(x);
auto x_order = layouts_->get(x).order;
ir::value *arg = x->get_operand(0);
auto arg_order = tiles_->order(arg);
auto arg_order = layouts_->get(arg).order;
// tiles
shared_tile* result = (shared_tile*)tmap_.at(x);
distributed_tile* in = (distributed_tile*)tmap_.at(arg);
@@ -1092,8 +1091,8 @@ void selection::lower_hmma_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *fn
Value* u_thread_id = tgt_->get_local_id(builder.GetInsertBlock()->getModule(), builder, 0);
auto ord_a = tiles_->order(dot->get_operand(0));
auto ord_b = tiles_->order(dot->get_operand(1));
auto ord_a = layouts_->get(dot->get_operand(0)).order;
auto ord_b = layouts_->get(dot->get_operand(1)).order;
bool is_a_trans = is_trans(dot->get_operand(0));
bool is_b_trans = is_trans(dot->get_operand(1));
@@ -1255,7 +1254,7 @@ void selection::lower_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *fn, IRB
if(NK != 1) {
shared_tile *TA = (shared_tile*)tmap_.at(A);
shared_tile *TB = (shared_tile*)tmap_.at(B);
if(tiles_->hmma(dot) == analysis::HMMA_C)
if(layouts_->get(dot).type == analysis::HMMA_884)
lower_hmma_dot(dot, ctx, fn, builder, TC, TA, TB, TD, NK);
else
lower_scanline_dot(dot, ctx, fn, builder, TC, TA, TB, TD, NK, c_ty, f_mul_add);
@@ -1271,7 +1270,7 @@ void selection::lower_masked_load(ir::masked_load_inst *x, LLVMContext &ctx, Fun
// find vector size
distributed_tile* result = (distributed_tile*)tmap_.at(x);
ir::value *ptr = x->get_pointer_operand();
size_t ld = tiles_->order(ptr)[0];
size_t ld = layouts_->get(ptr).order[0];
unsigned alignment = alignment_->get(ptr, ld);
unsigned vector_size = std::min<unsigned>(result->axis(ld).contiguous, alignment);
distributed_tile *pointers = (distributed_tile*)tmap_.at(ptr);
@@ -1343,7 +1342,7 @@ void selection::lower_load(ir::load_inst *x, LLVMContext &ctx, Function *fn, IRB
distributed_tile* result = (distributed_tile*)tmap_.at(x);
// find vector size
ir::value *ptr = x->get_pointer_operand();
size_t ld = tiles_->order(ptr)[0];
size_t ld = layouts_->get(ptr).order[0];
unsigned alignment = alignment_->get(ptr, ld);
unsigned vector_size = std::min<unsigned>(result->axis(ld).contiguous, alignment);
distributed_tile *pointers = (distributed_tile*)tmap_.at(ptr);