From f9ba69f1a4690d7e3c45f2e53276a04d688e0e1d Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Sat, 5 Jan 2019 19:23:00 -0500 Subject: [PATCH] [code generation] some bugfixes --- examples/matrix.cpp | 6 ++ include/codegen/lowering.h | 181 +++++++++++++++++++++++++++---------- include/ir/constant.h | 2 + include/ir/instructions.h | 4 +- include/ir/value.h | 1 + 5 files changed, 143 insertions(+), 51 deletions(-) diff --git a/examples/matrix.cpp b/examples/matrix.cpp index 41e6120e2..bae0ee52b 100644 --- a/examples/matrix.cpp +++ b/examples/matrix.cpp @@ -4,6 +4,9 @@ #include "ir/context.h" #include "ir/module.h" #include "codegen/lowering.h" +#include "llvm/IR/IRPrintingPasses.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/LLVMContext.h" typedef struct yy_buffer_state * YY_BUFFER_STATE; extern int yyparse(); @@ -35,6 +38,9 @@ int main() { tdl::ir::context context; tdl::ir::module module("matrix", context); program->codegen(&module); + llvm::LLVMContext llvm_context; + llvm::Module llvm_module("test", llvm_context); + tdl::codegen::lowering(module, llvm_module); // llvm::PrintModulePass print(llvm::outs()); // llvm::AnalysisManager analysis; // print.run(*module.handle(), analysis); diff --git a/include/codegen/lowering.h b/include/codegen/lowering.h index d418c38b7..673bcbf09 100644 --- a/include/codegen/lowering.h +++ b/include/codegen/lowering.h @@ -12,79 +12,146 @@ namespace tdl{ namespace codegen{ -/* convert ir::type to llvm::Type */ +using namespace llvm; -llvm::Type *llvm_type(ir::type *ty, llvm::LLVMContext &ctx) { +/* convert ir::type to Type */ +Type *llvm_type(ir::type *ty, LLVMContext &ctx) { // function if(auto* tt = dynamic_cast(ty)){ - llvm::Type *return_ty = llvm_type(tt->get_return_ty(), ctx); - std::vector param_tys; + Type *return_ty = llvm_type(tt->get_return_ty(), ctx); + std::vector param_tys; std::transform(tt->params_begin(), tt->params_end(), std::back_inserter(param_tys), [&ctx](ir::type* t){ return llvm_type(t, ctx);}); - return llvm::FunctionType::get(return_ty, param_tys, false); + return FunctionType::get(return_ty, param_tys, false); } // pointer if(ty->is_pointer_ty()){ - llvm::Type *elt_ty = llvm_type(ty->get_pointer_element_ty(), ctx); + Type *elt_ty = llvm_type(ty->get_pointer_element_ty(), ctx); unsigned addr_space = ty->get_pointer_address_space(); - return llvm::PointerType::get(elt_ty, addr_space); + return PointerType::get(elt_ty, addr_space); } // integer if(ty->is_integer_ty()){ unsigned bitwidth = ty->get_integer_bitwidth(); - return llvm::IntegerType::get(ctx, bitwidth); + return IntegerType::get(ctx, bitwidth); } // primitive types switch(ty->get_type_id()){ - case ir::type::VoidTyID: return llvm::Type::getVoidTy(ctx); - case ir::type::HalfTyID: return llvm::Type::getHalfTy(ctx); - case ir::type::FloatTyID: return llvm::Type::getFloatTy(ctx); - case ir::type::DoubleTyID: return llvm::Type::getDoubleTy(ctx); - case ir::type::X86_FP80TyID: return llvm::Type::getX86_FP80Ty(ctx); - case ir::type::PPC_FP128TyID: return llvm::Type::getPPC_FP128Ty(ctx); - case ir::type::LabelTyID: return llvm::Type::getLabelTy(ctx); - case ir::type::MetadataTyID: return llvm::Type::getMetadataTy(ctx); - case ir::type::TokenTyID: return llvm::Type::getTokenTy(ctx); + case ir::type::VoidTyID: return Type::getVoidTy(ctx); + case ir::type::HalfTyID: return Type::getHalfTy(ctx); + case ir::type::FloatTyID: return Type::getFloatTy(ctx); + case ir::type::DoubleTyID: return Type::getDoubleTy(ctx); + case ir::type::X86_FP80TyID: return Type::getX86_FP80Ty(ctx); + case ir::type::PPC_FP128TyID: return Type::getPPC_FP128Ty(ctx); + case ir::type::LabelTyID: return Type::getLabelTy(ctx); + case ir::type::MetadataTyID: return Type::getMetadataTy(ctx); + case ir::type::TokenTyID: return Type::getTokenTy(ctx); default: break; } // unknown type - throw std::runtime_error("unknown conversion from ir::type to llvm::Type"); + throw std::runtime_error("unknown conversion from ir::type to Type"); } -/* convert ir::instruction to llvm::Instruction */ -llvm::Instruction *llvm_inst(ir::instruction *inst, llvm::LLVMContext & ctx, - std::map &v, - std::map &b) { - if(auto* ii = dynamic_cast(inst)) - return llvm::BranchInst::Create(b[ii->get_true_dest()], b[ii->get_false_dest()], v[ii->get_cond()]); - if(auto* ii = dynamic_cast(inst)) - return llvm::BranchInst::Create(b[ii->get_dest()]); - if(auto* ii = dynamic_cast(inst)) - return llvm::PHINode::Create(llvm_type(ii->get_type(), ctx), ii->get_num_operands(), ii->get_name()); - if(auto* ii = dynamic_cast(inst)) - return llvm::ReturnInst::Create(ctx, v[ii->get_return_value()]); - if(auto* ii = dynamic_cast(inst)) - return llvm::BinaryOperator::Create(ii->get_op(), v[ii->get_operand(0)], v[ii->get_operand(1)], ii->get_name()); - if(auto* ii = dynamic_cast(inst)) - return llvm::CmpInst::Create(llvm::Instruction::ICmp, ii->get_pred(), v[ii->get_operand(0)], v[ii->get_operand(1)], ii->get_name()); - if(auto* ii = dynamic_cast(inst)) - return llvm::FCmpInst::Create(llvm::Instruction::FCmp, ii->get_pred(), v[ii->get_operand(0)], v[ii->get_operand(1)], ii->get_name()); - if(auto* ii = dynamic_cast(inst)) - return llvm::CastInst::Create(ii->get_op(), v[ii->get_operand(0)], llvm_type(ii->get_type(), ctx), ii->get_name()); - if(auto* ii = dynamic_cast(inst)){ - std::vector idx_vals; - std::transform(ii->idx_begin(), ii->idx_end(), std::back_inserter(idx_vals), - [&v](ir::value* x){ return v[x];}); - return llvm::GetElementPtrInst::Create(llvm_type(ii->get_source_elt_ty(), ctx), v[ii->get_operand(0)], idx_vals, ii->get_name()); +Value* llvm_value(ir::value *v, LLVMContext &ctx, + std::map &vmap, + std::map &bmap); + +/* convert ir::constant to Constant */ +Constant *llvm_constant(ir::constant *cst, LLVMContext &ctx) { + Type *dst_ty = llvm_type(cst->get_type(), ctx); + if(auto* cc = dynamic_cast(cst)) + return ConstantInt::get(dst_ty, cc->get_value()); + if(auto* cc = dynamic_cast(cst)) + return ConstantFP::get(dst_ty, cc->get_value()); + // unknown constant + throw std::runtime_error("unknown conversion from ir::constant to Constant"); +} + + +/* convert ir::instruction to Instruction */ +Instruction *llvm_inst(ir::instruction *inst, LLVMContext & ctx, + std::map &vmap, + std::map &bmap) { + auto value = [&](ir::value *x) { return llvm_value(x, ctx, vmap, bmap); }; + auto block = [&](ir::basic_block *x) { return bmap.at(x); }; + auto type = [&](ir::type *x) { return llvm_type(x, ctx); }; + if(auto* ii = dynamic_cast(inst)){ + BasicBlock *true_dest = block(ii->get_true_dest()); + BasicBlock *false_dest = block(ii->get_false_dest()); + Value *cond = value(ii->get_cond()); + return BranchInst::Create(true_dest, false_dest, cond); + } + if(auto* ii = dynamic_cast(inst)){ + BasicBlock *dest = block(ii->get_dest()); + return BranchInst::Create(dest); + } + if(auto* ii = dynamic_cast(inst)){ + Type *ty = type(ii->get_type()); + unsigned num_ops = ii->get_num_operands(); + return PHINode::Create(ty, num_ops, ii->get_name()); + } + if(auto* ii = dynamic_cast(inst)){ + Value *ret_val = value(ii->get_return_value()); + return ReturnInst::Create(ctx, ret_val); + } + if(auto* ii = dynamic_cast(inst)){ + Value *lhs = value(ii->get_operand(0)); + Value *rhs = value(ii->get_operand(1)); + return BinaryOperator::Create(ii->get_op(), lhs, rhs, ii->get_name()); + } + if(auto* ii = dynamic_cast(inst)){ + CmpInst::Predicate pred = ii->get_pred(); + Value *lhs = value(ii->get_operand(0)); + Value *rhs = value(ii->get_operand(1)); + return CmpInst::Create(Instruction::ICmp, pred, lhs, rhs, ii->get_name()); + } + if(auto* ii = dynamic_cast(inst)){ + CmpInst::Predicate pred = ii->get_pred(); + Value *lhs = value(ii->get_operand(0)); + Value *rhs = value(ii->get_operand(1)); + return FCmpInst::Create(Instruction::FCmp, pred, lhs, rhs, ii->get_name()); + } + if(auto* ii = dynamic_cast(inst)){ + Value *arg = value(ii->get_operand(0)); + Type *dst_ty = type(ii->get_type()); + return CastInst::Create(ii->get_op(), arg, dst_ty, ii->get_name()); + } + if(auto* ii = dynamic_cast(inst)){ + std::vector idx_vals; + std::transform(ii->idx_begin(), ii->idx_end(), std::back_inserter(idx_vals), + [&value](ir::value* x){ return value(x);}); + Type *source_ty = type(ii->get_source_elt_ty()); + Value *arg = value(ii->get_operand(0)); + return GetElementPtrInst::Create(source_ty, arg, idx_vals, ii->get_name()); + } + if(ir::load_inst* ii = dynamic_cast(inst)){ + Value *ptr = value(ii->get_pointer_operand()); + return new LoadInst(ptr, ii->get_name()); } - if(ir::load_inst* ii = dynamic_cast(inst)) - return new llvm::LoadInst(v[ii->get_pointer_operand()], ii->get_name()); // unknown instruction - throw std::runtime_error("unknown conversion from ir::type to llvm::Type"); + throw std::runtime_error("unknown conversion from ir::type to Type"); } -void lowering(ir::module &src, llvm::Module &dst){ - using namespace llvm; +Value* llvm_value(ir::value *v, LLVMContext &ctx, + std::map &vmap, + std::map &bmap) { + if(vmap.find(v) != vmap.end()) + return vmap.at(v); + // create operands + if(auto *uu = dynamic_cast(v)) + for(ir::use u: uu->ops()) + vmap[u.get()] = llvm_value(u, ctx, vmap, bmap); + // constant + if(auto *cc = dynamic_cast(v)) + return llvm_constant(cc, ctx); + // instruction + if(auto *ii = dynamic_cast(v)) + return llvm_inst(ii, ctx, vmap, bmap); + // unknown value + throw std::runtime_error("unknown conversion from ir::value to Value"); +} + +void lowering(ir::module &src, Module &dst){ std::map vmap; std::map bmap; LLVMContext &dst_ctx = dst.getContext(); @@ -93,10 +160,14 @@ void lowering(ir::module &src, llvm::Module &dst){ for(ir::function *fn: src.get_function_list()) { // create LLVM function Type *fn_ty = llvm_type(fn->get_type(), dst_ctx); - Function *dst_function = (Function*)dst.getOrInsertFunction(fn->get_name(), fn_ty); + Function *dst_fn = (Function*)dst.getOrInsertFunction(fn->get_name(), fn_ty); + // map parameters + for(unsigned i = 0; i < fn->args().size(); i++) { + vmap[fn->args()[i]] = &*(dst_fn->arg_begin() + i); + } // create blocks for(ir::basic_block *block: fn->blocks()) { - BasicBlock *dst_block = BasicBlock::Create(dst_ctx, block->get_name(), dst_function); + BasicBlock *dst_block = BasicBlock::Create(dst_ctx, block->get_name(), dst_fn); bmap[block] = dst_block; } // iterate through block @@ -108,6 +179,16 @@ void lowering(ir::module &src, llvm::Module &dst){ } } // add phi operands + for(ir::basic_block *block: fn->blocks()) + for(ir::instruction *inst: block->get_inst_list()) + if(auto *phi = dynamic_cast(inst)){ + PHINode *dst_phi = (PHINode*)vmap.at(phi); + for(unsigned i = 0; i < phi->get_num_incoming(); i++){ + ir::value *inc_val = phi->get_incoming_value(i); + ir::basic_block *inc_block = phi->get_incoming_block(i); + dst_phi->addIncoming(vmap[inc_val], bmap[inc_block]); + } + } } } diff --git a/include/ir/constant.h b/include/ir/constant.h index dff6606f8..57af4ad33 100644 --- a/include/ir/constant.h +++ b/include/ir/constant.h @@ -33,6 +33,7 @@ class constant_int: public constant{ constant_int(type *ty, uint64_t value); public: + uint64_t get_value() const { return value_; } static constant *get(type *ty, uint64_t value); private: @@ -44,6 +45,7 @@ class constant_fp: public constant{ constant_fp(context &ctx, double value); public: + double get_value() { return value_; } static constant* get_negative_zero(type *ty); static constant* get_zero_value_for_negation(type *ty); static constant *get(context &ctx, double v); diff --git a/include/ir/instructions.h b/include/ir/instructions.h index 8ec38cd1e..8632d9098 100644 --- a/include/ir/instructions.h +++ b/include/ir/instructions.h @@ -42,7 +42,9 @@ private: public: void set_incoming_value(unsigned i, value *v); void set_incoming_block(unsigned i, basic_block *block); - + value *get_incoming_value(unsigned i) { return get_operand(i); } + basic_block *get_incoming_block(unsigned i) { return blocks_[i]; } + unsigned get_num_incoming() { return get_num_operands(); } void add_incoming(value *v, basic_block *block); // Factory methods diff --git a/include/ir/value.h b/include/ir/value.h index df2f099de..35c6ca839 100644 --- a/include/ir/value.h +++ b/include/ir/value.h @@ -71,6 +71,7 @@ public: : value(ty, name), ops_(num_ops){ } // Operands + const std::vector& ops() { return ops_; } op_iterator op_begin() { return ops_.begin(); } op_iterator op_end() { return ops_.end(); } void set_operand(unsigned i, value *x);