[intermediate representation] transitioning towards more flexible tile

shapes
This commit is contained in:
Philippe Tillet
2019-02-23 11:37:01 -05:00
parent 7cda55df16
commit 8f4798b81a
21 changed files with 268 additions and 115 deletions

1
TODO
View File

@@ -2,3 +2,4 @@
- proper naming scheme - proper naming scheme
- symbols table - symbols table
- name conflicts on globals? - name conflicts on globals?
- separate header for typedef (e.g., type::tile_shapes_t) to reduce compilation time

View File

@@ -6,6 +6,7 @@
#include "ir/context.h" #include "ir/context.h"
#include "ir/module.h" #include "ir/module.h"
#include "ir/print.h" #include "ir/print.h"
#include "ir/context_impl.h"
#include "codegen/selection.h" #include "codegen/selection.h"
#include "codegen/tune.h" #include "codegen/tune.h"
#include "codegen/shared_copy.h" #include "codegen/shared_copy.h"
@@ -182,6 +183,9 @@ int main() {
llvm::LLVMContext llvm_context; llvm::LLVMContext llvm_context;
llvm::Module llvm_module("test", llvm_context); llvm::Module llvm_module("test", llvm_context);
// context.p_impl->mp_constants_[0]->set_value(16);
// context.p_impl->mp_constants_[1]->set_value(16);
// context.p_impl->mp_constants_[2]->set_value(8);
// create passes // create passes
tdl::codegen::buffer_info_pass buffer_info; tdl::codegen::buffer_info_pass buffer_info;

View File

@@ -5,7 +5,7 @@
#include <cassert> #include <cassert>
#include <vector> #include <vector>
#include <string> #include <string>
#include <iostream>
namespace tdl{ namespace tdl{
@@ -56,6 +56,12 @@ enum TYPE_T{
FLOAT32_T, FLOAT64_T FLOAT32_T, FLOAT64_T
}; };
enum STORAGE_SPEC_T{
TUNABLE_T,
KERNEL_T,
READONLY_T, WRITEONLY_T,
};
class pointer; class pointer;
class identifier; class identifier;
class constant; class constant;
@@ -75,7 +81,7 @@ public:
template<class T> template<class T>
class list: public node { class list: public node {
public: public:
list(const T& x): values_{x} {} list(const T& x): values_(1, x) {}
node* append(const T& x){ node* append(const T& x){
values_.push_back(x); values_.push_back(x);
@@ -389,16 +395,30 @@ public:
class no_op: public statement { }; class no_op: public statement { };
// Types // Types
class declaration_specifier: public node{ class declaration_specifier: public node{
public: public:
declaration_specifier(TYPE_T spec) using node::node;
: spec_(spec) { } virtual ir::type* type(ir::module *mod) const = 0;
};
class typed_declaration_specifier: public declaration_specifier {
public:
typed_declaration_specifier(TYPE_T ty): ty_(ty){ }
ir::type* type(ir::module *mod) const; ir::type* type(ir::module *mod) const;
private: private:
const TYPE_T spec_; const TYPE_T ty_;
};
class storage_declaration_specifier: public declaration_specifier {
public:
storage_declaration_specifier(STORAGE_SPEC_T storage_spec, node *decl_spec)
: storage_spec_(storage_spec), decl_spec_((declaration_specifier*)decl_spec) {}
ir::type* type(ir::module *mod) const;
private:
const STORAGE_SPEC_T storage_spec_;
const declaration_specifier* decl_spec_;
}; };
class declarator; class declarator;
@@ -495,7 +515,7 @@ public:
: declarator((node*)((declarator*)decl)->id()), : declarator((node*)((declarator*)decl)->id()),
decl_((declarator*)decl), expr_((expression*)init){ } decl_((declarator*)decl), expr_((expression*)init){ }
void specifier(const declaration_specifier *spec); void set_specifier(const declaration_specifier *spec);
ir::value* codegen(ir::module *) const; ir::value* codegen(ir::module *) const;
public: public:
@@ -535,17 +555,17 @@ public:
class translation_unit: public node{ class translation_unit: public node{
public: public:
translation_unit(node *item) translation_unit(node *item)
: decls_((list<node*>*)item) { } : decls_(item) { }
translation_unit *add(node *item) { translation_unit *add(node *item) {
decls_->append(item); decls_.append(item);
return this; return this;
} }
ir::value* codegen(ir::module * mod) const; ir::value* codegen(ir::module * mod) const;
private: private:
list<node*>* decls_; list<node*> decls_;
}; };
} }

View File

@@ -20,12 +20,14 @@ struct token: public node{
token(BIN_OP_T value): bin_op(value){ } token(BIN_OP_T value): bin_op(value){ }
token(UNARY_OP_T value): unary_op(value){ } token(UNARY_OP_T value): unary_op(value){ }
token(TYPE_T value): type(value){ } token(TYPE_T value): type(value){ }
token(STORAGE_SPEC_T value): storage_spec(value){ }
union { union {
ASSIGN_OP_T assign_op; ASSIGN_OP_T assign_op;
BIN_OP_T bin_op; BIN_OP_T bin_op;
UNARY_OP_T unary_op; UNARY_OP_T unary_op;
TYPE_T type; TYPE_T type;
STORAGE_SPEC_T storage_spec;
}; };
}; };
@@ -39,10 +41,12 @@ node* append_ptr_list(node *result, node *in){
ASSIGN_OP_T get_assign_op(node *op) { return ((token*)op)->assign_op; } ASSIGN_OP_T get_assign_op(node *op) { return ((token*)op)->assign_op; }
UNARY_OP_T get_unary_op(node *op) { return ((token*)op)->unary_op; } UNARY_OP_T get_unary_op(node *op) { return ((token*)op)->unary_op; }
TYPE_T get_type_spec(node *op) { return ((token*)op)->type; } TYPE_T get_type_spec(node *op) { return ((token*)op)->type; }
STORAGE_SPEC_T get_storage_spec(node *op) { return ((token*)op)->storage_spec;}
%} %}
%token IDENTIFIER CONSTANT STRING_LITERAL %token IDENTIFIER CONSTANT STRING_LITERAL
%token TUNABLE KERNEL READONLY WRITEONLY
%token PTR_OP INC_OP DEC_OP LEFT_OP RIGHT_OP LE_OP GE_OP EQ_OP NE_OP %token PTR_OP INC_OP DEC_OP LEFT_OP RIGHT_OP LE_OP GE_OP EQ_OP NE_OP
%token AND_OP OR_OP MUL_ASSIGN DIV_ASSIGN MOD_ASSIGN ADD_ASSIGN %token AND_OP OR_OP MUL_ASSIGN DIV_ASSIGN MOD_ASSIGN ADD_ASSIGN
%token SUB_ASSIGN LEFT_ASSIGN RIGHT_ASSIGN AND_ASSIGN %token SUB_ASSIGN LEFT_ASSIGN RIGHT_ASSIGN AND_ASSIGN
@@ -87,17 +91,12 @@ abstract_declarator
; ;
direct_abstract_declarator direct_abstract_declarator
: '[' constant_list ']' { $$ = new tile(nullptr, $1); } : '[' primary_expression_list ']' { $$ = new tile(nullptr, $1); }
constant : constant :
CONSTANT { $$ = new constant(atoi(yytext)); } CONSTANT { $$ = new constant(atoi(yytext)); }
; ;
constant_list
: constant { $$ = new list<constant*>((constant*)$1); }
| constant_list ',' constant { $$ = append_ptr_list<constant>($1, $3); }
;
type_name type_name
: declaration_specifiers { $$ = new type_name($1, nullptr); } : declaration_specifiers { $$ = new type_name($1, nullptr); }
| declaration_specifiers abstract_declarator { $$ = new type_name($1, $2); } | declaration_specifiers abstract_declarator { $$ = new type_name($1, $2); }
@@ -112,7 +111,7 @@ identifier
; ;
builtin builtin
: GET_GLOBAL_RANGE '[' constant ']' '(' constant ')' { $$ = new get_global_range($3, $6); } : GET_GLOBAL_RANGE '[' primary_expression ']' '(' constant ')' { $$ = new get_global_range($3, $6); }
| DOT '(' expression ',' expression ',' expression ')' { $$ = new matmul_expression($3, $5, $7); } | DOT '(' expression ',' expression ',' expression ')' { $$ = new matmul_expression($3, $5, $7); }
primary_expression primary_expression
@@ -124,6 +123,11 @@ primary_expression
| '(' expression ')' { $$ = $2; } | '(' expression ')' { $$ = $2; }
; ;
primary_expression_list
: primary_expression { $$ = new list<expression*>((expression*)$1); }
| primary_expression_list ',' primary_expression { $$ = append_ptr_list<expression>($1, $3); }
;
slice slice
: ':' { $$ = new slice(tdl::ast::ALL); } : ':' { $$ = new slice(tdl::ast::ALL); }
| NEWAXIS { $$ = new slice(tdl::ast::NEWAXIS); } | NEWAXIS { $$ = new slice(tdl::ast::NEWAXIS); }
@@ -312,7 +316,7 @@ jump_statement
direct_declarator direct_declarator
: identifier { $$ = $1; } : identifier { $$ = $1; }
| identifier '[' constant_list ']' { $$ = new tile($1, $3); } | identifier '[' primary_expression_list ']' { $$ = new tile($1, $3); }
| identifier '(' parameter_list ')' { $$ = new function($1, $3); } | identifier '(' parameter_list ')' { $$ = new function($1, $3); }
| identifier '(' ')' { $$ = new function($1, nullptr); } | identifier '(' ')' { $$ = new function($1, nullptr); }
; ;
@@ -330,7 +334,8 @@ parameter_declaration
declaration_specifiers declaration_specifiers
: type_specifier { $$ = new declaration_specifier(get_type_spec($1)); } : type_specifier { $$ = new typed_declaration_specifier(get_type_spec($1)); }
| storage_class_specifier declaration_specifiers { $$ = new storage_declaration_specifier(get_storage_spec($1), $2); }
; ;
init_declarator_list init_declarator_list
@@ -354,6 +359,13 @@ init_declarator
| declarator '=' initialization_expression { $$ = new initializer($1, $3); } | declarator '=' initialization_expression { $$ = new initializer($1, $3); }
; ;
storage_class_specifier
: TUNABLE { $$ = new token(TUNABLE_T); }
| KERNEL { $$ = new token(KERNEL_T); }
| READONLY { $$ = new token(READONLY_T); }
| WRITEONLY { $$ = new token(WRITEONLY_T); }
;
/* -------------------------- */ /* -------------------------- */
/* Translation Unit */ /* Translation Unit */
/* -------------------------- */ /* -------------------------- */

View File

@@ -16,6 +16,10 @@ int comment();
%} %}
%% %%
"tunable" { count(); return(TUNABLE); }
"kernel" { count(); return(KERNEL); }
"readonly" { count(); return(READONLY); }
"writeonly" { count(); return(WRITEONLY); }
"@" { count(); return(AT); } "@" { count(); return(AT); }
"newaxis" { count(); return(NEWAXIS); } "newaxis" { count(); return(NEWAXIS); }
"if" { count(); return(IF); } "if" { count(); return(IF); }

View File

@@ -32,7 +32,7 @@ private:
void add_reference(ir::value *v, interval_vec_t &res); void add_reference(ir::value *v, interval_vec_t &res);
void get_read_intervals(ir::instruction *i, interval_vec_t &res); void get_read_intervals(ir::instruction *i, interval_vec_t &res);
void get_written_intervals(ir::instruction *i, interval_vec_t &res); void get_written_intervals(ir::instruction *i, interval_vec_t &res);
void add(ir::basic_block *block, interval_vec_t &not_synced, std::set<ir::instruction *> &insert_pts); void add(ir::basic_block *block, interval_vec_t &not_synced, ir::builder &builder);
public: public:
barriers(allocation *alloc, buffer_info_pass *buffer_info): alloc_(alloc), buffer_info_(buffer_info) {} barriers(allocation *alloc, buffer_info_pass *buffer_info): alloc_(alloc), buffer_info_(buffer_info) {}

View File

@@ -100,6 +100,7 @@ class selection{
private: private:
// utils // utils
llvm::Type *make_vector_ty(llvm::Type *ty, size_t vector_size); llvm::Type *make_vector_ty(llvm::Type *ty, size_t vector_size);
std::vector<unsigned> extract_shapes(ir::value *v);
// LLVM conversions // LLVM conversions
llvm::Type* llvm_type(ir::type *ty, llvm::LLVMContext &ctx); llvm::Type* llvm_type(ir::type *ty, llvm::LLVMContext &ctx);

View File

@@ -6,6 +6,7 @@
#include <string> #include <string>
#include "instructions.h" #include "instructions.h"
#include "basic_block.h" #include "basic_block.h"
#include "type.h"
namespace tdl{ namespace tdl{
namespace ir{ namespace ir{
@@ -110,11 +111,11 @@ public:
value *create_load(value *arg, const std::string &name = ""); value *create_load(value *arg, const std::string &name = "");
value *create_store(value *ptr, value *val, const std::string &name = ""); value *create_store(value *ptr, value *val, const std::string &name = "");
// Tile instruction // Tile instruction
value *create_splat(value *arg, const std::vector<unsigned> &shapes, const std::string &name = ""); value *create_splat(value *arg, const type::tile_shapes_t &shapes, const std::string &name = "");
value *create_reshape(value *arg, const std::vector<unsigned> &shapes, const std::string &name = ""); value *create_reshape(value *arg, const type::tile_shapes_t &shapes, const std::string &name = "");
value *create_broadcast(value *arg, const std::vector<unsigned> &shapes, const std::string &name = ""); value *create_broadcast(value *arg, const type::tile_shapes_t &shapes, const std::string &name = "");
// Built-in instruction // Built-in instruction
value *create_get_global_range(unsigned axis, unsigned size, const std::string &name = ""); value *create_get_global_range(unsigned axis, type::tile_shapes_t::value_type size, const std::string &name = "");
value *create_matmul(value *A, value *B, value *C, const std::string &name = ""); value *create_matmul(value *A, value *B, value *C, const std::string &name = "");
// Intrinsics // Intrinsics
value *create_copy_to_shared(value *arg, const std::string &name = ""); value *create_copy_to_shared(value *arg, const std::string &name = "");

View File

@@ -2,6 +2,7 @@
#define TDL_INCLUDE_IR_CONSTANT_H #define TDL_INCLUDE_IR_CONSTANT_H
#include "value.h" #include "value.h"
#include <cassert>
namespace tdl{ namespace tdl{
namespace ir{ namespace ir{
@@ -28,28 +29,43 @@ public:
static undef_value* get(type* ty); static undef_value* get(type* ty);
}; };
/* Constant int */ /* Constant int */
class constant_int: public constant{ class constant_int: public constant{
protected:
constant_int(type *ty, uint64_t value); constant_int(type *ty, uint64_t value);
public: public:
uint64_t get_value() const { return value_; } uint64_t get_value() const { return value_; }
static constant *get(type *ty, uint64_t value); static constant_int *get(type *ty, uint64_t value);
protected:
uint64_t value_;
};
/* Metaparameter int */
class metaparameter: public constant_int{
metaparameter(type *ty, unsigned lo, unsigned hi);
public:
static metaparameter *create(context &ctx, type *ty, unsigned lo, unsigned hi);
void set_value(uint64_t value) { value_ = value; }
private: private:
uint64_t value_; unsigned lo_;
unsigned hi_;
}; };
/* constant range */ /* constant range */
class constant_range: public constant{ class constant_range: public constant{
constant_range(type *ty, uint64_t first, uint64_t last); constant_range(type *ty, constant_int* first, constant_int* last);
public: public:
static constant *get(constant *first, constant *last); static constant *get(constant_int *first, constant_int *last);
private: private:
uint64_t first_; constant_int* first_;
uint64_t last_; constant_int* last_;
}; };
/* constant fp */ /* constant fp */

View File

@@ -12,6 +12,7 @@ class context;
class constant_int; class constant_int;
class constant_fp; class constant_fp;
class undef_value; class undef_value;
class metaparameter;
/* Context impl */ /* Context impl */
class context_impl { class context_impl {
@@ -26,13 +27,15 @@ public:
integer_type int1_ty, int8_ty, int16_ty, int32_ty, int64_ty, int128_ty; integer_type int1_ty, int8_ty, int16_ty, int32_ty, int64_ty, int128_ty;
// Pointer types // Pointer types
std::map<std::pair<type*, unsigned>, pointer_type*> ptr_tys; std::map<std::pair<type*, unsigned>, pointer_type*> ptr_tys;
std::map<std::pair<type*,std::vector<unsigned>>, tile_type*> tile_tys; std::map<std::pair<type*, type::tile_shapes_t>, tile_type*> tile_tys;
// Int constants // Int constants
std::map<uint64_t, constant_int*> int_constants_; std::map<std::pair<type*, uint64_t>, constant_int*> int_constants_;
// Float constants // Float constants
std::map<double, constant_fp*> fp_constants_; std::map<double, constant_fp*> fp_constants_;
// undef values // undef values
std::map<type*, undef_value*> uv_constants_; std::map<type*, undef_value*> uv_constants_;
// Metaparameters
std::vector<metaparameter*> mp_constants_;
}; };
} }

View File

@@ -3,6 +3,7 @@
#include <vector> #include <vector>
#include "value.h" #include "value.h"
#include "ir/type.h"
#include "llvm/IR/Instructions.h" #include "llvm/IR/Instructions.h"
namespace tdl{ namespace tdl{
@@ -358,7 +359,7 @@ public:
class retile_inst: public unary_inst { class retile_inst: public unary_inst {
protected: protected:
retile_inst(value *arg, const std::vector<unsigned> &shape_suffix, const std::string &name, instruction *next); retile_inst(value *arg, const type::tile_shapes_t &shapes, const std::string &name, instruction *next);
static std::string shape_suffix(ir::type* ty); static std::string shape_suffix(ir::type* ty);
}; };
@@ -370,7 +371,7 @@ private:
std::string repr_impl() const { return "reshape" + shape_suffix(get_type()); } std::string repr_impl() const { return "reshape" + shape_suffix(get_type()); }
public: public:
static instruction* create(value *arg, const std::vector<unsigned> &shape_suffix, static instruction* create(value *arg, const type::tile_shapes_t &shape_suffix,
const std::string &name = "", instruction *next = nullptr); const std::string &name = "", instruction *next = nullptr);
}; };
@@ -382,7 +383,7 @@ private:
std::string repr_impl() const { return "splat" + shape_suffix(get_type()); } std::string repr_impl() const { return "splat" + shape_suffix(get_type()); }
public: public:
static instruction* create(value *arg, const std::vector<unsigned> &shape_suffix, static instruction* create(value *arg, const type::tile_shapes_t &shape_suffix,
const std::string &name = "", instruction *next = nullptr); const std::string &name = "", instruction *next = nullptr);
}; };
@@ -394,7 +395,7 @@ private:
std::string repr_impl() const { return "broadcast" + shape_suffix(get_type()); } std::string repr_impl() const { return "broadcast" + shape_suffix(get_type()); }
public: public:
static instruction* create(value *arg, const std::vector<unsigned> &shape_suffix, static instruction* create(value *arg, const type::tile_shapes_t &shape_suffix,
const std::string &name = "", instruction *next = nullptr); const std::string &name = "", instruction *next = nullptr);
}; };
@@ -414,7 +415,7 @@ private:
std::string repr_impl() const { return "get_global_range(" + std::to_string(axis_) + ")"; } std::string repr_impl() const { return "get_global_range(" + std::to_string(axis_) + ")"; }
public: public:
static instruction* create(context &ctx, unsigned axis, unsigned size, static instruction* create(context &ctx, unsigned axis, type::tile_shapes_t::value_type size,
const std::string &name = "", const std::string &name = "",
instruction *next = nullptr); instruction *next = nullptr);
unsigned get_axis() const { return axis_; } unsigned get_axis() const { return axis_; }

View File

@@ -3,6 +3,7 @@
#include <map> #include <map>
#include <set> #include <set>
#include <stack>
#include <string> #include <string>
#include <functional> #include <functional>
#include "builder.h" #include "builder.h"
@@ -12,6 +13,7 @@ namespace tdl{
namespace ast{ namespace ast{
class iteration_statement; class iteration_statement;
class compound_statement;
} }
@@ -69,6 +71,10 @@ public:
const functions_list_t &get_function_list() const { return functions_; } const functions_list_t &get_function_list() const { return functions_; }
functions_list_t &get_function_list() { return functions_; } functions_list_t &get_function_list() { return functions_; }
function *get_or_insert_function(const std::string &name, function_type *ty); function *get_or_insert_function(const std::string &name, function_type *ty);
// Scope
void push_scope(const ast::compound_statement* scope) { scopes_.push(scope); }
void pop_scope() { scopes_.pop(); }
const ast::compound_statement* get_scope() { return scopes_.top(); }
private: private:
@@ -83,6 +89,7 @@ private:
symbols_map_t symbols_; symbols_map_t symbols_;
std::function<ir::value*()> continue_fn_; std::function<ir::value*()> continue_fn_;
std::map<value*, value**> current_phi_; std::map<value*, value**> current_phi_;
std::stack<const ast::compound_statement*> scopes_;
}; };
} }

View File

@@ -3,6 +3,7 @@
#include <vector> #include <vector>
#include <iostream> #include <iostream>
#include <set>
namespace tdl{ namespace tdl{
namespace ir{ namespace ir{
@@ -10,9 +11,13 @@ namespace ir{
class context; class context;
class value; class value;
class integer_type; class integer_type;
class constant_int;
/* Type */ /* Type */
class type { class type {
public:
typedef std::vector<constant_int*> tile_shapes_t;
protected: protected:
typedef std::vector<type*> contained_tys_vec_t; typedef std::vector<type*> contained_tys_vec_t;
typedef contained_tys_vec_t::iterator ty_iterator; typedef contained_tys_vec_t::iterator ty_iterator;
@@ -54,7 +59,7 @@ public:
unsigned get_tile_bitwidth() const; unsigned get_tile_bitwidth() const;
unsigned get_primitive_size_in_bits() const; unsigned get_primitive_size_in_bits() const;
type *get_scalar_ty() const; type *get_scalar_ty() const;
const std::vector<unsigned> &get_tile_shapes() const; const tile_shapes_t& get_tile_shapes() const;
unsigned get_tile_num_elements() const; unsigned get_tile_num_elements() const;
type *get_tile_element_ty() const; type *get_tile_element_ty() const;
unsigned get_pointer_address_space() const; unsigned get_pointer_address_space() const;
@@ -94,9 +99,25 @@ public:
static integer_type *get_int64_ty(context &ctx); static integer_type *get_int64_ty(context &ctx);
static integer_type *get_int128_ty(context &ctx); static integer_type *get_int128_ty(context &ctx);
// Attributes
type* set_tunable() { is_tunable_ = true; return this; }
type* set_readonly() { is_readonly_ = true; return this; }
type* set_writeonly() { is_writeonly_ = true; return this; }
type* set_kernel() { is_kernel_ = true; return this; }
bool get_tunable() { return is_tunable_; }
bool get_readonly() { return is_readonly_; }
bool get_writeonly() { return is_writeonly_; }
bool get_kernel() { return is_kernel_; }
private: private:
context &ctx_; context &ctx_;
id_t id_; id_t id_;
// attributes
bool is_tunable_;
bool is_readonly_;
bool is_writeonly_;
bool is_kernel_;
protected: protected:
contained_tys_vec_t contained_tys_; contained_tys_vec_t contained_tys_;
@@ -132,21 +153,24 @@ public:
class tile_type: public composite_type { class tile_type: public composite_type {
private: private:
tile_type(type *ty, const std::vector<unsigned> &shapes); tile_type(type *ty, const tile_shapes_t &shapes);
static bool is_valid_elt_ty(type *ty); static bool is_valid_elt_ty(type *ty);
public: public:
// accessors // accessors
const std::vector<unsigned>& get_shapes() const { return shapes_; } const tile_shapes_t& get_shapes() const { return shapes_; }
unsigned get_num_elements() const; unsigned get_num_elements() const;
unsigned get_bitwidth() const; unsigned get_bitwidth() const;
// factory methods // factory methods
static tile_type* get(type *ty, const std::vector<unsigned> &shapes); static tile_type* get(type *ty, const tile_shapes_t &shapes);
static tile_type* get_same_shapes(type *ty, type *ref); static tile_type* get_same_shapes(type *ty, type *ref);
// shortcut to get a 1 element in the shape
static tile_shapes_t::value_type make_one(context &ctx);
private: private:
std::vector<unsigned> shapes_; tile_shapes_t shapes_;
}; };
class pointer_type: public type { class pointer_type: public type {

View File

@@ -100,6 +100,7 @@ void node::implicit_broadcast(ir::module *mod, ir::value *&lhs, ir::value *&rhs)
ir::builder &builder = mod->get_builder(); ir::builder &builder = mod->get_builder();
ir::type *lhs_ty = lhs->get_type(); ir::type *lhs_ty = lhs->get_type();
ir::type *rhs_ty = rhs->get_type(); ir::type *rhs_ty = rhs->get_type();
ir::type::tile_shapes_t::value_type one = ir::tile_type::make_one(mod->get_context());
// Both are scalar // Both are scalar
if(!lhs_ty->is_tile_ty() && !rhs_ty->is_tile_ty()) if(!lhs_ty->is_tile_ty() && !rhs_ty->is_tile_ty())
return; return;
@@ -111,30 +112,30 @@ void node::implicit_broadcast(ir::module *mod, ir::value *&lhs, ir::value *&rhs)
return; return;
} }
// Both are arrays // Both are arrays
std::vector<unsigned> lhs_shapes = lhs->get_type()->get_tile_shapes(); auto lhs_shapes = lhs->get_type()->get_tile_shapes();
std::vector<unsigned> rhs_shapes = rhs->get_type()->get_tile_shapes(); auto rhs_shapes = rhs->get_type()->get_tile_shapes();
if(lhs_shapes == rhs_shapes) if(lhs_shapes == rhs_shapes)
return; return;
int lhs_dim = lhs_shapes.size(); int lhs_dim = lhs_shapes.size();
int rhs_dim = rhs_shapes.size(); int rhs_dim = rhs_shapes.size();
std::vector<unsigned> &shortest = (lhs_dim < rhs_dim)?lhs_shapes:rhs_shapes; auto &shortest = (lhs_dim < rhs_dim)?lhs_shapes:rhs_shapes;
std::vector<unsigned> &longest = (lhs_dim < rhs_dim)?rhs_shapes:lhs_shapes; auto &longest = (lhs_dim < rhs_dim)?rhs_shapes:lhs_shapes;
size_t ndim = longest.size(); size_t ndim = longest.size();
int off = longest.size() - shortest.size(); int off = longest.size() - shortest.size();
for(int i = longest.size() - 1; i>= 0; i--){ for(int i = longest.size() - 1; i>= 0; i--){
if(shortest[off + i] != longest[i] && shortest[off + i] != 1 && longest[i] != 1) if(shortest[off + i] != longest[i] && shortest[off + i] != one && longest[i] != one)
throw std::runtime_error("cannot broadcast"); throw std::runtime_error("cannot broadcast");
} }
// Pad // Pad
for(size_t i = 0; i < off; i++) for(size_t i = 0; i < off; i++)
shortest.insert(shortest.begin(), 1); shortest.insert(shortest.begin(), one);
ir::value *&target = (lhs_dim < rhs_dim)?lhs:rhs; ir::value *&target = (lhs_dim < rhs_dim)?lhs:rhs;
if(off > 0) if(off > 0)
target = builder.create_reshape(target, shortest); target = builder.create_reshape(target, shortest);
// Broadcast // Broadcast
std::vector<unsigned> shapes(ndim); ir::type::tile_shapes_t shapes(ndim);
for(size_t i = 0; i < ndim; i++) for(size_t i = 0; i < ndim; i++)
shapes[i] = std::max(shortest[i], longest[i]); shapes[i] = shortest[i]==one?longest[i]:shortest[i];
if(shapes != lhs_shapes) if(shapes != lhs_shapes)
lhs = builder.create_broadcast(lhs, shapes); lhs = builder.create_broadcast(lhs, shapes);
if(shapes != rhs_shapes) if(shapes != rhs_shapes)
@@ -148,14 +149,15 @@ inline bool is_terminator(ir::value* x) {
/* Translation unit */ /* Translation unit */
ir::value* translation_unit::codegen(ir::module *mod) const{ ir::value* translation_unit::codegen(ir::module *mod) const{
decls_->codegen(mod); mod->push_scope(nullptr);
decls_.codegen(mod);
return nullptr; return nullptr;
} }
/* Declaration specifier */ /* Declaration specifier */
ir::type* declaration_specifier::type(ir::module *mod) const { ir::type* typed_declaration_specifier::type(ir::module *mod) const {
ir::context &ctx = mod->get_context(); ir::context &ctx = mod->get_context();
switch (spec_) { switch (ty_) {
case VOID_T: return ir::type::get_void_ty(ctx); case VOID_T: return ir::type::get_void_ty(ctx);
case INT1_T: return ir::type::get_int1_ty(ctx); case INT1_T: return ir::type::get_int1_ty(ctx);
case INT8_T: return ir::type::get_int8_ty(ctx); case INT8_T: return ir::type::get_int8_ty(ctx);
@@ -164,7 +166,18 @@ ir::type* declaration_specifier::type(ir::module *mod) const {
case INT64_T: return ir::type::get_int64_ty(ctx); case INT64_T: return ir::type::get_int64_ty(ctx);
case FLOAT32_T: return ir::type::get_float_ty(ctx); case FLOAT32_T: return ir::type::get_float_ty(ctx);
case FLOAT64_T: return ir::type::get_double_ty(ctx); case FLOAT64_T: return ir::type::get_double_ty(ctx);
default: throw std::runtime_error("unreachable"); default: throw std::runtime_error("unreachable");
}
}
ir::type* storage_declaration_specifier::type(ir::module *mod) const {
ir::type* result = decl_spec_->type(mod);
switch(storage_spec_){
case TUNABLE_T: return result->set_tunable();
case KERNEL_T: return result->set_kernel();
case READONLY_T: return result->set_readonly();
case WRITEONLY_T: return result->set_writeonly();
default: throw std::runtime_error("unreachable");
} }
} }
@@ -194,10 +207,10 @@ const std::string &identifier::name() const{
} }
// Tile // Tile
ir::type* tile::type_impl(ir::module*, ir::type *type) const{ ir::type* tile::type_impl(ir::module *mod, ir::type *type) const{
std::vector<unsigned> shapes; ir::type::tile_shapes_t shapes;
for(constant *cst: shapes_->values()) for(constant *cst: shapes_->values())
shapes.push_back(cst->value()); shapes.push_back((ir::constant_int*)cst->codegen(mod));
return ir::tile_type::get(type, shapes); return ir::tile_type::get(type, shapes);
} }
@@ -245,6 +258,7 @@ ir::value* function_definition::codegen(ir::module *mod) const{
/* Statements */ /* Statements */
ir::value* compound_statement::codegen(ir::module* mod) const{ ir::value* compound_statement::codegen(ir::module* mod) const{
mod->push_scope(this);
if(decls_) if(decls_)
decls_->codegen(mod); decls_->codegen(mod);
if(statements_){ if(statements_){
@@ -254,6 +268,7 @@ ir::value* compound_statement::codegen(ir::module* mod) const{
return current; return current;
} }
} }
mod->pop_scope();
return nullptr; return nullptr;
} }
@@ -337,7 +352,7 @@ ir::value* continue_statement::codegen(ir::module *mod) const{
/* Declaration */ /* Declaration */
ir::value* declaration::codegen(ir::module* mod) const{ ir::value* declaration::codegen(ir::module* mod) const{
for(initializer *init: init_->values()) for(initializer *init: init_->values())
init->specifier(spec_); init->set_specifier(spec_);
init_->codegen(mod); init_->codegen(mod);
return nullptr; return nullptr;
} }
@@ -347,7 +362,7 @@ ir::type* initializer::type_impl(ir::module *mod, ir::type *type) const{
return decl_->type(mod, type); return decl_->type(mod, type);
} }
void initializer::specifier(const declaration_specifier *spec) { void initializer::set_specifier(const declaration_specifier *spec) {
spec_ = spec; spec_ = spec;
} }
@@ -355,6 +370,11 @@ ir::value* initializer::codegen(ir::module * mod) const{
ir::type *ty = decl_->type(mod, spec_->type(mod)); ir::type *ty = decl_->type(mod, spec_->type(mod));
std::string name = decl_->id()->name(); std::string name = decl_->id()->name();
ir::value *value = ir::undef_value::get(ty); ir::value *value = ir::undef_value::get(ty);
if(ty->get_tunable()){
assert(expr_ == nullptr);
//TODO
value = ir::metaparameter::create(mod->get_context(), ty, 4, 8);
}
if(expr_){ if(expr_){
value = expr_->codegen(mod); value = expr_->codegen(mod);
value = explicit_cast(mod->get_builder(), value, ty); value = explicit_cast(mod->get_builder(), value, ty);
@@ -464,7 +484,7 @@ ir::value* binary_operator::codegen(ir::module *mod) const{
// get_global_range // get_global_range
ir::value* get_global_range::codegen(ir::module *mod) const { ir::value* get_global_range::codegen(ir::module *mod) const {
ir::builder &builder = mod->get_builder(); ir::builder &builder = mod->get_builder();
return builder.create_get_global_range(axis_->value(), size_->value()); return builder.create_get_global_range(axis_->value(), (ir::constant_int*)size_->codegen(mod));
} }
@@ -487,11 +507,13 @@ ir::value* matmul_expression::codegen(ir::module *mod) const {
ir::value* indexing_expression::codegen(ir::module *mod) const{ ir::value* indexing_expression::codegen(ir::module *mod) const{
ir::value *in = mod->get_value(id_->name()); ir::value *in = mod->get_value(id_->name());
const std::vector<slice*> &slices = slices_->values(); const std::vector<slice*> &slices = slices_->values();
std::vector<unsigned> in_shapes = in->get_type()->get_tile_shapes(); auto in_shapes = in->get_type()->get_tile_shapes();
std::vector<unsigned> out_shapes(slices.size()); ir::type::tile_shapes_t::value_type one = ir::tile_type::make_one(mod->get_context());
ir::type::tile_shapes_t out_shapes(slices.size());
// create shapes
size_t current = 0; size_t current = 0;
for(size_t i = 0; i < out_shapes.size(); i++) for(size_t i = 0; i < out_shapes.size(); i++)
out_shapes[i] = (slices[i]->type()==NEWAXIS)?1:in_shapes[current++]; out_shapes[i] = (slices[i]->type()==NEWAXIS)?one:in_shapes[current++];
return mod->get_builder().create_reshape(in, out_shapes); return mod->get_builder().create_reshape(in, out_shapes);
} }
@@ -586,8 +608,8 @@ int constant::value() const{
/* Constant range */ /* Constant range */
ir::value* constant_range::codegen(ir::module *mod) const{ ir::value* constant_range::codegen(ir::module *mod) const{
return ir::constant_range::get((ir::constant*)first_->codegen(mod), return ir::constant_range::get((ir::constant_int*)first_->codegen(mod),
(ir::constant*)last_->codegen(mod)); (ir::constant_int*)last_->codegen(mod));
} }
/* Named */ /* Named */

View File

@@ -45,10 +45,15 @@ void barriers::get_written_intervals(ir::instruction *i, interval_vec_t &res){
void barriers::insert_barrier(ir::instruction *instr, ir::builder &builder) { void barriers::insert_barrier(ir::instruction *instr, ir::builder &builder) {
if(auto *phi = dynamic_cast<ir::phi_node*>(instr)) { if(auto *phi = dynamic_cast<ir::phi_node*>(instr)) {
std::set<ir::value*> incoming;
for(unsigned n = 0; n < phi->get_num_incoming(); n++){ for(unsigned n = 0; n < phi->get_num_incoming(); n++){
ir::basic_block *block = phi->get_incoming_block(n); ir::instruction *inc_val = dynamic_cast<ir::instruction*>(phi->get_incoming_value(n));
builder.set_insert_point(block->get_inst_list().back()); assert(inc_val);
builder.create_barrier(); if(incoming.insert(inc_val).second){
ir::basic_block *block = inc_val->get_parent();
builder.set_insert_point(block->get_inst_list().back());
builder.create_barrier();
}
} }
} }
else { else {
@@ -57,15 +62,15 @@ void barriers::insert_barrier(ir::instruction *instr, ir::builder &builder) {
} }
} }
void barriers::add(ir::basic_block *block, interval_vec_t &not_synced, std::set<ir::instruction*> &insert_pts) { void barriers::add(ir::basic_block *block, interval_vec_t &not_synced, ir::builder &builder) {
for(ir::instruction *i: block->get_inst_list()){ ir::basic_block::inst_list_t instructions = block->get_inst_list();
for(ir::instruction *i: instructions){
interval_vec_t read, written; interval_vec_t read, written;
get_read_intervals(i, read); get_read_intervals(i, read);
get_written_intervals(i, written); get_written_intervals(i, written);
if(intersect(not_synced, read) if(intersect(not_synced, read)) {
|| intersect(not_synced, written)) {
not_synced.clear(); not_synced.clear();
insert_pts.insert(i); insert_barrier(i, builder);
} }
std::copy(written.begin(), written.end(), std::back_inserter(not_synced)); std::copy(written.begin(), written.end(), std::back_inserter(not_synced));
} }
@@ -76,12 +81,8 @@ void barriers::run(ir::module &mod) {
for(ir::function *fn: mod.get_function_list()){ for(ir::function *fn: mod.get_function_list()){
// find barrier location // find barrier location
interval_vec_t not_synced; interval_vec_t not_synced;
std::set<ir::instruction*> insert_pts;
for(ir::basic_block *block: fn->blocks()) for(ir::basic_block *block: fn->blocks())
add(block, not_synced, insert_pts); add(block, not_synced, builder);
// insert barrier
for(ir::instruction *i: insert_pts)
insert_barrier(i, builder);
} }
} }

View File

@@ -44,6 +44,7 @@ llvm::Type *distributed_tile::make_vector_ty(llvm::Type *ty, size_t vector_size)
return VectorType::get(ty, vector_size); return VectorType::get(ty, vector_size);
} }
distributed_tile::distributed_tile(Type *ty, const shapes_t &shapes, const axes_t &axes, llvm::IRBuilder<> &builder, bool vectorize) distributed_tile::distributed_tile(Type *ty, const shapes_t &shapes, const axes_t &axes, llvm::IRBuilder<> &builder, bool vectorize)
: tile(make_vector_ty(ty, vectorize?axes[0].contiguous:1), shapes), axes_(axes), builder_(builder) { : tile(make_vector_ty(ty, vectorize?axes[0].contiguous:1), shapes), axes_(axes), builder_(builder) {
vector_size_ = vectorize?ty_->getVectorNumElements():1; vector_size_ = vectorize?ty_->getVectorNumElements():1;
@@ -149,6 +150,16 @@ Value* shared_tile::get_value(indices_t idx) {
return builder_.CreateLoad(ptr); return builder_.CreateLoad(ptr);
} }
/* Utils */
std::vector<unsigned> selection::extract_shapes(ir::value *v) {
const auto& shapes = v->get_type()->get_tile_shapes();
std::vector<unsigned> result(shapes.size());
for(ir::constant_int* cst: shapes)
result.push_back(cst->get_value());
return result;
}
/* convert ir::type to Type */ /* convert ir::type to Type */
Type *selection::llvm_type(ir::type *ty, LLVMContext &ctx) { Type *selection::llvm_type(ir::type *ty, LLVMContext &ctx) {
// function // function
@@ -299,11 +310,12 @@ std::vector<Value*> delinearize(Value *trailing, std::vector<unsigned> &shapes,
} }
void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) { void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
const auto& shapes = v->get_type()->get_tile_shapes(); const auto& shapes = extract_shapes(v);
size_t dim = shapes.size(); size_t dim = shapes.size();
std::vector<unsigned> contiguous(dim); std::vector<unsigned> contiguous(dim);
std::vector<unsigned> warp_size(dim); std::vector<unsigned> warp_size(dim);
std::vector<unsigned> n_warps(dim); std::vector<unsigned> n_warps(dim);
std::cout << v->get_name() << " " << typeid(*v).name() << std::endl;
for(unsigned i = 0; i < shapes.size(); i++){ for(unsigned i = 0; i < shapes.size(); i++){
std::string str_i = std::to_string(i); std::string str_i = std::to_string(i);
contiguous[i] = *params_->get_param(v, "p0.d" + str_i); contiguous[i] = *params_->get_param(v, "p0.d" + str_i);
@@ -336,7 +348,7 @@ void selection::create_grids(std::vector<ir::value*> &grids,
// get number of dimensions greater than 1 // get number of dimensions greater than 1
auto get_tile_gt1_dim = [&](ir::value *v){ auto get_tile_gt1_dim = [&](ir::value *v){
unsigned result = 0; unsigned result = 0;
for(unsigned shape: v->get_type()->get_tile_shapes()) { for(unsigned shape: extract_shapes(v)) {
result += (shape > 1)?shape:0; result += (shape > 1)?shape:0;
} }
return result; return result;
@@ -353,7 +365,7 @@ void selection::create_grids(std::vector<ir::value*> &grids,
for(ir::value *op: user->ops()) for(ir::value *op: user->ops())
bind_references(op); bind_references(op);
// bind // bind
const auto& shapes = v->get_type()->get_tile_shapes(); const auto& shapes = extract_shapes(v);
if(dynamic_cast<ir::copy_to_shared_inst*>(v) || buffer_info_->is_double(v)) if(dynamic_cast<ir::copy_to_shared_inst*>(v) || buffer_info_->is_double(v))
return; return;
for(size_t d = 0; d < shapes.size(); d++){ for(size_t d = 0; d < shapes.size(); d++){
@@ -385,7 +397,7 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder,
for(ir::value *op: user->ops()) for(ir::value *op: user->ops())
create_tile(op, builder, references, seen, sh_mem_ptr); create_tile(op, builder, references, seen, sh_mem_ptr);
LLVMContext &ctx = builder.getContext(); LLVMContext &ctx = builder.getContext();
const auto& shapes = v->get_type()->get_tile_shapes(); const auto& shapes = extract_shapes(v);
Type* ty = llvm_type(v->get_type()->get_scalar_ty(), ctx); Type* ty = llvm_type(v->get_type()->get_scalar_ty(), ctx);
// create shared tile // create shared tile
if(dynamic_cast<ir::copy_to_shared_inst*>(v) || (buffer_info_->is_double(v))){ if(dynamic_cast<ir::copy_to_shared_inst*>(v) || (buffer_info_->is_double(v))){
@@ -429,7 +441,7 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder,
} }
// create distributed tile // create distributed tile
else { else {
const auto &shapes = v->get_type()->get_tile_shapes(); const auto &shapes = extract_shapes(v);
std::vector<distributed_axis> axes(shapes.size()); std::vector<distributed_axis> axes(shapes.size());
for(size_t d = 0; d < shapes.size(); d++){ for(size_t d = 0; d < shapes.size(); d++){
if(shapes[d] > 1){ if(shapes[d] > 1){
@@ -530,7 +542,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
distributed_tile* result = (distributed_tile*)ti; distributed_tile* result = (distributed_tile*)ti;
if(!ins->get_type()->is_tile_ty()) if(!ins->get_type()->is_tile_ty())
return; return;
const auto& shapes = ins->get_type()->get_tile_shapes(); const auto& shapes = extract_shapes(ins);
// global_range // global_range
if(auto *x = dynamic_cast<ir::get_global_range_inst*>(ins)) { if(auto *x = dynamic_cast<ir::get_global_range_inst*>(ins)) {
static std::array<Intrinsic::ID, 3> ctaid = { static std::array<Intrinsic::ID, 3> ctaid = {
@@ -568,7 +580,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
// broadcast // broadcast
else if(dynamic_cast<ir::broadcast_inst*>(ins)) { else if(dynamic_cast<ir::broadcast_inst*>(ins)) {
ir::value* in = ins->get_operand(0); ir::value* in = ins->get_operand(0);
const auto& in_shapes = in->get_type()->get_tile_shapes(); const auto& in_shapes = extract_shapes(in);
distributed_tile *in_tile = (distributed_tile*)tmap_.at(in); distributed_tile *in_tile = (distributed_tile*)tmap_.at(in);
result->for_each([&](indices_t out_idx){ result->for_each([&](indices_t out_idx){
indices_t in_idx = out_idx; indices_t in_idx = out_idx;
@@ -615,7 +627,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
Function *f_mul_add = Intrinsic::getDeclaration(module, Intrinsic::fmuladd, {llvm_type(C->get_type()->get_scalar_ty(), ctx)}); Function *f_mul_add = Intrinsic::getDeclaration(module, Intrinsic::fmuladd, {llvm_type(C->get_type()->get_scalar_ty(), ctx)});
result->for_each([&](indices_t idx){ result->for_each([&](indices_t idx){
Value *res = tmap_.at(C)->get_value(idx); Value *res = tmap_.at(C)->get_value(idx);
unsigned NK = A->get_type()->get_tile_shapes()[1]; unsigned NK = extract_shapes(A)[1];
for(unsigned K = 0; K < NK; ++K){ for(unsigned K = 0; K < NK; ++K){
indices_t a_idx = {idx[0], builder.getInt32(K)}; indices_t a_idx = {idx[0], builder.getInt32(K)};
indices_t b_idx = {idx[1], builder.getInt32(K)}; indices_t b_idx = {idx[1], builder.getInt32(K)};

View File

@@ -3,6 +3,8 @@
#include "ir/type.h" #include "ir/type.h"
#include "ir/module.h" #include "ir/module.h"
#include "ir/function.h" #include "ir/function.h"
#include "ir/context_impl.h"
#include <cstdlib> #include <cstdlib>
@@ -29,7 +31,8 @@ void tune::init_c_phi(ir::instruction *v) {
void tune::init_c_graph(ir::instruction *v) { void tune::init_c_graph(ir::instruction *v) {
// Reference shape // Reference shape
std::vector<unsigned> shapes; ir::type::tile_shapes_t::value_type one = ir::tile_type::make_one(v->get_parent()->get_context());
ir::type::tile_shapes_t shapes;
if(auto *store = dynamic_cast<ir::store_inst*>(v)) if(auto *store = dynamic_cast<ir::store_inst*>(v))
shapes = store->get_pointer_operand()->get_type()->get_tile_shapes(); shapes = store->get_pointer_operand()->get_type()->get_tile_shapes();
else else
@@ -39,7 +42,7 @@ void tune::init_c_graph(ir::instruction *v) {
ir::value *op = v->get_operand(0); ir::value *op = v->get_operand(0);
unsigned current = 0; unsigned current = 0;
for(unsigned i = 0; i < shapes.size(); i ++){ for(unsigned i = 0; i < shapes.size(); i ++){
if(shapes[i] == 1) if(shapes[i] == one)
static_params_.insert({{v, i}, 1}); static_params_.insert({{v, i}, 1});
else else
add_constraint({v, i}, {op, current++}); add_constraint({v, i}, {op, current++});
@@ -99,6 +102,7 @@ void tune::connected_components(node_t x, const std::vector<unsigned *> vals, st
std::vector<unsigned*> tune::get_params(ir::module &mod) { std::vector<unsigned*> tune::get_params(ir::module &mod) {
std::vector<unsigned *> result; std::vector<unsigned *> result;
std::set<unsigned*> seen; std::set<unsigned*> seen;
for(ir::function *fn: mod.get_function_list()) for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: fn->blocks()) for(ir::basic_block *block: fn->blocks())
for(ir::instruction *i : block->get_inst_list()) for(ir::instruction *i : block->get_inst_list())
@@ -143,8 +147,9 @@ void tune::create_grids(std::vector<ir::instruction*> &grids,
// get number of dimensions greater than 1 // get number of dimensions greater than 1
auto get_tile_gt1_dim = [&](ir::value *v){ auto get_tile_gt1_dim = [&](ir::value *v){
unsigned result = 0; unsigned result = 0;
for(unsigned shape: v->get_type()->get_tile_shapes()) { auto one = ir::tile_type::make_one(fn->get_fn_type()->get_context());
result += (shape > 1)?shape:0; for(ir::constant_int *shape: v->get_type()->get_tile_shapes()) {
result += (shape != one);
} }
return result; return result;
}; };
@@ -194,8 +199,8 @@ for(ir::function *fn: mod.get_function_list()){
unsigned *s1 = params_[i]["p1.d" + strk]; unsigned *s1 = params_[i]["p1.d" + strk];
unsigned *s2 = params_[i]["p2.d" + strk]; unsigned *s2 = params_[i]["p2.d" + strk];
unsigned multiple = (*s0)*(*s1)*(*s2); unsigned multiple = (*s0)*(*s1)*(*s2);
if(shapes[k] % multiple != 0) if(shapes[k]->get_value() % multiple != 0)
errors[i].push_back("for dim " + strk + ": shape (" + to_string(shapes[k]) + ")" errors[i].push_back("for dim " + strk + ": shape (" + to_string(shapes[k]->get_value()) + ")"
" is not a multiple of layout (" + to_string(multiple) + ")"); " is not a multiple of layout (" + to_string(multiple) + ")");
} }
// the number of thread per warp must be 32 // the number of thread per warp must be 32

View File

@@ -244,15 +244,15 @@ value *builder::create_store(value *ptr, value *val, const std::string &name){
// tile instructions // tile instructions
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
value *builder::create_reshape(value *arg, const std::vector<unsigned> &shapes, const std::string &name) { value *builder::create_reshape(value *arg, const type::tile_shapes_t &shapes, const std::string &name) {
return insert(reshape_inst::create(arg, shapes, name)); return insert(reshape_inst::create(arg, shapes, name));
} }
value *builder::create_splat(value *arg, const std::vector<unsigned> &shapes, const std::string &name) { value *builder::create_splat(value *arg, const type::tile_shapes_t &shapes, const std::string &name) {
return insert(splat_inst::create(arg, shapes, name)); return insert(splat_inst::create(arg, shapes, name));
} }
value *builder::create_broadcast(value *arg, const std::vector<unsigned> &shapes, const std::string &name) { value *builder::create_broadcast(value *arg, const type::tile_shapes_t &shapes, const std::string &name) {
return insert(broadcast_inst::create(arg, shapes, name)); return insert(broadcast_inst::create(arg, shapes, name));
} }
@@ -260,7 +260,7 @@ value *builder::create_broadcast(value *arg, const std::vector<unsigned> &shapes
// built-in instructions // built-in instructions
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
value *builder::create_get_global_range(unsigned axis, unsigned size, const std::string &name) { value *builder::create_get_global_range(unsigned axis, type::tile_shapes_t::value_type size, const std::string &name) {
return insert(get_global_range_inst::create(ctx_, axis, size, name)); return insert(get_global_range_inst::create(ctx_, axis, size, name));
} }

View File

@@ -48,24 +48,27 @@ constant *constant::get_all_ones_value(type *ty) {
constant_int::constant_int(type *ty, uint64_t value) constant_int::constant_int(type *ty, uint64_t value)
: constant(ty, 0), value_(value){ } : constant(ty, 0), value_(value){ }
constant *constant_int::get(type *ty, uint64_t value) { constant_int *constant_int::get(type *ty, uint64_t value) {
return new constant_int(ty, value); context_impl *impl = ty->get_context().p_impl.get();
constant_int *& cst = impl->int_constants_[std::make_pair(ty, value)];
if(cst == nullptr)
cst = new constant_int(ty, value);
return cst;
} }
// constant_range // constant_range
// FIXME use something like APInt // FIXME use something like APInt
constant_range::constant_range(type *ty, uint64_t first, uint64_t last) constant_range::constant_range(type *ty, constant_int *first, constant_int *last)
: constant(ty, 0), first_(first), last_(last){ } : constant(ty, 0), first_(first), last_(last){ }
constant *constant_range::get(constant *first, constant *last) { constant *constant_range::get(constant_int *first, constant_int *last) {
assert(first->get_type()->is_integer_ty()); assert(first->get_type()->is_integer_ty());
assert(first->get_type() == last->get_type()); assert(first->get_type() == last->get_type());
unsigned vfirst = ((constant_int*)first)->get_value(); unsigned vfirst = ((constant_int*)first)->get_value();
unsigned vlast = ((constant_int*)last)->get_value(); assert(vfirst == 0);
assert(vlast > vfirst); type *ty = tile_type::get(first->get_type(), {last});
type *ty = tile_type::get(first->get_type(), {vlast - vfirst}); return new constant_range(ty, first, last);
return new constant_range(ty, vfirst, vlast);
} }
@@ -94,6 +97,17 @@ constant *constant_fp::get(context &ctx, double v){
return result; return result;
} }
// metaparameter
metaparameter::metaparameter(type *ty, unsigned lo, unsigned hi)
: constant_int(ty, 0), lo_(lo), hi_(hi){ }
metaparameter* metaparameter::create(context &ctx, type *ty, unsigned lo, unsigned hi) {
context_impl *impl = ctx.p_impl.get();
metaparameter *result = new metaparameter(ty, lo, hi);
impl->mp_constants_.push_back(result);
return result;
}
// undef value // undef value
undef_value::undef_value(type *ty) undef_value::undef_value(type *ty)
: constant(ty, 0) { } : constant(ty, 0) { }

View File

@@ -409,7 +409,7 @@ std::string retile_inst::shape_suffix(ir::type* ty){
std::string res = "["; std::string res = "[";
const auto& shapes = ty->get_tile_shapes(); const auto& shapes = ty->get_tile_shapes();
for(unsigned i = 0; i < shapes.size(); i++){ for(unsigned i = 0; i < shapes.size(); i++){
res += std::to_string(ty->get_tile_shapes()[i]); res += std::to_string(ty->get_tile_shapes()[i]->get_value());
if(i < shapes.size() - 1) if(i < shapes.size() - 1)
res += ", "; res += ", ";
} }
@@ -417,13 +417,13 @@ std::string retile_inst::shape_suffix(ir::type* ty){
return res; return res;
} }
retile_inst::retile_inst(value *arg, const std::vector<unsigned> &shapes, retile_inst::retile_inst(value *arg, const type::tile_shapes_t &shapes,
const std::string &name, instruction *next) const std::string &name, instruction *next)
: unary_inst(tile_type::get(arg->get_type()->get_scalar_ty(), shapes), arg, name, next) { } : unary_inst(tile_type::get(arg->get_type()->get_scalar_ty(), shapes), arg, name, next) { }
// reshape // reshape
instruction* reshape_inst::create(value *arg, const std::vector<unsigned> &shapes, instruction* reshape_inst::create(value *arg, const type::tile_shapes_t &shapes,
const std::string &name, instruction *next) { const std::string &name, instruction *next) {
return new reshape_inst(arg, shapes, name, next); return new reshape_inst(arg, shapes, name, next);
} }
@@ -431,14 +431,14 @@ instruction* reshape_inst::create(value *arg, const std::vector<unsigned> &shape
// splat // splat
instruction* splat_inst::create(value *arg, const std::vector<unsigned> &shapes, instruction* splat_inst::create(value *arg, const type::tile_shapes_t &shapes,
const std::string &name, instruction *next) { const std::string &name, instruction *next) {
return new splat_inst(arg, shapes, name, next); return new splat_inst(arg, shapes, name, next);
} }
// broadcast // broadcast
instruction* broadcast_inst::create(value *arg, const std::vector<unsigned> &shapes, instruction* broadcast_inst::create(value *arg, const type::tile_shapes_t &shapes,
const std::string &name, instruction *next) { const std::string &name, instruction *next) {
return new broadcast_inst(arg, shapes, name, next); return new broadcast_inst(arg, shapes, name, next);
} }
@@ -470,7 +470,7 @@ get_global_range_inst::get_global_range_inst(type *ty, unsigned axis,
} }
instruction* get_global_range_inst::create(context &ctx, unsigned axis, unsigned size, instruction* get_global_range_inst::create(context &ctx, unsigned axis, type::tile_shapes_t::value_type size,
const std::string &name, instruction *next) { const std::string &name, instruction *next) {
type *int_ty = type::get_int32_ty(ctx); type *int_ty = type::get_int32_ty(ctx);
type *tile_ty = tile_type::get(int_ty, {size}); type *tile_ty = tile_type::get(int_ty, {size});

View File

@@ -3,6 +3,7 @@
#include "ir/context.h" #include "ir/context.h"
#include "ir/context_impl.h" #include "ir/context_impl.h"
#include "ir/value.h" #include "ir/value.h"
#include "ir/constant.h"
namespace tdl{ namespace tdl{
namespace ir{ namespace ir{
@@ -63,7 +64,7 @@ type * type::get_pointer_element_ty() const {
} }
const std::vector<unsigned> &type::get_tile_shapes() const { const type::tile_shapes_t &type::get_tile_shapes() const {
assert(is_tile_ty()); assert(is_tile_ty());
return ((tile_type*)this)->get_shapes(); return ((tile_type*)this)->get_shapes();
} }
@@ -148,7 +149,7 @@ bool composite_type::index_valid(value *idx) const{
// tile_type class // tile_type class
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
tile_type::tile_type(type *ty, const std::vector<unsigned> &shapes) tile_type::tile_type(type *ty, const tile_shapes_t &shapes)
: composite_type(ty->get_context(), TileTyID), shapes_(shapes) { : composite_type(ty->get_context(), TileTyID), shapes_(shapes) {
contained_tys_.push_back(ty); contained_tys_.push_back(ty);
} }
@@ -159,8 +160,8 @@ bool tile_type::is_valid_elt_ty(type *ty) {
unsigned tile_type::get_num_elements() const { unsigned tile_type::get_num_elements() const {
unsigned res = 1; unsigned res = 1;
for(unsigned shape: shapes_) for(auto shape: shapes_)
res *= shape; res *= shape->get_value();
return res; return res;
} }
@@ -168,7 +169,7 @@ unsigned tile_type::get_bitwidth() const {
return get_num_elements() * get_tile_element_ty()->get_primitive_size_in_bits(); return get_num_elements() * get_tile_element_ty()->get_primitive_size_in_bits();
} }
tile_type* tile_type::get(type *elt_ty, const std::vector<unsigned> &shapes) { tile_type* tile_type::get(type *elt_ty, const tile_shapes_t &shapes) {
assert(elt_ty && "Can't get a tile of <null> type!"); assert(elt_ty && "Can't get a tile of <null> type!");
assert(shapes.size() && "Can't create a tile with empty shapes!"); assert(shapes.size() && "Can't create a tile with empty shapes!");
assert(is_valid_elt_ty(elt_ty) && "Invalid type for pointer element!"); assert(is_valid_elt_ty(elt_ty) && "Invalid type for pointer element!");
@@ -185,6 +186,10 @@ tile_type* tile_type::get_same_shapes(type *ty, type *ref){
return get(ty, ref->get_tile_shapes()); return get(ty, ref->get_tile_shapes());
} }
type::tile_shapes_t::value_type tile_type::make_one(ir::context& ctx){
return constant_int::get(type::get_int32_ty(ctx), 1);
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// function_type class // function_type class