[AST] disambiguate named_expression vs identifier

This commit is contained in:
Philippe Tillet
2018-12-22 11:55:04 -05:00
parent eab275dc99
commit d06f0fa593
4 changed files with 106 additions and 17 deletions

View File

@@ -56,6 +56,9 @@ enum TYPE_T{
FLOAT32_T, FLOAT64_T
};
class pointer;
class identifier;
// AST
class node {
public:
@@ -91,6 +94,15 @@ public:
virtual llvm::Value* codegen(module *) const = 0;
};
class named_expression: public expression {
public:
named_expression(node *id): id_((const identifier*)id){}
llvm::Value* codegen(module* mod) const;
private:
const identifier *id_;
};
class binary_operator: public expression{
private:
llvm::Value* llvm_op(llvm::IRBuilder<> &bld, llvm::Value *lhs, llvm::Value *rhs, const std::string &name) const;
@@ -285,9 +297,6 @@ public:
};
/* Declarators */
class pointer;
class identifier;
class declarator: public node{
virtual llvm::Type* type_impl(module*mod, llvm::Type *type) const = 0;
@@ -311,7 +320,7 @@ protected:
pointer *ptr_;
};
class identifier: public declarator{
class identifier: public declarator {
llvm::Type* type_impl(module*mod, llvm::Type *type) const;
public:

View File

@@ -20,13 +20,13 @@ public:
module(const std::string &name, context *ctx);
llvm::Module* handle();
llvm::IRBuilder<>& builder();
void value(ast::node* node, llvm::Value* value);
llvm::Value *value(ast::node* node);
void value(const ast::node* node, llvm::Value* value);
llvm::Value *value(const ast::node *node);
private:
llvm::Module handle_;
llvm::IRBuilder<> builder_;
std::unordered_map<ast::node*, llvm::Value*> values_;
std::unordered_map<const ast::node*, llvm::Value*> values_;
};

View File

@@ -109,7 +109,7 @@ identifier
;
primary_expression
: identifier { $$ = $1; }
: identifier { $$ = new named_expression($1); }
| constant { $$ = $1; }
| STRING_LITERAL { $$ = new string_literal(yytext); }
| '(' expression ')' { $$ = $1; }

View File

@@ -28,11 +28,11 @@ llvm::IRBuilder<>& module::builder() {
return builder_;
}
void module::value(ast::node* node, llvm::Value* value){
void module::value(const ast::node* node, llvm::Value* value){
values_[node] = value;
}
llvm::Value *module::value(ast::node* node){
llvm::Value *module::value(const ast::node* node){
return values_[node];
}
@@ -87,7 +87,6 @@ const std::string &identifier::name() const{
return name_;
}
// Tile
Type* tile::type_impl(module*, Type *type) const{
return TileType::get(type, shapes_->values().size());
@@ -166,16 +165,90 @@ Value* initializer::codegen(module * mod) const{
/*------------------*/
/* Expression */
/*------------------*/
llvm::Value *llvm_cast(llvm::IRBuilder<> &builder, Value *src, Type *dst_ty){
Type *src_ty = src->getType();
bool src_signed = false;
bool dst_signed = false;
if(src_ty == dst_ty)
return src;
else if(src_ty->isIntegerTy() && src_signed && dst_ty->isFloatingPointTy())
return builder.CreateSIToFP(src, dst_ty);
else if(src_ty->isIntegerTy() && !src_signed && dst_ty->isFloatingPointTy())
return builder.CreateUIToFP(src, dst_ty);
else if(src_ty->isFloatingPointTy() && dst_ty->isIntegerTy() && dst_signed)
return builder.CreateFPToSI(src, dst_ty);
else if(src_ty->isFloatingPointTy() && dst_ty->isIntegerTy() && !dst_signed)
return builder.CreateFPToUI(src, dst_ty);
else if(src_ty->isFloatingPointTy() && dst_ty->isFloatingPointTy() &&
src_ty->getFPMantissaWidth() < dst_ty->getFPMantissaWidth())
return builder.CreateFPExt(src, dst_ty);
else if(src_ty->isFloatingPointTy() && dst_ty->isFloatingPointTy() &&
src_ty->getFPMantissaWidth() > dst_ty->getFPMantissaWidth())
return builder.CreateFPTrunc(src, dst_ty);
else if(src_ty->isIntegerTy() && dst_ty->isIntegerTy() &&
src_ty->getIntegerBitWidth())
return builder.CreateIntCast(src, dst_ty, dst_signed);
else{
assert(false && "unreachable");
throw;
}
}
inline void implicit_cast(llvm::IRBuilder<> &builder, Value *&lhs, Value *&rhs,
bool &is_float, bool &is_ptr, bool &is_int, bool &is_signed){
// Input types
Type *left_ty = lhs->getType();
Type *right_ty = rhs->getType();
// One operand is pointer
if(left_ty->isPointerTy()){
is_ptr = true;
}
// One operand is double
else if(left_ty->isDoubleTy() || right_ty->isDoubleTy()){
Value *&to_convert = left_ty->isDoubleTy()?rhs:lhs;
to_convert = llvm_cast(builder, to_convert, builder.getDoubleTy());
is_float = true;
}
// One operand is float
else if(left_ty->isFloatTy() || right_ty->isFloatTy()){
Value *&to_convert = left_ty->isFloatTy()?rhs:lhs;
to_convert = llvm_cast(builder, to_convert, builder.getFloatTy());
is_float = true;
}
// Both operands are integers
else if(left_ty->isIntegerTy() && right_ty->isIntegerTy()){
is_int = true;
is_signed = false;
if(left_ty->getIntegerBitWidth() != right_ty->getIntegerBitWidth()){
Value *&to_convert = (left_ty->getIntegerBitWidth() > right_ty->getIntegerBitWidth())?rhs:lhs;
Type *dst_ty = (to_convert==lhs)?right_ty:left_ty;
to_convert = llvm_cast(builder, to_convert, dst_ty);
}
}
// Not reachable
else{
assert(false);
throw;
}
}
//inline void implicit_broadcast(llvm::IRBuilder<> &builder, Value *&lhs, Value *&rhs){
// return;
//}
/* Binary operator */
Value *binary_operator::llvm_op(llvm::IRBuilder<> &builder, Value *lhs, Value *rhs, const std::string &name) const
{
Type *ltype = lhs->getType();
Type *rtype = rhs->getType();
bool is_float = ltype->isFloatingPointTy() || rtype->isFloatingPointTy();
bool is_ptr = ltype->isPointerTy() || rtype->isPointerTy();
bool is_int = ltype->isIntegerTy() || rtype->isIntegerTy();
bool is_signed = false;
bool is_float, is_ptr, is_int, is_signed;
implicit_cast(builder, lhs, rhs, is_float, is_ptr, is_int, is_signed);
// implicit_broadcast(builder, lhs, rhs);
// Mul
if(op_==MUL && is_float)
return builder.CreateFMul(lhs, rhs, name);
@@ -357,6 +430,7 @@ Value *assignment_expression::llvm_op(llvm::IRBuilder<> &builder, Value *lvalue,
Value *assignment_expression::codegen(module *mod) const{
Value *lvalue = lvalue_->codegen(mod);
Value *rvalue = rvalue_->codegen(mod);
BasicBlock *block = mod->builder().GetInsertBlock();
return llvm_op(mod->builder(), lvalue, rvalue, "");
}
@@ -375,6 +449,12 @@ llvm::Value* constant::codegen(module *mod) const{
return mod->builder().getInt32(value_);
}
/* Named */
llvm::Value* named_expression::codegen(module *mod) const{
return mod->value(id_);
}
}
}