[syntax tree] added basic support for range

This commit is contained in:
Philippe Tillet
2019-01-09 02:07:34 -05:00
parent 7dfa578c9d
commit 4f923accd7
13 changed files with 155 additions and 18 deletions

View File

@@ -20,11 +20,9 @@ extern translation_unit *ast_root;
const char src[] =
"\
void test(fp32 *A, fp32 *B, fp32 *C, int32 i){\
int32 tile[16, 16] = 0;\
fp32 *test[16, 16] = tile + A;\
i = 1;\
A = A + i;\
void test(fp32 *A, fp32 *B, fp32 *C, int32 M, int32 N, int32 K){\
fp32 acc[16, 16] = 0;\
fp32 *pa[16, 8] = A;\
}\
";
@@ -42,6 +40,9 @@ int main() {
tdl::codegen::selection selection;
tdl::codegen::tune tune;
tune.run(module);
std::vector<unsigned*> params;
tune.get_params(module, params);
std::cout << params.size() << std::endl;
// selection.run(module, llvm_module);
// // print LLVM program
// llvm::PrintModulePass print(llvm::outs());

View File

@@ -58,6 +58,7 @@ enum TYPE_T{
class pointer;
class identifier;
class constant;
// AST
class node {
@@ -121,6 +122,21 @@ class postfix_expression: public expression{
};
class builtin_expression: public node{
};
class get_global_range: public builtin_expression{
public:
get_global_range(node *size, node *axis): size_((constant*)size), axis_((constant*)axis) { }
ir::value* codegen(ir::module *) const;
private:
const constant* size_;
const constant* axis_;
};
class indexing_expression: public postfix_expression{
public:
indexing_expression(node *id, node *ranges)
@@ -133,7 +149,7 @@ private:
const list<range*>* ranges_;
};
class unary_expression: public node{
class unary_expression: public expression{
public:
unary_expression(node *id): id_((const identifier*)id) {}
const identifier *id() const;
@@ -174,6 +190,17 @@ private:
const int value_;
};
class constant_range: public expression {
public:
constant_range(node *first, node *last)
: first_((constant*)first), last_((constant*)last) { }
ir::value* codegen(ir::module *mod) const;
private:
constant *first_;
constant *last_;
};
class string_literal: public expression{
public:

View File

@@ -49,7 +49,8 @@ TYPE_T get_type_spec(node *op) { return ((token*)op)->type; }
%token XOR_ASSIGN OR_ASSIGN TYPE_NAME
%token VOID UINT8 UINT16 UINT32 UINT64 INT8 INT16 INT32 INT64 FP32 FP64
%token IF ELSE FOR
%token NEWAXIS
%token NEWAXIS ELLIPSIS
%token GET_GLOBAL_RANGE
%start translation_unit
%%
@@ -87,7 +88,8 @@ direct_abstract_declarator
: '[' constant_list ']' { $$ = new tile(nullptr, $1); }
constant :
CONSTANT { $$ = new constant(atoi(yytext)); }
CONSTANT { $$ = new constant(atoi(yytext)); }
| constant ELLIPSIS constant { $$ = new constant_range($1, $2); }
;
constant_list
@@ -107,11 +109,15 @@ type_name
identifier
: IDENTIFIER { $$ = new identifier(yytext); }
;
builtin
: GET_GLOBAL_RANGE '[' constant ']' '(' constant ')' { $$ = new get_global_range($3, $6); }
primary_expression
: identifier { $$ = new named_expression($1); }
| constant { $$ = $1; }
| STRING_LITERAL { $$ = new string_literal(yytext); }
: identifier { $$ = new named_expression($1); }
| constant { $$ = $1; }
| builtin { $$ = $1; }
| STRING_LITERAL { $$ = new string_literal(yytext); }
| '(' expression ')' { $$ = $1; }
;

View File

@@ -31,6 +31,8 @@ int comment();
"int64" { count(); return(INT64); }
"fp32" { count(); return(FP32); }
"fp64" { count(); return(FP64); }
"..." { count(); return(ELLIPSIS); }
"get_global_range" { count(); return GET_GLOBAL_RANGE; }
{L}({L}|{D})* { count(); return(check_type()); }

View File

@@ -27,6 +27,7 @@ private:
public:
void get_params(ir::module& mod, std::vector<unsigned*> &result);
unsigned *get_param(ir::value *value);
bool check_constraints(ir::module &fn, std::map<ir::value *, std::vector<std::string>> &errors);
void run(ir::module &mod);

View File

@@ -111,6 +111,8 @@ public:
value *create_splat(value *arg, const std::vector<unsigned> &shapes, const std::string &name = "");
value *create_reshape(value *arg, const std::vector<unsigned> &shapes, const std::string &name = "");
value *create_broadcast(value *arg, const std::vector<unsigned> &shapes, const std::string &name = "");
// Built-in instruction
value *create_get_global_range(unsigned axis, unsigned size, const std::string &name = "");
private:
context &ctx_;

View File

@@ -40,6 +40,18 @@ private:
uint64_t value_;
};
/* constant range */
class constant_range: public constant{
constant_range(type *ty, uint64_t first, uint64_t last);
public:
static constant *get(constant *first, constant *last);
private:
uint64_t first_;
uint64_t last_;
};
/* constant fp */
class constant_fp: public constant{
constant_fp(context &ctx, double value);

View File

@@ -333,11 +333,28 @@ public:
// matmul
class matmul_inst: public instruction {
};
// built-in
class builtin_inst: public instruction{
protected:
using instruction::instruction;
};
class get_global_range_inst: public builtin_inst {
get_global_range_inst(type *ty, unsigned axis, const std::string &name, instruction *next);
public:
static instruction* create(context &ctx, unsigned axis, unsigned size,
const std::string &name = "",
instruction *next = nullptr);
private:
unsigned axis_;
};
}
}

View File

@@ -415,6 +415,13 @@ ir::value* binary_operator::codegen(ir::module *mod) const{
return result;
}
/* Builtin expression */
ir::value* get_global_range::codegen(ir::module *mod) const {
ir::builder &builder = mod->get_builder();
return builder.create_get_global_range(axis_->value(), size_->value());
}
/* Postfix expression */
ir::value* indexing_expression::codegen(ir::module *mod) const{
ir::value *in = mod->get_value(id_->name());
@@ -509,6 +516,11 @@ int constant::value() const{
return value_;
}
/* Constant range */
ir::value* constant_range::codegen(ir::module *mod) const{
return ir::constant_range::get((ir::constant*)first_->codegen(mod),
(ir::constant*)last_->codegen(mod));
}
/* Unary expression */
const identifier* unary_expression::id() const{

View File

@@ -62,6 +62,17 @@ void tune::connected_components(node_t x, const std::vector<unsigned *> vals, st
}
}
void tune::get_params(ir::module &mod, std::vector<unsigned *> &result) {
result.clear();
std::set<unsigned*> seen;
for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *i : block->get_inst_list())
for(auto &x: params_[i])
if(seen.insert(x.second).second)
result.push_back(x.second);
}
void tune::run(ir::module &mod) {
for(ir::function *fn: mod.get_function_list()){
// Build constraints graph

View File

@@ -246,7 +246,13 @@ value *builder::create_broadcast(value *arg, const std::vector<unsigned> &shapes
return insert(broadcast_inst::create(arg, shapes, name));
}
//===----------------------------------------------------------------------===//
// built-in instructions
//===----------------------------------------------------------------------===//
value *builder::create_get_global_range(unsigned axis, unsigned size, const std::string &name) {
return insert(get_global_range_inst::create(ctx_, axis, size, name));
}
}
}

View File

@@ -1,3 +1,4 @@
#include <cassert>
#include "ir/constant.h"
#include "ir/type.h"
#include "ir/context.h"
@@ -51,6 +52,21 @@ constant *constant_int::get(type *ty, uint64_t value) {
return new constant_int(ty, value);
}
// constant_range
// FIXME use something like APInt
constant_range::constant_range(type *ty, uint64_t first, uint64_t last)
: constant(ty, 0), first_(first), last_(last){ }
constant *constant_range::get(constant *first, constant *last) {
assert(first->get_type()->is_integer_ty());
assert(first->get_type() == last->get_type());
uint64_t vfirst = ((constant_int*)first)->get_value();
uint64_t vlast = ((constant_int*)first)->get_value();
return new constant_range(first->get_type(), vfirst, vlast);
}
// constant_fp
// FIXME use something like APFloat

View File

@@ -246,15 +246,17 @@ getelementptr_inst::getelementptr_inst(type *pointee_ty, value *ptr, const std::
set_operand(1 + i, idx[i]);
}
type *getelementptr_inst::get_return_type(type *elt_ty, value *ptr, const std::vector<value *> &idx_list) {
type *getelementptr_inst::get_return_type(type *elt_ty, value *x, const std::vector<value *> &idx_list) {
// result pointer type
type *ptr_ty = pointer_type::get(get_indexed_type(elt_ty, idx_list), ptr->get_type()->get_scalar_ty()->get_pointer_address_space());
type *ty = x->get_type();
unsigned addr_space = ty->get_scalar_ty()->get_pointer_address_space();
type *ptr_ty = pointer_type::get(get_indexed_type(elt_ty, idx_list), addr_space);
// Tile GEP
if(ptr->get_type()->is_tile_ty())
return tile_type::get_same_shapes(ptr_ty, ptr->get_type());
if(ty->is_tile_ty())
return tile_type::get_same_shapes(ptr_ty, ty);
for(value *idx : idx_list)
if (idx->get_type()->is_tile_ty())
return tile_type::get_same_shapes(ptr_ty, idx->get_type());
return tile_type::get_same_shapes(ptr_ty, ty);
// Scalar GEP
return ptr_ty;
}
@@ -329,5 +331,27 @@ instruction* broadcast_inst::create(value *arg, const std::vector<unsigned> &sha
return new broadcast_inst(arg, shapes, name, next);
}
//===----------------------------------------------------------------------===//
// matmul_inst classes
//===----------------------------------------------------------------------===//
//===----------------------------------------------------------------------===//
// builtin instructions
//===----------------------------------------------------------------------===//
get_global_range_inst::get_global_range_inst(type *ty, unsigned axis,
const std::string &name, instruction *next)
: builtin_inst(ty, 0, name, next), axis_(axis) {
}
instruction* get_global_range_inst::create(context &ctx, unsigned axis, unsigned size,
const std::string &name, instruction *next) {
type *int_ty = type::get_int32_ty(ctx);
type *tile_ty = tile_type::get(int_ty, {size});
return new get_global_range_inst(tile_ty, axis, name, next);
}
}
}