[abstract syntax tree] improved the grammar
This commit is contained in:
@@ -38,61 +38,47 @@ extern translation_unit *ast_root;
|
|||||||
|
|
||||||
const char src[] =
|
const char src[] =
|
||||||
"\
|
"\
|
||||||
__constant__ int32* delta = alloc_const int32[16];\
|
|
||||||
__constant__ int32* masks = alloc_const int32[16];\
|
|
||||||
\
|
|
||||||
const tunable int32 TM;\
|
const tunable int32 TM;\
|
||||||
const tunable int32 TN;\
|
const tunable int32 TN;\
|
||||||
const tunable int32 TK;\
|
const tunable int32 TK;\
|
||||||
\
|
\
|
||||||
void matmul(restrict readonly fp32 *a, restrict readonly fp32 *b, fp32 *c, int32 M, int32 N, int32 K, int32 bound){\
|
void matmul(restrict read_only fp32 *a, restrict read_only fp32 *b, fp32 *c,\
|
||||||
|
int32 M, int32 N, int32 K, int32 bound){\
|
||||||
int32 rxa[TM] = get_global_range[TM](0);\
|
int32 rxa[TM] = get_global_range[TM](0);\
|
||||||
int32 ryb[TN] = get_global_range[TN](1);\
|
int32 ryb[TN] = get_global_range[TN](1);\
|
||||||
int32 rka[TK] = 0 ... TK;\
|
int32 rka[TK] = 0 ... TK;\
|
||||||
int32 rkb[TK] = 0 ... TK;\
|
int32 rkb[TK] = 0 ... TK;\
|
||||||
int32 rxc[TM];\
|
|
||||||
int32 ryc[TN];\
|
|
||||||
fp32 C[TM, TN] = 0;\
|
fp32 C[TM, TN] = 0;\
|
||||||
int32 k;\
|
|
||||||
fp32* pa[TM, TK] = a + rka[newaxis, :]*M + rxa[:, newaxis];\
|
fp32* pa[TM, TK] = a + rka[newaxis, :]*M + rxa[:, newaxis];\
|
||||||
fp32* pb[TN, TK] = b + rkb[newaxis, :]*K + ryb[:, newaxis];\
|
fp32* pb[TN, TK] = b + rkb[newaxis, :]*K + ryb[:, newaxis];\
|
||||||
fp32* pc[TM, TN];\
|
|
||||||
fp32 a[TM, TK] = *pa;\
|
fp32 a[TM, TK] = *pa;\
|
||||||
fp32 b[TN, TK] = *pb;\
|
fp32 b[TN, TK] = *pb;\
|
||||||
int1 checkc0[TM];\
|
for(int32 k = K; k > 0;){\
|
||||||
int1 checkc1[TN];\
|
C = dot(a, b, C);\
|
||||||
int1 checkc[TM, TN];\
|
pa = pa + TK*M;\
|
||||||
for(k = K; k > 0; k = k - TK){\
|
pb = pb + TK*K;\
|
||||||
int1 checka[TM, TK];\
|
k = k - TK;\
|
||||||
int1 checkb[TN, TK];\
|
int1 checka[TM, TK] = k > bound;\
|
||||||
int1 checka0[TM];\
|
int1 checkb[TN, TK] = k > bound;\
|
||||||
int1 checka1[TK];\
|
@checka a = *pa;\
|
||||||
int1 checkb0[TN];\
|
@checkb b = *pb;\
|
||||||
int1 checkb1[TK];\
|
if(k > bound)\
|
||||||
C = dot(a, b, C);\
|
continue;\
|
||||||
pa = pa + TK*M;\
|
int1 checka0[TM] = rxa < M;\
|
||||||
pb = pb + TK*K;\
|
int1 checka1[TK] = rka < k;\
|
||||||
checka = k > bound;\
|
int1 checkb0[TN] = ryb < N;\
|
||||||
checkb = k > bound;\
|
int1 checkb1[TK] = rkb < k;\
|
||||||
@checka a = *pa;\
|
checka = checka0[:, newaxis] && checka1[newaxis, :];\
|
||||||
@checkb b = *pb;\
|
checkb = checkb0[:, newaxis] && checkb1[newaxis, :];\
|
||||||
if(k > bound)\
|
a = checka ? *pa : 0;\
|
||||||
continue;\
|
b = checkb ? *pb : 0;\
|
||||||
checka0 = rxa < M;\
|
}\
|
||||||
checka1 = rka < k;\
|
int32 rxc[TM] = get_global_range[TM](0);\
|
||||||
checkb0 = ryb < N;\
|
int32 ryc[TN] = get_global_range[TN](1);\
|
||||||
checkb1 = rkb < k;\
|
fp32* pc[TM, TN] = c + ryc[newaxis, :]*M + rxc[:, newaxis];\
|
||||||
checka = checka0[:, newaxis] && checka1[newaxis, :];\
|
int1 checkc0[TM] = rxc < M;\
|
||||||
checkb = checkb0[:, newaxis] && checkb1[newaxis, :];\
|
int1 checkc1[TN] = ryc < N;\
|
||||||
a = checka ? *pa : 0;\
|
int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];\
|
||||||
b = checkb ? *pb : 0;\
|
|
||||||
}\
|
|
||||||
rxc = get_global_range[TM](0);\
|
|
||||||
ryc = get_global_range[TN](1);\
|
|
||||||
pc = c + ryc[newaxis, :]*M + rxc[:, newaxis];\
|
|
||||||
checkc0 = rxc < M;\
|
|
||||||
checkc1 = ryc < N;\
|
|
||||||
checkc = checkc0[:, newaxis] && checkc1[newaxis, :];\
|
|
||||||
@checkc *pc = C;\
|
@checkc *pc = C;\
|
||||||
}\
|
}\
|
||||||
";
|
";
|
||||||
|
@@ -323,7 +323,10 @@ public:
|
|||||||
class initializer;
|
class initializer;
|
||||||
class declaration_specifier;
|
class declaration_specifier;
|
||||||
|
|
||||||
class declaration: public node{
|
class block_item: public node{
|
||||||
|
};
|
||||||
|
|
||||||
|
class declaration: public block_item{
|
||||||
public:
|
public:
|
||||||
declaration(node *spec, node *init)
|
declaration(node *spec, node *init)
|
||||||
: spec_((declaration_specifier*)spec), init_((list<initializer*>*)init) { }
|
: spec_((declaration_specifier*)spec), init_((list<initializer*>*)init) { }
|
||||||
@@ -335,10 +338,7 @@ public:
|
|||||||
const list<initializer*> *init_;
|
const list<initializer*> *init_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class statement: public node{
|
class statement: public block_item{
|
||||||
|
|
||||||
private:
|
|
||||||
expression *pred_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
class expression_statement: public statement{
|
class expression_statement: public statement{
|
||||||
@@ -353,19 +353,19 @@ private:
|
|||||||
expression *mask_;
|
expression *mask_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
class compound_statement: public statement{
|
class compound_statement: public statement{
|
||||||
typedef list<declaration*>* declarations_t;
|
typedef list<declaration*>* declarations_t;
|
||||||
typedef list<statement*>* statements_t;
|
typedef list<statement*>* statements_t;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
compound_statement(node* decls, node* statements)
|
compound_statement(node* items)
|
||||||
: decls_((declarations_t)decls), statements_((statements_t)statements) {}
|
: items_((list<block_item*>*)items){}
|
||||||
|
|
||||||
ir::value* codegen(ir::module * mod) const;
|
ir::value* codegen(ir::module * mod) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
declarations_t decls_;
|
list<block_item*>* items_;
|
||||||
statements_t statements_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
class selection_statement: public statement{
|
class selection_statement: public statement{
|
||||||
@@ -413,7 +413,6 @@ class no_op: public statement { };
|
|||||||
// Types
|
// Types
|
||||||
class declaration_specifier: public node{
|
class declaration_specifier: public node{
|
||||||
public:
|
public:
|
||||||
using node::node;
|
|
||||||
virtual ir::type* type(ir::module *mod) const = 0;
|
virtual ir::type* type(ir::module *mod) const = 0;
|
||||||
virtual std::vector<STORAGE_SPEC_T> storage() const = 0;
|
virtual std::vector<STORAGE_SPEC_T> storage() const = 0;
|
||||||
};
|
};
|
||||||
|
@@ -275,21 +275,16 @@ statement
|
|||||||
;
|
;
|
||||||
|
|
||||||
compound_statement
|
compound_statement
|
||||||
: '{' '}' { $$ = new compound_statement(nullptr, nullptr); }
|
: '{' '}' { $$ = new compound_statement(nullptr); }
|
||||||
| '{' statement_list '}' { $$ = new compound_statement(nullptr, $2); }
|
| '{' block_item_list '}' { $$ = new compound_statement($2); }
|
||||||
| '{' declaration_list '}' { $$ = new compound_statement($2, nullptr); }
|
|
||||||
| '{' declaration_list statement_list '}' { $$ = new compound_statement($2, $3);}
|
|
||||||
;
|
|
||||||
|
|
||||||
|
block_item_list
|
||||||
|
: block_item { $$ = new list<block_item*>((block_item*)$1); }
|
||||||
|
| block_item_list block_item { $$ = append_ptr_list<block_item>($1, $2); }
|
||||||
|
|
||||||
declaration_list
|
block_item
|
||||||
: declaration { $$ = new list<declaration*>((declaration*)$1); }
|
: declaration { $$ = $1; }
|
||||||
| declaration_list declaration { $$ = append_ptr_list<declaration>($1, $2); }
|
| statement { $$ = $1; }
|
||||||
|
|
||||||
statement_list
|
|
||||||
: statement { $$ = new list<statement*>((statement*)$1); }
|
|
||||||
| statement_list statement { $$ = append_ptr_list<statement>($1, $2); }
|
|
||||||
;
|
|
||||||
|
|
||||||
expression_statement
|
expression_statement
|
||||||
: ';' { $$ = new no_op(); }
|
: ';' { $$ = new no_op(); }
|
||||||
@@ -304,6 +299,8 @@ selection_statement
|
|||||||
|
|
||||||
iteration_statement
|
iteration_statement
|
||||||
: FOR '(' expression_statement expression_statement expression ')' statement { $$ = new iteration_statement($3, $4, $5, $7); }
|
: FOR '(' expression_statement expression_statement expression ')' statement { $$ = new iteration_statement($3, $4, $5, $7); }
|
||||||
|
| FOR '(' declaration expression_statement ')' statement { $$ = new iteration_statement($3, $4, nullptr, $6); }
|
||||||
|
| FOR '(' declaration expression_statement expression ')' statement { $$ = new iteration_statement($3, $4, $5, $7); }
|
||||||
;
|
;
|
||||||
|
|
||||||
jump_statement
|
jump_statement
|
||||||
|
@@ -16,8 +16,8 @@ IS (u|U|l|L)*
|
|||||||
"tunable" { return(TUNABLE); }
|
"tunable" { return(TUNABLE); }
|
||||||
"kernel" { return(KERNEL); }
|
"kernel" { return(KERNEL); }
|
||||||
"restrict" { return(RESTRICT); }
|
"restrict" { return(RESTRICT); }
|
||||||
"readonly" { return(READONLY); }
|
"read_only" { return(READONLY); }
|
||||||
"writeonly" { return(WRITEONLY); }
|
"write_only" { return(WRITEONLY); }
|
||||||
"@" { return(AT); }
|
"@" { return(AT); }
|
||||||
"newaxis" { return(NEWAXIS); }
|
"newaxis" { return(NEWAXIS); }
|
||||||
"if" { return(IF); }
|
"if" { return(IF); }
|
||||||
|
@@ -287,15 +287,8 @@ ir::value* function_definition::codegen(ir::module *mod) const{
|
|||||||
/* Statements */
|
/* Statements */
|
||||||
ir::value* compound_statement::codegen(ir::module* mod) const{
|
ir::value* compound_statement::codegen(ir::module* mod) const{
|
||||||
mod->add_new_scope();
|
mod->add_new_scope();
|
||||||
if(decls_)
|
if(items_)
|
||||||
decls_->codegen(mod);
|
items_->codegen(mod);
|
||||||
if(statements_){
|
|
||||||
for(statement *stmt: statements_->values()){
|
|
||||||
ir::value *current = stmt->codegen(mod);
|
|
||||||
if(is_terminator(current))
|
|
||||||
return current;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
mod->pop_scope();
|
mod->pop_scope();
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
@@ -333,7 +326,8 @@ ir::value* iteration_statement::codegen(ir::module *mod) const{
|
|||||||
ir::basic_block *loop_bb = ir::basic_block::create(ctx, "loop", fn);
|
ir::basic_block *loop_bb = ir::basic_block::create(ctx, "loop", fn);
|
||||||
ir::basic_block *next_bb = ir::basic_block::create(ctx, "postloop", fn);
|
ir::basic_block *next_bb = ir::basic_block::create(ctx, "postloop", fn);
|
||||||
mod->set_continue_fn([&](){
|
mod->set_continue_fn([&](){
|
||||||
exec_->codegen(mod);
|
if(exec_)
|
||||||
|
exec_->codegen(mod);
|
||||||
ir::value *cond = stop_->codegen(mod);
|
ir::value *cond = stop_->codegen(mod);
|
||||||
return builder.create_cond_br(cond, loop_bb, next_bb);
|
return builder.create_cond_br(cond, loop_bb, next_bb);
|
||||||
});
|
});
|
||||||
|
Reference in New Issue
Block a user