diff --git a/lib/codegen/transform/pipeline.cc b/lib/codegen/transform/pipeline.cc index 39b058ae7..8b9fe1a5d 100644 --- a/lib/codegen/transform/pipeline.cc +++ b/lib/codegen/transform/pipeline.cc @@ -23,6 +23,18 @@ void recursive_deps(ir::value* v, ir::basic_block* block, std::vector& phis) { + auto instr = dynamic_cast(cond); + for (auto op : instr->ops()) { + if (auto phi_op = dynamic_cast(op)) { + phis.insert(phi_op); + return; + } + if (dynamic_cast(op)) + get_induction_vars(op, phis); + } +} + /// assume incoming block is 1 ir::value* rematerialize_vals(ir::builder& builder, ir::value* v, std::map& prev_phi_vals) { @@ -46,37 +58,6 @@ ir::value* rematerialize_vals(ir::builder& builder, ir::value* v, return ret; } -void get_induction_vars(ir::value* cond, std::set& phis) { - auto instr = dynamic_cast(cond); - for (auto op : instr->ops()) { - if (auto phi_op = dynamic_cast(op)) { - phis.insert(phi_op); - return; - } - if (dynamic_cast(op)) - get_induction_vars(op, phis); - } -} - -/// Returns phi_val if sees a phi node -ir::value* rematerialize_val(ir::builder& builder, ir::value* v, ir::value* phi_val) { - ir::instruction* i = dynamic_cast(v); - if(!i) - return v; - if(ir::phi_node* phi = dynamic_cast(v)) - return phi_val; - - std::vector new_ops; - for(ir::value* op: i->ops()){ - new_ops.push_back(rematerialize_val(builder, op, phi_val)); - } - ir::instruction* ret = i->clone(); - for(size_t k = 0; k < new_ops.size(); k++) - ret->set_operand(k, new_ops[k]); - builder.insert(ret); - return ret; -} - ir::value* rematerialize(ir::builder& builder, ir::value* v, size_t phi_idx){ ir::instruction* i = dynamic_cast(v); if(!i) @@ -96,11 +77,13 @@ ir::value* rematerialize(ir::builder& builder, ir::value* v, size_t phi_idx){ } /// moving the prev phi vals to the next iteration -void update_prev_phi_vals(ir::builder& builder, std::map& prev_phi_vals) { - for (auto& [phi, val] : prev_phi_vals) { - // TODO: handling nested phis - val = rematerialize_val(builder, phi->get_incoming_value(1), val); +std::map update_prev_phi_vals( + ir::builder& builder, std::map& prev_phi_vals) { + std::map next_phi_vals; + for (auto &[phi, val] : prev_phi_vals) { + next_phi_vals[phi] = rematerialize_vals(builder, phi->get_incoming_value(1), prev_phi_vals); } + return next_phi_vals; } void finalize_iv_vals(ir::builder& builder, std::map& load_ivs, @@ -163,12 +146,10 @@ void pipeline::run(ir::module &mod) { std::map prev_phi_vals; // initialize prev_phi_vals - // note: we assume that ptr & other values only depend on ptr & iv (phis) - // TODO: can we just add all phis here? - prev_phi_vals[ptr] = ptr->get_value_for_block(header); - for (ir::phi_node* iv : induction_vars) - prev_phi_vals[iv] = iv->get_value_for_block(header); - prev_phi_vals[ptr] = ptr->get_value_for_block(header); + // Add all phi nodes. The following DCE pass will delete dead ones. + for (ir::instruction *instr : block->get_inst_list()) + if (auto *phi = dynamic_cast(instr)) + prev_phi_vals[phi] = phi->get_value_for_block(header); builder.set_insert_point(header->get_inst_list().back()); first_ptrs[0] = ptr->get_value_for_block(header); @@ -188,7 +169,7 @@ void pipeline::run(ir::module &mod) { for (int stage = 1; stage < num_stages-1; ++stage) { // mask is the loop condition of the previous iteration loop_conds[stage] = rematerialize_vals(builder, block_cond, prev_phi_vals); - update_prev_phi_vals(builder, prev_phi_vals); + prev_phi_vals = update_prev_phi_vals(builder, prev_phi_vals); first_ptrs[stage] = rematerialize_vals(builder, ptr, prev_phi_vals); first_masks[stage] = builder.create_splat(loop_conds[stage], ty->get_block_shapes()); if (auto* masked_load = dynamic_cast(load)) { @@ -330,4 +311,4 @@ void pipeline::run(ir::module &mod) { } } -} +} \ No newline at end of file