This commit is contained in:
Philippe Tillet
2019-02-06 17:21:07 -05:00
parent 5aec34a094
commit 4490061950
4 changed files with 149 additions and 130 deletions

View File

@@ -82,9 +82,9 @@ int main() {
// std::vector<unsigned*> params = tune.get_params(module);
// std::cout << params.size() << std::endl;
// selection.run(module, llvm_module);
// // print LLVM program
// llvm::PrintModulePass print(llvm::outs());
// llvm::AnalysisManager<llvm::Module> analysis;
// print.run(llvm_module, analysis);
// print LLVM program
llvm::PrintModulePass print(llvm::outs());
llvm::AnalysisManager<llvm::Module> analysis;
print.run(llvm_module, analysis);
return 0;
}

View File

@@ -78,8 +78,8 @@ class selection{
private:
// LLVM conversions
llvm::Type* llvm_type(ir::type *ty, llvm::LLVMContext &ctx);
llvm::Value* llvm_value(ir::value *v,llvm:: LLVMContext &ctx);
llvm::Instruction* llvm_inst(ir::instruction *inst, std::function<llvm::Value*(ir::value*)> value, llvm::LLVMContext &ctx);
llvm::Value* llvm_value(ir::value *v, llvm:: LLVMContext &ctx, llvm::IRBuilder<> &builder);
llvm::Instruction* llvm_inst(ir::instruction *inst, std::function<llvm::Value*(ir::value*)> value, llvm::LLVMContext &ctx, llvm::IRBuilder<> &builder);
llvm::Constant* llvm_constant(ir::constant *cst, llvm::LLVMContext &ctx);
// grid construction

View File

@@ -324,7 +324,7 @@ ir::value* initializer::codegen(ir::module * mod) const{
ir::value *value = ir::undef_value::get(ty);
if(expr_){
value = expr_->codegen(mod);
explicit_cast(mod->get_builder(), value, ty->get_scalar_ty());
value = explicit_cast(mod->get_builder(), value, ty->get_scalar_ty());
implicit_broadcast(mod, value, ty);
}
value->set_name(name);
@@ -336,85 +336,85 @@ ir::value* initializer::codegen(ir::module * mod) const{
/* Expression */
/*------------------*/
/* Binary operator */
ir::value *binary_operator::llvm_op(ir::module *mod, ir::builder &builder, ir::value *arg, ir::value *rhs, const std::string &name) const
ir::value *binary_operator::llvm_op(ir::module *mod, ir::builder &builder, ir::value *lhs, ir::value *rhs, const std::string &name) const
{
bool is_float = false, is_ptr = false, is_int = false, is_signed = false;
implicit_cast(builder, arg, rhs, is_float, is_ptr, is_int, is_signed);
implicit_broadcast(mod, arg, rhs);
implicit_cast(builder, lhs, rhs, is_float, is_ptr, is_int, is_signed);
implicit_broadcast(mod, lhs, rhs);
if(op_==MUL && is_float)
return builder.create_fmul(arg, rhs, name);
return builder.create_fmul(lhs, rhs, name);
if(op_==MUL && is_int)
return builder.create_mul(arg, rhs, name);
return builder.create_mul(lhs, rhs, name);
if(op_==DIV && is_float)
return builder.create_fdiv(arg, rhs, name);
return builder.create_fdiv(lhs, rhs, name);
if(op_==DIV && is_int && is_signed)
return builder.create_sdiv(arg, rhs, name);
return builder.create_sdiv(lhs, rhs, name);
if(op_==DIV && is_int && !is_signed)
return builder.create_udiv(arg, rhs, name);
return builder.create_udiv(lhs, rhs, name);
if(op_==MOD && is_float)
return builder.create_frem(arg, rhs, name);
return builder.create_frem(lhs, rhs, name);
if(op_==MOD && is_int && is_signed)
return builder.create_srem(arg, rhs, name);
return builder.create_srem(lhs, rhs, name);
if(op_==MOD && is_int && !is_signed)
return builder.create_urem(arg, rhs, name);
return builder.create_urem(lhs, rhs, name);
if(op_==ADD && is_float)
return builder.create_fadd(arg, rhs, name);
return builder.create_fadd(lhs, rhs, name);
if(op_==ADD && is_int)
return builder.create_add(arg, rhs);
return builder.create_add(lhs, rhs);
if(op_==ADD && is_ptr)
return builder.create_gep(arg, {rhs});
return builder.create_gep(lhs, {rhs});
if(op_==SUB && is_float)
return builder.create_fsub(arg, rhs, name);
return builder.create_fsub(lhs, rhs, name);
if(op_==SUB && is_int)
return builder.create_sub(arg, rhs, name);
return builder.create_sub(lhs, rhs, name);
if(op_==SUB && is_ptr)
return builder.create_gep(arg, {builder.create_neg(rhs)});
return builder.create_gep(lhs, {builder.create_neg(rhs)});
if(op_==LEFT_SHIFT)
return builder.create_shl(arg, rhs, name);
return builder.create_shl(lhs, rhs, name);
if(op_==RIGHT_SHIFT)
return builder.create_ashr(arg, rhs, name);
return builder.create_ashr(lhs, rhs, name);
if(op_ == LT && is_float)
return builder.create_fcmpOLT(arg, rhs, name);
return builder.create_fcmpOLT(lhs, rhs, name);
if(op_ == LT && is_int && is_signed)
return builder.create_icmpSLT(arg, rhs, name);
return builder.create_icmpSLT(lhs, rhs, name);
if(op_ == LT && is_int && !is_signed)
return builder.create_icmpULT(arg, rhs, name);
return builder.create_icmpULT(lhs, rhs, name);
if(op_ == GT && is_float)
return builder.create_fcmpOGT(arg, rhs, name);
return builder.create_fcmpOGT(lhs, rhs, name);
if(op_ == GT && is_int && is_signed)
return builder.create_icmpSGT(arg, rhs, name);
return builder.create_icmpSGT(lhs, rhs, name);
if(op_ == GT && is_int && !is_signed)
return builder.create_icmpUGT(arg, rhs, name);
return builder.create_icmpUGT(lhs, rhs, name);
if(op_ == LE && is_float)
return builder.create_fcmpOLE(arg, rhs, name);
return builder.create_fcmpOLE(lhs, rhs, name);
if(op_ == LE && is_int && is_signed)
return builder.create_icmpSLE(arg, rhs, name);
return builder.create_icmpSLE(lhs, rhs, name);
if(op_ == LE && is_int && !is_signed)
return builder.create_icmpULE(arg, rhs, name);
return builder.create_icmpULE(lhs, rhs, name);
if(op_ == GE && is_float)
return builder.create_fcmpOGE(arg, rhs, name);
return builder.create_fcmpOGE(lhs, rhs, name);
if(op_ == GE && is_int && is_signed)
return builder.create_icmpSGE(arg, rhs, name);
return builder.create_icmpSGE(lhs, rhs, name);
if(op_ == GE && is_int && !is_signed)
return builder.create_icmpUGE(arg, rhs, name);
return builder.create_icmpUGE(lhs, rhs, name);
if(op_ == EQ && is_float)
return builder.create_fcmpOEQ(arg, rhs, name);
return builder.create_fcmpOEQ(lhs, rhs, name);
if(op_ == EQ && is_int)
return builder.create_icmpEQ(arg, rhs, name);
return builder.create_icmpEQ(lhs, rhs, name);
if(op_ == NE && is_float)
return builder.create_fcmpONE(arg, rhs, name);
return builder.create_fcmpONE(lhs, rhs, name);
if(op_ == NE && is_int)
return builder.create_icmpNE(arg, rhs, name);
return builder.create_icmpNE(lhs, rhs, name);
if(op_ == AND)
return builder.create_and(arg, rhs, name);
return builder.create_and(lhs, rhs, name);
if(op_ == XOR)
return builder.create_xor(arg, rhs, name);
return builder.create_xor(lhs, rhs, name);
if(op_ == OR)
return builder.create_or(arg, rhs, name);
return builder.create_or(lhs, rhs, name);
if(op_ == LAND)
return builder.create_and(arg, rhs, name);
return builder.create_and(lhs, rhs, name);
if(op_ == LOR)
return builder.create_or(arg, rhs, name);
return builder.create_or(lhs, rhs, name);
throw std::runtime_error("unreachable");
}

View File

@@ -105,49 +105,49 @@ Constant *selection::llvm_constant(ir::constant *cst, LLVMContext &ctx) {
/* convert ir::instruction to llvm::Instruction */
Instruction *selection::llvm_inst(ir::instruction *inst, std::function<Value*(ir::value*)> value, LLVMContext & ctx) {
Instruction *selection::llvm_inst(ir::instruction *inst, std::function<Value*(ir::value*)> value, LLVMContext & ctx, IRBuilder<> &builder) {
auto block = [&](ir::basic_block *x) { return bmap_.at(x); };
auto type = [&](ir::type *x) { return llvm_type(x, ctx); };
if(auto* ii = dynamic_cast<ir::cond_branch_inst*>(inst)){
BasicBlock *true_dest = block(ii->get_true_dest());
BasicBlock *false_dest = block(ii->get_false_dest());
Value *cond = value(ii->get_cond());
return BranchInst::Create(true_dest, false_dest, cond);
return builder.CreateCondBr(cond, true_dest, false_dest);
}
if(auto* ii = dynamic_cast<ir::uncond_branch_inst*>(inst)){
BasicBlock *dest = block(ii->get_dest());
return BranchInst::Create(dest);
return builder.CreateBr(dest);
}
if(auto* ii = dynamic_cast<ir::phi_node*>(inst)){
Type *ty = type(ii->get_type()->get_scalar_ty());
unsigned num_ops = ii->get_num_operands();
return PHINode::Create(ty, num_ops, ii->get_name());
return builder.CreatePHI(ty, num_ops);
}
if(auto* ii = dynamic_cast<ir::return_inst*>(inst)){
ir::value *ret_val = ii->get_return_value();
return ReturnInst::Create(ctx, ret_val?value(ret_val):nullptr);
return builder.CreateRet(ret_val?value(ret_val):nullptr);
}
if(auto* ii = dynamic_cast<ir::binary_operator*>(inst)){
Value *lhs = value(ii->get_operand(0));
Value *rhs = value(ii->get_operand(1));
return BinaryOperator::Create(ii->get_op(), lhs, rhs, ii->get_name());
return builder.Insert(BinaryOperator::Create(ii->get_op(), lhs, rhs));
}
if(auto* ii = dynamic_cast<ir::icmp_inst*>(inst)){
CmpInst::Predicate pred = ii->get_pred();
Value *lhs = value(ii->get_operand(0));
Value *rhs = value(ii->get_operand(1));
return CmpInst::Create(Instruction::ICmp, pred, lhs, rhs, ii->get_name());
return builder.Insert(CmpInst::Create(Instruction::ICmp, pred, lhs, rhs));
}
if(auto* ii = dynamic_cast<ir::fcmp_inst*>(inst)){
CmpInst::Predicate pred = ii->get_pred();
Value *lhs = value(ii->get_operand(0));
Value *rhs = value(ii->get_operand(1));
return FCmpInst::Create(Instruction::FCmp, pred, lhs, rhs, ii->get_name());
return builder.Insert(FCmpInst::Create(Instruction::FCmp, pred, lhs, rhs));
}
if(auto* ii = dynamic_cast<ir::cast_inst*>(inst)){
Value *arg = value(ii->get_operand(0));
Type *dst_ty = type(ii->get_type()->get_scalar_ty());
return CastInst::Create(ii->get_op(), arg, dst_ty, ii->get_name());
return builder.Insert(CastInst::Create(ii->get_op(), arg, dst_ty));
}
if(auto* ii = dynamic_cast<ir::getelementptr_inst*>(inst)){
std::vector<Value*> idx_vals;
@@ -155,31 +155,31 @@ Instruction *selection::llvm_inst(ir::instruction *inst, std::function<Value*(ir
[&value](ir::value* x){ return value(x);});
Type *source_ty = type(ii->get_source_elt_ty()->get_scalar_ty());
Value *arg = value(ii->get_operand(0));
return GetElementPtrInst::Create(source_ty, arg, idx_vals, ii->get_name());
return builder.Insert(GetElementPtrInst::Create(source_ty, arg, idx_vals));
}
if(ir::load_inst* ii = dynamic_cast<ir::load_inst*>(inst)){
Value *ptr = value(ii->get_pointer_operand());
return new LoadInst(ptr, ii->get_name());
return builder.CreateLoad(ptr);
}
// unknown instruction
throw std::runtime_error("unknown conversion from ir::type to Type");
}
/* convert ir::value to llvm::Value */
Value* selection::llvm_value(ir::value *v, LLVMContext &ctx) {
Value* selection::llvm_value(ir::value *v, LLVMContext &ctx, IRBuilder<> &builder) {
assert(!v->get_type()->is_tile_ty());
if(vmap_.find(v) != vmap_.end())
return vmap_.at(v);
// create operands
if(auto *uu = dynamic_cast<ir::user*>(v))
for(ir::value* u: uu->ops())
vmap_[u] = llvm_value(u, ctx);
vmap_.insert({u, llvm_value(u, ctx, builder)});
if(auto *cc = dynamic_cast<ir::constant*>(v))
return llvm_constant(cc, ctx);
// instruction
if(auto *ii = dynamic_cast<ir::instruction*>(v)){
auto value = [&](ir::value *x) { return llvm_value(x, ctx); };
return llvm_inst(ii, value, ctx);
auto value = [&](ir::value *x) { return llvm_value(x, ctx, builder); };
return llvm_inst(ii, value, ctx, builder);
}
// unknown value
throw std::runtime_error("unknown conversion from ir::value to Value");
@@ -308,7 +308,14 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder,
else
axes[d].values = {builder.getInt32(0)};
}
tmap_.insert({v, new distributed_tile(ty, shapes, axes)});
distributed_tile *T = new distributed_tile(ty, shapes, axes);
tmap_.insert({v, T});
// constant range
if(dynamic_cast<ir::constant*>(v))
T->for_each([&](indices_t idx){
T->set_value(idx, idx[0]);
});
}
}
@@ -340,6 +347,15 @@ void selection::init_grids(ir::function *fn, IRBuilder<> &builder){
void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &builder) {
Module *module = builder.GetInsertBlock()->getModule();
LLVMContext &ctx = builder.getContext();
// store
if(auto *x = dynamic_cast<ir::store_inst*>(ins)) {
distributed_tile* ptr = (distributed_tile*)tmap_.at(x->get_pointer_operand());
tile *value = tmap_.at(x->get_value_operand());
ptr->for_each([&](indices_t idx){
builder.CreateStore(value->get_value(idx), ptr->get_value(idx));
});
}
else {
tile *ti = tmap_[ins];
distributed_tile* result = (distributed_tile*)ti;
if(!ins->get_type()->is_tile_ty())
@@ -377,7 +393,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
// splat
else if(dynamic_cast<ir::splat_inst*>(ins)) {
result->for_each([&](indices_t idx) {
result->set_value(idx, llvm_value(ins->get_operand(0), ctx));
result->set_value(idx, llvm_value(ins->get_operand(0), ctx, builder));
});
}
// broadcast
@@ -405,20 +421,23 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
if(x->get_type()->is_tile_ty())
return tmap_.at(x)->get_value(idx);
else
return llvm_value(x, ctx);
return llvm_value(x, ctx, builder);
};
result->set_value(idx, llvm_inst(ins, value, ctx));
result->set_value(idx, llvm_inst(ins, value, ctx, builder));
});
}
}
}
void selection::lower_instruction(ir::instruction *src, IRBuilder<> &builder) {
LLVMContext &ctx = builder.getContext();
if(src->has_tile_result_or_op()) {
lower_tile_instruction(src, builder);
}
else {
Instruction *i = (Instruction*)llvm_value(src, ctx);
Instruction *i = (Instruction*)llvm_value(src, ctx, builder);
vmap_[src] = i;
builder.Insert(i);
}
@@ -459,7 +478,7 @@ void selection::run(ir::module &src, Module &dst){
for(unsigned i = 0; i < phi->get_num_incoming(); i++){
ir::value *inc_val = phi->get_incoming_value(i);
ir::basic_block *inc_block = phi->get_incoming_block(i);
Value *llvm_inc_val = llvm_value(inc_val, dst_ctx);
Value *llvm_inc_val = llvm_value(inc_val, dst_ctx, dst_builder);
BasicBlock *llvm_block = bmap_[inc_block];
dst_phi->addIncoming(llvm_inc_val, llvm_block);
}