[syntax tree] added basic support for range
This commit is contained in:
@@ -20,11 +20,9 @@ extern translation_unit *ast_root;
|
||||
|
||||
const char src[] =
|
||||
"\
|
||||
void test(fp32 *A, fp32 *B, fp32 *C, int32 i){\
|
||||
int32 tile[16, 16] = 0;\
|
||||
fp32 *test[16, 16] = tile + A;\
|
||||
i = 1;\
|
||||
A = A + i;\
|
||||
void test(fp32 *A, fp32 *B, fp32 *C, int32 M, int32 N, int32 K){\
|
||||
fp32 acc[16, 16] = 0;\
|
||||
fp32 *pa[16, 8] = A;\
|
||||
}\
|
||||
";
|
||||
|
||||
@@ -42,6 +40,9 @@ int main() {
|
||||
tdl::codegen::selection selection;
|
||||
tdl::codegen::tune tune;
|
||||
tune.run(module);
|
||||
std::vector<unsigned*> params;
|
||||
tune.get_params(module, params);
|
||||
std::cout << params.size() << std::endl;
|
||||
// selection.run(module, llvm_module);
|
||||
// // print LLVM program
|
||||
// llvm::PrintModulePass print(llvm::outs());
|
||||
|
@@ -58,6 +58,7 @@ enum TYPE_T{
|
||||
|
||||
class pointer;
|
||||
class identifier;
|
||||
class constant;
|
||||
|
||||
// AST
|
||||
class node {
|
||||
@@ -121,6 +122,21 @@ class postfix_expression: public expression{
|
||||
|
||||
};
|
||||
|
||||
class builtin_expression: public node{
|
||||
|
||||
};
|
||||
|
||||
|
||||
class get_global_range: public builtin_expression{
|
||||
public:
|
||||
get_global_range(node *size, node *axis): size_((constant*)size), axis_((constant*)axis) { }
|
||||
ir::value* codegen(ir::module *) const;
|
||||
|
||||
private:
|
||||
const constant* size_;
|
||||
const constant* axis_;
|
||||
};
|
||||
|
||||
class indexing_expression: public postfix_expression{
|
||||
public:
|
||||
indexing_expression(node *id, node *ranges)
|
||||
@@ -133,7 +149,7 @@ private:
|
||||
const list<range*>* ranges_;
|
||||
};
|
||||
|
||||
class unary_expression: public node{
|
||||
class unary_expression: public expression{
|
||||
public:
|
||||
unary_expression(node *id): id_((const identifier*)id) {}
|
||||
const identifier *id() const;
|
||||
@@ -174,6 +190,17 @@ private:
|
||||
const int value_;
|
||||
};
|
||||
|
||||
class constant_range: public expression {
|
||||
public:
|
||||
constant_range(node *first, node *last)
|
||||
: first_((constant*)first), last_((constant*)last) { }
|
||||
|
||||
ir::value* codegen(ir::module *mod) const;
|
||||
|
||||
private:
|
||||
constant *first_;
|
||||
constant *last_;
|
||||
};
|
||||
|
||||
class string_literal: public expression{
|
||||
public:
|
||||
|
@@ -49,7 +49,8 @@ TYPE_T get_type_spec(node *op) { return ((token*)op)->type; }
|
||||
%token XOR_ASSIGN OR_ASSIGN TYPE_NAME
|
||||
%token VOID UINT8 UINT16 UINT32 UINT64 INT8 INT16 INT32 INT64 FP32 FP64
|
||||
%token IF ELSE FOR
|
||||
%token NEWAXIS
|
||||
%token NEWAXIS ELLIPSIS
|
||||
%token GET_GLOBAL_RANGE
|
||||
|
||||
%start translation_unit
|
||||
%%
|
||||
@@ -87,7 +88,8 @@ direct_abstract_declarator
|
||||
: '[' constant_list ']' { $$ = new tile(nullptr, $1); }
|
||||
|
||||
constant :
|
||||
CONSTANT { $$ = new constant(atoi(yytext)); }
|
||||
CONSTANT { $$ = new constant(atoi(yytext)); }
|
||||
| constant ELLIPSIS constant { $$ = new constant_range($1, $2); }
|
||||
;
|
||||
|
||||
constant_list
|
||||
@@ -107,11 +109,15 @@ type_name
|
||||
identifier
|
||||
: IDENTIFIER { $$ = new identifier(yytext); }
|
||||
;
|
||||
|
||||
|
||||
builtin
|
||||
: GET_GLOBAL_RANGE '[' constant ']' '(' constant ')' { $$ = new get_global_range($3, $6); }
|
||||
|
||||
primary_expression
|
||||
: identifier { $$ = new named_expression($1); }
|
||||
| constant { $$ = $1; }
|
||||
| STRING_LITERAL { $$ = new string_literal(yytext); }
|
||||
: identifier { $$ = new named_expression($1); }
|
||||
| constant { $$ = $1; }
|
||||
| builtin { $$ = $1; }
|
||||
| STRING_LITERAL { $$ = new string_literal(yytext); }
|
||||
| '(' expression ')' { $$ = $1; }
|
||||
;
|
||||
|
||||
|
@@ -31,6 +31,8 @@ int comment();
|
||||
"int64" { count(); return(INT64); }
|
||||
"fp32" { count(); return(FP32); }
|
||||
"fp64" { count(); return(FP64); }
|
||||
"..." { count(); return(ELLIPSIS); }
|
||||
"get_global_range" { count(); return GET_GLOBAL_RANGE; }
|
||||
|
||||
{L}({L}|{D})* { count(); return(check_type()); }
|
||||
|
||||
|
@@ -27,6 +27,7 @@ private:
|
||||
|
||||
|
||||
public:
|
||||
void get_params(ir::module& mod, std::vector<unsigned*> &result);
|
||||
unsigned *get_param(ir::value *value);
|
||||
bool check_constraints(ir::module &fn, std::map<ir::value *, std::vector<std::string>> &errors);
|
||||
void run(ir::module &mod);
|
||||
|
@@ -111,6 +111,8 @@ public:
|
||||
value *create_splat(value *arg, const std::vector<unsigned> &shapes, const std::string &name = "");
|
||||
value *create_reshape(value *arg, const std::vector<unsigned> &shapes, const std::string &name = "");
|
||||
value *create_broadcast(value *arg, const std::vector<unsigned> &shapes, const std::string &name = "");
|
||||
// Built-in instruction
|
||||
value *create_get_global_range(unsigned axis, unsigned size, const std::string &name = "");
|
||||
|
||||
private:
|
||||
context &ctx_;
|
||||
|
@@ -40,6 +40,18 @@ private:
|
||||
uint64_t value_;
|
||||
};
|
||||
|
||||
/* constant range */
|
||||
class constant_range: public constant{
|
||||
constant_range(type *ty, uint64_t first, uint64_t last);
|
||||
|
||||
public:
|
||||
static constant *get(constant *first, constant *last);
|
||||
|
||||
private:
|
||||
uint64_t first_;
|
||||
uint64_t last_;
|
||||
};
|
||||
|
||||
/* constant fp */
|
||||
class constant_fp: public constant{
|
||||
constant_fp(context &ctx, double value);
|
||||
|
@@ -333,11 +333,28 @@ public:
|
||||
|
||||
|
||||
// matmul
|
||||
|
||||
class matmul_inst: public instruction {
|
||||
|
||||
};
|
||||
|
||||
// built-in
|
||||
class builtin_inst: public instruction{
|
||||
protected:
|
||||
using instruction::instruction;
|
||||
};
|
||||
|
||||
class get_global_range_inst: public builtin_inst {
|
||||
get_global_range_inst(type *ty, unsigned axis, const std::string &name, instruction *next);
|
||||
|
||||
public:
|
||||
static instruction* create(context &ctx, unsigned axis, unsigned size,
|
||||
const std::string &name = "",
|
||||
instruction *next = nullptr);
|
||||
|
||||
private:
|
||||
unsigned axis_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -415,6 +415,13 @@ ir::value* binary_operator::codegen(ir::module *mod) const{
|
||||
return result;
|
||||
}
|
||||
|
||||
/* Builtin expression */
|
||||
ir::value* get_global_range::codegen(ir::module *mod) const {
|
||||
ir::builder &builder = mod->get_builder();
|
||||
return builder.create_get_global_range(axis_->value(), size_->value());
|
||||
}
|
||||
|
||||
|
||||
/* Postfix expression */
|
||||
ir::value* indexing_expression::codegen(ir::module *mod) const{
|
||||
ir::value *in = mod->get_value(id_->name());
|
||||
@@ -509,6 +516,11 @@ int constant::value() const{
|
||||
return value_;
|
||||
}
|
||||
|
||||
/* Constant range */
|
||||
ir::value* constant_range::codegen(ir::module *mod) const{
|
||||
return ir::constant_range::get((ir::constant*)first_->codegen(mod),
|
||||
(ir::constant*)last_->codegen(mod));
|
||||
}
|
||||
|
||||
/* Unary expression */
|
||||
const identifier* unary_expression::id() const{
|
||||
|
@@ -62,6 +62,17 @@ void tune::connected_components(node_t x, const std::vector<unsigned *> vals, st
|
||||
}
|
||||
}
|
||||
|
||||
void tune::get_params(ir::module &mod, std::vector<unsigned *> &result) {
|
||||
result.clear();
|
||||
std::set<unsigned*> seen;
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *i : block->get_inst_list())
|
||||
for(auto &x: params_[i])
|
||||
if(seen.insert(x.second).second)
|
||||
result.push_back(x.second);
|
||||
}
|
||||
|
||||
void tune::run(ir::module &mod) {
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
// Build constraints graph
|
||||
|
@@ -246,7 +246,13 @@ value *builder::create_broadcast(value *arg, const std::vector<unsigned> &shapes
|
||||
return insert(broadcast_inst::create(arg, shapes, name));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// built-in instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
value *builder::create_get_global_range(unsigned axis, unsigned size, const std::string &name) {
|
||||
return insert(get_global_range_inst::create(ctx_, axis, size, name));
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
@@ -1,3 +1,4 @@
|
||||
#include <cassert>
|
||||
#include "ir/constant.h"
|
||||
#include "ir/type.h"
|
||||
#include "ir/context.h"
|
||||
@@ -51,6 +52,21 @@ constant *constant_int::get(type *ty, uint64_t value) {
|
||||
return new constant_int(ty, value);
|
||||
}
|
||||
|
||||
// constant_range
|
||||
// FIXME use something like APInt
|
||||
|
||||
constant_range::constant_range(type *ty, uint64_t first, uint64_t last)
|
||||
: constant(ty, 0), first_(first), last_(last){ }
|
||||
|
||||
constant *constant_range::get(constant *first, constant *last) {
|
||||
assert(first->get_type()->is_integer_ty());
|
||||
assert(first->get_type() == last->get_type());
|
||||
uint64_t vfirst = ((constant_int*)first)->get_value();
|
||||
uint64_t vlast = ((constant_int*)first)->get_value();
|
||||
return new constant_range(first->get_type(), vfirst, vlast);
|
||||
}
|
||||
|
||||
|
||||
// constant_fp
|
||||
// FIXME use something like APFloat
|
||||
|
||||
|
@@ -246,15 +246,17 @@ getelementptr_inst::getelementptr_inst(type *pointee_ty, value *ptr, const std::
|
||||
set_operand(1 + i, idx[i]);
|
||||
}
|
||||
|
||||
type *getelementptr_inst::get_return_type(type *elt_ty, value *ptr, const std::vector<value *> &idx_list) {
|
||||
type *getelementptr_inst::get_return_type(type *elt_ty, value *x, const std::vector<value *> &idx_list) {
|
||||
// result pointer type
|
||||
type *ptr_ty = pointer_type::get(get_indexed_type(elt_ty, idx_list), ptr->get_type()->get_scalar_ty()->get_pointer_address_space());
|
||||
type *ty = x->get_type();
|
||||
unsigned addr_space = ty->get_scalar_ty()->get_pointer_address_space();
|
||||
type *ptr_ty = pointer_type::get(get_indexed_type(elt_ty, idx_list), addr_space);
|
||||
// Tile GEP
|
||||
if(ptr->get_type()->is_tile_ty())
|
||||
return tile_type::get_same_shapes(ptr_ty, ptr->get_type());
|
||||
if(ty->is_tile_ty())
|
||||
return tile_type::get_same_shapes(ptr_ty, ty);
|
||||
for(value *idx : idx_list)
|
||||
if (idx->get_type()->is_tile_ty())
|
||||
return tile_type::get_same_shapes(ptr_ty, idx->get_type());
|
||||
return tile_type::get_same_shapes(ptr_ty, ty);
|
||||
// Scalar GEP
|
||||
return ptr_ty;
|
||||
}
|
||||
@@ -329,5 +331,27 @@ instruction* broadcast_inst::create(value *arg, const std::vector<unsigned> &sha
|
||||
return new broadcast_inst(arg, shapes, name, next);
|
||||
}
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// matmul_inst classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// builtin instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
get_global_range_inst::get_global_range_inst(type *ty, unsigned axis,
|
||||
const std::string &name, instruction *next)
|
||||
: builtin_inst(ty, 0, name, next), axis_(axis) {
|
||||
|
||||
}
|
||||
|
||||
instruction* get_global_range_inst::create(context &ctx, unsigned axis, unsigned size,
|
||||
const std::string &name, instruction *next) {
|
||||
type *int_ty = type::get_int32_ty(ctx);
|
||||
type *tile_ty = tile_type::get(int_ty, {size});
|
||||
return new get_global_range_inst(tile_ty, axis, name, next);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user