From 896e856b071c73a93a58dd116c173a76fa6f5cad Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Wed, 13 Feb 2019 15:41:03 -0500 Subject: [PATCH] [syntax] added support for predicated expressions --- examples/matrix.cpp | 9 ++++++--- include/ast/ast.h | 12 ++++++++---- include/ast/parser.y | 11 +++++++---- include/ast/scanner.l | 9 ++++++--- include/ir/instructions.h | 1 + lib/ast/lowering.cpp | 1 + 6 files changed, 29 insertions(+), 14 deletions(-) diff --git a/examples/matrix.cpp b/examples/matrix.cpp index 938a4eddb..f5334769e 100644 --- a/examples/matrix.cpp +++ b/examples/matrix.cpp @@ -50,6 +50,9 @@ void test(fp32 *a, fp32 *b, fp32 *c, int32 M, int32 N, int32 K){\ fp32* pc[16, 16] = c + rxc[:, newaxis] + ryc[newaxis, :]*M;\ fp32 a[16, 8] = *pa;\ fp32 b[16, 8] = *pb;\ + int1 checkc0[16] = (rxc < M);\ + int1 checkc1[16] = (ryc < N);\ + int1 checkc[16, 16] = checkc0[:, newaxis] && checkc1[newaxis, :];\ for(k = K; k > 0; k = k - 8){\ C = dot(a, b, C);\ pa = pa + 8*M;\ @@ -57,7 +60,7 @@ void test(fp32 *a, fp32 *b, fp32 *c, int32 M, int32 N, int32 K){\ a = *pa;\ b = *pb;\ }\ - *pc = C;\ + @checkc *pc = C;\ }\ "; @@ -215,7 +218,7 @@ int main() { manager.run(llvm_module); std::string src = generate_machine_code(llvm_module, "nvptx64-nvidia-cuda", compute_data_layout(true, true)); -// std::cout << src << std::endl; + std::cout << src << std::endl; // compile machine code CUdevice cu_device; @@ -229,7 +232,7 @@ int main() { // execute machine code // Allocate buffers typedef float numeric_t; - size_t M = 32, N = 32, K = 32; + size_t M = 128, N = 128, K = 128; std::vector c(M*N); std::vector rc(M*N); std::vector a(M*K); diff --git a/include/ast/ast.h b/include/ast/ast.h index 7a2a62563..6471b2296 100644 --- a/include/ast/ast.h +++ b/include/ast/ast.h @@ -51,8 +51,8 @@ enum UNARY_OP_T{ enum TYPE_T{ VOID_T, - UINT8_T, UINT16_T, UINT32_T, UINT64_T, - INT8_T, INT16_T, INT32_T, INT64_T, + UINT1_T, UINT8_T, UINT16_T, UINT32_T, UINT64_T, + INT1_T, INT8_T, INT16_T, INT32_T, INT64_T, FLOAT32_T, FLOAT64_T }; @@ -313,17 +313,21 @@ public: }; class statement: public node{ + +private: + expression *pred_; }; class expression_statement: public statement{ public: - expression_statement(node *expr) - : expr_((expression*)expr){ } + expression_statement(node *expr, node *pred = nullptr) + : expr_((expression*)expr), pred_((expression*)pred){ } ir::value* codegen(ir::module * mod) const; private: expression *expr_; + expression *pred_; }; class compound_statement: public statement{ diff --git a/include/ast/parser.y b/include/ast/parser.y index 0b68443ce..442bee12e 100644 --- a/include/ast/parser.y +++ b/include/ast/parser.y @@ -47,9 +47,9 @@ TYPE_T get_type_spec(node *op) { return ((token*)op)->type; } %token AND_OP OR_OP MUL_ASSIGN DIV_ASSIGN MOD_ASSIGN ADD_ASSIGN %token SUB_ASSIGN LEFT_ASSIGN RIGHT_ASSIGN AND_ASSIGN %token XOR_ASSIGN OR_ASSIGN TYPE_NAME -%token VOID UINT8 UINT16 UINT32 UINT64 INT8 INT16 INT32 INT64 FP32 FP64 +%token VOID UINT1 UINT8 UINT16 UINT32 UINT64 INT1 INT8 INT16 INT32 INT64 FP32 FP64 %token IF ELSE FOR -%token NEWAXIS ELLIPSIS +%token NEWAXIS ELLIPSIS AT %token GET_GLOBAL_RANGE DOT %start translation_unit @@ -62,10 +62,12 @@ TYPE_T get_type_spec(node *op) { return ((token*)op)->type; } type_specifier : VOID { $$ = new token(VOID_T); } + | UINT1 { $$ = new token(UINT1_T); } | UINT8 { $$ = new token(UINT8_T); } | UINT16 { $$ = new token(UINT16_T); } | UINT32 { $$ = new token(UINT32_T); } | UINT64 { $$ = new token(UINT64_T); } + | INT1 { $$ = new token(INT1_T);} | INT8 { $$ = new token(INT8_T); } | INT16 { $$ = new token(INT16_T); } | INT32 { $$ = new token(INT32_T); } @@ -282,11 +284,12 @@ statement_list : statement { $$ = new list((statement*)$1); } | statement_list statement { $$ = append_ptr_list($1, $2); } ; - + expression_statement : ';' { $$ = new no_op(); } | expression ';' { $$ = new expression_statement($1); } - ; + | AT primary_expression expression ';' { $$ = new expression_statement($3, $2); } + ; selection_statement : IF '(' expression ')' statement { $$ = new selection_statement($3, $5); } diff --git a/include/ast/scanner.l b/include/ast/scanner.l index 6b5ed66b0..8e2d89f14 100644 --- a/include/ast/scanner.l +++ b/include/ast/scanner.l @@ -16,15 +16,18 @@ int comment(); %} %% -"newaxis" { count(); return(NEWAXIS); } -"if" { count(); return(IF); } +"@" { count(); return(AT); } +"newaxis" { count(); return(NEWAXIS); } +"if" { count(); return(IF); } "else" { count(); return(ELSE); } -"for" { count(); return(FOR); } +"for" { count(); return(FOR); } "void" { count(); return(VOID); } +"uint1" { count(); return(UINT1); } "uint8" { count(); return(UINT8); } "uint16" { count(); return(UINT16); } "uint32" { count(); return(UINT32); } "uint64" { count(); return(UINT64); } +"int1" { count(); return(INT1); } "int8" { count(); return(INT8); } "int16" { count(); return(INT16); } "int32" { count(); return(INT32); } diff --git a/include/ir/instructions.h b/include/ir/instructions.h index 08f472786..28feeb442 100644 --- a/include/ir/instructions.h +++ b/include/ir/instructions.h @@ -31,6 +31,7 @@ public: private: basic_block *parent_; + value *pred_; }; //===----------------------------------------------------------------------===// diff --git a/lib/ast/lowering.cpp b/lib/ast/lowering.cpp index c9d8c6ff8..36bd50adb 100644 --- a/lib/ast/lowering.cpp +++ b/lib/ast/lowering.cpp @@ -151,6 +151,7 @@ ir::type* declaration_specifier::type(ir::module *mod) const { ir::context &ctx = mod->get_context(); switch (spec_) { case VOID_T: return ir::type::get_void_ty(ctx); + case INT1_T: return ir::type::get_int1_ty(ctx); case INT8_T: return ir::type::get_int8_ty(ctx); case INT16_T: return ir::type::get_int16_ty(ctx); case INT32_T: return ir::type::get_int32_ty(ctx);