better masking

This commit is contained in:
Philippe Tillet
2019-02-28 23:46:11 -05:00
parent 017702590b
commit 36acf22fd3
9 changed files with 203 additions and 86 deletions

View File

@@ -472,7 +472,7 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder,
distributed_tile *T = new distributed_tile(ty, shapes2, axes, builder, vectorize);
tmap_.insert({v, T});
// constant range
if(dynamic_cast<ir::constant*>(v)){
if(dynamic_cast<ir::constant*>(v) && !dynamic_cast<ir::undef_value*>(v)){
T->for_each([&](indices_t idx){
assert(idx.size() == 1);
T->set_value(idx, idx[0]);
@@ -494,15 +494,21 @@ void selection::init_grids(ir::function *fn, IRBuilder<> &builder, Value *sh_mem
std::vector<ir::value*> grids;
std::map<unsigned*, ir::value*> references;
create_grids(grids, references, fn);
for(ir::value* i: grids)
init_axes(i, builder, u_thread_warp_id, u_warp_id);
for(ir::value* i: grids){
if(auto *instr = dynamic_cast<ir::instruction*>(i))
for(unsigned r = 0; r < instr->get_num_results(); r++)
init_axes(instr->get_result(r), builder, u_thread_warp_id, u_warp_id);
else
init_axes(i, builder, u_thread_warp_id, u_warp_id);
}
// create tile
std::set<ir::value*> seen;
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *i: block->get_inst_list()){
if(!i->get_type()->is_tile_ty())
continue;
create_tile(i, builder, references, seen, sh_mem_ptr);
for(unsigned r = 0; r < i->get_num_results(); r++)
create_tile(i->get_result(r), builder, references, seen, sh_mem_ptr);
}
}
@@ -510,46 +516,43 @@ void selection::init_grids(ir::function *fn, IRBuilder<> &builder, Value *sh_mem
void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &builder) {
BasicBlock *block = builder.GetInsertBlock();
Module *module = block->getModule();
Function *function = block->getParent();
ir::value* mask_pred = ins->get_mask_pred();
LLVMContext &ctx = builder.getContext();
// helper to handle masks
auto insert_masked = [&](indices_t idx, std::function<Value*()> insert_value) {
BasicBlock *block = builder.GetInsertBlock();
Value *result;
if(mask_pred){
// if(mask.else_value)
// std::cout << mask.else_value << std::endl;
Value *llvm_mask = tmap_.at(mask_pred)->get_value(idx);
BasicBlock *then_bb = BasicBlock::Create(ctx, "", function);
BasicBlock *done_bb = BasicBlock::Create(ctx, "", function);
builder.CreateCondBr(llvm_mask, then_bb, done_bb);
builder.SetInsertPoint(then_bb);
result = insert_value();
builder.CreateBr(done_bb);
builder.SetInsertPoint(done_bb);
if(!ins->get_type()->is_void_ty()){
Type *ty = result->getType();
PHINode *phi = builder.CreatePHI(ty, 2);
// if(mask.else_value)
// phi->addIncoming(tmap_.at(mask.else_value)->get_value(idx), block);
// // helper to handle masks
// auto insert_masked = [&](indices_t idx, std::function<Value*()> insert_value) {
// BasicBlock *block = builder.GetInsertBlock();
// Value *result;
// if(mask_pred){
// Value *llvm_mask = tmap_.at(mask_pred)->get_value(idx);
// BasicBlock *then_bb = BasicBlock::Create(ctx, "", function);
// BasicBlock *done_bb = BasicBlock::Create(ctx, "", function);
// builder.CreateCondBr(llvm_mask, then_bb, done_bb);
// builder.SetInsertPoint(then_bb);
// result = insert_value();
// builder.CreateBr(done_bb);
// builder.SetInsertPoint(done_bb);
// if(!ins->get_type()->is_void_ty()){
// Type *ty = result->getType();
// PHINode *phi = builder.CreatePHI(ty, 2);
// if(mask_else)
// phi->addIncoming(tmap_.at(mask_else)->get_value(idx), block);
// else
phi->addIncoming(llvm::UndefValue::get(ty), block);
phi->addIncoming(result, then_bb);
return (Value*)phi;
}
}
else
result = insert_value();
return result;
};
// phi->addIncoming(llvm::UndefValue::get(ty), block);
// phi->addIncoming(result, then_bb);
// return (Value*)phi;
// }
// }
// else
// result = insert_value();
// return result;
// };
std::cout << ins->get_name() << " " << typeid(*ins).name() << std::endl;
// store
if(auto *x = dynamic_cast<ir::store_inst*>(ins)) {
distributed_tile* ptr = (distributed_tile*)tmap_.at(x->get_pointer_operand());
tile *value = tmap_.at(x->get_value_operand());
ptr->for_each([&](indices_t idx){
insert_masked(idx, [&]{ return builder.CreateStore(value->get_value(idx), ptr->get_value(idx)); });
builder.CreateStore(value->get_value(idx), ptr->get_value(idx));
});
}
else {
@@ -570,9 +573,30 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
Value *offset = builder.CreateMul(builder.getInt32(shapes[0]->get_value()), group_id);
result->for_each([&](indices_t idx){
BinaryOperator *bin = static_cast<BinaryOperator*>(idx[0]);
result->set_value(idx, insert_masked(idx, [&]{ return builder.CreateAdd(bin, offset); }));
result->set_value(idx, builder.CreateAdd(bin, offset));
});
}
// mask
else if(dynamic_cast<ir::mask_inst*>(ins)) {
// distributed_tile* pred = (distributed_tile*)ins->get_operand(0);
// BasicBlock *mask_done_bb = BasicBlock::Create(ctx, "mask_done");
// pred->for_each([&](indices_t idx){
// BasicBlock *mask_if_bb = BasicBlock::Create(ctx, "mask_if");
// BasicBlock* mask_else_bb = BasicBlock::Create(ctx, "mask_else");
// builder.CreateCondBr(pred->get_value(idx), mask_if_bb, mask_else_bb);
// builder.SetInsertPoint(mask_if_bb);
// builder.CreateBr(mask_done_bb);
// builder.SetInsertPoint(mask_else_bb);
// builder.CreateBr(mask_done_bb);
// });
// builder.SetInsertPoint(mask_done_bb);
}
// merge
else if(dynamic_cast<ir::merge_inst*>(ins)) {
// result->for_each([&](indices_t idx){
// std::cout << "merge" << std::endl;
// });
}
// reshape
else if(dynamic_cast<ir::reshape_inst*>(ins)) {
ir::value* in = ins->get_operand(0);
@@ -589,7 +613,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
// splat
else if(dynamic_cast<ir::splat_inst*>(ins)) {
result->for_each([&](indices_t idx) {
result->set_value(idx, insert_masked(idx, [&]{ return llvm_value(ins->get_operand(0), builder); }));
result->set_value(idx, llvm_value(ins->get_operand(0), builder));
});
}
// broadcast
@@ -667,7 +691,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
else
return llvm_value(x, builder);
};
result->set_value(idx, insert_masked(idx, [&]() { return llvm_inst(ins, value, builder); }));
result->set_value(idx, llvm_inst(ins, value, builder));
});
}
}