[CODEGEN] Fixed bug in pipelining pass and casting semantics analysis (#257)
This commit is contained in:
@@ -36,10 +36,10 @@ void get_induction_vars(ir::value* cond, std::set<ir::phi_node*>& phis) {
|
||||
}
|
||||
|
||||
/// assume incoming block is 1
|
||||
ir::value* rematerialize_vals(ir::builder& builder, ir::value* v,
|
||||
ir::value* rematerialize_vals(ir::builder& builder, ir::basic_block* block, ir::value* v,
|
||||
std::map<ir::phi_node*, ir::value*>& prev_phi_vals) {
|
||||
ir::instruction* i = dynamic_cast<ir::instruction*>(v);
|
||||
if(!i)
|
||||
if(!i || i->get_parent() != block)
|
||||
return v;
|
||||
if(ir::phi_node* phi = dynamic_cast<ir::phi_node*>(v)) {
|
||||
if (prev_phi_vals.find(phi) == prev_phi_vals.end())
|
||||
@@ -49,7 +49,7 @@ ir::value* rematerialize_vals(ir::builder& builder, ir::value* v,
|
||||
|
||||
std::vector<ir::value*> new_ops;
|
||||
for(ir::value* op: i->ops()){
|
||||
new_ops.push_back(rematerialize_vals(builder, op, prev_phi_vals));
|
||||
new_ops.push_back(rematerialize_vals(builder, block, op, prev_phi_vals));
|
||||
}
|
||||
ir::instruction* ret = i->clone();
|
||||
for(size_t k = 0; k < new_ops.size(); k++)
|
||||
@@ -58,16 +58,17 @@ ir::value* rematerialize_vals(ir::builder& builder, ir::value* v,
|
||||
return ret;
|
||||
}
|
||||
|
||||
ir::value* rematerialize(ir::builder& builder, ir::value* v, size_t phi_idx){
|
||||
ir::value* rematerialize(ir::builder& builder, ir::basic_block* block,
|
||||
ir::value* v, size_t phi_idx){
|
||||
ir::instruction* i = dynamic_cast<ir::instruction*>(v);
|
||||
if(!i)
|
||||
if(!i || i->get_parent() != block)
|
||||
return v;
|
||||
if(ir::phi_node* phi = dynamic_cast<ir::phi_node*>(v))
|
||||
return phi->get_incoming_value(phi_idx);
|
||||
|
||||
std::vector<ir::value*> new_ops;
|
||||
for(ir::value* op: i->ops()){
|
||||
new_ops.push_back(rematerialize(builder, op, phi_idx));
|
||||
new_ops.push_back(rematerialize(builder, block, op, phi_idx));
|
||||
}
|
||||
ir::instruction* ret = i->clone();
|
||||
for(size_t k = 0; k < new_ops.size(); k++)
|
||||
@@ -78,19 +79,19 @@ ir::value* rematerialize(ir::builder& builder, ir::value* v, size_t phi_idx){
|
||||
|
||||
/// moving the prev phi vals to the next iteration
|
||||
std::map<ir::phi_node*, ir::value*> update_prev_phi_vals(
|
||||
ir::builder& builder, std::map<ir::phi_node*, ir::value*>& prev_phi_vals) {
|
||||
ir::builder& builder, ir::basic_block* block, std::map<ir::phi_node*, ir::value*>& prev_phi_vals) {
|
||||
std::map<ir::phi_node*, ir::value*> next_phi_vals;
|
||||
for (auto &[phi, val] : prev_phi_vals) {
|
||||
next_phi_vals[phi] = rematerialize_vals(builder, phi->get_incoming_value(1), prev_phi_vals);
|
||||
next_phi_vals[phi] = rematerialize_vals(builder, block, phi->get_incoming_value(1), prev_phi_vals);
|
||||
}
|
||||
return next_phi_vals;
|
||||
}
|
||||
|
||||
void finalize_iv_vals(ir::builder& builder, std::map<ir::phi_node*, ir::value*>& load_ivs,
|
||||
void finalize_iv_vals(ir::builder& builder, ir::basic_block* block, std::map<ir::phi_node*, ir::value*>& load_ivs,
|
||||
std::map<ir::phi_node*, ir::value*>& next_load_ivs) {
|
||||
for (auto& [phi, val] : load_ivs) {
|
||||
if (auto new_phi = dynamic_cast<ir::phi_node*>(val)) {
|
||||
ir::value* next_k = rematerialize_vals(builder, phi->get_incoming_value(1), load_ivs);
|
||||
ir::value* next_k = rematerialize_vals(builder, block, phi->get_incoming_value(1), load_ivs);
|
||||
assert(new_phi->get_num_operands() == 1 && "should be incomplete phi");
|
||||
new_phi->add_incoming(next_k, phi->get_incoming_block(1));
|
||||
// cache next_k (to be used by next_mask)
|
||||
@@ -149,7 +150,8 @@ void pipeline::run(ir::module &mod) {
|
||||
// Add all phi nodes. The following DCE pass will delete dead ones.
|
||||
for (ir::instruction *instr : block->get_inst_list())
|
||||
if (auto *phi = dynamic_cast<ir::phi_node*>(instr))
|
||||
prev_phi_vals[phi] = phi->get_value_for_block(header);
|
||||
if (phi->get_incoming_block(1) == block)
|
||||
prev_phi_vals[phi] = phi->get_value_for_block(header);
|
||||
|
||||
builder.set_insert_point(header->get_inst_list().back());
|
||||
first_ptrs[0] = ptr->get_value_for_block(header);
|
||||
@@ -157,9 +159,9 @@ void pipeline::run(ir::module &mod) {
|
||||
first_masks[0] = builder.create_splat(loop_conds[0], ty->get_block_shapes());
|
||||
ir::value* false_value = nullptr;
|
||||
if (auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)) {
|
||||
ir::value* remat_mask =rematerialize_vals(builder, masked_load->get_mask_operand(), prev_phi_vals) ;
|
||||
ir::value* remat_mask =rematerialize_vals(builder, block, masked_load->get_mask_operand(), prev_phi_vals) ;
|
||||
ir::value* remat_false_value =
|
||||
rematerialize_vals(builder, masked_load->get_false_value_operand(), prev_phi_vals);
|
||||
rematerialize_vals(builder, block, masked_load->get_false_value_operand(), prev_phi_vals);
|
||||
first_masks[0] = builder.create_and(first_masks[0], remat_mask);
|
||||
false_value = remat_false_value;
|
||||
} else
|
||||
@@ -168,14 +170,14 @@ void pipeline::run(ir::module &mod) {
|
||||
|
||||
for (int stage = 1; stage < num_stages-1; ++stage) {
|
||||
// mask is the loop condition of the previous iteration
|
||||
loop_conds[stage] = rematerialize_vals(builder, block_cond, prev_phi_vals);
|
||||
prev_phi_vals = update_prev_phi_vals(builder, prev_phi_vals);
|
||||
first_ptrs[stage] = rematerialize_vals(builder, ptr, prev_phi_vals);
|
||||
loop_conds[stage] = rematerialize_vals(builder, block, block_cond, prev_phi_vals);
|
||||
prev_phi_vals = update_prev_phi_vals(builder, block, prev_phi_vals);
|
||||
first_ptrs[stage] = rematerialize_vals(builder, block, ptr, prev_phi_vals);
|
||||
first_masks[stage] = builder.create_splat(loop_conds[stage], ty->get_block_shapes());
|
||||
if (auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)) {
|
||||
ir::value* remat_mask = rematerialize_vals(builder, masked_load->get_mask_operand(), prev_phi_vals);
|
||||
ir::value* remat_mask = rematerialize_vals(builder, block, masked_load->get_mask_operand(), prev_phi_vals);
|
||||
ir::value* remat_false_value =
|
||||
rematerialize_vals(builder, masked_load->get_false_value_operand(), prev_phi_vals);
|
||||
rematerialize_vals(builder, block, masked_load->get_false_value_operand(), prev_phi_vals);
|
||||
first_masks[stage] = builder.create_and(first_masks[stage], remat_mask);
|
||||
false_value = remat_false_value;
|
||||
}
|
||||
@@ -192,18 +194,18 @@ void pipeline::run(ir::module &mod) {
|
||||
load_ivs[iv] = pn;
|
||||
}
|
||||
// add incoming for phis & update next_load_ivs
|
||||
finalize_iv_vals(builder, load_ivs, next_load_ivs);
|
||||
finalize_iv_vals(builder, block, load_ivs, next_load_ivs);
|
||||
|
||||
// pre-fetch next iteration
|
||||
builder.set_insert_point(block->get_inst_list().back());
|
||||
ir::value* next_ptr = ptr->get_value_for_block(block);
|
||||
ir::value* next_mask = builder.create_splat(
|
||||
rematerialize_vals(builder, block_cond, load_ivs), ty->get_block_shapes());
|
||||
rematerialize_vals(builder, block, block_cond, load_ivs), ty->get_block_shapes());
|
||||
if (auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)) {
|
||||
ir::value* remat_mask = rematerialize_vals(builder, masked_load->get_mask_operand(), next_load_ivs);
|
||||
ir::value* remat_mask = rematerialize_vals(builder, block, masked_load->get_mask_operand(), next_load_ivs);
|
||||
// TODO: false may depends on some other phi nodes
|
||||
ir::value* remat_false_value =
|
||||
rematerialize_vals(builder, masked_load->get_false_value_operand(), next_load_ivs);
|
||||
rematerialize_vals(builder, block, masked_load->get_false_value_operand(), next_load_ivs);
|
||||
next_mask = builder.create_and(next_mask, remat_mask);
|
||||
false_value = remat_false_value;
|
||||
}
|
||||
@@ -235,8 +237,8 @@ void pipeline::run(ir::module &mod) {
|
||||
ir::value* first_mask = builder.create_splat(header_br->get_cond(), ty->get_block_shapes());
|
||||
ir::value* false_value;
|
||||
if(auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)){
|
||||
ir::value* remat_mask = rematerialize(builder, masked_load->get_mask_operand(), 0);
|
||||
ir::value* remat_false_value = rematerialize(builder, masked_load->get_false_value_operand(), 0);
|
||||
ir::value* remat_mask = rematerialize(builder, block, masked_load->get_mask_operand(), 0);
|
||||
ir::value* remat_false_value = rematerialize(builder, block, masked_load->get_false_value_operand(), 0);
|
||||
first_mask = builder.create_and(first_mask, remat_mask);
|
||||
false_value = remat_false_value;
|
||||
}
|
||||
@@ -248,8 +250,8 @@ void pipeline::run(ir::module &mod) {
|
||||
ir::value* next_ptr = ptr->get_value_for_block(block);
|
||||
ir::value* next_mask = builder.create_splat(block_br->get_cond(), ty->get_block_shapes());
|
||||
if(auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)){
|
||||
ir::value* remat_mask = rematerialize(builder, masked_load->get_mask_operand(), 1);
|
||||
ir::value* remat_false_value = rematerialize(builder, masked_load->get_false_value_operand(), 1);
|
||||
ir::value* remat_mask = rematerialize(builder, block, masked_load->get_mask_operand(), 1);
|
||||
ir::value* remat_false_value = rematerialize(builder, block, masked_load->get_false_value_operand(), 1);
|
||||
next_mask = builder.create_and(next_mask, remat_mask);
|
||||
false_value = remat_false_value;
|
||||
}
|
||||
|
@@ -425,6 +425,8 @@ ir::value *dispatch::bitcast(ir::value *input, ir::type *dst_ty, ir::builder *bu
|
||||
return input;
|
||||
ir::type *src_sca_ty = src_ty->get_scalar_ty();
|
||||
ir::type *dst_sca_ty = dst_ty->get_scalar_ty();
|
||||
if(src_sca_ty->is_pointer_ty() || dst_sca_ty->is_pointer_ty())
|
||||
return cast(input, dst_ty, builder);
|
||||
// Bitcast
|
||||
int src_bits = src_sca_ty->get_primitive_size_in_bits();
|
||||
int dst_bits = dst_sca_ty->get_primitive_size_in_bits();
|
||||
@@ -472,6 +474,10 @@ ir::value *dispatch::cast(ir::value *input, ir::type *dst_ty, ir::builder *build
|
||||
else
|
||||
return builder->create_si_to_fp(input, dst_ty);
|
||||
}
|
||||
if (src_sca_ty->is_pointer_ty() && !dst_sca_ty->is_pointer_ty())
|
||||
return builder->create_cast(ir::PtrToInt, input, dst_ty);
|
||||
if (!src_sca_ty->is_pointer_ty() && dst_sca_ty->is_pointer_ty())
|
||||
return builder->create_cast(ir::IntToPtr, input, dst_ty);
|
||||
// Ptr -> Ptr
|
||||
if (src_sca_ty->is_pointer_ty() && dst_sca_ty->is_pointer_ty())
|
||||
return builder->create_cast(ir::BitCast, input, dst_ty);
|
||||
|
Reference in New Issue
Block a user