[abstract syntax tree] improved the grammar

This commit is contained in:
Philippe Tillet
2019-03-05 21:03:19 -05:00
parent 4189e130bf
commit 20ff9543ac
5 changed files with 53 additions and 77 deletions

View File

@@ -38,61 +38,47 @@ extern translation_unit *ast_root;
const char src[] = const char src[] =
"\ "\
__constant__ int32* delta = alloc_const int32[16];\
__constant__ int32* masks = alloc_const int32[16];\
\
const tunable int32 TM;\ const tunable int32 TM;\
const tunable int32 TN;\ const tunable int32 TN;\
const tunable int32 TK;\ const tunable int32 TK;\
\ \
void matmul(restrict readonly fp32 *a, restrict readonly fp32 *b, fp32 *c, int32 M, int32 N, int32 K, int32 bound){\ void matmul(restrict read_only fp32 *a, restrict read_only fp32 *b, fp32 *c,\
int32 M, int32 N, int32 K, int32 bound){\
int32 rxa[TM] = get_global_range[TM](0);\ int32 rxa[TM] = get_global_range[TM](0);\
int32 ryb[TN] = get_global_range[TN](1);\ int32 ryb[TN] = get_global_range[TN](1);\
int32 rka[TK] = 0 ... TK;\ int32 rka[TK] = 0 ... TK;\
int32 rkb[TK] = 0 ... TK;\ int32 rkb[TK] = 0 ... TK;\
int32 rxc[TM];\
int32 ryc[TN];\
fp32 C[TM, TN] = 0;\ fp32 C[TM, TN] = 0;\
int32 k;\
fp32* pa[TM, TK] = a + rka[newaxis, :]*M + rxa[:, newaxis];\ fp32* pa[TM, TK] = a + rka[newaxis, :]*M + rxa[:, newaxis];\
fp32* pb[TN, TK] = b + rkb[newaxis, :]*K + ryb[:, newaxis];\ fp32* pb[TN, TK] = b + rkb[newaxis, :]*K + ryb[:, newaxis];\
fp32* pc[TM, TN];\
fp32 a[TM, TK] = *pa;\ fp32 a[TM, TK] = *pa;\
fp32 b[TN, TK] = *pb;\ fp32 b[TN, TK] = *pb;\
int1 checkc0[TM];\ for(int32 k = K; k > 0;){\
int1 checkc1[TN];\ C = dot(a, b, C);\
int1 checkc[TM, TN];\ pa = pa + TK*M;\
for(k = K; k > 0; k = k - TK){\ pb = pb + TK*K;\
int1 checka[TM, TK];\ k = k - TK;\
int1 checkb[TN, TK];\ int1 checka[TM, TK] = k > bound;\
int1 checka0[TM];\ int1 checkb[TN, TK] = k > bound;\
int1 checka1[TK];\ @checka a = *pa;\
int1 checkb0[TN];\ @checkb b = *pb;\
int1 checkb1[TK];\ if(k > bound)\
C = dot(a, b, C);\ continue;\
pa = pa + TK*M;\ int1 checka0[TM] = rxa < M;\
pb = pb + TK*K;\ int1 checka1[TK] = rka < k;\
checka = k > bound;\ int1 checkb0[TN] = ryb < N;\
checkb = k > bound;\ int1 checkb1[TK] = rkb < k;\
@checka a = *pa;\ checka = checka0[:, newaxis] && checka1[newaxis, :];\
@checkb b = *pb;\ checkb = checkb0[:, newaxis] && checkb1[newaxis, :];\
if(k > bound)\ a = checka ? *pa : 0;\
continue;\ b = checkb ? *pb : 0;\
checka0 = rxa < M;\ }\
checka1 = rka < k;\ int32 rxc[TM] = get_global_range[TM](0);\
checkb0 = ryb < N;\ int32 ryc[TN] = get_global_range[TN](1);\
checkb1 = rkb < k;\ fp32* pc[TM, TN] = c + ryc[newaxis, :]*M + rxc[:, newaxis];\
checka = checka0[:, newaxis] && checka1[newaxis, :];\ int1 checkc0[TM] = rxc < M;\
checkb = checkb0[:, newaxis] && checkb1[newaxis, :];\ int1 checkc1[TN] = ryc < N;\
a = checka ? *pa : 0;\ int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];\
b = checkb ? *pb : 0;\
}\
rxc = get_global_range[TM](0);\
ryc = get_global_range[TN](1);\
pc = c + ryc[newaxis, :]*M + rxc[:, newaxis];\
checkc0 = rxc < M;\
checkc1 = ryc < N;\
checkc = checkc0[:, newaxis] && checkc1[newaxis, :];\
@checkc *pc = C;\ @checkc *pc = C;\
}\ }\
"; ";

View File

@@ -323,7 +323,10 @@ public:
class initializer; class initializer;
class declaration_specifier; class declaration_specifier;
class declaration: public node{ class block_item: public node{
};
class declaration: public block_item{
public: public:
declaration(node *spec, node *init) declaration(node *spec, node *init)
: spec_((declaration_specifier*)spec), init_((list<initializer*>*)init) { } : spec_((declaration_specifier*)spec), init_((list<initializer*>*)init) { }
@@ -335,10 +338,7 @@ public:
const list<initializer*> *init_; const list<initializer*> *init_;
}; };
class statement: public node{ class statement: public block_item{
private:
expression *pred_;
}; };
class expression_statement: public statement{ class expression_statement: public statement{
@@ -353,19 +353,19 @@ private:
expression *mask_; expression *mask_;
}; };
class compound_statement: public statement{ class compound_statement: public statement{
typedef list<declaration*>* declarations_t; typedef list<declaration*>* declarations_t;
typedef list<statement*>* statements_t; typedef list<statement*>* statements_t;
public: public:
compound_statement(node* decls, node* statements) compound_statement(node* items)
: decls_((declarations_t)decls), statements_((statements_t)statements) {} : items_((list<block_item*>*)items){}
ir::value* codegen(ir::module * mod) const; ir::value* codegen(ir::module * mod) const;
private: private:
declarations_t decls_; list<block_item*>* items_;
statements_t statements_;
}; };
class selection_statement: public statement{ class selection_statement: public statement{
@@ -413,7 +413,6 @@ class no_op: public statement { };
// Types // Types
class declaration_specifier: public node{ class declaration_specifier: public node{
public: public:
using node::node;
virtual ir::type* type(ir::module *mod) const = 0; virtual ir::type* type(ir::module *mod) const = 0;
virtual std::vector<STORAGE_SPEC_T> storage() const = 0; virtual std::vector<STORAGE_SPEC_T> storage() const = 0;
}; };

View File

@@ -275,21 +275,16 @@ statement
; ;
compound_statement compound_statement
: '{' '}' { $$ = new compound_statement(nullptr, nullptr); } : '{' '}' { $$ = new compound_statement(nullptr); }
| '{' statement_list '}' { $$ = new compound_statement(nullptr, $2); } | '{' block_item_list '}' { $$ = new compound_statement($2); }
| '{' declaration_list '}' { $$ = new compound_statement($2, nullptr); }
| '{' declaration_list statement_list '}' { $$ = new compound_statement($2, $3);}
;
block_item_list
: block_item { $$ = new list<block_item*>((block_item*)$1); }
| block_item_list block_item { $$ = append_ptr_list<block_item>($1, $2); }
declaration_list block_item
: declaration { $$ = new list<declaration*>((declaration*)$1); } : declaration { $$ = $1; }
| declaration_list declaration { $$ = append_ptr_list<declaration>($1, $2); } | statement { $$ = $1; }
statement_list
: statement { $$ = new list<statement*>((statement*)$1); }
| statement_list statement { $$ = append_ptr_list<statement>($1, $2); }
;
expression_statement expression_statement
: ';' { $$ = new no_op(); } : ';' { $$ = new no_op(); }
@@ -304,6 +299,8 @@ selection_statement
iteration_statement iteration_statement
: FOR '(' expression_statement expression_statement expression ')' statement { $$ = new iteration_statement($3, $4, $5, $7); } : FOR '(' expression_statement expression_statement expression ')' statement { $$ = new iteration_statement($3, $4, $5, $7); }
| FOR '(' declaration expression_statement ')' statement { $$ = new iteration_statement($3, $4, nullptr, $6); }
| FOR '(' declaration expression_statement expression ')' statement { $$ = new iteration_statement($3, $4, $5, $7); }
; ;
jump_statement jump_statement

View File

@@ -16,8 +16,8 @@ IS (u|U|l|L)*
"tunable" { return(TUNABLE); } "tunable" { return(TUNABLE); }
"kernel" { return(KERNEL); } "kernel" { return(KERNEL); }
"restrict" { return(RESTRICT); } "restrict" { return(RESTRICT); }
"readonly" { return(READONLY); } "read_only" { return(READONLY); }
"writeonly" { return(WRITEONLY); } "write_only" { return(WRITEONLY); }
"@" { return(AT); } "@" { return(AT); }
"newaxis" { return(NEWAXIS); } "newaxis" { return(NEWAXIS); }
"if" { return(IF); } "if" { return(IF); }

View File

@@ -287,15 +287,8 @@ ir::value* function_definition::codegen(ir::module *mod) const{
/* Statements */ /* Statements */
ir::value* compound_statement::codegen(ir::module* mod) const{ ir::value* compound_statement::codegen(ir::module* mod) const{
mod->add_new_scope(); mod->add_new_scope();
if(decls_) if(items_)
decls_->codegen(mod); items_->codegen(mod);
if(statements_){
for(statement *stmt: statements_->values()){
ir::value *current = stmt->codegen(mod);
if(is_terminator(current))
return current;
}
}
mod->pop_scope(); mod->pop_scope();
return nullptr; return nullptr;
} }
@@ -333,7 +326,8 @@ ir::value* iteration_statement::codegen(ir::module *mod) const{
ir::basic_block *loop_bb = ir::basic_block::create(ctx, "loop", fn); ir::basic_block *loop_bb = ir::basic_block::create(ctx, "loop", fn);
ir::basic_block *next_bb = ir::basic_block::create(ctx, "postloop", fn); ir::basic_block *next_bb = ir::basic_block::create(ctx, "postloop", fn);
mod->set_continue_fn([&](){ mod->set_continue_fn([&](){
exec_->codegen(mod); if(exec_)
exec_->codegen(mod);
ir::value *cond = stop_->codegen(mod); ir::value *cond = stop_->codegen(mod);
return builder.create_cond_br(cond, loop_bb, next_bb); return builder.create_cond_br(cond, loop_bb, next_bb);
}); });