[dnn] better specification of recompilation key
This commit is contained in:
@@ -74,6 +74,11 @@ unsigned distributed_tile::get_linear_index(indices_t idx) {
|
||||
return indices_[idx];
|
||||
}
|
||||
|
||||
indices_t distributed_tile::get_ordered_indices(unsigned id) {
|
||||
return ordered_indices_.at(id);
|
||||
}
|
||||
|
||||
|
||||
void distributed_tile::for_each(std::function<void (indices_t)> fn) {
|
||||
for(unsigned i = 0; i < ordered_indices_.size(); i++)
|
||||
if(i % vector_size_ == 0)
|
||||
@@ -779,13 +784,21 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
// store
|
||||
if(auto *x = dynamic_cast<ir::masked_store_inst*>(ins)){
|
||||
distributed_tile* ptrs = (distributed_tile*)tmap_.at(x->get_pointer_operand());
|
||||
tile *scalars = tmap_.at(x->get_value_operand());
|
||||
distributed_tile* scalars = (distributed_tile*)tmap_.at(x->get_value_operand());
|
||||
ir::value *mask = x->get_mask_operand();
|
||||
distributed_tile* preds = (distributed_tile*)tmap_.at(mask);
|
||||
ptrs->for_each([&](indices_t idx){
|
||||
Value *scalar = scalars->get_value(idx);
|
||||
Value *ptr = ptrs->get_value(idx);
|
||||
Value *pred = preds->get_value(idx);
|
||||
BasicBlock *mask_then_bb = BasicBlock::Create(ctx, "mask_then", fn);
|
||||
BasicBlock *mask_done_bb = BasicBlock::Create(ctx, "mask_done", fn);
|
||||
builder.CreateCondBr(pred, mask_then_bb, mask_done_bb);
|
||||
builder.SetInsertPoint(mask_then_bb);
|
||||
builder.CreateStore(scalar, ptr);
|
||||
builder.CreateBr(mask_done_bb);
|
||||
builder.SetInsertPoint(mask_done_bb);
|
||||
|
||||
// std::string offset = "";
|
||||
// if(GetElementPtrInst *gep = dyn_cast<GetElementPtrInst>(ptr))
|
||||
// if(gep->getNumIndices() == 1)
|
||||
@@ -796,14 +809,6 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
// std::string asm_str = "@$0 st.global.b32 [$1" + offset + "], $2;";
|
||||
// InlineAsm *iasm = InlineAsm::get(ty, asm_str, "b,l,f", true);
|
||||
// builder.CreateCall(iasm, {pred, ptr, scalar});
|
||||
|
||||
BasicBlock *mask_then_bb = BasicBlock::Create(ctx, "mask_then", fn);
|
||||
BasicBlock *mask_done_bb = BasicBlock::Create(ctx, "mask_done", fn);
|
||||
builder.CreateCondBr(pred, mask_then_bb, mask_done_bb);
|
||||
builder.SetInsertPoint(mask_then_bb);
|
||||
builder.CreateStore(scalar, ptr);
|
||||
builder.CreateBr(mask_done_bb);
|
||||
builder.SetInsertPoint(mask_done_bb);
|
||||
});
|
||||
}
|
||||
else if(auto *x = dynamic_cast<ir::store_inst*>(ins)) {
|
||||
@@ -893,11 +898,8 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
ir::value* in = ins->get_operand(0);
|
||||
distributed_tile *in_tile = (distributed_tile*)tmap_.at(in);
|
||||
result->for_each([&](indices_t out_idx){
|
||||
indices_t in_idx;
|
||||
for(size_t k = 0; k < shapes.size(); k++){
|
||||
if(shapes[k]->get_value() > 1)
|
||||
in_idx.push_back(out_idx[k]);
|
||||
}
|
||||
unsigned pos = result->get_linear_index(out_idx);
|
||||
indices_t in_idx = in_tile->get_ordered_indices(pos);
|
||||
result->set_value(out_idx, in_tile->get_value(in_idx));
|
||||
});
|
||||
}
|
||||
|
Reference in New Issue
Block a user