[general] added support for constant memory declaration

This commit is contained in:
Philippe Tillet
2019-03-03 23:16:33 -05:00
parent 1f30e111ec
commit 4189e130bf
11 changed files with 211 additions and 144 deletions

View File

@@ -38,6 +38,9 @@ extern translation_unit *ast_root;
const char src[] = const char src[] =
"\ "\
__constant__ int32* delta = alloc_const int32[16];\
__constant__ int32* masks = alloc_const int32[16];\
\
const tunable int32 TM;\ const tunable int32 TM;\
const tunable int32 TN;\ const tunable int32 TN;\
const tunable int32 TK;\ const tunable int32 TK;\

View File

@@ -62,6 +62,7 @@ enum STORAGE_SPEC_T{
KERNEL_T, KERNEL_T,
RESTRICT_T, RESTRICT_T,
READONLY_T, READONLY_T,
CONSTANT_SPACE_T,
WRITEONLY_T WRITEONLY_T
}; };
@@ -142,6 +143,16 @@ class builtin_expression: public node{
}; };
class typed_declaration_specifier;
class alloc_const: public builtin_expression{
public:
alloc_const(node *spec, node *size): spec_((typed_declaration_specifier*)spec), size_((constant*)size) { }
ir::value* codegen(ir::module *mod) const;
private:
const typed_declaration_specifier* spec_;
const constant* size_;
};
class get_global_range: public builtin_expression{ class get_global_range: public builtin_expression{
public: public:
@@ -447,13 +458,18 @@ public:
/* Declarators */ /* Declarators */
class declarator: public node{ class declarator: public node{
virtual ir::type* type_impl(ir::module *mod, ir::type *type) const = 0; protected:
typedef std::vector<STORAGE_SPEC_T> storage_spec_vec_t;
typedef const storage_spec_vec_t& storage_spec_vec_const_ref_t;
public:
virtual ir::type* type_impl(ir::module *mod, ir::type *type, storage_spec_vec_const_ref_t storage) const = 0;
public: public:
declarator(node *lhs) declarator(node *lhs)
: lhs_((declarator*)lhs), ptr_(nullptr){ } : lhs_((declarator*)lhs), ptr_(nullptr){ }
ir::type* type(ir::module *mod, ir::type *type) const; ir::type* type(ir::module *mod, ir::type *type, storage_spec_vec_const_ref_t storage) const;
const identifier* id() const { const identifier* id() const {
return (const identifier*)lhs_; return (const identifier*)lhs_;
@@ -464,13 +480,18 @@ public:
return this; return this;
} }
void set_addr_space(unsigned addr_space){
addr_space_ = addr_space;
}
protected: protected:
declarator *lhs_; declarator *lhs_;
pointer *ptr_; pointer *ptr_;
unsigned addr_space_;
}; };
class identifier: public declarator { class identifier: public declarator {
ir::type* type_impl(ir::module *mod, ir::type *type) const; ir::type* type_impl(ir::module *mod, ir::type *type, storage_spec_vec_const_ref_t storage) const;
public: public:
identifier(char *&name): declarator(this), name_(name) { } identifier(char *&name): declarator(this), name_(name) { }
@@ -482,7 +503,7 @@ private:
class pointer: public declarator{ class pointer: public declarator{
private: private:
ir::type* type_impl(ir::module *mod, ir::type *type) const; ir::type* type_impl(ir::module *mod, ir::type *type, storage_spec_vec_const_ref_t storage) const;
public: public:
pointer(node *id): declarator(id) { } pointer(node *id): declarator(id) { }
@@ -490,7 +511,7 @@ public:
class tile: public declarator{ class tile: public declarator{
private: private:
ir::type* type_impl(ir::module *mod, ir::type *type) const; ir::type* type_impl(ir::module *mod, ir::type *type, storage_spec_vec_const_ref_t storage) const;
public: public:
tile(node *id, node *shapes) tile(node *id, node *shapes)
@@ -502,7 +523,7 @@ public:
class function: public declarator{ class function: public declarator{
private: private:
ir::type* type_impl(ir::module *mod, ir::type *type) const; ir::type* type_impl(ir::module *mod, ir::type *type, storage_spec_vec_const_ref_t storage) const;
public: public:
function(node *id, node *args) function(node *id, node *args)
@@ -519,7 +540,7 @@ public:
class initializer : public declarator{ class initializer : public declarator{
private: private:
ir::type* type_impl(ir::module * mod, ir::type *type) const; ir::type* type_impl(ir::module * mod, ir::type *type, storage_spec_vec_const_ref_t storage) const;
public: public:
initializer(node *decl, node *init) initializer(node *decl, node *init)
@@ -531,7 +552,7 @@ public:
public: public:
const declaration_specifier *spec_; const declaration_specifier *spec_;
const declarator *decl_; declarator *decl_;
const expression *expr_; const expression *expr_;
}; };

View File

@@ -8,6 +8,7 @@ using namespace triton::ast;
#define YYSTYPE node* #define YYSTYPE node*
#include "../include/triton/ast/ast.h" #include "../include/triton/ast/ast.h"
#define YYERROR_VERBOSE 1
extern char* yytext; extern char* yytext;
void yyerror(const char *s); void yyerror(const char *s);
int yylex(void); int yylex(void);
@@ -42,11 +43,10 @@ ASSIGN_OP_T get_assign_op(node *op) { return ((token*)op)->assign_op; }
UNARY_OP_T get_unary_op(node *op) { return ((token*)op)->unary_op; } UNARY_OP_T get_unary_op(node *op) { return ((token*)op)->unary_op; }
TYPE_T get_type_spec(node *op) { return ((token*)op)->type; } TYPE_T get_type_spec(node *op) { return ((token*)op)->type; }
STORAGE_SPEC_T get_storage_spec(node *op) { return ((token*)op)->storage_spec;} STORAGE_SPEC_T get_storage_spec(node *op) { return ((token*)op)->storage_spec;}
%} %}
%token IDENTIFIER CONSTANT STRING_LITERAL %token IDENTIFIER CONSTANT STRING_LITERAL
%token TUNABLE KERNEL RESTRICT READONLY WRITEONLY CONST %token TUNABLE KERNEL RESTRICT READONLY WRITEONLY CONST CONSTANT_SPACE
%token PTR_OP INC_OP DEC_OP LEFT_OP RIGHT_OP LE_OP GE_OP EQ_OP NE_OP %token PTR_OP INC_OP DEC_OP LEFT_OP RIGHT_OP LE_OP GE_OP EQ_OP NE_OP
%token AND_OP OR_OP MUL_ASSIGN DIV_ASSIGN MOD_ASSIGN ADD_ASSIGN %token AND_OP OR_OP MUL_ASSIGN DIV_ASSIGN MOD_ASSIGN ADD_ASSIGN
%token SUB_ASSIGN LEFT_ASSIGN RIGHT_ASSIGN AND_ASSIGN %token SUB_ASSIGN LEFT_ASSIGN RIGHT_ASSIGN AND_ASSIGN
@@ -54,7 +54,7 @@ STORAGE_SPEC_T get_storage_spec(node *op) { return ((token*)op)->storage_spec;}
%token VOID UINT1 UINT8 UINT16 UINT32 UINT64 INT1 INT8 INT16 INT32 INT64 FP32 FP64 %token VOID UINT1 UINT8 UINT16 UINT32 UINT64 INT1 INT8 INT16 INT32 INT64 FP32 FP64
%token IF ELSE FOR CONTINUE %token IF ELSE FOR CONTINUE
%token NEWAXIS ELLIPSIS AT %token NEWAXIS ELLIPSIS AT
%token GET_GLOBAL_RANGE DOT %token GET_GLOBAL_RANGE DOT ALLOC_CONST
%start translation_unit %start translation_unit
%% %%
@@ -112,7 +112,8 @@ identifier
builtin builtin
: GET_GLOBAL_RANGE '[' primary_expression ']' '(' constant ')' { $$ = new get_global_range($3, $6); } : GET_GLOBAL_RANGE '[' primary_expression ']' '(' constant ')' { $$ = new get_global_range($3, $6); }
| DOT '(' expression ',' expression ',' expression ')' { $$ = new matmul_expression($3, $5, $7); } | DOT '(' expression ',' expression ',' expression ')' { $$ = new matmul_expression($3, $5, $7); }
| ALLOC_CONST type_specifier '[' constant ']' { $$ = new alloc_const(new typed_declaration_specifier(get_type_spec($2)), $4); }
primary_expression primary_expression
: identifier { $$ = new named_expression($1); } : identifier { $$ = new named_expression($1); }
@@ -366,6 +367,7 @@ storage_class_specifier
| RESTRICT { $$ = new token(RESTRICT_T); } | RESTRICT { $$ = new token(RESTRICT_T); }
| READONLY { $$ = new token(READONLY_T); } | READONLY { $$ = new token(READONLY_T); }
| WRITEONLY { $$ = new token(WRITEONLY_T); } | WRITEONLY { $$ = new token(WRITEONLY_T); }
| CONSTANT_SPACE { $$ = new token(CONSTANT_SPACE_T); }
; ;
/* -------------------------- */ /* -------------------------- */

View File

@@ -8,133 +8,107 @@ IS (u|U|l|L)*
%{ %{
#include <stdio.h> #include <stdio.h>
#include "parser.hpp" #include "parser.hpp"
void count();
int check_type();
int comment();
%} %}
%% %%
"const" { count(); return(CONST); } "__constant__" { return(CONSTANT_SPACE); }
"tunable" { count(); return(TUNABLE); } "const" { return(CONST); }
"kernel" { count(); return(KERNEL); } "tunable" { return(TUNABLE); }
"restrict" { count(); return(RESTRICT); } "kernel" { return(KERNEL); }
"readonly" { count(); return(READONLY); } "restrict" { return(RESTRICT); }
"writeonly" { count(); return(WRITEONLY); } "readonly" { return(READONLY); }
"@" { count(); return(AT); } "writeonly" { return(WRITEONLY); }
"newaxis" { count(); return(NEWAXIS); } "@" { return(AT); }
"if" { count(); return(IF); } "newaxis" { return(NEWAXIS); }
"else" { count(); return(ELSE); } "if" { return(IF); }
"for" { count(); return(FOR); } "else" { return(ELSE); }
"void" { count(); return(VOID); } "for" { return(FOR); }
"uint1" { count(); return(UINT1); } "void" { return(VOID); }
"uint8" { count(); return(UINT8); } "uint1" { return(UINT1); }
"uint16" { count(); return(UINT16); } "uint8" { return(UINT8); }
"uint32" { count(); return(UINT32); } "uint16" { return(UINT16); }
"uint64" { count(); return(UINT64); } "uint32" { return(UINT32); }
"int1" { count(); return(INT1); } "uint64" { return(UINT64); }
"int8" { count(); return(INT8); } "int1" { return(INT1); }
"int16" { count(); return(INT16); } "int8" { return(INT8); }
"int32" { count(); return(INT32); } "int16" { return(INT16); }
"int64" { count(); return(INT64); } "int32" { return(INT32); }
"fp32" { count(); return(FP32); } "int64" { return(INT64); }
"fp64" { count(); return(FP64); } "fp32" { return(FP32); }
"..." { count(); return(ELLIPSIS); } "fp64" { return(FP64); }
"get_global_range" { count(); return GET_GLOBAL_RANGE; } "..." { return(ELLIPSIS); }
"dot" { count(); return DOT;} "get_global_range" { return GET_GLOBAL_RANGE; }
"continue" { count(); return(CONTINUE); } "dot" { return DOT;}
"continue" { return(CONTINUE); }
"alloc_const" { return(ALLOC_CONST); }
{L}({L}|{D})* { return(IDENTIFIER); }
{L}({L}|{D})* { count(); return(check_type()); } 0[xX]{H}+{IS}? { return(CONSTANT); }
0{D}+{IS}? { return(CONSTANT); }
{D}+{IS}? { return(CONSTANT); }
L?'(\\.|[^\\'])+' { return(CONSTANT); }
0[xX]{H}+{IS}? { count(); return(CONSTANT); } {D}+{E}{FS}? { return(CONSTANT); }
0{D}+{IS}? { count(); return(CONSTANT); } {D}*"."{D}+({E})?{FS}? { return(CONSTANT); }
{D}+{IS}? { count(); return(CONSTANT); } {D}+"."{D}*({E})?{FS}? { return(CONSTANT); }
L?'(\\.|[^\\'])+' { count(); return(CONSTANT); }
{D}+{E}{FS}? { count(); return(CONSTANT); } L?\"(\\.|[^\\"])*\" { return(STRING_LITERAL); }
{D}*"."{D}+({E})?{FS}? { count(); return(CONSTANT); }
{D}+"."{D}*({E})?{FS}? { count(); return(CONSTANT); }
L?\"(\\.|[^\\"])*\" { count(); return(STRING_LITERAL); } ">>=" { return(RIGHT_ASSIGN); }
"<<=" { return(LEFT_ASSIGN); }
"+=" { return(ADD_ASSIGN); }
"-=" { return(SUB_ASSIGN); }
"*=" { return(MUL_ASSIGN); }
"/=" { return(DIV_ASSIGN); }
"%=" { return(MOD_ASSIGN); }
"&=" { return(AND_ASSIGN); }
"^=" { return(XOR_ASSIGN); }
"|=" { return(OR_ASSIGN); }
">>" { return(RIGHT_OP); }
"<<" { return(LEFT_OP); }
"++" { return(INC_OP); }
"--" { return(DEC_OP); }
"->" { return(PTR_OP); }
"&&" { return(AND_OP); }
"||" { return(OR_OP); }
"<=" { return(LE_OP); }
">=" { return(GE_OP); }
"==" { return(EQ_OP); }
"!=" { return(NE_OP); }
";" { return(';'); }
("{"|"<%") { return('{'); }
("}"|"%>") { return('}'); }
"," { return(','); }
":" { return(':'); }
"=" { return('='); }
"(" { return('('); }
")" { return(')'); }
("["|"<:") { return('['); }
("]"|":>") { return(']'); }
"." { return('.'); }
"&" { return('&'); }
"!" { return('!'); }
"~" { return('~'); }
"-" { return('-'); }
"+" { return('+'); }
"*" { return('*'); }
"/" { return('/'); }
"%" { return('%'); }
"<" { return('<'); }
">" { return('>'); }
"^" { return('^'); }
"|" { return('|'); }
"?" { return('?'); }
">>=" { count(); return(RIGHT_ASSIGN); } [ \t\v\n\f] { }
"<<=" { count(); return(LEFT_ASSIGN); } . { /* ignore bad characters */ }
"+=" { count(); return(ADD_ASSIGN); }
"-=" { count(); return(SUB_ASSIGN); }
"*=" { count(); return(MUL_ASSIGN); }
"/=" { count(); return(DIV_ASSIGN); }
"%=" { count(); return(MOD_ASSIGN); }
"&=" { count(); return(AND_ASSIGN); }
"^=" { count(); return(XOR_ASSIGN); }
"|=" { count(); return(OR_ASSIGN); }
">>" { count(); return(RIGHT_OP); }
"<<" { count(); return(LEFT_OP); }
"++" { count(); return(INC_OP); }
"--" { count(); return(DEC_OP); }
"->" { count(); return(PTR_OP); }
"&&" { count(); return(AND_OP); }
"||" { count(); return(OR_OP); }
"<=" { count(); return(LE_OP); }
">=" { count(); return(GE_OP); }
"==" { count(); return(EQ_OP); }
"!=" { count(); return(NE_OP); }
";" { count(); return(';'); }
("{"|"<%") { count(); return('{'); }
("}"|"%>") { count(); return('}'); }
"," { count(); return(','); }
":" { count(); return(':'); }
"=" { count(); return('='); }
"(" { count(); return('('); }
")" { count(); return(')'); }
("["|"<:") { count(); return('['); }
("]"|":>") { count(); return(']'); }
"." { count(); return('.'); }
"&" { count(); return('&'); }
"!" { count(); return('!'); }
"~" { count(); return('~'); }
"-" { count(); return('-'); }
"+" { count(); return('+'); }
"*" { count(); return('*'); }
"/" { count(); return('/'); }
"%" { count(); return('%'); }
"<" { count(); return('<'); }
">" { count(); return('>'); }
"^" { count(); return('^'); }
"|" { count(); return('|'); }
"?" { count(); return('?'); }
[ \t\v\n\f] { count(); }
. { /* ignore bad characters */ }
%% %%
int yywrap() int yywrap()
{ return(1); } { return(1); }
int column = 0;
void count()
{
int i;
for (i = 0; yytext[i] != '\0'; i++)
if (yytext[i] == '\n')
column = 0;
else if (yytext[i] == '\t')
column += 8 - (column % 8);
else
column++;
//ECHO;
}
void yyerror (const char *s) /* Called by yyparse on error */ void yyerror (const char *s) /* Called by yyparse on error */
{ {
printf ("Error: %s\n", s); printf ("Error: %s\n", s);
} }
int check_type()
{
return(IDENTIFIER);
}

View File

@@ -112,6 +112,8 @@ private:
llvm::Value* llvm_value(ir::value *v, llvm::IRBuilder<> &builder); llvm::Value* llvm_value(ir::value *v, llvm::IRBuilder<> &builder);
llvm::Instruction* llvm_inst(ir::instruction *inst, std::function<llvm::Value*(ir::value*)> value, llvm::IRBuilder<> &builder); llvm::Instruction* llvm_inst(ir::instruction *inst, std::function<llvm::Value*(ir::value*)> value, llvm::IRBuilder<> &builder);
llvm::Constant* llvm_constant(ir::constant *cst, llvm::LLVMContext &ctx); llvm::Constant* llvm_constant(ir::constant *cst, llvm::LLVMContext &ctx);
llvm::Value* llvm_alloc_const(ir::alloc_const *v, llvm::Module *module, llvm::IRBuilder<> &builder);
llvm::ArrayType* llvm_linearized_tile_type(ir::type *ty, llvm::LLVMContext &ctx);
// grid construction // grid construction
void create_grids(std::vector<ir::value *> &grids, void create_grids(std::vector<ir::value *> &grids,

View File

@@ -106,6 +106,12 @@ public:
unsigned addr_space = 0); unsigned addr_space = 0);
}; };
/* global variable */
class alloc_const: public global_object {
public:
alloc_const(type *ty, constant_int *size,
const std::string &name = "");
};
} }
} }

View File

@@ -28,6 +28,7 @@ class attribute;
class function_type; class function_type;
class constant; class constant;
class global_value; class global_value;
class alloc_const;
/* Module */ /* Module */
struct scope { struct scope {
@@ -76,7 +77,9 @@ public:
void add_new_scope() { if(scopes_.empty()) scopes_.push(scope()); else scopes_.push(scope(get_scope())); } void add_new_scope() { if(scopes_.empty()) scopes_.push(scope()); else scopes_.push(scope(get_scope())); }
void pop_scope() { scopes_.pop(); } void pop_scope() { scopes_.pop(); }
scope& get_scope() { return scopes_.top(); } scope& get_scope() { return scopes_.top(); }
// Const allocation
void add_alloc(ir::alloc_const* x) { allocs_.push_back(x); }
const std::vector<ir::alloc_const*>& allocs() { return allocs_; }
private: private:
std::string name_; std::string name_;
@@ -92,6 +95,7 @@ private:
std::function<ir::value*()> continue_fn_; std::function<ir::value*()> continue_fn_;
std::map<value*, value**> current_phi_; std::map<value*, value**> current_phi_;
std::stack<scope> scopes_; std::stack<scope> scopes_;
std::vector<ir::alloc_const*> allocs_;
}; };
} }

View File

@@ -165,9 +165,8 @@ private:
public: public:
// accessors // accessors
unsigned get_address_space() const { return address_space_; } unsigned get_address_space() const { return address_space_; }
type *get_element_ty() const { return contained_tys_[0]; } type *get_element_ty() const { return contained_tys_[0]; }
// factory methods // factory methods
static pointer_type* get(type *ty, unsigned address_space); static pointer_type* get(type *ty, unsigned address_space);

View File

@@ -188,7 +188,7 @@ std::vector<STORAGE_SPEC_T> storage_declaration_specifier::storage() const {
/* Parameter */ /* Parameter */
ir::type* parameter::type(ir::module *mod) const { ir::type* parameter::type(ir::module *mod) const {
return decl_->type(mod, spec_->type(mod)); return decl_->type(mod, spec_->type(mod), {});
} }
std::vector<STORAGE_SPEC_T> parameter::storage() const { std::vector<STORAGE_SPEC_T> parameter::storage() const {
@@ -200,14 +200,14 @@ const identifier *parameter::id() const {
} }
/* Declarators */ /* Declarators */
ir::type* declarator::type(ir::module *mod, ir::type *type) const{ ir::type* declarator::type(ir::module *mod, ir::type *type, storage_spec_vec_const_ref_t storage) const{
if(ptr_) if(ptr_)
return type_impl(mod, ptr_->type(mod, type)); return type_impl(mod, ptr_->type(mod, type, storage), storage);
return type_impl(mod, type); return type_impl(mod, type, storage);
} }
// Identifier // Identifier
ir::type* identifier::type_impl(ir::module *, ir::type *type) const{ ir::type* identifier::type_impl(ir::module *, ir::type *type, storage_spec_vec_const_ref_t) const{
return type; return type;
} }
@@ -216,7 +216,7 @@ const std::string &identifier::name() const{
} }
// Tile // Tile
ir::type* tile::type_impl(ir::module *mod, ir::type *type) const{ ir::type* tile::type_impl(ir::module *mod, ir::type *type, storage_spec_vec_const_ref_t) const{
ir::type::tile_shapes_t shapes; ir::type::tile_shapes_t shapes;
for(expression *expr: shapes_->values()){ for(expression *expr: shapes_->values()){
ir::constant_int *shape = dynamic_cast<ir::constant_int*>(expr->codegen(mod)); ir::constant_int *shape = dynamic_cast<ir::constant_int*>(expr->codegen(mod));
@@ -228,8 +228,9 @@ ir::type* tile::type_impl(ir::module *mod, ir::type *type) const{
// Pointer // Pointer
ir::type* pointer::type_impl(ir::module*, ir::type *type) const{ ir::type* pointer::type_impl(ir::module*, ir::type *type, storage_spec_vec_const_ref_t storage) const{
return ir::pointer_type::get(type, 1); bool is_ptr_to_const = std::find(storage.begin(), storage.end(), CONSTANT_SPACE_T) != storage.end();
return ir::pointer_type::get(type, is_ptr_to_const?4:1);
} }
// Function // Function
@@ -247,7 +248,7 @@ void function::bind_parameters(ir::module *mod, ir::function *fn) const{
} }
} }
ir::type* function::type_impl(ir::module* mod, ir::type *type) const{ ir::type* function::type_impl(ir::module* mod, ir::type *type, storage_spec_vec_const_ref_t) const{
std::vector<ir::type*> types; std::vector<ir::type*> types;
for(parameter* param: args_->values()) for(parameter* param: args_->values())
types.push_back(param->type(mod)); types.push_back(param->type(mod));
@@ -265,7 +266,7 @@ ir::attribute_t get_ir_attr(STORAGE_SPEC_T spec){
} }
ir::value* function_definition::codegen(ir::module *mod) const{ ir::value* function_definition::codegen(ir::module *mod) const{
ir::function_type *prototype = (ir::function_type*)header_->type(mod, spec_->type(mod)); ir::function_type *prototype = (ir::function_type*)header_->type(mod, spec_->type(mod), spec_->storage());
const std::string &name = header_->id()->name(); const std::string &name = header_->id()->name();
ir::function *fn = mod->get_or_insert_function(name, prototype); ir::function *fn = mod->get_or_insert_function(name, prototype);
for(unsigned i = 0; i < header_->get_num_args(); i++){ for(unsigned i = 0; i < header_->get_num_args(); i++){
@@ -397,8 +398,8 @@ ir::value* declaration::codegen(ir::module* mod) const{
} }
/* Initializer */ /* Initializer */
ir::type* initializer::type_impl(ir::module *mod, ir::type *type) const{ ir::type* initializer::type_impl(ir::module *mod, ir::type *type, storage_spec_vec_const_ref_t storage) const{
return decl_->type(mod, type); return decl_->type(mod, type, storage);
} }
void initializer::set_specifier(const declaration_specifier *spec) { void initializer::set_specifier(const declaration_specifier *spec) {
@@ -406,8 +407,8 @@ void initializer::set_specifier(const declaration_specifier *spec) {
} }
ir::value* initializer::codegen(ir::module * mod) const{ ir::value* initializer::codegen(ir::module * mod) const{
ir::type *ty = decl_->type(mod, spec_->type(mod));
std::vector<STORAGE_SPEC_T> storage = spec_->storage(); std::vector<STORAGE_SPEC_T> storage = spec_->storage();
ir::type *ty = decl_->type(mod, spec_->type(mod), storage);
std::string name = decl_->id()->name(); std::string name = decl_->id()->name();
ir::value *value = ir::undef_value::get(ty); ir::value *value = ir::undef_value::get(ty);
if(std::find(storage.begin(), storage.end(), TUNABLE_T) != storage.end()){ if(std::find(storage.begin(), storage.end(), TUNABLE_T) != storage.end()){
@@ -423,6 +424,8 @@ ir::value* initializer::codegen(ir::module * mod) const{
value->set_name(name); value->set_name(name);
mod->set_value(name, value); mod->set_value(name, value);
mod->get_scope().types[name] = ty; mod->get_scope().types[name] = ty;
if(auto *x = dynamic_cast<ir::alloc_const*>(value))
mod->add_alloc(x);
if(std::find(storage.begin(), storage.end(), CONST_T) != storage.end()) if(std::find(storage.begin(), storage.end(), CONST_T) != storage.end())
mod->set_const(name); mod->set_const(name);
return value; return value;
@@ -523,13 +526,21 @@ ir::value* binary_operator::codegen(ir::module *mod) const{
/* Builtin expression */ /* Builtin expression */
// alloc constant
ir::value* alloc_const::codegen(ir::module *mod) const {
ir::type *ty = spec_->type(mod);
ir::constant_int *size = (ir::constant_int*)size_->codegen(mod);
ir::alloc_const *res = new ir::alloc_const(ty, size);
return res;
}
// get_global_range // get_global_range
ir::value* get_global_range::codegen(ir::module *mod) const { ir::value* get_global_range::codegen(ir::module *mod) const {
ir::builder &builder = mod->get_builder(); ir::builder &builder = mod->get_builder();
return builder.create_get_global_range(axis_->value(), (ir::constant_int*)size_->codegen(mod)); return builder.create_get_global_range(axis_->value(), (ir::constant_int*)size_->codegen(mod));
} }
// matmul
ir::value* matmul_expression::codegen(ir::module *mod) const { ir::value* matmul_expression::codegen(ir::module *mod) const {
ir::value *A = A_->codegen(mod); ir::value *A = A_->codegen(mod);
ir::value *B = B_->codegen(mod); ir::value *B = B_->codegen(mod);
@@ -666,7 +677,7 @@ ir::value *assignment_expression::codegen(ir::module *mod) const{
/* Type name */ /* Type name */
ir::type *type_name::type(ir::module *mod) const{ ir::type *type_name::type(ir::module *mod) const{
return decl_->type(mod, spec_->type(mod)); return decl_->type(mod, spec_->type(mod), {});
} }
/* String literal */ /* String literal */
@@ -693,6 +704,9 @@ ir::value* constant_range::codegen(ir::module *mod) const{
/* Named */ /* Named */
ir::value* named_expression::codegen(ir::module *mod) const{ ir::value* named_expression::codegen(ir::module *mod) const{
const std::string &name = id()->name(); const std::string &name = id()->name();
const auto& declarations = mod->get_scope().types;
if(declarations.find(name) == declarations.end())
throw std::runtime_error("variable " + name + " not declared");
return mod->get_value(name); return mod->get_value(name);
} }

View File

@@ -315,6 +315,16 @@ Instruction *selection::llvm_inst(ir::instruction *inst, std::function<Value*(ir
throw std::runtime_error("unknown conversion from ir::instruction to Instruction"); throw std::runtime_error("unknown conversion from ir::instruction to Instruction");
} }
/* convert ir::alloc_const to llvm::GlobalVariable */
Value* selection::llvm_alloc_const(ir::alloc_const *v, Module *module, IRBuilder<> &builder) {
unsigned size = ((ir::constant_int*)v->get_operand(0))->get_value();
Type *element_ty = llvm_type(v->get_type()->get_pointer_element_ty(), module->getContext());
Type *array_ty = llvm::ArrayType::get(element_ty, size);
Value *array = new llvm::GlobalVariable(*module, array_ty, false, llvm::GlobalVariable::ExternalLinkage,
nullptr, v->get_name(), nullptr, llvm::GlobalVariable::NotThreadLocal, 4);
return builder.CreateBitCast(array, element_ty->getPointerTo(4));
}
/* convert ir::value to llvm::Value */ /* convert ir::value to llvm::Value */
Value* selection::llvm_value(ir::value *v, IRBuilder<> &builder) { Value* selection::llvm_value(ir::value *v, IRBuilder<> &builder) {
assert(!v->get_type()->is_tile_ty()); assert(!v->get_type()->is_tile_ty());
@@ -324,6 +334,20 @@ Value* selection::llvm_value(ir::value *v, IRBuilder<> &builder) {
// create operands // create operands
if(auto *cc = dynamic_cast<ir::constant*>(v)) if(auto *cc = dynamic_cast<ir::constant*>(v))
return llvm_constant(cc, ctx); return llvm_constant(cc, ctx);
// alloc const
if(auto *cc = dynamic_cast<ir::alloc_const*>(v)){
BasicBlock *block = builder.GetInsertBlock();
Module *module = block->getModule();
unsigned size = ((ir::constant_int*)cc->get_operand(0))->get_value();
Type *element_ty = llvm_type(cc->get_type()->get_pointer_element_ty(), ctx);
Type *array_ty = llvm::ArrayType::get(element_ty, size);
if(vmap_.find(v) == vmap_.end()){
Value *array = new llvm::GlobalVariable(*module, array_ty, false, llvm::GlobalVariable::ExternalLinkage,
nullptr, cc->get_name(), nullptr, llvm::GlobalVariable::NotThreadLocal, 4);
vmap_[v] = builder.CreateBitCast(array, array->getType()->getArrayElementType()->getPointerTo(4));
}
return vmap_.at(v);
}
// instruction // instruction
if(auto *ii = dynamic_cast<ir::instruction*>(v)){ if(auto *ii = dynamic_cast<ir::instruction*>(v)){
auto value = [&](ir::value *x) { return llvm_value(x, builder); }; auto value = [&](ir::value *x) { return llvm_value(x, builder); };
@@ -755,11 +779,22 @@ inline llvm::Attribute::AttrKind llvm_attr(ir::attribute_t attr) {
} }
} }
ArrayType* selection::llvm_linearized_tile_type(ir::type *ty, LLVMContext &ctx) {
unsigned size = 1;
for(ir::constant_int* shape: ty->get_tile_shapes())
size *= shape->get_value();
return ArrayType::get(llvm_type(ty->get_scalar_ty(), ctx), size);
}
void selection::run(ir::module &src, Module &dst){ void selection::run(ir::module &src, Module &dst){
vmap_.clear(); vmap_.clear();
LLVMContext &dst_ctx = dst.getContext(); LLVMContext &dst_ctx = dst.getContext();
IRBuilder<> dst_builder(dst_ctx); IRBuilder<> dst_builder(dst_ctx);
for(ir::alloc_const *x: src.allocs()) {
vmap_[x] = llvm_alloc_const(x, &dst, dst_builder);
}
// iterate over functions // iterate over functions
for(ir::function *fn: src.get_function_list()) { for(ir::function *fn: src.get_function_list()) {
// create LLVM function // create LLVM function
@@ -795,7 +830,7 @@ void selection::run(ir::module &src, Module &dst){
ArrayType *array_ty = ArrayType::get(int_8_ty, alloc_size); ArrayType *array_ty = ArrayType::get(int_8_ty, alloc_size);
Type *ptr_ty = PointerType::get(int_8_ty, 3); Type *ptr_ty = PointerType::get(int_8_ty, 3);
GlobalVariable *sh_mem_array = GlobalVariable *sh_mem_array =
new GlobalVariable(*dst_fn->getParent(), array_ty, false, GlobalVariable::ExternalLinkage, new GlobalVariable(dst, array_ty, false, GlobalVariable::ExternalLinkage,
nullptr, "__shared_ptr", nullptr, GlobalVariable::NotThreadLocal, 3); nullptr, "__shared_ptr", nullptr, GlobalVariable::NotThreadLocal, 3);
sh_mem_ptr = dst_builder.CreateBitCast(sh_mem_array, ptr_ty); sh_mem_ptr = dst_builder.CreateBitCast(sh_mem_array, ptr_ty);
} }

View File

@@ -135,5 +135,12 @@ global_object::global_object(type *ty, unsigned num_ops,
: global_value(ty, num_ops, linkage, name, addr_space) { } : global_value(ty, num_ops, linkage, name, addr_space) { }
/* alloc const */
alloc_const::alloc_const(type *ty, constant_int *size, const std::string &name)
: global_object(ty, 1, global_value::external, name, 4) {
set_operand(0, size);
}
} }
} }