From 80d019ec16cc461bf9916d53a5f2b919c5fdf0da Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Thu, 10 Jan 2019 23:53:27 -0500 Subject: [PATCH] [syntax tree] added syntactic support for dereferencing --- examples/matrix.cpp | 22 ++++++++++++++-------- include/ast/ast.h | 32 +++++++++++++++++++++----------- include/ast/parser.y | 3 ++- include/ast/scanner.l | 1 + include/codegen/tune.h | 1 + include/ir/builder.h | 5 ++++- include/ir/instructions.h | 25 ++++++++++++++++++++----- lib/ast/lowering.cpp | 30 +++++++++++++++++++++--------- lib/codegen/tune.cpp | 13 +++++++------ lib/ir/builder.cpp | 10 +++++++++- lib/ir/constant.cpp | 8 +++++--- lib/ir/instructions.cpp | 23 +++++++++++++++++++++++ 12 files changed, 128 insertions(+), 45 deletions(-) diff --git a/examples/matrix.cpp b/examples/matrix.cpp index e2ea19527..e9d380f39 100644 --- a/examples/matrix.cpp +++ b/examples/matrix.cpp @@ -20,19 +20,25 @@ extern translation_unit *ast_root; const char src[] = "\ -void test(fp32 *A, fp32 *B, fp32 *C, int32 M, int32 N, int32 K){\ +void test(fp32 *a, fp32 *b, fp32 *c, int32 M, int32 N, int32 K){\ int32 rx[16] = get_global_range[16](0);\ int32 ry[16] = get_global_range[16](1);\ - int32 rk[8] = 0 ... 8;\ - fp32 c[16, 16] = 0;\ + int32 rka[8] = 0 ... 8;\ + int32 rkb[8] = 0 ... 8;\ + fp32 C[16, 16] = 0;\ int32 k;\ - fp32* pa[16, 8] = A + rx[:, newaxis] + rk[newaxis, :]*M;\ - fp32* pb[16, 8] = B + ry[:, newaxis] + rk[newaxis, :]*K;\ + fp32* pa[16, 8] = a + rx[:, newaxis] + rka[newaxis, :]*M;\ + fp32* pb[16, 8] = b + ry[:, newaxis] + rkb[newaxis, :]*K;\ + fp32* pc[16, 16];\ for(k = 0; k < K; k = k + 8){\ - fp32 a[16, 8] = *pa;\ - fp32 b[16, 8] = *pb;\ - pa = pa + 8;\ + fp32 A[16, 8] = *pa;\ + fp32 B[16, 8] = *pb;\ + C = dot(A, B, C);\ + pa = pa + 8*M;\ + pb = pb + 8*K;\ }\ + pc = c + rx[:, newaxis] + ry[newaxis, :];\ + *pc = C;\ }\ "; diff --git a/include/ast/ast.h b/include/ast/ast.h index e3da73802..dea0af010 100644 --- a/include/ast/ast.h +++ b/include/ast/ast.h @@ -137,6 +137,19 @@ private: const constant* axis_; }; +class matmul_expression: public builtin_expression{ +public: + matmul_expression(node* A, node *B, node *C): + A_((expression*)A), B_((expression*)B), C_((expression*)C) { } + ir::value* codegen(ir::module *) const; + +private: + const expression *A_; + const expression *B_; + const expression *C_; +}; + + class indexing_expression: public postfix_expression{ public: indexing_expression(node *id, node *slices) @@ -149,21 +162,17 @@ private: const list* slices_; }; -class unary_expression: public expression{ + +class named_expression: public expression { public: - unary_expression(node *id): id_((const identifier*)id) {} - const identifier *id() const; + named_expression(node *id): id_((const identifier*)id) { } + const identifier *id() const { return id_; } + ir::value* codegen(ir::module * mod) const; private: const identifier *id_; }; -class named_expression: public unary_expression { -public: - named_expression(node *id): unary_expression(id){ } - ir::value* codegen(ir::module * mod) const; -}; - class binary_operator: public expression{ private: ir::value* llvm_op(ir::module *mod, ir::builder &bld, ir::value *lhs, ir::value *rhs, const std::string &name) const; @@ -220,6 +229,7 @@ public: : op_(op), arg_((expression*)arg) { } + UNARY_OP_T get_op() const { return op_; } ir::value* codegen(ir::module *mod) const; private: @@ -267,13 +277,13 @@ public: class assignment_expression: public expression{ public: assignment_expression(node *lvalue, ASSIGN_OP_T op, node *rvalue) - : lvalue_((unary_expression*)lvalue), op_(op), rvalue_((expression*)rvalue) { } + : lvalue_((named_expression*)lvalue), op_(op), rvalue_((expression*)rvalue) { } ir::value* codegen(ir::module *mod) const; public: ASSIGN_OP_T op_; - const unary_expression *lvalue_; + const expression *lvalue_; const expression *rvalue_; }; diff --git a/include/ast/parser.y b/include/ast/parser.y index 351556ff7..0b68443ce 100644 --- a/include/ast/parser.y +++ b/include/ast/parser.y @@ -50,7 +50,7 @@ TYPE_T get_type_spec(node *op) { return ((token*)op)->type; } %token VOID UINT8 UINT16 UINT32 UINT64 INT8 INT16 INT32 INT64 FP32 FP64 %token IF ELSE FOR %token NEWAXIS ELLIPSIS -%token GET_GLOBAL_RANGE +%token GET_GLOBAL_RANGE DOT %start translation_unit %% @@ -111,6 +111,7 @@ identifier builtin : GET_GLOBAL_RANGE '[' constant ']' '(' constant ')' { $$ = new get_global_range($3, $6); } + | DOT '(' expression ',' expression ',' expression ')' { $$ = new matmul_expression($3, $5, $7); } primary_expression : identifier { $$ = new named_expression($1); } diff --git a/include/ast/scanner.l b/include/ast/scanner.l index 55366859c..6b5ed66b0 100644 --- a/include/ast/scanner.l +++ b/include/ast/scanner.l @@ -33,6 +33,7 @@ int comment(); "fp64" { count(); return(FP64); } "..." { count(); return(ELLIPSIS); } "get_global_range" { count(); return GET_GLOBAL_RANGE; } +"dot" { count(); return DOT;} {L}({L}|{D})* { count(); return(check_type()); } diff --git a/include/codegen/tune.h b/include/codegen/tune.h index f7ce1b10d..5904fddcf 100644 --- a/include/codegen/tune.h +++ b/include/codegen/tune.h @@ -37,6 +37,7 @@ private: std::vector pool_; graph_t dependencies_; std::set nodes_; + std::map static_params_; }; diff --git a/include/ir/builder.h b/include/ir/builder.h index a96408f18..c84ef02d8 100644 --- a/include/ir/builder.h +++ b/include/ir/builder.h @@ -105,14 +105,17 @@ public: // Side effects value *create_fneg(value *arg, const std::string &name = ""); value *create_neg(value *arg, const std::string &name = ""); - value *create_load(value *arg, const std::string &name = ""); value *create_not(value *arg, const std::string &name = ""); + // Input/Output + value *create_load(value *arg, const std::string &name = ""); + value *create_store(value *ptr, value *val, const std::string &name = ""); // Tile instruction value *create_splat(value *arg, const std::vector &shapes, const std::string &name = ""); value *create_reshape(value *arg, const std::vector &shapes, const std::string &name = ""); value *create_broadcast(value *arg, const std::vector &shapes, const std::string &name = ""); // Built-in instruction value *create_get_global_range(unsigned axis, unsigned size, const std::string &name = ""); + value *create_matmul(value *A, value *B, value *C, const std::string &name = ""); private: context &ctx_; diff --git a/include/ir/instructions.h b/include/ir/instructions.h index e700bea04..fa7a0d6e3 100644 --- a/include/ir/instructions.h +++ b/include/ir/instructions.h @@ -291,7 +291,18 @@ public: // factory method static load_inst* create(value *ptr, const std::string &name = "", instruction *next = nullptr); +}; +class store_inst: public instruction{ +private: + store_inst(value *ptr, value *v, const std::string &name, instruction *next); + +public: + value *get_pointer_operand() { return get_operand(0); } + value *get_value_operand() { return get_operand(1); } + // factory method + static store_inst* create(value* ptr, value *v, const std::string &name = "", + instruction *next = nullptr); }; //===----------------------------------------------------------------------===// @@ -336,11 +347,6 @@ public: }; -// matmul -class matmul_inst: public instruction { - -}; - // built-in class builtin_inst: public instruction{ protected: @@ -359,6 +365,15 @@ private: unsigned axis_; }; +class matmul_inst: public builtin_inst { + matmul_inst(value *A, value *B, value *C, const std::string &name, instruction *next); + +public: + static instruction* create(value *A, value *B, value *C, + const std::string &name = "", + instruction *next = nullptr); +}; + } } diff --git a/lib/ast/lowering.cpp b/lib/ast/lowering.cpp index 14a5249cc..fcce76729 100644 --- a/lib/ast/lowering.cpp +++ b/lib/ast/lowering.cpp @@ -318,9 +318,9 @@ ir::value* initializer::codegen(ir::module * mod) const{ std::string name = decl_->id()->name(); ir::value *value = ir::undef_value::get(ty); if(expr_){ - ir::value* target = expr_->codegen(mod); - explicit_cast(mod->get_builder(), target, ty->get_scalar_ty()); - implicit_broadcast(mod, value, target); + value = expr_->codegen(mod); + explicit_cast(mod->get_builder(), value, ty->get_scalar_ty()); + implicit_broadcast(mod, value, value); } value->set_name(name); mod->set_value(name, value); @@ -421,12 +421,23 @@ ir::value* binary_operator::codegen(ir::module *mod) const{ } /* Builtin expression */ + +// get_global_range ir::value* get_global_range::codegen(ir::module *mod) const { ir::builder &builder = mod->get_builder(); return builder.create_get_global_range(axis_->value(), size_->value()); } +ir::value* matmul_expression::codegen(ir::module *mod) const { + ir::value *A = A_->codegen(mod); + ir::value *B = B_->codegen(mod); + ir::value *C = C_->codegen(mod); + return mod->get_builder().create_matmul(A, B, C); +} + + + /* Postfix expression */ ir::value* indexing_expression::codegen(ir::module *mod) const{ ir::value *in = mod->get_value(id_->name()); @@ -497,7 +508,13 @@ ir::value *conditional_expression::codegen(ir::module *mod) const{ /* Assignment expression */ ir::value *assignment_expression::codegen(ir::module *mod) const{ ir::value *rvalue = rvalue_->codegen(mod); - mod->set_value(lvalue_->id()->name(), rvalue); + if(auto *x = dynamic_cast(lvalue_)) + mod->set_value(x->id()->name(), rvalue); + else if(auto* x = dynamic_cast(lvalue_)){ + assert(x->get_op()==DEREF); + ir::value *ptr = x->codegen(mod); + mod->get_builder().create_store(ptr, rvalue); + } return rvalue; } @@ -527,11 +544,6 @@ ir::value* constant_range::codegen(ir::module *mod) const{ (ir::constant*)last_->codegen(mod)); } -/* Unary expression */ -const identifier* unary_expression::id() const{ - return id_; -} - /* Named */ ir::value* named_expression::codegen(ir::module *mod) const{ const std::string &name = id()->name(); diff --git a/lib/codegen/tune.cpp b/lib/codegen/tune.cpp index fcc5930a5..deb2f858b 100644 --- a/lib/codegen/tune.cpp +++ b/lib/codegen/tune.cpp @@ -33,7 +33,9 @@ void tune::init_c_graph(ir::instruction *v) { ir::value *op = v->get_operand(0); unsigned current = 0; for(unsigned i = 0; i < shapes.size(); i ++) - if(shapes[i] > 1) + if(shapes[i] == 1) + static_params_.insert({{v, i}, 1}); + else add_constraint({v, i}, {op, current++}); } else if(dynamic_cast(v)){ @@ -70,10 +72,10 @@ void tune::connected_components(node_t x, const std::vector vals, st params_[instr].insert({"p1" + suffix, vals[1]}); params_[instr].insert({"p2" + suffix, vals[2]}); } - if(auto *cst = dynamic_cast(x.first)){ - *vals[0] = cst->get_value(); - *vals[1] = cst->get_value(); - *vals[2] = cst->get_value(); + if(static_params_.find(x) != static_params_.end()){ + *vals[0] = static_params_.at(x); + *vals[1] = static_params_.at(x); + *vals[2] = static_params_.at(x); } for(const node_t &y: graph[x]) connected_components(y, vals, nodes, graph); @@ -88,7 +90,6 @@ void tune::get_params(ir::module &mod, std::vector &result) { for(ir::instruction *i : block->get_inst_list()) for(auto &x: params_[i]) if(seen.insert(x.second).second && *x.second == 0){ - std::cout << typeid(*i).name() << " " << i << std::endl; result.push_back(x.second); } } diff --git a/lib/ir/builder.cpp b/lib/ir/builder.cpp index 1cfdeefa3..8d3f58792 100644 --- a/lib/ir/builder.cpp +++ b/lib/ir/builder.cpp @@ -227,7 +227,11 @@ DEFINE_FCMP_INSTR(ONE, llvm::FCmpInst::FCMP_ONE) //===----------------------------------------------------------------------===// value *builder::create_load(value *arg, const std::string &name){ - return load_inst::create(arg, name); + return insert(load_inst::create(arg, name)); +} + +value *builder::create_store(value *ptr, value *val, const std::string &name){ + return insert(store_inst::create(ptr, val, name)); } //===----------------------------------------------------------------------===// @@ -254,5 +258,9 @@ value *builder::create_get_global_range(unsigned axis, unsigned size, const std: return insert(get_global_range_inst::create(ctx_, axis, size, name)); } +value *builder::create_matmul(value *A, value *B, value *C, const std::string &name) { + return insert(matmul_inst::create(A, B, C, name)); +} + } } diff --git a/lib/ir/constant.cpp b/lib/ir/constant.cpp index f2779b75b..58f3b1ab7 100644 --- a/lib/ir/constant.cpp +++ b/lib/ir/constant.cpp @@ -61,9 +61,11 @@ constant_range::constant_range(type *ty, uint64_t first, uint64_t last) constant *constant_range::get(constant *first, constant *last) { assert(first->get_type()->is_integer_ty()); assert(first->get_type() == last->get_type()); - uint64_t vfirst = ((constant_int*)first)->get_value(); - uint64_t vlast = ((constant_int*)first)->get_value(); - return new constant_range(first->get_type(), vfirst, vlast); + unsigned vfirst = ((constant_int*)first)->get_value(); + unsigned vlast = ((constant_int*)last)->get_value(); + assert(vlast > vfirst); + type *ty = tile_type::get(first->get_type(), {vlast - vfirst}); + return new constant_range(ty, vfirst, vlast); } diff --git a/lib/ir/instructions.cpp b/lib/ir/instructions.cpp index eb2483132..0a62e9a6d 100644 --- a/lib/ir/instructions.cpp +++ b/lib/ir/instructions.cpp @@ -310,6 +310,17 @@ load_inst* load_inst::create(value *ptr, const std::string &name, instruction *n return new load_inst(ptr, name, next); } +// store +store_inst::store_inst(value *ptr, value *v, const std::string &name, instruction *next) + : instruction(type::get_void_ty(ptr->get_type()->get_context()), 2, name, next) { + set_operand(0, ptr); + set_operand(1, v); +} + +store_inst* store_inst::create(value *ptr, value *v, const std::string &name, instruction *next) { + return new store_inst(ptr, v, name, next); +} + //===----------------------------------------------------------------------===// // retile_inst classes //===----------------------------------------------------------------------===// @@ -345,6 +356,18 @@ instruction* broadcast_inst::create(value *arg, const std::vector &sha // matmul_inst classes //===----------------------------------------------------------------------===// +matmul_inst::matmul_inst(value *A, value *B, value *C, + const std::string &name, instruction *next) + : builtin_inst(C->get_type(), 3, name, next) { + set_operand(0, A); + set_operand(0, B); + set_operand(0, C); +} + +instruction *matmul_inst::create(value *A, value *B, value *C, + const std::string &name, instruction *next) { + return new matmul_inst(A, B, C, name, next); +} //===----------------------------------------------------------------------===// // builtin instructions