[syntax] added support for predicated expressions

This commit is contained in:
Philippe Tillet
2019-02-13 15:41:03 -05:00
parent 32562677e9
commit 896e856b07
6 changed files with 29 additions and 14 deletions

View File

@@ -50,6 +50,9 @@ void test(fp32 *a, fp32 *b, fp32 *c, int32 M, int32 N, int32 K){\
fp32* pc[16, 16] = c + rxc[:, newaxis] + ryc[newaxis, :]*M;\ fp32* pc[16, 16] = c + rxc[:, newaxis] + ryc[newaxis, :]*M;\
fp32 a[16, 8] = *pa;\ fp32 a[16, 8] = *pa;\
fp32 b[16, 8] = *pb;\ fp32 b[16, 8] = *pb;\
int1 checkc0[16] = (rxc < M);\
int1 checkc1[16] = (ryc < N);\
int1 checkc[16, 16] = checkc0[:, newaxis] && checkc1[newaxis, :];\
for(k = K; k > 0; k = k - 8){\ for(k = K; k > 0; k = k - 8){\
C = dot(a, b, C);\ C = dot(a, b, C);\
pa = pa + 8*M;\ pa = pa + 8*M;\
@@ -57,7 +60,7 @@ void test(fp32 *a, fp32 *b, fp32 *c, int32 M, int32 N, int32 K){\
a = *pa;\ a = *pa;\
b = *pb;\ b = *pb;\
}\ }\
*pc = C;\ @checkc *pc = C;\
}\ }\
"; ";
@@ -215,7 +218,7 @@ int main() {
manager.run(llvm_module); manager.run(llvm_module);
std::string src = generate_machine_code(llvm_module, "nvptx64-nvidia-cuda", compute_data_layout(true, true)); std::string src = generate_machine_code(llvm_module, "nvptx64-nvidia-cuda", compute_data_layout(true, true));
// std::cout << src << std::endl; std::cout << src << std::endl;
// compile machine code // compile machine code
CUdevice cu_device; CUdevice cu_device;
@@ -229,7 +232,7 @@ int main() {
// execute machine code // execute machine code
// Allocate buffers // Allocate buffers
typedef float numeric_t; typedef float numeric_t;
size_t M = 32, N = 32, K = 32; size_t M = 128, N = 128, K = 128;
std::vector<numeric_t> c(M*N); std::vector<numeric_t> c(M*N);
std::vector<numeric_t> rc(M*N); std::vector<numeric_t> rc(M*N);
std::vector<numeric_t> a(M*K); std::vector<numeric_t> a(M*K);

View File

@@ -51,8 +51,8 @@ enum UNARY_OP_T{
enum TYPE_T{ enum TYPE_T{
VOID_T, VOID_T,
UINT8_T, UINT16_T, UINT32_T, UINT64_T, UINT1_T, UINT8_T, UINT16_T, UINT32_T, UINT64_T,
INT8_T, INT16_T, INT32_T, INT64_T, INT1_T, INT8_T, INT16_T, INT32_T, INT64_T,
FLOAT32_T, FLOAT64_T FLOAT32_T, FLOAT64_T
}; };
@@ -313,17 +313,21 @@ public:
}; };
class statement: public node{ class statement: public node{
private:
expression *pred_;
}; };
class expression_statement: public statement{ class expression_statement: public statement{
public: public:
expression_statement(node *expr) expression_statement(node *expr, node *pred = nullptr)
: expr_((expression*)expr){ } : expr_((expression*)expr), pred_((expression*)pred){ }
ir::value* codegen(ir::module * mod) const; ir::value* codegen(ir::module * mod) const;
private: private:
expression *expr_; expression *expr_;
expression *pred_;
}; };
class compound_statement: public statement{ class compound_statement: public statement{

View File

@@ -47,9 +47,9 @@ TYPE_T get_type_spec(node *op) { return ((token*)op)->type; }
%token AND_OP OR_OP MUL_ASSIGN DIV_ASSIGN MOD_ASSIGN ADD_ASSIGN %token AND_OP OR_OP MUL_ASSIGN DIV_ASSIGN MOD_ASSIGN ADD_ASSIGN
%token SUB_ASSIGN LEFT_ASSIGN RIGHT_ASSIGN AND_ASSIGN %token SUB_ASSIGN LEFT_ASSIGN RIGHT_ASSIGN AND_ASSIGN
%token XOR_ASSIGN OR_ASSIGN TYPE_NAME %token XOR_ASSIGN OR_ASSIGN TYPE_NAME
%token VOID UINT8 UINT16 UINT32 UINT64 INT8 INT16 INT32 INT64 FP32 FP64 %token VOID UINT1 UINT8 UINT16 UINT32 UINT64 INT1 INT8 INT16 INT32 INT64 FP32 FP64
%token IF ELSE FOR %token IF ELSE FOR
%token NEWAXIS ELLIPSIS %token NEWAXIS ELLIPSIS AT
%token GET_GLOBAL_RANGE DOT %token GET_GLOBAL_RANGE DOT
%start translation_unit %start translation_unit
@@ -62,10 +62,12 @@ TYPE_T get_type_spec(node *op) { return ((token*)op)->type; }
type_specifier type_specifier
: VOID { $$ = new token(VOID_T); } : VOID { $$ = new token(VOID_T); }
| UINT1 { $$ = new token(UINT1_T); }
| UINT8 { $$ = new token(UINT8_T); } | UINT8 { $$ = new token(UINT8_T); }
| UINT16 { $$ = new token(UINT16_T); } | UINT16 { $$ = new token(UINT16_T); }
| UINT32 { $$ = new token(UINT32_T); } | UINT32 { $$ = new token(UINT32_T); }
| UINT64 { $$ = new token(UINT64_T); } | UINT64 { $$ = new token(UINT64_T); }
| INT1 { $$ = new token(INT1_T);}
| INT8 { $$ = new token(INT8_T); } | INT8 { $$ = new token(INT8_T); }
| INT16 { $$ = new token(INT16_T); } | INT16 { $$ = new token(INT16_T); }
| INT32 { $$ = new token(INT32_T); } | INT32 { $$ = new token(INT32_T); }
@@ -282,11 +284,12 @@ statement_list
: statement { $$ = new list<statement*>((statement*)$1); } : statement { $$ = new list<statement*>((statement*)$1); }
| statement_list statement { $$ = append_ptr_list<statement>($1, $2); } | statement_list statement { $$ = append_ptr_list<statement>($1, $2); }
; ;
expression_statement expression_statement
: ';' { $$ = new no_op(); } : ';' { $$ = new no_op(); }
| expression ';' { $$ = new expression_statement($1); } | expression ';' { $$ = new expression_statement($1); }
; | AT primary_expression expression ';' { $$ = new expression_statement($3, $2); }
;
selection_statement selection_statement
: IF '(' expression ')' statement { $$ = new selection_statement($3, $5); } : IF '(' expression ')' statement { $$ = new selection_statement($3, $5); }

View File

@@ -16,15 +16,18 @@ int comment();
%} %}
%% %%
"newaxis" { count(); return(NEWAXIS); } "@" { count(); return(AT); }
"if" { count(); return(IF); } "newaxis" { count(); return(NEWAXIS); }
"if" { count(); return(IF); }
"else" { count(); return(ELSE); } "else" { count(); return(ELSE); }
"for" { count(); return(FOR); } "for" { count(); return(FOR); }
"void" { count(); return(VOID); } "void" { count(); return(VOID); }
"uint1" { count(); return(UINT1); }
"uint8" { count(); return(UINT8); } "uint8" { count(); return(UINT8); }
"uint16" { count(); return(UINT16); } "uint16" { count(); return(UINT16); }
"uint32" { count(); return(UINT32); } "uint32" { count(); return(UINT32); }
"uint64" { count(); return(UINT64); } "uint64" { count(); return(UINT64); }
"int1" { count(); return(INT1); }
"int8" { count(); return(INT8); } "int8" { count(); return(INT8); }
"int16" { count(); return(INT16); } "int16" { count(); return(INT16); }
"int32" { count(); return(INT32); } "int32" { count(); return(INT32); }

View File

@@ -31,6 +31,7 @@ public:
private: private:
basic_block *parent_; basic_block *parent_;
value *pred_;
}; };
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@@ -151,6 +151,7 @@ ir::type* declaration_specifier::type(ir::module *mod) const {
ir::context &ctx = mod->get_context(); ir::context &ctx = mod->get_context();
switch (spec_) { switch (spec_) {
case VOID_T: return ir::type::get_void_ty(ctx); case VOID_T: return ir::type::get_void_ty(ctx);
case INT1_T: return ir::type::get_int1_ty(ctx);
case INT8_T: return ir::type::get_int8_ty(ctx); case INT8_T: return ir::type::get_int8_ty(ctx);
case INT16_T: return ir::type::get_int16_ty(ctx); case INT16_T: return ir::type::get_int16_ty(ctx);
case INT32_T: return ir::type::get_int32_ty(ctx); case INT32_T: return ir::type::get_int32_ty(ctx);