diff --git a/examples/matrix.cpp b/examples/matrix.cpp index 20a911387..0baf844dc 100644 --- a/examples/matrix.cpp +++ b/examples/matrix.cpp @@ -38,6 +38,9 @@ extern translation_unit *ast_root; const char src[] = "\ +__constant__ int32* delta = alloc_const int32[16];\ +__constant__ int32* masks = alloc_const int32[16];\ +\ const tunable int32 TM;\ const tunable int32 TN;\ const tunable int32 TK;\ diff --git a/include/triton/ast/ast.h b/include/triton/ast/ast.h index b9ae16ea0..f511ac132 100644 --- a/include/triton/ast/ast.h +++ b/include/triton/ast/ast.h @@ -62,6 +62,7 @@ enum STORAGE_SPEC_T{ KERNEL_T, RESTRICT_T, READONLY_T, + CONSTANT_SPACE_T, WRITEONLY_T }; @@ -142,6 +143,16 @@ class builtin_expression: public node{ }; +class typed_declaration_specifier; +class alloc_const: public builtin_expression{ +public: + alloc_const(node *spec, node *size): spec_((typed_declaration_specifier*)spec), size_((constant*)size) { } + ir::value* codegen(ir::module *mod) const; + +private: + const typed_declaration_specifier* spec_; + const constant* size_; +}; class get_global_range: public builtin_expression{ public: @@ -447,13 +458,18 @@ public: /* Declarators */ class declarator: public node{ - virtual ir::type* type_impl(ir::module *mod, ir::type *type) const = 0; +protected: + typedef std::vector storage_spec_vec_t; + typedef const storage_spec_vec_t& storage_spec_vec_const_ref_t; + +public: + virtual ir::type* type_impl(ir::module *mod, ir::type *type, storage_spec_vec_const_ref_t storage) const = 0; public: declarator(node *lhs) : lhs_((declarator*)lhs), ptr_(nullptr){ } - ir::type* type(ir::module *mod, ir::type *type) const; + ir::type* type(ir::module *mod, ir::type *type, storage_spec_vec_const_ref_t storage) const; const identifier* id() const { return (const identifier*)lhs_; @@ -464,13 +480,18 @@ public: return this; } + void set_addr_space(unsigned addr_space){ + addr_space_ = addr_space; + } + protected: declarator *lhs_; pointer *ptr_; + unsigned addr_space_; }; class identifier: public declarator { - ir::type* type_impl(ir::module *mod, ir::type *type) const; + ir::type* type_impl(ir::module *mod, ir::type *type, storage_spec_vec_const_ref_t storage) const; public: identifier(char *&name): declarator(this), name_(name) { } @@ -482,7 +503,7 @@ private: class pointer: public declarator{ private: - ir::type* type_impl(ir::module *mod, ir::type *type) const; + ir::type* type_impl(ir::module *mod, ir::type *type, storage_spec_vec_const_ref_t storage) const; public: pointer(node *id): declarator(id) { } @@ -490,7 +511,7 @@ public: class tile: public declarator{ private: - ir::type* type_impl(ir::module *mod, ir::type *type) const; + ir::type* type_impl(ir::module *mod, ir::type *type, storage_spec_vec_const_ref_t storage) const; public: tile(node *id, node *shapes) @@ -502,7 +523,7 @@ public: class function: public declarator{ private: - ir::type* type_impl(ir::module *mod, ir::type *type) const; + ir::type* type_impl(ir::module *mod, ir::type *type, storage_spec_vec_const_ref_t storage) const; public: function(node *id, node *args) @@ -519,7 +540,7 @@ public: class initializer : public declarator{ private: - ir::type* type_impl(ir::module * mod, ir::type *type) const; + ir::type* type_impl(ir::module * mod, ir::type *type, storage_spec_vec_const_ref_t storage) const; public: initializer(node *decl, node *init) @@ -531,7 +552,7 @@ public: public: const declaration_specifier *spec_; - const declarator *decl_; + declarator *decl_; const expression *expr_; }; diff --git a/include/triton/ast/parser.y b/include/triton/ast/parser.y index a7d46e5a7..acd31d995 100644 --- a/include/triton/ast/parser.y +++ b/include/triton/ast/parser.y @@ -8,6 +8,7 @@ using namespace triton::ast; #define YYSTYPE node* #include "../include/triton/ast/ast.h" +#define YYERROR_VERBOSE 1 extern char* yytext; void yyerror(const char *s); int yylex(void); @@ -42,11 +43,10 @@ ASSIGN_OP_T get_assign_op(node *op) { return ((token*)op)->assign_op; } UNARY_OP_T get_unary_op(node *op) { return ((token*)op)->unary_op; } TYPE_T get_type_spec(node *op) { return ((token*)op)->type; } STORAGE_SPEC_T get_storage_spec(node *op) { return ((token*)op)->storage_spec;} - %} %token IDENTIFIER CONSTANT STRING_LITERAL -%token TUNABLE KERNEL RESTRICT READONLY WRITEONLY CONST +%token TUNABLE KERNEL RESTRICT READONLY WRITEONLY CONST CONSTANT_SPACE %token PTR_OP INC_OP DEC_OP LEFT_OP RIGHT_OP LE_OP GE_OP EQ_OP NE_OP %token AND_OP OR_OP MUL_ASSIGN DIV_ASSIGN MOD_ASSIGN ADD_ASSIGN %token SUB_ASSIGN LEFT_ASSIGN RIGHT_ASSIGN AND_ASSIGN @@ -54,7 +54,7 @@ STORAGE_SPEC_T get_storage_spec(node *op) { return ((token*)op)->storage_spec;} %token VOID UINT1 UINT8 UINT16 UINT32 UINT64 INT1 INT8 INT16 INT32 INT64 FP32 FP64 %token IF ELSE FOR CONTINUE %token NEWAXIS ELLIPSIS AT -%token GET_GLOBAL_RANGE DOT +%token GET_GLOBAL_RANGE DOT ALLOC_CONST %start translation_unit %% @@ -112,7 +112,8 @@ identifier builtin : GET_GLOBAL_RANGE '[' primary_expression ']' '(' constant ')' { $$ = new get_global_range($3, $6); } - | DOT '(' expression ',' expression ',' expression ')' { $$ = new matmul_expression($3, $5, $7); } + | DOT '(' expression ',' expression ',' expression ')' { $$ = new matmul_expression($3, $5, $7); } + | ALLOC_CONST type_specifier '[' constant ']' { $$ = new alloc_const(new typed_declaration_specifier(get_type_spec($2)), $4); } primary_expression : identifier { $$ = new named_expression($1); } @@ -366,6 +367,7 @@ storage_class_specifier | RESTRICT { $$ = new token(RESTRICT_T); } | READONLY { $$ = new token(READONLY_T); } | WRITEONLY { $$ = new token(WRITEONLY_T); } + | CONSTANT_SPACE { $$ = new token(CONSTANT_SPACE_T); } ; /* -------------------------- */ diff --git a/include/triton/ast/scanner.l b/include/triton/ast/scanner.l index 4c0635dbc..56cc777a7 100644 --- a/include/triton/ast/scanner.l +++ b/include/triton/ast/scanner.l @@ -8,133 +8,107 @@ IS (u|U|l|L)* %{ #include #include "parser.hpp" - -void count(); -int check_type(); -int comment(); - %} %% -"const" { count(); return(CONST); } -"tunable" { count(); return(TUNABLE); } -"kernel" { count(); return(KERNEL); } -"restrict" { count(); return(RESTRICT); } -"readonly" { count(); return(READONLY); } -"writeonly" { count(); return(WRITEONLY); } -"@" { count(); return(AT); } -"newaxis" { count(); return(NEWAXIS); } -"if" { count(); return(IF); } -"else" { count(); return(ELSE); } -"for" { count(); return(FOR); } -"void" { count(); return(VOID); } -"uint1" { count(); return(UINT1); } -"uint8" { count(); return(UINT8); } -"uint16" { count(); return(UINT16); } -"uint32" { count(); return(UINT32); } -"uint64" { count(); return(UINT64); } -"int1" { count(); return(INT1); } -"int8" { count(); return(INT8); } -"int16" { count(); return(INT16); } -"int32" { count(); return(INT32); } -"int64" { count(); return(INT64); } -"fp32" { count(); return(FP32); } -"fp64" { count(); return(FP64); } -"..." { count(); return(ELLIPSIS); } -"get_global_range" { count(); return GET_GLOBAL_RANGE; } -"dot" { count(); return DOT;} -"continue" { count(); return(CONTINUE); } +"__constant__" { return(CONSTANT_SPACE); } +"const" { return(CONST); } +"tunable" { return(TUNABLE); } +"kernel" { return(KERNEL); } +"restrict" { return(RESTRICT); } +"readonly" { return(READONLY); } +"writeonly" { return(WRITEONLY); } +"@" { return(AT); } +"newaxis" { return(NEWAXIS); } +"if" { return(IF); } +"else" { return(ELSE); } +"for" { return(FOR); } +"void" { return(VOID); } +"uint1" { return(UINT1); } +"uint8" { return(UINT8); } +"uint16" { return(UINT16); } +"uint32" { return(UINT32); } +"uint64" { return(UINT64); } +"int1" { return(INT1); } +"int8" { return(INT8); } +"int16" { return(INT16); } +"int32" { return(INT32); } +"int64" { return(INT64); } +"fp32" { return(FP32); } +"fp64" { return(FP64); } +"..." { return(ELLIPSIS); } +"get_global_range" { return GET_GLOBAL_RANGE; } +"dot" { return DOT;} +"continue" { return(CONTINUE); } +"alloc_const" { return(ALLOC_CONST); } +{L}({L}|{D})* { return(IDENTIFIER); } -{L}({L}|{D})* { count(); return(check_type()); } +0[xX]{H}+{IS}? { return(CONSTANT); } +0{D}+{IS}? { return(CONSTANT); } +{D}+{IS}? { return(CONSTANT); } +L?'(\\.|[^\\'])+' { return(CONSTANT); } -0[xX]{H}+{IS}? { count(); return(CONSTANT); } -0{D}+{IS}? { count(); return(CONSTANT); } -{D}+{IS}? { count(); return(CONSTANT); } -L?'(\\.|[^\\'])+' { count(); return(CONSTANT); } +{D}+{E}{FS}? { return(CONSTANT); } +{D}*"."{D}+({E})?{FS}? { return(CONSTANT); } +{D}+"."{D}*({E})?{FS}? { return(CONSTANT); } -{D}+{E}{FS}? { count(); return(CONSTANT); } -{D}*"."{D}+({E})?{FS}? { count(); return(CONSTANT); } -{D}+"."{D}*({E})?{FS}? { count(); return(CONSTANT); } +L?\"(\\.|[^\\"])*\" { return(STRING_LITERAL); } -L?\"(\\.|[^\\"])*\" { count(); return(STRING_LITERAL); } +">>=" { return(RIGHT_ASSIGN); } +"<<=" { return(LEFT_ASSIGN); } +"+=" { return(ADD_ASSIGN); } +"-=" { return(SUB_ASSIGN); } +"*=" { return(MUL_ASSIGN); } +"/=" { return(DIV_ASSIGN); } +"%=" { return(MOD_ASSIGN); } +"&=" { return(AND_ASSIGN); } +"^=" { return(XOR_ASSIGN); } +"|=" { return(OR_ASSIGN); } +">>" { return(RIGHT_OP); } +"<<" { return(LEFT_OP); } +"++" { return(INC_OP); } +"--" { return(DEC_OP); } +"->" { return(PTR_OP); } +"&&" { return(AND_OP); } +"||" { return(OR_OP); } +"<=" { return(LE_OP); } +">=" { return(GE_OP); } +"==" { return(EQ_OP); } +"!=" { return(NE_OP); } +";" { return(';'); } +("{"|"<%") { return('{'); } +("}"|"%>") { return('}'); } +"," { return(','); } +":" { return(':'); } +"=" { return('='); } +"(" { return('('); } +")" { return(')'); } +("["|"<:") { return('['); } +("]"|":>") { return(']'); } +"." { return('.'); } +"&" { return('&'); } +"!" { return('!'); } +"~" { return('~'); } +"-" { return('-'); } +"+" { return('+'); } +"*" { return('*'); } +"/" { return('/'); } +"%" { return('%'); } +"<" { return('<'); } +">" { return('>'); } +"^" { return('^'); } +"|" { return('|'); } +"?" { return('?'); } -">>=" { count(); return(RIGHT_ASSIGN); } -"<<=" { count(); return(LEFT_ASSIGN); } -"+=" { count(); return(ADD_ASSIGN); } -"-=" { count(); return(SUB_ASSIGN); } -"*=" { count(); return(MUL_ASSIGN); } -"/=" { count(); return(DIV_ASSIGN); } -"%=" { count(); return(MOD_ASSIGN); } -"&=" { count(); return(AND_ASSIGN); } -"^=" { count(); return(XOR_ASSIGN); } -"|=" { count(); return(OR_ASSIGN); } -">>" { count(); return(RIGHT_OP); } -"<<" { count(); return(LEFT_OP); } -"++" { count(); return(INC_OP); } -"--" { count(); return(DEC_OP); } -"->" { count(); return(PTR_OP); } -"&&" { count(); return(AND_OP); } -"||" { count(); return(OR_OP); } -"<=" { count(); return(LE_OP); } -">=" { count(); return(GE_OP); } -"==" { count(); return(EQ_OP); } -"!=" { count(); return(NE_OP); } -";" { count(); return(';'); } -("{"|"<%") { count(); return('{'); } -("}"|"%>") { count(); return('}'); } -"," { count(); return(','); } -":" { count(); return(':'); } -"=" { count(); return('='); } -"(" { count(); return('('); } -")" { count(); return(')'); } -("["|"<:") { count(); return('['); } -("]"|":>") { count(); return(']'); } -"." { count(); return('.'); } -"&" { count(); return('&'); } -"!" { count(); return('!'); } -"~" { count(); return('~'); } -"-" { count(); return('-'); } -"+" { count(); return('+'); } -"*" { count(); return('*'); } -"/" { count(); return('/'); } -"%" { count(); return('%'); } -"<" { count(); return('<'); } -">" { count(); return('>'); } -"^" { count(); return('^'); } -"|" { count(); return('|'); } -"?" { count(); return('?'); } - -[ \t\v\n\f] { count(); } -. { /* ignore bad characters */ } +[ \t\v\n\f] { } +. { /* ignore bad characters */ } %% int yywrap() { return(1); } - -int column = 0; - -void count() -{ - int i; - - for (i = 0; yytext[i] != '\0'; i++) - if (yytext[i] == '\n') - column = 0; - else if (yytext[i] == '\t') - column += 8 - (column % 8); - else - column++; - //ECHO; -} - void yyerror (const char *s) /* Called by yyparse on error */ { printf ("Error: %s\n", s); } - -int check_type() -{ - return(IDENTIFIER); -} diff --git a/include/triton/codegen/selection.h b/include/triton/codegen/selection.h index c8632262e..5c81ca8a0 100644 --- a/include/triton/codegen/selection.h +++ b/include/triton/codegen/selection.h @@ -112,6 +112,8 @@ private: llvm::Value* llvm_value(ir::value *v, llvm::IRBuilder<> &builder); llvm::Instruction* llvm_inst(ir::instruction *inst, std::function value, llvm::IRBuilder<> &builder); llvm::Constant* llvm_constant(ir::constant *cst, llvm::LLVMContext &ctx); + llvm::Value* llvm_alloc_const(ir::alloc_const *v, llvm::Module *module, llvm::IRBuilder<> &builder); + llvm::ArrayType* llvm_linearized_tile_type(ir::type *ty, llvm::LLVMContext &ctx); // grid construction void create_grids(std::vector &grids, diff --git a/include/triton/ir/constant.h b/include/triton/ir/constant.h index 11403c6dd..9f2baf618 100644 --- a/include/triton/ir/constant.h +++ b/include/triton/ir/constant.h @@ -106,6 +106,12 @@ public: unsigned addr_space = 0); }; +/* global variable */ +class alloc_const: public global_object { +public: + alloc_const(type *ty, constant_int *size, + const std::string &name = ""); +}; } } diff --git a/include/triton/ir/module.h b/include/triton/ir/module.h index e4026f8b6..4ec681f67 100644 --- a/include/triton/ir/module.h +++ b/include/triton/ir/module.h @@ -28,6 +28,7 @@ class attribute; class function_type; class constant; class global_value; +class alloc_const; /* Module */ struct scope { @@ -76,7 +77,9 @@ public: void add_new_scope() { if(scopes_.empty()) scopes_.push(scope()); else scopes_.push(scope(get_scope())); } void pop_scope() { scopes_.pop(); } scope& get_scope() { return scopes_.top(); } - + // Const allocation + void add_alloc(ir::alloc_const* x) { allocs_.push_back(x); } + const std::vector& allocs() { return allocs_; } private: std::string name_; @@ -92,6 +95,7 @@ private: std::function continue_fn_; std::map current_phi_; std::stack scopes_; + std::vector allocs_; }; } diff --git a/include/triton/ir/type.h b/include/triton/ir/type.h index 1977ff47c..04da05b60 100644 --- a/include/triton/ir/type.h +++ b/include/triton/ir/type.h @@ -165,9 +165,8 @@ private: public: // accessors - unsigned get_address_space() const { return address_space_; } - type *get_element_ty() const { return contained_tys_[0]; } - + unsigned get_address_space() const { return address_space_; } + type *get_element_ty() const { return contained_tys_[0]; } // factory methods static pointer_type* get(type *ty, unsigned address_space); diff --git a/lib/ast/lowering.cpp b/lib/ast/lowering.cpp index d70955d98..49fe03206 100644 --- a/lib/ast/lowering.cpp +++ b/lib/ast/lowering.cpp @@ -188,7 +188,7 @@ std::vector storage_declaration_specifier::storage() const { /* Parameter */ ir::type* parameter::type(ir::module *mod) const { - return decl_->type(mod, spec_->type(mod)); + return decl_->type(mod, spec_->type(mod), {}); } std::vector parameter::storage() const { @@ -200,14 +200,14 @@ const identifier *parameter::id() const { } /* Declarators */ -ir::type* declarator::type(ir::module *mod, ir::type *type) const{ +ir::type* declarator::type(ir::module *mod, ir::type *type, storage_spec_vec_const_ref_t storage) const{ if(ptr_) - return type_impl(mod, ptr_->type(mod, type)); - return type_impl(mod, type); + return type_impl(mod, ptr_->type(mod, type, storage), storage); + return type_impl(mod, type, storage); } // Identifier -ir::type* identifier::type_impl(ir::module *, ir::type *type) const{ +ir::type* identifier::type_impl(ir::module *, ir::type *type, storage_spec_vec_const_ref_t) const{ return type; } @@ -216,7 +216,7 @@ const std::string &identifier::name() const{ } // Tile -ir::type* tile::type_impl(ir::module *mod, ir::type *type) const{ +ir::type* tile::type_impl(ir::module *mod, ir::type *type, storage_spec_vec_const_ref_t) const{ ir::type::tile_shapes_t shapes; for(expression *expr: shapes_->values()){ ir::constant_int *shape = dynamic_cast(expr->codegen(mod)); @@ -228,8 +228,9 @@ ir::type* tile::type_impl(ir::module *mod, ir::type *type) const{ // Pointer -ir::type* pointer::type_impl(ir::module*, ir::type *type) const{ - return ir::pointer_type::get(type, 1); +ir::type* pointer::type_impl(ir::module*, ir::type *type, storage_spec_vec_const_ref_t storage) const{ + bool is_ptr_to_const = std::find(storage.begin(), storage.end(), CONSTANT_SPACE_T) != storage.end(); + return ir::pointer_type::get(type, is_ptr_to_const?4:1); } // Function @@ -247,7 +248,7 @@ void function::bind_parameters(ir::module *mod, ir::function *fn) const{ } } -ir::type* function::type_impl(ir::module* mod, ir::type *type) const{ +ir::type* function::type_impl(ir::module* mod, ir::type *type, storage_spec_vec_const_ref_t) const{ std::vector types; for(parameter* param: args_->values()) types.push_back(param->type(mod)); @@ -265,7 +266,7 @@ ir::attribute_t get_ir_attr(STORAGE_SPEC_T spec){ } ir::value* function_definition::codegen(ir::module *mod) const{ - ir::function_type *prototype = (ir::function_type*)header_->type(mod, spec_->type(mod)); + ir::function_type *prototype = (ir::function_type*)header_->type(mod, spec_->type(mod), spec_->storage()); const std::string &name = header_->id()->name(); ir::function *fn = mod->get_or_insert_function(name, prototype); for(unsigned i = 0; i < header_->get_num_args(); i++){ @@ -397,8 +398,8 @@ ir::value* declaration::codegen(ir::module* mod) const{ } /* Initializer */ -ir::type* initializer::type_impl(ir::module *mod, ir::type *type) const{ - return decl_->type(mod, type); +ir::type* initializer::type_impl(ir::module *mod, ir::type *type, storage_spec_vec_const_ref_t storage) const{ + return decl_->type(mod, type, storage); } void initializer::set_specifier(const declaration_specifier *spec) { @@ -406,8 +407,8 @@ void initializer::set_specifier(const declaration_specifier *spec) { } ir::value* initializer::codegen(ir::module * mod) const{ - ir::type *ty = decl_->type(mod, spec_->type(mod)); std::vector storage = spec_->storage(); + ir::type *ty = decl_->type(mod, spec_->type(mod), storage); std::string name = decl_->id()->name(); ir::value *value = ir::undef_value::get(ty); if(std::find(storage.begin(), storage.end(), TUNABLE_T) != storage.end()){ @@ -423,6 +424,8 @@ ir::value* initializer::codegen(ir::module * mod) const{ value->set_name(name); mod->set_value(name, value); mod->get_scope().types[name] = ty; + if(auto *x = dynamic_cast(value)) + mod->add_alloc(x); if(std::find(storage.begin(), storage.end(), CONST_T) != storage.end()) mod->set_const(name); return value; @@ -523,13 +526,21 @@ ir::value* binary_operator::codegen(ir::module *mod) const{ /* Builtin expression */ +// alloc constant +ir::value* alloc_const::codegen(ir::module *mod) const { + ir::type *ty = spec_->type(mod); + ir::constant_int *size = (ir::constant_int*)size_->codegen(mod); + ir::alloc_const *res = new ir::alloc_const(ty, size); + return res; +} + // get_global_range ir::value* get_global_range::codegen(ir::module *mod) const { ir::builder &builder = mod->get_builder(); return builder.create_get_global_range(axis_->value(), (ir::constant_int*)size_->codegen(mod)); } - +// matmul ir::value* matmul_expression::codegen(ir::module *mod) const { ir::value *A = A_->codegen(mod); ir::value *B = B_->codegen(mod); @@ -666,7 +677,7 @@ ir::value *assignment_expression::codegen(ir::module *mod) const{ /* Type name */ ir::type *type_name::type(ir::module *mod) const{ - return decl_->type(mod, spec_->type(mod)); + return decl_->type(mod, spec_->type(mod), {}); } /* String literal */ @@ -693,6 +704,9 @@ ir::value* constant_range::codegen(ir::module *mod) const{ /* Named */ ir::value* named_expression::codegen(ir::module *mod) const{ const std::string &name = id()->name(); + const auto& declarations = mod->get_scope().types; + if(declarations.find(name) == declarations.end()) + throw std::runtime_error("variable " + name + " not declared"); return mod->get_value(name); } diff --git a/lib/codegen/selection.cpp b/lib/codegen/selection.cpp index 32a713428..3f79c9375 100644 --- a/lib/codegen/selection.cpp +++ b/lib/codegen/selection.cpp @@ -315,6 +315,16 @@ Instruction *selection::llvm_inst(ir::instruction *inst, std::function &builder) { + unsigned size = ((ir::constant_int*)v->get_operand(0))->get_value(); + Type *element_ty = llvm_type(v->get_type()->get_pointer_element_ty(), module->getContext()); + Type *array_ty = llvm::ArrayType::get(element_ty, size); + Value *array = new llvm::GlobalVariable(*module, array_ty, false, llvm::GlobalVariable::ExternalLinkage, + nullptr, v->get_name(), nullptr, llvm::GlobalVariable::NotThreadLocal, 4); + return builder.CreateBitCast(array, element_ty->getPointerTo(4)); +} + /* convert ir::value to llvm::Value */ Value* selection::llvm_value(ir::value *v, IRBuilder<> &builder) { assert(!v->get_type()->is_tile_ty()); @@ -324,6 +334,20 @@ Value* selection::llvm_value(ir::value *v, IRBuilder<> &builder) { // create operands if(auto *cc = dynamic_cast(v)) return llvm_constant(cc, ctx); + // alloc const + if(auto *cc = dynamic_cast(v)){ + BasicBlock *block = builder.GetInsertBlock(); + Module *module = block->getModule(); + unsigned size = ((ir::constant_int*)cc->get_operand(0))->get_value(); + Type *element_ty = llvm_type(cc->get_type()->get_pointer_element_ty(), ctx); + Type *array_ty = llvm::ArrayType::get(element_ty, size); + if(vmap_.find(v) == vmap_.end()){ + Value *array = new llvm::GlobalVariable(*module, array_ty, false, llvm::GlobalVariable::ExternalLinkage, + nullptr, cc->get_name(), nullptr, llvm::GlobalVariable::NotThreadLocal, 4); + vmap_[v] = builder.CreateBitCast(array, array->getType()->getArrayElementType()->getPointerTo(4)); + } + return vmap_.at(v); + } // instruction if(auto *ii = dynamic_cast(v)){ auto value = [&](ir::value *x) { return llvm_value(x, builder); }; @@ -755,11 +779,22 @@ inline llvm::Attribute::AttrKind llvm_attr(ir::attribute_t attr) { } } +ArrayType* selection::llvm_linearized_tile_type(ir::type *ty, LLVMContext &ctx) { + unsigned size = 1; + for(ir::constant_int* shape: ty->get_tile_shapes()) + size *= shape->get_value(); + return ArrayType::get(llvm_type(ty->get_scalar_ty(), ctx), size); +} + void selection::run(ir::module &src, Module &dst){ vmap_.clear(); LLVMContext &dst_ctx = dst.getContext(); IRBuilder<> dst_builder(dst_ctx); + for(ir::alloc_const *x: src.allocs()) { + vmap_[x] = llvm_alloc_const(x, &dst, dst_builder); + } + // iterate over functions for(ir::function *fn: src.get_function_list()) { // create LLVM function @@ -795,7 +830,7 @@ void selection::run(ir::module &src, Module &dst){ ArrayType *array_ty = ArrayType::get(int_8_ty, alloc_size); Type *ptr_ty = PointerType::get(int_8_ty, 3); GlobalVariable *sh_mem_array = - new GlobalVariable(*dst_fn->getParent(), array_ty, false, GlobalVariable::ExternalLinkage, + new GlobalVariable(dst, array_ty, false, GlobalVariable::ExternalLinkage, nullptr, "__shared_ptr", nullptr, GlobalVariable::NotThreadLocal, 3); sh_mem_ptr = dst_builder.CreateBitCast(sh_mem_array, ptr_ty); } diff --git a/lib/ir/constant.cpp b/lib/ir/constant.cpp index 929af2228..314714c04 100644 --- a/lib/ir/constant.cpp +++ b/lib/ir/constant.cpp @@ -135,5 +135,12 @@ global_object::global_object(type *ty, unsigned num_ops, : global_value(ty, num_ops, linkage, name, addr_space) { } +/* alloc const */ +alloc_const::alloc_const(type *ty, constant_int *size, const std::string &name) + : global_object(ty, 1, global_value::external, name, 4) { + set_operand(0, size); +} + + } }