[syntax tree] added syntactic support for dereferencing
This commit is contained in:
@@ -20,19 +20,25 @@ extern translation_unit *ast_root;
|
||||
|
||||
const char src[] =
|
||||
"\
|
||||
void test(fp32 *A, fp32 *B, fp32 *C, int32 M, int32 N, int32 K){\
|
||||
void test(fp32 *a, fp32 *b, fp32 *c, int32 M, int32 N, int32 K){\
|
||||
int32 rx[16] = get_global_range[16](0);\
|
||||
int32 ry[16] = get_global_range[16](1);\
|
||||
int32 rk[8] = 0 ... 8;\
|
||||
fp32 c[16, 16] = 0;\
|
||||
int32 rka[8] = 0 ... 8;\
|
||||
int32 rkb[8] = 0 ... 8;\
|
||||
fp32 C[16, 16] = 0;\
|
||||
int32 k;\
|
||||
fp32* pa[16, 8] = A + rx[:, newaxis] + rk[newaxis, :]*M;\
|
||||
fp32* pb[16, 8] = B + ry[:, newaxis] + rk[newaxis, :]*K;\
|
||||
fp32* pa[16, 8] = a + rx[:, newaxis] + rka[newaxis, :]*M;\
|
||||
fp32* pb[16, 8] = b + ry[:, newaxis] + rkb[newaxis, :]*K;\
|
||||
fp32* pc[16, 16];\
|
||||
for(k = 0; k < K; k = k + 8){\
|
||||
fp32 a[16, 8] = *pa;\
|
||||
fp32 b[16, 8] = *pb;\
|
||||
pa = pa + 8;\
|
||||
fp32 A[16, 8] = *pa;\
|
||||
fp32 B[16, 8] = *pb;\
|
||||
C = dot(A, B, C);\
|
||||
pa = pa + 8*M;\
|
||||
pb = pb + 8*K;\
|
||||
}\
|
||||
pc = c + rx[:, newaxis] + ry[newaxis, :];\
|
||||
*pc = C;\
|
||||
}\
|
||||
";
|
||||
|
||||
|
@@ -137,6 +137,19 @@ private:
|
||||
const constant* axis_;
|
||||
};
|
||||
|
||||
class matmul_expression: public builtin_expression{
|
||||
public:
|
||||
matmul_expression(node* A, node *B, node *C):
|
||||
A_((expression*)A), B_((expression*)B), C_((expression*)C) { }
|
||||
ir::value* codegen(ir::module *) const;
|
||||
|
||||
private:
|
||||
const expression *A_;
|
||||
const expression *B_;
|
||||
const expression *C_;
|
||||
};
|
||||
|
||||
|
||||
class indexing_expression: public postfix_expression{
|
||||
public:
|
||||
indexing_expression(node *id, node *slices)
|
||||
@@ -149,21 +162,17 @@ private:
|
||||
const list<slice*>* slices_;
|
||||
};
|
||||
|
||||
class unary_expression: public expression{
|
||||
|
||||
class named_expression: public expression {
|
||||
public:
|
||||
unary_expression(node *id): id_((const identifier*)id) {}
|
||||
const identifier *id() const;
|
||||
named_expression(node *id): id_((const identifier*)id) { }
|
||||
const identifier *id() const { return id_; }
|
||||
ir::value* codegen(ir::module * mod) const;
|
||||
|
||||
private:
|
||||
const identifier *id_;
|
||||
};
|
||||
|
||||
class named_expression: public unary_expression {
|
||||
public:
|
||||
named_expression(node *id): unary_expression(id){ }
|
||||
ir::value* codegen(ir::module * mod) const;
|
||||
};
|
||||
|
||||
class binary_operator: public expression{
|
||||
private:
|
||||
ir::value* llvm_op(ir::module *mod, ir::builder &bld, ir::value *lhs, ir::value *rhs, const std::string &name) const;
|
||||
@@ -220,6 +229,7 @@ public:
|
||||
: op_(op),
|
||||
arg_((expression*)arg) { }
|
||||
|
||||
UNARY_OP_T get_op() const { return op_; }
|
||||
ir::value* codegen(ir::module *mod) const;
|
||||
|
||||
private:
|
||||
@@ -267,13 +277,13 @@ public:
|
||||
class assignment_expression: public expression{
|
||||
public:
|
||||
assignment_expression(node *lvalue, ASSIGN_OP_T op, node *rvalue)
|
||||
: lvalue_((unary_expression*)lvalue), op_(op), rvalue_((expression*)rvalue) { }
|
||||
: lvalue_((named_expression*)lvalue), op_(op), rvalue_((expression*)rvalue) { }
|
||||
|
||||
ir::value* codegen(ir::module *mod) const;
|
||||
|
||||
public:
|
||||
ASSIGN_OP_T op_;
|
||||
const unary_expression *lvalue_;
|
||||
const expression *lvalue_;
|
||||
const expression *rvalue_;
|
||||
};
|
||||
|
||||
|
@@ -50,7 +50,7 @@ TYPE_T get_type_spec(node *op) { return ((token*)op)->type; }
|
||||
%token VOID UINT8 UINT16 UINT32 UINT64 INT8 INT16 INT32 INT64 FP32 FP64
|
||||
%token IF ELSE FOR
|
||||
%token NEWAXIS ELLIPSIS
|
||||
%token GET_GLOBAL_RANGE
|
||||
%token GET_GLOBAL_RANGE DOT
|
||||
|
||||
%start translation_unit
|
||||
%%
|
||||
@@ -111,6 +111,7 @@ identifier
|
||||
|
||||
builtin
|
||||
: GET_GLOBAL_RANGE '[' constant ']' '(' constant ')' { $$ = new get_global_range($3, $6); }
|
||||
| DOT '(' expression ',' expression ',' expression ')' { $$ = new matmul_expression($3, $5, $7); }
|
||||
|
||||
primary_expression
|
||||
: identifier { $$ = new named_expression($1); }
|
||||
|
@@ -33,6 +33,7 @@ int comment();
|
||||
"fp64" { count(); return(FP64); }
|
||||
"..." { count(); return(ELLIPSIS); }
|
||||
"get_global_range" { count(); return GET_GLOBAL_RANGE; }
|
||||
"dot" { count(); return DOT;}
|
||||
|
||||
{L}({L}|{D})* { count(); return(check_type()); }
|
||||
|
||||
|
@@ -37,6 +37,7 @@ private:
|
||||
std::vector<unsigned*> pool_;
|
||||
graph_t dependencies_;
|
||||
std::set<node_t> nodes_;
|
||||
std::map<node_t, unsigned> static_params_;
|
||||
};
|
||||
|
||||
|
||||
|
@@ -105,14 +105,17 @@ public:
|
||||
// Side effects
|
||||
value *create_fneg(value *arg, const std::string &name = "");
|
||||
value *create_neg(value *arg, const std::string &name = "");
|
||||
value *create_load(value *arg, const std::string &name = "");
|
||||
value *create_not(value *arg, const std::string &name = "");
|
||||
// Input/Output
|
||||
value *create_load(value *arg, const std::string &name = "");
|
||||
value *create_store(value *ptr, value *val, const std::string &name = "");
|
||||
// Tile instruction
|
||||
value *create_splat(value *arg, const std::vector<unsigned> &shapes, const std::string &name = "");
|
||||
value *create_reshape(value *arg, const std::vector<unsigned> &shapes, const std::string &name = "");
|
||||
value *create_broadcast(value *arg, const std::vector<unsigned> &shapes, const std::string &name = "");
|
||||
// Built-in instruction
|
||||
value *create_get_global_range(unsigned axis, unsigned size, const std::string &name = "");
|
||||
value *create_matmul(value *A, value *B, value *C, const std::string &name = "");
|
||||
|
||||
private:
|
||||
context &ctx_;
|
||||
|
@@ -291,7 +291,18 @@ public:
|
||||
// factory method
|
||||
static load_inst* create(value *ptr, const std::string &name = "",
|
||||
instruction *next = nullptr);
|
||||
};
|
||||
|
||||
class store_inst: public instruction{
|
||||
private:
|
||||
store_inst(value *ptr, value *v, const std::string &name, instruction *next);
|
||||
|
||||
public:
|
||||
value *get_pointer_operand() { return get_operand(0); }
|
||||
value *get_value_operand() { return get_operand(1); }
|
||||
// factory method
|
||||
static store_inst* create(value* ptr, value *v, const std::string &name = "",
|
||||
instruction *next = nullptr);
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -336,11 +347,6 @@ public:
|
||||
};
|
||||
|
||||
|
||||
// matmul
|
||||
class matmul_inst: public instruction {
|
||||
|
||||
};
|
||||
|
||||
// built-in
|
||||
class builtin_inst: public instruction{
|
||||
protected:
|
||||
@@ -359,6 +365,15 @@ private:
|
||||
unsigned axis_;
|
||||
};
|
||||
|
||||
class matmul_inst: public builtin_inst {
|
||||
matmul_inst(value *A, value *B, value *C, const std::string &name, instruction *next);
|
||||
|
||||
public:
|
||||
static instruction* create(value *A, value *B, value *C,
|
||||
const std::string &name = "",
|
||||
instruction *next = nullptr);
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -318,9 +318,9 @@ ir::value* initializer::codegen(ir::module * mod) const{
|
||||
std::string name = decl_->id()->name();
|
||||
ir::value *value = ir::undef_value::get(ty);
|
||||
if(expr_){
|
||||
ir::value* target = expr_->codegen(mod);
|
||||
explicit_cast(mod->get_builder(), target, ty->get_scalar_ty());
|
||||
implicit_broadcast(mod, value, target);
|
||||
value = expr_->codegen(mod);
|
||||
explicit_cast(mod->get_builder(), value, ty->get_scalar_ty());
|
||||
implicit_broadcast(mod, value, value);
|
||||
}
|
||||
value->set_name(name);
|
||||
mod->set_value(name, value);
|
||||
@@ -421,12 +421,23 @@ ir::value* binary_operator::codegen(ir::module *mod) const{
|
||||
}
|
||||
|
||||
/* Builtin expression */
|
||||
|
||||
// get_global_range
|
||||
ir::value* get_global_range::codegen(ir::module *mod) const {
|
||||
ir::builder &builder = mod->get_builder();
|
||||
return builder.create_get_global_range(axis_->value(), size_->value());
|
||||
}
|
||||
|
||||
|
||||
ir::value* matmul_expression::codegen(ir::module *mod) const {
|
||||
ir::value *A = A_->codegen(mod);
|
||||
ir::value *B = B_->codegen(mod);
|
||||
ir::value *C = C_->codegen(mod);
|
||||
return mod->get_builder().create_matmul(A, B, C);
|
||||
}
|
||||
|
||||
|
||||
|
||||
/* Postfix expression */
|
||||
ir::value* indexing_expression::codegen(ir::module *mod) const{
|
||||
ir::value *in = mod->get_value(id_->name());
|
||||
@@ -497,7 +508,13 @@ ir::value *conditional_expression::codegen(ir::module *mod) const{
|
||||
/* Assignment expression */
|
||||
ir::value *assignment_expression::codegen(ir::module *mod) const{
|
||||
ir::value *rvalue = rvalue_->codegen(mod);
|
||||
mod->set_value(lvalue_->id()->name(), rvalue);
|
||||
if(auto *x = dynamic_cast<const named_expression*>(lvalue_))
|
||||
mod->set_value(x->id()->name(), rvalue);
|
||||
else if(auto* x = dynamic_cast<const unary_operator*>(lvalue_)){
|
||||
assert(x->get_op()==DEREF);
|
||||
ir::value *ptr = x->codegen(mod);
|
||||
mod->get_builder().create_store(ptr, rvalue);
|
||||
}
|
||||
return rvalue;
|
||||
}
|
||||
|
||||
@@ -527,11 +544,6 @@ ir::value* constant_range::codegen(ir::module *mod) const{
|
||||
(ir::constant*)last_->codegen(mod));
|
||||
}
|
||||
|
||||
/* Unary expression */
|
||||
const identifier* unary_expression::id() const{
|
||||
return id_;
|
||||
}
|
||||
|
||||
/* Named */
|
||||
ir::value* named_expression::codegen(ir::module *mod) const{
|
||||
const std::string &name = id()->name();
|
||||
|
@@ -33,7 +33,9 @@ void tune::init_c_graph(ir::instruction *v) {
|
||||
ir::value *op = v->get_operand(0);
|
||||
unsigned current = 0;
|
||||
for(unsigned i = 0; i < shapes.size(); i ++)
|
||||
if(shapes[i] > 1)
|
||||
if(shapes[i] == 1)
|
||||
static_params_.insert({{v, i}, 1});
|
||||
else
|
||||
add_constraint({v, i}, {op, current++});
|
||||
}
|
||||
else if(dynamic_cast<ir::splat_inst*>(v)){
|
||||
@@ -70,10 +72,10 @@ void tune::connected_components(node_t x, const std::vector<unsigned *> vals, st
|
||||
params_[instr].insert({"p1" + suffix, vals[1]});
|
||||
params_[instr].insert({"p2" + suffix, vals[2]});
|
||||
}
|
||||
if(auto *cst = dynamic_cast<ir::constant_int*>(x.first)){
|
||||
*vals[0] = cst->get_value();
|
||||
*vals[1] = cst->get_value();
|
||||
*vals[2] = cst->get_value();
|
||||
if(static_params_.find(x) != static_params_.end()){
|
||||
*vals[0] = static_params_.at(x);
|
||||
*vals[1] = static_params_.at(x);
|
||||
*vals[2] = static_params_.at(x);
|
||||
}
|
||||
for(const node_t &y: graph[x])
|
||||
connected_components(y, vals, nodes, graph);
|
||||
@@ -88,7 +90,6 @@ void tune::get_params(ir::module &mod, std::vector<unsigned *> &result) {
|
||||
for(ir::instruction *i : block->get_inst_list())
|
||||
for(auto &x: params_[i])
|
||||
if(seen.insert(x.second).second && *x.second == 0){
|
||||
std::cout << typeid(*i).name() << " " << i << std::endl;
|
||||
result.push_back(x.second);
|
||||
}
|
||||
}
|
||||
|
@@ -227,7 +227,11 @@ DEFINE_FCMP_INSTR(ONE, llvm::FCmpInst::FCMP_ONE)
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
value *builder::create_load(value *arg, const std::string &name){
|
||||
return load_inst::create(arg, name);
|
||||
return insert(load_inst::create(arg, name));
|
||||
}
|
||||
|
||||
value *builder::create_store(value *ptr, value *val, const std::string &name){
|
||||
return insert(store_inst::create(ptr, val, name));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -254,5 +258,9 @@ value *builder::create_get_global_range(unsigned axis, unsigned size, const std:
|
||||
return insert(get_global_range_inst::create(ctx_, axis, size, name));
|
||||
}
|
||||
|
||||
value *builder::create_matmul(value *A, value *B, value *C, const std::string &name) {
|
||||
return insert(matmul_inst::create(A, B, C, name));
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
@@ -61,9 +61,11 @@ constant_range::constant_range(type *ty, uint64_t first, uint64_t last)
|
||||
constant *constant_range::get(constant *first, constant *last) {
|
||||
assert(first->get_type()->is_integer_ty());
|
||||
assert(first->get_type() == last->get_type());
|
||||
uint64_t vfirst = ((constant_int*)first)->get_value();
|
||||
uint64_t vlast = ((constant_int*)first)->get_value();
|
||||
return new constant_range(first->get_type(), vfirst, vlast);
|
||||
unsigned vfirst = ((constant_int*)first)->get_value();
|
||||
unsigned vlast = ((constant_int*)last)->get_value();
|
||||
assert(vlast > vfirst);
|
||||
type *ty = tile_type::get(first->get_type(), {vlast - vfirst});
|
||||
return new constant_range(ty, vfirst, vlast);
|
||||
}
|
||||
|
||||
|
||||
|
@@ -310,6 +310,17 @@ load_inst* load_inst::create(value *ptr, const std::string &name, instruction *n
|
||||
return new load_inst(ptr, name, next);
|
||||
}
|
||||
|
||||
// store
|
||||
store_inst::store_inst(value *ptr, value *v, const std::string &name, instruction *next)
|
||||
: instruction(type::get_void_ty(ptr->get_type()->get_context()), 2, name, next) {
|
||||
set_operand(0, ptr);
|
||||
set_operand(1, v);
|
||||
}
|
||||
|
||||
store_inst* store_inst::create(value *ptr, value *v, const std::string &name, instruction *next) {
|
||||
return new store_inst(ptr, v, name, next);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// retile_inst classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -345,6 +356,18 @@ instruction* broadcast_inst::create(value *arg, const std::vector<unsigned> &sha
|
||||
// matmul_inst classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
matmul_inst::matmul_inst(value *A, value *B, value *C,
|
||||
const std::string &name, instruction *next)
|
||||
: builtin_inst(C->get_type(), 3, name, next) {
|
||||
set_operand(0, A);
|
||||
set_operand(0, B);
|
||||
set_operand(0, C);
|
||||
}
|
||||
|
||||
instruction *matmul_inst::create(value *A, value *B, value *C,
|
||||
const std::string &name, instruction *next) {
|
||||
return new matmul_inst(A, B, C, name, next);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// builtin instructions
|
||||
|
Reference in New Issue
Block a user