diff --git a/examples/matrix.cpp b/examples/matrix.cpp index 1a7484ab0..ec818ae58 100644 --- a/examples/matrix.cpp +++ b/examples/matrix.cpp @@ -36,52 +36,52 @@ extern void yy_delete_buffer(YY_BUFFER_STATE buffer); using triton::ast::translation_unit; extern translation_unit *ast_root; -const char src[] = -"\ -const tunable int32 TM;\ -const tunable int32 TN;\ -const tunable int32 TK;\ -\ -void matmul(restrict read_only fp32 *a, restrict read_only fp32 *b, fp32 *c,\ - int32 M, int32 N, int32 K, int32 bound){\ - int32 rxa[TM] = get_global_range[TM](0);\ - int32 ryb[TN] = get_global_range[TN](1);\ - int32 rka[TK] = 0 ... TK;\ - int32 rkb[TK] = 0 ... TK;\ - fp32 C[TM, TN] = 0;\ - fp32* pa[TM, TK] = a + rka[newaxis, :]*M + rxa[:, newaxis];\ - fp32* pb[TN, TK] = b + rkb[newaxis, :]*K + ryb[:, newaxis];\ - fp32 a[TM, TK] = *pa;\ - fp32 b[TN, TK] = *pb;\ - for(int32 k = K; k > 0;){\ - C = dot(a, b, C);\ - pa = pa + TK*M;\ - pb = pb + TK*K;\ - k = k - TK;\ - int1 checka[TM, TK] = k > bound;\ - int1 checkb[TN, TK] = k > bound;\ - @checka a = *pa;\ - @checkb b = *pb;\ - if(k > bound)\ - continue;\ - int1 checka0[TM] = rxa < M;\ - int1 checka1[TK] = rka < k;\ - int1 checkb0[TN] = ryb < N;\ - int1 checkb1[TK] = rkb < k;\ - checka = checka0[:, newaxis] && checka1[newaxis, :];\ - checkb = checkb0[:, newaxis] && checkb1[newaxis, :];\ - a = checka ? *pa : 0;\ - b = checkb ? *pb : 0;\ - }\ - int32 rxc[TM] = get_global_range[TM](0);\ - int32 ryc[TN] = get_global_range[TN](1);\ - fp32* pc[TM, TN] = c + ryc[newaxis, :]*M + rxc[:, newaxis];\ - int1 checkc0[TM] = rxc < M;\ - int1 checkc1[TN] = ryc < N;\ - int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];\ - @checkc *pc = C;\ -}\ -"; +const char* src = +R"( +const tunable int32 TM; +const tunable int32 TN; +const tunable int32 TK; + +void matmul(restrict read_only fp32 *a, restrict read_only fp32 *b, fp32 *c, + int32 M, int32 N, int32 K, int32 bound){ + int32 rxa[TM] = get_global_range[TM](0) + int32 ryb[TN] = get_global_range[TN](1); + int32 rka[TK] = 0 ... TK; + int32 rkb[TK] = 0 ... TK; + fp32 C[TM, TN] = 0; + fp32* pa[TM, TK] = a + rka[newaxis, :]*M + rxa[:, newaxis]; + fp32* pb[TN, TK] = b + rkb[newaxis, :]*K + ryb[:, newaxis]; + fp32 a[TM, TK] = *pa; + fp32 b[TN, TK] = *pb; + for(int32 k = K; k > 0;){ + C = dot(a, b, C); + pa = pa + TK*M; + pb = pb + TK*K; + k = k - TK; + int1 checka[TM, TK] = k > bound; + int1 checkb[TN, TK] = k > bound; + @checka a = *pa; + @checkb b = *pb; + if(k > bound) + continue; + int1 checka0[TM] = rxa < M; + int1 checka1[TK] = rka < k; + int1 checkb0[TN] = ryb < N; + int1 checkb1[TK] = rkb < k; + checka = checka0[:, newaxis] && checka1[newaxis, :]; + checkb = checkb0[:, newaxis] && checkb1[newaxis, :]; + a = checka ? *pa : 0; + b = checkb ? *pb : 0; + } + int32 rxc[TM] = get_global_range[TM](0); + int32 ryc[TN] = get_global_range[TN](1); + fp32* pc[TM, TN] = c + ryc[newaxis, :]*M + rxc[:, newaxis]; + int1 checkc0[TM] = rxc < M; + int1 checkc1[TN] = ryc < N; + int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :]; + @checkc *pc = C; +} +)"; static std::string compute_data_layout(bool is64Bit, bool UseShortPointers) { std::string Ret = "e"; diff --git a/include/triton/ast/ast.h b/include/triton/ast/ast.h index cc77f66b5..b286c5a79 100644 --- a/include/triton/ast/ast.h +++ b/include/triton/ast/ast.h @@ -599,6 +599,12 @@ private: list decls_; }; +void update_location(const char *t); +void print_error(const char *error); +char return_impl(char t, const char * yytext); +yytokentype return_impl(yytokentype t, const char * yytext); +void return_void(const char * yytext); + } } diff --git a/include/triton/ast/parser.y b/include/triton/ast/parser.y index 724f4240b..8ce55f372 100644 --- a/include/triton/ast/parser.y +++ b/include/triton/ast/parser.y @@ -1,3 +1,5 @@ +%define parse.error verbose + %{ namespace triton{ namespace ast{ @@ -8,7 +10,6 @@ using namespace triton::ast; #define YYSTYPE node* #include "../include/triton/ast/ast.h" -#define YYERROR_VERBOSE 1 extern char* yytext; void yyerror(const char *s); int yylex(void); @@ -44,7 +45,7 @@ UNARY_OP_T get_unary_op(node *op) { return ((token*)op)->unary_op; } TYPE_T get_type_spec(node *op) { return ((token*)op)->type; } STORAGE_SPEC_T get_storage_spec(node *op) { return ((token*)op)->storage_spec;} %} - + %token IDENTIFIER CONSTANT STRING_LITERAL %token TUNABLE KERNEL RESTRICT READONLY WRITEONLY CONST CONSTANT_SPACE %token PTR_OP INC_OP DEC_OP LEFT_OP RIGHT_OP LE_OP GE_OP EQ_OP NE_OP @@ -385,3 +386,7 @@ function_definition : declaration_specifiers declarator compound_statement { $$ = new function_definition($1, $2, $3); } ; +%% +void yyerror (const char *s){ + print_error(s); +} diff --git a/include/triton/ast/scanner.l b/include/triton/ast/scanner.l index e3bf32be0..91b700655 100644 --- a/include/triton/ast/scanner.l +++ b/include/triton/ast/scanner.l @@ -8,107 +8,104 @@ IS (u|U|l|L)* %{ #include #include "parser.hpp" +#include "../include/triton/ast/ast.h" +using triton::ast::return_impl; +using triton::ast::return_void; %} %% -"__constant__" { return(CONSTANT_SPACE); } -"const" { return(CONST); } -"tunable" { return(TUNABLE); } -"kernel" { return(KERNEL); } -"restrict" { return(RESTRICT); } -"read_only" { return(READONLY); } -"write_only" { return(WRITEONLY); } -"@" { return(AT); } -"newaxis" { return(NEWAXIS); } -"if" { return(IF); } -"else" { return(ELSE); } -"for" { return(FOR); } -"void" { return(VOID); } -"uint1" { return(UINT1); } -"uint8" { return(UINT8); } -"uint16" { return(UINT16); } -"uint32" { return(UINT32); } -"uint64" { return(UINT64); } -"int1" { return(INT1); } -"int8" { return(INT8); } -"int16" { return(INT16); } -"int32" { return(INT32); } -"int64" { return(INT64); } -"fp32" { return(FP32); } -"fp64" { return(FP64); } -"..." { return(ELLIPSIS); } -"get_global_range" { return GET_GLOBAL_RANGE; } -"dot" { return DOT;} -"continue" { return(CONTINUE); } -"alloc_const" { return(ALLOC_CONST); } -{L}({L}|{D})* { return(IDENTIFIER); } +"__constant__" { return return_impl(CONSTANT_SPACE, yytext); } +"const" { return return_impl(CONST, yytext); } +"tunable" { return return_impl(TUNABLE, yytext); } +"kernel" { return return_impl(KERNEL, yytext); } +"restrict" { return return_impl(RESTRICT, yytext); } +"read_only" { return return_impl(READONLY, yytext); } +"write_only" { return return_impl(WRITEONLY, yytext); } +"@" { return return_impl(AT, yytext); } +"newaxis" { return return_impl(NEWAXIS, yytext); } +"if" { return return_impl(IF, yytext); } +"else" { return return_impl(ELSE, yytext); } +"for" { return return_impl(FOR, yytext); } +"void" { return return_impl(VOID, yytext); } +"uint1" { return return_impl(UINT1, yytext); } +"uint8" { return return_impl(UINT8, yytext); } +"uint16" { return return_impl(UINT16, yytext); } +"uint32" { return return_impl(UINT32, yytext); } +"uint64" { return return_impl(UINT64, yytext); } +"int1" { return return_impl(INT1, yytext); } +"int8" { return return_impl(INT8, yytext); } +"int16" { return return_impl(INT16, yytext); } +"int32" { return return_impl(INT32, yytext); } +"int64" { return return_impl(INT64, yytext); } +"fp32" { return return_impl(FP32, yytext); } +"fp64" { return return_impl(FP64, yytext); } +"..." { return return_impl(ELLIPSIS, yytext); } +"get_global_range" { return return_impl(GET_GLOBAL_RANGE, yytext); } +"dot" { return return_impl(DOT, yytext); } +"continue" { return return_impl(CONTINUE, yytext); } +"alloc_const" { return return_impl(ALLOC_CONST, yytext); } +{L}({L}|{D})* { return return_impl(IDENTIFIER, yytext); } -0[xX]{H}+{IS}? { return(CONSTANT); } -0{D}+{IS}? { return(CONSTANT); } -{D}+{IS}? { return(CONSTANT); } -L?'(\\.|[^\\'])+' { return(CONSTANT); } +0[xX]{H}+{IS}? { return return_impl(CONSTANT, yytext); } +0{D}+{IS}? { return return_impl(CONSTANT, yytext); } +{D}+{IS}? { return return_impl(CONSTANT, yytext); } +L?'(\\.|[^\\'])+' { return return_impl(CONSTANT, yytext); } -{D}+{E}{FS}? { return(CONSTANT); } -{D}*"."{D}+({E})?{FS}? { return(CONSTANT); } -{D}+"."{D}*({E})?{FS}? { return(CONSTANT); } +{D}+{E}{FS}? { return return_impl(CONSTANT, yytext); } +{D}*"."{D}+({E})?{FS}? { return return_impl(CONSTANT, yytext); } +{D}+"."{D}*({E})?{FS}? { return return_impl(CONSTANT, yytext); } -L?\"(\\.|[^\\"])*\" { return(STRING_LITERAL); } +L?\"(\\.|[^\\"])*\" { return return_impl(STRING_LITERAL, yytext); } -">>=" { return(RIGHT_ASSIGN); } -"<<=" { return(LEFT_ASSIGN); } -"+=" { return(ADD_ASSIGN); } -"-=" { return(SUB_ASSIGN); } -"*=" { return(MUL_ASSIGN); } -"/=" { return(DIV_ASSIGN); } -"%=" { return(MOD_ASSIGN); } -"&=" { return(AND_ASSIGN); } -"^=" { return(XOR_ASSIGN); } -"|=" { return(OR_ASSIGN); } -">>" { return(RIGHT_OP); } -"<<" { return(LEFT_OP); } -"++" { return(INC_OP); } -"--" { return(DEC_OP); } -"->" { return(PTR_OP); } -"&&" { return(AND_OP); } -"||" { return(OR_OP); } -"<=" { return(LE_OP); } -">=" { return(GE_OP); } -"==" { return(EQ_OP); } -"!=" { return(NE_OP); } -";" { return(';'); } -("{"|"<%") { return('{'); } -("}"|"%>") { return('}'); } -"," { return(','); } -":" { return(':'); } -"=" { return('='); } -"(" { return('('); } -")" { return(')'); } -("["|"<:") { return('['); } -("]"|":>") { return(']'); } -"." { return('.'); } -"&" { return('&'); } -"!" { return('!'); } -"~" { return('~'); } -"-" { return('-'); } -"+" { return('+'); } -"*" { return('*'); } -"/" { return('/'); } -"%" { return('%'); } -"<" { return('<'); } -">" { return('>'); } -"^" { return('^'); } -"|" { return('|'); } -"?" { return('?'); } - -[ \t\v\n\f] { } +">>=" { return return_impl(RIGHT_ASSIGN, yytext); } +"<<=" { return return_impl(LEFT_ASSIGN, yytext); } +"+=" { return return_impl(ADD_ASSIGN, yytext); } +"-=" { return return_impl(SUB_ASSIGN, yytext); } +"*=" { return return_impl(MUL_ASSIGN, yytext); } +"/=" { return return_impl(DIV_ASSIGN, yytext); } +"%=" { return return_impl(MOD_ASSIGN, yytext); } +"&=" { return return_impl(AND_ASSIGN, yytext); } +"^=" { return return_impl(XOR_ASSIGN, yytext); } +"|=" { return return_impl(OR_ASSIGN, yytext); } +">>" { return return_impl(RIGHT_OP, yytext); } +"<<" { return return_impl(LEFT_OP, yytext); } +"++" { return return_impl(INC_OP, yytext); } +"--" { return return_impl(DEC_OP, yytext); } +"->" { return return_impl(PTR_OP, yytext); } +"&&" { return return_impl(AND_OP, yytext); } +"||" { return return_impl(OR_OP, yytext); } +"<=" { return return_impl(LE_OP, yytext); } +">=" { return return_impl(GE_OP, yytext); } +"==" { return return_impl(EQ_OP, yytext); } +"!=" { return return_impl(NE_OP, yytext); } +";" { return return_impl(';', yytext); } +("{"|"<%") { return return_impl('{', yytext); } +("}"|"%>") { return return_impl('}', yytext); } +"," { return return_impl(',', yytext); } +":" { return return_impl(':', yytext); } +"=" { return return_impl('=', yytext); } +"(" { return return_impl('(', yytext); } +")" { return return_impl(')', yytext); } +("["|"<:") { return return_impl('[', yytext); } +("]"|":>") { return return_impl(']', yytext); } +"." { return return_impl('.', yytext); } +"&" { return return_impl('&', yytext); } +"!" { return return_impl('!', yytext); } +"~" { return return_impl('~', yytext); } +"-" { return return_impl('-', yytext); } +"+" { return return_impl('+', yytext); } +"*" { return return_impl('*', yytext); } +"/" { return return_impl('/', yytext); } +"%" { return return_impl('%', yytext); } +"<" { return return_impl('<', yytext); } +">" { return return_impl('>', yytext); } +"^" { return return_impl('^', yytext); } +"|" { return return_impl('|', yytext); } +"?" { return return_impl('?', yytext); } +[ \t\v\n\f] { return_void(yytext);} . { /* ignore bad characters */ } %% int yywrap() { return(1); } - -void yyerror (const char *s) /* Called by yyparse on error */ -{ - printf ("Error: %s\n", s); -} diff --git a/lib/ast/lowering.cpp b/lib/ast/lowering.cpp index db9c9ed2a..cac80b262 100644 --- a/lib/ast/lowering.cpp +++ b/lib/ast/lowering.cpp @@ -8,12 +8,16 @@ #include "triton/ir/builder.h" #include "triton/ir/type.h" #include +#include namespace triton{ namespace ast{ +static int current_line = 0; +static int current_column = 0; + /* node */ ir::value *node::explicit_cast(ir::builder &builder, ir::value *src, ir::type *dst_ty){ ir::type *src_scalar_ty = src->get_type()->get_scalar_ty(); @@ -705,6 +709,42 @@ ir::value* named_expression::codegen(ir::module *mod) const{ } +// begin token +void update_location(const char *text) { + for (int i = 0; text[i] != '\0'; i++){ + if (text[i] == '\n'){ + current_column = 0; + current_line++; + } + else if (text[i] == '\t') + current_column += 8 - (current_column % 8); + else + current_column++; + } +} + +void print_error(const char *cerror) { + std::string error(cerror); + auto it = error.find("syntax error,"); + error.replace(it, 13, ""); + std::cerr << "error at line " << current_line << " (column " << current_column << "): " << error << std::endl; + throw std::runtime_error("compilation failed"); +} + +char return_impl(char t, const char * yytext) { + update_location(yytext); + return t; +} + +yytokentype return_impl(yytokentype t, const char * yytext){ + update_location(yytext); + return t; +} + +void return_void(const char * yytext){ + update_location(yytext); +} + } }