[codegen] more cleaning

This commit is contained in:
Philippe Tillet
2019-10-10 15:52:03 -04:00
parent a3f76b6eb1
commit 4efd0a3c6b
9 changed files with 148 additions and 144 deletions

View File

@@ -559,7 +559,7 @@ Value* selection::llvm_value(ir::value *v, IRBuilder<> &builder) {
* ------------------- */
// Grid construction
std::vector<Value*> delinearize(Value *trailing, const std::vector<int>& order, std::vector<unsigned> &shapes, IRBuilder<> &builder){
std::vector<Value*> delinearize(Value *trailing, const std::vector<int>& order, std::vector<int> &shapes, IRBuilder<> &builder){
size_t dim = shapes.size();
std::vector<Value*> result(dim);
for(unsigned k = 0; k < dim - 1; k++){
@@ -580,12 +580,8 @@ void selection::init_strided_scan_axes(const analysis::layout_t& layout, IRBuild
auto order = layout.order;
const auto& shapes = layout.shapes;
size_t dim = shapes.size();
std::vector<unsigned> nts(dim);
std::vector<unsigned> mts(dim);
for(unsigned i = 0; i < shapes.size(); i++){
nts[i] = layout.nts.at(i);
mts[i] = layout.mts.at(i);
}
std::vector<int> nts = layout.nts;
std::vector<int> mts = layout.mts;
Value* full_thread_id = builder.CreateAdd(builder.CreateMul(u_warp_id, builder.getInt32(32)), u_thread_id);
std::vector<Value*> thread_id = delinearize(full_thread_id, order, mts, builder);
// Create axes
@@ -608,6 +604,7 @@ void selection::init_hmma_axes(const analysis::layout_t& layout, IRBuilder<> &bu
const auto& shapes = layout.shapes;
if(shapes.size() > 3)
throw std::runtime_error("unsupported");
bool is_batched = shapes.size() >= 3;
Value *_1 = builder.getInt32(1);
@@ -725,7 +722,7 @@ void selection::init_axes(const analysis::layout_t& layout, IRBuilder<> &builder
void selection::create_shared_tile(ir::value *v, IRBuilder<> &builder, Value *sh_mem_ptr) {
if(tmap_.find(v) != tmap_.end())
return;
auto order = layouts_->get(v).order;
auto order = layouts_->get(v)->order;
auto shapes = v->get_type()->get_tile_shapes();
unsigned pad = liveness_->get_pad(v);
if(pad > 0)
@@ -775,7 +772,7 @@ void selection::create_distributed_tile(ir::value *v, IRBuilder<> &builder) {
axes[d].values = {builder.getInt32(0)};
}
}
distributed_tile *T = new distributed_tile(ty, shapes, layouts_->get(v).order, axes, builder, false);
distributed_tile *T = new distributed_tile(ty, shapes, layouts_->get(v)->order, axes, builder, false);
bool is_inserted = tmap_.insert({v, T}).second;
// constant range
if(is_inserted && dynamic_cast<ir::make_range*>(v)){
@@ -819,7 +816,7 @@ void selection::init_layouts(ir::function *fn, IRBuilder<> &builder, Value *sh_m
Value *u_warp_id = builder.CreateUDiv(u_thread_id, warp_size);
// create grid
for(auto x: layouts_->get_all())
init_axes(x.second, builder, u_thread_warp_id, u_warp_id);
init_axes(*x.second, builder, u_thread_warp_id, u_warp_id);
// create tile
std::set<ir::value*> seen;
for(ir::basic_block *block: fn->blocks())
@@ -932,7 +929,7 @@ void selection::lower_reduce(ir::reduce_inst *x, LLVMContext &ctx, Function *fn,
tgt_->add_barrier(module, builder);
builder.CreateStore(result, write_ptr);
// build result
unsigned depth = layouts_->get(op).wpt.at(axis);
unsigned depth = layouts_->get(op)->wpt.at(axis);
for(unsigned i = depth/2; i > 0; i >>= 1){
// current indices
indices_t current(write_idx.size(), builder.getInt32(0));
@@ -1013,15 +1010,15 @@ void selection::lower_broadcast(ir::broadcast_inst *x, LLVMContext &ctx, Functio
void selection::lower_copy_to_shared(ir::copy_to_shared_inst *x, LLVMContext &ctx, Function *fn, IRBuilder<> &builder) {
unsigned vector_size = 1;
auto x_order = layouts_->get(x).order;
auto x_order = layouts_->get(x)->order;
ir::value *arg = x->get_operand(0);
auto arg_order = layouts_->get(arg).order;
auto arg_order = layouts_->get(arg)->order;
// tiles
shared_tile* result = (shared_tile*)tmap_.at(x);
distributed_tile* in = (distributed_tile*)tmap_.at(arg);
if(x_order == arg_order){
size_t ld = arg_order[0];
vector_size = std::min(layouts_->get(x).nts.at(ld), layouts_->get(arg).nts.at(ld));
vector_size = std::min(layouts_->get(x)->nts.at(ld), layouts_->get(arg)->nts.at(ld));
}
std::map<unsigned, Value*> packets;
@@ -1090,8 +1087,8 @@ void selection::lower_hmma_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *fn
Value* u_thread_id = tgt_->get_local_id(builder.GetInsertBlock()->getModule(), builder, 0);
auto ord_a = layouts_->get(dot->get_operand(0)).order;
auto ord_b = layouts_->get(dot->get_operand(1)).order;
auto ord_a = layouts_->get(dot->get_operand(0))->order;
auto ord_b = layouts_->get(dot->get_operand(1))->order;
bool is_a_trans = is_trans(dot->get_operand(0));
bool is_b_trans = is_trans(dot->get_operand(1));
@@ -1117,12 +1114,12 @@ void selection::lower_hmma_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *fn
"{$10, $11}, "
"{$0, $1, $2, $3, $4, $5, $6, $7};", "=f,=f,=f,=f,=f,=f,=f,=f,r,r,r,r,0,1,2,3,4,5,6,7", false);
unsigned fpw_0 = layouts_->get(dot).fpw.at(0);
unsigned fpw_1 = layouts_->get(dot).fpw.at(1);
unsigned fpw_0 = layouts_->get(dot)->fpw.at(0);
unsigned fpw_1 = layouts_->get(dot)->fpw.at(1);
unsigned wts_0 = fpw_0 * 8;
unsigned wts_1 = fpw_1 * 8;
unsigned wpt_0 = layouts_->get(dot).wpt.at(0);
unsigned wpt_1 = layouts_->get(dot).wpt.at(1);
unsigned wpt_0 = layouts_->get(dot)->wpt.at(0);
unsigned wpt_1 = layouts_->get(dot)->wpt.at(1);
unsigned stride_rep_i = wpt_0 * wts_0;
unsigned stride_rep_j = wpt_1 * wts_1;
unsigned num_rep_i = shapes[0] / stride_rep_i;
@@ -1253,7 +1250,7 @@ void selection::lower_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *fn, IRB
if(NK != 1) {
shared_tile *TA = (shared_tile*)tmap_.at(A);
shared_tile *TB = (shared_tile*)tmap_.at(B);
if(layouts_->get(dot).type == analysis::HMMA_884)
if(layouts_->get(dot)->type == analysis::HMMA_884)
lower_hmma_dot(dot, ctx, fn, builder, TC, TA, TB, TD, NK);
else
lower_scanline_dot(dot, ctx, fn, builder, TC, TA, TB, TD, NK, c_ty, f_mul_add);
@@ -1269,7 +1266,7 @@ void selection::lower_masked_load(ir::masked_load_inst *x, LLVMContext &ctx, Fun
// find vector size
distributed_tile* result = (distributed_tile*)tmap_.at(x);
ir::value *ptr = x->get_pointer_operand();
size_t ld = layouts_->get(ptr).order[0];
size_t ld = layouts_->get(ptr)->order[0];
unsigned alignment = alignment_->get(ptr, ld);
unsigned vector_size = std::min<unsigned>(result->axis(ld).contiguous, alignment);
distributed_tile *pointers = (distributed_tile*)tmap_.at(ptr);
@@ -1341,7 +1338,7 @@ void selection::lower_load(ir::load_inst *x, LLVMContext &ctx, Function *fn, IRB
distributed_tile* result = (distributed_tile*)tmap_.at(x);
// find vector size
ir::value *ptr = x->get_pointer_operand();
size_t ld = layouts_->get(ptr).order[0];
size_t ld = layouts_->get(ptr)->order[0];
unsigned alignment = alignment_->get(ptr, ld);
unsigned vector_size = std::min<unsigned>(result->axis(ld).contiguous, alignment);
distributed_tile *pointers = (distributed_tile*)tmap_.at(ptr);