[codegen] more cleaning
This commit is contained in:
@@ -559,7 +559,7 @@ Value* selection::llvm_value(ir::value *v, IRBuilder<> &builder) {
|
||||
* ------------------- */
|
||||
|
||||
// Grid construction
|
||||
std::vector<Value*> delinearize(Value *trailing, const std::vector<int>& order, std::vector<unsigned> &shapes, IRBuilder<> &builder){
|
||||
std::vector<Value*> delinearize(Value *trailing, const std::vector<int>& order, std::vector<int> &shapes, IRBuilder<> &builder){
|
||||
size_t dim = shapes.size();
|
||||
std::vector<Value*> result(dim);
|
||||
for(unsigned k = 0; k < dim - 1; k++){
|
||||
@@ -580,12 +580,8 @@ void selection::init_strided_scan_axes(const analysis::layout_t& layout, IRBuild
|
||||
auto order = layout.order;
|
||||
const auto& shapes = layout.shapes;
|
||||
size_t dim = shapes.size();
|
||||
std::vector<unsigned> nts(dim);
|
||||
std::vector<unsigned> mts(dim);
|
||||
for(unsigned i = 0; i < shapes.size(); i++){
|
||||
nts[i] = layout.nts.at(i);
|
||||
mts[i] = layout.mts.at(i);
|
||||
}
|
||||
std::vector<int> nts = layout.nts;
|
||||
std::vector<int> mts = layout.mts;
|
||||
Value* full_thread_id = builder.CreateAdd(builder.CreateMul(u_warp_id, builder.getInt32(32)), u_thread_id);
|
||||
std::vector<Value*> thread_id = delinearize(full_thread_id, order, mts, builder);
|
||||
// Create axes
|
||||
@@ -608,6 +604,7 @@ void selection::init_hmma_axes(const analysis::layout_t& layout, IRBuilder<> &bu
|
||||
const auto& shapes = layout.shapes;
|
||||
if(shapes.size() > 3)
|
||||
throw std::runtime_error("unsupported");
|
||||
|
||||
bool is_batched = shapes.size() >= 3;
|
||||
|
||||
Value *_1 = builder.getInt32(1);
|
||||
@@ -725,7 +722,7 @@ void selection::init_axes(const analysis::layout_t& layout, IRBuilder<> &builder
|
||||
void selection::create_shared_tile(ir::value *v, IRBuilder<> &builder, Value *sh_mem_ptr) {
|
||||
if(tmap_.find(v) != tmap_.end())
|
||||
return;
|
||||
auto order = layouts_->get(v).order;
|
||||
auto order = layouts_->get(v)->order;
|
||||
auto shapes = v->get_type()->get_tile_shapes();
|
||||
unsigned pad = liveness_->get_pad(v);
|
||||
if(pad > 0)
|
||||
@@ -775,7 +772,7 @@ void selection::create_distributed_tile(ir::value *v, IRBuilder<> &builder) {
|
||||
axes[d].values = {builder.getInt32(0)};
|
||||
}
|
||||
}
|
||||
distributed_tile *T = new distributed_tile(ty, shapes, layouts_->get(v).order, axes, builder, false);
|
||||
distributed_tile *T = new distributed_tile(ty, shapes, layouts_->get(v)->order, axes, builder, false);
|
||||
bool is_inserted = tmap_.insert({v, T}).second;
|
||||
// constant range
|
||||
if(is_inserted && dynamic_cast<ir::make_range*>(v)){
|
||||
@@ -819,7 +816,7 @@ void selection::init_layouts(ir::function *fn, IRBuilder<> &builder, Value *sh_m
|
||||
Value *u_warp_id = builder.CreateUDiv(u_thread_id, warp_size);
|
||||
// create grid
|
||||
for(auto x: layouts_->get_all())
|
||||
init_axes(x.second, builder, u_thread_warp_id, u_warp_id);
|
||||
init_axes(*x.second, builder, u_thread_warp_id, u_warp_id);
|
||||
// create tile
|
||||
std::set<ir::value*> seen;
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
@@ -932,7 +929,7 @@ void selection::lower_reduce(ir::reduce_inst *x, LLVMContext &ctx, Function *fn,
|
||||
tgt_->add_barrier(module, builder);
|
||||
builder.CreateStore(result, write_ptr);
|
||||
// build result
|
||||
unsigned depth = layouts_->get(op).wpt.at(axis);
|
||||
unsigned depth = layouts_->get(op)->wpt.at(axis);
|
||||
for(unsigned i = depth/2; i > 0; i >>= 1){
|
||||
// current indices
|
||||
indices_t current(write_idx.size(), builder.getInt32(0));
|
||||
@@ -1013,15 +1010,15 @@ void selection::lower_broadcast(ir::broadcast_inst *x, LLVMContext &ctx, Functio
|
||||
|
||||
void selection::lower_copy_to_shared(ir::copy_to_shared_inst *x, LLVMContext &ctx, Function *fn, IRBuilder<> &builder) {
|
||||
unsigned vector_size = 1;
|
||||
auto x_order = layouts_->get(x).order;
|
||||
auto x_order = layouts_->get(x)->order;
|
||||
ir::value *arg = x->get_operand(0);
|
||||
auto arg_order = layouts_->get(arg).order;
|
||||
auto arg_order = layouts_->get(arg)->order;
|
||||
// tiles
|
||||
shared_tile* result = (shared_tile*)tmap_.at(x);
|
||||
distributed_tile* in = (distributed_tile*)tmap_.at(arg);
|
||||
if(x_order == arg_order){
|
||||
size_t ld = arg_order[0];
|
||||
vector_size = std::min(layouts_->get(x).nts.at(ld), layouts_->get(arg).nts.at(ld));
|
||||
vector_size = std::min(layouts_->get(x)->nts.at(ld), layouts_->get(arg)->nts.at(ld));
|
||||
}
|
||||
|
||||
std::map<unsigned, Value*> packets;
|
||||
@@ -1090,8 +1087,8 @@ void selection::lower_hmma_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *fn
|
||||
|
||||
Value* u_thread_id = tgt_->get_local_id(builder.GetInsertBlock()->getModule(), builder, 0);
|
||||
|
||||
auto ord_a = layouts_->get(dot->get_operand(0)).order;
|
||||
auto ord_b = layouts_->get(dot->get_operand(1)).order;
|
||||
auto ord_a = layouts_->get(dot->get_operand(0))->order;
|
||||
auto ord_b = layouts_->get(dot->get_operand(1))->order;
|
||||
|
||||
bool is_a_trans = is_trans(dot->get_operand(0));
|
||||
bool is_b_trans = is_trans(dot->get_operand(1));
|
||||
@@ -1117,12 +1114,12 @@ void selection::lower_hmma_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *fn
|
||||
"{$10, $11}, "
|
||||
"{$0, $1, $2, $3, $4, $5, $6, $7};", "=f,=f,=f,=f,=f,=f,=f,=f,r,r,r,r,0,1,2,3,4,5,6,7", false);
|
||||
|
||||
unsigned fpw_0 = layouts_->get(dot).fpw.at(0);
|
||||
unsigned fpw_1 = layouts_->get(dot).fpw.at(1);
|
||||
unsigned fpw_0 = layouts_->get(dot)->fpw.at(0);
|
||||
unsigned fpw_1 = layouts_->get(dot)->fpw.at(1);
|
||||
unsigned wts_0 = fpw_0 * 8;
|
||||
unsigned wts_1 = fpw_1 * 8;
|
||||
unsigned wpt_0 = layouts_->get(dot).wpt.at(0);
|
||||
unsigned wpt_1 = layouts_->get(dot).wpt.at(1);
|
||||
unsigned wpt_0 = layouts_->get(dot)->wpt.at(0);
|
||||
unsigned wpt_1 = layouts_->get(dot)->wpt.at(1);
|
||||
unsigned stride_rep_i = wpt_0 * wts_0;
|
||||
unsigned stride_rep_j = wpt_1 * wts_1;
|
||||
unsigned num_rep_i = shapes[0] / stride_rep_i;
|
||||
@@ -1253,7 +1250,7 @@ void selection::lower_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *fn, IRB
|
||||
if(NK != 1) {
|
||||
shared_tile *TA = (shared_tile*)tmap_.at(A);
|
||||
shared_tile *TB = (shared_tile*)tmap_.at(B);
|
||||
if(layouts_->get(dot).type == analysis::HMMA_884)
|
||||
if(layouts_->get(dot)->type == analysis::HMMA_884)
|
||||
lower_hmma_dot(dot, ctx, fn, builder, TC, TA, TB, TD, NK);
|
||||
else
|
||||
lower_scanline_dot(dot, ctx, fn, builder, TC, TA, TB, TD, NK, c_ty, f_mul_add);
|
||||
@@ -1269,7 +1266,7 @@ void selection::lower_masked_load(ir::masked_load_inst *x, LLVMContext &ctx, Fun
|
||||
// find vector size
|
||||
distributed_tile* result = (distributed_tile*)tmap_.at(x);
|
||||
ir::value *ptr = x->get_pointer_operand();
|
||||
size_t ld = layouts_->get(ptr).order[0];
|
||||
size_t ld = layouts_->get(ptr)->order[0];
|
||||
unsigned alignment = alignment_->get(ptr, ld);
|
||||
unsigned vector_size = std::min<unsigned>(result->axis(ld).contiguous, alignment);
|
||||
distributed_tile *pointers = (distributed_tile*)tmap_.at(ptr);
|
||||
@@ -1341,7 +1338,7 @@ void selection::lower_load(ir::load_inst *x, LLVMContext &ctx, Function *fn, IRB
|
||||
distributed_tile* result = (distributed_tile*)tmap_.at(x);
|
||||
// find vector size
|
||||
ir::value *ptr = x->get_pointer_operand();
|
||||
size_t ld = layouts_->get(ptr).order[0];
|
||||
size_t ld = layouts_->get(ptr)->order[0];
|
||||
unsigned alignment = alignment_->get(ptr, ld);
|
||||
unsigned vector_size = std::min<unsigned>(result->axis(ld).contiguous, alignment);
|
||||
distributed_tile *pointers = (distributed_tile*)tmap_.at(ptr);
|
||||
|
Reference in New Issue
Block a user