better masking
This commit is contained in:
@@ -472,7 +472,7 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder,
|
||||
distributed_tile *T = new distributed_tile(ty, shapes2, axes, builder, vectorize);
|
||||
tmap_.insert({v, T});
|
||||
// constant range
|
||||
if(dynamic_cast<ir::constant*>(v)){
|
||||
if(dynamic_cast<ir::constant*>(v) && !dynamic_cast<ir::undef_value*>(v)){
|
||||
T->for_each([&](indices_t idx){
|
||||
assert(idx.size() == 1);
|
||||
T->set_value(idx, idx[0]);
|
||||
@@ -494,15 +494,21 @@ void selection::init_grids(ir::function *fn, IRBuilder<> &builder, Value *sh_mem
|
||||
std::vector<ir::value*> grids;
|
||||
std::map<unsigned*, ir::value*> references;
|
||||
create_grids(grids, references, fn);
|
||||
for(ir::value* i: grids)
|
||||
init_axes(i, builder, u_thread_warp_id, u_warp_id);
|
||||
for(ir::value* i: grids){
|
||||
if(auto *instr = dynamic_cast<ir::instruction*>(i))
|
||||
for(unsigned r = 0; r < instr->get_num_results(); r++)
|
||||
init_axes(instr->get_result(r), builder, u_thread_warp_id, u_warp_id);
|
||||
else
|
||||
init_axes(i, builder, u_thread_warp_id, u_warp_id);
|
||||
}
|
||||
// create tile
|
||||
std::set<ir::value*> seen;
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *i: block->get_inst_list()){
|
||||
if(!i->get_type()->is_tile_ty())
|
||||
continue;
|
||||
create_tile(i, builder, references, seen, sh_mem_ptr);
|
||||
for(unsigned r = 0; r < i->get_num_results(); r++)
|
||||
create_tile(i->get_result(r), builder, references, seen, sh_mem_ptr);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -510,46 +516,43 @@ void selection::init_grids(ir::function *fn, IRBuilder<> &builder, Value *sh_mem
|
||||
void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &builder) {
|
||||
BasicBlock *block = builder.GetInsertBlock();
|
||||
Module *module = block->getModule();
|
||||
Function *function = block->getParent();
|
||||
ir::value* mask_pred = ins->get_mask_pred();
|
||||
LLVMContext &ctx = builder.getContext();
|
||||
// helper to handle masks
|
||||
auto insert_masked = [&](indices_t idx, std::function<Value*()> insert_value) {
|
||||
BasicBlock *block = builder.GetInsertBlock();
|
||||
Value *result;
|
||||
if(mask_pred){
|
||||
// if(mask.else_value)
|
||||
// std::cout << mask.else_value << std::endl;
|
||||
Value *llvm_mask = tmap_.at(mask_pred)->get_value(idx);
|
||||
BasicBlock *then_bb = BasicBlock::Create(ctx, "", function);
|
||||
BasicBlock *done_bb = BasicBlock::Create(ctx, "", function);
|
||||
builder.CreateCondBr(llvm_mask, then_bb, done_bb);
|
||||
builder.SetInsertPoint(then_bb);
|
||||
result = insert_value();
|
||||
builder.CreateBr(done_bb);
|
||||
builder.SetInsertPoint(done_bb);
|
||||
if(!ins->get_type()->is_void_ty()){
|
||||
Type *ty = result->getType();
|
||||
PHINode *phi = builder.CreatePHI(ty, 2);
|
||||
// if(mask.else_value)
|
||||
// phi->addIncoming(tmap_.at(mask.else_value)->get_value(idx), block);
|
||||
// // helper to handle masks
|
||||
// auto insert_masked = [&](indices_t idx, std::function<Value*()> insert_value) {
|
||||
// BasicBlock *block = builder.GetInsertBlock();
|
||||
// Value *result;
|
||||
// if(mask_pred){
|
||||
// Value *llvm_mask = tmap_.at(mask_pred)->get_value(idx);
|
||||
// BasicBlock *then_bb = BasicBlock::Create(ctx, "", function);
|
||||
// BasicBlock *done_bb = BasicBlock::Create(ctx, "", function);
|
||||
// builder.CreateCondBr(llvm_mask, then_bb, done_bb);
|
||||
// builder.SetInsertPoint(then_bb);
|
||||
// result = insert_value();
|
||||
// builder.CreateBr(done_bb);
|
||||
// builder.SetInsertPoint(done_bb);
|
||||
// if(!ins->get_type()->is_void_ty()){
|
||||
// Type *ty = result->getType();
|
||||
// PHINode *phi = builder.CreatePHI(ty, 2);
|
||||
// if(mask_else)
|
||||
// phi->addIncoming(tmap_.at(mask_else)->get_value(idx), block);
|
||||
// else
|
||||
phi->addIncoming(llvm::UndefValue::get(ty), block);
|
||||
phi->addIncoming(result, then_bb);
|
||||
return (Value*)phi;
|
||||
}
|
||||
}
|
||||
else
|
||||
result = insert_value();
|
||||
return result;
|
||||
};
|
||||
// phi->addIncoming(llvm::UndefValue::get(ty), block);
|
||||
// phi->addIncoming(result, then_bb);
|
||||
// return (Value*)phi;
|
||||
// }
|
||||
// }
|
||||
// else
|
||||
// result = insert_value();
|
||||
// return result;
|
||||
// };
|
||||
|
||||
std::cout << ins->get_name() << " " << typeid(*ins).name() << std::endl;
|
||||
// store
|
||||
if(auto *x = dynamic_cast<ir::store_inst*>(ins)) {
|
||||
distributed_tile* ptr = (distributed_tile*)tmap_.at(x->get_pointer_operand());
|
||||
tile *value = tmap_.at(x->get_value_operand());
|
||||
ptr->for_each([&](indices_t idx){
|
||||
insert_masked(idx, [&]{ return builder.CreateStore(value->get_value(idx), ptr->get_value(idx)); });
|
||||
builder.CreateStore(value->get_value(idx), ptr->get_value(idx));
|
||||
});
|
||||
}
|
||||
else {
|
||||
@@ -570,9 +573,30 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
Value *offset = builder.CreateMul(builder.getInt32(shapes[0]->get_value()), group_id);
|
||||
result->for_each([&](indices_t idx){
|
||||
BinaryOperator *bin = static_cast<BinaryOperator*>(idx[0]);
|
||||
result->set_value(idx, insert_masked(idx, [&]{ return builder.CreateAdd(bin, offset); }));
|
||||
result->set_value(idx, builder.CreateAdd(bin, offset));
|
||||
});
|
||||
}
|
||||
// mask
|
||||
else if(dynamic_cast<ir::mask_inst*>(ins)) {
|
||||
// distributed_tile* pred = (distributed_tile*)ins->get_operand(0);
|
||||
// BasicBlock *mask_done_bb = BasicBlock::Create(ctx, "mask_done");
|
||||
// pred->for_each([&](indices_t idx){
|
||||
// BasicBlock *mask_if_bb = BasicBlock::Create(ctx, "mask_if");
|
||||
// BasicBlock* mask_else_bb = BasicBlock::Create(ctx, "mask_else");
|
||||
// builder.CreateCondBr(pred->get_value(idx), mask_if_bb, mask_else_bb);
|
||||
// builder.SetInsertPoint(mask_if_bb);
|
||||
// builder.CreateBr(mask_done_bb);
|
||||
// builder.SetInsertPoint(mask_else_bb);
|
||||
// builder.CreateBr(mask_done_bb);
|
||||
// });
|
||||
// builder.SetInsertPoint(mask_done_bb);
|
||||
}
|
||||
// merge
|
||||
else if(dynamic_cast<ir::merge_inst*>(ins)) {
|
||||
// result->for_each([&](indices_t idx){
|
||||
// std::cout << "merge" << std::endl;
|
||||
// });
|
||||
}
|
||||
// reshape
|
||||
else if(dynamic_cast<ir::reshape_inst*>(ins)) {
|
||||
ir::value* in = ins->get_operand(0);
|
||||
@@ -589,7 +613,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
// splat
|
||||
else if(dynamic_cast<ir::splat_inst*>(ins)) {
|
||||
result->for_each([&](indices_t idx) {
|
||||
result->set_value(idx, insert_masked(idx, [&]{ return llvm_value(ins->get_operand(0), builder); }));
|
||||
result->set_value(idx, llvm_value(ins->get_operand(0), builder));
|
||||
});
|
||||
}
|
||||
// broadcast
|
||||
@@ -667,7 +691,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
else
|
||||
return llvm_value(x, builder);
|
||||
};
|
||||
result->set_value(idx, insert_masked(idx, [&]() { return llvm_inst(ins, value, builder); }));
|
||||
result->set_value(idx, llvm_inst(ins, value, builder));
|
||||
});
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user