[code generation] fixed bugs in tile instructions lowering

This commit is contained in:
Philippe Tillet
2019-02-06 17:30:33 -05:00
parent 4490061950
commit 53aca3fa89
2 changed files with 19 additions and 19 deletions

View File

@@ -78,8 +78,8 @@ class selection{
private: private:
// LLVM conversions // LLVM conversions
llvm::Type* llvm_type(ir::type *ty, llvm::LLVMContext &ctx); llvm::Type* llvm_type(ir::type *ty, llvm::LLVMContext &ctx);
llvm::Value* llvm_value(ir::value *v, llvm:: LLVMContext &ctx, llvm::IRBuilder<> &builder); llvm::Value* llvm_value(ir::value *v, llvm::IRBuilder<> &builder);
llvm::Instruction* llvm_inst(ir::instruction *inst, std::function<llvm::Value*(ir::value*)> value, llvm::LLVMContext &ctx, llvm::IRBuilder<> &builder); llvm::Instruction* llvm_inst(ir::instruction *inst, std::function<llvm::Value*(ir::value*)> value, llvm::IRBuilder<> &builder);
llvm::Constant* llvm_constant(ir::constant *cst, llvm::LLVMContext &ctx); llvm::Constant* llvm_constant(ir::constant *cst, llvm::LLVMContext &ctx);
// grid construction // grid construction

View File

@@ -105,27 +105,28 @@ Constant *selection::llvm_constant(ir::constant *cst, LLVMContext &ctx) {
/* convert ir::instruction to llvm::Instruction */ /* convert ir::instruction to llvm::Instruction */
Instruction *selection::llvm_inst(ir::instruction *inst, std::function<Value*(ir::value*)> value, LLVMContext & ctx, IRBuilder<> &builder) { Instruction *selection::llvm_inst(ir::instruction *inst, std::function<Value*(ir::value*)> value, IRBuilder<> &builder) {
LLVMContext & ctx = builder.getContext();
auto block = [&](ir::basic_block *x) { return bmap_.at(x); }; auto block = [&](ir::basic_block *x) { return bmap_.at(x); };
auto type = [&](ir::type *x) { return llvm_type(x, ctx); }; auto type = [&](ir::type *x) { return llvm_type(x, ctx); };
if(auto* ii = dynamic_cast<ir::cond_branch_inst*>(inst)){ if(auto* ii = dynamic_cast<ir::cond_branch_inst*>(inst)){
BasicBlock *true_dest = block(ii->get_true_dest()); BasicBlock *true_dest = block(ii->get_true_dest());
BasicBlock *false_dest = block(ii->get_false_dest()); BasicBlock *false_dest = block(ii->get_false_dest());
Value *cond = value(ii->get_cond()); Value *cond = value(ii->get_cond());
return builder.CreateCondBr(cond, true_dest, false_dest); return builder.Insert(BranchInst::Create(true_dest, false_dest, cond));
} }
if(auto* ii = dynamic_cast<ir::uncond_branch_inst*>(inst)){ if(auto* ii = dynamic_cast<ir::uncond_branch_inst*>(inst)){
BasicBlock *dest = block(ii->get_dest()); BasicBlock *dest = block(ii->get_dest());
return builder.CreateBr(dest); return builder.Insert(BranchInst::Create(dest));
} }
if(auto* ii = dynamic_cast<ir::phi_node*>(inst)){ if(auto* ii = dynamic_cast<ir::phi_node*>(inst)){
Type *ty = type(ii->get_type()->get_scalar_ty()); Type *ty = type(ii->get_type()->get_scalar_ty());
unsigned num_ops = ii->get_num_operands(); unsigned num_ops = ii->get_num_operands();
return builder.CreatePHI(ty, num_ops); return builder.Insert(PHINode::Create(ty, num_ops));
} }
if(auto* ii = dynamic_cast<ir::return_inst*>(inst)){ if(auto* ii = dynamic_cast<ir::return_inst*>(inst)){
ir::value *ret_val = ii->get_return_value(); ir::value *ret_val = ii->get_return_value();
return builder.CreateRet(ret_val?value(ret_val):nullptr); return builder.Insert(ReturnInst::Create(ctx, ret_val?value(ret_val):nullptr));
} }
if(auto* ii = dynamic_cast<ir::binary_operator*>(inst)){ if(auto* ii = dynamic_cast<ir::binary_operator*>(inst)){
Value *lhs = value(ii->get_operand(0)); Value *lhs = value(ii->get_operand(0));
@@ -159,27 +160,28 @@ Instruction *selection::llvm_inst(ir::instruction *inst, std::function<Value*(ir
} }
if(ir::load_inst* ii = dynamic_cast<ir::load_inst*>(inst)){ if(ir::load_inst* ii = dynamic_cast<ir::load_inst*>(inst)){
Value *ptr = value(ii->get_pointer_operand()); Value *ptr = value(ii->get_pointer_operand());
return builder.CreateLoad(ptr); return builder.Insert(new LoadInst(ptr));
} }
// unknown instruction // unknown instruction
throw std::runtime_error("unknown conversion from ir::type to Type"); throw std::runtime_error("unknown conversion from ir::type to Type");
} }
/* convert ir::value to llvm::Value */ /* convert ir::value to llvm::Value */
Value* selection::llvm_value(ir::value *v, LLVMContext &ctx, IRBuilder<> &builder) { Value* selection::llvm_value(ir::value *v, IRBuilder<> &builder) {
assert(!v->get_type()->is_tile_ty()); assert(!v->get_type()->is_tile_ty());
LLVMContext &ctx = builder.getContext();
if(vmap_.find(v) != vmap_.end()) if(vmap_.find(v) != vmap_.end())
return vmap_.at(v); return vmap_.at(v);
// create operands // create operands
if(auto *uu = dynamic_cast<ir::user*>(v)) if(auto *uu = dynamic_cast<ir::user*>(v))
for(ir::value* u: uu->ops()) for(ir::value* u: uu->ops())
vmap_.insert({u, llvm_value(u, ctx, builder)}); vmap_.insert({u, llvm_value(u, builder)});
if(auto *cc = dynamic_cast<ir::constant*>(v)) if(auto *cc = dynamic_cast<ir::constant*>(v))
return llvm_constant(cc, ctx); return llvm_constant(cc, ctx);
// instruction // instruction
if(auto *ii = dynamic_cast<ir::instruction*>(v)){ if(auto *ii = dynamic_cast<ir::instruction*>(v)){
auto value = [&](ir::value *x) { return llvm_value(x, ctx, builder); }; auto value = [&](ir::value *x) { return llvm_value(x, builder); };
return llvm_inst(ii, value, ctx, builder); return llvm_inst(ii, value, builder);
} }
// unknown value // unknown value
throw std::runtime_error("unknown conversion from ir::value to Value"); throw std::runtime_error("unknown conversion from ir::value to Value");
@@ -393,7 +395,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
// splat // splat
else if(dynamic_cast<ir::splat_inst*>(ins)) { else if(dynamic_cast<ir::splat_inst*>(ins)) {
result->for_each([&](indices_t idx) { result->for_each([&](indices_t idx) {
result->set_value(idx, llvm_value(ins->get_operand(0), ctx, builder)); result->set_value(idx, llvm_value(ins->get_operand(0), builder));
}); });
} }
// broadcast // broadcast
@@ -421,9 +423,9 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
if(x->get_type()->is_tile_ty()) if(x->get_type()->is_tile_ty())
return tmap_.at(x)->get_value(idx); return tmap_.at(x)->get_value(idx);
else else
return llvm_value(x, ctx, builder); return llvm_value(x, builder);
}; };
result->set_value(idx, llvm_inst(ins, value, ctx, builder)); result->set_value(idx, llvm_inst(ins, value, builder));
}); });
} }
} }
@@ -432,14 +434,12 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
} }
void selection::lower_instruction(ir::instruction *src, IRBuilder<> &builder) { void selection::lower_instruction(ir::instruction *src, IRBuilder<> &builder) {
LLVMContext &ctx = builder.getContext();
if(src->has_tile_result_or_op()) { if(src->has_tile_result_or_op()) {
lower_tile_instruction(src, builder); lower_tile_instruction(src, builder);
} }
else { else {
Instruction *i = (Instruction*)llvm_value(src, ctx, builder); Instruction *i = (Instruction*)llvm_value(src, builder);
vmap_[src] = i; vmap_[src] = i;
builder.Insert(i);
} }
} }
@@ -478,7 +478,7 @@ void selection::run(ir::module &src, Module &dst){
for(unsigned i = 0; i < phi->get_num_incoming(); i++){ for(unsigned i = 0; i < phi->get_num_incoming(); i++){
ir::value *inc_val = phi->get_incoming_value(i); ir::value *inc_val = phi->get_incoming_value(i);
ir::basic_block *inc_block = phi->get_incoming_block(i); ir::basic_block *inc_block = phi->get_incoming_block(i);
Value *llvm_inc_val = llvm_value(inc_val, dst_ctx, dst_builder); Value *llvm_inc_val = llvm_value(inc_val, dst_builder);
BasicBlock *llvm_block = bmap_[inc_block]; BasicBlock *llvm_block = bmap_[inc_block];
dst_phi->addIncoming(llvm_inc_val, llvm_block); dst_phi->addIncoming(llvm_inc_val, llvm_block);
} }