diff --git a/examples/matrix.cpp b/examples/matrix.cpp index 0baf844dc..1a7484ab0 100644 --- a/examples/matrix.cpp +++ b/examples/matrix.cpp @@ -38,61 +38,47 @@ extern translation_unit *ast_root; const char src[] = "\ -__constant__ int32* delta = alloc_const int32[16];\ -__constant__ int32* masks = alloc_const int32[16];\ -\ const tunable int32 TM;\ const tunable int32 TN;\ const tunable int32 TK;\ \ -void matmul(restrict readonly fp32 *a, restrict readonly fp32 *b, fp32 *c, int32 M, int32 N, int32 K, int32 bound){\ +void matmul(restrict read_only fp32 *a, restrict read_only fp32 *b, fp32 *c,\ + int32 M, int32 N, int32 K, int32 bound){\ int32 rxa[TM] = get_global_range[TM](0);\ int32 ryb[TN] = get_global_range[TN](1);\ int32 rka[TK] = 0 ... TK;\ int32 rkb[TK] = 0 ... TK;\ - int32 rxc[TM];\ - int32 ryc[TN];\ fp32 C[TM, TN] = 0;\ - int32 k;\ fp32* pa[TM, TK] = a + rka[newaxis, :]*M + rxa[:, newaxis];\ fp32* pb[TN, TK] = b + rkb[newaxis, :]*K + ryb[:, newaxis];\ - fp32* pc[TM, TN];\ fp32 a[TM, TK] = *pa;\ fp32 b[TN, TK] = *pb;\ - int1 checkc0[TM];\ - int1 checkc1[TN];\ - int1 checkc[TM, TN];\ - for(k = K; k > 0; k = k - TK){\ - int1 checka[TM, TK];\ - int1 checkb[TN, TK];\ - int1 checka0[TM];\ - int1 checka1[TK];\ - int1 checkb0[TN];\ - int1 checkb1[TK];\ - C = dot(a, b, C);\ - pa = pa + TK*M;\ - pb = pb + TK*K;\ - checka = k > bound;\ - checkb = k > bound;\ - @checka a = *pa;\ - @checkb b = *pb;\ - if(k > bound)\ - continue;\ - checka0 = rxa < M;\ - checka1 = rka < k;\ - checkb0 = ryb < N;\ - checkb1 = rkb < k;\ - checka = checka0[:, newaxis] && checka1[newaxis, :];\ - checkb = checkb0[:, newaxis] && checkb1[newaxis, :];\ - a = checka ? *pa : 0;\ - b = checkb ? *pb : 0;\ - }\ - rxc = get_global_range[TM](0);\ - ryc = get_global_range[TN](1);\ - pc = c + ryc[newaxis, :]*M + rxc[:, newaxis];\ - checkc0 = rxc < M;\ - checkc1 = ryc < N;\ - checkc = checkc0[:, newaxis] && checkc1[newaxis, :];\ + for(int32 k = K; k > 0;){\ + C = dot(a, b, C);\ + pa = pa + TK*M;\ + pb = pb + TK*K;\ + k = k - TK;\ + int1 checka[TM, TK] = k > bound;\ + int1 checkb[TN, TK] = k > bound;\ + @checka a = *pa;\ + @checkb b = *pb;\ + if(k > bound)\ + continue;\ + int1 checka0[TM] = rxa < M;\ + int1 checka1[TK] = rka < k;\ + int1 checkb0[TN] = ryb < N;\ + int1 checkb1[TK] = rkb < k;\ + checka = checka0[:, newaxis] && checka1[newaxis, :];\ + checkb = checkb0[:, newaxis] && checkb1[newaxis, :];\ + a = checka ? *pa : 0;\ + b = checkb ? *pb : 0;\ + }\ + int32 rxc[TM] = get_global_range[TM](0);\ + int32 ryc[TN] = get_global_range[TN](1);\ + fp32* pc[TM, TN] = c + ryc[newaxis, :]*M + rxc[:, newaxis];\ + int1 checkc0[TM] = rxc < M;\ + int1 checkc1[TN] = ryc < N;\ + int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];\ @checkc *pc = C;\ }\ "; diff --git a/include/triton/ast/ast.h b/include/triton/ast/ast.h index f511ac132..cc77f66b5 100644 --- a/include/triton/ast/ast.h +++ b/include/triton/ast/ast.h @@ -323,7 +323,10 @@ public: class initializer; class declaration_specifier; -class declaration: public node{ +class block_item: public node{ +}; + +class declaration: public block_item{ public: declaration(node *spec, node *init) : spec_((declaration_specifier*)spec), init_((list*)init) { } @@ -335,10 +338,7 @@ public: const list *init_; }; -class statement: public node{ - -private: - expression *pred_; +class statement: public block_item{ }; class expression_statement: public statement{ @@ -353,19 +353,19 @@ private: expression *mask_; }; + class compound_statement: public statement{ typedef list* declarations_t; typedef list* statements_t; public: - compound_statement(node* decls, node* statements) - : decls_((declarations_t)decls), statements_((statements_t)statements) {} + compound_statement(node* items) + : items_((list*)items){} ir::value* codegen(ir::module * mod) const; private: - declarations_t decls_; - statements_t statements_; + list* items_; }; class selection_statement: public statement{ @@ -413,7 +413,6 @@ class no_op: public statement { }; // Types class declaration_specifier: public node{ public: - using node::node; virtual ir::type* type(ir::module *mod) const = 0; virtual std::vector storage() const = 0; }; diff --git a/include/triton/ast/parser.y b/include/triton/ast/parser.y index acd31d995..724f4240b 100644 --- a/include/triton/ast/parser.y +++ b/include/triton/ast/parser.y @@ -275,21 +275,16 @@ statement ; compound_statement - : '{' '}' { $$ = new compound_statement(nullptr, nullptr); } - | '{' statement_list '}' { $$ = new compound_statement(nullptr, $2); } - | '{' declaration_list '}' { $$ = new compound_statement($2, nullptr); } - | '{' declaration_list statement_list '}' { $$ = new compound_statement($2, $3);} - ; + : '{' '}' { $$ = new compound_statement(nullptr); } + | '{' block_item_list '}' { $$ = new compound_statement($2); } +block_item_list + : block_item { $$ = new list((block_item*)$1); } + | block_item_list block_item { $$ = append_ptr_list($1, $2); } -declaration_list - : declaration { $$ = new list((declaration*)$1); } - | declaration_list declaration { $$ = append_ptr_list($1, $2); } - -statement_list - : statement { $$ = new list((statement*)$1); } - | statement_list statement { $$ = append_ptr_list($1, $2); } - ; +block_item + : declaration { $$ = $1; } + | statement { $$ = $1; } expression_statement : ';' { $$ = new no_op(); } @@ -304,6 +299,8 @@ selection_statement iteration_statement : FOR '(' expression_statement expression_statement expression ')' statement { $$ = new iteration_statement($3, $4, $5, $7); } + | FOR '(' declaration expression_statement ')' statement { $$ = new iteration_statement($3, $4, nullptr, $6); } + | FOR '(' declaration expression_statement expression ')' statement { $$ = new iteration_statement($3, $4, $5, $7); } ; jump_statement diff --git a/include/triton/ast/scanner.l b/include/triton/ast/scanner.l index 56cc777a7..e3bf32be0 100644 --- a/include/triton/ast/scanner.l +++ b/include/triton/ast/scanner.l @@ -16,8 +16,8 @@ IS (u|U|l|L)* "tunable" { return(TUNABLE); } "kernel" { return(KERNEL); } "restrict" { return(RESTRICT); } -"readonly" { return(READONLY); } -"writeonly" { return(WRITEONLY); } +"read_only" { return(READONLY); } +"write_only" { return(WRITEONLY); } "@" { return(AT); } "newaxis" { return(NEWAXIS); } "if" { return(IF); } diff --git a/lib/ast/lowering.cpp b/lib/ast/lowering.cpp index 49fe03206..db9c9ed2a 100644 --- a/lib/ast/lowering.cpp +++ b/lib/ast/lowering.cpp @@ -287,15 +287,8 @@ ir::value* function_definition::codegen(ir::module *mod) const{ /* Statements */ ir::value* compound_statement::codegen(ir::module* mod) const{ mod->add_new_scope(); - if(decls_) - decls_->codegen(mod); - if(statements_){ - for(statement *stmt: statements_->values()){ - ir::value *current = stmt->codegen(mod); - if(is_terminator(current)) - return current; - } - } + if(items_) + items_->codegen(mod); mod->pop_scope(); return nullptr; } @@ -333,7 +326,8 @@ ir::value* iteration_statement::codegen(ir::module *mod) const{ ir::basic_block *loop_bb = ir::basic_block::create(ctx, "loop", fn); ir::basic_block *next_bb = ir::basic_block::create(ctx, "postloop", fn); mod->set_continue_fn([&](){ - exec_->codegen(mod); + if(exec_) + exec_->codegen(mod); ir::value *cond = stop_->codegen(mod); return builder.create_cond_br(cond, loop_bb, next_bb); });