From 1b8199b82d0bfecff7a6d52b7d76ddc49343b7cd Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Mon, 24 Dec 2018 01:04:55 -0500 Subject: [PATCH] [Code generation] added support for FOR and IF/THEN/ELSE --- examples/matrix.cpp | 10 +++- include/ast.h | 31 ++++++---- include/codegen.h | 16 +++--- include/parser.y | 7 +-- include/scanner.l | 1 + lib/codegen.cpp | 134 ++++++++++++++++++++++++++++++++++---------- 6 files changed, 144 insertions(+), 55 deletions(-) diff --git a/examples/matrix.cpp b/examples/matrix.cpp index 9bf15ba00..8fa19a840 100644 --- a/examples/matrix.cpp +++ b/examples/matrix.cpp @@ -15,10 +15,16 @@ extern translation_unit *ast_root; const char src[] = "\ -void test(fp32 *A, fp32 *B, fp32 *C){\ - int32 i = 0;\ +void test(fp32 *A, fp32 *B, fp32 *C, int32 i){\ int32 j = 1;\ + int32 k;\ i = i + j;\ + for(k = 0; k < 10; k = k+1){\ + int32 u = 1;\ + u = u + i;\ + if(k == 0)\ + u = u + 2;\ + }\ }\ "; diff --git a/include/ast.h b/include/ast.h index 1359c761b..6b8e48570 100644 --- a/include/ast.h +++ b/include/ast.h @@ -94,15 +94,21 @@ public: virtual llvm::Value* codegen(module *) const = 0; }; -class named_expression: public expression { +class unary_expression: public node{ public: - named_expression(node *id): id_((const identifier*)id){} - llvm::Value* codegen(module* mod) const; + unary_expression(node *id): id_((const identifier*)id) {} + const identifier *id() const; private: const identifier *id_; }; +class named_expression: public unary_expression { +public: + named_expression(node *id): unary_expression(id){ } + llvm::Value* codegen(module* mod) const; +}; + class binary_operator: public expression{ private: llvm::Value* llvm_op(llvm::IRBuilder<> &bld, llvm::Value *lhs, llvm::Value *rhs, const std::string &name) const; @@ -194,13 +200,13 @@ public: class assignment_expression: public expression{ public: assignment_expression(node *lvalue, ASSIGN_OP_T op, node *rvalue) - : lvalue_((identifier*)lvalue), op_(op), rvalue_((expression*)rvalue) { } + : lvalue_((unary_expression*)lvalue), op_(op), rvalue_((expression*)rvalue) { } llvm::Value* codegen(module *mod) const; public: ASSIGN_OP_T op_; - const identifier *lvalue_; + const unary_expression *lvalue_; const expression *rvalue_; }; @@ -241,18 +247,23 @@ private: class selection_statement: public statement{ public: selection_statement(node *cond, node *if_value, node *else_value = nullptr) - : cond_(cond), if_value_(if_value), else_value_(else_value) { } + : cond_(cond), then_value_(if_value), else_value_(else_value) { } + + llvm::Value* codegen(module *mod) const; public: const node *cond_; - const node *if_value_; + const node *then_value_; const node *else_value_; }; class iteration_statement: public statement{ public: iteration_statement(node *init, node *stop, node *exec, node *statements) - : init_(init), stop_(stop), exec_(exec), statements_(statements) { } + : init_(init), stop_(stop), exec_(exec), statements_(statements) + { } + + llvm::Value* codegen(module *mod) const; private: const node *init_; @@ -368,7 +379,7 @@ private: public: initializer(node *decl, node *init) : declarator((node*)((declarator*)decl)->id()), - decl_((declarator*)decl), init_((expression*)init){ } + decl_((declarator*)decl), expr_((expression*)init){ } void specifier(const declaration_specifier *spec); llvm::Value* codegen(module *) const; @@ -376,7 +387,7 @@ public: public: const declaration_specifier *spec_; const declarator *decl_; - const expression *init_; + const expression *expr_; }; diff --git a/include/codegen.h b/include/codegen.h index 87a6e0f30..316e3b859 100644 --- a/include/codegen.h +++ b/include/codegen.h @@ -17,21 +17,21 @@ private: }; class module { - typedef std::pair val_key_t; + typedef std::pair val_key_t; llvm::PHINode *make_phi(llvm::Type *type, unsigned num_values, llvm::BasicBlock *block); - llvm::Value *add_phi_operands(const ast::node *node, llvm::PHINode *&phi); - llvm::Value *get_value_recursive(const ast::node* node, llvm::BasicBlock *block); + llvm::Value *add_phi_operands(const std::string& name, llvm::PHINode *&phi); + llvm::Value *get_value_recursive(const std::string& name, llvm::BasicBlock *block); public: module(const std::string &name, context *ctx); llvm::Module* handle(); llvm::IRBuilder<>& builder(); // Setters - void set_value(const ast::node *node, llvm::BasicBlock* block, llvm::Value *value); - void set_value(const ast::node* node, llvm::Value* value); + void set_value(const std::string& name, llvm::BasicBlock* block, llvm::Value *value); + void set_value(const std::string& name, llvm::Value* value); // Getters - llvm::Value *get_value(const ast::node *node, llvm::BasicBlock* block); - llvm::Value *get_value(const ast::node *node); + llvm::Value *get_value(const std::string& name, llvm::BasicBlock* block); + llvm::Value *get_value(const std::string& name); // Seal block -- no more predecessors will be added llvm::Value *seal_block(llvm::BasicBlock *block); @@ -40,7 +40,7 @@ private: llvm::IRBuilder<> builder_; std::map values_; std::set sealed_blocks_; - std::map> incomplete_phis_; + std::map> incomplete_phis_; }; diff --git a/include/parser.y b/include/parser.y index b5f4b56a4..6d1e81936 100644 --- a/include/parser.y +++ b/include/parser.y @@ -269,13 +269,12 @@ expression_statement ; selection_statement - : IF '(' expression ')' statement { $$ = new selection_statement($1, $3); } - | IF '(' expression ')' statement ELSE statement { $$ = new selection_statement($1, $3, $5); } + : IF '(' expression ')' statement { $$ = new selection_statement($3, $5); } + | IF '(' expression ')' statement ELSE statement { $$ = new selection_statement($3, $5, $7); } ; iteration_statement - : FOR '(' expression_statement expression_statement ')' statement { $$ = new iteration_statement($1, $3, NULL, $4); } - | FOR '(' expression_statement expression_statement expression ')' statement { $$ = new iteration_statement($1, $3, $4, $5); } + : FOR '(' expression_statement expression_statement expression ')' statement { $$ = new iteration_statement($3, $4, $5, $7); } ; diff --git a/include/scanner.l b/include/scanner.l index df730aec8..5ecf37b1b 100644 --- a/include/scanner.l +++ b/include/scanner.l @@ -113,6 +113,7 @@ void count() column += 8 - (column % 8); else column++; + //ECHO; } void yyerror (const char *s) /* Called by yyparse on error */ diff --git a/lib/codegen.cpp b/lib/codegen.cpp index 9657a2c7a..3c20c1487 100644 --- a/lib/codegen.cpp +++ b/lib/codegen.cpp @@ -4,6 +4,9 @@ #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Module.h" #include "llvm/IR/CFG.h" +#include "llvm/IR/Instructions.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include using namespace llvm; @@ -19,6 +22,7 @@ LLVMContext *context::handle() { /* Module */ module::module(const std::string &name, context *ctx) : handle_(name.c_str(), *ctx->handle()), builder_(*ctx->handle()) { + sealed_blocks_.insert(nullptr); } llvm::Module* module::handle() { @@ -29,58 +33,64 @@ llvm::IRBuilder<>& module::builder() { return builder_; } -void module::set_value(const ast::node *node, BasicBlock *block, Value *value){ - values_[val_key_t{node, block}] = value; +void module::set_value(const std::string& name, BasicBlock *block, Value *value){ + values_[val_key_t{name, block}] = value; } -void module::set_value(const ast::node* node, llvm::Value* value){ - return set_value(node, builder_.GetInsertBlock(), value); +void module::set_value(const std::string& name, llvm::Value* value){ + return set_value(name, builder_.GetInsertBlock(), value); } PHINode* module::make_phi(Type *type, unsigned num_values, BasicBlock *block){ - llvm::BasicBlock::iterator save = builder_.GetInsertPoint(); - builder_.SetInsertPoint(&*block->getFirstInsertionPt()); + Instruction* instr = block->getFirstNonPHIOrDbg(); + if(instr) + builder_.SetInsertPoint(instr); PHINode *res = builder_.CreatePHI(type, num_values); - builder_.SetInsertPoint(&*save); + if(instr) + builder_.SetInsertPoint(block); return res; } -Value *module::add_phi_operands(const ast::node *node, PHINode *&phi){ +Value *module::add_phi_operands(const std::string& name, PHINode *&phi){ BasicBlock *block = phi->getParent(); for(BasicBlock *pred: predecessors(block)){ - llvm::Value *value = get_value(node, pred); - if(phi->getType()==nullptr){ - phi = make_phi(value->getType(), pred_size(block), block); - } + llvm::Value *value = get_value(name, pred); phi->addIncoming(value, pred); } + return phi; } -llvm::Value *module::get_value_recursive(const ast::node* node, BasicBlock *block) { +llvm::Value *module::get_value_recursive(const std::string& name, BasicBlock *block) { llvm::Value *result; if(sealed_blocks_.find(block) == sealed_blocks_.end()){ - incomplete_phis_[block][node] = make_phi(nullptr, 1, block); + llvm::Value *pred = get_value(name, *pred_begin(block)); + incomplete_phis_[block][name] = make_phi(pred->getType(), 1, block); + result = (Value*)incomplete_phis_[block][name]; } else if(pred_size(block) <= 1){ - result = get_value(node, *pred_begin(block)); + bool has_pred = pred_size(block); + result = get_value(name, has_pred?*pred_begin(block):nullptr); } else{ - result = make_phi(nullptr, 1, block); - set_value(node, block, result); - add_phi_operands(node, (PHINode*&)result); + llvm::Value *pred = get_value(name, *pred_begin(block)); + result = make_phi(pred->getType(), 1, block); + set_value(name, block, result); + add_phi_operands(name, (PHINode*&)result); } - set_value(node, block, result); + set_value(name, block, result); + return result; } -llvm::Value *module::get_value(const ast::node* node, BasicBlock *block) { - val_key_t key(node, block); - if(values_.find(key) != values_.end()) +llvm::Value *module::get_value(const std::string& name, BasicBlock *block) { + val_key_t key(name, block); + if(values_.find(key) != values_.end()){ return values_.at(key); - return get_value_recursive(node, block); + } + return get_value_recursive(name, block); } -llvm::Value *module::get_value(const ast::node *node) { - return get_value(node, builder_.GetInsertBlock()); +llvm::Value *module::get_value(const std::string& name) { + return get_value(name, builder_.GetInsertBlock()); } llvm::Value *module::seal_block(BasicBlock *block){ @@ -162,7 +172,7 @@ void function::bind_parameters(module *mod, Function *fn) const{ const identifier *id_i = param_i->id(); if(id_i){ args[i]->setName(id_i->name()); - mod->set_value(id_i, nullptr, args[i]); + mod->set_value(id_i->name(), nullptr, args[i]); } } } @@ -182,8 +192,10 @@ Value* function_definition::codegen(module *mod) const{ Function *fn = Function::Create(prototype, Function::ExternalLinkage, name, mod->handle()); header_->bind_parameters(mod, fn); BasicBlock *entry = BasicBlock::Create(mod->handle()->getContext(), "entry", fn); + mod->seal_block(entry); mod->builder().SetInsertPoint(entry); body_->codegen(mod); + mod->builder().CreateRetVoid(); return nullptr; } @@ -195,6 +207,55 @@ Value* compound_statement::codegen(module* mod) const{ return nullptr; } +/* Iteration statement */ +Value* iteration_statement::codegen(module *mod) const{ + IRBuilder<> &builder = mod->builder(); + LLVMContext &ctx = mod->handle()->getContext(); + Function *fn = builder.GetInsertBlock()->getParent(); + BasicBlock *loop_bb = BasicBlock::Create(ctx, "loop", fn); + BasicBlock *next_bb = BasicBlock::Create(ctx, "postloop", fn); + init_->codegen(mod); + builder.CreateBr(loop_bb); + builder.SetInsertPoint(loop_bb); + statements_->codegen(mod); + exec_->codegen(mod); + Value *cond = stop_->codegen(mod); + builder.CreateCondBr(cond, loop_bb, next_bb); + builder.SetInsertPoint(next_bb); + mod->seal_block(loop_bb); + mod->seal_block(next_bb); + return nullptr; +} + +/* Selection statement */ +Value* selection_statement::codegen(module* mod) const{ + IRBuilder<> &builder = mod->builder(); + LLVMContext &ctx = mod->handle()->getContext(); + Function *fn = builder.GetInsertBlock()->getParent(); + Value *cond = cond_->codegen(mod); + BasicBlock *then_bb = BasicBlock::Create(ctx, "then", fn); + BasicBlock *else_bb = else_value_?BasicBlock::Create(ctx, "else", fn):nullptr; + BasicBlock *endif_bb = BasicBlock::Create(ctx, "endif", fn); + // Branch + if(else_value_) + builder.CreateCondBr(cond, then_bb, else_bb); + else + builder.CreateCondBr(cond, then_bb, endif_bb); + // Then + builder.SetInsertPoint(then_bb); + then_value_->codegen(mod); + if(else_value_) + builder.CreateBr(endif_bb); + // Else + if(else_value_){ + builder.SetInsertPoint(else_bb); + else_value_->codegen(mod); + builder.CreateBr(endif_bb); + } + // Endif + builder.SetInsertPoint(endif_bb); +} + /* Declaration */ Value* declaration::codegen(module* mod) const{ for(initializer *init: init_->values()) @@ -211,9 +272,14 @@ void initializer::specifier(const declaration_specifier *spec) { Value* initializer::codegen(module * mod) const{ Type *ty = decl_->type(mod, spec_->type(mod)); std::string name = decl_->id()->name(); - Value *value = llvm::UndefValue::get(ty); + Value *value; + if(expr_) + value = expr_->codegen(mod); + else + value = llvm::UndefValue::get(ty); value->setName(name); - return nullptr; + mod->set_value(name, value); + return value; } /*------------------*/ @@ -300,7 +366,7 @@ inline void implicit_cast(llvm::IRBuilder<> &builder, Value *&lhs, Value *&rhs, /* Binary operator */ Value *binary_operator::llvm_op(llvm::IRBuilder<> &builder, Value *lhs, Value *rhs, const std::string &name) const { - bool is_float, is_ptr, is_int, is_signed; + bool is_float = false, is_ptr = false, is_int = false, is_signed = false; implicit_cast(builder, lhs, rhs, is_float, is_ptr, is_int, is_signed); // implicit_broadcast(builder, lhs, rhs); // Mul @@ -479,7 +545,7 @@ Value *conditional_expression::codegen(module *mod) const{ /* Assignment expression */ Value *assignment_expression::codegen(module *mod) const{ Value *rvalue = rvalue_->codegen(mod); - mod->set_value(lvalue_, rvalue); + mod->set_value(lvalue_->id()->name(), rvalue); return rvalue; } @@ -498,9 +564,15 @@ llvm::Value* constant::codegen(module *mod) const{ return mod->builder().getInt32(value_); } +/* Unary expression */ +const identifier* unary_expression::id() const{ + return id_; +} + /* Named */ llvm::Value* named_expression::codegen(module *mod) const{ - return mod->get_value(id_); + const std::string &name = id()->name(); + return mod->get_value(name); }