[abstract syntax tree] improved the grammar
This commit is contained in:
@@ -38,61 +38,47 @@ extern translation_unit *ast_root;
|
||||
|
||||
const char src[] =
|
||||
"\
|
||||
__constant__ int32* delta = alloc_const int32[16];\
|
||||
__constant__ int32* masks = alloc_const int32[16];\
|
||||
\
|
||||
const tunable int32 TM;\
|
||||
const tunable int32 TN;\
|
||||
const tunable int32 TK;\
|
||||
\
|
||||
void matmul(restrict readonly fp32 *a, restrict readonly fp32 *b, fp32 *c, int32 M, int32 N, int32 K, int32 bound){\
|
||||
void matmul(restrict read_only fp32 *a, restrict read_only fp32 *b, fp32 *c,\
|
||||
int32 M, int32 N, int32 K, int32 bound){\
|
||||
int32 rxa[TM] = get_global_range[TM](0);\
|
||||
int32 ryb[TN] = get_global_range[TN](1);\
|
||||
int32 rka[TK] = 0 ... TK;\
|
||||
int32 rkb[TK] = 0 ... TK;\
|
||||
int32 rxc[TM];\
|
||||
int32 ryc[TN];\
|
||||
fp32 C[TM, TN] = 0;\
|
||||
int32 k;\
|
||||
fp32* pa[TM, TK] = a + rka[newaxis, :]*M + rxa[:, newaxis];\
|
||||
fp32* pb[TN, TK] = b + rkb[newaxis, :]*K + ryb[:, newaxis];\
|
||||
fp32* pc[TM, TN];\
|
||||
fp32 a[TM, TK] = *pa;\
|
||||
fp32 b[TN, TK] = *pb;\
|
||||
int1 checkc0[TM];\
|
||||
int1 checkc1[TN];\
|
||||
int1 checkc[TM, TN];\
|
||||
for(k = K; k > 0; k = k - TK){\
|
||||
int1 checka[TM, TK];\
|
||||
int1 checkb[TN, TK];\
|
||||
int1 checka0[TM];\
|
||||
int1 checka1[TK];\
|
||||
int1 checkb0[TN];\
|
||||
int1 checkb1[TK];\
|
||||
C = dot(a, b, C);\
|
||||
pa = pa + TK*M;\
|
||||
pb = pb + TK*K;\
|
||||
checka = k > bound;\
|
||||
checkb = k > bound;\
|
||||
@checka a = *pa;\
|
||||
@checkb b = *pb;\
|
||||
if(k > bound)\
|
||||
continue;\
|
||||
checka0 = rxa < M;\
|
||||
checka1 = rka < k;\
|
||||
checkb0 = ryb < N;\
|
||||
checkb1 = rkb < k;\
|
||||
checka = checka0[:, newaxis] && checka1[newaxis, :];\
|
||||
checkb = checkb0[:, newaxis] && checkb1[newaxis, :];\
|
||||
a = checka ? *pa : 0;\
|
||||
b = checkb ? *pb : 0;\
|
||||
}\
|
||||
rxc = get_global_range[TM](0);\
|
||||
ryc = get_global_range[TN](1);\
|
||||
pc = c + ryc[newaxis, :]*M + rxc[:, newaxis];\
|
||||
checkc0 = rxc < M;\
|
||||
checkc1 = ryc < N;\
|
||||
checkc = checkc0[:, newaxis] && checkc1[newaxis, :];\
|
||||
for(int32 k = K; k > 0;){\
|
||||
C = dot(a, b, C);\
|
||||
pa = pa + TK*M;\
|
||||
pb = pb + TK*K;\
|
||||
k = k - TK;\
|
||||
int1 checka[TM, TK] = k > bound;\
|
||||
int1 checkb[TN, TK] = k > bound;\
|
||||
@checka a = *pa;\
|
||||
@checkb b = *pb;\
|
||||
if(k > bound)\
|
||||
continue;\
|
||||
int1 checka0[TM] = rxa < M;\
|
||||
int1 checka1[TK] = rka < k;\
|
||||
int1 checkb0[TN] = ryb < N;\
|
||||
int1 checkb1[TK] = rkb < k;\
|
||||
checka = checka0[:, newaxis] && checka1[newaxis, :];\
|
||||
checkb = checkb0[:, newaxis] && checkb1[newaxis, :];\
|
||||
a = checka ? *pa : 0;\
|
||||
b = checkb ? *pb : 0;\
|
||||
}\
|
||||
int32 rxc[TM] = get_global_range[TM](0);\
|
||||
int32 ryc[TN] = get_global_range[TN](1);\
|
||||
fp32* pc[TM, TN] = c + ryc[newaxis, :]*M + rxc[:, newaxis];\
|
||||
int1 checkc0[TM] = rxc < M;\
|
||||
int1 checkc1[TN] = ryc < N;\
|
||||
int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];\
|
||||
@checkc *pc = C;\
|
||||
}\
|
||||
";
|
||||
|
@@ -323,7 +323,10 @@ public:
|
||||
class initializer;
|
||||
class declaration_specifier;
|
||||
|
||||
class declaration: public node{
|
||||
class block_item: public node{
|
||||
};
|
||||
|
||||
class declaration: public block_item{
|
||||
public:
|
||||
declaration(node *spec, node *init)
|
||||
: spec_((declaration_specifier*)spec), init_((list<initializer*>*)init) { }
|
||||
@@ -335,10 +338,7 @@ public:
|
||||
const list<initializer*> *init_;
|
||||
};
|
||||
|
||||
class statement: public node{
|
||||
|
||||
private:
|
||||
expression *pred_;
|
||||
class statement: public block_item{
|
||||
};
|
||||
|
||||
class expression_statement: public statement{
|
||||
@@ -353,19 +353,19 @@ private:
|
||||
expression *mask_;
|
||||
};
|
||||
|
||||
|
||||
class compound_statement: public statement{
|
||||
typedef list<declaration*>* declarations_t;
|
||||
typedef list<statement*>* statements_t;
|
||||
|
||||
public:
|
||||
compound_statement(node* decls, node* statements)
|
||||
: decls_((declarations_t)decls), statements_((statements_t)statements) {}
|
||||
compound_statement(node* items)
|
||||
: items_((list<block_item*>*)items){}
|
||||
|
||||
ir::value* codegen(ir::module * mod) const;
|
||||
|
||||
private:
|
||||
declarations_t decls_;
|
||||
statements_t statements_;
|
||||
list<block_item*>* items_;
|
||||
};
|
||||
|
||||
class selection_statement: public statement{
|
||||
@@ -413,7 +413,6 @@ class no_op: public statement { };
|
||||
// Types
|
||||
class declaration_specifier: public node{
|
||||
public:
|
||||
using node::node;
|
||||
virtual ir::type* type(ir::module *mod) const = 0;
|
||||
virtual std::vector<STORAGE_SPEC_T> storage() const = 0;
|
||||
};
|
||||
|
@@ -275,21 +275,16 @@ statement
|
||||
;
|
||||
|
||||
compound_statement
|
||||
: '{' '}' { $$ = new compound_statement(nullptr, nullptr); }
|
||||
| '{' statement_list '}' { $$ = new compound_statement(nullptr, $2); }
|
||||
| '{' declaration_list '}' { $$ = new compound_statement($2, nullptr); }
|
||||
| '{' declaration_list statement_list '}' { $$ = new compound_statement($2, $3);}
|
||||
;
|
||||
: '{' '}' { $$ = new compound_statement(nullptr); }
|
||||
| '{' block_item_list '}' { $$ = new compound_statement($2); }
|
||||
|
||||
block_item_list
|
||||
: block_item { $$ = new list<block_item*>((block_item*)$1); }
|
||||
| block_item_list block_item { $$ = append_ptr_list<block_item>($1, $2); }
|
||||
|
||||
declaration_list
|
||||
: declaration { $$ = new list<declaration*>((declaration*)$1); }
|
||||
| declaration_list declaration { $$ = append_ptr_list<declaration>($1, $2); }
|
||||
|
||||
statement_list
|
||||
: statement { $$ = new list<statement*>((statement*)$1); }
|
||||
| statement_list statement { $$ = append_ptr_list<statement>($1, $2); }
|
||||
;
|
||||
block_item
|
||||
: declaration { $$ = $1; }
|
||||
| statement { $$ = $1; }
|
||||
|
||||
expression_statement
|
||||
: ';' { $$ = new no_op(); }
|
||||
@@ -304,6 +299,8 @@ selection_statement
|
||||
|
||||
iteration_statement
|
||||
: FOR '(' expression_statement expression_statement expression ')' statement { $$ = new iteration_statement($3, $4, $5, $7); }
|
||||
| FOR '(' declaration expression_statement ')' statement { $$ = new iteration_statement($3, $4, nullptr, $6); }
|
||||
| FOR '(' declaration expression_statement expression ')' statement { $$ = new iteration_statement($3, $4, $5, $7); }
|
||||
;
|
||||
|
||||
jump_statement
|
||||
|
@@ -16,8 +16,8 @@ IS (u|U|l|L)*
|
||||
"tunable" { return(TUNABLE); }
|
||||
"kernel" { return(KERNEL); }
|
||||
"restrict" { return(RESTRICT); }
|
||||
"readonly" { return(READONLY); }
|
||||
"writeonly" { return(WRITEONLY); }
|
||||
"read_only" { return(READONLY); }
|
||||
"write_only" { return(WRITEONLY); }
|
||||
"@" { return(AT); }
|
||||
"newaxis" { return(NEWAXIS); }
|
||||
"if" { return(IF); }
|
||||
|
@@ -287,15 +287,8 @@ ir::value* function_definition::codegen(ir::module *mod) const{
|
||||
/* Statements */
|
||||
ir::value* compound_statement::codegen(ir::module* mod) const{
|
||||
mod->add_new_scope();
|
||||
if(decls_)
|
||||
decls_->codegen(mod);
|
||||
if(statements_){
|
||||
for(statement *stmt: statements_->values()){
|
||||
ir::value *current = stmt->codegen(mod);
|
||||
if(is_terminator(current))
|
||||
return current;
|
||||
}
|
||||
}
|
||||
if(items_)
|
||||
items_->codegen(mod);
|
||||
mod->pop_scope();
|
||||
return nullptr;
|
||||
}
|
||||
@@ -333,7 +326,8 @@ ir::value* iteration_statement::codegen(ir::module *mod) const{
|
||||
ir::basic_block *loop_bb = ir::basic_block::create(ctx, "loop", fn);
|
||||
ir::basic_block *next_bb = ir::basic_block::create(ctx, "postloop", fn);
|
||||
mod->set_continue_fn([&](){
|
||||
exec_->codegen(mod);
|
||||
if(exec_)
|
||||
exec_->codegen(mod);
|
||||
ir::value *cond = stop_->codegen(mod);
|
||||
return builder.create_cond_br(cond, loop_bb, next_bb);
|
||||
});
|
||||
|
Reference in New Issue
Block a user