[syntax tree] added syntactic support for dereferencing

This commit is contained in:
Philippe Tillet
2019-01-10 23:53:27 -05:00
parent b5c8c25d43
commit 80d019ec16
12 changed files with 128 additions and 45 deletions

View File

@@ -20,19 +20,25 @@ extern translation_unit *ast_root;
const char src[] =
"\
void test(fp32 *A, fp32 *B, fp32 *C, int32 M, int32 N, int32 K){\
void test(fp32 *a, fp32 *b, fp32 *c, int32 M, int32 N, int32 K){\
int32 rx[16] = get_global_range[16](0);\
int32 ry[16] = get_global_range[16](1);\
int32 rk[8] = 0 ... 8;\
fp32 c[16, 16] = 0;\
int32 rka[8] = 0 ... 8;\
int32 rkb[8] = 0 ... 8;\
fp32 C[16, 16] = 0;\
int32 k;\
fp32* pa[16, 8] = A + rx[:, newaxis] + rk[newaxis, :]*M;\
fp32* pb[16, 8] = B + ry[:, newaxis] + rk[newaxis, :]*K;\
fp32* pa[16, 8] = a + rx[:, newaxis] + rka[newaxis, :]*M;\
fp32* pb[16, 8] = b + ry[:, newaxis] + rkb[newaxis, :]*K;\
fp32* pc[16, 16];\
for(k = 0; k < K; k = k + 8){\
fp32 a[16, 8] = *pa;\
fp32 b[16, 8] = *pb;\
pa = pa + 8;\
fp32 A[16, 8] = *pa;\
fp32 B[16, 8] = *pb;\
C = dot(A, B, C);\
pa = pa + 8*M;\
pb = pb + 8*K;\
}\
pc = c + rx[:, newaxis] + ry[newaxis, :];\
*pc = C;\
}\
";

View File

@@ -137,6 +137,19 @@ private:
const constant* axis_;
};
class matmul_expression: public builtin_expression{
public:
matmul_expression(node* A, node *B, node *C):
A_((expression*)A), B_((expression*)B), C_((expression*)C) { }
ir::value* codegen(ir::module *) const;
private:
const expression *A_;
const expression *B_;
const expression *C_;
};
class indexing_expression: public postfix_expression{
public:
indexing_expression(node *id, node *slices)
@@ -149,21 +162,17 @@ private:
const list<slice*>* slices_;
};
class unary_expression: public expression{
class named_expression: public expression {
public:
unary_expression(node *id): id_((const identifier*)id) {}
const identifier *id() const;
named_expression(node *id): id_((const identifier*)id) { }
const identifier *id() const { return id_; }
ir::value* codegen(ir::module * mod) const;
private:
const identifier *id_;
};
class named_expression: public unary_expression {
public:
named_expression(node *id): unary_expression(id){ }
ir::value* codegen(ir::module * mod) const;
};
class binary_operator: public expression{
private:
ir::value* llvm_op(ir::module *mod, ir::builder &bld, ir::value *lhs, ir::value *rhs, const std::string &name) const;
@@ -220,6 +229,7 @@ public:
: op_(op),
arg_((expression*)arg) { }
UNARY_OP_T get_op() const { return op_; }
ir::value* codegen(ir::module *mod) const;
private:
@@ -267,13 +277,13 @@ public:
class assignment_expression: public expression{
public:
assignment_expression(node *lvalue, ASSIGN_OP_T op, node *rvalue)
: lvalue_((unary_expression*)lvalue), op_(op), rvalue_((expression*)rvalue) { }
: lvalue_((named_expression*)lvalue), op_(op), rvalue_((expression*)rvalue) { }
ir::value* codegen(ir::module *mod) const;
public:
ASSIGN_OP_T op_;
const unary_expression *lvalue_;
const expression *lvalue_;
const expression *rvalue_;
};

View File

@@ -50,7 +50,7 @@ TYPE_T get_type_spec(node *op) { return ((token*)op)->type; }
%token VOID UINT8 UINT16 UINT32 UINT64 INT8 INT16 INT32 INT64 FP32 FP64
%token IF ELSE FOR
%token NEWAXIS ELLIPSIS
%token GET_GLOBAL_RANGE
%token GET_GLOBAL_RANGE DOT
%start translation_unit
%%
@@ -111,6 +111,7 @@ identifier
builtin
: GET_GLOBAL_RANGE '[' constant ']' '(' constant ')' { $$ = new get_global_range($3, $6); }
| DOT '(' expression ',' expression ',' expression ')' { $$ = new matmul_expression($3, $5, $7); }
primary_expression
: identifier { $$ = new named_expression($1); }

View File

@@ -33,6 +33,7 @@ int comment();
"fp64" { count(); return(FP64); }
"..." { count(); return(ELLIPSIS); }
"get_global_range" { count(); return GET_GLOBAL_RANGE; }
"dot" { count(); return DOT;}
{L}({L}|{D})* { count(); return(check_type()); }

View File

@@ -37,6 +37,7 @@ private:
std::vector<unsigned*> pool_;
graph_t dependencies_;
std::set<node_t> nodes_;
std::map<node_t, unsigned> static_params_;
};

View File

@@ -105,14 +105,17 @@ public:
// Side effects
value *create_fneg(value *arg, const std::string &name = "");
value *create_neg(value *arg, const std::string &name = "");
value *create_load(value *arg, const std::string &name = "");
value *create_not(value *arg, const std::string &name = "");
// Input/Output
value *create_load(value *arg, const std::string &name = "");
value *create_store(value *ptr, value *val, const std::string &name = "");
// Tile instruction
value *create_splat(value *arg, const std::vector<unsigned> &shapes, const std::string &name = "");
value *create_reshape(value *arg, const std::vector<unsigned> &shapes, const std::string &name = "");
value *create_broadcast(value *arg, const std::vector<unsigned> &shapes, const std::string &name = "");
// Built-in instruction
value *create_get_global_range(unsigned axis, unsigned size, const std::string &name = "");
value *create_matmul(value *A, value *B, value *C, const std::string &name = "");
private:
context &ctx_;

View File

@@ -291,7 +291,18 @@ public:
// factory method
static load_inst* create(value *ptr, const std::string &name = "",
instruction *next = nullptr);
};
class store_inst: public instruction{
private:
store_inst(value *ptr, value *v, const std::string &name, instruction *next);
public:
value *get_pointer_operand() { return get_operand(0); }
value *get_value_operand() { return get_operand(1); }
// factory method
static store_inst* create(value* ptr, value *v, const std::string &name = "",
instruction *next = nullptr);
};
//===----------------------------------------------------------------------===//
@@ -336,11 +347,6 @@ public:
};
// matmul
class matmul_inst: public instruction {
};
// built-in
class builtin_inst: public instruction{
protected:
@@ -359,6 +365,15 @@ private:
unsigned axis_;
};
class matmul_inst: public builtin_inst {
matmul_inst(value *A, value *B, value *C, const std::string &name, instruction *next);
public:
static instruction* create(value *A, value *B, value *C,
const std::string &name = "",
instruction *next = nullptr);
};
}
}

View File

@@ -318,9 +318,9 @@ ir::value* initializer::codegen(ir::module * mod) const{
std::string name = decl_->id()->name();
ir::value *value = ir::undef_value::get(ty);
if(expr_){
ir::value* target = expr_->codegen(mod);
explicit_cast(mod->get_builder(), target, ty->get_scalar_ty());
implicit_broadcast(mod, value, target);
value = expr_->codegen(mod);
explicit_cast(mod->get_builder(), value, ty->get_scalar_ty());
implicit_broadcast(mod, value, value);
}
value->set_name(name);
mod->set_value(name, value);
@@ -421,12 +421,23 @@ ir::value* binary_operator::codegen(ir::module *mod) const{
}
/* Builtin expression */
// get_global_range
ir::value* get_global_range::codegen(ir::module *mod) const {
ir::builder &builder = mod->get_builder();
return builder.create_get_global_range(axis_->value(), size_->value());
}
ir::value* matmul_expression::codegen(ir::module *mod) const {
ir::value *A = A_->codegen(mod);
ir::value *B = B_->codegen(mod);
ir::value *C = C_->codegen(mod);
return mod->get_builder().create_matmul(A, B, C);
}
/* Postfix expression */
ir::value* indexing_expression::codegen(ir::module *mod) const{
ir::value *in = mod->get_value(id_->name());
@@ -497,7 +508,13 @@ ir::value *conditional_expression::codegen(ir::module *mod) const{
/* Assignment expression */
ir::value *assignment_expression::codegen(ir::module *mod) const{
ir::value *rvalue = rvalue_->codegen(mod);
mod->set_value(lvalue_->id()->name(), rvalue);
if(auto *x = dynamic_cast<const named_expression*>(lvalue_))
mod->set_value(x->id()->name(), rvalue);
else if(auto* x = dynamic_cast<const unary_operator*>(lvalue_)){
assert(x->get_op()==DEREF);
ir::value *ptr = x->codegen(mod);
mod->get_builder().create_store(ptr, rvalue);
}
return rvalue;
}
@@ -527,11 +544,6 @@ ir::value* constant_range::codegen(ir::module *mod) const{
(ir::constant*)last_->codegen(mod));
}
/* Unary expression */
const identifier* unary_expression::id() const{
return id_;
}
/* Named */
ir::value* named_expression::codegen(ir::module *mod) const{
const std::string &name = id()->name();

View File

@@ -33,7 +33,9 @@ void tune::init_c_graph(ir::instruction *v) {
ir::value *op = v->get_operand(0);
unsigned current = 0;
for(unsigned i = 0; i < shapes.size(); i ++)
if(shapes[i] > 1)
if(shapes[i] == 1)
static_params_.insert({{v, i}, 1});
else
add_constraint({v, i}, {op, current++});
}
else if(dynamic_cast<ir::splat_inst*>(v)){
@@ -70,10 +72,10 @@ void tune::connected_components(node_t x, const std::vector<unsigned *> vals, st
params_[instr].insert({"p1" + suffix, vals[1]});
params_[instr].insert({"p2" + suffix, vals[2]});
}
if(auto *cst = dynamic_cast<ir::constant_int*>(x.first)){
*vals[0] = cst->get_value();
*vals[1] = cst->get_value();
*vals[2] = cst->get_value();
if(static_params_.find(x) != static_params_.end()){
*vals[0] = static_params_.at(x);
*vals[1] = static_params_.at(x);
*vals[2] = static_params_.at(x);
}
for(const node_t &y: graph[x])
connected_components(y, vals, nodes, graph);
@@ -88,7 +90,6 @@ void tune::get_params(ir::module &mod, std::vector<unsigned *> &result) {
for(ir::instruction *i : block->get_inst_list())
for(auto &x: params_[i])
if(seen.insert(x.second).second && *x.second == 0){
std::cout << typeid(*i).name() << " " << i << std::endl;
result.push_back(x.second);
}
}

View File

@@ -227,7 +227,11 @@ DEFINE_FCMP_INSTR(ONE, llvm::FCmpInst::FCMP_ONE)
//===----------------------------------------------------------------------===//
value *builder::create_load(value *arg, const std::string &name){
return load_inst::create(arg, name);
return insert(load_inst::create(arg, name));
}
value *builder::create_store(value *ptr, value *val, const std::string &name){
return insert(store_inst::create(ptr, val, name));
}
//===----------------------------------------------------------------------===//
@@ -254,5 +258,9 @@ value *builder::create_get_global_range(unsigned axis, unsigned size, const std:
return insert(get_global_range_inst::create(ctx_, axis, size, name));
}
value *builder::create_matmul(value *A, value *B, value *C, const std::string &name) {
return insert(matmul_inst::create(A, B, C, name));
}
}
}

View File

@@ -61,9 +61,11 @@ constant_range::constant_range(type *ty, uint64_t first, uint64_t last)
constant *constant_range::get(constant *first, constant *last) {
assert(first->get_type()->is_integer_ty());
assert(first->get_type() == last->get_type());
uint64_t vfirst = ((constant_int*)first)->get_value();
uint64_t vlast = ((constant_int*)first)->get_value();
return new constant_range(first->get_type(), vfirst, vlast);
unsigned vfirst = ((constant_int*)first)->get_value();
unsigned vlast = ((constant_int*)last)->get_value();
assert(vlast > vfirst);
type *ty = tile_type::get(first->get_type(), {vlast - vfirst});
return new constant_range(ty, vfirst, vlast);
}

View File

@@ -310,6 +310,17 @@ load_inst* load_inst::create(value *ptr, const std::string &name, instruction *n
return new load_inst(ptr, name, next);
}
// store
store_inst::store_inst(value *ptr, value *v, const std::string &name, instruction *next)
: instruction(type::get_void_ty(ptr->get_type()->get_context()), 2, name, next) {
set_operand(0, ptr);
set_operand(1, v);
}
store_inst* store_inst::create(value *ptr, value *v, const std::string &name, instruction *next) {
return new store_inst(ptr, v, name, next);
}
//===----------------------------------------------------------------------===//
// retile_inst classes
//===----------------------------------------------------------------------===//
@@ -345,6 +356,18 @@ instruction* broadcast_inst::create(value *arg, const std::vector<unsigned> &sha
// matmul_inst classes
//===----------------------------------------------------------------------===//
matmul_inst::matmul_inst(value *A, value *B, value *C,
const std::string &name, instruction *next)
: builtin_inst(C->get_type(), 3, name, next) {
set_operand(0, A);
set_operand(0, B);
set_operand(0, C);
}
instruction *matmul_inst::create(value *A, value *B, value *C,
const std::string &name, instruction *next) {
return new matmul_inst(A, B, C, name, next);
}
//===----------------------------------------------------------------------===//
// builtin instructions