[CODEGEN] Fix bug in auto-pipeline pass when a value depends on multiple phis (#164)

This commit is contained in:
daadaada
2021-08-01 14:40:36 +08:00
committed by GitHub
parent c0bb895d9d
commit c7060eadb2

View File

@@ -23,6 +23,18 @@ void recursive_deps(ir::value* v, ir::basic_block* block, std::vector<ir::instru
recursive_deps(u, block, ret);
}
void get_induction_vars(ir::value* cond, std::set<ir::phi_node*>& phis) {
auto instr = dynamic_cast<ir::instruction*>(cond);
for (auto op : instr->ops()) {
if (auto phi_op = dynamic_cast<ir::phi_node*>(op)) {
phis.insert(phi_op);
return;
}
if (dynamic_cast<ir::instruction*>(op))
get_induction_vars(op, phis);
}
}
/// assume incoming block is 1
ir::value* rematerialize_vals(ir::builder& builder, ir::value* v,
std::map<ir::phi_node*, ir::value*>& prev_phi_vals) {
@@ -46,37 +58,6 @@ ir::value* rematerialize_vals(ir::builder& builder, ir::value* v,
return ret;
}
void get_induction_vars(ir::value* cond, std::set<ir::phi_node*>& phis) {
auto instr = dynamic_cast<ir::instruction*>(cond);
for (auto op : instr->ops()) {
if (auto phi_op = dynamic_cast<ir::phi_node*>(op)) {
phis.insert(phi_op);
return;
}
if (dynamic_cast<ir::instruction*>(op))
get_induction_vars(op, phis);
}
}
/// Returns phi_val if sees a phi node
ir::value* rematerialize_val(ir::builder& builder, ir::value* v, ir::value* phi_val) {
ir::instruction* i = dynamic_cast<ir::instruction*>(v);
if(!i)
return v;
if(ir::phi_node* phi = dynamic_cast<ir::phi_node*>(v))
return phi_val;
std::vector<ir::value*> new_ops;
for(ir::value* op: i->ops()){
new_ops.push_back(rematerialize_val(builder, op, phi_val));
}
ir::instruction* ret = i->clone();
for(size_t k = 0; k < new_ops.size(); k++)
ret->set_operand(k, new_ops[k]);
builder.insert(ret);
return ret;
}
ir::value* rematerialize(ir::builder& builder, ir::value* v, size_t phi_idx){
ir::instruction* i = dynamic_cast<ir::instruction*>(v);
if(!i)
@@ -96,11 +77,13 @@ ir::value* rematerialize(ir::builder& builder, ir::value* v, size_t phi_idx){
}
/// moving the prev phi vals to the next iteration
void update_prev_phi_vals(ir::builder& builder, std::map<ir::phi_node*, ir::value*>& prev_phi_vals) {
for (auto& [phi, val] : prev_phi_vals) {
// TODO: handling nested phis
val = rematerialize_val(builder, phi->get_incoming_value(1), val);
std::map<ir::phi_node*, ir::value*> update_prev_phi_vals(
ir::builder& builder, std::map<ir::phi_node*, ir::value*>& prev_phi_vals) {
std::map<ir::phi_node*, ir::value*> next_phi_vals;
for (auto &[phi, val] : prev_phi_vals) {
next_phi_vals[phi] = rematerialize_vals(builder, phi->get_incoming_value(1), prev_phi_vals);
}
return next_phi_vals;
}
void finalize_iv_vals(ir::builder& builder, std::map<ir::phi_node*, ir::value*>& load_ivs,
@@ -163,12 +146,10 @@ void pipeline::run(ir::module &mod) {
std::map<ir::phi_node*, ir::value*> prev_phi_vals;
// initialize prev_phi_vals
// note: we assume that ptr & other values only depend on ptr & iv (phis)
// TODO: can we just add all phis here?
prev_phi_vals[ptr] = ptr->get_value_for_block(header);
for (ir::phi_node* iv : induction_vars)
prev_phi_vals[iv] = iv->get_value_for_block(header);
prev_phi_vals[ptr] = ptr->get_value_for_block(header);
// Add all phi nodes. The following DCE pass will delete dead ones.
for (ir::instruction *instr : block->get_inst_list())
if (auto *phi = dynamic_cast<ir::phi_node*>(instr))
prev_phi_vals[phi] = phi->get_value_for_block(header);
builder.set_insert_point(header->get_inst_list().back());
first_ptrs[0] = ptr->get_value_for_block(header);
@@ -188,7 +169,7 @@ void pipeline::run(ir::module &mod) {
for (int stage = 1; stage < num_stages-1; ++stage) {
// mask is the loop condition of the previous iteration
loop_conds[stage] = rematerialize_vals(builder, block_cond, prev_phi_vals);
update_prev_phi_vals(builder, prev_phi_vals);
prev_phi_vals = update_prev_phi_vals(builder, prev_phi_vals);
first_ptrs[stage] = rematerialize_vals(builder, ptr, prev_phi_vals);
first_masks[stage] = builder.create_splat(loop_conds[stage], ty->get_block_shapes());
if (auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)) {
@@ -330,4 +311,4 @@ void pipeline::run(ir::module &mod) {
}
}
}
}