[intermediate representation] transitioning towards more flexible tile
shapes
This commit is contained in:
1
TODO
1
TODO
@@ -2,3 +2,4 @@
|
||||
- proper naming scheme
|
||||
- symbols table
|
||||
- name conflicts on globals?
|
||||
- separate header for typedef (e.g., type::tile_shapes_t) to reduce compilation time
|
||||
|
@@ -6,6 +6,7 @@
|
||||
#include "ir/context.h"
|
||||
#include "ir/module.h"
|
||||
#include "ir/print.h"
|
||||
#include "ir/context_impl.h"
|
||||
#include "codegen/selection.h"
|
||||
#include "codegen/tune.h"
|
||||
#include "codegen/shared_copy.h"
|
||||
@@ -182,6 +183,9 @@ int main() {
|
||||
llvm::LLVMContext llvm_context;
|
||||
llvm::Module llvm_module("test", llvm_context);
|
||||
|
||||
// context.p_impl->mp_constants_[0]->set_value(16);
|
||||
// context.p_impl->mp_constants_[1]->set_value(16);
|
||||
// context.p_impl->mp_constants_[2]->set_value(8);
|
||||
|
||||
// create passes
|
||||
tdl::codegen::buffer_info_pass buffer_info;
|
||||
|
@@ -5,7 +5,7 @@
|
||||
#include <cassert>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
#include <iostream>
|
||||
|
||||
|
||||
namespace tdl{
|
||||
@@ -56,6 +56,12 @@ enum TYPE_T{
|
||||
FLOAT32_T, FLOAT64_T
|
||||
};
|
||||
|
||||
enum STORAGE_SPEC_T{
|
||||
TUNABLE_T,
|
||||
KERNEL_T,
|
||||
READONLY_T, WRITEONLY_T,
|
||||
};
|
||||
|
||||
class pointer;
|
||||
class identifier;
|
||||
class constant;
|
||||
@@ -75,7 +81,7 @@ public:
|
||||
template<class T>
|
||||
class list: public node {
|
||||
public:
|
||||
list(const T& x): values_{x} {}
|
||||
list(const T& x): values_(1, x) {}
|
||||
|
||||
node* append(const T& x){
|
||||
values_.push_back(x);
|
||||
@@ -389,16 +395,30 @@ public:
|
||||
class no_op: public statement { };
|
||||
|
||||
// Types
|
||||
|
||||
class declaration_specifier: public node{
|
||||
public:
|
||||
declaration_specifier(TYPE_T spec)
|
||||
: spec_(spec) { }
|
||||
using node::node;
|
||||
virtual ir::type* type(ir::module *mod) const = 0;
|
||||
};
|
||||
|
||||
class typed_declaration_specifier: public declaration_specifier {
|
||||
public:
|
||||
typed_declaration_specifier(TYPE_T ty): ty_(ty){ }
|
||||
ir::type* type(ir::module *mod) const;
|
||||
|
||||
private:
|
||||
const TYPE_T spec_;
|
||||
const TYPE_T ty_;
|
||||
};
|
||||
|
||||
class storage_declaration_specifier: public declaration_specifier {
|
||||
public:
|
||||
storage_declaration_specifier(STORAGE_SPEC_T storage_spec, node *decl_spec)
|
||||
: storage_spec_(storage_spec), decl_spec_((declaration_specifier*)decl_spec) {}
|
||||
ir::type* type(ir::module *mod) const;
|
||||
|
||||
private:
|
||||
const STORAGE_SPEC_T storage_spec_;
|
||||
const declaration_specifier* decl_spec_;
|
||||
};
|
||||
|
||||
class declarator;
|
||||
@@ -495,7 +515,7 @@ public:
|
||||
: declarator((node*)((declarator*)decl)->id()),
|
||||
decl_((declarator*)decl), expr_((expression*)init){ }
|
||||
|
||||
void specifier(const declaration_specifier *spec);
|
||||
void set_specifier(const declaration_specifier *spec);
|
||||
ir::value* codegen(ir::module *) const;
|
||||
|
||||
public:
|
||||
@@ -535,17 +555,17 @@ public:
|
||||
class translation_unit: public node{
|
||||
public:
|
||||
translation_unit(node *item)
|
||||
: decls_((list<node*>*)item) { }
|
||||
: decls_(item) { }
|
||||
|
||||
translation_unit *add(node *item) {
|
||||
decls_->append(item);
|
||||
decls_.append(item);
|
||||
return this;
|
||||
}
|
||||
|
||||
ir::value* codegen(ir::module * mod) const;
|
||||
|
||||
private:
|
||||
list<node*>* decls_;
|
||||
list<node*> decls_;
|
||||
};
|
||||
|
||||
}
|
||||
|
@@ -20,12 +20,14 @@ struct token: public node{
|
||||
token(BIN_OP_T value): bin_op(value){ }
|
||||
token(UNARY_OP_T value): unary_op(value){ }
|
||||
token(TYPE_T value): type(value){ }
|
||||
token(STORAGE_SPEC_T value): storage_spec(value){ }
|
||||
|
||||
union {
|
||||
ASSIGN_OP_T assign_op;
|
||||
BIN_OP_T bin_op;
|
||||
UNARY_OP_T unary_op;
|
||||
TYPE_T type;
|
||||
STORAGE_SPEC_T storage_spec;
|
||||
};
|
||||
};
|
||||
|
||||
@@ -39,10 +41,12 @@ node* append_ptr_list(node *result, node *in){
|
||||
ASSIGN_OP_T get_assign_op(node *op) { return ((token*)op)->assign_op; }
|
||||
UNARY_OP_T get_unary_op(node *op) { return ((token*)op)->unary_op; }
|
||||
TYPE_T get_type_spec(node *op) { return ((token*)op)->type; }
|
||||
STORAGE_SPEC_T get_storage_spec(node *op) { return ((token*)op)->storage_spec;}
|
||||
|
||||
%}
|
||||
|
||||
%token IDENTIFIER CONSTANT STRING_LITERAL
|
||||
%token TUNABLE KERNEL READONLY WRITEONLY
|
||||
%token PTR_OP INC_OP DEC_OP LEFT_OP RIGHT_OP LE_OP GE_OP EQ_OP NE_OP
|
||||
%token AND_OP OR_OP MUL_ASSIGN DIV_ASSIGN MOD_ASSIGN ADD_ASSIGN
|
||||
%token SUB_ASSIGN LEFT_ASSIGN RIGHT_ASSIGN AND_ASSIGN
|
||||
@@ -87,17 +91,12 @@ abstract_declarator
|
||||
;
|
||||
|
||||
direct_abstract_declarator
|
||||
: '[' constant_list ']' { $$ = new tile(nullptr, $1); }
|
||||
: '[' primary_expression_list ']' { $$ = new tile(nullptr, $1); }
|
||||
|
||||
constant :
|
||||
CONSTANT { $$ = new constant(atoi(yytext)); }
|
||||
;
|
||||
|
||||
constant_list
|
||||
: constant { $$ = new list<constant*>((constant*)$1); }
|
||||
| constant_list ',' constant { $$ = append_ptr_list<constant>($1, $3); }
|
||||
;
|
||||
|
||||
type_name
|
||||
: declaration_specifiers { $$ = new type_name($1, nullptr); }
|
||||
| declaration_specifiers abstract_declarator { $$ = new type_name($1, $2); }
|
||||
@@ -112,7 +111,7 @@ identifier
|
||||
;
|
||||
|
||||
builtin
|
||||
: GET_GLOBAL_RANGE '[' constant ']' '(' constant ')' { $$ = new get_global_range($3, $6); }
|
||||
: GET_GLOBAL_RANGE '[' primary_expression ']' '(' constant ')' { $$ = new get_global_range($3, $6); }
|
||||
| DOT '(' expression ',' expression ',' expression ')' { $$ = new matmul_expression($3, $5, $7); }
|
||||
|
||||
primary_expression
|
||||
@@ -124,6 +123,11 @@ primary_expression
|
||||
| '(' expression ')' { $$ = $2; }
|
||||
;
|
||||
|
||||
primary_expression_list
|
||||
: primary_expression { $$ = new list<expression*>((expression*)$1); }
|
||||
| primary_expression_list ',' primary_expression { $$ = append_ptr_list<expression>($1, $3); }
|
||||
;
|
||||
|
||||
slice
|
||||
: ':' { $$ = new slice(tdl::ast::ALL); }
|
||||
| NEWAXIS { $$ = new slice(tdl::ast::NEWAXIS); }
|
||||
@@ -312,7 +316,7 @@ jump_statement
|
||||
|
||||
direct_declarator
|
||||
: identifier { $$ = $1; }
|
||||
| identifier '[' constant_list ']' { $$ = new tile($1, $3); }
|
||||
| identifier '[' primary_expression_list ']' { $$ = new tile($1, $3); }
|
||||
| identifier '(' parameter_list ')' { $$ = new function($1, $3); }
|
||||
| identifier '(' ')' { $$ = new function($1, nullptr); }
|
||||
;
|
||||
@@ -330,7 +334,8 @@ parameter_declaration
|
||||
|
||||
|
||||
declaration_specifiers
|
||||
: type_specifier { $$ = new declaration_specifier(get_type_spec($1)); }
|
||||
: type_specifier { $$ = new typed_declaration_specifier(get_type_spec($1)); }
|
||||
| storage_class_specifier declaration_specifiers { $$ = new storage_declaration_specifier(get_storage_spec($1), $2); }
|
||||
;
|
||||
|
||||
init_declarator_list
|
||||
@@ -354,6 +359,13 @@ init_declarator
|
||||
| declarator '=' initialization_expression { $$ = new initializer($1, $3); }
|
||||
;
|
||||
|
||||
storage_class_specifier
|
||||
: TUNABLE { $$ = new token(TUNABLE_T); }
|
||||
| KERNEL { $$ = new token(KERNEL_T); }
|
||||
| READONLY { $$ = new token(READONLY_T); }
|
||||
| WRITEONLY { $$ = new token(WRITEONLY_T); }
|
||||
;
|
||||
|
||||
/* -------------------------- */
|
||||
/* Translation Unit */
|
||||
/* -------------------------- */
|
||||
|
@@ -16,6 +16,10 @@ int comment();
|
||||
%}
|
||||
|
||||
%%
|
||||
"tunable" { count(); return(TUNABLE); }
|
||||
"kernel" { count(); return(KERNEL); }
|
||||
"readonly" { count(); return(READONLY); }
|
||||
"writeonly" { count(); return(WRITEONLY); }
|
||||
"@" { count(); return(AT); }
|
||||
"newaxis" { count(); return(NEWAXIS); }
|
||||
"if" { count(); return(IF); }
|
||||
|
@@ -32,7 +32,7 @@ private:
|
||||
void add_reference(ir::value *v, interval_vec_t &res);
|
||||
void get_read_intervals(ir::instruction *i, interval_vec_t &res);
|
||||
void get_written_intervals(ir::instruction *i, interval_vec_t &res);
|
||||
void add(ir::basic_block *block, interval_vec_t ¬_synced, std::set<ir::instruction *> &insert_pts);
|
||||
void add(ir::basic_block *block, interval_vec_t ¬_synced, ir::builder &builder);
|
||||
|
||||
public:
|
||||
barriers(allocation *alloc, buffer_info_pass *buffer_info): alloc_(alloc), buffer_info_(buffer_info) {}
|
||||
|
@@ -100,6 +100,7 @@ class selection{
|
||||
private:
|
||||
// utils
|
||||
llvm::Type *make_vector_ty(llvm::Type *ty, size_t vector_size);
|
||||
std::vector<unsigned> extract_shapes(ir::value *v);
|
||||
|
||||
// LLVM conversions
|
||||
llvm::Type* llvm_type(ir::type *ty, llvm::LLVMContext &ctx);
|
||||
|
@@ -6,6 +6,7 @@
|
||||
#include <string>
|
||||
#include "instructions.h"
|
||||
#include "basic_block.h"
|
||||
#include "type.h"
|
||||
|
||||
namespace tdl{
|
||||
namespace ir{
|
||||
@@ -110,11 +111,11 @@ public:
|
||||
value *create_load(value *arg, const std::string &name = "");
|
||||
value *create_store(value *ptr, value *val, const std::string &name = "");
|
||||
// Tile instruction
|
||||
value *create_splat(value *arg, const std::vector<unsigned> &shapes, const std::string &name = "");
|
||||
value *create_reshape(value *arg, const std::vector<unsigned> &shapes, const std::string &name = "");
|
||||
value *create_broadcast(value *arg, const std::vector<unsigned> &shapes, const std::string &name = "");
|
||||
value *create_splat(value *arg, const type::tile_shapes_t &shapes, const std::string &name = "");
|
||||
value *create_reshape(value *arg, const type::tile_shapes_t &shapes, const std::string &name = "");
|
||||
value *create_broadcast(value *arg, const type::tile_shapes_t &shapes, const std::string &name = "");
|
||||
// Built-in instruction
|
||||
value *create_get_global_range(unsigned axis, unsigned size, const std::string &name = "");
|
||||
value *create_get_global_range(unsigned axis, type::tile_shapes_t::value_type size, const std::string &name = "");
|
||||
value *create_matmul(value *A, value *B, value *C, const std::string &name = "");
|
||||
// Intrinsics
|
||||
value *create_copy_to_shared(value *arg, const std::string &name = "");
|
||||
|
@@ -2,6 +2,7 @@
|
||||
#define TDL_INCLUDE_IR_CONSTANT_H
|
||||
|
||||
#include "value.h"
|
||||
#include <cassert>
|
||||
|
||||
namespace tdl{
|
||||
namespace ir{
|
||||
@@ -28,28 +29,43 @@ public:
|
||||
static undef_value* get(type* ty);
|
||||
};
|
||||
|
||||
|
||||
/* Constant int */
|
||||
class constant_int: public constant{
|
||||
protected:
|
||||
constant_int(type *ty, uint64_t value);
|
||||
|
||||
public:
|
||||
uint64_t get_value() const { return value_; }
|
||||
static constant *get(type *ty, uint64_t value);
|
||||
static constant_int *get(type *ty, uint64_t value);
|
||||
|
||||
protected:
|
||||
uint64_t value_;
|
||||
};
|
||||
|
||||
/* Metaparameter int */
|
||||
class metaparameter: public constant_int{
|
||||
metaparameter(type *ty, unsigned lo, unsigned hi);
|
||||
|
||||
public:
|
||||
static metaparameter *create(context &ctx, type *ty, unsigned lo, unsigned hi);
|
||||
void set_value(uint64_t value) { value_ = value; }
|
||||
|
||||
private:
|
||||
uint64_t value_;
|
||||
unsigned lo_;
|
||||
unsigned hi_;
|
||||
};
|
||||
|
||||
/* constant range */
|
||||
class constant_range: public constant{
|
||||
constant_range(type *ty, uint64_t first, uint64_t last);
|
||||
constant_range(type *ty, constant_int* first, constant_int* last);
|
||||
|
||||
public:
|
||||
static constant *get(constant *first, constant *last);
|
||||
static constant *get(constant_int *first, constant_int *last);
|
||||
|
||||
private:
|
||||
uint64_t first_;
|
||||
uint64_t last_;
|
||||
constant_int* first_;
|
||||
constant_int* last_;
|
||||
};
|
||||
|
||||
/* constant fp */
|
||||
|
@@ -12,6 +12,7 @@ class context;
|
||||
class constant_int;
|
||||
class constant_fp;
|
||||
class undef_value;
|
||||
class metaparameter;
|
||||
|
||||
/* Context impl */
|
||||
class context_impl {
|
||||
@@ -26,13 +27,15 @@ public:
|
||||
integer_type int1_ty, int8_ty, int16_ty, int32_ty, int64_ty, int128_ty;
|
||||
// Pointer types
|
||||
std::map<std::pair<type*, unsigned>, pointer_type*> ptr_tys;
|
||||
std::map<std::pair<type*,std::vector<unsigned>>, tile_type*> tile_tys;
|
||||
std::map<std::pair<type*, type::tile_shapes_t>, tile_type*> tile_tys;
|
||||
// Int constants
|
||||
std::map<uint64_t, constant_int*> int_constants_;
|
||||
std::map<std::pair<type*, uint64_t>, constant_int*> int_constants_;
|
||||
// Float constants
|
||||
std::map<double, constant_fp*> fp_constants_;
|
||||
// undef values
|
||||
std::map<type*, undef_value*> uv_constants_;
|
||||
// Metaparameters
|
||||
std::vector<metaparameter*> mp_constants_;
|
||||
};
|
||||
|
||||
}
|
||||
|
@@ -3,6 +3,7 @@
|
||||
|
||||
#include <vector>
|
||||
#include "value.h"
|
||||
#include "ir/type.h"
|
||||
#include "llvm/IR/Instructions.h"
|
||||
|
||||
namespace tdl{
|
||||
@@ -358,7 +359,7 @@ public:
|
||||
|
||||
class retile_inst: public unary_inst {
|
||||
protected:
|
||||
retile_inst(value *arg, const std::vector<unsigned> &shape_suffix, const std::string &name, instruction *next);
|
||||
retile_inst(value *arg, const type::tile_shapes_t &shapes, const std::string &name, instruction *next);
|
||||
static std::string shape_suffix(ir::type* ty);
|
||||
};
|
||||
|
||||
@@ -370,7 +371,7 @@ private:
|
||||
std::string repr_impl() const { return "reshape" + shape_suffix(get_type()); }
|
||||
|
||||
public:
|
||||
static instruction* create(value *arg, const std::vector<unsigned> &shape_suffix,
|
||||
static instruction* create(value *arg, const type::tile_shapes_t &shape_suffix,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
};
|
||||
|
||||
@@ -382,7 +383,7 @@ private:
|
||||
std::string repr_impl() const { return "splat" + shape_suffix(get_type()); }
|
||||
|
||||
public:
|
||||
static instruction* create(value *arg, const std::vector<unsigned> &shape_suffix,
|
||||
static instruction* create(value *arg, const type::tile_shapes_t &shape_suffix,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
};
|
||||
|
||||
@@ -394,7 +395,7 @@ private:
|
||||
std::string repr_impl() const { return "broadcast" + shape_suffix(get_type()); }
|
||||
|
||||
public:
|
||||
static instruction* create(value *arg, const std::vector<unsigned> &shape_suffix,
|
||||
static instruction* create(value *arg, const type::tile_shapes_t &shape_suffix,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
};
|
||||
|
||||
@@ -414,7 +415,7 @@ private:
|
||||
std::string repr_impl() const { return "get_global_range(" + std::to_string(axis_) + ")"; }
|
||||
|
||||
public:
|
||||
static instruction* create(context &ctx, unsigned axis, unsigned size,
|
||||
static instruction* create(context &ctx, unsigned axis, type::tile_shapes_t::value_type size,
|
||||
const std::string &name = "",
|
||||
instruction *next = nullptr);
|
||||
unsigned get_axis() const { return axis_; }
|
||||
|
@@ -3,6 +3,7 @@
|
||||
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <stack>
|
||||
#include <string>
|
||||
#include <functional>
|
||||
#include "builder.h"
|
||||
@@ -12,6 +13,7 @@ namespace tdl{
|
||||
namespace ast{
|
||||
|
||||
class iteration_statement;
|
||||
class compound_statement;
|
||||
|
||||
}
|
||||
|
||||
@@ -69,6 +71,10 @@ public:
|
||||
const functions_list_t &get_function_list() const { return functions_; }
|
||||
functions_list_t &get_function_list() { return functions_; }
|
||||
function *get_or_insert_function(const std::string &name, function_type *ty);
|
||||
// Scope
|
||||
void push_scope(const ast::compound_statement* scope) { scopes_.push(scope); }
|
||||
void pop_scope() { scopes_.pop(); }
|
||||
const ast::compound_statement* get_scope() { return scopes_.top(); }
|
||||
|
||||
|
||||
private:
|
||||
@@ -83,6 +89,7 @@ private:
|
||||
symbols_map_t symbols_;
|
||||
std::function<ir::value*()> continue_fn_;
|
||||
std::map<value*, value**> current_phi_;
|
||||
std::stack<const ast::compound_statement*> scopes_;
|
||||
};
|
||||
|
||||
}
|
||||
|
@@ -3,6 +3,7 @@
|
||||
|
||||
#include <vector>
|
||||
#include <iostream>
|
||||
#include <set>
|
||||
|
||||
namespace tdl{
|
||||
namespace ir{
|
||||
@@ -10,9 +11,13 @@ namespace ir{
|
||||
class context;
|
||||
class value;
|
||||
class integer_type;
|
||||
class constant_int;
|
||||
|
||||
/* Type */
|
||||
class type {
|
||||
public:
|
||||
typedef std::vector<constant_int*> tile_shapes_t;
|
||||
|
||||
protected:
|
||||
typedef std::vector<type*> contained_tys_vec_t;
|
||||
typedef contained_tys_vec_t::iterator ty_iterator;
|
||||
@@ -54,7 +59,7 @@ public:
|
||||
unsigned get_tile_bitwidth() const;
|
||||
unsigned get_primitive_size_in_bits() const;
|
||||
type *get_scalar_ty() const;
|
||||
const std::vector<unsigned> &get_tile_shapes() const;
|
||||
const tile_shapes_t& get_tile_shapes() const;
|
||||
unsigned get_tile_num_elements() const;
|
||||
type *get_tile_element_ty() const;
|
||||
unsigned get_pointer_address_space() const;
|
||||
@@ -94,9 +99,25 @@ public:
|
||||
static integer_type *get_int64_ty(context &ctx);
|
||||
static integer_type *get_int128_ty(context &ctx);
|
||||
|
||||
// Attributes
|
||||
type* set_tunable() { is_tunable_ = true; return this; }
|
||||
type* set_readonly() { is_readonly_ = true; return this; }
|
||||
type* set_writeonly() { is_writeonly_ = true; return this; }
|
||||
type* set_kernel() { is_kernel_ = true; return this; }
|
||||
|
||||
bool get_tunable() { return is_tunable_; }
|
||||
bool get_readonly() { return is_readonly_; }
|
||||
bool get_writeonly() { return is_writeonly_; }
|
||||
bool get_kernel() { return is_kernel_; }
|
||||
|
||||
private:
|
||||
context &ctx_;
|
||||
id_t id_;
|
||||
// attributes
|
||||
bool is_tunable_;
|
||||
bool is_readonly_;
|
||||
bool is_writeonly_;
|
||||
bool is_kernel_;
|
||||
|
||||
protected:
|
||||
contained_tys_vec_t contained_tys_;
|
||||
@@ -132,21 +153,24 @@ public:
|
||||
|
||||
class tile_type: public composite_type {
|
||||
private:
|
||||
tile_type(type *ty, const std::vector<unsigned> &shapes);
|
||||
tile_type(type *ty, const tile_shapes_t &shapes);
|
||||
static bool is_valid_elt_ty(type *ty);
|
||||
|
||||
public:
|
||||
// accessors
|
||||
const std::vector<unsigned>& get_shapes() const { return shapes_; }
|
||||
const tile_shapes_t& get_shapes() const { return shapes_; }
|
||||
unsigned get_num_elements() const;
|
||||
unsigned get_bitwidth() const;
|
||||
|
||||
// factory methods
|
||||
static tile_type* get(type *ty, const std::vector<unsigned> &shapes);
|
||||
static tile_type* get(type *ty, const tile_shapes_t &shapes);
|
||||
static tile_type* get_same_shapes(type *ty, type *ref);
|
||||
|
||||
// shortcut to get a 1 element in the shape
|
||||
static tile_shapes_t::value_type make_one(context &ctx);
|
||||
|
||||
private:
|
||||
std::vector<unsigned> shapes_;
|
||||
tile_shapes_t shapes_;
|
||||
};
|
||||
|
||||
class pointer_type: public type {
|
||||
|
@@ -100,6 +100,7 @@ void node::implicit_broadcast(ir::module *mod, ir::value *&lhs, ir::value *&rhs)
|
||||
ir::builder &builder = mod->get_builder();
|
||||
ir::type *lhs_ty = lhs->get_type();
|
||||
ir::type *rhs_ty = rhs->get_type();
|
||||
ir::type::tile_shapes_t::value_type one = ir::tile_type::make_one(mod->get_context());
|
||||
// Both are scalar
|
||||
if(!lhs_ty->is_tile_ty() && !rhs_ty->is_tile_ty())
|
||||
return;
|
||||
@@ -111,30 +112,30 @@ void node::implicit_broadcast(ir::module *mod, ir::value *&lhs, ir::value *&rhs)
|
||||
return;
|
||||
}
|
||||
// Both are arrays
|
||||
std::vector<unsigned> lhs_shapes = lhs->get_type()->get_tile_shapes();
|
||||
std::vector<unsigned> rhs_shapes = rhs->get_type()->get_tile_shapes();
|
||||
auto lhs_shapes = lhs->get_type()->get_tile_shapes();
|
||||
auto rhs_shapes = rhs->get_type()->get_tile_shapes();
|
||||
if(lhs_shapes == rhs_shapes)
|
||||
return;
|
||||
int lhs_dim = lhs_shapes.size();
|
||||
int rhs_dim = rhs_shapes.size();
|
||||
std::vector<unsigned> &shortest = (lhs_dim < rhs_dim)?lhs_shapes:rhs_shapes;
|
||||
std::vector<unsigned> &longest = (lhs_dim < rhs_dim)?rhs_shapes:lhs_shapes;
|
||||
auto &shortest = (lhs_dim < rhs_dim)?lhs_shapes:rhs_shapes;
|
||||
auto &longest = (lhs_dim < rhs_dim)?rhs_shapes:lhs_shapes;
|
||||
size_t ndim = longest.size();
|
||||
int off = longest.size() - shortest.size();
|
||||
for(int i = longest.size() - 1; i>= 0; i--){
|
||||
if(shortest[off + i] != longest[i] && shortest[off + i] != 1 && longest[i] != 1)
|
||||
if(shortest[off + i] != longest[i] && shortest[off + i] != one && longest[i] != one)
|
||||
throw std::runtime_error("cannot broadcast");
|
||||
}
|
||||
// Pad
|
||||
for(size_t i = 0; i < off; i++)
|
||||
shortest.insert(shortest.begin(), 1);
|
||||
shortest.insert(shortest.begin(), one);
|
||||
ir::value *&target = (lhs_dim < rhs_dim)?lhs:rhs;
|
||||
if(off > 0)
|
||||
target = builder.create_reshape(target, shortest);
|
||||
// Broadcast
|
||||
std::vector<unsigned> shapes(ndim);
|
||||
ir::type::tile_shapes_t shapes(ndim);
|
||||
for(size_t i = 0; i < ndim; i++)
|
||||
shapes[i] = std::max(shortest[i], longest[i]);
|
||||
shapes[i] = shortest[i]==one?longest[i]:shortest[i];
|
||||
if(shapes != lhs_shapes)
|
||||
lhs = builder.create_broadcast(lhs, shapes);
|
||||
if(shapes != rhs_shapes)
|
||||
@@ -148,14 +149,15 @@ inline bool is_terminator(ir::value* x) {
|
||||
|
||||
/* Translation unit */
|
||||
ir::value* translation_unit::codegen(ir::module *mod) const{
|
||||
decls_->codegen(mod);
|
||||
mod->push_scope(nullptr);
|
||||
decls_.codegen(mod);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
/* Declaration specifier */
|
||||
ir::type* declaration_specifier::type(ir::module *mod) const {
|
||||
ir::type* typed_declaration_specifier::type(ir::module *mod) const {
|
||||
ir::context &ctx = mod->get_context();
|
||||
switch (spec_) {
|
||||
switch (ty_) {
|
||||
case VOID_T: return ir::type::get_void_ty(ctx);
|
||||
case INT1_T: return ir::type::get_int1_ty(ctx);
|
||||
case INT8_T: return ir::type::get_int8_ty(ctx);
|
||||
@@ -164,7 +166,18 @@ ir::type* declaration_specifier::type(ir::module *mod) const {
|
||||
case INT64_T: return ir::type::get_int64_ty(ctx);
|
||||
case FLOAT32_T: return ir::type::get_float_ty(ctx);
|
||||
case FLOAT64_T: return ir::type::get_double_ty(ctx);
|
||||
default: throw std::runtime_error("unreachable");
|
||||
default: throw std::runtime_error("unreachable");
|
||||
}
|
||||
}
|
||||
|
||||
ir::type* storage_declaration_specifier::type(ir::module *mod) const {
|
||||
ir::type* result = decl_spec_->type(mod);
|
||||
switch(storage_spec_){
|
||||
case TUNABLE_T: return result->set_tunable();
|
||||
case KERNEL_T: return result->set_kernel();
|
||||
case READONLY_T: return result->set_readonly();
|
||||
case WRITEONLY_T: return result->set_writeonly();
|
||||
default: throw std::runtime_error("unreachable");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -194,10 +207,10 @@ const std::string &identifier::name() const{
|
||||
}
|
||||
|
||||
// Tile
|
||||
ir::type* tile::type_impl(ir::module*, ir::type *type) const{
|
||||
std::vector<unsigned> shapes;
|
||||
ir::type* tile::type_impl(ir::module *mod, ir::type *type) const{
|
||||
ir::type::tile_shapes_t shapes;
|
||||
for(constant *cst: shapes_->values())
|
||||
shapes.push_back(cst->value());
|
||||
shapes.push_back((ir::constant_int*)cst->codegen(mod));
|
||||
return ir::tile_type::get(type, shapes);
|
||||
}
|
||||
|
||||
@@ -245,6 +258,7 @@ ir::value* function_definition::codegen(ir::module *mod) const{
|
||||
|
||||
/* Statements */
|
||||
ir::value* compound_statement::codegen(ir::module* mod) const{
|
||||
mod->push_scope(this);
|
||||
if(decls_)
|
||||
decls_->codegen(mod);
|
||||
if(statements_){
|
||||
@@ -254,6 +268,7 @@ ir::value* compound_statement::codegen(ir::module* mod) const{
|
||||
return current;
|
||||
}
|
||||
}
|
||||
mod->pop_scope();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
@@ -337,7 +352,7 @@ ir::value* continue_statement::codegen(ir::module *mod) const{
|
||||
/* Declaration */
|
||||
ir::value* declaration::codegen(ir::module* mod) const{
|
||||
for(initializer *init: init_->values())
|
||||
init->specifier(spec_);
|
||||
init->set_specifier(spec_);
|
||||
init_->codegen(mod);
|
||||
return nullptr;
|
||||
}
|
||||
@@ -347,7 +362,7 @@ ir::type* initializer::type_impl(ir::module *mod, ir::type *type) const{
|
||||
return decl_->type(mod, type);
|
||||
}
|
||||
|
||||
void initializer::specifier(const declaration_specifier *spec) {
|
||||
void initializer::set_specifier(const declaration_specifier *spec) {
|
||||
spec_ = spec;
|
||||
}
|
||||
|
||||
@@ -355,6 +370,11 @@ ir::value* initializer::codegen(ir::module * mod) const{
|
||||
ir::type *ty = decl_->type(mod, spec_->type(mod));
|
||||
std::string name = decl_->id()->name();
|
||||
ir::value *value = ir::undef_value::get(ty);
|
||||
if(ty->get_tunable()){
|
||||
assert(expr_ == nullptr);
|
||||
//TODO
|
||||
value = ir::metaparameter::create(mod->get_context(), ty, 4, 8);
|
||||
}
|
||||
if(expr_){
|
||||
value = expr_->codegen(mod);
|
||||
value = explicit_cast(mod->get_builder(), value, ty);
|
||||
@@ -464,7 +484,7 @@ ir::value* binary_operator::codegen(ir::module *mod) const{
|
||||
// get_global_range
|
||||
ir::value* get_global_range::codegen(ir::module *mod) const {
|
||||
ir::builder &builder = mod->get_builder();
|
||||
return builder.create_get_global_range(axis_->value(), size_->value());
|
||||
return builder.create_get_global_range(axis_->value(), (ir::constant_int*)size_->codegen(mod));
|
||||
}
|
||||
|
||||
|
||||
@@ -487,11 +507,13 @@ ir::value* matmul_expression::codegen(ir::module *mod) const {
|
||||
ir::value* indexing_expression::codegen(ir::module *mod) const{
|
||||
ir::value *in = mod->get_value(id_->name());
|
||||
const std::vector<slice*> &slices = slices_->values();
|
||||
std::vector<unsigned> in_shapes = in->get_type()->get_tile_shapes();
|
||||
std::vector<unsigned> out_shapes(slices.size());
|
||||
auto in_shapes = in->get_type()->get_tile_shapes();
|
||||
ir::type::tile_shapes_t::value_type one = ir::tile_type::make_one(mod->get_context());
|
||||
ir::type::tile_shapes_t out_shapes(slices.size());
|
||||
// create shapes
|
||||
size_t current = 0;
|
||||
for(size_t i = 0; i < out_shapes.size(); i++)
|
||||
out_shapes[i] = (slices[i]->type()==NEWAXIS)?1:in_shapes[current++];
|
||||
out_shapes[i] = (slices[i]->type()==NEWAXIS)?one:in_shapes[current++];
|
||||
return mod->get_builder().create_reshape(in, out_shapes);
|
||||
}
|
||||
|
||||
@@ -586,8 +608,8 @@ int constant::value() const{
|
||||
|
||||
/* Constant range */
|
||||
ir::value* constant_range::codegen(ir::module *mod) const{
|
||||
return ir::constant_range::get((ir::constant*)first_->codegen(mod),
|
||||
(ir::constant*)last_->codegen(mod));
|
||||
return ir::constant_range::get((ir::constant_int*)first_->codegen(mod),
|
||||
(ir::constant_int*)last_->codegen(mod));
|
||||
}
|
||||
|
||||
/* Named */
|
||||
|
@@ -45,10 +45,15 @@ void barriers::get_written_intervals(ir::instruction *i, interval_vec_t &res){
|
||||
|
||||
void barriers::insert_barrier(ir::instruction *instr, ir::builder &builder) {
|
||||
if(auto *phi = dynamic_cast<ir::phi_node*>(instr)) {
|
||||
std::set<ir::value*> incoming;
|
||||
for(unsigned n = 0; n < phi->get_num_incoming(); n++){
|
||||
ir::basic_block *block = phi->get_incoming_block(n);
|
||||
builder.set_insert_point(block->get_inst_list().back());
|
||||
builder.create_barrier();
|
||||
ir::instruction *inc_val = dynamic_cast<ir::instruction*>(phi->get_incoming_value(n));
|
||||
assert(inc_val);
|
||||
if(incoming.insert(inc_val).second){
|
||||
ir::basic_block *block = inc_val->get_parent();
|
||||
builder.set_insert_point(block->get_inst_list().back());
|
||||
builder.create_barrier();
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
@@ -57,15 +62,15 @@ void barriers::insert_barrier(ir::instruction *instr, ir::builder &builder) {
|
||||
}
|
||||
}
|
||||
|
||||
void barriers::add(ir::basic_block *block, interval_vec_t ¬_synced, std::set<ir::instruction*> &insert_pts) {
|
||||
for(ir::instruction *i: block->get_inst_list()){
|
||||
void barriers::add(ir::basic_block *block, interval_vec_t ¬_synced, ir::builder &builder) {
|
||||
ir::basic_block::inst_list_t instructions = block->get_inst_list();
|
||||
for(ir::instruction *i: instructions){
|
||||
interval_vec_t read, written;
|
||||
get_read_intervals(i, read);
|
||||
get_written_intervals(i, written);
|
||||
if(intersect(not_synced, read)
|
||||
|| intersect(not_synced, written)) {
|
||||
if(intersect(not_synced, read)) {
|
||||
not_synced.clear();
|
||||
insert_pts.insert(i);
|
||||
insert_barrier(i, builder);
|
||||
}
|
||||
std::copy(written.begin(), written.end(), std::back_inserter(not_synced));
|
||||
}
|
||||
@@ -76,12 +81,8 @@ void barriers::run(ir::module &mod) {
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
// find barrier location
|
||||
interval_vec_t not_synced;
|
||||
std::set<ir::instruction*> insert_pts;
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
add(block, not_synced, insert_pts);
|
||||
// insert barrier
|
||||
for(ir::instruction *i: insert_pts)
|
||||
insert_barrier(i, builder);
|
||||
add(block, not_synced, builder);
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -44,6 +44,7 @@ llvm::Type *distributed_tile::make_vector_ty(llvm::Type *ty, size_t vector_size)
|
||||
return VectorType::get(ty, vector_size);
|
||||
}
|
||||
|
||||
|
||||
distributed_tile::distributed_tile(Type *ty, const shapes_t &shapes, const axes_t &axes, llvm::IRBuilder<> &builder, bool vectorize)
|
||||
: tile(make_vector_ty(ty, vectorize?axes[0].contiguous:1), shapes), axes_(axes), builder_(builder) {
|
||||
vector_size_ = vectorize?ty_->getVectorNumElements():1;
|
||||
@@ -149,6 +150,16 @@ Value* shared_tile::get_value(indices_t idx) {
|
||||
return builder_.CreateLoad(ptr);
|
||||
}
|
||||
|
||||
/* Utils */
|
||||
std::vector<unsigned> selection::extract_shapes(ir::value *v) {
|
||||
const auto& shapes = v->get_type()->get_tile_shapes();
|
||||
std::vector<unsigned> result(shapes.size());
|
||||
for(ir::constant_int* cst: shapes)
|
||||
result.push_back(cst->get_value());
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
/* convert ir::type to Type */
|
||||
Type *selection::llvm_type(ir::type *ty, LLVMContext &ctx) {
|
||||
// function
|
||||
@@ -299,11 +310,12 @@ std::vector<Value*> delinearize(Value *trailing, std::vector<unsigned> &shapes,
|
||||
}
|
||||
|
||||
void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
|
||||
const auto& shapes = v->get_type()->get_tile_shapes();
|
||||
const auto& shapes = extract_shapes(v);
|
||||
size_t dim = shapes.size();
|
||||
std::vector<unsigned> contiguous(dim);
|
||||
std::vector<unsigned> warp_size(dim);
|
||||
std::vector<unsigned> n_warps(dim);
|
||||
std::cout << v->get_name() << " " << typeid(*v).name() << std::endl;
|
||||
for(unsigned i = 0; i < shapes.size(); i++){
|
||||
std::string str_i = std::to_string(i);
|
||||
contiguous[i] = *params_->get_param(v, "p0.d" + str_i);
|
||||
@@ -336,7 +348,7 @@ void selection::create_grids(std::vector<ir::value*> &grids,
|
||||
// get number of dimensions greater than 1
|
||||
auto get_tile_gt1_dim = [&](ir::value *v){
|
||||
unsigned result = 0;
|
||||
for(unsigned shape: v->get_type()->get_tile_shapes()) {
|
||||
for(unsigned shape: extract_shapes(v)) {
|
||||
result += (shape > 1)?shape:0;
|
||||
}
|
||||
return result;
|
||||
@@ -353,7 +365,7 @@ void selection::create_grids(std::vector<ir::value*> &grids,
|
||||
for(ir::value *op: user->ops())
|
||||
bind_references(op);
|
||||
// bind
|
||||
const auto& shapes = v->get_type()->get_tile_shapes();
|
||||
const auto& shapes = extract_shapes(v);
|
||||
if(dynamic_cast<ir::copy_to_shared_inst*>(v) || buffer_info_->is_double(v))
|
||||
return;
|
||||
for(size_t d = 0; d < shapes.size(); d++){
|
||||
@@ -385,7 +397,7 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder,
|
||||
for(ir::value *op: user->ops())
|
||||
create_tile(op, builder, references, seen, sh_mem_ptr);
|
||||
LLVMContext &ctx = builder.getContext();
|
||||
const auto& shapes = v->get_type()->get_tile_shapes();
|
||||
const auto& shapes = extract_shapes(v);
|
||||
Type* ty = llvm_type(v->get_type()->get_scalar_ty(), ctx);
|
||||
// create shared tile
|
||||
if(dynamic_cast<ir::copy_to_shared_inst*>(v) || (buffer_info_->is_double(v))){
|
||||
@@ -429,7 +441,7 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder,
|
||||
}
|
||||
// create distributed tile
|
||||
else {
|
||||
const auto &shapes = v->get_type()->get_tile_shapes();
|
||||
const auto &shapes = extract_shapes(v);
|
||||
std::vector<distributed_axis> axes(shapes.size());
|
||||
for(size_t d = 0; d < shapes.size(); d++){
|
||||
if(shapes[d] > 1){
|
||||
@@ -530,7 +542,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
distributed_tile* result = (distributed_tile*)ti;
|
||||
if(!ins->get_type()->is_tile_ty())
|
||||
return;
|
||||
const auto& shapes = ins->get_type()->get_tile_shapes();
|
||||
const auto& shapes = extract_shapes(ins);
|
||||
// global_range
|
||||
if(auto *x = dynamic_cast<ir::get_global_range_inst*>(ins)) {
|
||||
static std::array<Intrinsic::ID, 3> ctaid = {
|
||||
@@ -568,7 +580,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
// broadcast
|
||||
else if(dynamic_cast<ir::broadcast_inst*>(ins)) {
|
||||
ir::value* in = ins->get_operand(0);
|
||||
const auto& in_shapes = in->get_type()->get_tile_shapes();
|
||||
const auto& in_shapes = extract_shapes(in);
|
||||
distributed_tile *in_tile = (distributed_tile*)tmap_.at(in);
|
||||
result->for_each([&](indices_t out_idx){
|
||||
indices_t in_idx = out_idx;
|
||||
@@ -615,7 +627,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
Function *f_mul_add = Intrinsic::getDeclaration(module, Intrinsic::fmuladd, {llvm_type(C->get_type()->get_scalar_ty(), ctx)});
|
||||
result->for_each([&](indices_t idx){
|
||||
Value *res = tmap_.at(C)->get_value(idx);
|
||||
unsigned NK = A->get_type()->get_tile_shapes()[1];
|
||||
unsigned NK = extract_shapes(A)[1];
|
||||
for(unsigned K = 0; K < NK; ++K){
|
||||
indices_t a_idx = {idx[0], builder.getInt32(K)};
|
||||
indices_t b_idx = {idx[1], builder.getInt32(K)};
|
||||
|
@@ -3,6 +3,8 @@
|
||||
#include "ir/type.h"
|
||||
#include "ir/module.h"
|
||||
#include "ir/function.h"
|
||||
#include "ir/context_impl.h"
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
|
||||
@@ -29,7 +31,8 @@ void tune::init_c_phi(ir::instruction *v) {
|
||||
|
||||
void tune::init_c_graph(ir::instruction *v) {
|
||||
// Reference shape
|
||||
std::vector<unsigned> shapes;
|
||||
ir::type::tile_shapes_t::value_type one = ir::tile_type::make_one(v->get_parent()->get_context());
|
||||
ir::type::tile_shapes_t shapes;
|
||||
if(auto *store = dynamic_cast<ir::store_inst*>(v))
|
||||
shapes = store->get_pointer_operand()->get_type()->get_tile_shapes();
|
||||
else
|
||||
@@ -39,7 +42,7 @@ void tune::init_c_graph(ir::instruction *v) {
|
||||
ir::value *op = v->get_operand(0);
|
||||
unsigned current = 0;
|
||||
for(unsigned i = 0; i < shapes.size(); i ++){
|
||||
if(shapes[i] == 1)
|
||||
if(shapes[i] == one)
|
||||
static_params_.insert({{v, i}, 1});
|
||||
else
|
||||
add_constraint({v, i}, {op, current++});
|
||||
@@ -99,6 +102,7 @@ void tune::connected_components(node_t x, const std::vector<unsigned *> vals, st
|
||||
std::vector<unsigned*> tune::get_params(ir::module &mod) {
|
||||
std::vector<unsigned *> result;
|
||||
std::set<unsigned*> seen;
|
||||
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *i : block->get_inst_list())
|
||||
@@ -143,8 +147,9 @@ void tune::create_grids(std::vector<ir::instruction*> &grids,
|
||||
// get number of dimensions greater than 1
|
||||
auto get_tile_gt1_dim = [&](ir::value *v){
|
||||
unsigned result = 0;
|
||||
for(unsigned shape: v->get_type()->get_tile_shapes()) {
|
||||
result += (shape > 1)?shape:0;
|
||||
auto one = ir::tile_type::make_one(fn->get_fn_type()->get_context());
|
||||
for(ir::constant_int *shape: v->get_type()->get_tile_shapes()) {
|
||||
result += (shape != one);
|
||||
}
|
||||
return result;
|
||||
};
|
||||
@@ -194,8 +199,8 @@ for(ir::function *fn: mod.get_function_list()){
|
||||
unsigned *s1 = params_[i]["p1.d" + strk];
|
||||
unsigned *s2 = params_[i]["p2.d" + strk];
|
||||
unsigned multiple = (*s0)*(*s1)*(*s2);
|
||||
if(shapes[k] % multiple != 0)
|
||||
errors[i].push_back("for dim " + strk + ": shape (" + to_string(shapes[k]) + ")"
|
||||
if(shapes[k]->get_value() % multiple != 0)
|
||||
errors[i].push_back("for dim " + strk + ": shape (" + to_string(shapes[k]->get_value()) + ")"
|
||||
" is not a multiple of layout (" + to_string(multiple) + ")");
|
||||
}
|
||||
// the number of thread per warp must be 32
|
||||
|
@@ -244,15 +244,15 @@ value *builder::create_store(value *ptr, value *val, const std::string &name){
|
||||
// tile instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
value *builder::create_reshape(value *arg, const std::vector<unsigned> &shapes, const std::string &name) {
|
||||
value *builder::create_reshape(value *arg, const type::tile_shapes_t &shapes, const std::string &name) {
|
||||
return insert(reshape_inst::create(arg, shapes, name));
|
||||
}
|
||||
|
||||
value *builder::create_splat(value *arg, const std::vector<unsigned> &shapes, const std::string &name) {
|
||||
value *builder::create_splat(value *arg, const type::tile_shapes_t &shapes, const std::string &name) {
|
||||
return insert(splat_inst::create(arg, shapes, name));
|
||||
}
|
||||
|
||||
value *builder::create_broadcast(value *arg, const std::vector<unsigned> &shapes, const std::string &name) {
|
||||
value *builder::create_broadcast(value *arg, const type::tile_shapes_t &shapes, const std::string &name) {
|
||||
return insert(broadcast_inst::create(arg, shapes, name));
|
||||
}
|
||||
|
||||
@@ -260,7 +260,7 @@ value *builder::create_broadcast(value *arg, const std::vector<unsigned> &shapes
|
||||
// built-in instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
value *builder::create_get_global_range(unsigned axis, unsigned size, const std::string &name) {
|
||||
value *builder::create_get_global_range(unsigned axis, type::tile_shapes_t::value_type size, const std::string &name) {
|
||||
return insert(get_global_range_inst::create(ctx_, axis, size, name));
|
||||
}
|
||||
|
||||
|
@@ -48,24 +48,27 @@ constant *constant::get_all_ones_value(type *ty) {
|
||||
constant_int::constant_int(type *ty, uint64_t value)
|
||||
: constant(ty, 0), value_(value){ }
|
||||
|
||||
constant *constant_int::get(type *ty, uint64_t value) {
|
||||
return new constant_int(ty, value);
|
||||
constant_int *constant_int::get(type *ty, uint64_t value) {
|
||||
context_impl *impl = ty->get_context().p_impl.get();
|
||||
constant_int *& cst = impl->int_constants_[std::make_pair(ty, value)];
|
||||
if(cst == nullptr)
|
||||
cst = new constant_int(ty, value);
|
||||
return cst;
|
||||
}
|
||||
|
||||
// constant_range
|
||||
// FIXME use something like APInt
|
||||
|
||||
constant_range::constant_range(type *ty, uint64_t first, uint64_t last)
|
||||
constant_range::constant_range(type *ty, constant_int *first, constant_int *last)
|
||||
: constant(ty, 0), first_(first), last_(last){ }
|
||||
|
||||
constant *constant_range::get(constant *first, constant *last) {
|
||||
constant *constant_range::get(constant_int *first, constant_int *last) {
|
||||
assert(first->get_type()->is_integer_ty());
|
||||
assert(first->get_type() == last->get_type());
|
||||
unsigned vfirst = ((constant_int*)first)->get_value();
|
||||
unsigned vlast = ((constant_int*)last)->get_value();
|
||||
assert(vlast > vfirst);
|
||||
type *ty = tile_type::get(first->get_type(), {vlast - vfirst});
|
||||
return new constant_range(ty, vfirst, vlast);
|
||||
assert(vfirst == 0);
|
||||
type *ty = tile_type::get(first->get_type(), {last});
|
||||
return new constant_range(ty, first, last);
|
||||
}
|
||||
|
||||
|
||||
@@ -94,6 +97,17 @@ constant *constant_fp::get(context &ctx, double v){
|
||||
return result;
|
||||
}
|
||||
|
||||
// metaparameter
|
||||
metaparameter::metaparameter(type *ty, unsigned lo, unsigned hi)
|
||||
: constant_int(ty, 0), lo_(lo), hi_(hi){ }
|
||||
|
||||
metaparameter* metaparameter::create(context &ctx, type *ty, unsigned lo, unsigned hi) {
|
||||
context_impl *impl = ctx.p_impl.get();
|
||||
metaparameter *result = new metaparameter(ty, lo, hi);
|
||||
impl->mp_constants_.push_back(result);
|
||||
return result;
|
||||
}
|
||||
|
||||
// undef value
|
||||
undef_value::undef_value(type *ty)
|
||||
: constant(ty, 0) { }
|
||||
|
@@ -409,7 +409,7 @@ std::string retile_inst::shape_suffix(ir::type* ty){
|
||||
std::string res = "[";
|
||||
const auto& shapes = ty->get_tile_shapes();
|
||||
for(unsigned i = 0; i < shapes.size(); i++){
|
||||
res += std::to_string(ty->get_tile_shapes()[i]);
|
||||
res += std::to_string(ty->get_tile_shapes()[i]->get_value());
|
||||
if(i < shapes.size() - 1)
|
||||
res += ", ";
|
||||
}
|
||||
@@ -417,13 +417,13 @@ std::string retile_inst::shape_suffix(ir::type* ty){
|
||||
return res;
|
||||
}
|
||||
|
||||
retile_inst::retile_inst(value *arg, const std::vector<unsigned> &shapes,
|
||||
retile_inst::retile_inst(value *arg, const type::tile_shapes_t &shapes,
|
||||
const std::string &name, instruction *next)
|
||||
: unary_inst(tile_type::get(arg->get_type()->get_scalar_ty(), shapes), arg, name, next) { }
|
||||
|
||||
// reshape
|
||||
|
||||
instruction* reshape_inst::create(value *arg, const std::vector<unsigned> &shapes,
|
||||
instruction* reshape_inst::create(value *arg, const type::tile_shapes_t &shapes,
|
||||
const std::string &name, instruction *next) {
|
||||
return new reshape_inst(arg, shapes, name, next);
|
||||
}
|
||||
@@ -431,14 +431,14 @@ instruction* reshape_inst::create(value *arg, const std::vector<unsigned> &shape
|
||||
|
||||
// splat
|
||||
|
||||
instruction* splat_inst::create(value *arg, const std::vector<unsigned> &shapes,
|
||||
instruction* splat_inst::create(value *arg, const type::tile_shapes_t &shapes,
|
||||
const std::string &name, instruction *next) {
|
||||
return new splat_inst(arg, shapes, name, next);
|
||||
}
|
||||
|
||||
// broadcast
|
||||
|
||||
instruction* broadcast_inst::create(value *arg, const std::vector<unsigned> &shapes,
|
||||
instruction* broadcast_inst::create(value *arg, const type::tile_shapes_t &shapes,
|
||||
const std::string &name, instruction *next) {
|
||||
return new broadcast_inst(arg, shapes, name, next);
|
||||
}
|
||||
@@ -470,7 +470,7 @@ get_global_range_inst::get_global_range_inst(type *ty, unsigned axis,
|
||||
|
||||
}
|
||||
|
||||
instruction* get_global_range_inst::create(context &ctx, unsigned axis, unsigned size,
|
||||
instruction* get_global_range_inst::create(context &ctx, unsigned axis, type::tile_shapes_t::value_type size,
|
||||
const std::string &name, instruction *next) {
|
||||
type *int_ty = type::get_int32_ty(ctx);
|
||||
type *tile_ty = tile_type::get(int_ty, {size});
|
||||
|
@@ -3,6 +3,7 @@
|
||||
#include "ir/context.h"
|
||||
#include "ir/context_impl.h"
|
||||
#include "ir/value.h"
|
||||
#include "ir/constant.h"
|
||||
|
||||
namespace tdl{
|
||||
namespace ir{
|
||||
@@ -63,7 +64,7 @@ type * type::get_pointer_element_ty() const {
|
||||
}
|
||||
|
||||
|
||||
const std::vector<unsigned> &type::get_tile_shapes() const {
|
||||
const type::tile_shapes_t &type::get_tile_shapes() const {
|
||||
assert(is_tile_ty());
|
||||
return ((tile_type*)this)->get_shapes();
|
||||
}
|
||||
@@ -148,7 +149,7 @@ bool composite_type::index_valid(value *idx) const{
|
||||
// tile_type class
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
tile_type::tile_type(type *ty, const std::vector<unsigned> &shapes)
|
||||
tile_type::tile_type(type *ty, const tile_shapes_t &shapes)
|
||||
: composite_type(ty->get_context(), TileTyID), shapes_(shapes) {
|
||||
contained_tys_.push_back(ty);
|
||||
}
|
||||
@@ -159,8 +160,8 @@ bool tile_type::is_valid_elt_ty(type *ty) {
|
||||
|
||||
unsigned tile_type::get_num_elements() const {
|
||||
unsigned res = 1;
|
||||
for(unsigned shape: shapes_)
|
||||
res *= shape;
|
||||
for(auto shape: shapes_)
|
||||
res *= shape->get_value();
|
||||
return res;
|
||||
}
|
||||
|
||||
@@ -168,7 +169,7 @@ unsigned tile_type::get_bitwidth() const {
|
||||
return get_num_elements() * get_tile_element_ty()->get_primitive_size_in_bits();
|
||||
}
|
||||
|
||||
tile_type* tile_type::get(type *elt_ty, const std::vector<unsigned> &shapes) {
|
||||
tile_type* tile_type::get(type *elt_ty, const tile_shapes_t &shapes) {
|
||||
assert(elt_ty && "Can't get a tile of <null> type!");
|
||||
assert(shapes.size() && "Can't create a tile with empty shapes!");
|
||||
assert(is_valid_elt_ty(elt_ty) && "Invalid type for pointer element!");
|
||||
@@ -185,6 +186,10 @@ tile_type* tile_type::get_same_shapes(type *ty, type *ref){
|
||||
return get(ty, ref->get_tile_shapes());
|
||||
}
|
||||
|
||||
type::tile_shapes_t::value_type tile_type::make_one(ir::context& ctx){
|
||||
return constant_int::get(type::get_int32_ty(ctx), 1);
|
||||
}
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// function_type class
|
||||
|
Reference in New Issue
Block a user