This commit is contained in:
Philippe Tillet
2019-02-06 17:21:07 -05:00
parent 5aec34a094
commit 4490061950
4 changed files with 149 additions and 130 deletions

View File

@@ -82,9 +82,9 @@ int main() {
// std::vector<unsigned*> params = tune.get_params(module); // std::vector<unsigned*> params = tune.get_params(module);
// std::cout << params.size() << std::endl; // std::cout << params.size() << std::endl;
// selection.run(module, llvm_module); // selection.run(module, llvm_module);
// // print LLVM program // print LLVM program
// llvm::PrintModulePass print(llvm::outs()); llvm::PrintModulePass print(llvm::outs());
// llvm::AnalysisManager<llvm::Module> analysis; llvm::AnalysisManager<llvm::Module> analysis;
// print.run(llvm_module, analysis); print.run(llvm_module, analysis);
return 0; return 0;
} }

View File

@@ -78,8 +78,8 @@ class selection{
private: private:
// LLVM conversions // LLVM conversions
llvm::Type* llvm_type(ir::type *ty, llvm::LLVMContext &ctx); llvm::Type* llvm_type(ir::type *ty, llvm::LLVMContext &ctx);
llvm::Value* llvm_value(ir::value *v,llvm:: LLVMContext &ctx); llvm::Value* llvm_value(ir::value *v, llvm:: LLVMContext &ctx, llvm::IRBuilder<> &builder);
llvm::Instruction* llvm_inst(ir::instruction *inst, std::function<llvm::Value*(ir::value*)> value, llvm::LLVMContext &ctx); llvm::Instruction* llvm_inst(ir::instruction *inst, std::function<llvm::Value*(ir::value*)> value, llvm::LLVMContext &ctx, llvm::IRBuilder<> &builder);
llvm::Constant* llvm_constant(ir::constant *cst, llvm::LLVMContext &ctx); llvm::Constant* llvm_constant(ir::constant *cst, llvm::LLVMContext &ctx);
// grid construction // grid construction

View File

@@ -324,7 +324,7 @@ ir::value* initializer::codegen(ir::module * mod) const{
ir::value *value = ir::undef_value::get(ty); ir::value *value = ir::undef_value::get(ty);
if(expr_){ if(expr_){
value = expr_->codegen(mod); value = expr_->codegen(mod);
explicit_cast(mod->get_builder(), value, ty->get_scalar_ty()); value = explicit_cast(mod->get_builder(), value, ty->get_scalar_ty());
implicit_broadcast(mod, value, ty); implicit_broadcast(mod, value, ty);
} }
value->set_name(name); value->set_name(name);
@@ -336,85 +336,85 @@ ir::value* initializer::codegen(ir::module * mod) const{
/* Expression */ /* Expression */
/*------------------*/ /*------------------*/
/* Binary operator */ /* Binary operator */
ir::value *binary_operator::llvm_op(ir::module *mod, ir::builder &builder, ir::value *arg, ir::value *rhs, const std::string &name) const ir::value *binary_operator::llvm_op(ir::module *mod, ir::builder &builder, ir::value *lhs, ir::value *rhs, const std::string &name) const
{ {
bool is_float = false, is_ptr = false, is_int = false, is_signed = false; bool is_float = false, is_ptr = false, is_int = false, is_signed = false;
implicit_cast(builder, arg, rhs, is_float, is_ptr, is_int, is_signed); implicit_cast(builder, lhs, rhs, is_float, is_ptr, is_int, is_signed);
implicit_broadcast(mod, arg, rhs); implicit_broadcast(mod, lhs, rhs);
if(op_==MUL && is_float) if(op_==MUL && is_float)
return builder.create_fmul(arg, rhs, name); return builder.create_fmul(lhs, rhs, name);
if(op_==MUL && is_int) if(op_==MUL && is_int)
return builder.create_mul(arg, rhs, name); return builder.create_mul(lhs, rhs, name);
if(op_==DIV && is_float) if(op_==DIV && is_float)
return builder.create_fdiv(arg, rhs, name); return builder.create_fdiv(lhs, rhs, name);
if(op_==DIV && is_int && is_signed) if(op_==DIV && is_int && is_signed)
return builder.create_sdiv(arg, rhs, name); return builder.create_sdiv(lhs, rhs, name);
if(op_==DIV && is_int && !is_signed) if(op_==DIV && is_int && !is_signed)
return builder.create_udiv(arg, rhs, name); return builder.create_udiv(lhs, rhs, name);
if(op_==MOD && is_float) if(op_==MOD && is_float)
return builder.create_frem(arg, rhs, name); return builder.create_frem(lhs, rhs, name);
if(op_==MOD && is_int && is_signed) if(op_==MOD && is_int && is_signed)
return builder.create_srem(arg, rhs, name); return builder.create_srem(lhs, rhs, name);
if(op_==MOD && is_int && !is_signed) if(op_==MOD && is_int && !is_signed)
return builder.create_urem(arg, rhs, name); return builder.create_urem(lhs, rhs, name);
if(op_==ADD && is_float) if(op_==ADD && is_float)
return builder.create_fadd(arg, rhs, name); return builder.create_fadd(lhs, rhs, name);
if(op_==ADD && is_int) if(op_==ADD && is_int)
return builder.create_add(arg, rhs); return builder.create_add(lhs, rhs);
if(op_==ADD && is_ptr) if(op_==ADD && is_ptr)
return builder.create_gep(arg, {rhs}); return builder.create_gep(lhs, {rhs});
if(op_==SUB && is_float) if(op_==SUB && is_float)
return builder.create_fsub(arg, rhs, name); return builder.create_fsub(lhs, rhs, name);
if(op_==SUB && is_int) if(op_==SUB && is_int)
return builder.create_sub(arg, rhs, name); return builder.create_sub(lhs, rhs, name);
if(op_==SUB && is_ptr) if(op_==SUB && is_ptr)
return builder.create_gep(arg, {builder.create_neg(rhs)}); return builder.create_gep(lhs, {builder.create_neg(rhs)});
if(op_==LEFT_SHIFT) if(op_==LEFT_SHIFT)
return builder.create_shl(arg, rhs, name); return builder.create_shl(lhs, rhs, name);
if(op_==RIGHT_SHIFT) if(op_==RIGHT_SHIFT)
return builder.create_ashr(arg, rhs, name); return builder.create_ashr(lhs, rhs, name);
if(op_ == LT && is_float) if(op_ == LT && is_float)
return builder.create_fcmpOLT(arg, rhs, name); return builder.create_fcmpOLT(lhs, rhs, name);
if(op_ == LT && is_int && is_signed) if(op_ == LT && is_int && is_signed)
return builder.create_icmpSLT(arg, rhs, name); return builder.create_icmpSLT(lhs, rhs, name);
if(op_ == LT && is_int && !is_signed) if(op_ == LT && is_int && !is_signed)
return builder.create_icmpULT(arg, rhs, name); return builder.create_icmpULT(lhs, rhs, name);
if(op_ == GT && is_float) if(op_ == GT && is_float)
return builder.create_fcmpOGT(arg, rhs, name); return builder.create_fcmpOGT(lhs, rhs, name);
if(op_ == GT && is_int && is_signed) if(op_ == GT && is_int && is_signed)
return builder.create_icmpSGT(arg, rhs, name); return builder.create_icmpSGT(lhs, rhs, name);
if(op_ == GT && is_int && !is_signed) if(op_ == GT && is_int && !is_signed)
return builder.create_icmpUGT(arg, rhs, name); return builder.create_icmpUGT(lhs, rhs, name);
if(op_ == LE && is_float) if(op_ == LE && is_float)
return builder.create_fcmpOLE(arg, rhs, name); return builder.create_fcmpOLE(lhs, rhs, name);
if(op_ == LE && is_int && is_signed) if(op_ == LE && is_int && is_signed)
return builder.create_icmpSLE(arg, rhs, name); return builder.create_icmpSLE(lhs, rhs, name);
if(op_ == LE && is_int && !is_signed) if(op_ == LE && is_int && !is_signed)
return builder.create_icmpULE(arg, rhs, name); return builder.create_icmpULE(lhs, rhs, name);
if(op_ == GE && is_float) if(op_ == GE && is_float)
return builder.create_fcmpOGE(arg, rhs, name); return builder.create_fcmpOGE(lhs, rhs, name);
if(op_ == GE && is_int && is_signed) if(op_ == GE && is_int && is_signed)
return builder.create_icmpSGE(arg, rhs, name); return builder.create_icmpSGE(lhs, rhs, name);
if(op_ == GE && is_int && !is_signed) if(op_ == GE && is_int && !is_signed)
return builder.create_icmpUGE(arg, rhs, name); return builder.create_icmpUGE(lhs, rhs, name);
if(op_ == EQ && is_float) if(op_ == EQ && is_float)
return builder.create_fcmpOEQ(arg, rhs, name); return builder.create_fcmpOEQ(lhs, rhs, name);
if(op_ == EQ && is_int) if(op_ == EQ && is_int)
return builder.create_icmpEQ(arg, rhs, name); return builder.create_icmpEQ(lhs, rhs, name);
if(op_ == NE && is_float) if(op_ == NE && is_float)
return builder.create_fcmpONE(arg, rhs, name); return builder.create_fcmpONE(lhs, rhs, name);
if(op_ == NE && is_int) if(op_ == NE && is_int)
return builder.create_icmpNE(arg, rhs, name); return builder.create_icmpNE(lhs, rhs, name);
if(op_ == AND) if(op_ == AND)
return builder.create_and(arg, rhs, name); return builder.create_and(lhs, rhs, name);
if(op_ == XOR) if(op_ == XOR)
return builder.create_xor(arg, rhs, name); return builder.create_xor(lhs, rhs, name);
if(op_ == OR) if(op_ == OR)
return builder.create_or(arg, rhs, name); return builder.create_or(lhs, rhs, name);
if(op_ == LAND) if(op_ == LAND)
return builder.create_and(arg, rhs, name); return builder.create_and(lhs, rhs, name);
if(op_ == LOR) if(op_ == LOR)
return builder.create_or(arg, rhs, name); return builder.create_or(lhs, rhs, name);
throw std::runtime_error("unreachable"); throw std::runtime_error("unreachable");
} }

View File

@@ -105,49 +105,49 @@ Constant *selection::llvm_constant(ir::constant *cst, LLVMContext &ctx) {
/* convert ir::instruction to llvm::Instruction */ /* convert ir::instruction to llvm::Instruction */
Instruction *selection::llvm_inst(ir::instruction *inst, std::function<Value*(ir::value*)> value, LLVMContext & ctx) { Instruction *selection::llvm_inst(ir::instruction *inst, std::function<Value*(ir::value*)> value, LLVMContext & ctx, IRBuilder<> &builder) {
auto block = [&](ir::basic_block *x) { return bmap_.at(x); }; auto block = [&](ir::basic_block *x) { return bmap_.at(x); };
auto type = [&](ir::type *x) { return llvm_type(x, ctx); }; auto type = [&](ir::type *x) { return llvm_type(x, ctx); };
if(auto* ii = dynamic_cast<ir::cond_branch_inst*>(inst)){ if(auto* ii = dynamic_cast<ir::cond_branch_inst*>(inst)){
BasicBlock *true_dest = block(ii->get_true_dest()); BasicBlock *true_dest = block(ii->get_true_dest());
BasicBlock *false_dest = block(ii->get_false_dest()); BasicBlock *false_dest = block(ii->get_false_dest());
Value *cond = value(ii->get_cond()); Value *cond = value(ii->get_cond());
return BranchInst::Create(true_dest, false_dest, cond); return builder.CreateCondBr(cond, true_dest, false_dest);
} }
if(auto* ii = dynamic_cast<ir::uncond_branch_inst*>(inst)){ if(auto* ii = dynamic_cast<ir::uncond_branch_inst*>(inst)){
BasicBlock *dest = block(ii->get_dest()); BasicBlock *dest = block(ii->get_dest());
return BranchInst::Create(dest); return builder.CreateBr(dest);
} }
if(auto* ii = dynamic_cast<ir::phi_node*>(inst)){ if(auto* ii = dynamic_cast<ir::phi_node*>(inst)){
Type *ty = type(ii->get_type()->get_scalar_ty()); Type *ty = type(ii->get_type()->get_scalar_ty());
unsigned num_ops = ii->get_num_operands(); unsigned num_ops = ii->get_num_operands();
return PHINode::Create(ty, num_ops, ii->get_name()); return builder.CreatePHI(ty, num_ops);
} }
if(auto* ii = dynamic_cast<ir::return_inst*>(inst)){ if(auto* ii = dynamic_cast<ir::return_inst*>(inst)){
ir::value *ret_val = ii->get_return_value(); ir::value *ret_val = ii->get_return_value();
return ReturnInst::Create(ctx, ret_val?value(ret_val):nullptr); return builder.CreateRet(ret_val?value(ret_val):nullptr);
} }
if(auto* ii = dynamic_cast<ir::binary_operator*>(inst)){ if(auto* ii = dynamic_cast<ir::binary_operator*>(inst)){
Value *lhs = value(ii->get_operand(0)); Value *lhs = value(ii->get_operand(0));
Value *rhs = value(ii->get_operand(1)); Value *rhs = value(ii->get_operand(1));
return BinaryOperator::Create(ii->get_op(), lhs, rhs, ii->get_name()); return builder.Insert(BinaryOperator::Create(ii->get_op(), lhs, rhs));
} }
if(auto* ii = dynamic_cast<ir::icmp_inst*>(inst)){ if(auto* ii = dynamic_cast<ir::icmp_inst*>(inst)){
CmpInst::Predicate pred = ii->get_pred(); CmpInst::Predicate pred = ii->get_pred();
Value *lhs = value(ii->get_operand(0)); Value *lhs = value(ii->get_operand(0));
Value *rhs = value(ii->get_operand(1)); Value *rhs = value(ii->get_operand(1));
return CmpInst::Create(Instruction::ICmp, pred, lhs, rhs, ii->get_name()); return builder.Insert(CmpInst::Create(Instruction::ICmp, pred, lhs, rhs));
} }
if(auto* ii = dynamic_cast<ir::fcmp_inst*>(inst)){ if(auto* ii = dynamic_cast<ir::fcmp_inst*>(inst)){
CmpInst::Predicate pred = ii->get_pred(); CmpInst::Predicate pred = ii->get_pred();
Value *lhs = value(ii->get_operand(0)); Value *lhs = value(ii->get_operand(0));
Value *rhs = value(ii->get_operand(1)); Value *rhs = value(ii->get_operand(1));
return FCmpInst::Create(Instruction::FCmp, pred, lhs, rhs, ii->get_name()); return builder.Insert(FCmpInst::Create(Instruction::FCmp, pred, lhs, rhs));
} }
if(auto* ii = dynamic_cast<ir::cast_inst*>(inst)){ if(auto* ii = dynamic_cast<ir::cast_inst*>(inst)){
Value *arg = value(ii->get_operand(0)); Value *arg = value(ii->get_operand(0));
Type *dst_ty = type(ii->get_type()->get_scalar_ty()); Type *dst_ty = type(ii->get_type()->get_scalar_ty());
return CastInst::Create(ii->get_op(), arg, dst_ty, ii->get_name()); return builder.Insert(CastInst::Create(ii->get_op(), arg, dst_ty));
} }
if(auto* ii = dynamic_cast<ir::getelementptr_inst*>(inst)){ if(auto* ii = dynamic_cast<ir::getelementptr_inst*>(inst)){
std::vector<Value*> idx_vals; std::vector<Value*> idx_vals;
@@ -155,31 +155,31 @@ Instruction *selection::llvm_inst(ir::instruction *inst, std::function<Value*(ir
[&value](ir::value* x){ return value(x);}); [&value](ir::value* x){ return value(x);});
Type *source_ty = type(ii->get_source_elt_ty()->get_scalar_ty()); Type *source_ty = type(ii->get_source_elt_ty()->get_scalar_ty());
Value *arg = value(ii->get_operand(0)); Value *arg = value(ii->get_operand(0));
return GetElementPtrInst::Create(source_ty, arg, idx_vals, ii->get_name()); return builder.Insert(GetElementPtrInst::Create(source_ty, arg, idx_vals));
} }
if(ir::load_inst* ii = dynamic_cast<ir::load_inst*>(inst)){ if(ir::load_inst* ii = dynamic_cast<ir::load_inst*>(inst)){
Value *ptr = value(ii->get_pointer_operand()); Value *ptr = value(ii->get_pointer_operand());
return new LoadInst(ptr, ii->get_name()); return builder.CreateLoad(ptr);
} }
// unknown instruction // unknown instruction
throw std::runtime_error("unknown conversion from ir::type to Type"); throw std::runtime_error("unknown conversion from ir::type to Type");
} }
/* convert ir::value to llvm::Value */ /* convert ir::value to llvm::Value */
Value* selection::llvm_value(ir::value *v, LLVMContext &ctx) { Value* selection::llvm_value(ir::value *v, LLVMContext &ctx, IRBuilder<> &builder) {
assert(!v->get_type()->is_tile_ty()); assert(!v->get_type()->is_tile_ty());
if(vmap_.find(v) != vmap_.end()) if(vmap_.find(v) != vmap_.end())
return vmap_.at(v); return vmap_.at(v);
// create operands // create operands
if(auto *uu = dynamic_cast<ir::user*>(v)) if(auto *uu = dynamic_cast<ir::user*>(v))
for(ir::value* u: uu->ops()) for(ir::value* u: uu->ops())
vmap_[u] = llvm_value(u, ctx); vmap_.insert({u, llvm_value(u, ctx, builder)});
if(auto *cc = dynamic_cast<ir::constant*>(v)) if(auto *cc = dynamic_cast<ir::constant*>(v))
return llvm_constant(cc, ctx); return llvm_constant(cc, ctx);
// instruction // instruction
if(auto *ii = dynamic_cast<ir::instruction*>(v)){ if(auto *ii = dynamic_cast<ir::instruction*>(v)){
auto value = [&](ir::value *x) { return llvm_value(x, ctx); }; auto value = [&](ir::value *x) { return llvm_value(x, ctx, builder); };
return llvm_inst(ii, value, ctx); return llvm_inst(ii, value, ctx, builder);
} }
// unknown value // unknown value
throw std::runtime_error("unknown conversion from ir::value to Value"); throw std::runtime_error("unknown conversion from ir::value to Value");
@@ -308,7 +308,14 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder,
else else
axes[d].values = {builder.getInt32(0)}; axes[d].values = {builder.getInt32(0)};
} }
tmap_.insert({v, new distributed_tile(ty, shapes, axes)}); distributed_tile *T = new distributed_tile(ty, shapes, axes);
tmap_.insert({v, T});
// constant range
if(dynamic_cast<ir::constant*>(v))
T->for_each([&](indices_t idx){
T->set_value(idx, idx[0]);
});
} }
} }
@@ -340,76 +347,88 @@ void selection::init_grids(ir::function *fn, IRBuilder<> &builder){
void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &builder) { void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &builder) {
Module *module = builder.GetInsertBlock()->getModule(); Module *module = builder.GetInsertBlock()->getModule();
LLVMContext &ctx = builder.getContext(); LLVMContext &ctx = builder.getContext();
tile *ti = tmap_[ins]; // store
distributed_tile* result = (distributed_tile*)ti; if(auto *x = dynamic_cast<ir::store_inst*>(ins)) {
if(!ins->get_type()->is_tile_ty()) distributed_tile* ptr = (distributed_tile*)tmap_.at(x->get_pointer_operand());
return; tile *value = tmap_.at(x->get_value_operand());
const auto& shapes = ins->get_type()->get_tile_shapes(); ptr->for_each([&](indices_t idx){
// global_range builder.CreateStore(value->get_value(idx), ptr->get_value(idx));
if(auto *x = dynamic_cast<ir::get_global_range_inst*>(ins)) {
static std::array<Intrinsic::ID, 3> ctaid = {
Intrinsic::nvvm_read_ptx_sreg_ctaid_x,
Intrinsic::nvvm_read_ptx_sreg_ctaid_y,
Intrinsic::nvvm_read_ptx_sreg_ctaid_z
};
Function *get_group_id = Intrinsic::getDeclaration(module, ctaid[x->get_axis()]);
Value *group_id = builder.CreateCall(get_group_id, {});
Value *offset = builder.CreateMul(builder.getInt32(shapes[0]), group_id);
result->for_each([&](indices_t idx){
BinaryOperator *bin = static_cast<BinaryOperator*>(idx[0]);
result->set_value(idx, builder.CreateAdd(bin->getOperand(1),
builder.CreateAdd(bin->getOperand(0), offset)));
}); });
} }
// reshape
else if(dynamic_cast<ir::reshape_inst*>(ins)) {
ir::value* in = ins->get_operand(0);
distributed_tile *in_tile = (distributed_tile*)tmap_.at(in);
result->for_each([&](indices_t out_idx){
indices_t in_idx;
for(size_t k = 0; k < shapes.size(); k++){
if(shapes[k] > 1)
in_idx.push_back(out_idx[k]);
}
result->set_value(out_idx, in_tile->get_value(in_idx));
});
}
// splat
else if(dynamic_cast<ir::splat_inst*>(ins)) {
result->for_each([&](indices_t idx) {
result->set_value(idx, llvm_value(ins->get_operand(0), ctx));
});
}
// broadcast
else if(dynamic_cast<ir::broadcast_inst*>(ins)) {
ir::value* in = ins->get_operand(0);
const auto& in_shapes = in->get_type()->get_tile_shapes();
distributed_tile *in_tile = (distributed_tile*)tmap_.at(in);
result->for_each([&](indices_t out_idx){
indices_t in_idx = out_idx;
for(size_t k = 0; k < in_idx.size(); k++){
if(in_shapes[k] == 1)
in_idx[k] = builder.getInt32(0);
result->set_value(out_idx, in_tile->get_value(in_idx));
}
});
}
// copy to shared
else if(dynamic_cast<ir::copy_to_shared_inst*>(ins)) {
}
// element-wise
else { else {
result->for_each([&](indices_t idx){ tile *ti = tmap_[ins];
auto value = [&](ir::value *x) { distributed_tile* result = (distributed_tile*)ti;
if(x->get_type()->is_tile_ty()) if(!ins->get_type()->is_tile_ty())
return tmap_.at(x)->get_value(idx); return;
else const auto& shapes = ins->get_type()->get_tile_shapes();
return llvm_value(x, ctx); // global_range
if(auto *x = dynamic_cast<ir::get_global_range_inst*>(ins)) {
static std::array<Intrinsic::ID, 3> ctaid = {
Intrinsic::nvvm_read_ptx_sreg_ctaid_x,
Intrinsic::nvvm_read_ptx_sreg_ctaid_y,
Intrinsic::nvvm_read_ptx_sreg_ctaid_z
}; };
result->set_value(idx, llvm_inst(ins, value, ctx)); Function *get_group_id = Intrinsic::getDeclaration(module, ctaid[x->get_axis()]);
}); Value *group_id = builder.CreateCall(get_group_id, {});
Value *offset = builder.CreateMul(builder.getInt32(shapes[0]), group_id);
result->for_each([&](indices_t idx){
BinaryOperator *bin = static_cast<BinaryOperator*>(idx[0]);
result->set_value(idx, builder.CreateAdd(bin->getOperand(1),
builder.CreateAdd(bin->getOperand(0), offset)));
});
}
// reshape
else if(dynamic_cast<ir::reshape_inst*>(ins)) {
ir::value* in = ins->get_operand(0);
distributed_tile *in_tile = (distributed_tile*)tmap_.at(in);
result->for_each([&](indices_t out_idx){
indices_t in_idx;
for(size_t k = 0; k < shapes.size(); k++){
if(shapes[k] > 1)
in_idx.push_back(out_idx[k]);
}
result->set_value(out_idx, in_tile->get_value(in_idx));
});
}
// splat
else if(dynamic_cast<ir::splat_inst*>(ins)) {
result->for_each([&](indices_t idx) {
result->set_value(idx, llvm_value(ins->get_operand(0), ctx, builder));
});
}
// broadcast
else if(dynamic_cast<ir::broadcast_inst*>(ins)) {
ir::value* in = ins->get_operand(0);
const auto& in_shapes = in->get_type()->get_tile_shapes();
distributed_tile *in_tile = (distributed_tile*)tmap_.at(in);
result->for_each([&](indices_t out_idx){
indices_t in_idx = out_idx;
for(size_t k = 0; k < in_idx.size(); k++){
if(in_shapes[k] == 1)
in_idx[k] = builder.getInt32(0);
result->set_value(out_idx, in_tile->get_value(in_idx));
}
});
}
// copy to shared
else if(dynamic_cast<ir::copy_to_shared_inst*>(ins)) {
}
// element-wise
else {
result->for_each([&](indices_t idx){
auto value = [&](ir::value *x) {
if(x->get_type()->is_tile_ty())
return tmap_.at(x)->get_value(idx);
else
return llvm_value(x, ctx, builder);
};
result->set_value(idx, llvm_inst(ins, value, ctx, builder));
});
}
} }
} }
void selection::lower_instruction(ir::instruction *src, IRBuilder<> &builder) { void selection::lower_instruction(ir::instruction *src, IRBuilder<> &builder) {
@@ -418,7 +437,7 @@ void selection::lower_instruction(ir::instruction *src, IRBuilder<> &builder) {
lower_tile_instruction(src, builder); lower_tile_instruction(src, builder);
} }
else { else {
Instruction *i = (Instruction*)llvm_value(src, ctx); Instruction *i = (Instruction*)llvm_value(src, ctx, builder);
vmap_[src] = i; vmap_[src] = i;
builder.Insert(i); builder.Insert(i);
} }
@@ -459,7 +478,7 @@ void selection::run(ir::module &src, Module &dst){
for(unsigned i = 0; i < phi->get_num_incoming(); i++){ for(unsigned i = 0; i < phi->get_num_incoming(); i++){
ir::value *inc_val = phi->get_incoming_value(i); ir::value *inc_val = phi->get_incoming_value(i);
ir::basic_block *inc_block = phi->get_incoming_block(i); ir::basic_block *inc_block = phi->get_incoming_block(i);
Value *llvm_inc_val = llvm_value(inc_val, dst_ctx); Value *llvm_inc_val = llvm_value(inc_val, dst_ctx, dst_builder);
BasicBlock *llvm_block = bmap_[inc_block]; BasicBlock *llvm_block = bmap_[inc_block];
dst_phi->addIncoming(llvm_inc_val, llvm_block); dst_phi->addIncoming(llvm_inc_val, llvm_block);
} }