diff --git a/include/triton/ast/ast.h b/include/triton/ast/ast.h index 3d1da3064..26282894e 100644 --- a/include/triton/ast/ast.h +++ b/include/triton/ast/ast.h @@ -1,691 +1,12 @@ -#ifndef TDL_INCLUDE_AST_H -#define TDL_INCLUDE_AST_H +#ifndef TRITON_INCLUDE_AST_AST_H +#define TRITON_INCLUDE_AST_AST_H +#include "ops.h" #include "parser.hpp" -#include -#include -#include -#include - - -namespace triton{ - - -namespace ir{ - class function; - class value; - class type; - class builder; - class module; -} - -namespace ast{ - -// Enumerations -enum ASSIGN_OP_T{ - ASSIGN, - INPLACE_MUL, INPLACE_DIV, INPLACE_MOD, - INPLACE_ADD, INPLACE_SUB, - INPLACE_LSHIFT, INPLACE_RSHIFT, - INPLACE_AND, INPLACE_XOR, - INPLACE_OR -}; - -enum BIN_OP_T{ - MUL, DIV, MOD, - ADD, SUB, - LEFT_SHIFT, RIGHT_SHIFT, - LT, GT, - LE, GE, - EQ, NE, - AND, XOR, OR, - LAND, LOR -}; - -enum UNARY_OP_T{ - INC, DEC, - PLUS, MINUS, - ADDR, DEREF, - COMPL, NOT -}; - -enum TYPE_T{ - VOID_T, - UINT1_T, UINT8_T, UINT16_T, UINT32_T, UINT64_T, - INT1_T, INT8_T, INT16_T, INT32_T, INT64_T, - FLOAT32_T, FLOAT64_T -}; - -enum STORAGE_SPEC_T{ - CONST_T, - TUNABLE_T, - KERNEL_T, - RESTRICT_T, - READONLY_T, - CONSTANT_SPACE_T, - WRITEONLY_T -}; - -class pointer; -class identifier; -class constant; - -// AST -class node { -protected: - static ir::value* explicit_cast(ir::builder &builder, ir::value *src, ir::type *dst_ty); - static void implicit_broadcast(ir::module *mod, ir::type *dst_ty, ir::value *&src); - static void implicit_broadcast(ir::module *mod, ir::value *&lhs, ir::value *&rhs); - static void implicit_cast(ir::builder &builder, ir::value *&lhs, ir::value *&rhs, - bool &is_float, bool &is_ptr, bool &is_int, bool &is_signed); -public: - virtual ir::value* codegen(ir::module *) const { return nullptr; } -}; - -template -class list: public node { -public: - list(const T& x): values_(1, x) {} - - node* append(const T& x){ - values_.push_back(x); - return this; - } - - ir::value* codegen(ir::module * mod) const{ - for(T x: values_){ - x->codegen(mod); - } - return nullptr; - } - - const std::vector &values() const - { return values_; } - -private: - std::vector values_; -}; - -enum slice_enum_t{ - ALL, - NEWAXIS -}; - -class slice: public node{ -public: - slice(slice_enum_t type) - : type_(type){} - - slice_enum_t type() const{ - return type_; - } - -public: - const slice_enum_t type_; -}; - -class named_expression; - -class expression: public node{ -public: - virtual ir::value* codegen(ir::module *) const = 0; - named_expression *lvalue() const { return lvalue_; } - -protected: - named_expression *lvalue_; -}; - -class postfix_expression: public expression{ - -}; - -class builtin_expression: public node{ - -}; - -class typed_declaration_specifier; -class alloc_const: public builtin_expression{ -public: - alloc_const(node *spec, node *size): spec_((typed_declaration_specifier*)spec), size_((constant*)size) { } - ir::value* codegen(ir::module *mod) const; - -private: - const typed_declaration_specifier* spec_; - const constant* size_; -}; - -class get_global_range: public builtin_expression{ -public: - get_global_range(node *size, node *axis): size_((constant*)size), axis_((constant*)axis) { } - ir::value* codegen(ir::module *) const; - -private: - const constant* size_; - const constant* axis_; -}; - -class get_range_id: public builtin_expression{ -public: - get_range_id(node *axis): axis_((constant*)axis) { } - ir::value* codegen(ir::module *) const; - -private: - const constant* axis_; -}; - -class atomic_cas: public builtin_expression{ -public: - atomic_cas(node *ptr, node *cmp, node *val): ptr_(ptr), cmp_(cmp), val_(val) { } - ir::value* codegen(ir::module *) const; - -private: - const node *ptr_; - const node *cmp_; - const node *val_; -}; - - -class matmul_expression: public builtin_expression{ -public: - matmul_expression(node* A, node *B, node *C): - A_((expression*)A), B_((expression*)B), C_((expression*)C) { } - ir::value* codegen(ir::module *) const; - -private: - const expression *A_; - const expression *B_; - const expression *C_; -}; - -class max_expression: public builtin_expression{ -public: - max_expression(node* x, node* y) - : x_((expression*)x), y_((expression*)y){ } - ir::value* codegen(ir::module *) const; - -private: - const expression *x_; - const expression *y_; -}; - -class min_expression: public builtin_expression{ -public: - min_expression(node* x, node* y) - : x_((expression*)x), y_((expression*)y){ } - ir::value* codegen(ir::module *mod) const; - -private: - const expression *x_; - const expression *y_; -}; - -class select_expression: public builtin_expression{ -public: - select_expression(node* pred, node* if_value, node* else_value) - : pred_((expression*)pred), if_value_((expression*)if_value), else_value_((expression*)else_value) { } - ir::value* codegen(ir::module *mod) const; - -private: - const expression *pred_; - const expression *if_value_; - const expression *else_value_; -}; - -class trans_expression: public builtin_expression{ -public: - trans_expression(node *arg): arg_(arg) {} - ir::value* codegen(ir::module *mod) const; - -private: - node* arg_; -}; - - -class indexing_expression: public postfix_expression{ -public: - indexing_expression(node *id, node *slices) - : id_((const identifier*)id), slices_((const list*)slices) {} - - ir::value* codegen(ir::module *) const; - -private: - const identifier* id_; - const list* slices_; -}; - - - -class named_expression: public expression { -public: - named_expression(node *id): id_((const identifier*)id) { lvalue_ = this; } - const identifier *id() const { return id_; } - ir::value* codegen(ir::module * mod) const; - -private: - const identifier *id_; -}; - -class binary_operator: public expression{ -private: - ir::value* llvm_op(ir::module *mod, ir::builder &bld, ir::value *lhs, ir::value *rhs, const std::string &name) const; - -public: - binary_operator(BIN_OP_T op, node *lhs, node *rhs) - : op_(op), lhs_((expression*)lhs), rhs_((expression*)rhs) { - } - ir::value* codegen(ir::module *) const; - -private: - const BIN_OP_T op_; - const expression *lhs_; - const expression *rhs_; -}; - - -class constant: public expression{ -public: - constant(int value): value_(value) { } - ir::value* codegen(ir::module *mod) const; - int value() const; - -private: - const int value_; -}; - -class constant_range: public expression { -public: - constant_range(node *first, node *last) - : first_((constant*)first), last_((constant*)last) { } - - ir::value* codegen(ir::module *mod) const; - -private: - constant *first_; - constant *last_; -}; - -class string_literal: public expression{ -public: - string_literal(char *&value): value_(value) { } - ir::value* codegen(ir::module *mod) const; - -public: - std::string value_; -}; - -class unary_operator: public expression{ -private: - ir::value *llvm_op(ir::builder &builder, ir::value *arg, const std::string &name) const; - -public: - unary_operator(UNARY_OP_T op, node *arg) - : op_(op), - arg_((expression*)arg) { - if(op == DEREF) - this->lvalue_ = arg_->lvalue(); - } - - UNARY_OP_T get_op() const { return op_; } - ir::value* codegen(ir::module *mod) const; - -private: - const UNARY_OP_T op_; - const expression *arg_; -}; - -class type_name; -class cast_operator: public expression{ -private: - ir::value *llvm_op(ir::builder &builder, ir::type *T, ir::value *arg, const std::string &name) const; - -public: - cast_operator(node *T, node *arg): - T_((type_name*)T), - arg_((expression*)arg) { } - - ir::value* codegen(ir::module *mod) const; - -public: - const type_name *T_; - const expression *arg_; -}; - -class conditional_expression: public expression{ -private: - ir::value *llvm_op(ir::builder &builder, - ir::value *cond, ir::value *true_value, ir::value *false_value, - const std::string &name) const; - -public: - conditional_expression(node *cond, node *true_value, node *false_value) - : cond_((expression*)cond), - true_value_((expression*)true_value), - false_value_((expression*)false_value) { } - - ir::value* codegen(ir::module *mod) const; - -public: - const expression *cond_; - const expression *true_value_; - const expression *false_value_; -}; - -class assignment_expression: public expression{ -public: - assignment_expression(node *lvalue, ASSIGN_OP_T op, node *rvalue) - : lvalue_((named_expression*)lvalue), op_(op), rvalue_((expression*)rvalue) { } - - ir::value* codegen(ir::module *mod) const; - const expression *lvalue() const { return lvalue_; } - const expression *rvalue() const { return rvalue_; } - -public: - ASSIGN_OP_T op_; - const expression *lvalue_; - const expression *rvalue_; -}; - - -class initializer; -class declaration_specifier; - -class block_item: public node{ -}; - -class declaration: public block_item{ -public: - declaration(node *spec, node *init) - : spec_((declaration_specifier*)spec), init_((list*)init) { } - - ir::value* codegen(ir::module * mod) const; - -public: - const declaration_specifier *spec_; - const list *init_; -}; - -class statement: public block_item{ -}; - -class expression_statement: public statement{ -public: - expression_statement(node *expr, node *mask = nullptr) - : expr_((expression*)expr), pred_((expression*)mask){ } - - ir::value* codegen(ir::module * mod) const; - -private: - expression *expr_; - expression *pred_; -}; - - -class compound_statement: public statement{ - typedef list* declarations_t; - typedef list* statements_t; - -public: - compound_statement(node* items) - : items_((list*)items){} - - ir::value* codegen(ir::module * mod) const; - -private: - list* items_; -}; - -class selection_statement: public statement{ -public: - selection_statement(node *cond, node *if_value, node *else_value = nullptr) - : cond_(cond), then_value_(if_value), else_value_(else_value) { } - - ir::value* codegen(ir::module *mod) const; - -public: - const node *cond_; - const node *then_value_; - const node *else_value_; -}; - -class iteration_statement: public statement{ -public: - iteration_statement(node *init, node *stop, node *exec, node *statements) - : init_(init), stop_(stop), exec_(exec), statements_(statements) - { } - - ir::value* codegen(ir::module *mod) const; - -private: - const node *init_; - const node *stop_; - const node *exec_; - const node *statements_; -}; - -class while_statement: public statement{ -public: - while_statement(node *cond, node *statements) - : cond_(cond), statements_(statements) - { } - - ir::value* codegen(ir::module *) const; - -private: - const node *cond_; - const node *statements_; -}; - -// Jump - -class jump_statement: public statement{ -public: - using statement::statement; -}; - -class continue_statement: public jump_statement{ -public: - ir::value* codegen(ir::module *mod) const; -}; - -class no_op: public statement { }; - -// Types -class declaration_specifier: public node{ -public: - virtual ir::type* type(ir::module *mod) const = 0; - virtual std::vector storage() const = 0; -}; - -class typed_declaration_specifier: public declaration_specifier { -public: - typed_declaration_specifier(TYPE_T ty): ty_(ty){ } - ir::type* type(ir::module *mod) const; - std::vector storage() const; - -private: - const TYPE_T ty_; -}; - -class storage_declaration_specifier: public declaration_specifier { -public: - storage_declaration_specifier(STORAGE_SPEC_T storage_spec, node *decl_spec) - : storage_spec_(storage_spec), decl_spec_((declaration_specifier*)decl_spec) {} - ir::type* type(ir::module *mod) const; - std::vector storage() const; - -private: - const STORAGE_SPEC_T storage_spec_; - const declaration_specifier* decl_spec_; -}; - -class declarator; -class parameter: public node { -public: - parameter(node *spec, node *decl) - : spec_((declaration_specifier*)spec), - decl_((declarator*)decl) { } - - ir::type* type(ir::module *mod) const; - std::vector storage() const; - const identifier* id() const; - -public: - const declaration_specifier *spec_; - const declarator *decl_; -}; - -/* Declarators */ -class declarator: public node{ -protected: - typedef std::vector storage_spec_vec_t; - typedef const storage_spec_vec_t& storage_spec_vec_const_ref_t; - -public: - virtual ir::type* type_impl(ir::module *mod, ir::type *type, storage_spec_vec_const_ref_t storage) const = 0; - -public: - declarator(node *lhs) - : lhs_((declarator*)lhs), ptr_(nullptr){ } - - ir::type* type(ir::module *mod, ir::type *type, storage_spec_vec_const_ref_t storage) const; - - const identifier* id() const { - return (const identifier*)lhs_; - } - - declarator *set_ptr(node *ptr){ - ptr_ = (pointer*)ptr; - return this; - } - - void set_addr_space(unsigned addr_space){ - addr_space_ = addr_space; - } - -protected: - declarator *lhs_; - pointer *ptr_; - unsigned addr_space_; -}; - -class identifier: public declarator { - ir::type* type_impl(ir::module *mod, ir::type *type, storage_spec_vec_const_ref_t storage) const; - -public: - identifier(char *&name): declarator(this), name_(name) { } - const std::string &name() const; - -private: - std::string name_; -}; - -class pointer: public declarator{ -private: - ir::type* type_impl(ir::module *mod, ir::type *type, storage_spec_vec_const_ref_t storage) const; - -public: - pointer(node *id): declarator(id) { } -}; - -class tile: public declarator{ -private: - ir::type* type_impl(ir::module *mod, ir::type *type, storage_spec_vec_const_ref_t storage) const; - -public: - tile(node *id, node *shapes) - : declarator(id), shapes_((list*)(shapes)) { } - -public: - const list* shapes_; -}; - -class function: public declarator{ -private: - ir::type* type_impl(ir::module *mod, ir::type *type, storage_spec_vec_const_ref_t storage) const; - -public: - function(node *id, node *args) - : declarator(id), args_((list*)args) { } - - void bind_parameters(ir::module *mod, ir::function *fn) const; - unsigned get_num_args() const { return args_->values().size(); } - parameter* get_arg(unsigned i) const { return args_->values().at(i); } - -public: - const list* args_; -}; - - -class initializer : public declarator{ -private: - ir::type* type_impl(ir::module * mod, ir::type *type, storage_spec_vec_const_ref_t storage) const; - -public: - initializer(node *decl, node *init) - : declarator((node*)((declarator*)decl)->id()), - decl_((declarator*)decl), expr_((expression*)init){ } - - void set_specifier(const declaration_specifier *spec); - ir::value* codegen(ir::module *) const; - -public: - const declaration_specifier *spec_; - declarator *decl_; - const expression *expr_; -}; - - -class type_name: public node{ -public: - type_name(node *spec, node * decl) - : spec_((declaration_specifier*)spec), decl_((declarator*)decl) { } - - ir::type *type(ir::module *mod) const; - -public: - const declaration_specifier *spec_; - const declarator *decl_; -}; - -/* Function definition */ -class function_definition: public node{ -public: - function_definition(node *spec, node *header, node *body) - : spec_((declaration_specifier*)spec), header_((function *)header), body_((compound_statement*)body) { } - - ir::value* codegen(ir::module * mod) const; - -public: - const declaration_specifier *spec_; - const function *header_; - const compound_statement *body_; -}; - -/* Translation Unit */ -class translation_unit: public node{ -public: - translation_unit(node *item) - : decls_(item) { } - - translation_unit *add(node *item) { - decls_.append(item); - return this; - } - - ir::value* codegen(ir::module * mod) const; - -private: - list decls_; -}; - -void update_location(const char *t); -void print_error(const char *error); -char return_impl(char t, const char * yytext); -yytokentype return_impl(yytokentype t, const char * yytext); -void return_void(const char * yytext); - -} - -} +#include "declaration.h" +#include "error.h" +#include "expression.h" +#include "node.h" +#include "ops.h" #endif diff --git a/include/triton/ast/declaration.h b/include/triton/ast/declaration.h new file mode 100644 index 000000000..5a51c3f9a --- /dev/null +++ b/include/triton/ast/declaration.h @@ -0,0 +1,222 @@ +#ifndef TRITON_INCLUDE_AST_DECLARATION_H +#define TRITON_INCLUDE_AST_DECLARATION_H + +#include "node.h" +#include "parser.hpp" +#include +#include +#include +#include + + +namespace triton{ + + +namespace ir{ + class function; + class value; + class type; + class builder; + class module; +} + +namespace ast{ + +class expression; +class pointer; +class identifier; +class constant; +class compound_statement; +class initializer; +class declaration_specifier; + + +class declaration: public block_item{ +public: + declaration(node *spec, node *init) + : spec_((declaration_specifier*)spec), init_((list*)init) { } + + ir::value* codegen(ir::module * mod) const; + +public: + const declaration_specifier *spec_; + const list *init_; +}; + +// Types +class declaration_specifier: public node{ +public: + virtual ir::type* type(ir::module *mod) const = 0; + virtual std::vector storage() const = 0; +}; + +class typed_declaration_specifier: public declaration_specifier { +public: + typed_declaration_specifier(TYPE_T ty): ty_(ty){ } + ir::type* type(ir::module *mod) const; + std::vector storage() const; + +private: + const TYPE_T ty_; +}; + +class storage_declaration_specifier: public declaration_specifier { +public: + storage_declaration_specifier(STORAGE_SPEC_T storage_spec, node *decl_spec) + : storage_spec_(storage_spec), decl_spec_((declaration_specifier*)decl_spec) {} + ir::type* type(ir::module *mod) const; + std::vector storage() const; + +private: + const STORAGE_SPEC_T storage_spec_; + const declaration_specifier* decl_spec_; +}; + +class declarator; +class parameter: public node { +public: + parameter(node *spec, node *decl) + : spec_((declaration_specifier*)spec), + decl_((declarator*)decl) { } + + ir::type* type(ir::module *mod) const; + std::vector storage() const; + const identifier* id() const; + +public: + const declaration_specifier *spec_; + const declarator *decl_; +}; + +/* Declarators */ +class declarator: public node{ +protected: + typedef std::vector storage_spec_vec_t; + typedef const storage_spec_vec_t& storage_spec_vec_const_ref_t; + +public: + virtual ir::type* type_impl(ir::module *mod, ir::type *type, storage_spec_vec_const_ref_t storage) const = 0; + +public: + declarator(node *lhs) + : lhs_((declarator*)lhs), ptr_(nullptr){ } + + ir::type* type(ir::module *mod, ir::type *type, storage_spec_vec_const_ref_t storage) const; + + const identifier* id() const { + return (const identifier*)lhs_; + } + + declarator *set_ptr(node *ptr){ + ptr_ = (pointer*)ptr; + return this; + } + + void set_addr_space(unsigned addr_space){ + addr_space_ = addr_space; + } + +protected: + declarator *lhs_; + pointer *ptr_; + unsigned addr_space_; +}; + +class identifier: public declarator { + ir::type* type_impl(ir::module *mod, ir::type *type, storage_spec_vec_const_ref_t storage) const; + +public: + identifier(char *&name): declarator(this), name_(name) { } + const std::string &name() const; + +private: + std::string name_; +}; + +class pointer: public declarator{ +private: + ir::type* type_impl(ir::module *mod, ir::type *type, storage_spec_vec_const_ref_t storage) const; + +public: + pointer(node *id): declarator(id) { } +}; + +class tile: public declarator{ +private: + ir::type* type_impl(ir::module *mod, ir::type *type, storage_spec_vec_const_ref_t storage) const; + +public: + tile(node *id, node *shapes) + : declarator(id), shapes_((list*)(shapes)) { } + +public: + const list* shapes_; +}; + +class function: public declarator{ +private: + ir::type* type_impl(ir::module *mod, ir::type *type, storage_spec_vec_const_ref_t storage) const; + +public: + function(node *id, node *args) + : declarator(id), args_((list*)args) { } + + void bind_parameters(ir::module *mod, ir::function *fn) const; + unsigned get_num_args() const { return args_->values().size(); } + parameter* get_arg(unsigned i) const { return args_->values().at(i); } + +public: + const list* args_; +}; + + +class initializer : public declarator{ +private: + ir::type* type_impl(ir::module * mod, ir::type *type, storage_spec_vec_const_ref_t storage) const; + +public: + initializer(node *decl, node *init) + : declarator((node*)((declarator*)decl)->id()), + decl_((declarator*)decl), expr_((expression*)init){ } + + void set_specifier(const declaration_specifier *spec); + ir::value* codegen(ir::module *) const; + +public: + const declaration_specifier *spec_; + declarator *decl_; + const expression *expr_; +}; + + +class type_name: public node{ +public: + type_name(node *spec, node * decl) + : spec_((declaration_specifier*)spec), decl_((declarator*)decl) { } + + ir::type *type(ir::module *mod) const; + +public: + const declaration_specifier *spec_; + const declarator *decl_; +}; + +/* Function definition */ +class function_definition: public node{ +public: + function_definition(node *spec, node *header, node *body) + : spec_((declaration_specifier*)spec), header_((function *)header), body_((compound_statement*)body) { } + + ir::value* codegen(ir::module * mod) const; + +public: + const declaration_specifier *spec_; + const function *header_; + const compound_statement *body_; +}; + +} + +} + +#endif diff --git a/include/triton/ast/error.h b/include/triton/ast/error.h new file mode 100644 index 000000000..5834d55f6 --- /dev/null +++ b/include/triton/ast/error.h @@ -0,0 +1,62 @@ +#ifndef TRITON_INCLUDE_AST_ERROR_H +#define TRITON_INCLUDE_AST_ERROR_H + +#include "ops.h" +#include "parser.hpp" +#include "node.h" +#include +#include +#include +#include + + +namespace triton{ + + +namespace ir{ + class function; + class value; + class type; + class builder; + class module; +} + +namespace ast{ + +class expression; +class pointer; +class identifier; +class constant; +class compound_statement; +class initializer; +class declaration_specifier; +class function; + +/* Translation Unit */ +class translation_unit: public node{ +public: + translation_unit(node *item) + : decls_(item) { } + + translation_unit *add(node *item) { + decls_.append(item); + return this; + } + + ir::value* codegen(ir::module * mod) const; + +private: + list decls_; +}; + +void update_location(const char *t); +void print_error(const char *error); +char return_impl(char t, const char * yytext); +yytokentype return_impl(yytokentype t, const char * yytext); +void return_void(const char * yytext); + +} + +} + +#endif diff --git a/include/triton/ast/expression.h b/include/triton/ast/expression.h new file mode 100644 index 000000000..27d72dec8 --- /dev/null +++ b/include/triton/ast/expression.h @@ -0,0 +1,311 @@ +#ifndef TDL_INCLUDE_AST_EXPRESSION_H +#define TDL_INCLUDE_AST_EXPRESSION_H + +#include "parser.hpp" +#include "ast.h" +#include +#include +#include +#include + + +namespace triton{ + + +namespace ir{ + class function; + class value; + class type; + class builder; + class module; +} + +namespace ast{ + + +enum slice_enum_t{ + ALL, + NEWAXIS +}; + +class slice: public node{ +public: + slice(slice_enum_t type) + : type_(type){} + + slice_enum_t type() const{ + return type_; + } + +public: + const slice_enum_t type_; +}; + + +class named_expression; + +class expression: public node{ +public: + virtual ir::value* codegen(ir::module *) const = 0; + named_expression *lvalue() const { return lvalue_; } + +protected: + named_expression *lvalue_; +}; + +class postfix_expression: public expression{ + +}; + +class builtin_expression: public node{ + +}; + +class typed_declaration_specifier; +class alloc_const_expression: public builtin_expression{ +public: + alloc_const_expression(node *spec, node *size): spec_((typed_declaration_specifier*)spec), size_((constant*)size) { } + ir::value* codegen(ir::module *mod) const; + +private: + const typed_declaration_specifier* spec_; + const constant* size_; +}; + +class get_global_range_expression: public builtin_expression{ +public: + get_global_range_expression(node *size, node *axis): size_((constant*)size), axis_((constant*)axis) { } + ir::value* codegen(ir::module *) const; + +private: + const constant* size_; + const constant* axis_; +}; + +class get_range_id_expression: public builtin_expression{ +public: + get_range_id_expression(node *axis): axis_((constant*)axis) { } + ir::value* codegen(ir::module *) const; + +private: + const constant* axis_; +}; + +class atomic_cas_expression: public builtin_expression{ +public: + atomic_cas_expression(node *ptr, node *cmp, node *val): ptr_(ptr), cmp_(cmp), val_(val) { } + ir::value* codegen(ir::module *) const; + +private: + const node *ptr_; + const node *cmp_; + const node *val_; +}; + + +class matmul_expression: public builtin_expression{ +public: + matmul_expression(node* A, node *B, node *C): + A_((expression*)A), B_((expression*)B), C_((expression*)C) { } + ir::value* codegen(ir::module *) const; + +private: + const expression *A_; + const expression *B_; + const expression *C_; +}; + +class max_expression: public builtin_expression{ +public: + max_expression(node* x, node* y) + : x_((expression*)x), y_((expression*)y){ } + ir::value* codegen(ir::module *) const; + +private: + const expression *x_; + const expression *y_; +}; + +class min_expression: public builtin_expression{ +public: + min_expression(node* x, node* y) + : x_((expression*)x), y_((expression*)y){ } + ir::value* codegen(ir::module *mod) const; + +private: + const expression *x_; + const expression *y_; +}; + +class select_expression: public builtin_expression{ +public: + select_expression(node* pred, node* if_value, node* else_value) + : pred_((expression*)pred), if_value_((expression*)if_value), else_value_((expression*)else_value) { } + ir::value* codegen(ir::module *mod) const; + +private: + const expression *pred_; + const expression *if_value_; + const expression *else_value_; +}; + +class trans_expression: public builtin_expression{ +public: + trans_expression(node *arg): arg_(arg) {} + ir::value* codegen(ir::module *mod) const; + +private: + node* arg_; +}; + + +class indexing_expression: public postfix_expression{ +public: + indexing_expression(node *id, node *slices) + : id_((const identifier*)id), slices_((const list*)slices) {} + + ir::value* codegen(ir::module *) const; + +private: + const identifier* id_; + const list* slices_; +}; + + + +class named_expression: public expression { +public: + named_expression(node *id): id_((const identifier*)id) { lvalue_ = this; } + const identifier *id() const { return id_; } + ir::value* codegen(ir::module * mod) const; + +private: + const identifier *id_; +}; + +class binary_expression: public expression{ +private: + ir::value* llvm_op(ir::module *mod, ir::builder &bld, ir::value *lhs, ir::value *rhs, const std::string &name) const; + +public: + binary_expression(BIN_OP_T op, node *lhs, node *rhs) + : op_(op), lhs_((expression*)lhs), rhs_((expression*)rhs) { + } + ir::value* codegen(ir::module *) const; + +private: + const BIN_OP_T op_; + const expression *lhs_; + const expression *rhs_; +}; + + +class constant: public expression{ +public: + constant(int value): value_(value) { } + ir::value* codegen(ir::module *mod) const; + int value() const; + +private: + const int value_; +}; + +class constant_range: public expression { +public: + constant_range(node *first, node *last) + : first_((constant*)first), last_((constant*)last) { } + + ir::value* codegen(ir::module *mod) const; + +private: + constant *first_; + constant *last_; +}; + +class string_literal: public expression{ +public: + string_literal(char *&value): value_(value) { } + ir::value* codegen(ir::module *mod) const; + +public: + std::string value_; +}; + +class unary_expression: public expression{ +private: + ir::value *llvm_op(ir::builder &builder, ir::value *arg, const std::string &name) const; + +public: + unary_expression(UNARY_OP_T op, node *arg) + : op_(op), + arg_((expression*)arg) { + if(op == DEREF) + this->lvalue_ = arg_->lvalue(); + } + + UNARY_OP_T get_op() const { return op_; } + ir::value* codegen(ir::module *mod) const; + +private: + const UNARY_OP_T op_; + const expression *arg_; +}; + +class type_name; +class cast_expression: public expression{ +private: + ir::value *llvm_op(ir::builder &builder, ir::type *T, ir::value *arg, const std::string &name) const; + +public: + cast_expression(node *T, node *arg): + T_((type_name*)T), + arg_((expression*)arg) { } + + ir::value* codegen(ir::module *mod) const; + +public: + const type_name *T_; + const expression *arg_; +}; + +class conditional_expression: public expression{ +private: + ir::value *llvm_op(ir::builder &builder, + ir::value *cond, ir::value *true_value, ir::value *false_value, + const std::string &name) const; + +public: + conditional_expression(node *cond, node *true_value, node *false_value) + : cond_((expression*)cond), + true_value_((expression*)true_value), + false_value_((expression*)false_value) { } + + ir::value* codegen(ir::module *mod) const; + +public: + const expression *cond_; + const expression *true_value_; + const expression *false_value_; +}; + +class assignment_expression: public expression{ +public: + assignment_expression(node *lvalue, ASSIGN_OP_T op, node *rvalue) + : lvalue_((named_expression*)lvalue), op_(op), rvalue_((expression*)rvalue) { } + + ir::value* codegen(ir::module *mod) const; + const expression *lvalue() const { return lvalue_; } + const expression *rvalue() const { return rvalue_; } + +public: + ASSIGN_OP_T op_; + const expression *lvalue_; + const expression *rvalue_; +}; + + +} + +} + +#endif diff --git a/include/triton/ast/module.h b/include/triton/ast/module.h new file mode 100644 index 000000000..6d72753ce --- /dev/null +++ b/include/triton/ast/module.h @@ -0,0 +1,37 @@ +#ifndef TRITON_INCLUDE_AST_MODULE_H +#define TRITON_INCLUDE_AST_MODULE_H + +#include "ops.h" +#include "parser.hpp" +#include "node.h" +#include +#include +#include +#include + + +namespace triton{ +namespace ast{ + +/* Translation Unit */ +class translation_unit: public node{ +public: + translation_unit(node *item) + : decls_(item) { } + + translation_unit *add(node *item) { + decls_.append(item); + return this; + } + + ir::value* codegen(ir::module * mod) const; + +private: + list decls_; +}; + +} + +} + +#endif diff --git a/include/triton/ast/node.h b/include/triton/ast/node.h new file mode 100644 index 000000000..265443397 --- /dev/null +++ b/include/triton/ast/node.h @@ -0,0 +1,77 @@ +#ifndef TRITON_INCLUDE_AST_NODE_H +#define TRITON_INCLUDE_AST_NODE_H + +#include "ops.h" +#include "parser.hpp" +#include +#include +#include +#include + + +namespace triton{ + + +namespace ir{ + class function; + class value; + class type; + class builder; + class module; +} + +namespace ast{ + +class expression; +class pointer; +class identifier; +class constant; +class compound_statement; +class initializer; +class declaration_specifier; +class function; + +// Node +class node { +protected: + static ir::value* explicit_cast(ir::builder &builder, ir::value *src, ir::type *dst_ty); + static void implicit_broadcast(ir::module *mod, ir::type *dst_ty, ir::value *&src); + static void implicit_broadcast(ir::module *mod, ir::value *&lhs, ir::value *&rhs); + static void implicit_cast(ir::builder &builder, ir::value *&lhs, ir::value *&rhs, + bool &is_float, bool &is_ptr, bool &is_int, bool &is_signed); +public: + virtual ir::value* codegen(ir::module *) const { return nullptr; } +}; + +class block_item: public node{ +}; + +template +class list: public node { +public: + list(const T& x): values_(1, x) {} + + node* append(const T& x){ + values_.push_back(x); + return this; + } + + ir::value* codegen(ir::module * mod) const{ + for(T x: values_){ + x->codegen(mod); + } + return nullptr; + } + + const std::vector &values() const + { return values_; } + +private: + std::vector values_; +}; + +} + +} + +#endif diff --git a/include/triton/ast/ops.h b/include/triton/ast/ops.h new file mode 100644 index 000000000..316fdccb3 --- /dev/null +++ b/include/triton/ast/ops.h @@ -0,0 +1,60 @@ +#ifndef TRITON_INCLUDE_AST_OPS_H +#define TRITON_INCLUDE_AST_OPS_H + +#include "parser.hpp" +#include +#include +#include +#include + +namespace triton{ +namespace ast{ + +enum ASSIGN_OP_T{ + ASSIGN, + INPLACE_MUL, INPLACE_DIV, INPLACE_MOD, + INPLACE_ADD, INPLACE_SUB, + INPLACE_LSHIFT, INPLACE_RSHIFT, + INPLACE_AND, INPLACE_XOR, + INPLACE_OR +}; + +enum BIN_OP_T{ + MUL, DIV, MOD, + ADD, SUB, + LEFT_SHIFT, RIGHT_SHIFT, + LT, GT, + LE, GE, + EQ, NE, + AND, XOR, OR, + LAND, LOR +}; + +enum UNARY_OP_T{ + INC, DEC, + PLUS, MINUS, + ADDR, DEREF, + COMPL, NOT +}; + +enum TYPE_T{ + VOID_T, + UINT1_T, UINT8_T, UINT16_T, UINT32_T, UINT64_T, + INT1_T, INT8_T, INT16_T, INT32_T, INT64_T, + FLOAT32_T, FLOAT64_T +}; + +enum STORAGE_SPEC_T{ + CONST_T, + TUNABLE_T, + KERNEL_T, + RESTRICT_T, + READONLY_T, + CONSTANT_SPACE_T, + WRITEONLY_T +}; + +} +} + +#endif diff --git a/include/triton/ast/parser.y b/include/triton/ast/parser.y index 9dab092de..c71f8a20e 100644 --- a/include/triton/ast/parser.y +++ b/include/triton/ast/parser.y @@ -9,6 +9,9 @@ class node; using namespace triton::ast; #define YYSTYPE node* #include "../include/triton/ast/ast.h" +#include "../include/triton/ast/expression.h" +#include "../include/triton/ast/statement.h" +#include "../include/triton/ast/declaration.h" extern char* yytext; void yyerror(const char *s); @@ -86,82 +89,80 @@ pointer | '*' pointer { $$ = new pointer($1); } abstract_declarator - : pointer { $$ = $1; } + : pointer { $$ = $1; } | pointer direct_abstract_declarator { $$ = ((declarator*)$2)->set_ptr($1); } - | direct_abstract_declarator { $$ = $1; } - ; + | direct_abstract_declarator { $$ = $1; } + ; direct_abstract_declarator - : '[' primary_expression_list ']' { $$ = new tile(nullptr, $1); } - -constant: - CONSTANT { $$ = new constant(atoi(yytext)); } - ; - -constant_list: - constant { $$ = new list((constant*)$1); } - | constant_list ',' constant { $$ = append_ptr_list($1, $3); } - ; + : '[' primary_expression_list ']' { $$ = new tile(nullptr, $1); } type_name : declaration_specifiers { $$ = new type_name($1, nullptr); } | declaration_specifiers abstract_declarator { $$ = new type_name($1, $2); } - ; + ; /* -------------------------- */ /* Expressions */ /* -------------------------- */ -identifier - : IDENTIFIER { $$ = new identifier(yytext); } - ; - -builtin - : GET_GLOBAL_RANGE '[' primary_expression ']' '(' constant ')' { $$ = new get_global_range($3, $6); } - | GET_RANGE_ID '(' constant ')' { $$ = new get_range_id($3); } - | DOT '(' expression ',' expression ',' expression ')' { $$ = new matmul_expression($3, $5, $7); } - | ALLOC_CONST type_specifier '[' constant ']' { $$ = new alloc_const(new typed_declaration_specifier(get_type_spec($2)), $4); } - | TRANS '(' expression ')' { $$ = new trans_expression($3); } - | MAX '(' expression ',' expression ')' { $$ = new max_expression($3, $5); } - | MIN '(' expression ',' expression ')' { $$ = new min_expression($3, $5); } - | SELECT '(' expression ',' expression ',' expression ')' { $$ = new select_expression($3, $5, $7); } - | ATOMIC_CAS '(' expression ',' expression ',' expression ')' { $$ = new atomic_cas($3, $5, $7); } +/* Constants */ +constant + : CONSTANT { $$ = new constant(atoi(yytext)); } ; +constant_list + : constant { $$ = new list((constant*)$1); } + | constant_list ',' constant { $$ = append_ptr_list($1, $3); } + ; + +identifier + : IDENTIFIER { $$ = new identifier(yytext); } + ; + +/* Built-in */ +builtin_expression + : GET_GLOBAL_RANGE '[' primary_expression ']' '(' constant ')' { $$ = new get_global_range_expression($3, $6); } + | GET_RANGE_ID '(' constant ')' { $$ = new get_range_id_expression($3); } + | DOT '(' expression ',' expression ',' expression ')' { $$ = new matmul_expression($3, $5, $7); } + | ALLOC_CONST type_specifier '[' constant ']' { $$ = new alloc_const_expression(new typed_declaration_specifier(get_type_spec($2)), $4); } + | TRANS '(' expression ')' { $$ = new trans_expression($3); } + | MAX '(' expression ',' expression ')' { $$ = new max_expression($3, $5); } + | MIN '(' expression ',' expression ')' { $$ = new min_expression($3, $5); } + | SELECT '(' expression ',' expression ',' expression ')' { $$ = new select_expression($3, $5, $7); } + | ATOMIC_CAS '(' expression ',' expression ',' expression ')' { $$ = new atomic_cas_expression($3, $5, $7); } + ; + +/* Primary */ primary_expression - : identifier { $$ = new named_expression($1); } - | constant { $$ = $1; } + : identifier { $$ = new named_expression($1); } + | constant { $$ = $1; } | primary_expression ELLIPSIS primary_expression { $$ = new constant_range($1, $3); } - | builtin { $$ = $1; } - | STRING_LITERAL { $$ = new string_literal(yytext); } - | '(' expression ')' { $$ = $2; } - ; + | builtin_expression { $$ = $1; } + | STRING_LITERAL { $$ = new string_literal(yytext); } + | '(' expression ')' { $$ = $2; } + ; primary_expression_list - : primary_expression { $$ = new list((expression*)$1); } + : primary_expression { $$ = new list((expression*)$1); } | primary_expression_list ',' primary_expression { $$ = append_ptr_list($1, $3); } ; +/* Postfix */ slice - : ':' { $$ = new slice(triton::ast::ALL); } - | NEWAXIS { $$ = new slice(triton::ast::NEWAXIS); } + : ':' { $$ = new slice(triton::ast::ALL); } + | NEWAXIS { $$ = new slice(triton::ast::NEWAXIS); } slice_list - : slice { $$ = new list((slice*)$1); } - | slice_list ',' slice { $$ = append_ptr_list($1, $3); } + : slice { $$ = new list((slice*)$1); } + | slice_list ',' slice { $$ = append_ptr_list($1, $3); } postfix_expression - : primary_expression { $$ = $1;} - | identifier '[' slice_list ']' { $$ = new indexing_expression($1, $3);} + : primary_expression { $$ = $1;} + | identifier '[' slice_list ']' { $$ = new indexing_expression($1, $3);} ; -unary_expression - : postfix_expression { $$ = $1; } - | INC_OP unary_expression { $$ = new unary_operator(INC, $2); } - | DEC_OP unary_expression { $$ = new unary_operator(DEC, $2); } - | unary_operator cast_expression { $$ = new unary_operator(get_unary_op($1), $2); } - ; - +/* Unary */ unary_operator : '&' { $$ = new token(ADDR); } | '*' { $$ = new token(DEREF); } @@ -169,79 +170,86 @@ unary_operator | '-' { $$ = new token(MINUS); } | '~' { $$ = new token(COMPL); } | '!' { $$ = new token(NOT); } - ; + ; + +unary_expression + : postfix_expression { $$ = $1; } + | INC_OP unary_expression { $$ = new unary_expression(INC, $2); } + | DEC_OP unary_expression { $$ = new unary_expression(DEC, $2); } + | unary_operator cast_expression { $$ = new unary_expression(get_unary_op($1), $2); } + ; cast_expression - : unary_expression { $$ = $1; } - | '(' type_name ')' cast_expression { $$ = new cast_operator($2, $4); } - ; + : unary_expression { $$ = $1; } + | '(' type_name ')' cast_expression { $$ = new cast_expression($2, $4); } + ; multiplicative_expression - : cast_expression { $$ = $1; } - | multiplicative_expression '*' cast_expression { $$ = new binary_operator(MUL, $1, $3); } - | multiplicative_expression '/' cast_expression { $$ = new binary_operator(DIV, $1, $3); } - | multiplicative_expression '%' cast_expression { $$ = new binary_operator(MOD, $1, $3); } - ; + : cast_expression { $$ = $1; } + | multiplicative_expression '*' cast_expression { $$ = new binary_expression(MUL, $1, $3); } + | multiplicative_expression '/' cast_expression { $$ = new binary_expression(DIV, $1, $3); } + | multiplicative_expression '%' cast_expression { $$ = new binary_expression(MOD, $1, $3); } + ; additive_expression - : multiplicative_expression { $$ = $1; } - | additive_expression '+' multiplicative_expression { $$ = new binary_operator(ADD, $1, $3); } - | additive_expression '-' multiplicative_expression { $$ = new binary_operator(SUB, $1, $3); } - ; + : multiplicative_expression { $$ = $1; } + | additive_expression '+' multiplicative_expression { $$ = new binary_expression(ADD, $1, $3); } + | additive_expression '-' multiplicative_expression { $$ = new binary_expression(SUB, $1, $3); } + ; shift_expression - : additive_expression { $$ = $1; } - | shift_expression LEFT_OP additive_expression { $$ = new binary_operator(LEFT_SHIFT, $1, $3); } - | shift_expression RIGHT_OP additive_expression { $$ = new binary_operator(RIGHT_SHIFT, $1, $3); } - ; + : additive_expression { $$ = $1; } + | shift_expression LEFT_OP additive_expression { $$ = new binary_expression(LEFT_SHIFT, $1, $3); } + | shift_expression RIGHT_OP additive_expression { $$ = new binary_expression(RIGHT_SHIFT, $1, $3); } + ; /* Comparison */ relational_expression - : shift_expression { $$ = $1; } - | relational_expression '<' shift_expression { $$ = new binary_operator(LT, $1, $3); } - | relational_expression '>' shift_expression { $$ = new binary_operator(GT, $1, $3); } - | relational_expression LE_OP shift_expression { $$ = new binary_operator(LE, $1, $3); } - | relational_expression GE_OP shift_expression { $$ = new binary_operator(GE, $1, $3); } - ; + : shift_expression { $$ = $1; } + | relational_expression '<' shift_expression { $$ = new binary_expression(LT, $1, $3); } + | relational_expression '>' shift_expression { $$ = new binary_expression(GT, $1, $3); } + | relational_expression LE_OP shift_expression { $$ = new binary_expression(LE, $1, $3); } + | relational_expression GE_OP shift_expression { $$ = new binary_expression(GE, $1, $3); } + ; equality_expression - : relational_expression { $$ = $1; } - | equality_expression EQ_OP relational_expression { $$ = new binary_operator(EQ, $1, $3); } - | equality_expression NE_OP relational_expression { $$ = new binary_operator(NE, $1, $3); } - ; + : relational_expression { $$ = $1; } + | equality_expression EQ_OP relational_expression { $$ = new binary_expression(EQ, $1, $3); } + | equality_expression NE_OP relational_expression { $$ = new binary_expression(NE, $1, $3); } + ; /* Binary */ and_expression - : equality_expression { $$ = $1; } - | and_expression '&' equality_expression { $$ = new binary_operator(AND, $1, $3); } - ; + : equality_expression { $$ = $1; } + | and_expression '&' equality_expression { $$ = new binary_expression(AND, $1, $3); } + ; exclusive_or_expression - : and_expression { $$ = $1; } - | exclusive_or_expression '^' and_expression { $$ = new binary_operator(XOR, $1, $3); } - ; + : and_expression { $$ = $1; } + | exclusive_or_expression '^' and_expression { $$ = new binary_expression(XOR, $1, $3); } + ; inclusive_or_expression - : exclusive_or_expression { $$ = $1; } - | inclusive_or_expression '|' exclusive_or_expression { $$ = new binary_operator(OR, $1, $3); } - ; + : exclusive_or_expression { $$ = $1; } + | inclusive_or_expression '|' exclusive_or_expression { $$ = new binary_expression(OR, $1, $3); } + ; /* Logical */ logical_and_expression - : inclusive_or_expression { $$ = $1; } - | logical_and_expression AND_OP inclusive_or_expression { $$ = new binary_operator(LAND, $1, $3); } - ; + : inclusive_or_expression { $$ = $1; } + | logical_and_expression AND_OP inclusive_or_expression { $$ = new binary_expression(LAND, $1, $3); } + ; logical_or_expression - : logical_and_expression { $$ = $1; } - | logical_or_expression OR_OP logical_and_expression { $$ = new binary_operator(LOR, $1, $3); } - ; + : logical_and_expression { $$ = $1; } + | logical_or_expression OR_OP logical_and_expression { $$ = new binary_expression(LOR, $1, $3); } + ; /* Conditional */ conditional_expression - : logical_or_expression { $$ = $1; } + : logical_or_expression { $$ = $1; } | logical_or_expression '?' conditional_expression ':' conditional_expression { $$ = new conditional_expression($1, $3, $5); } - ; + ; /* Assignment */ assignment_operator @@ -259,14 +267,14 @@ assignment_operator ; assignment_expression - : conditional_expression { $$ = $1; } + : conditional_expression { $$ = $1; } | unary_expression assignment_operator assignment_expression { $$ = new assignment_expression($1, get_assign_op($2), $3); } - ; + ; /* Expression */ expression - : assignment_expression { $$ = $1; } - ; + : assignment_expression { $$ = $1; } + ; /* Initialization */ initialization_expression @@ -280,16 +288,16 @@ initialization_expression /* -------------------------- */ statement - : compound_statement { $$ = $1; } - | expression_statement { $$ = $1; } - | selection_statement { $$ = $1; } - | iteration_statement { $$ = $1; } - | jump_statement { $$ = $1; } - ; + : compound_statement { $$ = $1; } + | expression_statement { $$ = $1; } + | selection_statement { $$ = $1; } + | iteration_statement { $$ = $1; } + | jump_statement { $$ = $1; } + ; compound_statement - : '{' '}' { $$ = new compound_statement(nullptr); } - | '{' block_item_list '}' { $$ = new compound_statement($2); } + : '{' '}' { $$ = new compound_statement(nullptr); } + | '{' block_item_list '}' { $$ = new compound_statement($2); } block_item_list : block_item { $$ = new list((block_item*)$1); } @@ -300,7 +308,7 @@ block_item | statement { $$ = $1; } expression_statement - : ';' { $$ = new no_op(); } + : ';' { $$ = new no_op(); } | expression ';' { $$ = new expression_statement($1); } | AT primary_expression expression ';' { $$ = new expression_statement($3, $2); } ; @@ -334,7 +342,7 @@ direct_declarator parameter_list - : parameter_declaration { $$ = new list((parameter*)$1); } + : parameter_declaration { $$ = new list((parameter*)$1); } | parameter_list ',' parameter_declaration { $$ = append_ptr_list($1, $3); } ; @@ -355,20 +363,19 @@ init_declarator_list ; declaration - : declaration_specifiers ';' { $$ = new declaration($1, nullptr); } - | declaration_specifiers init_declarator_list ';' { $$ = new declaration($1, $2); } - ; + : declaration_specifiers ';' { $$ = new declaration($1, nullptr); } + | declaration_specifiers init_declarator_list ';' { $$ = new declaration($1, $2); } + ; declarator : pointer direct_declarator { $$ = ((declarator*)$2)->set_ptr($1); } - | direct_declarator { $$ = $1; } - ; - + | direct_declarator { $$ = $1; } + ; init_declarator : declarator { $$ = new initializer($1, nullptr); } | declarator '=' initialization_expression { $$ = new initializer($1, $3); } - ; + ; storage_class_specifier : CONST { $$ = new token(CONST_T); } @@ -381,13 +388,13 @@ storage_class_specifier ; /* -------------------------- */ -/* Translation Unit */ +/* Translation Unit */ /* -------------------------- */ translation_unit : external_declaration { ast_root = new translation_unit($1); $$ = ast_root; } - | translation_unit external_declaration { $$ = ((translation_unit*)($1))->add($2); } - ; + | translation_unit external_declaration { $$ = ((translation_unit*)($1))->add($2); } + ; external_declaration : function_definition { $$ = $1; } @@ -396,7 +403,7 @@ external_declaration function_definition : declaration_specifiers declarator compound_statement { $$ = new function_definition($1, $2, $3); } - ; + ; %% void yyerror (const char *s){ diff --git a/include/triton/ast/statement.h b/include/triton/ast/statement.h new file mode 100644 index 000000000..575d70690 --- /dev/null +++ b/include/triton/ast/statement.h @@ -0,0 +1,121 @@ +#ifndef TRITON_INCLUDE_AST_STATEMENT_H +#define TRITON_INCLUDE_AST_STATEMENT_H + +#include "parser.hpp" +#include "triton/ast/ast.h" +#include +#include +#include +#include + + +namespace triton{ + + +namespace ir{ + class function; + class value; + class type; + class builder; + class module; +} + +namespace ast{ + +class declaration; + +class statement: public block_item{ +}; + +// Expression +class expression_statement: public statement{ +public: + expression_statement(node *expr, node *mask = nullptr) + : expr_((expression*)expr), pred_((expression*)mask){ } + + ir::value* codegen(ir::module * mod) const; + +private: + expression *expr_; + expression *pred_; +}; + +// Compound +class compound_statement: public statement{ + typedef list* declarations_t; + typedef list* statements_t; + +public: + compound_statement(node* items) + : items_((list*)items){} + + ir::value* codegen(ir::module * mod) const; + +private: + list* items_; +}; + +// Selection +class selection_statement: public statement{ +public: + selection_statement(node *cond, node *if_value, node *else_value = nullptr) + : cond_(cond), then_value_(if_value), else_value_(else_value) { } + + ir::value* codegen(ir::module *mod) const; + +public: + const node *cond_; + const node *then_value_; + const node *else_value_; +}; + +// Iteration +class iteration_statement: public statement{ +public: + iteration_statement(node *init, node *stop, node *exec, node *statements) + : init_(init), stop_(stop), exec_(exec), statements_(statements) + { } + + ir::value* codegen(ir::module *mod) const; + +private: + const node *init_; + const node *stop_; + const node *exec_; + const node *statements_; +}; + +// While +class while_statement: public statement{ +public: + while_statement(node *cond, node *statements) + : cond_(cond), statements_(statements) + { } + + ir::value* codegen(ir::module *) const; + +private: + const node *cond_; + const node *statements_; +}; + +// Jump +class jump_statement: public statement{ +public: + using statement::statement; +}; + +// Continue +class continue_statement: public jump_statement{ +public: + ir::value* codegen(ir::module *mod) const; +}; + +// No op +class no_op: public statement { }; + +} + +} + +#endif diff --git a/lib/ast/declaration.cpp b/lib/ast/declaration.cpp new file mode 100644 index 000000000..888cdf7ff --- /dev/null +++ b/lib/ast/declaration.cpp @@ -0,0 +1,199 @@ +#include "triton/ast/statement.h" +#include "triton/ast/declaration.h" +#include "triton/ir/function.h" +#include "triton/ir/module.h" +#include "triton/ir/basic_block.h" +#include "triton/ir/builder.h" +#include "triton/ir/type.h" + + +namespace triton{ + +namespace ast{ + +/* Declaration specifier */ +ir::type* typed_declaration_specifier::type(ir::module *mod) const { + ir::context &ctx = mod->get_context(); + switch (ty_) { + case VOID_T: return ir::type::get_void_ty(ctx); + case INT1_T: return ir::type::get_int1_ty(ctx); + case INT8_T: return ir::type::get_int8_ty(ctx); + case INT16_T: return ir::type::get_int16_ty(ctx); + case INT32_T: return ir::type::get_int32_ty(ctx); + case INT64_T: return ir::type::get_int64_ty(ctx); + case FLOAT32_T: return ir::type::get_float_ty(ctx); + case FLOAT64_T: return ir::type::get_double_ty(ctx); + default: throw std::runtime_error("unreachable"); + } +} + +std::vector typed_declaration_specifier::storage() const { + return {}; +} + + +ir::type* storage_declaration_specifier::type(ir::module *mod) const { + return decl_spec_->type(mod); +} + +std::vector storage_declaration_specifier::storage() const { + auto result = decl_spec_->storage(); + result.push_back(storage_spec_); + return result; +} + + +/* Parameter */ +ir::type* parameter::type(ir::module *mod) const { + return decl_->type(mod, spec_->type(mod), {}); +} + +std::vector parameter::storage() const { + return spec_->storage(); +} + +const identifier *parameter::id() const { + return decl_->id(); +} + +/* Declarators */ +ir::type* declarator::type(ir::module *mod, ir::type *type, storage_spec_vec_const_ref_t storage) const{ + if(ptr_) + return type_impl(mod, ptr_->type(mod, type, storage), storage); + return type_impl(mod, type, storage); +} + +// Identifier +ir::type* identifier::type_impl(ir::module *, ir::type *type, storage_spec_vec_const_ref_t) const{ + return type; +} + +const std::string &identifier::name() const{ + return name_; +} + +// Tile +ir::type* tile::type_impl(ir::module *mod, ir::type *type, storage_spec_vec_const_ref_t) const{ + ir::type::tile_shapes_t shapes; + for(expression *expr: shapes_->values()){ + ir::constant_int *shape = dynamic_cast(expr->codegen(mod)); + assert(shape); + shapes.push_back(shape); + } + return ir::tile_type::get(type, shapes); +} + + +// Pointer +ir::type* pointer::type_impl(ir::module*, ir::type *type, storage_spec_vec_const_ref_t storage) const{ + bool is_ptr_to_const = std::find(storage.begin(), storage.end(), CONSTANT_SPACE_T) != storage.end(); + return ir::pointer_type::get(type, is_ptr_to_const?4:1); +} + +// Function +void function::bind_parameters(ir::module *mod, ir::function *fn) const{ + std::vector args = fn->args(); + assert(args.size() == args_->values().size()); + for(size_t i = 0; i < args.size(); i++){ + parameter *param_i = args_->values().at(i); + const identifier *id_i = param_i->id(); + if(id_i){ + args[i]->set_name(id_i->name()); + mod->set_value(id_i->name(), nullptr, args[i]); + mod->get_scope().types[id_i->name()] = args[i]->get_type(); + } + } +} + +ir::type* function::type_impl(ir::module* mod, ir::type *type, storage_spec_vec_const_ref_t) const{ + std::vector types; + for(parameter* param: args_->values()) + types.push_back(param->type(mod)); + return ir::function_type::get(type, types); +} + + +/* Declaration */ +ir::value* declaration::codegen(ir::module* mod) const{ + for(initializer *init: init_->values()) + init->set_specifier(spec_); + init_->codegen(mod); + return nullptr; +} + +/* Initializer */ +ir::type* initializer::type_impl(ir::module *mod, ir::type *type, storage_spec_vec_const_ref_t storage) const{ + return decl_->type(mod, type, storage); +} + +void initializer::set_specifier(const declaration_specifier *spec) { + spec_ = spec; +} + +ir::value* initializer::codegen(ir::module * mod) const{ + std::vector storage = spec_->storage(); + ir::type *ty = decl_->type(mod, spec_->type(mod), storage); + std::string name = decl_->id()->name(); + ir::value *value = ir::undef_value::get(ty); + if(std::find(storage.begin(), storage.end(), TUNABLE_T) != storage.end()){ + auto csts = dynamic_cast*>((node*)expr_); + if(csts == nullptr) + throw std::runtime_error("must specify constant list for metaparameters"); + std::vector values; + for(constant* cst: csts->values()) + values.push_back(cst->value()); + value = ir::metaparameter::create(mod->get_context(), ty, values); + mod->register_global(name, value); + } + else if(expr_){ + value = expr_->codegen(mod); + value = explicit_cast(mod->get_builder(), value, ty); + implicit_broadcast(mod, ty, value); + } + value->set_name(name); + mod->set_value(name, value); + mod->get_scope().types[name] = ty; + if(auto *x = dynamic_cast(value)) + mod->add_alloc(x); + if(std::find(storage.begin(), storage.end(), CONST_T) != storage.end()) + mod->set_const(name); + return value; +} + +/* Type name */ +ir::type *type_name::type(ir::module *mod) const{ + return decl_->type(mod, spec_->type(mod), {}); +} + +/* Function definition */ +ir::attribute_t get_ir_attr(STORAGE_SPEC_T spec){ + switch(spec){ + case RESTRICT_T: return ir::noalias; + case READONLY_T: return ir::readonly; + case WRITEONLY_T: return ir::writeonly; + default: throw std::runtime_error("cannot convert storage specifier to IR function attribute"); + } +} + +ir::value* function_definition::codegen(ir::module *mod) const{ + ir::function_type *prototype = (ir::function_type*)header_->type(mod, spec_->type(mod), spec_->storage()); + const std::string &name = header_->id()->name(); + ir::function *fn = mod->get_or_insert_function(name, prototype); + for(unsigned i = 0; i < header_->get_num_args(); i++){ + parameter *param = header_->get_arg(i); + std::vector storage = param->storage(); + for(STORAGE_SPEC_T spec: storage) + fn->add_attr(1 + i, get_ir_attr(spec)); + } + header_->bind_parameters(mod, fn); + ir::basic_block *entry = ir::basic_block::create(mod->get_context(), "entry", fn); + mod->seal_block(entry); + mod->get_builder().set_insert_point(entry); + body_->codegen(mod); + mod->get_builder().create_ret_void(); + return nullptr; +} + +} + +} diff --git a/lib/ast/error.cpp b/lib/ast/error.cpp new file mode 100644 index 000000000..72c18277d --- /dev/null +++ b/lib/ast/error.cpp @@ -0,0 +1,49 @@ +#include "triton/ast/error.h" + + +namespace triton{ + +namespace ast{ + +static int current_line = 0; +static int current_column = 0; + +// begin token +void update_location(const char *text) { + for (int i = 0; text[i] != '\0'; i++){ + if (text[i] == '\n'){ + current_column = 0; + current_line++; + } + else if (text[i] == '\t') + current_column += 8 - (current_column % 8); + else + current_column++; + } +} + +void print_error(const char *cerror) { + std::string error(cerror); + auto it = error.find("syntax error,"); + error.replace(it, 13, ""); + std::cerr << "error at line " << current_line << " (column " << current_column << "): " << error << std::endl; + throw std::runtime_error("compilation failed"); +} + +char return_impl(char t, const char * yytext) { + update_location(yytext); + return t; +} + +yytokentype return_impl(yytokentype t, const char * yytext){ + update_location(yytext); + return t; +} + +void return_void(const char * yytext){ + update_location(yytext); +} + +} + +} diff --git a/lib/ast/expression.cpp b/lib/ast/expression.cpp new file mode 100644 index 000000000..7b6f43429 --- /dev/null +++ b/lib/ast/expression.cpp @@ -0,0 +1,329 @@ +#include "triton/ast/expression.h" +#include "triton/ast/declaration.h" +#include "triton/ir/constant.h" +#include "triton/ir/module.h" +#include "triton/ir/builder.h" +#include "triton/ir/type.h" + + +namespace triton{ + +namespace ast{ + + +/* Binary operator */ +ir::value *binary_expression::llvm_op(ir::module *mod, ir::builder &builder, ir::value *lhs, ir::value *rhs, const std::string &name) const +{ + bool is_float = false, is_ptr = false, is_int = false, is_signed = false; + implicit_cast(builder, lhs, rhs, is_float, is_ptr, is_int, is_signed); + implicit_broadcast(mod, lhs, rhs); + if(op_==MUL && is_float) + return builder.create_fmul(lhs, rhs, name); + if(op_==MUL && is_int) + return builder.create_mul(lhs, rhs, name); + if(op_==DIV && is_float) + return builder.create_fdiv(lhs, rhs, name); + if(op_==DIV && is_int && is_signed) + return builder.create_sdiv(lhs, rhs, name); + if(op_==DIV && is_int && !is_signed) + return builder.create_udiv(lhs, rhs, name); + if(op_==MOD && is_float) + return builder.create_frem(lhs, rhs, name); + if(op_==MOD && is_int && is_signed) + return builder.create_srem(lhs, rhs, name); + if(op_==MOD && is_int && !is_signed) + return builder.create_urem(lhs, rhs, name); + if(op_==ADD && is_float) + return builder.create_fadd(lhs, rhs, name); + if(op_==ADD && is_int) + return builder.create_add(lhs, rhs); + if(op_==ADD && is_ptr) + return builder.create_gep(lhs, {rhs}); + if(op_==SUB && is_float) + return builder.create_fsub(lhs, rhs, name); + if(op_==SUB && is_int) + return builder.create_sub(lhs, rhs, name); + if(op_==SUB && is_ptr) + return builder.create_gep(lhs, {builder.create_neg(rhs)}); + if(op_==LEFT_SHIFT) + return builder.create_shl(lhs, rhs, name); + if(op_==RIGHT_SHIFT) + return builder.create_ashr(lhs, rhs, name); + if(op_ == LT && is_float) + return builder.create_fcmpOLT(lhs, rhs, name); + if(op_ == LT && is_int && is_signed) + return builder.create_icmpSLT(lhs, rhs, name); + if(op_ == LT && is_int && !is_signed) + return builder.create_icmpULT(lhs, rhs, name); + if(op_ == GT && is_float) + return builder.create_fcmpOGT(lhs, rhs, name); + if(op_ == GT && is_int && is_signed) + return builder.create_icmpSGT(lhs, rhs, name); + if(op_ == GT && is_int && !is_signed) + return builder.create_icmpUGT(lhs, rhs, name); + if(op_ == LE && is_float) + return builder.create_fcmpOLE(lhs, rhs, name); + if(op_ == LE && is_int && is_signed) + return builder.create_icmpSLE(lhs, rhs, name); + if(op_ == LE && is_int && !is_signed) + return builder.create_icmpULE(lhs, rhs, name); + if(op_ == GE && is_float) + return builder.create_fcmpOGE(lhs, rhs, name); + if(op_ == GE && is_int && is_signed) + return builder.create_icmpSGE(lhs, rhs, name); + if(op_ == GE && is_int && !is_signed) + return builder.create_icmpUGE(lhs, rhs, name); + if(op_ == EQ && is_float) + return builder.create_fcmpOEQ(lhs, rhs, name); + if(op_ == EQ && is_int) + return builder.create_icmpEQ(lhs, rhs, name); + if(op_ == NE && is_float) + return builder.create_fcmpONE(lhs, rhs, name); + if(op_ == NE && is_int) + return builder.create_icmpNE(lhs, rhs, name); + if(op_ == AND) + return builder.create_and(lhs, rhs, name); + if(op_ == XOR) + return builder.create_xor(lhs, rhs, name); + if(op_ == OR) + return builder.create_or(lhs, rhs, name); + if(op_ == LAND) + return builder.create_and(lhs, rhs, name); + if(op_ == LOR) + return builder.create_or(lhs, rhs, name); + throw std::runtime_error("unreachable"); +} + +ir::value* binary_expression::codegen(ir::module *mod) const{ + ir::value *lhs = lhs_->codegen(mod); + ir::value *rhs = rhs_->codegen(mod); + ir::value *result = llvm_op(mod, mod->get_builder(), lhs, rhs, ""); + return result; +} + +/* Builtin expression */ + +// alloc constant +ir::value* alloc_const_expression::codegen(ir::module *mod) const { + ir::type *ty = spec_->type(mod); + ir::constant_int *size = (ir::constant_int*)size_->codegen(mod); + ir::alloc_const *res = new ir::alloc_const(ty, size); + return res; +} + +// get_global_range +ir::value* get_global_range_expression::codegen(ir::module *mod) const { + ir::builder &builder = mod->get_builder(); + return builder.create_get_global_range(axis_->value(), (ir::constant_int*)size_->codegen(mod)); +} + +// get_range_id +ir::value* get_range_id_expression::codegen(ir::module *mod) const { + return mod->get_builder().create_get_range_id(axis_->value()); +} + +// atomic cas +ir::value* atomic_cas_expression::codegen(ir::module *mod) const { + ir::value *ptr = ptr_->codegen(mod); + ir::value *cmp = cmp_->codegen(mod); + ir::value *val = val_->codegen(mod); + return mod->get_builder().create_atomic_cas(ptr, cmp, val); +} + +// matmul +ir::value* matmul_expression::codegen(ir::module *mod) const { + ir::value *A = A_->codegen(mod); + ir::value *B = B_->codegen(mod); + ir::value *C = C_->codegen(mod); +// unsigned M = A->get_type()->get_tile_shapes()[0]; +// unsigned N = B->get_type()->get_tile_shapes()[1]; +// ir::type *scalar_ty = A->get_type()->get_scalar_ty(); +// ir::type *tile_ty = ir::tile_type::get(scalar_ty, {M, N}); +// ir::value *tmp = ir::undef_value::get(tile_ty); +// implicit_broadcast(mod, tmp, C); + return mod->get_builder().create_dot(A, B, C); +} + +// min +ir::value* min_expression::codegen(ir::module *mod) const { + ir::value* cmp = binary_expression(LT, (node*)x_, (node*)y_).codegen(mod); + ir::value* x = ((ir::cmp_inst*)cmp)->get_operand(0); + ir::value* y = ((ir::cmp_inst*)cmp)->get_operand(1); + return mod->get_builder().create_select(cmp, x, y); +} + +// max +ir::value* max_expression::codegen(ir::module *mod) const { + ir::value* cmp = binary_expression(GT, (node*)x_, (node*)y_).codegen(mod); + ir::value* x = ((ir::cmp_inst*)cmp)->get_operand(0); + ir::value* y = ((ir::cmp_inst*)cmp)->get_operand(1); + return mod->get_builder().create_select(cmp, x, y); +} + +// select +ir::value* select_expression::codegen(ir::module *mod) const { + ir::value* pred = pred_->codegen(mod); + ir::value* if_value = if_value_->codegen(mod); + ir::value* else_value = else_value_->codegen(mod); + return mod->get_builder().create_select(pred, if_value, else_value); +} + +// Trans +ir::value* trans_expression::codegen(ir::module *mod) const { + return mod->get_builder().create_trans(arg_->codegen(mod)); +} + +/* Postfix expression */ +ir::value* indexing_expression::codegen(ir::module *mod) const{ + ir::value *in = mod->get_value(id_->name()); + const std::vector &slices = slices_->values(); + auto in_shapes = in->get_type()->get_tile_shapes(); + ir::type::tile_shapes_t::value_type one = ir::tile_type::make_one(mod->get_context()); + ir::type::tile_shapes_t out_shapes(slices.size()); + // create shapes + size_t current = 0; + for(size_t i = 0; i < out_shapes.size(); i++) + out_shapes[i] = (slices[i]->type()==NEWAXIS)?one:in_shapes[current++]; + return mod->get_builder().create_reshape(in, out_shapes); +} + + +/* Unary operator */ +ir::value *unary_expression::llvm_op(ir::builder &builder, ir::value *arg, const std::string &name) const{ + ir::type *atype = arg->get_type(); + bool is_float = atype->is_floating_point_ty(); + bool is_int = atype->is_integer_ty(); + if(op_ == INC) + return builder.create_add(arg, builder.get_int32(1), name); + if(op_ == DEC) + return builder.create_sub(arg, builder.get_int32(1), name); + if(op_ == PLUS) + return arg; + if(op_ == MINUS && is_float) + return builder.create_fneg(arg, name); + if(op_ == MINUS && is_int) + return builder.create_neg(arg, name); + if(op_ == ADDR) + throw std::runtime_error("not supported"); + if(op_ == DEREF) + return builder.create_load(arg, name); + if(op_ == COMPL) + throw std::runtime_error("not supported"); + if(op_ == NOT) + return builder.create_not(arg, name); + throw std::runtime_error("unreachable"); +} + +ir::value* unary_expression::codegen(ir::module *mod) const{ + ir::value *arg = arg_->codegen(mod); + ir::value *result = llvm_op(mod->get_builder(), arg, ""); + return result; +} + +/* Cast operator */ +ir::value *cast_expression::llvm_op(ir::builder &builder, ir::type *T, ir::value *arg, const std::string &name) const{ + return nullptr; +} + +ir::value* cast_expression::codegen(ir::module *mod) const{ + ir::value *arg = arg_->codegen(mod); + ir::type *T = T_->type(mod); + return llvm_op(mod->get_builder(), T, arg, ""); +} + +/* Conditional expression */ +ir::value *conditional_expression::codegen(ir::module *mod) const{ + ir::builder &builder = mod->get_builder(); + ir::value *pred = cond_->codegen(mod); + ir::instruction *mask = (ir::instruction*)builder.create_mask(pred); + ir::value *true_mask = mask->get_result(0); + ir::value *false_mask = mask->get_result(1); + ir::value *true_value = true_value_->codegen(mod); + ir::value *false_value = false_value_->codegen(mod); + if(auto *itn = dynamic_cast(true_value)) + itn->set_mask_pred(true_mask); + if(auto *itn = dynamic_cast(false_value)) + itn->set_mask_pred(false_mask); + bool is_float, is_ptr, is_int, is_signed; + ir::value *uncasted_true_value = true_value; + ir::value *uncasted_false_value = false_value; + implicit_cast(builder, true_value, false_value, is_float, is_ptr, is_int, is_signed); + implicit_broadcast(mod, true_value, false_value); + { + ir::value *current = true_value; + while(current != uncasted_true_value) { + if(auto *itn = dynamic_cast(current)){ + itn->set_mask_pred(true_mask); + current = itn->get_operand(0); + } + else + break; + } + } + { + ir::value *current = false_value; + while(current != uncasted_false_value) { + if(auto *itn = dynamic_cast(current)){ + itn->set_mask_pred(false_mask); + current = itn->get_operand(0); + } + else + break; + } + } + ir::value *result = builder.create_merge(true_mask, true_value, false_mask, false_value); + return result; +} + +/* Assignment expression */ +ir::value *assignment_expression::codegen(ir::module *mod) const{ + ir::value *rvalue = rvalue_->codegen(mod); + if(auto *x = dynamic_cast(lvalue_)){ + ir::type *ty = mod->get_scope().types.at(x->id()->name()); + rvalue = explicit_cast(mod->get_builder(), rvalue, ty); + implicit_broadcast(mod, ty, rvalue); + mod->set_value(x->id()->name(), rvalue); + } + else if(auto* x = dynamic_cast(lvalue_)){ + assert(x->get_op()==DEREF); + assert(x->lvalue()); + ir::value *ptr = x->lvalue()->codegen(mod); + rvalue = mod->get_builder().create_store(ptr, rvalue); + } + return rvalue; +} + + +/* String literal */ +ir::value* string_literal::codegen(ir::module *) const{ + throw std::runtime_error("not supported"); +// return ir::constant_data_array::get_string(mod->get_context(), value_); +} + +/* Constant */ +ir::value* constant::codegen(ir::module *mod) const{ + return mod->get_builder().get_int32(value_); +} + +int constant::value() const{ + return value_; +} + +/* Constant range */ +ir::value* constant_range::codegen(ir::module *mod) const{ + return ir::constant_range::get((ir::constant_int*)first_->codegen(mod), + (ir::constant_int*)last_->codegen(mod)); +} + +/* Named */ +ir::value* named_expression::codegen(ir::module *mod) const{ + const std::string &name = id()->name(); + const auto& declarations = mod->get_scope().types; + if(declarations.find(name) == declarations.end()) + throw std::runtime_error("variable " + name + " not declared"); + return mod->get_value(name); +} + +} + +} diff --git a/lib/ast/lowering.cpp b/lib/ast/lowering.cpp deleted file mode 100644 index 3f8623e1c..000000000 --- a/lib/ast/lowering.cpp +++ /dev/null @@ -1,855 +0,0 @@ -#include -#include -#include "triton/ast/ast.h" -#include "triton/ir/constant.h" -#include "triton/ir/function.h" -#include "triton/ir/module.h" -#include "triton/ir/basic_block.h" -#include "triton/ir/builder.h" -#include "triton/ir/type.h" -#include -#include - - -namespace triton{ - -namespace ast{ - -static int current_line = 0; -static int current_column = 0; - -/* node */ -ir::value *node::explicit_cast(ir::builder &builder, ir::value *src, ir::type *dst_ty){ - ir::type *src_scalar_ty = src->get_type()->get_scalar_ty(); - ir::type *dst_scalar_ty = dst_ty->get_scalar_ty(); - bool src_signed = false; - bool dst_signed = false; - if(src_scalar_ty == dst_scalar_ty) - return src; - else if(src_scalar_ty->is_integer_ty() && src_signed && dst_scalar_ty->is_floating_point_ty()) - return builder.create_si_to_fp(src, dst_ty); - - else if(src_scalar_ty->is_integer_ty() && !src_signed && dst_scalar_ty->is_floating_point_ty()) - return builder.create_ui_to_fp(src, dst_ty); - - else if(src_scalar_ty->is_floating_point_ty() && dst_scalar_ty->is_integer_ty() && dst_signed) - return builder.create_fp_to_si(src, dst_ty); - - else if(src_scalar_ty->is_floating_point_ty() && dst_scalar_ty->is_integer_ty() && !dst_signed) - return builder.create_fp_to_ui(src, dst_ty); - - else if(src_scalar_ty->is_floating_point_ty() && dst_scalar_ty->is_floating_point_ty() && - src_scalar_ty->get_fp_mantissa_width() < dst_scalar_ty->get_fp_mantissa_width()) - return builder.create_fp_ext(src, dst_ty); - - else if(src_scalar_ty->is_floating_point_ty() && dst_scalar_ty->is_floating_point_ty() && - src_scalar_ty->get_fp_mantissa_width() > dst_scalar_ty->get_fp_mantissa_width()) - return builder.create_fp_trunc(src, dst_ty); - - else if(src_scalar_ty->is_integer_ty() && dst_scalar_ty->is_integer_ty() && - src_scalar_ty->get_integer_bitwidth()) - return builder.create_int_cast(src, dst_ty, dst_signed); - - else - throw std::runtime_error("unreachable"); -} - - -void node::implicit_cast(ir::builder &builder, ir::value *&lhs, ir::value *&rhs, - bool &is_float, bool &is_ptr, bool &is_int, bool &is_signed){ - // Input types - ir::type *left_ty = lhs->get_type()->get_scalar_ty(); - ir::type *right_ty = rhs->get_type()->get_scalar_ty(); - // One operand is pointer - if(left_ty->is_pointer_ty() || right_ty->is_pointer_ty()){ - if(left_ty->is_pointer_ty() && right_ty->is_pointer_ty()) - throw std::runtime_error("invalid operands"); - if(right_ty->is_pointer_ty()) - std::swap(lhs, rhs); - is_ptr = true; - } - // One operand is double - else if(left_ty->is_double_ty() || right_ty->is_double_ty()){ - ir::value *&to_convert = left_ty->is_double_ty()?rhs:lhs; - to_convert = explicit_cast(builder, to_convert, builder.get_double_ty()); - is_float = true; - } - // One operand is float - else if(left_ty->is_float_ty() || right_ty->is_float_ty()){ - ir::value *&to_convert = left_ty->is_float_ty()?rhs:lhs; - to_convert = explicit_cast(builder, to_convert, builder.get_float_ty()); - is_float = true; - } - // Both operands are integers - else if(left_ty->is_integer_ty() && right_ty->is_integer_ty()){ - is_int = true; - is_signed = true; // always signed for now - if(left_ty->get_integer_bitwidth() != right_ty->get_integer_bitwidth()){ - ir::value *&to_convert = (left_ty->get_integer_bitwidth() > right_ty->get_integer_bitwidth())?rhs:lhs; - ir::type *dst_ty = (to_convert==lhs)?right_ty:left_ty; - to_convert = explicit_cast(builder, to_convert, dst_ty); - } - } - // Not reachable - else - throw std::runtime_error("unreachable"); -} - -void node::implicit_broadcast(ir::module *mod, ir::value *&lhs, ir::value *&rhs) { - ir::type *lhs_ty = lhs->get_type(); - ir::type *rhs_ty = rhs->get_type(); - ir::type *res_ty = nullptr; - if(!lhs_ty->is_tile_ty() && !rhs_ty->is_tile_ty()) - return; - else if(lhs_ty->is_tile_ty() && !rhs_ty->is_tile_ty()) - res_ty = lhs_ty; - else if(!lhs_ty->is_tile_ty() && rhs_ty->is_tile_ty()) - res_ty = rhs_ty; - else{ - auto lhs_shapes = lhs_ty->get_tile_shapes(); - auto rhs_shapes = rhs_ty->get_tile_shapes(); - size_t lhs_size = lhs_shapes.size(); - size_t rhs_size = rhs_shapes.size(); - size_t res_size = std::max(lhs_size, rhs_size); - ir::type::tile_shapes_t res_shapes(res_size); - ir::type::tile_shapes_t::value_type one = ir::tile_type::make_one(mod->get_context()); - for(int i = 0; i < res_size; i++){ - if(i >= res_size - lhs_size && i >= res_size - rhs_size) - res_shapes[i] = lhs_shapes[i]==one?rhs_shapes[i]:lhs_shapes[i]; - else if(i >= res_size - lhs_size) - res_shapes[i] = lhs_shapes[i]; - else if(i >= res_size - rhs_size) - res_shapes[i] = rhs_shapes[i]; - } - res_ty = ir::tile_type::get(lhs_ty->get_scalar_ty(), res_shapes); - } - implicit_broadcast(mod, res_ty, rhs); - implicit_broadcast(mod, res_ty, lhs); -} - -void node::implicit_broadcast(ir::module *mod, ir::type *ty, ir::value *&src){ - ir::builder &builder = mod->get_builder(); - ir::type *src_ty = src->get_type(); - ir::type::tile_shapes_t::value_type one = ir::tile_type::make_one(mod->get_context()); - // Both are scalar - if(!ty->is_tile_ty() && !src_ty->is_tile_ty()) - return; - // Broadcast scalar - if(ty->is_tile_ty() && !src_ty->is_tile_ty()){ - src = builder.create_splat(src, ty->get_tile_shapes()); - return; - } - // Downcast tile - if(!ty->is_tile_ty() && src_ty->is_tile_ty()){ - for(ir::constant *shape: src_ty->get_tile_shapes()) - if(shape != one) - throw std::runtime_error("cannot downcast"); - src = builder.create_downcast(src); - return; - } - // Both are arrays - auto dst_shapes = ty->get_tile_shapes(); - auto src_shapes = src_ty->get_tile_shapes(); - int dst_dim = dst_shapes.size(); - int src_dim = src_shapes.size(); - // Pad - int off = dst_dim - src_dim; - for(size_t i = 0; i < off; i++) - src_shapes.insert(src_shapes.begin(), one); - if(off > 0) - src = builder.create_reshape(src, src_shapes); - // Broadcast - for(int i = dst_dim - 1; i>= 0; i--) - if(dst_shapes[i] != src_shapes[i] && dst_shapes[i] != one && src_shapes[i] != one) - throw std::runtime_error("cannot broadcast"); - if(dst_shapes != src_shapes) - src = builder.create_broadcast(src, dst_shapes); -} - -/* Helper */ -inline bool is_terminator(ir::value* x) { - return x && dynamic_cast(x); -} - -/* Translation unit */ -ir::value* translation_unit::codegen(ir::module *mod) const{ - mod->add_new_scope(); - decls_.codegen(mod); - return nullptr; -} - -/* Declaration specifier */ -ir::type* typed_declaration_specifier::type(ir::module *mod) const { - ir::context &ctx = mod->get_context(); - switch (ty_) { - case VOID_T: return ir::type::get_void_ty(ctx); - case INT1_T: return ir::type::get_int1_ty(ctx); - case INT8_T: return ir::type::get_int8_ty(ctx); - case INT16_T: return ir::type::get_int16_ty(ctx); - case INT32_T: return ir::type::get_int32_ty(ctx); - case INT64_T: return ir::type::get_int64_ty(ctx); - case FLOAT32_T: return ir::type::get_float_ty(ctx); - case FLOAT64_T: return ir::type::get_double_ty(ctx); - default: throw std::runtime_error("unreachable"); - } -} - -std::vector typed_declaration_specifier::storage() const { - return {}; -} - - -ir::type* storage_declaration_specifier::type(ir::module *mod) const { - return decl_spec_->type(mod); -} - -std::vector storage_declaration_specifier::storage() const { - auto result = decl_spec_->storage(); - result.push_back(storage_spec_); - return result; -} - - -/* Parameter */ -ir::type* parameter::type(ir::module *mod) const { - return decl_->type(mod, spec_->type(mod), {}); -} - -std::vector parameter::storage() const { - return spec_->storage(); -} - -const identifier *parameter::id() const { - return decl_->id(); -} - -/* Declarators */ -ir::type* declarator::type(ir::module *mod, ir::type *type, storage_spec_vec_const_ref_t storage) const{ - if(ptr_) - return type_impl(mod, ptr_->type(mod, type, storage), storage); - return type_impl(mod, type, storage); -} - -// Identifier -ir::type* identifier::type_impl(ir::module *, ir::type *type, storage_spec_vec_const_ref_t) const{ - return type; -} - -const std::string &identifier::name() const{ - return name_; -} - -// Tile -ir::type* tile::type_impl(ir::module *mod, ir::type *type, storage_spec_vec_const_ref_t) const{ - ir::type::tile_shapes_t shapes; - for(expression *expr: shapes_->values()){ - ir::constant_int *shape = dynamic_cast(expr->codegen(mod)); - assert(shape); - shapes.push_back(shape); - } - return ir::tile_type::get(type, shapes); -} - - -// Pointer -ir::type* pointer::type_impl(ir::module*, ir::type *type, storage_spec_vec_const_ref_t storage) const{ - bool is_ptr_to_const = std::find(storage.begin(), storage.end(), CONSTANT_SPACE_T) != storage.end(); - return ir::pointer_type::get(type, is_ptr_to_const?4:1); -} - -// Function -void function::bind_parameters(ir::module *mod, ir::function *fn) const{ - std::vector args = fn->args(); - assert(args.size() == args_->values().size()); - for(size_t i = 0; i < args.size(); i++){ - parameter *param_i = args_->values().at(i); - const identifier *id_i = param_i->id(); - if(id_i){ - args[i]->set_name(id_i->name()); - mod->set_value(id_i->name(), nullptr, args[i]); - mod->get_scope().types[id_i->name()] = args[i]->get_type(); - } - } -} - -ir::type* function::type_impl(ir::module* mod, ir::type *type, storage_spec_vec_const_ref_t) const{ - std::vector types; - for(parameter* param: args_->values()) - types.push_back(param->type(mod)); - return ir::function_type::get(type, types); -} - -/* Function definition */ -ir::attribute_t get_ir_attr(STORAGE_SPEC_T spec){ - switch(spec){ - case RESTRICT_T: return ir::noalias; - case READONLY_T: return ir::readonly; - case WRITEONLY_T: return ir::writeonly; - default: throw std::runtime_error("cannot convert storage specifier to IR function attribute"); - } -} - -ir::value* function_definition::codegen(ir::module *mod) const{ - ir::function_type *prototype = (ir::function_type*)header_->type(mod, spec_->type(mod), spec_->storage()); - const std::string &name = header_->id()->name(); - ir::function *fn = mod->get_or_insert_function(name, prototype); - for(unsigned i = 0; i < header_->get_num_args(); i++){ - parameter *param = header_->get_arg(i); - std::vector storage = param->storage(); - for(STORAGE_SPEC_T spec: storage) - fn->add_attr(1 + i, get_ir_attr(spec)); - } - header_->bind_parameters(mod, fn); - ir::basic_block *entry = ir::basic_block::create(mod->get_context(), "entry", fn); - mod->seal_block(entry); - mod->get_builder().set_insert_point(entry); - body_->codegen(mod); - mod->get_builder().create_ret_void(); - return nullptr; -} - -/* Statements */ -ir::value* compound_statement::codegen(ir::module* mod) const{ - mod->add_new_scope(); - if(items_) - items_->codegen(mod); - mod->pop_scope(); - return nullptr; -} - -/* expression statement */ -ir::value* expression_statement::codegen(ir::module *mod) const{ - ir::builder &builder = mod->get_builder(); - ir::basic_block *block = builder.get_insert_block(); - if(pred_) { - // check that it is an assignment - assignment_expression *assignment = dynamic_cast(expr_); - assert(assignment); - // generate mask - ir::value *pred = pred_->codegen(mod); - ir::mask_inst *mask = (ir::mask_inst*)builder.create_mask(pred); - // generate expression - unsigned szbegin = block->get_inst_list().size(); - ir::value *expr = expr_->codegen(mod); - ir::basic_block::iterator begin = block->begin(); - std::advance(begin, szbegin); - // set mask - ir::type *ty = expr->get_type(); - for(auto it = begin; it != builder.get_insert_point(); it++) - (*it)->set_mask_pred(mask->get_result(0)); -// if(auto *itn = dynamic_cast(expr)) -// itn->set_mask_pred(mask->get_result(0)); - if(ty->is_void_ty()) - return expr; - // merge with psi - ir::psi_inst *psi = (ir::psi_inst*)builder.create_merge(mask->get_result(0), expr, - mask->get_result(1), ir::undef_value::get(ty)); - std::string name = ((named_expression*)assignment->lvalue())->id()->name(); - mod->set_value(name, psi); - return psi; - } - return expr_->codegen(mod); -} - -/* For statement */ -ir::value* iteration_statement::codegen(ir::module *mod) const{ - ir::builder &builder = mod->get_builder(); - ir::context &ctx = mod->get_context(); - ir::basic_block *current_bb = builder.get_insert_block(); - ir::function *fn = current_bb->get_parent(); - ir::basic_block *loop_bb = ir::basic_block::create(ctx, "loop", fn); - ir::basic_block *next_bb = ir::basic_block::create(ctx, "postloop", fn); - mod->set_continue_fn([&](){ - if(exec_) - exec_->codegen(mod); - ir::value *cond = explicit_cast(builder, stop_->codegen(mod), ir::type::get_int1_ty(ctx)); - return builder.create_cond_br(cond, loop_bb, next_bb); - }); - init_->codegen(mod); - ir::value *cond = explicit_cast(builder, stop_->codegen(mod), ir::type::get_int1_ty(ctx)); - builder.create_cond_br(cond, loop_bb, next_bb); -// builder.create_br(loop_bb); - builder.set_insert_point(loop_bb); - if(!is_terminator(statements_->codegen(mod))) - mod->get_continue_fn()(); - ir::basic_block *stop_bb = builder.get_insert_block(); - mod->seal_block(stop_bb); - mod->seal_block(loop_bb); - mod->seal_block(builder.get_insert_block()); - mod->seal_block(next_bb); - builder.set_insert_point(next_bb); - return nullptr; -} - -/* While statement */ -ir::value* while_statement::codegen(ir::module* mod) const{ - ir::builder &builder = mod->get_builder(); - ir::context &ctx = mod->get_context(); - ir::basic_block *current_bb = builder.get_insert_block(); - ir::function *fn = current_bb->get_parent(); - ir::basic_block *loop_bb = ir::basic_block::create(ctx, "loop", fn); - ir::basic_block *next_bb = ir::basic_block::create(ctx, "postloop", fn); - mod->set_continue_fn([&](){ - ir::value *cond = explicit_cast(builder, cond_->codegen(mod), ir::type::get_int1_ty(ctx)); - return builder.create_cond_br(cond, loop_bb, next_bb); - }); - ir::value *cond = explicit_cast(builder, cond_->codegen(mod), ir::type::get_int1_ty(ctx)); - builder.create_cond_br(cond, loop_bb, next_bb); - builder.set_insert_point(loop_bb); - if(!is_terminator(statements_->codegen(mod))) - mod->get_continue_fn()(); - ir::basic_block *stop_bb = builder.get_insert_block(); - mod->seal_block(stop_bb); - mod->seal_block(loop_bb); - mod->seal_block(builder.get_insert_block()); - mod->seal_block(next_bb); - builder.set_insert_point(next_bb); - return nullptr; -} - -/* Selection statement */ -ir::value* selection_statement::codegen(ir::module* mod) const{ - ir::builder &builder = mod->get_builder(); - ir::context &ctx = mod->get_context(); - ir::function *fn = builder.get_insert_block()->get_parent(); - ir::value *cond = cond_->codegen(mod); - ir::basic_block *then_bb = ir::basic_block::create(ctx, "then", fn); - ir::basic_block *else_bb = else_value_?ir::basic_block::create(ctx, "else", fn):nullptr; - ir::basic_block *endif_bb = ir::basic_block::create(ctx, "endif", fn); - mod->seal_block(then_bb); - if(else_value_) - mod->seal_block(else_bb); - - // Branch - if(else_value_) - builder.create_cond_br(cond, then_bb, else_bb); - else - builder.create_cond_br(cond, then_bb, endif_bb); - // Then - builder.set_insert_point(then_bb); - if(!is_terminator(then_value_->codegen(mod))) - builder.create_br(endif_bb); - // Else - if(else_value_){ - builder.set_insert_point(else_bb); - if(!is_terminator(else_value_->codegen(mod))) - builder.create_br(endif_bb); - } - // Endif - mod->seal_block(endif_bb); - builder.set_insert_point(endif_bb); - return nullptr; -} - -/* Continue statement */ -ir::value* continue_statement::codegen(ir::module *mod) const{ - return mod->get_continue_fn()(); -} - -/* Declaration */ -ir::value* declaration::codegen(ir::module* mod) const{ - for(initializer *init: init_->values()) - init->set_specifier(spec_); - init_->codegen(mod); - return nullptr; -} - -/* Initializer */ -ir::type* initializer::type_impl(ir::module *mod, ir::type *type, storage_spec_vec_const_ref_t storage) const{ - return decl_->type(mod, type, storage); -} - -void initializer::set_specifier(const declaration_specifier *spec) { - spec_ = spec; -} - -ir::value* initializer::codegen(ir::module * mod) const{ - std::vector storage = spec_->storage(); - ir::type *ty = decl_->type(mod, spec_->type(mod), storage); - std::string name = decl_->id()->name(); - ir::value *value = ir::undef_value::get(ty); - if(std::find(storage.begin(), storage.end(), TUNABLE_T) != storage.end()){ - auto csts = dynamic_cast*>((node*)expr_); - if(csts == nullptr) - throw std::runtime_error("must specify constant list for metaparameters"); - std::vector values; - for(constant* cst: csts->values()) - values.push_back(cst->value()); - value = ir::metaparameter::create(mod->get_context(), ty, values); - mod->register_global(name, value); - } - else if(expr_){ - value = expr_->codegen(mod); - value = explicit_cast(mod->get_builder(), value, ty); - implicit_broadcast(mod, ty, value); - } - value->set_name(name); - mod->set_value(name, value); - mod->get_scope().types[name] = ty; - if(auto *x = dynamic_cast(value)) - mod->add_alloc(x); - if(std::find(storage.begin(), storage.end(), CONST_T) != storage.end()) - mod->set_const(name); - return value; -} - -/*------------------*/ -/* Expression */ -/*------------------*/ -/* Binary operator */ -ir::value *binary_operator::llvm_op(ir::module *mod, ir::builder &builder, ir::value *lhs, ir::value *rhs, const std::string &name) const -{ - bool is_float = false, is_ptr = false, is_int = false, is_signed = false; - implicit_cast(builder, lhs, rhs, is_float, is_ptr, is_int, is_signed); - implicit_broadcast(mod, lhs, rhs); - if(op_==MUL && is_float) - return builder.create_fmul(lhs, rhs, name); - if(op_==MUL && is_int) - return builder.create_mul(lhs, rhs, name); - if(op_==DIV && is_float) - return builder.create_fdiv(lhs, rhs, name); - if(op_==DIV && is_int && is_signed) - return builder.create_sdiv(lhs, rhs, name); - if(op_==DIV && is_int && !is_signed) - return builder.create_udiv(lhs, rhs, name); - if(op_==MOD && is_float) - return builder.create_frem(lhs, rhs, name); - if(op_==MOD && is_int && is_signed) - return builder.create_srem(lhs, rhs, name); - if(op_==MOD && is_int && !is_signed) - return builder.create_urem(lhs, rhs, name); - if(op_==ADD && is_float) - return builder.create_fadd(lhs, rhs, name); - if(op_==ADD && is_int) - return builder.create_add(lhs, rhs); - if(op_==ADD && is_ptr) - return builder.create_gep(lhs, {rhs}); - if(op_==SUB && is_float) - return builder.create_fsub(lhs, rhs, name); - if(op_==SUB && is_int) - return builder.create_sub(lhs, rhs, name); - if(op_==SUB && is_ptr) - return builder.create_gep(lhs, {builder.create_neg(rhs)}); - if(op_==LEFT_SHIFT) - return builder.create_shl(lhs, rhs, name); - if(op_==RIGHT_SHIFT) - return builder.create_ashr(lhs, rhs, name); - if(op_ == LT && is_float) - return builder.create_fcmpOLT(lhs, rhs, name); - if(op_ == LT && is_int && is_signed) - return builder.create_icmpSLT(lhs, rhs, name); - if(op_ == LT && is_int && !is_signed) - return builder.create_icmpULT(lhs, rhs, name); - if(op_ == GT && is_float) - return builder.create_fcmpOGT(lhs, rhs, name); - if(op_ == GT && is_int && is_signed) - return builder.create_icmpSGT(lhs, rhs, name); - if(op_ == GT && is_int && !is_signed) - return builder.create_icmpUGT(lhs, rhs, name); - if(op_ == LE && is_float) - return builder.create_fcmpOLE(lhs, rhs, name); - if(op_ == LE && is_int && is_signed) - return builder.create_icmpSLE(lhs, rhs, name); - if(op_ == LE && is_int && !is_signed) - return builder.create_icmpULE(lhs, rhs, name); - if(op_ == GE && is_float) - return builder.create_fcmpOGE(lhs, rhs, name); - if(op_ == GE && is_int && is_signed) - return builder.create_icmpSGE(lhs, rhs, name); - if(op_ == GE && is_int && !is_signed) - return builder.create_icmpUGE(lhs, rhs, name); - if(op_ == EQ && is_float) - return builder.create_fcmpOEQ(lhs, rhs, name); - if(op_ == EQ && is_int) - return builder.create_icmpEQ(lhs, rhs, name); - if(op_ == NE && is_float) - return builder.create_fcmpONE(lhs, rhs, name); - if(op_ == NE && is_int) - return builder.create_icmpNE(lhs, rhs, name); - if(op_ == AND) - return builder.create_and(lhs, rhs, name); - if(op_ == XOR) - return builder.create_xor(lhs, rhs, name); - if(op_ == OR) - return builder.create_or(lhs, rhs, name); - if(op_ == LAND) - return builder.create_and(lhs, rhs, name); - if(op_ == LOR) - return builder.create_or(lhs, rhs, name); - throw std::runtime_error("unreachable"); -} - -ir::value* binary_operator::codegen(ir::module *mod) const{ - ir::value *lhs = lhs_->codegen(mod); - ir::value *rhs = rhs_->codegen(mod); - ir::value *result = llvm_op(mod, mod->get_builder(), lhs, rhs, ""); - return result; -} - -/* Builtin expression */ - -// alloc constant -ir::value* alloc_const::codegen(ir::module *mod) const { - ir::type *ty = spec_->type(mod); - ir::constant_int *size = (ir::constant_int*)size_->codegen(mod); - ir::alloc_const *res = new ir::alloc_const(ty, size); - return res; -} - -// get_global_range -ir::value* get_global_range::codegen(ir::module *mod) const { - ir::builder &builder = mod->get_builder(); - return builder.create_get_global_range(axis_->value(), (ir::constant_int*)size_->codegen(mod)); -} - -// get_range_id -ir::value* get_range_id::codegen(ir::module *mod) const { - return mod->get_builder().create_get_range_id(axis_->value()); -} - -// atomic cas -ir::value* atomic_cas::codegen(ir::module *mod) const { - ir::value *ptr = ptr_->codegen(mod); - ir::value *cmp = cmp_->codegen(mod); - ir::value *val = val_->codegen(mod); - return mod->get_builder().create_atomic_cas(ptr, cmp, val); -} - -// matmul -ir::value* matmul_expression::codegen(ir::module *mod) const { - ir::value *A = A_->codegen(mod); - ir::value *B = B_->codegen(mod); - ir::value *C = C_->codegen(mod); -// unsigned M = A->get_type()->get_tile_shapes()[0]; -// unsigned N = B->get_type()->get_tile_shapes()[1]; -// ir::type *scalar_ty = A->get_type()->get_scalar_ty(); -// ir::type *tile_ty = ir::tile_type::get(scalar_ty, {M, N}); -// ir::value *tmp = ir::undef_value::get(tile_ty); -// implicit_broadcast(mod, tmp, C); - return mod->get_builder().create_dot(A, B, C); -} - -// min -ir::value* min_expression::codegen(ir::module *mod) const { - ir::value* cmp = binary_operator(LT, (node*)x_, (node*)y_).codegen(mod); - ir::value* x = ((ir::cmp_inst*)cmp)->get_operand(0); - ir::value* y = ((ir::cmp_inst*)cmp)->get_operand(1); - return mod->get_builder().create_select(cmp, x, y); -} - -// max -ir::value* max_expression::codegen(ir::module *mod) const { - ir::value* cmp = binary_operator(GT, (node*)x_, (node*)y_).codegen(mod); - ir::value* x = ((ir::cmp_inst*)cmp)->get_operand(0); - ir::value* y = ((ir::cmp_inst*)cmp)->get_operand(1); - return mod->get_builder().create_select(cmp, x, y); -} - -// select -ir::value* select_expression::codegen(ir::module *mod) const { - ir::value* pred = pred_->codegen(mod); - ir::value* if_value = if_value_->codegen(mod); - ir::value* else_value = else_value_->codegen(mod); - return mod->get_builder().create_select(pred, if_value, else_value); -} - -// Trans -ir::value* trans_expression::codegen(ir::module *mod) const { - return mod->get_builder().create_trans(arg_->codegen(mod)); -} - -/* Postfix expression */ -ir::value* indexing_expression::codegen(ir::module *mod) const{ - ir::value *in = mod->get_value(id_->name()); - const std::vector &slices = slices_->values(); - auto in_shapes = in->get_type()->get_tile_shapes(); - ir::type::tile_shapes_t::value_type one = ir::tile_type::make_one(mod->get_context()); - ir::type::tile_shapes_t out_shapes(slices.size()); - // create shapes - size_t current = 0; - for(size_t i = 0; i < out_shapes.size(); i++) - out_shapes[i] = (slices[i]->type()==NEWAXIS)?one:in_shapes[current++]; - return mod->get_builder().create_reshape(in, out_shapes); -} - - -/* Unary operator */ -ir::value *unary_operator::llvm_op(ir::builder &builder, ir::value *arg, const std::string &name) const{ - ir::type *atype = arg->get_type(); - bool is_float = atype->is_floating_point_ty(); - bool is_int = atype->is_integer_ty(); - if(op_ == INC) - return builder.create_add(arg, builder.get_int32(1), name); - if(op_ == DEC) - return builder.create_sub(arg, builder.get_int32(1), name); - if(op_ == PLUS) - return arg; - if(op_ == MINUS && is_float) - return builder.create_fneg(arg, name); - if(op_ == MINUS && is_int) - return builder.create_neg(arg, name); - if(op_ == ADDR) - throw std::runtime_error("not supported"); - if(op_ == DEREF) - return builder.create_load(arg, name); - if(op_ == COMPL) - throw std::runtime_error("not supported"); - if(op_ == NOT) - return builder.create_not(arg, name); - throw std::runtime_error("unreachable"); -} - -ir::value* unary_operator::codegen(ir::module *mod) const{ - ir::value *arg = arg_->codegen(mod); - ir::value *result = llvm_op(mod->get_builder(), arg, ""); - return result; -} - -/* Cast operator */ -ir::value *cast_operator::llvm_op(ir::builder &builder, ir::type *T, ir::value *arg, const std::string &name) const{ - return nullptr; -} - -ir::value* cast_operator::codegen(ir::module *mod) const{ - ir::value *arg = arg_->codegen(mod); - ir::type *T = T_->type(mod); - return llvm_op(mod->get_builder(), T, arg, ""); -} - -/* Conditional expression */ -ir::value *conditional_expression::codegen(ir::module *mod) const{ - ir::builder &builder = mod->get_builder(); - ir::value *pred = cond_->codegen(mod); - ir::instruction *mask = (ir::instruction*)builder.create_mask(pred); - ir::value *true_mask = mask->get_result(0); - ir::value *false_mask = mask->get_result(1); - ir::value *true_value = true_value_->codegen(mod); - ir::value *false_value = false_value_->codegen(mod); - if(auto *itn = dynamic_cast(true_value)) - itn->set_mask_pred(true_mask); - if(auto *itn = dynamic_cast(false_value)) - itn->set_mask_pred(false_mask); - bool is_float, is_ptr, is_int, is_signed; - ir::value *uncasted_true_value = true_value; - ir::value *uncasted_false_value = false_value; - implicit_cast(builder, true_value, false_value, is_float, is_ptr, is_int, is_signed); - implicit_broadcast(mod, true_value, false_value); - { - ir::value *current = true_value; - while(current != uncasted_true_value) { - if(auto *itn = dynamic_cast(current)){ - itn->set_mask_pred(true_mask); - current = itn->get_operand(0); - } - else - break; - } - } - { - ir::value *current = false_value; - while(current != uncasted_false_value) { - if(auto *itn = dynamic_cast(current)){ - itn->set_mask_pred(false_mask); - current = itn->get_operand(0); - } - else - break; - } - } - ir::value *result = builder.create_merge(true_mask, true_value, false_mask, false_value); - return result; -} - -/* Assignment expression */ -ir::value *assignment_expression::codegen(ir::module *mod) const{ - ir::value *rvalue = rvalue_->codegen(mod); - if(auto *x = dynamic_cast(lvalue_)){ - ir::type *ty = mod->get_scope().types.at(x->id()->name()); - rvalue = explicit_cast(mod->get_builder(), rvalue, ty); - implicit_broadcast(mod, ty, rvalue); - mod->set_value(x->id()->name(), rvalue); - } - else if(auto* x = dynamic_cast(lvalue_)){ - assert(x->get_op()==DEREF); - assert(x->lvalue()); - ir::value *ptr = x->lvalue()->codegen(mod); - rvalue = mod->get_builder().create_store(ptr, rvalue); - } - return rvalue; -} - -/* Type name */ -ir::type *type_name::type(ir::module *mod) const{ - return decl_->type(mod, spec_->type(mod), {}); -} - -/* String literal */ -ir::value* string_literal::codegen(ir::module *) const{ - throw std::runtime_error("not supported"); -// return ir::constant_data_array::get_string(mod->get_context(), value_); -} - -/* Constant */ -ir::value* constant::codegen(ir::module *mod) const{ - return mod->get_builder().get_int32(value_); -} - -int constant::value() const{ - return value_; -} - -/* Constant range */ -ir::value* constant_range::codegen(ir::module *mod) const{ - return ir::constant_range::get((ir::constant_int*)first_->codegen(mod), - (ir::constant_int*)last_->codegen(mod)); -} - -/* Named */ -ir::value* named_expression::codegen(ir::module *mod) const{ - const std::string &name = id()->name(); - const auto& declarations = mod->get_scope().types; - if(declarations.find(name) == declarations.end()) - throw std::runtime_error("variable " + name + " not declared"); - return mod->get_value(name); -} - - -// begin token -void update_location(const char *text) { - for (int i = 0; text[i] != '\0'; i++){ - if (text[i] == '\n'){ - current_column = 0; - current_line++; - } - else if (text[i] == '\t') - current_column += 8 - (current_column % 8); - else - current_column++; - } -} - -void print_error(const char *cerror) { - std::string error(cerror); - auto it = error.find("syntax error,"); - error.replace(it, 13, ""); - std::cerr << "error at line " << current_line << " (column " << current_column << "): " << error << std::endl; - throw std::runtime_error("compilation failed"); -} - -char return_impl(char t, const char * yytext) { - update_location(yytext); - return t; -} - -yytokentype return_impl(yytokentype t, const char * yytext){ - update_location(yytext); - return t; -} - -void return_void(const char * yytext){ - update_location(yytext); -} - -} - -} diff --git a/lib/ast/module.cpp b/lib/ast/module.cpp new file mode 100644 index 000000000..32ae8b4c0 --- /dev/null +++ b/lib/ast/module.cpp @@ -0,0 +1,18 @@ +#include "triton/ast/module.h" +#include "triton/ir/module.h" + + +namespace triton{ + +namespace ast{ + +/* Translation unit */ +ir::value* translation_unit::codegen(ir::module *mod) const{ + mod->add_new_scope(); + decls_.codegen(mod); + return nullptr; +} + +} + +} diff --git a/lib/ast/node.cpp b/lib/ast/node.cpp new file mode 100644 index 000000000..c13bf3db7 --- /dev/null +++ b/lib/ast/node.cpp @@ -0,0 +1,160 @@ +#include "triton/ast/node.h" +#include "triton/ir/builder.h" +#include "triton/ir/module.h" +#include "triton/ir/constant.h" + +namespace triton{ + +namespace ast{ + +/* node */ +ir::value *node::explicit_cast(ir::builder &builder, ir::value *src, ir::type *dst_ty){ + ir::type *src_scalar_ty = src->get_type()->get_scalar_ty(); + ir::type *dst_scalar_ty = dst_ty->get_scalar_ty(); + bool src_signed = false; + bool dst_signed = false; + if(src_scalar_ty == dst_scalar_ty) + return src; + else if(src_scalar_ty->is_integer_ty() && src_signed && dst_scalar_ty->is_floating_point_ty()) + return builder.create_si_to_fp(src, dst_ty); + + else if(src_scalar_ty->is_integer_ty() && !src_signed && dst_scalar_ty->is_floating_point_ty()) + return builder.create_ui_to_fp(src, dst_ty); + + else if(src_scalar_ty->is_floating_point_ty() && dst_scalar_ty->is_integer_ty() && dst_signed) + return builder.create_fp_to_si(src, dst_ty); + + else if(src_scalar_ty->is_floating_point_ty() && dst_scalar_ty->is_integer_ty() && !dst_signed) + return builder.create_fp_to_ui(src, dst_ty); + + else if(src_scalar_ty->is_floating_point_ty() && dst_scalar_ty->is_floating_point_ty() && + src_scalar_ty->get_fp_mantissa_width() < dst_scalar_ty->get_fp_mantissa_width()) + return builder.create_fp_ext(src, dst_ty); + + else if(src_scalar_ty->is_floating_point_ty() && dst_scalar_ty->is_floating_point_ty() && + src_scalar_ty->get_fp_mantissa_width() > dst_scalar_ty->get_fp_mantissa_width()) + return builder.create_fp_trunc(src, dst_ty); + + else if(src_scalar_ty->is_integer_ty() && dst_scalar_ty->is_integer_ty() && + src_scalar_ty->get_integer_bitwidth()) + return builder.create_int_cast(src, dst_ty, dst_signed); + + else + throw std::runtime_error("unreachable"); +} + + +void node::implicit_cast(ir::builder &builder, ir::value *&lhs, ir::value *&rhs, + bool &is_float, bool &is_ptr, bool &is_int, bool &is_signed){ + // Input types + ir::type *left_ty = lhs->get_type()->get_scalar_ty(); + ir::type *right_ty = rhs->get_type()->get_scalar_ty(); + // One operand is pointer + if(left_ty->is_pointer_ty() || right_ty->is_pointer_ty()){ + if(left_ty->is_pointer_ty() && right_ty->is_pointer_ty()) + throw std::runtime_error("invalid operands"); + if(right_ty->is_pointer_ty()) + std::swap(lhs, rhs); + is_ptr = true; + } + // One operand is double + else if(left_ty->is_double_ty() || right_ty->is_double_ty()){ + ir::value *&to_convert = left_ty->is_double_ty()?rhs:lhs; + to_convert = explicit_cast(builder, to_convert, builder.get_double_ty()); + is_float = true; + } + // One operand is float + else if(left_ty->is_float_ty() || right_ty->is_float_ty()){ + ir::value *&to_convert = left_ty->is_float_ty()?rhs:lhs; + to_convert = explicit_cast(builder, to_convert, builder.get_float_ty()); + is_float = true; + } + // Both operands are integers + else if(left_ty->is_integer_ty() && right_ty->is_integer_ty()){ + is_int = true; + is_signed = true; // always signed for now + if(left_ty->get_integer_bitwidth() != right_ty->get_integer_bitwidth()){ + ir::value *&to_convert = (left_ty->get_integer_bitwidth() > right_ty->get_integer_bitwidth())?rhs:lhs; + ir::type *dst_ty = (to_convert==lhs)?right_ty:left_ty; + to_convert = explicit_cast(builder, to_convert, dst_ty); + } + } + // Not reachable + else + throw std::runtime_error("unreachable"); +} + +void node::implicit_broadcast(ir::module *mod, ir::value *&lhs, ir::value *&rhs) { + ir::type *lhs_ty = lhs->get_type(); + ir::type *rhs_ty = rhs->get_type(); + ir::type *res_ty = nullptr; + if(!lhs_ty->is_tile_ty() && !rhs_ty->is_tile_ty()) + return; + else if(lhs_ty->is_tile_ty() && !rhs_ty->is_tile_ty()) + res_ty = lhs_ty; + else if(!lhs_ty->is_tile_ty() && rhs_ty->is_tile_ty()) + res_ty = rhs_ty; + else{ + auto lhs_shapes = lhs_ty->get_tile_shapes(); + auto rhs_shapes = rhs_ty->get_tile_shapes(); + size_t lhs_size = lhs_shapes.size(); + size_t rhs_size = rhs_shapes.size(); + size_t res_size = std::max(lhs_size, rhs_size); + ir::type::tile_shapes_t res_shapes(res_size); + ir::type::tile_shapes_t::value_type one = ir::tile_type::make_one(mod->get_context()); + for(int i = 0; i < res_size; i++){ + if(i >= res_size - lhs_size && i >= res_size - rhs_size) + res_shapes[i] = lhs_shapes[i]==one?rhs_shapes[i]:lhs_shapes[i]; + else if(i >= res_size - lhs_size) + res_shapes[i] = lhs_shapes[i]; + else if(i >= res_size - rhs_size) + res_shapes[i] = rhs_shapes[i]; + } + res_ty = ir::tile_type::get(lhs_ty->get_scalar_ty(), res_shapes); + } + implicit_broadcast(mod, res_ty, rhs); + implicit_broadcast(mod, res_ty, lhs); +} + +void node::implicit_broadcast(ir::module *mod, ir::type *ty, ir::value *&src){ + ir::builder &builder = mod->get_builder(); + ir::type *src_ty = src->get_type(); + ir::type::tile_shapes_t::value_type one = ir::tile_type::make_one(mod->get_context()); + // Both are scalar + if(!ty->is_tile_ty() && !src_ty->is_tile_ty()) + return; + // Broadcast scalar + if(ty->is_tile_ty() && !src_ty->is_tile_ty()){ + src = builder.create_splat(src, ty->get_tile_shapes()); + return; + } + // Downcast tile + if(!ty->is_tile_ty() && src_ty->is_tile_ty()){ + for(ir::constant *shape: src_ty->get_tile_shapes()) + if(shape != one) + throw std::runtime_error("cannot downcast"); + src = builder.create_downcast(src); + return; + } + // Both are arrays + auto dst_shapes = ty->get_tile_shapes(); + auto src_shapes = src_ty->get_tile_shapes(); + int dst_dim = dst_shapes.size(); + int src_dim = src_shapes.size(); + // Pad + int off = dst_dim - src_dim; + for(size_t i = 0; i < off; i++) + src_shapes.insert(src_shapes.begin(), one); + if(off > 0) + src = builder.create_reshape(src, src_shapes); + // Broadcast + for(int i = dst_dim - 1; i>= 0; i--) + if(dst_shapes[i] != src_shapes[i] && dst_shapes[i] != one && src_shapes[i] != one) + throw std::runtime_error("cannot broadcast"); + if(dst_shapes != src_shapes) + src = builder.create_broadcast(src, dst_shapes); +} + +} + +} diff --git a/lib/ast/statement.cpp b/lib/ast/statement.cpp new file mode 100644 index 000000000..265dcbb19 --- /dev/null +++ b/lib/ast/statement.cpp @@ -0,0 +1,160 @@ +#include "triton/ast/expression.h" +#include "triton/ast/statement.h" +#include "triton/ast/declaration.h" +#include "triton/ir/constant.h" +#include "triton/ir/module.h" +#include "triton/ir/basic_block.h" +#include "triton/ir/builder.h" +#include "triton/ir/type.h" + +namespace triton{ + +namespace ast{ + +/* Helpers */ +inline bool is_terminator(ir::value* x) { + return x && dynamic_cast(x); +} + + +/* Statements */ +ir::value* compound_statement::codegen(ir::module* mod) const{ + mod->add_new_scope(); + if(items_) + items_->codegen(mod); + mod->pop_scope(); + return nullptr; +} + +/* Expression statement */ +ir::value* expression_statement::codegen(ir::module *mod) const{ + ir::builder &builder = mod->get_builder(); + ir::basic_block *block = builder.get_insert_block(); + if(pred_) { + // check that it is an assignment + assignment_expression *assignment = dynamic_cast(expr_); + assert(assignment); + // generate mask + ir::value *pred = pred_->codegen(mod); + ir::mask_inst *mask = (ir::mask_inst*)builder.create_mask(pred); + // generate expression + unsigned szbegin = block->get_inst_list().size(); + ir::value *expr = expr_->codegen(mod); + ir::basic_block::iterator begin = block->begin(); + std::advance(begin, szbegin); + // set mask + ir::type *ty = expr->get_type(); + for(auto it = begin; it != builder.get_insert_point(); it++) + (*it)->set_mask_pred(mask->get_result(0)); +// if(auto *itn = dynamic_cast(expr)) +// itn->set_mask_pred(mask->get_result(0)); + if(ty->is_void_ty()) + return expr; + // merge with psi + ir::psi_inst *psi = (ir::psi_inst*)builder.create_merge(mask->get_result(0), expr, + mask->get_result(1), ir::undef_value::get(ty)); + std::string name = ((named_expression*)assignment->lvalue())->id()->name(); + mod->set_value(name, psi); + return psi; + } + return expr_->codegen(mod); +} + +/* For statement */ +ir::value* iteration_statement::codegen(ir::module *mod) const{ + ir::builder &builder = mod->get_builder(); + ir::context &ctx = mod->get_context(); + ir::basic_block *current_bb = builder.get_insert_block(); + ir::function *fn = current_bb->get_parent(); + ir::basic_block *loop_bb = ir::basic_block::create(ctx, "loop", fn); + ir::basic_block *next_bb = ir::basic_block::create(ctx, "postloop", fn); + mod->set_continue_fn([&](){ + if(exec_) + exec_->codegen(mod); + ir::value *cond = explicit_cast(builder, stop_->codegen(mod), ir::type::get_int1_ty(ctx)); + return builder.create_cond_br(cond, loop_bb, next_bb); + }); + init_->codegen(mod); + ir::value *cond = explicit_cast(builder, stop_->codegen(mod), ir::type::get_int1_ty(ctx)); + builder.create_cond_br(cond, loop_bb, next_bb); +// builder.create_br(loop_bb); + builder.set_insert_point(loop_bb); + if(!is_terminator(statements_->codegen(mod))) + mod->get_continue_fn()(); + ir::basic_block *stop_bb = builder.get_insert_block(); + mod->seal_block(stop_bb); + mod->seal_block(loop_bb); + mod->seal_block(builder.get_insert_block()); + mod->seal_block(next_bb); + builder.set_insert_point(next_bb); + return nullptr; +} + +/* While statement */ +ir::value* while_statement::codegen(ir::module* mod) const{ + ir::builder &builder = mod->get_builder(); + ir::context &ctx = mod->get_context(); + ir::basic_block *current_bb = builder.get_insert_block(); + ir::function *fn = current_bb->get_parent(); + ir::basic_block *loop_bb = ir::basic_block::create(ctx, "loop", fn); + ir::basic_block *next_bb = ir::basic_block::create(ctx, "postloop", fn); + mod->set_continue_fn([&](){ + ir::value *cond = explicit_cast(builder, cond_->codegen(mod), ir::type::get_int1_ty(ctx)); + return builder.create_cond_br(cond, loop_bb, next_bb); + }); + ir::value *cond = explicit_cast(builder, cond_->codegen(mod), ir::type::get_int1_ty(ctx)); + builder.create_cond_br(cond, loop_bb, next_bb); + builder.set_insert_point(loop_bb); + if(!is_terminator(statements_->codegen(mod))) + mod->get_continue_fn()(); + ir::basic_block *stop_bb = builder.get_insert_block(); + mod->seal_block(stop_bb); + mod->seal_block(loop_bb); + mod->seal_block(builder.get_insert_block()); + mod->seal_block(next_bb); + builder.set_insert_point(next_bb); + return nullptr; +} + +/* Selection statement */ +ir::value* selection_statement::codegen(ir::module* mod) const{ + ir::builder &builder = mod->get_builder(); + ir::context &ctx = mod->get_context(); + ir::function *fn = builder.get_insert_block()->get_parent(); + ir::value *cond = cond_->codegen(mod); + ir::basic_block *then_bb = ir::basic_block::create(ctx, "then", fn); + ir::basic_block *else_bb = else_value_?ir::basic_block::create(ctx, "else", fn):nullptr; + ir::basic_block *endif_bb = ir::basic_block::create(ctx, "endif", fn); + mod->seal_block(then_bb); + if(else_value_) + mod->seal_block(else_bb); + + // Branch + if(else_value_) + builder.create_cond_br(cond, then_bb, else_bb); + else + builder.create_cond_br(cond, then_bb, endif_bb); + // Then + builder.set_insert_point(then_bb); + if(!is_terminator(then_value_->codegen(mod))) + builder.create_br(endif_bb); + // Else + if(else_value_){ + builder.set_insert_point(else_bb); + if(!is_terminator(else_value_->codegen(mod))) + builder.create_br(endif_bb); + } + // Endif + mod->seal_block(endif_bb); + builder.set_insert_point(endif_bb); + return nullptr; +} + +/* Continue statement */ +ir::value* continue_statement::codegen(ir::module *mod) const{ + return mod->get_continue_fn()(); +} + +} + +}