#ifndef TDL_INCLUDE_AST_H #define TDL_INCLUDE_AST_H #include "parser.hpp" #include #include #include namespace tdl{ namespace ir{ class function; class value; class type; class builder; class module; } namespace ast{ // Enumerations enum ASSIGN_OP_T{ ASSIGN, INPLACE_MUL, INPLACE_DIV, INPLACE_MOD, INPLACE_ADD, INPLACE_SUB, INPLACE_LSHIFT, INPLACE_RSHIFT, INPLACE_AND, INPLACE_XOR, INPLACE_OR }; enum BIN_OP_T{ MUL, DIV, MOD, ADD, SUB, LEFT_SHIFT, RIGHT_SHIFT, LT, GT, LE, GE, EQ, NE, AND, XOR, OR, LAND, LOR }; enum UNARY_OP_T{ INC, DEC, PLUS, MINUS, ADDR, DEREF, COMPL, NOT }; enum TYPE_T{ VOID_T, UINT1_T, UINT8_T, UINT16_T, UINT32_T, UINT64_T, INT1_T, INT8_T, INT16_T, INT32_T, INT64_T, FLOAT32_T, FLOAT64_T }; class pointer; class identifier; class constant; // AST class node { protected: static ir::value* explicit_cast(ir::builder &builder, ir::value *src, ir::type *dst_ty); static void implicit_broadcast(ir::module *mod, ir::value *&lhs, ir::value *&rhs); static void implicit_broadcast(ir::module *mod, ir::value *&arg, ir::type *ty); static void implicit_cast(ir::builder &builder, ir::value *&lhs, ir::value *&rhs, bool &is_float, bool &is_ptr, bool &is_int, bool &is_signed); public: virtual ir::value* codegen(ir::module *) const { return nullptr; } }; template class list: public node { public: list(const T& x): values_{x} {} node* append(const T& x){ values_.push_back(x); return this; } ir::value* codegen(ir::module * mod) const{ for(T x: values_){ x->codegen(mod); } return nullptr; } const std::vector &values() const { return values_; } private: std::vector values_; }; enum slice_enum_t{ ALL, NEWAXIS }; class slice: public node{ public: slice(slice_enum_t type) : type_(type){} slice_enum_t type() const{ return type_; } public: const slice_enum_t type_; }; class named_expression; class expression: public node{ public: virtual ir::value* codegen(ir::module *) const = 0; named_expression *lvalue() const { return lvalue_; } protected: named_expression *lvalue_; }; class postfix_expression: public expression{ }; class builtin_expression: public node{ }; class get_global_range: public builtin_expression{ public: get_global_range(node *size, node *axis): size_((constant*)size), axis_((constant*)axis) { } ir::value* codegen(ir::module *) const; private: const constant* size_; const constant* axis_; }; class matmul_expression: public builtin_expression{ public: matmul_expression(node* A, node *B, node *C): A_((expression*)A), B_((expression*)B), C_((expression*)C) { } ir::value* codegen(ir::module *) const; private: const expression *A_; const expression *B_; const expression *C_; }; class indexing_expression: public postfix_expression{ public: indexing_expression(node *id, node *slices) : id_((const identifier*)id), slices_((const list*)slices) {} ir::value* codegen(ir::module *) const; private: const identifier* id_; const list* slices_; }; class named_expression: public expression { public: named_expression(node *id): id_((const identifier*)id) { lvalue_ = this; } const identifier *id() const { return id_; } ir::value* codegen(ir::module * mod) const; private: const identifier *id_; }; class binary_operator: public expression{ private: ir::value* llvm_op(ir::module *mod, ir::builder &bld, ir::value *lhs, ir::value *rhs, const std::string &name) const; public: binary_operator(BIN_OP_T op, node *lhs, node *rhs) : op_(op), lhs_((expression*)lhs), rhs_((expression*)rhs) { } ir::value* codegen(ir::module *) const; private: const BIN_OP_T op_; const expression *lhs_; const expression *rhs_; }; class constant: public expression{ public: constant(int value): value_(value) { } ir::value* codegen(ir::module *mod) const; int value() const; private: const int value_; }; class constant_range: public expression { public: constant_range(node *first, node *last) : first_((constant*)first), last_((constant*)last) { } ir::value* codegen(ir::module *mod) const; private: constant *first_; constant *last_; }; class string_literal: public expression{ public: string_literal(char *&value): value_(value) { } ir::value* codegen(ir::module *mod) const; public: std::string value_; }; class unary_operator: public expression{ private: ir::value *llvm_op(ir::builder &builder, ir::value *arg, const std::string &name) const; public: unary_operator(UNARY_OP_T op, node *arg) : op_(op), arg_((expression*)arg) { if(op == DEREF) this->lvalue_ = arg_->lvalue(); } UNARY_OP_T get_op() const { return op_; } ir::value* codegen(ir::module *mod) const; private: const UNARY_OP_T op_; const expression *arg_; }; class type_name; class cast_operator: public expression{ private: ir::value *llvm_op(ir::builder &builder, ir::type *T, ir::value *arg, const std::string &name) const; public: cast_operator(node *T, node *arg): T_((type_name*)T), arg_((expression*)arg) { } ir::value* codegen(ir::module *mod) const; public: const type_name *T_; const expression *arg_; }; class conditional_expression: public expression{ private: ir::value *llvm_op(ir::builder &builder, ir::value *cond, ir::value *true_value, ir::value *false_value, const std::string &name) const; public: conditional_expression(node *cond, node *true_value, node *false_value) : cond_((expression*)cond), true_value_((expression*)true_value), false_value_((expression*)false_value) { } ir::value* codegen(ir::module *mod) const; public: const expression *cond_; const expression *true_value_; const expression *false_value_; }; class assignment_expression: public expression{ public: assignment_expression(node *lvalue, ASSIGN_OP_T op, node *rvalue) : lhs_((named_expression*)lvalue), op_(op), rhs_((expression*)rvalue) { } const expression *lhs() const { return lhs_; } const expression *rhs() const { return rhs_; } ir::value* codegen(ir::module *mod) const; public: ASSIGN_OP_T op_; const expression *lhs_; const expression *rhs_; }; class initializer; class declaration_specifier; class declaration: public node{ public: declaration(node *spec, node *init) : spec_((declaration_specifier*)spec), init_((list*)init) { } ir::value* codegen(ir::module * mod) const; public: const declaration_specifier *spec_; const list *init_; }; class statement: public node{ private: expression *pred_; }; class expression_statement: public statement{ public: expression_statement(node *expr, node *mask = nullptr) : expr_((expression*)expr), mask_((expression*)mask){ } ir::value* codegen(ir::module * mod) const; private: expression *expr_; expression *mask_; }; class compound_statement: public statement{ typedef list* declarations_t; typedef list* statements_t; public: compound_statement(node* decls, node* statements) : decls_((declarations_t)decls), statements_((statements_t)statements) {} ir::value* codegen(ir::module * mod) const; private: declarations_t decls_; statements_t statements_; }; class selection_statement: public statement{ public: selection_statement(node *cond, node *if_value, node *else_value = nullptr) : cond_(cond), then_value_(if_value), else_value_(else_value) { } ir::value* codegen(ir::module *mod) const; public: const node *cond_; const node *then_value_; const node *else_value_; }; class iteration_statement: public statement{ public: iteration_statement(node *init, node *stop, node *exec, node *statements) : init_(init), stop_(stop), exec_(exec), statements_(statements) { } ir::value* codegen(ir::module *mod) const; private: const node *init_; const node *stop_; const node *exec_; const node *statements_; }; class no_op: public statement { }; // Types class declaration_specifier: public node{ public: declaration_specifier(TYPE_T spec) : spec_(spec) { } ir::type* type(ir::module *mod) const; private: const TYPE_T spec_; }; class declarator; class parameter: public node { public: parameter(node *spec, node *decl) : spec_((declaration_specifier*)spec), decl_((declarator*)decl) { } ir::type* type(ir::module *mod) const; const identifier* id() const; public: const declaration_specifier *spec_; const declarator *decl_; }; /* Declarators */ class declarator: public node{ virtual ir::type* type_impl(ir::module *mod, ir::type *type) const = 0; public: declarator(node *lhs) : lhs_((declarator*)lhs), ptr_(nullptr){ } ir::type* type(ir::module *mod, ir::type *type) const; const identifier* id() const { return (const identifier*)lhs_; } declarator *set_ptr(node *ptr){ ptr_ = (pointer*)ptr; return this; } protected: declarator *lhs_; pointer *ptr_; }; class identifier: public declarator { ir::type* type_impl(ir::module *mod, ir::type *type) const; public: identifier(char *&name): declarator(this), name_(name) { } const std::string &name() const; private: std::string name_; }; class pointer: public declarator{ private: ir::type* type_impl(ir::module *mod, ir::type *type) const; public: pointer(node *id): declarator(id) { } }; class tile: public declarator{ private: ir::type* type_impl(ir::module *mod, ir::type *type) const; public: tile(node *id, node *shapes) : declarator(id), shapes_((list*)(shapes)) { } public: const list* shapes_; }; class function: public declarator{ private: ir::type* type_impl(ir::module *mod, ir::type *type) const; public: function(node *id, node *args) : declarator(id), args_((list*)args) { } void bind_parameters(ir::module *mod, ir::function *fn) const; public: const list* args_; }; class initializer : public declarator{ private: ir::type* type_impl(ir::module * mod, ir::type *type) const; public: initializer(node *decl, node *init) : declarator((node*)((declarator*)decl)->id()), decl_((declarator*)decl), expr_((expression*)init){ } void specifier(const declaration_specifier *spec); ir::value* codegen(ir::module *) const; public: const declaration_specifier *spec_; const declarator *decl_; const expression *expr_; }; class type_name: public node{ public: type_name(node *spec, node * decl) : spec_((declaration_specifier*)spec), decl_((declarator*)decl) { } ir::type *type(ir::module *mod) const; public: const declaration_specifier *spec_; const declarator *decl_; }; /* Function definition */ class function_definition: public node{ public: function_definition(node *spec, node *header, node *body) : spec_((declaration_specifier*)spec), header_((function *)header), body_((compound_statement*)body) { } ir::value* codegen(ir::module * mod) const; public: const declaration_specifier *spec_; const function *header_; const compound_statement *body_; }; /* Translation Unit */ class translation_unit: public node{ public: translation_unit(node *item) : decls_((list*)item) { } translation_unit *add(node *item) { decls_->append(item); return this; } ir::value* codegen(ir::module * mod) const; private: list* decls_; }; } } #endif