[CODEGEN] Fixed bug in pipelining pass and casting semantics analysis (#257)

This commit is contained in:
Philippe Tillet
2021-09-01 20:58:47 -07:00
committed by GitHub
parent c0daffc625
commit 768e0ded28
2 changed files with 34 additions and 26 deletions

View File

@@ -36,10 +36,10 @@ void get_induction_vars(ir::value* cond, std::set<ir::phi_node*>& phis) {
}
/// assume incoming block is 1
ir::value* rematerialize_vals(ir::builder& builder, ir::value* v,
ir::value* rematerialize_vals(ir::builder& builder, ir::basic_block* block, ir::value* v,
std::map<ir::phi_node*, ir::value*>& prev_phi_vals) {
ir::instruction* i = dynamic_cast<ir::instruction*>(v);
if(!i)
if(!i || i->get_parent() != block)
return v;
if(ir::phi_node* phi = dynamic_cast<ir::phi_node*>(v)) {
if (prev_phi_vals.find(phi) == prev_phi_vals.end())
@@ -49,7 +49,7 @@ ir::value* rematerialize_vals(ir::builder& builder, ir::value* v,
std::vector<ir::value*> new_ops;
for(ir::value* op: i->ops()){
new_ops.push_back(rematerialize_vals(builder, op, prev_phi_vals));
new_ops.push_back(rematerialize_vals(builder, block, op, prev_phi_vals));
}
ir::instruction* ret = i->clone();
for(size_t k = 0; k < new_ops.size(); k++)
@@ -58,16 +58,17 @@ ir::value* rematerialize_vals(ir::builder& builder, ir::value* v,
return ret;
}
ir::value* rematerialize(ir::builder& builder, ir::value* v, size_t phi_idx){
ir::value* rematerialize(ir::builder& builder, ir::basic_block* block,
ir::value* v, size_t phi_idx){
ir::instruction* i = dynamic_cast<ir::instruction*>(v);
if(!i)
if(!i || i->get_parent() != block)
return v;
if(ir::phi_node* phi = dynamic_cast<ir::phi_node*>(v))
return phi->get_incoming_value(phi_idx);
std::vector<ir::value*> new_ops;
for(ir::value* op: i->ops()){
new_ops.push_back(rematerialize(builder, op, phi_idx));
new_ops.push_back(rematerialize(builder, block, op, phi_idx));
}
ir::instruction* ret = i->clone();
for(size_t k = 0; k < new_ops.size(); k++)
@@ -78,19 +79,19 @@ ir::value* rematerialize(ir::builder& builder, ir::value* v, size_t phi_idx){
/// moving the prev phi vals to the next iteration
std::map<ir::phi_node*, ir::value*> update_prev_phi_vals(
ir::builder& builder, std::map<ir::phi_node*, ir::value*>& prev_phi_vals) {
ir::builder& builder, ir::basic_block* block, std::map<ir::phi_node*, ir::value*>& prev_phi_vals) {
std::map<ir::phi_node*, ir::value*> next_phi_vals;
for (auto &[phi, val] : prev_phi_vals) {
next_phi_vals[phi] = rematerialize_vals(builder, phi->get_incoming_value(1), prev_phi_vals);
next_phi_vals[phi] = rematerialize_vals(builder, block, phi->get_incoming_value(1), prev_phi_vals);
}
return next_phi_vals;
}
void finalize_iv_vals(ir::builder& builder, std::map<ir::phi_node*, ir::value*>& load_ivs,
void finalize_iv_vals(ir::builder& builder, ir::basic_block* block, std::map<ir::phi_node*, ir::value*>& load_ivs,
std::map<ir::phi_node*, ir::value*>& next_load_ivs) {
for (auto& [phi, val] : load_ivs) {
if (auto new_phi = dynamic_cast<ir::phi_node*>(val)) {
ir::value* next_k = rematerialize_vals(builder, phi->get_incoming_value(1), load_ivs);
ir::value* next_k = rematerialize_vals(builder, block, phi->get_incoming_value(1), load_ivs);
assert(new_phi->get_num_operands() == 1 && "should be incomplete phi");
new_phi->add_incoming(next_k, phi->get_incoming_block(1));
// cache next_k (to be used by next_mask)
@@ -149,7 +150,8 @@ void pipeline::run(ir::module &mod) {
// Add all phi nodes. The following DCE pass will delete dead ones.
for (ir::instruction *instr : block->get_inst_list())
if (auto *phi = dynamic_cast<ir::phi_node*>(instr))
prev_phi_vals[phi] = phi->get_value_for_block(header);
if (phi->get_incoming_block(1) == block)
prev_phi_vals[phi] = phi->get_value_for_block(header);
builder.set_insert_point(header->get_inst_list().back());
first_ptrs[0] = ptr->get_value_for_block(header);
@@ -157,9 +159,9 @@ void pipeline::run(ir::module &mod) {
first_masks[0] = builder.create_splat(loop_conds[0], ty->get_block_shapes());
ir::value* false_value = nullptr;
if (auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)) {
ir::value* remat_mask =rematerialize_vals(builder, masked_load->get_mask_operand(), prev_phi_vals) ;
ir::value* remat_mask =rematerialize_vals(builder, block, masked_load->get_mask_operand(), prev_phi_vals) ;
ir::value* remat_false_value =
rematerialize_vals(builder, masked_load->get_false_value_operand(), prev_phi_vals);
rematerialize_vals(builder, block, masked_load->get_false_value_operand(), prev_phi_vals);
first_masks[0] = builder.create_and(first_masks[0], remat_mask);
false_value = remat_false_value;
} else
@@ -168,14 +170,14 @@ void pipeline::run(ir::module &mod) {
for (int stage = 1; stage < num_stages-1; ++stage) {
// mask is the loop condition of the previous iteration
loop_conds[stage] = rematerialize_vals(builder, block_cond, prev_phi_vals);
prev_phi_vals = update_prev_phi_vals(builder, prev_phi_vals);
first_ptrs[stage] = rematerialize_vals(builder, ptr, prev_phi_vals);
loop_conds[stage] = rematerialize_vals(builder, block, block_cond, prev_phi_vals);
prev_phi_vals = update_prev_phi_vals(builder, block, prev_phi_vals);
first_ptrs[stage] = rematerialize_vals(builder, block, ptr, prev_phi_vals);
first_masks[stage] = builder.create_splat(loop_conds[stage], ty->get_block_shapes());
if (auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)) {
ir::value* remat_mask = rematerialize_vals(builder, masked_load->get_mask_operand(), prev_phi_vals);
ir::value* remat_mask = rematerialize_vals(builder, block, masked_load->get_mask_operand(), prev_phi_vals);
ir::value* remat_false_value =
rematerialize_vals(builder, masked_load->get_false_value_operand(), prev_phi_vals);
rematerialize_vals(builder, block, masked_load->get_false_value_operand(), prev_phi_vals);
first_masks[stage] = builder.create_and(first_masks[stage], remat_mask);
false_value = remat_false_value;
}
@@ -192,18 +194,18 @@ void pipeline::run(ir::module &mod) {
load_ivs[iv] = pn;
}
// add incoming for phis & update next_load_ivs
finalize_iv_vals(builder, load_ivs, next_load_ivs);
finalize_iv_vals(builder, block, load_ivs, next_load_ivs);
// pre-fetch next iteration
builder.set_insert_point(block->get_inst_list().back());
ir::value* next_ptr = ptr->get_value_for_block(block);
ir::value* next_mask = builder.create_splat(
rematerialize_vals(builder, block_cond, load_ivs), ty->get_block_shapes());
rematerialize_vals(builder, block, block_cond, load_ivs), ty->get_block_shapes());
if (auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)) {
ir::value* remat_mask = rematerialize_vals(builder, masked_load->get_mask_operand(), next_load_ivs);
ir::value* remat_mask = rematerialize_vals(builder, block, masked_load->get_mask_operand(), next_load_ivs);
// TODO: false may depends on some other phi nodes
ir::value* remat_false_value =
rematerialize_vals(builder, masked_load->get_false_value_operand(), next_load_ivs);
rematerialize_vals(builder, block, masked_load->get_false_value_operand(), next_load_ivs);
next_mask = builder.create_and(next_mask, remat_mask);
false_value = remat_false_value;
}
@@ -235,8 +237,8 @@ void pipeline::run(ir::module &mod) {
ir::value* first_mask = builder.create_splat(header_br->get_cond(), ty->get_block_shapes());
ir::value* false_value;
if(auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)){
ir::value* remat_mask = rematerialize(builder, masked_load->get_mask_operand(), 0);
ir::value* remat_false_value = rematerialize(builder, masked_load->get_false_value_operand(), 0);
ir::value* remat_mask = rematerialize(builder, block, masked_load->get_mask_operand(), 0);
ir::value* remat_false_value = rematerialize(builder, block, masked_load->get_false_value_operand(), 0);
first_mask = builder.create_and(first_mask, remat_mask);
false_value = remat_false_value;
}
@@ -248,8 +250,8 @@ void pipeline::run(ir::module &mod) {
ir::value* next_ptr = ptr->get_value_for_block(block);
ir::value* next_mask = builder.create_splat(block_br->get_cond(), ty->get_block_shapes());
if(auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)){
ir::value* remat_mask = rematerialize(builder, masked_load->get_mask_operand(), 1);
ir::value* remat_false_value = rematerialize(builder, masked_load->get_false_value_operand(), 1);
ir::value* remat_mask = rematerialize(builder, block, masked_load->get_mask_operand(), 1);
ir::value* remat_false_value = rematerialize(builder, block, masked_load->get_false_value_operand(), 1);
next_mask = builder.create_and(next_mask, remat_mask);
false_value = remat_false_value;
}

View File

@@ -425,6 +425,8 @@ ir::value *dispatch::bitcast(ir::value *input, ir::type *dst_ty, ir::builder *bu
return input;
ir::type *src_sca_ty = src_ty->get_scalar_ty();
ir::type *dst_sca_ty = dst_ty->get_scalar_ty();
if(src_sca_ty->is_pointer_ty() || dst_sca_ty->is_pointer_ty())
return cast(input, dst_ty, builder);
// Bitcast
int src_bits = src_sca_ty->get_primitive_size_in_bits();
int dst_bits = dst_sca_ty->get_primitive_size_in_bits();
@@ -472,6 +474,10 @@ ir::value *dispatch::cast(ir::value *input, ir::type *dst_ty, ir::builder *build
else
return builder->create_si_to_fp(input, dst_ty);
}
if (src_sca_ty->is_pointer_ty() && !dst_sca_ty->is_pointer_ty())
return builder->create_cast(ir::PtrToInt, input, dst_ty);
if (!src_sca_ty->is_pointer_ty() && dst_sca_ty->is_pointer_ty())
return builder->create_cast(ir::IntToPtr, input, dst_ty);
// Ptr -> Ptr
if (src_sca_ty->is_pointer_ty() && dst_sca_ty->is_pointer_ty())
return builder->create_cast(ir::BitCast, input, dst_ty);