[codegen/reassociation] now recursively takes pointer arguments into account as well

This commit is contained in:
Philippe Tillet
2019-07-31 18:41:56 -07:00
parent f7bd976fc7
commit 3b92ddf7e6
5 changed files with 24 additions and 8 deletions

View File

@@ -215,6 +215,21 @@ void reassociate::run(ir::module &mod) {
infos[sta_ptr].dyn_ptr = (ir::getelementptr_inst*)dyn_ptr;
infos[sta_ptr].sta_ptr = (ir::getelementptr_inst*)sta_ptr;
}
// reassociate pointer argument
if(ir::getelementptr_inst* gepy = dynamic_cast<ir::getelementptr_inst*>(py))
if(infos.find(gepy) != infos.end()){
builder.set_insert_point(pz);
ir::getelementptr_inst *sta = infos[gepy].sta_ptr;
ir::getelementptr_inst *dyn = infos[gepy].dyn_ptr;
ir::value *cst = *sta->idx_begin();
ir::value *off = *pz->idx_begin();
ir::value *new_dyn = builder.create_gep(dyn, {off});
ir::value *new_pz = builder.create_gep(new_dyn, {cst}, pz->get_name());
params_->copy(new_dyn, pz);
params_->copy(new_pz, pz);
align_->copy(new_pz, pz);
pz->replace_all_uses_with(new_pz);
}
// reassociate phi-node pointer
if(ir::phi_node* phi = dynamic_cast<ir::phi_node*>(py)){
// only optimize the case where py = phi pa, pz for now

View File

@@ -228,7 +228,7 @@ void tune::run(ir::module &mod) {
nts->set_value(1);
}
else {
ir::metaparameter *fpw = ir::metaparameter::create(ctx, ty, 1, 4);
ir::metaparameter *fpw = ir::metaparameter::create(ctx, ty, 2, 2);
ir::metaparameter *wpt = ir::metaparameter::create(ctx, ty, 1, 4);
connected_components(node, {fpw, wpt}, {"fpw", "wpt"}, nodes_, dependencies_, group_id++);
}
@@ -247,14 +247,14 @@ void tune::run(ir::module &mod) {
size_t addr_space = ptr_ty->get_pointer_address_space();
if(addr_space < 4){
ir::type *ty = mod.get_builder().get_int32_ty();
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 2, 8));
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 4, 8));
*params_.at(i).at("nts.d0") = *tmp;
}
}
if(dynamic_cast<ir::dot_inst*>(i) && i->get_type()->is_tile_ty()){
ir::type *ty = mod.get_builder().get_int32_ty();
std::unique_ptr<ir::metaparameter> tmp1(ir::metaparameter::create(ctx, ty, 2, 8));
std::unique_ptr<ir::metaparameter> tmp2(ir::metaparameter::create(ctx, ty, 2, 8));
std::unique_ptr<ir::metaparameter> tmp1(ir::metaparameter::create(ctx, ty, 4, 8));
std::unique_ptr<ir::metaparameter> tmp2(ir::metaparameter::create(ctx, ty, 4, 8));
*params_.at(i).at("nts.d0") = *tmp1;
*params_.at(i).at("nts.d1") = *tmp2;
}