diff --git a/examples/matrix.cpp b/examples/matrix.cpp index 0bc553849..9bf15ba00 100644 --- a/examples/matrix.cpp +++ b/examples/matrix.cpp @@ -2,6 +2,9 @@ #include #include "ast.h" #include "codegen.h" +#include "llvm/IR/IRPrintingPasses.h" +#include "llvm/IR/LegacyPassManager.h" +#include "llvm/IR/PassManager.h" typedef struct yy_buffer_state * YY_BUFFER_STATE; extern int yyparse(); @@ -14,6 +17,8 @@ const char src[] = "\ void test(fp32 *A, fp32 *B, fp32 *C){\ int32 i = 0;\ + int32 j = 1;\ + i = i + j;\ }\ "; @@ -25,5 +30,8 @@ int main() { tdl::context context; tdl::module module("matrix", &context); program->codegen(&module); + llvm::PrintModulePass print(llvm::outs()); + llvm::AnalysisManager analysis; + print.run(*module.handle(), analysis); return 0; } diff --git a/include/ast.h b/include/ast.h index 031c974c1..1359c761b 100644 --- a/include/ast.h +++ b/include/ast.h @@ -192,25 +192,18 @@ public: }; class assignment_expression: public expression{ -private: - llvm::Value *llvm_op(llvm::IRBuilder<> &builder, - llvm::Value *lvalue, llvm::Value *rvalue, - const std::string &name) const; - public: assignment_expression(node *lvalue, ASSIGN_OP_T op, node *rvalue) - : lvalue_((expression*)lvalue), op_(op), rvalue_((expression*)rvalue) { } + : lvalue_((identifier*)lvalue), op_(op), rvalue_((expression*)rvalue) { } llvm::Value* codegen(module *mod) const; public: ASSIGN_OP_T op_; - const expression *lvalue_; + const identifier *lvalue_; const expression *rvalue_; }; -class statement: public node{ -}; class initializer; class declaration_specifier; @@ -227,6 +220,8 @@ public: const list *init_; }; +class statement: public node{ +}; class compound_statement: public statement{ typedef list* declarations_t; diff --git a/include/codegen.h b/include/codegen.h index e81c8cd01..87a6e0f30 100644 --- a/include/codegen.h +++ b/include/codegen.h @@ -18,6 +18,8 @@ private: class module { typedef std::pair val_key_t; + llvm::PHINode *make_phi(llvm::Type *type, unsigned num_values, llvm::BasicBlock *block); + llvm::Value *add_phi_operands(const ast::node *node, llvm::PHINode *&phi); llvm::Value *get_value_recursive(const ast::node* node, llvm::BasicBlock *block); public: @@ -30,13 +32,15 @@ public: // Getters llvm::Value *get_value(const ast::node *node, llvm::BasicBlock* block); llvm::Value *get_value(const ast::node *node); + // Seal block -- no more predecessors will be added + llvm::Value *seal_block(llvm::BasicBlock *block); private: llvm::Module handle_; llvm::IRBuilder<> builder_; std::map values_; std::set sealed_blocks_; - std::map incomplete_phis_; + std::map> incomplete_phis_; }; diff --git a/lib/codegen.cpp b/lib/codegen.cpp index ac23557f0..9657a2c7a 100644 --- a/lib/codegen.cpp +++ b/lib/codegen.cpp @@ -37,22 +37,37 @@ void module::set_value(const ast::node* node, llvm::Value* value){ return set_value(node, builder_.GetInsertBlock(), value); } +PHINode* module::make_phi(Type *type, unsigned num_values, BasicBlock *block){ + llvm::BasicBlock::iterator save = builder_.GetInsertPoint(); + builder_.SetInsertPoint(&*block->getFirstInsertionPt()); + PHINode *res = builder_.CreatePHI(type, num_values); + builder_.SetInsertPoint(&*save); + return res; +} + +Value *module::add_phi_operands(const ast::node *node, PHINode *&phi){ + BasicBlock *block = phi->getParent(); + for(BasicBlock *pred: predecessors(block)){ + llvm::Value *value = get_value(node, pred); + if(phi->getType()==nullptr){ + phi = make_phi(value->getType(), pred_size(block), block); + } + phi->addIncoming(value, pred); + } +} + llvm::Value *module::get_value_recursive(const ast::node* node, BasicBlock *block) { llvm::Value *result; if(sealed_blocks_.find(block) == sealed_blocks_.end()){ - result = builder_.CreatePHI(nullptr, 1); - incomplete_phis_[val_key_t(node, block)] = (PHINode*)result; + incomplete_phis_[block][node] = make_phi(nullptr, 1, block); } else if(pred_size(block) <= 1){ result = get_value(node, *pred_begin(block)); } else{ - result = builder_.CreatePHI(nullptr, 1); + result = make_phi(nullptr, 1, block); set_value(node, block, result); - for(BasicBlock *pred: predecessors(block)){ - llvm::Value *value = get_value(node, pred); - static_cast(result)->addIncoming(value, pred); - } + add_phi_operands(node, (PHINode*&)result); } set_value(node, block, result); } @@ -68,6 +83,11 @@ llvm::Value *module::get_value(const ast::node *node) { return get_value(node, builder_.GetInsertBlock()); } +llvm::Value *module::seal_block(BasicBlock *block){ + for(auto &x: incomplete_phis_[block]) + add_phi_operands(x.first, x.second); + sealed_blocks_.insert(block); +} namespace ast{ @@ -170,7 +190,8 @@ Value* function_definition::codegen(module *mod) const{ /* Statements */ Value* compound_statement::codegen(module* mod) const{ decls_->codegen(mod); -// statements_->codegen(mod); + if(statements_) + statements_->codegen(mod); return nullptr; } @@ -456,15 +477,10 @@ Value *conditional_expression::codegen(module *mod) const{ } /* Assignment expression */ -Value *assignment_expression::llvm_op(llvm::IRBuilder<> &builder, Value *lvalue, Value *rvalue, const std::string &name) const{ - return nullptr; -} - Value *assignment_expression::codegen(module *mod) const{ - Value *lvalue = lvalue_->codegen(mod); Value *rvalue = rvalue_->codegen(mod); - BasicBlock *block = mod->builder().GetInsertBlock(); - return llvm_op(mod->builder(), lvalue, rvalue, ""); + mod->set_value(lvalue_, rvalue); + return rvalue; } /* Type name */