[code generation] some bugfixes

This commit is contained in:
Philippe Tillet
2019-01-05 19:23:00 -05:00
parent ec656af57c
commit f9ba69f1a4
5 changed files with 143 additions and 51 deletions

View File

@@ -4,6 +4,9 @@
#include "ir/context.h"
#include "ir/module.h"
#include "codegen/lowering.h"
#include "llvm/IR/IRPrintingPasses.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/LLVMContext.h"
typedef struct yy_buffer_state * YY_BUFFER_STATE;
extern int yyparse();
@@ -35,6 +38,9 @@ int main() {
tdl::ir::context context;
tdl::ir::module module("matrix", context);
program->codegen(&module);
llvm::LLVMContext llvm_context;
llvm::Module llvm_module("test", llvm_context);
tdl::codegen::lowering(module, llvm_module);
// llvm::PrintModulePass print(llvm::outs());
// llvm::AnalysisManager<llvm::Module> analysis;
// print.run(*module.handle(), analysis);

View File

@@ -12,79 +12,146 @@
namespace tdl{
namespace codegen{
/* convert ir::type to llvm::Type */
using namespace llvm;
llvm::Type *llvm_type(ir::type *ty, llvm::LLVMContext &ctx) {
/* convert ir::type to Type */
Type *llvm_type(ir::type *ty, LLVMContext &ctx) {
// function
if(auto* tt = dynamic_cast<ir::function_type*>(ty)){
llvm::Type *return_ty = llvm_type(tt->get_return_ty(), ctx);
std::vector<llvm::Type*> param_tys;
Type *return_ty = llvm_type(tt->get_return_ty(), ctx);
std::vector<Type*> param_tys;
std::transform(tt->params_begin(), tt->params_end(), std::back_inserter(param_tys),
[&ctx](ir::type* t){ return llvm_type(t, ctx);});
return llvm::FunctionType::get(return_ty, param_tys, false);
return FunctionType::get(return_ty, param_tys, false);
}
// pointer
if(ty->is_pointer_ty()){
llvm::Type *elt_ty = llvm_type(ty->get_pointer_element_ty(), ctx);
Type *elt_ty = llvm_type(ty->get_pointer_element_ty(), ctx);
unsigned addr_space = ty->get_pointer_address_space();
return llvm::PointerType::get(elt_ty, addr_space);
return PointerType::get(elt_ty, addr_space);
}
// integer
if(ty->is_integer_ty()){
unsigned bitwidth = ty->get_integer_bitwidth();
return llvm::IntegerType::get(ctx, bitwidth);
return IntegerType::get(ctx, bitwidth);
}
// primitive types
switch(ty->get_type_id()){
case ir::type::VoidTyID: return llvm::Type::getVoidTy(ctx);
case ir::type::HalfTyID: return llvm::Type::getHalfTy(ctx);
case ir::type::FloatTyID: return llvm::Type::getFloatTy(ctx);
case ir::type::DoubleTyID: return llvm::Type::getDoubleTy(ctx);
case ir::type::X86_FP80TyID: return llvm::Type::getX86_FP80Ty(ctx);
case ir::type::PPC_FP128TyID: return llvm::Type::getPPC_FP128Ty(ctx);
case ir::type::LabelTyID: return llvm::Type::getLabelTy(ctx);
case ir::type::MetadataTyID: return llvm::Type::getMetadataTy(ctx);
case ir::type::TokenTyID: return llvm::Type::getTokenTy(ctx);
case ir::type::VoidTyID: return Type::getVoidTy(ctx);
case ir::type::HalfTyID: return Type::getHalfTy(ctx);
case ir::type::FloatTyID: return Type::getFloatTy(ctx);
case ir::type::DoubleTyID: return Type::getDoubleTy(ctx);
case ir::type::X86_FP80TyID: return Type::getX86_FP80Ty(ctx);
case ir::type::PPC_FP128TyID: return Type::getPPC_FP128Ty(ctx);
case ir::type::LabelTyID: return Type::getLabelTy(ctx);
case ir::type::MetadataTyID: return Type::getMetadataTy(ctx);
case ir::type::TokenTyID: return Type::getTokenTy(ctx);
default: break;
}
// unknown type
throw std::runtime_error("unknown conversion from ir::type to llvm::Type");
throw std::runtime_error("unknown conversion from ir::type to Type");
}
/* convert ir::instruction to llvm::Instruction */
llvm::Instruction *llvm_inst(ir::instruction *inst, llvm::LLVMContext & ctx,
std::map<ir::value*, llvm::Value*> &v,
std::map<ir::basic_block*, llvm::BasicBlock*> &b) {
if(auto* ii = dynamic_cast<ir::cond_branch_inst*>(inst))
return llvm::BranchInst::Create(b[ii->get_true_dest()], b[ii->get_false_dest()], v[ii->get_cond()]);
if(auto* ii = dynamic_cast<ir::uncond_branch_inst*>(inst))
return llvm::BranchInst::Create(b[ii->get_dest()]);
if(auto* ii = dynamic_cast<ir::phi_node*>(inst))
return llvm::PHINode::Create(llvm_type(ii->get_type(), ctx), ii->get_num_operands(), ii->get_name());
if(auto* ii = dynamic_cast<ir::return_inst*>(inst))
return llvm::ReturnInst::Create(ctx, v[ii->get_return_value()]);
if(auto* ii = dynamic_cast<ir::binary_operator*>(inst))
return llvm::BinaryOperator::Create(ii->get_op(), v[ii->get_operand(0)], v[ii->get_operand(1)], ii->get_name());
if(auto* ii = dynamic_cast<ir::icmp_inst*>(inst))
return llvm::CmpInst::Create(llvm::Instruction::ICmp, ii->get_pred(), v[ii->get_operand(0)], v[ii->get_operand(1)], ii->get_name());
if(auto* ii = dynamic_cast<ir::fcmp_inst*>(inst))
return llvm::FCmpInst::Create(llvm::Instruction::FCmp, ii->get_pred(), v[ii->get_operand(0)], v[ii->get_operand(1)], ii->get_name());
if(auto* ii = dynamic_cast<ir::cast_inst*>(inst))
return llvm::CastInst::Create(ii->get_op(), v[ii->get_operand(0)], llvm_type(ii->get_type(), ctx), ii->get_name());
if(auto* ii = dynamic_cast<ir::getelementptr_inst*>(inst)){
std::vector<llvm::Value*> idx_vals;
std::transform(ii->idx_begin(), ii->idx_end(), std::back_inserter(idx_vals),
[&v](ir::value* x){ return v[x];});
return llvm::GetElementPtrInst::Create(llvm_type(ii->get_source_elt_ty(), ctx), v[ii->get_operand(0)], idx_vals, ii->get_name());
Value* llvm_value(ir::value *v, LLVMContext &ctx,
std::map<ir::value*, Value*> &vmap,
std::map<ir::basic_block*, BasicBlock*> &bmap);
/* convert ir::constant to Constant */
Constant *llvm_constant(ir::constant *cst, LLVMContext &ctx) {
Type *dst_ty = llvm_type(cst->get_type(), ctx);
if(auto* cc = dynamic_cast<ir::constant_int*>(cst))
return ConstantInt::get(dst_ty, cc->get_value());
if(auto* cc = dynamic_cast<ir::constant_fp*>(cst))
return ConstantFP::get(dst_ty, cc->get_value());
// unknown constant
throw std::runtime_error("unknown conversion from ir::constant to Constant");
}
/* convert ir::instruction to Instruction */
Instruction *llvm_inst(ir::instruction *inst, LLVMContext & ctx,
std::map<ir::value*, Value*> &vmap,
std::map<ir::basic_block*, BasicBlock*> &bmap) {
auto value = [&](ir::value *x) { return llvm_value(x, ctx, vmap, bmap); };
auto block = [&](ir::basic_block *x) { return bmap.at(x); };
auto type = [&](ir::type *x) { return llvm_type(x, ctx); };
if(auto* ii = dynamic_cast<ir::cond_branch_inst*>(inst)){
BasicBlock *true_dest = block(ii->get_true_dest());
BasicBlock *false_dest = block(ii->get_false_dest());
Value *cond = value(ii->get_cond());
return BranchInst::Create(true_dest, false_dest, cond);
}
if(auto* ii = dynamic_cast<ir::uncond_branch_inst*>(inst)){
BasicBlock *dest = block(ii->get_dest());
return BranchInst::Create(dest);
}
if(auto* ii = dynamic_cast<ir::phi_node*>(inst)){
Type *ty = type(ii->get_type());
unsigned num_ops = ii->get_num_operands();
return PHINode::Create(ty, num_ops, ii->get_name());
}
if(auto* ii = dynamic_cast<ir::return_inst*>(inst)){
Value *ret_val = value(ii->get_return_value());
return ReturnInst::Create(ctx, ret_val);
}
if(auto* ii = dynamic_cast<ir::binary_operator*>(inst)){
Value *lhs = value(ii->get_operand(0));
Value *rhs = value(ii->get_operand(1));
return BinaryOperator::Create(ii->get_op(), lhs, rhs, ii->get_name());
}
if(auto* ii = dynamic_cast<ir::icmp_inst*>(inst)){
CmpInst::Predicate pred = ii->get_pred();
Value *lhs = value(ii->get_operand(0));
Value *rhs = value(ii->get_operand(1));
return CmpInst::Create(Instruction::ICmp, pred, lhs, rhs, ii->get_name());
}
if(auto* ii = dynamic_cast<ir::fcmp_inst*>(inst)){
CmpInst::Predicate pred = ii->get_pred();
Value *lhs = value(ii->get_operand(0));
Value *rhs = value(ii->get_operand(1));
return FCmpInst::Create(Instruction::FCmp, pred, lhs, rhs, ii->get_name());
}
if(auto* ii = dynamic_cast<ir::cast_inst*>(inst)){
Value *arg = value(ii->get_operand(0));
Type *dst_ty = type(ii->get_type());
return CastInst::Create(ii->get_op(), arg, dst_ty, ii->get_name());
}
if(auto* ii = dynamic_cast<ir::getelementptr_inst*>(inst)){
std::vector<Value*> idx_vals;
std::transform(ii->idx_begin(), ii->idx_end(), std::back_inserter(idx_vals),
[&value](ir::value* x){ return value(x);});
Type *source_ty = type(ii->get_source_elt_ty());
Value *arg = value(ii->get_operand(0));
return GetElementPtrInst::Create(source_ty, arg, idx_vals, ii->get_name());
}
if(ir::load_inst* ii = dynamic_cast<ir::load_inst*>(inst)){
Value *ptr = value(ii->get_pointer_operand());
return new LoadInst(ptr, ii->get_name());
}
if(ir::load_inst* ii = dynamic_cast<ir::load_inst*>(inst))
return new llvm::LoadInst(v[ii->get_pointer_operand()], ii->get_name());
// unknown instruction
throw std::runtime_error("unknown conversion from ir::type to llvm::Type");
throw std::runtime_error("unknown conversion from ir::type to Type");
}
void lowering(ir::module &src, llvm::Module &dst){
using namespace llvm;
Value* llvm_value(ir::value *v, LLVMContext &ctx,
std::map<ir::value*, Value*> &vmap,
std::map<ir::basic_block*, BasicBlock*> &bmap) {
if(vmap.find(v) != vmap.end())
return vmap.at(v);
// create operands
if(auto *uu = dynamic_cast<ir::user*>(v))
for(ir::use u: uu->ops())
vmap[u.get()] = llvm_value(u, ctx, vmap, bmap);
// constant
if(auto *cc = dynamic_cast<ir::constant*>(v))
return llvm_constant(cc, ctx);
// instruction
if(auto *ii = dynamic_cast<ir::instruction*>(v))
return llvm_inst(ii, ctx, vmap, bmap);
// unknown value
throw std::runtime_error("unknown conversion from ir::value to Value");
}
void lowering(ir::module &src, Module &dst){
std::map<ir::value*, Value*> vmap;
std::map<ir::basic_block*, BasicBlock*> bmap;
LLVMContext &dst_ctx = dst.getContext();
@@ -93,10 +160,14 @@ void lowering(ir::module &src, llvm::Module &dst){
for(ir::function *fn: src.get_function_list()) {
// create LLVM function
Type *fn_ty = llvm_type(fn->get_type(), dst_ctx);
Function *dst_function = (Function*)dst.getOrInsertFunction(fn->get_name(), fn_ty);
Function *dst_fn = (Function*)dst.getOrInsertFunction(fn->get_name(), fn_ty);
// map parameters
for(unsigned i = 0; i < fn->args().size(); i++) {
vmap[fn->args()[i]] = &*(dst_fn->arg_begin() + i);
}
// create blocks
for(ir::basic_block *block: fn->blocks()) {
BasicBlock *dst_block = BasicBlock::Create(dst_ctx, block->get_name(), dst_function);
BasicBlock *dst_block = BasicBlock::Create(dst_ctx, block->get_name(), dst_fn);
bmap[block] = dst_block;
}
// iterate through block
@@ -108,6 +179,16 @@ void lowering(ir::module &src, llvm::Module &dst){
}
}
// add phi operands
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *inst: block->get_inst_list())
if(auto *phi = dynamic_cast<ir::phi_node*>(inst)){
PHINode *dst_phi = (PHINode*)vmap.at(phi);
for(unsigned i = 0; i < phi->get_num_incoming(); i++){
ir::value *inc_val = phi->get_incoming_value(i);
ir::basic_block *inc_block = phi->get_incoming_block(i);
dst_phi->addIncoming(vmap[inc_val], bmap[inc_block]);
}
}
}
}

View File

@@ -33,6 +33,7 @@ class constant_int: public constant{
constant_int(type *ty, uint64_t value);
public:
uint64_t get_value() const { return value_; }
static constant *get(type *ty, uint64_t value);
private:
@@ -44,6 +45,7 @@ class constant_fp: public constant{
constant_fp(context &ctx, double value);
public:
double get_value() { return value_; }
static constant* get_negative_zero(type *ty);
static constant* get_zero_value_for_negation(type *ty);
static constant *get(context &ctx, double v);

View File

@@ -42,7 +42,9 @@ private:
public:
void set_incoming_value(unsigned i, value *v);
void set_incoming_block(unsigned i, basic_block *block);
value *get_incoming_value(unsigned i) { return get_operand(i); }
basic_block *get_incoming_block(unsigned i) { return blocks_[i]; }
unsigned get_num_incoming() { return get_num_operands(); }
void add_incoming(value *v, basic_block *block);
// Factory methods

View File

@@ -71,6 +71,7 @@ public:
: value(ty, name), ops_(num_ops){ }
// Operands
const std::vector<use>& ops() { return ops_; }
op_iterator op_begin() { return ops_.begin(); }
op_iterator op_end() { return ops_.end(); }
void set_operand(unsigned i, value *x);