diff --git a/include/ast.h b/include/ast.h index 04590a800..2d5e46f98 100644 --- a/include/ast.h +++ b/include/ast.h @@ -56,6 +56,9 @@ enum TYPE_T{ FLOAT32_T, FLOAT64_T }; +class pointer; +class identifier; + // AST class node { public: @@ -91,6 +94,15 @@ public: virtual llvm::Value* codegen(module *) const = 0; }; +class named_expression: public expression { +public: + named_expression(node *id): id_((const identifier*)id){} + llvm::Value* codegen(module* mod) const; + +private: + const identifier *id_; +}; + class binary_operator: public expression{ private: llvm::Value* llvm_op(llvm::IRBuilder<> &bld, llvm::Value *lhs, llvm::Value *rhs, const std::string &name) const; @@ -285,9 +297,6 @@ public: }; /* Declarators */ -class pointer; -class identifier; - class declarator: public node{ virtual llvm::Type* type_impl(module*mod, llvm::Type *type) const = 0; @@ -311,7 +320,7 @@ protected: pointer *ptr_; }; -class identifier: public declarator{ +class identifier: public declarator { llvm::Type* type_impl(module*mod, llvm::Type *type) const; public: diff --git a/include/codegen.h b/include/codegen.h index 099ad5787..796845582 100644 --- a/include/codegen.h +++ b/include/codegen.h @@ -20,13 +20,13 @@ public: module(const std::string &name, context *ctx); llvm::Module* handle(); llvm::IRBuilder<>& builder(); - void value(ast::node* node, llvm::Value* value); - llvm::Value *value(ast::node* node); + void value(const ast::node* node, llvm::Value* value); + llvm::Value *value(const ast::node *node); private: llvm::Module handle_; llvm::IRBuilder<> builder_; - std::unordered_map values_; + std::unordered_map values_; }; diff --git a/include/parser.y b/include/parser.y index ab6e1f489..b5f4b56a4 100644 --- a/include/parser.y +++ b/include/parser.y @@ -109,7 +109,7 @@ identifier ; primary_expression - : identifier { $$ = $1; } + : identifier { $$ = new named_expression($1); } | constant { $$ = $1; } | STRING_LITERAL { $$ = new string_literal(yytext); } | '(' expression ')' { $$ = $1; } diff --git a/lib/codegen.cpp b/lib/codegen.cpp index 676b61f2f..5d264f1c1 100644 --- a/lib/codegen.cpp +++ b/lib/codegen.cpp @@ -28,11 +28,11 @@ llvm::IRBuilder<>& module::builder() { return builder_; } -void module::value(ast::node* node, llvm::Value* value){ +void module::value(const ast::node* node, llvm::Value* value){ values_[node] = value; } -llvm::Value *module::value(ast::node* node){ +llvm::Value *module::value(const ast::node* node){ return values_[node]; } @@ -87,7 +87,6 @@ const std::string &identifier::name() const{ return name_; } - // Tile Type* tile::type_impl(module*, Type *type) const{ return TileType::get(type, shapes_->values().size()); @@ -166,16 +165,90 @@ Value* initializer::codegen(module * mod) const{ /*------------------*/ /* Expression */ /*------------------*/ +llvm::Value *llvm_cast(llvm::IRBuilder<> &builder, Value *src, Type *dst_ty){ + Type *src_ty = src->getType(); + bool src_signed = false; + bool dst_signed = false; + if(src_ty == dst_ty) + return src; + else if(src_ty->isIntegerTy() && src_signed && dst_ty->isFloatingPointTy()) + return builder.CreateSIToFP(src, dst_ty); + + else if(src_ty->isIntegerTy() && !src_signed && dst_ty->isFloatingPointTy()) + return builder.CreateUIToFP(src, dst_ty); + + else if(src_ty->isFloatingPointTy() && dst_ty->isIntegerTy() && dst_signed) + return builder.CreateFPToSI(src, dst_ty); + + else if(src_ty->isFloatingPointTy() && dst_ty->isIntegerTy() && !dst_signed) + return builder.CreateFPToUI(src, dst_ty); + + else if(src_ty->isFloatingPointTy() && dst_ty->isFloatingPointTy() && + src_ty->getFPMantissaWidth() < dst_ty->getFPMantissaWidth()) + return builder.CreateFPExt(src, dst_ty); + + else if(src_ty->isFloatingPointTy() && dst_ty->isFloatingPointTy() && + src_ty->getFPMantissaWidth() > dst_ty->getFPMantissaWidth()) + return builder.CreateFPTrunc(src, dst_ty); + + else if(src_ty->isIntegerTy() && dst_ty->isIntegerTy() && + src_ty->getIntegerBitWidth()) + return builder.CreateIntCast(src, dst_ty, dst_signed); + + else{ + assert(false && "unreachable"); + throw; + } +} + +inline void implicit_cast(llvm::IRBuilder<> &builder, Value *&lhs, Value *&rhs, + bool &is_float, bool &is_ptr, bool &is_int, bool &is_signed){ + // Input types + Type *left_ty = lhs->getType(); + Type *right_ty = rhs->getType(); + // One operand is pointer + if(left_ty->isPointerTy()){ + is_ptr = true; + } + // One operand is double + else if(left_ty->isDoubleTy() || right_ty->isDoubleTy()){ + Value *&to_convert = left_ty->isDoubleTy()?rhs:lhs; + to_convert = llvm_cast(builder, to_convert, builder.getDoubleTy()); + is_float = true; + } + // One operand is float + else if(left_ty->isFloatTy() || right_ty->isFloatTy()){ + Value *&to_convert = left_ty->isFloatTy()?rhs:lhs; + to_convert = llvm_cast(builder, to_convert, builder.getFloatTy()); + is_float = true; + } + // Both operands are integers + else if(left_ty->isIntegerTy() && right_ty->isIntegerTy()){ + is_int = true; + is_signed = false; + if(left_ty->getIntegerBitWidth() != right_ty->getIntegerBitWidth()){ + Value *&to_convert = (left_ty->getIntegerBitWidth() > right_ty->getIntegerBitWidth())?rhs:lhs; + Type *dst_ty = (to_convert==lhs)?right_ty:left_ty; + to_convert = llvm_cast(builder, to_convert, dst_ty); + } + } + // Not reachable + else{ + assert(false); + throw; + } +} + +//inline void implicit_broadcast(llvm::IRBuilder<> &builder, Value *&lhs, Value *&rhs){ +// return; +//} /* Binary operator */ Value *binary_operator::llvm_op(llvm::IRBuilder<> &builder, Value *lhs, Value *rhs, const std::string &name) const { - Type *ltype = lhs->getType(); - Type *rtype = rhs->getType(); - bool is_float = ltype->isFloatingPointTy() || rtype->isFloatingPointTy(); - bool is_ptr = ltype->isPointerTy() || rtype->isPointerTy(); - bool is_int = ltype->isIntegerTy() || rtype->isIntegerTy(); - bool is_signed = false; + bool is_float, is_ptr, is_int, is_signed; + implicit_cast(builder, lhs, rhs, is_float, is_ptr, is_int, is_signed); +// implicit_broadcast(builder, lhs, rhs); // Mul if(op_==MUL && is_float) return builder.CreateFMul(lhs, rhs, name); @@ -357,6 +430,7 @@ Value *assignment_expression::llvm_op(llvm::IRBuilder<> &builder, Value *lvalue, Value *assignment_expression::codegen(module *mod) const{ Value *lvalue = lvalue_->codegen(mod); Value *rvalue = rvalue_->codegen(mod); + BasicBlock *block = mod->builder().GetInsertBlock(); return llvm_op(mod->builder(), lvalue, rvalue, ""); } @@ -375,6 +449,12 @@ llvm::Value* constant::codegen(module *mod) const{ return mod->builder().getInt32(value_); } +/* Named */ +llvm::Value* named_expression::codegen(module *mod) const{ + return mod->value(id_); +} + + } }