[code generation] better predication
This commit is contained in:
@@ -517,41 +517,21 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
BasicBlock *block = builder.GetInsertBlock();
|
||||
Module *module = block->getModule();
|
||||
LLVMContext &ctx = builder.getContext();
|
||||
// // helper to handle masks
|
||||
// auto insert_masked = [&](indices_t idx, std::function<Value*()> insert_value) {
|
||||
// BasicBlock *block = builder.GetInsertBlock();
|
||||
// Value *result;
|
||||
// if(mask_pred){
|
||||
// Value *llvm_mask = tmap_.at(mask_pred)->get_value(idx);
|
||||
// BasicBlock *then_bb = BasicBlock::Create(ctx, "", function);
|
||||
// BasicBlock *done_bb = BasicBlock::Create(ctx, "", function);
|
||||
// builder.CreateCondBr(llvm_mask, then_bb, done_bb);
|
||||
// builder.SetInsertPoint(then_bb);
|
||||
// result = insert_value();
|
||||
// builder.CreateBr(done_bb);
|
||||
// builder.SetInsertPoint(done_bb);
|
||||
// if(!ins->get_type()->is_void_ty()){
|
||||
// Type *ty = result->getType();
|
||||
// PHINode *phi = builder.CreatePHI(ty, 2);
|
||||
// if(mask_else)
|
||||
// phi->addIncoming(tmap_.at(mask_else)->get_value(idx), block);
|
||||
// else
|
||||
// phi->addIncoming(llvm::UndefValue::get(ty), block);
|
||||
// phi->addIncoming(result, then_bb);
|
||||
// return (Value*)phi;
|
||||
// }
|
||||
// }
|
||||
// else
|
||||
// result = insert_value();
|
||||
// return result;
|
||||
// };
|
||||
|
||||
std::cout << ins->get_name() << " " << typeid(*ins).name() << std::endl;
|
||||
Function *fn = block->getParent();
|
||||
ir::value *mask = ins->get_mask_pred();
|
||||
auto set_mask_insert_pt = [&](indices_t idx){
|
||||
if(mask){
|
||||
distributed_tile *mask_tile = (distributed_tile*)tmap_.at(ins->get_mask_pred());
|
||||
BasicBlock *block = pmap_.at({mask_tile, idx});
|
||||
builder.SetInsertPoint(block->getTerminator());
|
||||
}
|
||||
};
|
||||
// store
|
||||
if(auto *x = dynamic_cast<ir::store_inst*>(ins)) {
|
||||
distributed_tile* ptr = (distributed_tile*)tmap_.at(x->get_pointer_operand());
|
||||
tile *value = tmap_.at(x->get_value_operand());
|
||||
ptr->for_each([&](indices_t idx){
|
||||
set_mask_insert_pt(idx);
|
||||
builder.CreateStore(value->get_value(idx), ptr->get_value(idx));
|
||||
});
|
||||
}
|
||||
@@ -578,24 +558,46 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
}
|
||||
// mask
|
||||
else if(dynamic_cast<ir::mask_inst*>(ins)) {
|
||||
// distributed_tile* pred = (distributed_tile*)ins->get_operand(0);
|
||||
// BasicBlock *mask_done_bb = BasicBlock::Create(ctx, "mask_done");
|
||||
// pred->for_each([&](indices_t idx){
|
||||
// BasicBlock *mask_if_bb = BasicBlock::Create(ctx, "mask_if");
|
||||
// BasicBlock* mask_else_bb = BasicBlock::Create(ctx, "mask_else");
|
||||
// builder.CreateCondBr(pred->get_value(idx), mask_if_bb, mask_else_bb);
|
||||
// builder.SetInsertPoint(mask_if_bb);
|
||||
// builder.CreateBr(mask_done_bb);
|
||||
// builder.SetInsertPoint(mask_else_bb);
|
||||
// builder.CreateBr(mask_done_bb);
|
||||
// });
|
||||
// builder.SetInsertPoint(mask_done_bb);
|
||||
distributed_tile* pred = (distributed_tile*)tmap_.at(ins->get_operand(0));
|
||||
distributed_tile* mask_tile_true = (distributed_tile*)tmap_.at(ins->get_result(0));
|
||||
distributed_tile* mask_tile_false = (distributed_tile*)tmap_.at(ins->get_result(1));
|
||||
pred->for_each([&](indices_t idx){
|
||||
BasicBlock *mask_then_bb = BasicBlock::Create(ctx, "mask_then", fn);
|
||||
BasicBlock* mask_else_bb = BasicBlock::Create(ctx, "mask_else", fn);
|
||||
BasicBlock *mask_done_bb = BasicBlock::Create(ctx, "mask_done", fn);
|
||||
builder.CreateCondBr(pred->get_value(idx), mask_then_bb, mask_else_bb);
|
||||
builder.SetInsertPoint(mask_then_bb);
|
||||
builder.CreateBr(mask_done_bb);
|
||||
builder.SetInsertPoint(mask_else_bb);
|
||||
builder.CreateBr(mask_done_bb);
|
||||
builder.SetInsertPoint(mask_done_bb);
|
||||
pmap_.insert({{mask_tile_true, idx}, mask_then_bb});
|
||||
pmap_.insert({{mask_tile_false, idx}, mask_else_bb});
|
||||
last_block_.insert({{mask_tile_true, idx}, mask_done_bb});
|
||||
last_block_.insert({{mask_tile_false, idx}, mask_done_bb});
|
||||
});
|
||||
}
|
||||
// merge
|
||||
else if(dynamic_cast<ir::merge_inst*>(ins)) {
|
||||
// result->for_each([&](indices_t idx){
|
||||
// std::cout << "merge" << std::endl;
|
||||
// });
|
||||
else if(auto *merge = dynamic_cast<ir::merge_inst*>(ins)) {
|
||||
distributed_tile* mask_tile_true = (distributed_tile*)tmap_.at(merge->get_mask_true());
|
||||
distributed_tile *value_tile_true = (distributed_tile*)tmap_.at(merge->get_value_true());
|
||||
distributed_tile* mask_tile_false = (distributed_tile*)tmap_.at(merge->get_mask_false());
|
||||
distributed_tile *value_tile_false = (distributed_tile*)tmap_.at(merge->get_value_false());
|
||||
result->for_each([&](indices_t idx){
|
||||
BasicBlock *block_true = pmap_.at({mask_tile_true, idx});
|
||||
Value *value_true = value_tile_true->get_value(idx);
|
||||
BasicBlock *block_false = pmap_.at({mask_tile_false, idx});
|
||||
Value *value_false = value_tile_false->get_value(idx);
|
||||
BasicBlock *block_done = last_block_.at({mask_tile_true, idx});
|
||||
if(block_done->empty())
|
||||
builder.SetInsertPoint(block_done);
|
||||
else
|
||||
builder.SetInsertPoint(block_done->getTerminator());
|
||||
PHINode *phi = builder.CreatePHI(value_true->getType(), 2);
|
||||
phi->addIncoming(value_true, block_true);
|
||||
phi->addIncoming(value_false,block_false);
|
||||
result->set_value(idx, phi);
|
||||
});
|
||||
}
|
||||
// reshape
|
||||
else if(dynamic_cast<ir::reshape_inst*>(ins)) {
|
||||
@@ -691,12 +693,13 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
else
|
||||
return llvm_value(x, builder);
|
||||
};
|
||||
set_mask_insert_pt(idx);
|
||||
result->set_value(idx, llvm_inst(ins, value, builder));
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if(mask)
|
||||
builder.SetInsertPoint(block);
|
||||
}
|
||||
|
||||
void selection::lower_instruction(ir::instruction *src, IRBuilder<> &builder) {
|
||||
@@ -722,7 +725,6 @@ void selection::run(ir::module &src, Module &dst){
|
||||
vmap_.clear();
|
||||
LLVMContext &dst_ctx = dst.getContext();
|
||||
IRBuilder<> dst_builder(dst_ctx);
|
||||
std::map<ir::value*, llvm::BasicBlock*> block_of;
|
||||
|
||||
// iterate over functions
|
||||
for(ir::function *fn: src.get_function_list()) {
|
||||
@@ -771,12 +773,13 @@ void selection::run(ir::module &src, Module &dst){
|
||||
BasicBlock *parent = (BasicBlock*)vmap_[block];
|
||||
dst_builder.SetInsertPoint(parent);
|
||||
for(ir::instruction *i: block->get_inst_list()){
|
||||
if(dynamic_cast<ir::phi_node*>(i) && !parent->empty()){
|
||||
dst_builder.SetInsertPoint(&*parent->getFirstInsertionPt());
|
||||
}
|
||||
BasicBlock *current = dst_builder.GetInsertBlock();
|
||||
bool phi_inserted = (dynamic_cast<ir::phi_node*>(i) || dynamic_cast<ir::merge_inst*>(i)) && !current->empty();
|
||||
if(phi_inserted)
|
||||
dst_builder.SetInsertPoint(&*current->getFirstInsertionPt());
|
||||
lower_instruction(i, dst_builder);
|
||||
if(dynamic_cast<ir::phi_node*>(i) && !parent->empty())
|
||||
dst_builder.SetInsertPoint(parent);
|
||||
if(phi_inserted)
|
||||
dst_builder.SetInsertPoint(current);
|
||||
last_block[block] = dst_builder.GetInsertBlock();
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user