[dnn] better specification of recompilation key

This commit is contained in:
Philippe Tillet
2019-08-02 17:42:48 -07:00
parent 3b92ddf7e6
commit d9945692a9
31 changed files with 418 additions and 428 deletions

View File

@@ -74,6 +74,11 @@ unsigned distributed_tile::get_linear_index(indices_t idx) {
return indices_[idx];
}
indices_t distributed_tile::get_ordered_indices(unsigned id) {
return ordered_indices_.at(id);
}
void distributed_tile::for_each(std::function<void (indices_t)> fn) {
for(unsigned i = 0; i < ordered_indices_.size(); i++)
if(i % vector_size_ == 0)
@@ -779,13 +784,21 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
// store
if(auto *x = dynamic_cast<ir::masked_store_inst*>(ins)){
distributed_tile* ptrs = (distributed_tile*)tmap_.at(x->get_pointer_operand());
tile *scalars = tmap_.at(x->get_value_operand());
distributed_tile* scalars = (distributed_tile*)tmap_.at(x->get_value_operand());
ir::value *mask = x->get_mask_operand();
distributed_tile* preds = (distributed_tile*)tmap_.at(mask);
ptrs->for_each([&](indices_t idx){
Value *scalar = scalars->get_value(idx);
Value *ptr = ptrs->get_value(idx);
Value *pred = preds->get_value(idx);
BasicBlock *mask_then_bb = BasicBlock::Create(ctx, "mask_then", fn);
BasicBlock *mask_done_bb = BasicBlock::Create(ctx, "mask_done", fn);
builder.CreateCondBr(pred, mask_then_bb, mask_done_bb);
builder.SetInsertPoint(mask_then_bb);
builder.CreateStore(scalar, ptr);
builder.CreateBr(mask_done_bb);
builder.SetInsertPoint(mask_done_bb);
// std::string offset = "";
// if(GetElementPtrInst *gep = dyn_cast<GetElementPtrInst>(ptr))
// if(gep->getNumIndices() == 1)
@@ -796,14 +809,6 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
// std::string asm_str = "@$0 st.global.b32 [$1" + offset + "], $2;";
// InlineAsm *iasm = InlineAsm::get(ty, asm_str, "b,l,f", true);
// builder.CreateCall(iasm, {pred, ptr, scalar});
BasicBlock *mask_then_bb = BasicBlock::Create(ctx, "mask_then", fn);
BasicBlock *mask_done_bb = BasicBlock::Create(ctx, "mask_done", fn);
builder.CreateCondBr(pred, mask_then_bb, mask_done_bb);
builder.SetInsertPoint(mask_then_bb);
builder.CreateStore(scalar, ptr);
builder.CreateBr(mask_done_bb);
builder.SetInsertPoint(mask_done_bb);
});
}
else if(auto *x = dynamic_cast<ir::store_inst*>(ins)) {
@@ -893,11 +898,8 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
ir::value* in = ins->get_operand(0);
distributed_tile *in_tile = (distributed_tile*)tmap_.at(in);
result->for_each([&](indices_t out_idx){
indices_t in_idx;
for(size_t k = 0; k < shapes.size(); k++){
if(shapes[k]->get_value() > 1)
in_idx.push_back(out_idx[k]);
}
unsigned pos = result->get_linear_index(out_idx);
indices_t in_idx = in_tile->get_ordered_indices(pos);
result->set_value(out_idx, in_tile->get_value(in_idx));
});
}