diff --git a/TODO b/TODO index 340218d7a..72f439161 100644 --- a/TODO +++ b/TODO @@ -2,3 +2,4 @@ - proper naming scheme - symbols table - name conflicts on globals? + - separate header for typedef (e.g., type::tile_shapes_t) to reduce compilation time diff --git a/examples/matrix.cpp b/examples/matrix.cpp index d54f19e38..86ec4f249 100644 --- a/examples/matrix.cpp +++ b/examples/matrix.cpp @@ -6,6 +6,7 @@ #include "ir/context.h" #include "ir/module.h" #include "ir/print.h" +#include "ir/context_impl.h" #include "codegen/selection.h" #include "codegen/tune.h" #include "codegen/shared_copy.h" @@ -182,6 +183,9 @@ int main() { llvm::LLVMContext llvm_context; llvm::Module llvm_module("test", llvm_context); +// context.p_impl->mp_constants_[0]->set_value(16); +// context.p_impl->mp_constants_[1]->set_value(16); +// context.p_impl->mp_constants_[2]->set_value(8); // create passes tdl::codegen::buffer_info_pass buffer_info; diff --git a/include/ast/ast.h b/include/ast/ast.h index 529c4b01b..a24d0480f 100644 --- a/include/ast/ast.h +++ b/include/ast/ast.h @@ -5,7 +5,7 @@ #include #include #include - +#include namespace tdl{ @@ -56,6 +56,12 @@ enum TYPE_T{ FLOAT32_T, FLOAT64_T }; +enum STORAGE_SPEC_T{ + TUNABLE_T, + KERNEL_T, + READONLY_T, WRITEONLY_T, +}; + class pointer; class identifier; class constant; @@ -75,7 +81,7 @@ public: template class list: public node { public: - list(const T& x): values_{x} {} + list(const T& x): values_(1, x) {} node* append(const T& x){ values_.push_back(x); @@ -389,16 +395,30 @@ public: class no_op: public statement { }; // Types - class declaration_specifier: public node{ public: - declaration_specifier(TYPE_T spec) - : spec_(spec) { } + using node::node; + virtual ir::type* type(ir::module *mod) const = 0; +}; +class typed_declaration_specifier: public declaration_specifier { +public: + typed_declaration_specifier(TYPE_T ty): ty_(ty){ } ir::type* type(ir::module *mod) const; private: - const TYPE_T spec_; + const TYPE_T ty_; +}; + +class storage_declaration_specifier: public declaration_specifier { +public: + storage_declaration_specifier(STORAGE_SPEC_T storage_spec, node *decl_spec) + : storage_spec_(storage_spec), decl_spec_((declaration_specifier*)decl_spec) {} + ir::type* type(ir::module *mod) const; + +private: + const STORAGE_SPEC_T storage_spec_; + const declaration_specifier* decl_spec_; }; class declarator; @@ -495,7 +515,7 @@ public: : declarator((node*)((declarator*)decl)->id()), decl_((declarator*)decl), expr_((expression*)init){ } - void specifier(const declaration_specifier *spec); + void set_specifier(const declaration_specifier *spec); ir::value* codegen(ir::module *) const; public: @@ -535,17 +555,17 @@ public: class translation_unit: public node{ public: translation_unit(node *item) - : decls_((list*)item) { } + : decls_(item) { } translation_unit *add(node *item) { - decls_->append(item); + decls_.append(item); return this; } ir::value* codegen(ir::module * mod) const; private: - list* decls_; + list decls_; }; } diff --git a/include/ast/parser.y b/include/ast/parser.y index 43c530e12..826204f8a 100644 --- a/include/ast/parser.y +++ b/include/ast/parser.y @@ -20,12 +20,14 @@ struct token: public node{ token(BIN_OP_T value): bin_op(value){ } token(UNARY_OP_T value): unary_op(value){ } token(TYPE_T value): type(value){ } + token(STORAGE_SPEC_T value): storage_spec(value){ } union { ASSIGN_OP_T assign_op; BIN_OP_T bin_op; UNARY_OP_T unary_op; TYPE_T type; + STORAGE_SPEC_T storage_spec; }; }; @@ -39,10 +41,12 @@ node* append_ptr_list(node *result, node *in){ ASSIGN_OP_T get_assign_op(node *op) { return ((token*)op)->assign_op; } UNARY_OP_T get_unary_op(node *op) { return ((token*)op)->unary_op; } TYPE_T get_type_spec(node *op) { return ((token*)op)->type; } +STORAGE_SPEC_T get_storage_spec(node *op) { return ((token*)op)->storage_spec;} %} %token IDENTIFIER CONSTANT STRING_LITERAL +%token TUNABLE KERNEL READONLY WRITEONLY %token PTR_OP INC_OP DEC_OP LEFT_OP RIGHT_OP LE_OP GE_OP EQ_OP NE_OP %token AND_OP OR_OP MUL_ASSIGN DIV_ASSIGN MOD_ASSIGN ADD_ASSIGN %token SUB_ASSIGN LEFT_ASSIGN RIGHT_ASSIGN AND_ASSIGN @@ -87,17 +91,12 @@ abstract_declarator ; direct_abstract_declarator - : '[' constant_list ']' { $$ = new tile(nullptr, $1); } + : '[' primary_expression_list ']' { $$ = new tile(nullptr, $1); } constant : CONSTANT { $$ = new constant(atoi(yytext)); } ; -constant_list - : constant { $$ = new list((constant*)$1); } - | constant_list ',' constant { $$ = append_ptr_list($1, $3); } - ; - type_name : declaration_specifiers { $$ = new type_name($1, nullptr); } | declaration_specifiers abstract_declarator { $$ = new type_name($1, $2); } @@ -112,7 +111,7 @@ identifier ; builtin - : GET_GLOBAL_RANGE '[' constant ']' '(' constant ')' { $$ = new get_global_range($3, $6); } + : GET_GLOBAL_RANGE '[' primary_expression ']' '(' constant ')' { $$ = new get_global_range($3, $6); } | DOT '(' expression ',' expression ',' expression ')' { $$ = new matmul_expression($3, $5, $7); } primary_expression @@ -124,6 +123,11 @@ primary_expression | '(' expression ')' { $$ = $2; } ; +primary_expression_list + : primary_expression { $$ = new list((expression*)$1); } + | primary_expression_list ',' primary_expression { $$ = append_ptr_list($1, $3); } + ; + slice : ':' { $$ = new slice(tdl::ast::ALL); } | NEWAXIS { $$ = new slice(tdl::ast::NEWAXIS); } @@ -312,7 +316,7 @@ jump_statement direct_declarator : identifier { $$ = $1; } - | identifier '[' constant_list ']' { $$ = new tile($1, $3); } + | identifier '[' primary_expression_list ']' { $$ = new tile($1, $3); } | identifier '(' parameter_list ')' { $$ = new function($1, $3); } | identifier '(' ')' { $$ = new function($1, nullptr); } ; @@ -330,7 +334,8 @@ parameter_declaration declaration_specifiers - : type_specifier { $$ = new declaration_specifier(get_type_spec($1)); } + : type_specifier { $$ = new typed_declaration_specifier(get_type_spec($1)); } + | storage_class_specifier declaration_specifiers { $$ = new storage_declaration_specifier(get_storage_spec($1), $2); } ; init_declarator_list @@ -354,6 +359,13 @@ init_declarator | declarator '=' initialization_expression { $$ = new initializer($1, $3); } ; +storage_class_specifier + : TUNABLE { $$ = new token(TUNABLE_T); } + | KERNEL { $$ = new token(KERNEL_T); } + | READONLY { $$ = new token(READONLY_T); } + | WRITEONLY { $$ = new token(WRITEONLY_T); } +; + /* -------------------------- */ /* Translation Unit */ /* -------------------------- */ diff --git a/include/ast/scanner.l b/include/ast/scanner.l index 80da95dad..885404ca3 100644 --- a/include/ast/scanner.l +++ b/include/ast/scanner.l @@ -16,6 +16,10 @@ int comment(); %} %% +"tunable" { count(); return(TUNABLE); } +"kernel" { count(); return(KERNEL); } +"readonly" { count(); return(READONLY); } +"writeonly" { count(); return(WRITEONLY); } "@" { count(); return(AT); } "newaxis" { count(); return(NEWAXIS); } "if" { count(); return(IF); } diff --git a/include/codegen/barriers.h b/include/codegen/barriers.h index 9b476ae75..5199f94ad 100644 --- a/include/codegen/barriers.h +++ b/include/codegen/barriers.h @@ -32,7 +32,7 @@ private: void add_reference(ir::value *v, interval_vec_t &res); void get_read_intervals(ir::instruction *i, interval_vec_t &res); void get_written_intervals(ir::instruction *i, interval_vec_t &res); - void add(ir::basic_block *block, interval_vec_t ¬_synced, std::set &insert_pts); + void add(ir::basic_block *block, interval_vec_t ¬_synced, ir::builder &builder); public: barriers(allocation *alloc, buffer_info_pass *buffer_info): alloc_(alloc), buffer_info_(buffer_info) {} diff --git a/include/codegen/selection.h b/include/codegen/selection.h index 2531dc74c..ec733ed57 100644 --- a/include/codegen/selection.h +++ b/include/codegen/selection.h @@ -100,6 +100,7 @@ class selection{ private: // utils llvm::Type *make_vector_ty(llvm::Type *ty, size_t vector_size); + std::vector extract_shapes(ir::value *v); // LLVM conversions llvm::Type* llvm_type(ir::type *ty, llvm::LLVMContext &ctx); diff --git a/include/ir/builder.h b/include/ir/builder.h index d7e49cf14..a6c0013fd 100644 --- a/include/ir/builder.h +++ b/include/ir/builder.h @@ -6,6 +6,7 @@ #include #include "instructions.h" #include "basic_block.h" +#include "type.h" namespace tdl{ namespace ir{ @@ -110,11 +111,11 @@ public: value *create_load(value *arg, const std::string &name = ""); value *create_store(value *ptr, value *val, const std::string &name = ""); // Tile instruction - value *create_splat(value *arg, const std::vector &shapes, const std::string &name = ""); - value *create_reshape(value *arg, const std::vector &shapes, const std::string &name = ""); - value *create_broadcast(value *arg, const std::vector &shapes, const std::string &name = ""); + value *create_splat(value *arg, const type::tile_shapes_t &shapes, const std::string &name = ""); + value *create_reshape(value *arg, const type::tile_shapes_t &shapes, const std::string &name = ""); + value *create_broadcast(value *arg, const type::tile_shapes_t &shapes, const std::string &name = ""); // Built-in instruction - value *create_get_global_range(unsigned axis, unsigned size, const std::string &name = ""); + value *create_get_global_range(unsigned axis, type::tile_shapes_t::value_type size, const std::string &name = ""); value *create_matmul(value *A, value *B, value *C, const std::string &name = ""); // Intrinsics value *create_copy_to_shared(value *arg, const std::string &name = ""); diff --git a/include/ir/constant.h b/include/ir/constant.h index 78814283c..132902d3a 100644 --- a/include/ir/constant.h +++ b/include/ir/constant.h @@ -2,6 +2,7 @@ #define TDL_INCLUDE_IR_CONSTANT_H #include "value.h" +#include namespace tdl{ namespace ir{ @@ -28,28 +29,43 @@ public: static undef_value* get(type* ty); }; + /* Constant int */ class constant_int: public constant{ +protected: constant_int(type *ty, uint64_t value); public: uint64_t get_value() const { return value_; } - static constant *get(type *ty, uint64_t value); + static constant_int *get(type *ty, uint64_t value); + +protected: + uint64_t value_; +}; + +/* Metaparameter int */ +class metaparameter: public constant_int{ + metaparameter(type *ty, unsigned lo, unsigned hi); + +public: + static metaparameter *create(context &ctx, type *ty, unsigned lo, unsigned hi); + void set_value(uint64_t value) { value_ = value; } private: - uint64_t value_; + unsigned lo_; + unsigned hi_; }; /* constant range */ class constant_range: public constant{ - constant_range(type *ty, uint64_t first, uint64_t last); + constant_range(type *ty, constant_int* first, constant_int* last); public: - static constant *get(constant *first, constant *last); + static constant *get(constant_int *first, constant_int *last); private: - uint64_t first_; - uint64_t last_; + constant_int* first_; + constant_int* last_; }; /* constant fp */ diff --git a/include/ir/context_impl.h b/include/ir/context_impl.h index cb3acc186..b9017b39c 100644 --- a/include/ir/context_impl.h +++ b/include/ir/context_impl.h @@ -12,6 +12,7 @@ class context; class constant_int; class constant_fp; class undef_value; +class metaparameter; /* Context impl */ class context_impl { @@ -26,13 +27,15 @@ public: integer_type int1_ty, int8_ty, int16_ty, int32_ty, int64_ty, int128_ty; // Pointer types std::map, pointer_type*> ptr_tys; - std::map>, tile_type*> tile_tys; + std::map, tile_type*> tile_tys; // Int constants - std::map int_constants_; + std::map, constant_int*> int_constants_; // Float constants std::map fp_constants_; // undef values std::map uv_constants_; + // Metaparameters + std::vector mp_constants_; }; } diff --git a/include/ir/instructions.h b/include/ir/instructions.h index 047126cf2..ae752f78e 100644 --- a/include/ir/instructions.h +++ b/include/ir/instructions.h @@ -3,6 +3,7 @@ #include #include "value.h" +#include "ir/type.h" #include "llvm/IR/Instructions.h" namespace tdl{ @@ -358,7 +359,7 @@ public: class retile_inst: public unary_inst { protected: - retile_inst(value *arg, const std::vector &shape_suffix, const std::string &name, instruction *next); + retile_inst(value *arg, const type::tile_shapes_t &shapes, const std::string &name, instruction *next); static std::string shape_suffix(ir::type* ty); }; @@ -370,7 +371,7 @@ private: std::string repr_impl() const { return "reshape" + shape_suffix(get_type()); } public: - static instruction* create(value *arg, const std::vector &shape_suffix, + static instruction* create(value *arg, const type::tile_shapes_t &shape_suffix, const std::string &name = "", instruction *next = nullptr); }; @@ -382,7 +383,7 @@ private: std::string repr_impl() const { return "splat" + shape_suffix(get_type()); } public: - static instruction* create(value *arg, const std::vector &shape_suffix, + static instruction* create(value *arg, const type::tile_shapes_t &shape_suffix, const std::string &name = "", instruction *next = nullptr); }; @@ -394,7 +395,7 @@ private: std::string repr_impl() const { return "broadcast" + shape_suffix(get_type()); } public: - static instruction* create(value *arg, const std::vector &shape_suffix, + static instruction* create(value *arg, const type::tile_shapes_t &shape_suffix, const std::string &name = "", instruction *next = nullptr); }; @@ -414,7 +415,7 @@ private: std::string repr_impl() const { return "get_global_range(" + std::to_string(axis_) + ")"; } public: - static instruction* create(context &ctx, unsigned axis, unsigned size, + static instruction* create(context &ctx, unsigned axis, type::tile_shapes_t::value_type size, const std::string &name = "", instruction *next = nullptr); unsigned get_axis() const { return axis_; } diff --git a/include/ir/module.h b/include/ir/module.h index 347178fda..26b6c6769 100644 --- a/include/ir/module.h +++ b/include/ir/module.h @@ -3,6 +3,7 @@ #include #include +#include #include #include #include "builder.h" @@ -12,6 +13,7 @@ namespace tdl{ namespace ast{ class iteration_statement; +class compound_statement; } @@ -69,6 +71,10 @@ public: const functions_list_t &get_function_list() const { return functions_; } functions_list_t &get_function_list() { return functions_; } function *get_or_insert_function(const std::string &name, function_type *ty); + // Scope + void push_scope(const ast::compound_statement* scope) { scopes_.push(scope); } + void pop_scope() { scopes_.pop(); } + const ast::compound_statement* get_scope() { return scopes_.top(); } private: @@ -83,6 +89,7 @@ private: symbols_map_t symbols_; std::function continue_fn_; std::map current_phi_; + std::stack scopes_; }; } diff --git a/include/ir/type.h b/include/ir/type.h index 9f29b465b..6e2049ddd 100644 --- a/include/ir/type.h +++ b/include/ir/type.h @@ -3,6 +3,7 @@ #include #include +#include namespace tdl{ namespace ir{ @@ -10,9 +11,13 @@ namespace ir{ class context; class value; class integer_type; +class constant_int; /* Type */ class type { +public: + typedef std::vector tile_shapes_t; + protected: typedef std::vector contained_tys_vec_t; typedef contained_tys_vec_t::iterator ty_iterator; @@ -54,7 +59,7 @@ public: unsigned get_tile_bitwidth() const; unsigned get_primitive_size_in_bits() const; type *get_scalar_ty() const; - const std::vector &get_tile_shapes() const; + const tile_shapes_t& get_tile_shapes() const; unsigned get_tile_num_elements() const; type *get_tile_element_ty() const; unsigned get_pointer_address_space() const; @@ -94,9 +99,25 @@ public: static integer_type *get_int64_ty(context &ctx); static integer_type *get_int128_ty(context &ctx); + // Attributes + type* set_tunable() { is_tunable_ = true; return this; } + type* set_readonly() { is_readonly_ = true; return this; } + type* set_writeonly() { is_writeonly_ = true; return this; } + type* set_kernel() { is_kernel_ = true; return this; } + + bool get_tunable() { return is_tunable_; } + bool get_readonly() { return is_readonly_; } + bool get_writeonly() { return is_writeonly_; } + bool get_kernel() { return is_kernel_; } + private: context &ctx_; id_t id_; + // attributes + bool is_tunable_; + bool is_readonly_; + bool is_writeonly_; + bool is_kernel_; protected: contained_tys_vec_t contained_tys_; @@ -132,21 +153,24 @@ public: class tile_type: public composite_type { private: - tile_type(type *ty, const std::vector &shapes); + tile_type(type *ty, const tile_shapes_t &shapes); static bool is_valid_elt_ty(type *ty); public: // accessors - const std::vector& get_shapes() const { return shapes_; } + const tile_shapes_t& get_shapes() const { return shapes_; } unsigned get_num_elements() const; unsigned get_bitwidth() const; // factory methods - static tile_type* get(type *ty, const std::vector &shapes); + static tile_type* get(type *ty, const tile_shapes_t &shapes); static tile_type* get_same_shapes(type *ty, type *ref); + // shortcut to get a 1 element in the shape + static tile_shapes_t::value_type make_one(context &ctx); + private: - std::vector shapes_; + tile_shapes_t shapes_; }; class pointer_type: public type { diff --git a/lib/ast/lowering.cpp b/lib/ast/lowering.cpp index dee27d139..993856b2d 100644 --- a/lib/ast/lowering.cpp +++ b/lib/ast/lowering.cpp @@ -100,6 +100,7 @@ void node::implicit_broadcast(ir::module *mod, ir::value *&lhs, ir::value *&rhs) ir::builder &builder = mod->get_builder(); ir::type *lhs_ty = lhs->get_type(); ir::type *rhs_ty = rhs->get_type(); + ir::type::tile_shapes_t::value_type one = ir::tile_type::make_one(mod->get_context()); // Both are scalar if(!lhs_ty->is_tile_ty() && !rhs_ty->is_tile_ty()) return; @@ -111,30 +112,30 @@ void node::implicit_broadcast(ir::module *mod, ir::value *&lhs, ir::value *&rhs) return; } // Both are arrays - std::vector lhs_shapes = lhs->get_type()->get_tile_shapes(); - std::vector rhs_shapes = rhs->get_type()->get_tile_shapes(); + auto lhs_shapes = lhs->get_type()->get_tile_shapes(); + auto rhs_shapes = rhs->get_type()->get_tile_shapes(); if(lhs_shapes == rhs_shapes) return; int lhs_dim = lhs_shapes.size(); int rhs_dim = rhs_shapes.size(); - std::vector &shortest = (lhs_dim < rhs_dim)?lhs_shapes:rhs_shapes; - std::vector &longest = (lhs_dim < rhs_dim)?rhs_shapes:lhs_shapes; + auto &shortest = (lhs_dim < rhs_dim)?lhs_shapes:rhs_shapes; + auto &longest = (lhs_dim < rhs_dim)?rhs_shapes:lhs_shapes; size_t ndim = longest.size(); int off = longest.size() - shortest.size(); for(int i = longest.size() - 1; i>= 0; i--){ - if(shortest[off + i] != longest[i] && shortest[off + i] != 1 && longest[i] != 1) + if(shortest[off + i] != longest[i] && shortest[off + i] != one && longest[i] != one) throw std::runtime_error("cannot broadcast"); } // Pad for(size_t i = 0; i < off; i++) - shortest.insert(shortest.begin(), 1); + shortest.insert(shortest.begin(), one); ir::value *&target = (lhs_dim < rhs_dim)?lhs:rhs; if(off > 0) target = builder.create_reshape(target, shortest); // Broadcast - std::vector shapes(ndim); + ir::type::tile_shapes_t shapes(ndim); for(size_t i = 0; i < ndim; i++) - shapes[i] = std::max(shortest[i], longest[i]); + shapes[i] = shortest[i]==one?longest[i]:shortest[i]; if(shapes != lhs_shapes) lhs = builder.create_broadcast(lhs, shapes); if(shapes != rhs_shapes) @@ -148,14 +149,15 @@ inline bool is_terminator(ir::value* x) { /* Translation unit */ ir::value* translation_unit::codegen(ir::module *mod) const{ - decls_->codegen(mod); + mod->push_scope(nullptr); + decls_.codegen(mod); return nullptr; } /* Declaration specifier */ -ir::type* declaration_specifier::type(ir::module *mod) const { +ir::type* typed_declaration_specifier::type(ir::module *mod) const { ir::context &ctx = mod->get_context(); - switch (spec_) { + switch (ty_) { case VOID_T: return ir::type::get_void_ty(ctx); case INT1_T: return ir::type::get_int1_ty(ctx); case INT8_T: return ir::type::get_int8_ty(ctx); @@ -164,7 +166,18 @@ ir::type* declaration_specifier::type(ir::module *mod) const { case INT64_T: return ir::type::get_int64_ty(ctx); case FLOAT32_T: return ir::type::get_float_ty(ctx); case FLOAT64_T: return ir::type::get_double_ty(ctx); - default: throw std::runtime_error("unreachable"); + default: throw std::runtime_error("unreachable"); + } +} + +ir::type* storage_declaration_specifier::type(ir::module *mod) const { + ir::type* result = decl_spec_->type(mod); + switch(storage_spec_){ + case TUNABLE_T: return result->set_tunable(); + case KERNEL_T: return result->set_kernel(); + case READONLY_T: return result->set_readonly(); + case WRITEONLY_T: return result->set_writeonly(); + default: throw std::runtime_error("unreachable"); } } @@ -194,10 +207,10 @@ const std::string &identifier::name() const{ } // Tile -ir::type* tile::type_impl(ir::module*, ir::type *type) const{ - std::vector shapes; +ir::type* tile::type_impl(ir::module *mod, ir::type *type) const{ + ir::type::tile_shapes_t shapes; for(constant *cst: shapes_->values()) - shapes.push_back(cst->value()); + shapes.push_back((ir::constant_int*)cst->codegen(mod)); return ir::tile_type::get(type, shapes); } @@ -245,6 +258,7 @@ ir::value* function_definition::codegen(ir::module *mod) const{ /* Statements */ ir::value* compound_statement::codegen(ir::module* mod) const{ + mod->push_scope(this); if(decls_) decls_->codegen(mod); if(statements_){ @@ -254,6 +268,7 @@ ir::value* compound_statement::codegen(ir::module* mod) const{ return current; } } + mod->pop_scope(); return nullptr; } @@ -337,7 +352,7 @@ ir::value* continue_statement::codegen(ir::module *mod) const{ /* Declaration */ ir::value* declaration::codegen(ir::module* mod) const{ for(initializer *init: init_->values()) - init->specifier(spec_); + init->set_specifier(spec_); init_->codegen(mod); return nullptr; } @@ -347,7 +362,7 @@ ir::type* initializer::type_impl(ir::module *mod, ir::type *type) const{ return decl_->type(mod, type); } -void initializer::specifier(const declaration_specifier *spec) { +void initializer::set_specifier(const declaration_specifier *spec) { spec_ = spec; } @@ -355,6 +370,11 @@ ir::value* initializer::codegen(ir::module * mod) const{ ir::type *ty = decl_->type(mod, spec_->type(mod)); std::string name = decl_->id()->name(); ir::value *value = ir::undef_value::get(ty); + if(ty->get_tunable()){ + assert(expr_ == nullptr); + //TODO + value = ir::metaparameter::create(mod->get_context(), ty, 4, 8); + } if(expr_){ value = expr_->codegen(mod); value = explicit_cast(mod->get_builder(), value, ty); @@ -464,7 +484,7 @@ ir::value* binary_operator::codegen(ir::module *mod) const{ // get_global_range ir::value* get_global_range::codegen(ir::module *mod) const { ir::builder &builder = mod->get_builder(); - return builder.create_get_global_range(axis_->value(), size_->value()); + return builder.create_get_global_range(axis_->value(), (ir::constant_int*)size_->codegen(mod)); } @@ -487,11 +507,13 @@ ir::value* matmul_expression::codegen(ir::module *mod) const { ir::value* indexing_expression::codegen(ir::module *mod) const{ ir::value *in = mod->get_value(id_->name()); const std::vector &slices = slices_->values(); - std::vector in_shapes = in->get_type()->get_tile_shapes(); - std::vector out_shapes(slices.size()); + auto in_shapes = in->get_type()->get_tile_shapes(); + ir::type::tile_shapes_t::value_type one = ir::tile_type::make_one(mod->get_context()); + ir::type::tile_shapes_t out_shapes(slices.size()); + // create shapes size_t current = 0; for(size_t i = 0; i < out_shapes.size(); i++) - out_shapes[i] = (slices[i]->type()==NEWAXIS)?1:in_shapes[current++]; + out_shapes[i] = (slices[i]->type()==NEWAXIS)?one:in_shapes[current++]; return mod->get_builder().create_reshape(in, out_shapes); } @@ -586,8 +608,8 @@ int constant::value() const{ /* Constant range */ ir::value* constant_range::codegen(ir::module *mod) const{ - return ir::constant_range::get((ir::constant*)first_->codegen(mod), - (ir::constant*)last_->codegen(mod)); + return ir::constant_range::get((ir::constant_int*)first_->codegen(mod), + (ir::constant_int*)last_->codegen(mod)); } /* Named */ diff --git a/lib/codegen/barriers.cpp b/lib/codegen/barriers.cpp index 0466d5ef3..df017931b 100644 --- a/lib/codegen/barriers.cpp +++ b/lib/codegen/barriers.cpp @@ -45,10 +45,15 @@ void barriers::get_written_intervals(ir::instruction *i, interval_vec_t &res){ void barriers::insert_barrier(ir::instruction *instr, ir::builder &builder) { if(auto *phi = dynamic_cast(instr)) { + std::set incoming; for(unsigned n = 0; n < phi->get_num_incoming(); n++){ - ir::basic_block *block = phi->get_incoming_block(n); - builder.set_insert_point(block->get_inst_list().back()); - builder.create_barrier(); + ir::instruction *inc_val = dynamic_cast(phi->get_incoming_value(n)); + assert(inc_val); + if(incoming.insert(inc_val).second){ + ir::basic_block *block = inc_val->get_parent(); + builder.set_insert_point(block->get_inst_list().back()); + builder.create_barrier(); + } } } else { @@ -57,15 +62,15 @@ void barriers::insert_barrier(ir::instruction *instr, ir::builder &builder) { } } -void barriers::add(ir::basic_block *block, interval_vec_t ¬_synced, std::set &insert_pts) { - for(ir::instruction *i: block->get_inst_list()){ +void barriers::add(ir::basic_block *block, interval_vec_t ¬_synced, ir::builder &builder) { + ir::basic_block::inst_list_t instructions = block->get_inst_list(); + for(ir::instruction *i: instructions){ interval_vec_t read, written; get_read_intervals(i, read); get_written_intervals(i, written); - if(intersect(not_synced, read) - || intersect(not_synced, written)) { + if(intersect(not_synced, read)) { not_synced.clear(); - insert_pts.insert(i); + insert_barrier(i, builder); } std::copy(written.begin(), written.end(), std::back_inserter(not_synced)); } @@ -76,12 +81,8 @@ void barriers::run(ir::module &mod) { for(ir::function *fn: mod.get_function_list()){ // find barrier location interval_vec_t not_synced; - std::set insert_pts; for(ir::basic_block *block: fn->blocks()) - add(block, not_synced, insert_pts); - // insert barrier - for(ir::instruction *i: insert_pts) - insert_barrier(i, builder); + add(block, not_synced, builder); } } diff --git a/lib/codegen/selection.cpp b/lib/codegen/selection.cpp index ed17b2fcb..8665714a2 100644 --- a/lib/codegen/selection.cpp +++ b/lib/codegen/selection.cpp @@ -44,6 +44,7 @@ llvm::Type *distributed_tile::make_vector_ty(llvm::Type *ty, size_t vector_size) return VectorType::get(ty, vector_size); } + distributed_tile::distributed_tile(Type *ty, const shapes_t &shapes, const axes_t &axes, llvm::IRBuilder<> &builder, bool vectorize) : tile(make_vector_ty(ty, vectorize?axes[0].contiguous:1), shapes), axes_(axes), builder_(builder) { vector_size_ = vectorize?ty_->getVectorNumElements():1; @@ -149,6 +150,16 @@ Value* shared_tile::get_value(indices_t idx) { return builder_.CreateLoad(ptr); } +/* Utils */ +std::vector selection::extract_shapes(ir::value *v) { + const auto& shapes = v->get_type()->get_tile_shapes(); + std::vector result(shapes.size()); + for(ir::constant_int* cst: shapes) + result.push_back(cst->get_value()); + return result; +} + + /* convert ir::type to Type */ Type *selection::llvm_type(ir::type *ty, LLVMContext &ctx) { // function @@ -299,11 +310,12 @@ std::vector delinearize(Value *trailing, std::vector &shapes, } void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) { - const auto& shapes = v->get_type()->get_tile_shapes(); + const auto& shapes = extract_shapes(v); size_t dim = shapes.size(); std::vector contiguous(dim); std::vector warp_size(dim); std::vector n_warps(dim); + std::cout << v->get_name() << " " << typeid(*v).name() << std::endl; for(unsigned i = 0; i < shapes.size(); i++){ std::string str_i = std::to_string(i); contiguous[i] = *params_->get_param(v, "p0.d" + str_i); @@ -336,7 +348,7 @@ void selection::create_grids(std::vector &grids, // get number of dimensions greater than 1 auto get_tile_gt1_dim = [&](ir::value *v){ unsigned result = 0; - for(unsigned shape: v->get_type()->get_tile_shapes()) { + for(unsigned shape: extract_shapes(v)) { result += (shape > 1)?shape:0; } return result; @@ -353,7 +365,7 @@ void selection::create_grids(std::vector &grids, for(ir::value *op: user->ops()) bind_references(op); // bind - const auto& shapes = v->get_type()->get_tile_shapes(); + const auto& shapes = extract_shapes(v); if(dynamic_cast(v) || buffer_info_->is_double(v)) return; for(size_t d = 0; d < shapes.size(); d++){ @@ -385,7 +397,7 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder, for(ir::value *op: user->ops()) create_tile(op, builder, references, seen, sh_mem_ptr); LLVMContext &ctx = builder.getContext(); - const auto& shapes = v->get_type()->get_tile_shapes(); + const auto& shapes = extract_shapes(v); Type* ty = llvm_type(v->get_type()->get_scalar_ty(), ctx); // create shared tile if(dynamic_cast(v) || (buffer_info_->is_double(v))){ @@ -429,7 +441,7 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder, } // create distributed tile else { - const auto &shapes = v->get_type()->get_tile_shapes(); + const auto &shapes = extract_shapes(v); std::vector axes(shapes.size()); for(size_t d = 0; d < shapes.size(); d++){ if(shapes[d] > 1){ @@ -530,7 +542,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> & distributed_tile* result = (distributed_tile*)ti; if(!ins->get_type()->is_tile_ty()) return; - const auto& shapes = ins->get_type()->get_tile_shapes(); + const auto& shapes = extract_shapes(ins); // global_range if(auto *x = dynamic_cast(ins)) { static std::array ctaid = { @@ -568,7 +580,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> & // broadcast else if(dynamic_cast(ins)) { ir::value* in = ins->get_operand(0); - const auto& in_shapes = in->get_type()->get_tile_shapes(); + const auto& in_shapes = extract_shapes(in); distributed_tile *in_tile = (distributed_tile*)tmap_.at(in); result->for_each([&](indices_t out_idx){ indices_t in_idx = out_idx; @@ -615,7 +627,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> & Function *f_mul_add = Intrinsic::getDeclaration(module, Intrinsic::fmuladd, {llvm_type(C->get_type()->get_scalar_ty(), ctx)}); result->for_each([&](indices_t idx){ Value *res = tmap_.at(C)->get_value(idx); - unsigned NK = A->get_type()->get_tile_shapes()[1]; + unsigned NK = extract_shapes(A)[1]; for(unsigned K = 0; K < NK; ++K){ indices_t a_idx = {idx[0], builder.getInt32(K)}; indices_t b_idx = {idx[1], builder.getInt32(K)}; diff --git a/lib/codegen/tune.cpp b/lib/codegen/tune.cpp index 5de551924..3dc5c4e87 100644 --- a/lib/codegen/tune.cpp +++ b/lib/codegen/tune.cpp @@ -3,6 +3,8 @@ #include "ir/type.h" #include "ir/module.h" #include "ir/function.h" +#include "ir/context_impl.h" + #include @@ -29,7 +31,8 @@ void tune::init_c_phi(ir::instruction *v) { void tune::init_c_graph(ir::instruction *v) { // Reference shape - std::vector shapes; + ir::type::tile_shapes_t::value_type one = ir::tile_type::make_one(v->get_parent()->get_context()); + ir::type::tile_shapes_t shapes; if(auto *store = dynamic_cast(v)) shapes = store->get_pointer_operand()->get_type()->get_tile_shapes(); else @@ -39,7 +42,7 @@ void tune::init_c_graph(ir::instruction *v) { ir::value *op = v->get_operand(0); unsigned current = 0; for(unsigned i = 0; i < shapes.size(); i ++){ - if(shapes[i] == 1) + if(shapes[i] == one) static_params_.insert({{v, i}, 1}); else add_constraint({v, i}, {op, current++}); @@ -99,6 +102,7 @@ void tune::connected_components(node_t x, const std::vector vals, st std::vector tune::get_params(ir::module &mod) { std::vector result; std::set seen; + for(ir::function *fn: mod.get_function_list()) for(ir::basic_block *block: fn->blocks()) for(ir::instruction *i : block->get_inst_list()) @@ -143,8 +147,9 @@ void tune::create_grids(std::vector &grids, // get number of dimensions greater than 1 auto get_tile_gt1_dim = [&](ir::value *v){ unsigned result = 0; - for(unsigned shape: v->get_type()->get_tile_shapes()) { - result += (shape > 1)?shape:0; + auto one = ir::tile_type::make_one(fn->get_fn_type()->get_context()); + for(ir::constant_int *shape: v->get_type()->get_tile_shapes()) { + result += (shape != one); } return result; }; @@ -194,8 +199,8 @@ for(ir::function *fn: mod.get_function_list()){ unsigned *s1 = params_[i]["p1.d" + strk]; unsigned *s2 = params_[i]["p2.d" + strk]; unsigned multiple = (*s0)*(*s1)*(*s2); - if(shapes[k] % multiple != 0) - errors[i].push_back("for dim " + strk + ": shape (" + to_string(shapes[k]) + ")" + if(shapes[k]->get_value() % multiple != 0) + errors[i].push_back("for dim " + strk + ": shape (" + to_string(shapes[k]->get_value()) + ")" " is not a multiple of layout (" + to_string(multiple) + ")"); } // the number of thread per warp must be 32 diff --git a/lib/ir/builder.cpp b/lib/ir/builder.cpp index cb5edd2b6..6622125b5 100644 --- a/lib/ir/builder.cpp +++ b/lib/ir/builder.cpp @@ -244,15 +244,15 @@ value *builder::create_store(value *ptr, value *val, const std::string &name){ // tile instructions //===----------------------------------------------------------------------===// -value *builder::create_reshape(value *arg, const std::vector &shapes, const std::string &name) { +value *builder::create_reshape(value *arg, const type::tile_shapes_t &shapes, const std::string &name) { return insert(reshape_inst::create(arg, shapes, name)); } -value *builder::create_splat(value *arg, const std::vector &shapes, const std::string &name) { +value *builder::create_splat(value *arg, const type::tile_shapes_t &shapes, const std::string &name) { return insert(splat_inst::create(arg, shapes, name)); } -value *builder::create_broadcast(value *arg, const std::vector &shapes, const std::string &name) { +value *builder::create_broadcast(value *arg, const type::tile_shapes_t &shapes, const std::string &name) { return insert(broadcast_inst::create(arg, shapes, name)); } @@ -260,7 +260,7 @@ value *builder::create_broadcast(value *arg, const std::vector &shapes // built-in instructions //===----------------------------------------------------------------------===// -value *builder::create_get_global_range(unsigned axis, unsigned size, const std::string &name) { +value *builder::create_get_global_range(unsigned axis, type::tile_shapes_t::value_type size, const std::string &name) { return insert(get_global_range_inst::create(ctx_, axis, size, name)); } diff --git a/lib/ir/constant.cpp b/lib/ir/constant.cpp index 58f3b1ab7..87f669e4d 100644 --- a/lib/ir/constant.cpp +++ b/lib/ir/constant.cpp @@ -48,24 +48,27 @@ constant *constant::get_all_ones_value(type *ty) { constant_int::constant_int(type *ty, uint64_t value) : constant(ty, 0), value_(value){ } -constant *constant_int::get(type *ty, uint64_t value) { - return new constant_int(ty, value); +constant_int *constant_int::get(type *ty, uint64_t value) { + context_impl *impl = ty->get_context().p_impl.get(); + constant_int *& cst = impl->int_constants_[std::make_pair(ty, value)]; + if(cst == nullptr) + cst = new constant_int(ty, value); + return cst; } // constant_range // FIXME use something like APInt -constant_range::constant_range(type *ty, uint64_t first, uint64_t last) +constant_range::constant_range(type *ty, constant_int *first, constant_int *last) : constant(ty, 0), first_(first), last_(last){ } -constant *constant_range::get(constant *first, constant *last) { +constant *constant_range::get(constant_int *first, constant_int *last) { assert(first->get_type()->is_integer_ty()); assert(first->get_type() == last->get_type()); unsigned vfirst = ((constant_int*)first)->get_value(); - unsigned vlast = ((constant_int*)last)->get_value(); - assert(vlast > vfirst); - type *ty = tile_type::get(first->get_type(), {vlast - vfirst}); - return new constant_range(ty, vfirst, vlast); + assert(vfirst == 0); + type *ty = tile_type::get(first->get_type(), {last}); + return new constant_range(ty, first, last); } @@ -94,6 +97,17 @@ constant *constant_fp::get(context &ctx, double v){ return result; } +// metaparameter +metaparameter::metaparameter(type *ty, unsigned lo, unsigned hi) + : constant_int(ty, 0), lo_(lo), hi_(hi){ } + +metaparameter* metaparameter::create(context &ctx, type *ty, unsigned lo, unsigned hi) { + context_impl *impl = ctx.p_impl.get(); + metaparameter *result = new metaparameter(ty, lo, hi); + impl->mp_constants_.push_back(result); + return result; +} + // undef value undef_value::undef_value(type *ty) : constant(ty, 0) { } diff --git a/lib/ir/instructions.cpp b/lib/ir/instructions.cpp index acf0c0329..38adcc377 100644 --- a/lib/ir/instructions.cpp +++ b/lib/ir/instructions.cpp @@ -409,7 +409,7 @@ std::string retile_inst::shape_suffix(ir::type* ty){ std::string res = "["; const auto& shapes = ty->get_tile_shapes(); for(unsigned i = 0; i < shapes.size(); i++){ - res += std::to_string(ty->get_tile_shapes()[i]); + res += std::to_string(ty->get_tile_shapes()[i]->get_value()); if(i < shapes.size() - 1) res += ", "; } @@ -417,13 +417,13 @@ std::string retile_inst::shape_suffix(ir::type* ty){ return res; } -retile_inst::retile_inst(value *arg, const std::vector &shapes, +retile_inst::retile_inst(value *arg, const type::tile_shapes_t &shapes, const std::string &name, instruction *next) : unary_inst(tile_type::get(arg->get_type()->get_scalar_ty(), shapes), arg, name, next) { } // reshape -instruction* reshape_inst::create(value *arg, const std::vector &shapes, +instruction* reshape_inst::create(value *arg, const type::tile_shapes_t &shapes, const std::string &name, instruction *next) { return new reshape_inst(arg, shapes, name, next); } @@ -431,14 +431,14 @@ instruction* reshape_inst::create(value *arg, const std::vector &shape // splat -instruction* splat_inst::create(value *arg, const std::vector &shapes, +instruction* splat_inst::create(value *arg, const type::tile_shapes_t &shapes, const std::string &name, instruction *next) { return new splat_inst(arg, shapes, name, next); } // broadcast -instruction* broadcast_inst::create(value *arg, const std::vector &shapes, +instruction* broadcast_inst::create(value *arg, const type::tile_shapes_t &shapes, const std::string &name, instruction *next) { return new broadcast_inst(arg, shapes, name, next); } @@ -470,7 +470,7 @@ get_global_range_inst::get_global_range_inst(type *ty, unsigned axis, } -instruction* get_global_range_inst::create(context &ctx, unsigned axis, unsigned size, +instruction* get_global_range_inst::create(context &ctx, unsigned axis, type::tile_shapes_t::value_type size, const std::string &name, instruction *next) { type *int_ty = type::get_int32_ty(ctx); type *tile_ty = tile_type::get(int_ty, {size}); diff --git a/lib/ir/type.cpp b/lib/ir/type.cpp index c790120fb..5aebd94a5 100644 --- a/lib/ir/type.cpp +++ b/lib/ir/type.cpp @@ -3,6 +3,7 @@ #include "ir/context.h" #include "ir/context_impl.h" #include "ir/value.h" +#include "ir/constant.h" namespace tdl{ namespace ir{ @@ -63,7 +64,7 @@ type * type::get_pointer_element_ty() const { } -const std::vector &type::get_tile_shapes() const { +const type::tile_shapes_t &type::get_tile_shapes() const { assert(is_tile_ty()); return ((tile_type*)this)->get_shapes(); } @@ -148,7 +149,7 @@ bool composite_type::index_valid(value *idx) const{ // tile_type class //===----------------------------------------------------------------------===// -tile_type::tile_type(type *ty, const std::vector &shapes) +tile_type::tile_type(type *ty, const tile_shapes_t &shapes) : composite_type(ty->get_context(), TileTyID), shapes_(shapes) { contained_tys_.push_back(ty); } @@ -159,8 +160,8 @@ bool tile_type::is_valid_elt_ty(type *ty) { unsigned tile_type::get_num_elements() const { unsigned res = 1; - for(unsigned shape: shapes_) - res *= shape; + for(auto shape: shapes_) + res *= shape->get_value(); return res; } @@ -168,7 +169,7 @@ unsigned tile_type::get_bitwidth() const { return get_num_elements() * get_tile_element_ty()->get_primitive_size_in_bits(); } -tile_type* tile_type::get(type *elt_ty, const std::vector &shapes) { +tile_type* tile_type::get(type *elt_ty, const tile_shapes_t &shapes) { assert(elt_ty && "Can't get a tile of type!"); assert(shapes.size() && "Can't create a tile with empty shapes!"); assert(is_valid_elt_ty(elt_ty) && "Invalid type for pointer element!"); @@ -185,6 +186,10 @@ tile_type* tile_type::get_same_shapes(type *ty, type *ref){ return get(ty, ref->get_tile_shapes()); } +type::tile_shapes_t::value_type tile_type::make_one(ir::context& ctx){ + return constant_int::get(type::get_int32_ty(ctx), 1); +} + //===----------------------------------------------------------------------===// // function_type class