[intermediate representation] transitioning towards more flexible tile
shapes
This commit is contained in:
1
TODO
1
TODO
@@ -2,3 +2,4 @@
|
|||||||
- proper naming scheme
|
- proper naming scheme
|
||||||
- symbols table
|
- symbols table
|
||||||
- name conflicts on globals?
|
- name conflicts on globals?
|
||||||
|
- separate header for typedef (e.g., type::tile_shapes_t) to reduce compilation time
|
||||||
|
@@ -6,6 +6,7 @@
|
|||||||
#include "ir/context.h"
|
#include "ir/context.h"
|
||||||
#include "ir/module.h"
|
#include "ir/module.h"
|
||||||
#include "ir/print.h"
|
#include "ir/print.h"
|
||||||
|
#include "ir/context_impl.h"
|
||||||
#include "codegen/selection.h"
|
#include "codegen/selection.h"
|
||||||
#include "codegen/tune.h"
|
#include "codegen/tune.h"
|
||||||
#include "codegen/shared_copy.h"
|
#include "codegen/shared_copy.h"
|
||||||
@@ -182,6 +183,9 @@ int main() {
|
|||||||
llvm::LLVMContext llvm_context;
|
llvm::LLVMContext llvm_context;
|
||||||
llvm::Module llvm_module("test", llvm_context);
|
llvm::Module llvm_module("test", llvm_context);
|
||||||
|
|
||||||
|
// context.p_impl->mp_constants_[0]->set_value(16);
|
||||||
|
// context.p_impl->mp_constants_[1]->set_value(16);
|
||||||
|
// context.p_impl->mp_constants_[2]->set_value(8);
|
||||||
|
|
||||||
// create passes
|
// create passes
|
||||||
tdl::codegen::buffer_info_pass buffer_info;
|
tdl::codegen::buffer_info_pass buffer_info;
|
||||||
|
@@ -5,7 +5,7 @@
|
|||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
|
||||||
namespace tdl{
|
namespace tdl{
|
||||||
@@ -56,6 +56,12 @@ enum TYPE_T{
|
|||||||
FLOAT32_T, FLOAT64_T
|
FLOAT32_T, FLOAT64_T
|
||||||
};
|
};
|
||||||
|
|
||||||
|
enum STORAGE_SPEC_T{
|
||||||
|
TUNABLE_T,
|
||||||
|
KERNEL_T,
|
||||||
|
READONLY_T, WRITEONLY_T,
|
||||||
|
};
|
||||||
|
|
||||||
class pointer;
|
class pointer;
|
||||||
class identifier;
|
class identifier;
|
||||||
class constant;
|
class constant;
|
||||||
@@ -75,7 +81,7 @@ public:
|
|||||||
template<class T>
|
template<class T>
|
||||||
class list: public node {
|
class list: public node {
|
||||||
public:
|
public:
|
||||||
list(const T& x): values_{x} {}
|
list(const T& x): values_(1, x) {}
|
||||||
|
|
||||||
node* append(const T& x){
|
node* append(const T& x){
|
||||||
values_.push_back(x);
|
values_.push_back(x);
|
||||||
@@ -389,16 +395,30 @@ public:
|
|||||||
class no_op: public statement { };
|
class no_op: public statement { };
|
||||||
|
|
||||||
// Types
|
// Types
|
||||||
|
|
||||||
class declaration_specifier: public node{
|
class declaration_specifier: public node{
|
||||||
public:
|
public:
|
||||||
declaration_specifier(TYPE_T spec)
|
using node::node;
|
||||||
: spec_(spec) { }
|
virtual ir::type* type(ir::module *mod) const = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
class typed_declaration_specifier: public declaration_specifier {
|
||||||
|
public:
|
||||||
|
typed_declaration_specifier(TYPE_T ty): ty_(ty){ }
|
||||||
ir::type* type(ir::module *mod) const;
|
ir::type* type(ir::module *mod) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const TYPE_T spec_;
|
const TYPE_T ty_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class storage_declaration_specifier: public declaration_specifier {
|
||||||
|
public:
|
||||||
|
storage_declaration_specifier(STORAGE_SPEC_T storage_spec, node *decl_spec)
|
||||||
|
: storage_spec_(storage_spec), decl_spec_((declaration_specifier*)decl_spec) {}
|
||||||
|
ir::type* type(ir::module *mod) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
const STORAGE_SPEC_T storage_spec_;
|
||||||
|
const declaration_specifier* decl_spec_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class declarator;
|
class declarator;
|
||||||
@@ -495,7 +515,7 @@ public:
|
|||||||
: declarator((node*)((declarator*)decl)->id()),
|
: declarator((node*)((declarator*)decl)->id()),
|
||||||
decl_((declarator*)decl), expr_((expression*)init){ }
|
decl_((declarator*)decl), expr_((expression*)init){ }
|
||||||
|
|
||||||
void specifier(const declaration_specifier *spec);
|
void set_specifier(const declaration_specifier *spec);
|
||||||
ir::value* codegen(ir::module *) const;
|
ir::value* codegen(ir::module *) const;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
@@ -535,17 +555,17 @@ public:
|
|||||||
class translation_unit: public node{
|
class translation_unit: public node{
|
||||||
public:
|
public:
|
||||||
translation_unit(node *item)
|
translation_unit(node *item)
|
||||||
: decls_((list<node*>*)item) { }
|
: decls_(item) { }
|
||||||
|
|
||||||
translation_unit *add(node *item) {
|
translation_unit *add(node *item) {
|
||||||
decls_->append(item);
|
decls_.append(item);
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
ir::value* codegen(ir::module * mod) const;
|
ir::value* codegen(ir::module * mod) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
list<node*>* decls_;
|
list<node*> decls_;
|
||||||
};
|
};
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@@ -20,12 +20,14 @@ struct token: public node{
|
|||||||
token(BIN_OP_T value): bin_op(value){ }
|
token(BIN_OP_T value): bin_op(value){ }
|
||||||
token(UNARY_OP_T value): unary_op(value){ }
|
token(UNARY_OP_T value): unary_op(value){ }
|
||||||
token(TYPE_T value): type(value){ }
|
token(TYPE_T value): type(value){ }
|
||||||
|
token(STORAGE_SPEC_T value): storage_spec(value){ }
|
||||||
|
|
||||||
union {
|
union {
|
||||||
ASSIGN_OP_T assign_op;
|
ASSIGN_OP_T assign_op;
|
||||||
BIN_OP_T bin_op;
|
BIN_OP_T bin_op;
|
||||||
UNARY_OP_T unary_op;
|
UNARY_OP_T unary_op;
|
||||||
TYPE_T type;
|
TYPE_T type;
|
||||||
|
STORAGE_SPEC_T storage_spec;
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -39,10 +41,12 @@ node* append_ptr_list(node *result, node *in){
|
|||||||
ASSIGN_OP_T get_assign_op(node *op) { return ((token*)op)->assign_op; }
|
ASSIGN_OP_T get_assign_op(node *op) { return ((token*)op)->assign_op; }
|
||||||
UNARY_OP_T get_unary_op(node *op) { return ((token*)op)->unary_op; }
|
UNARY_OP_T get_unary_op(node *op) { return ((token*)op)->unary_op; }
|
||||||
TYPE_T get_type_spec(node *op) { return ((token*)op)->type; }
|
TYPE_T get_type_spec(node *op) { return ((token*)op)->type; }
|
||||||
|
STORAGE_SPEC_T get_storage_spec(node *op) { return ((token*)op)->storage_spec;}
|
||||||
|
|
||||||
%}
|
%}
|
||||||
|
|
||||||
%token IDENTIFIER CONSTANT STRING_LITERAL
|
%token IDENTIFIER CONSTANT STRING_LITERAL
|
||||||
|
%token TUNABLE KERNEL READONLY WRITEONLY
|
||||||
%token PTR_OP INC_OP DEC_OP LEFT_OP RIGHT_OP LE_OP GE_OP EQ_OP NE_OP
|
%token PTR_OP INC_OP DEC_OP LEFT_OP RIGHT_OP LE_OP GE_OP EQ_OP NE_OP
|
||||||
%token AND_OP OR_OP MUL_ASSIGN DIV_ASSIGN MOD_ASSIGN ADD_ASSIGN
|
%token AND_OP OR_OP MUL_ASSIGN DIV_ASSIGN MOD_ASSIGN ADD_ASSIGN
|
||||||
%token SUB_ASSIGN LEFT_ASSIGN RIGHT_ASSIGN AND_ASSIGN
|
%token SUB_ASSIGN LEFT_ASSIGN RIGHT_ASSIGN AND_ASSIGN
|
||||||
@@ -87,17 +91,12 @@ abstract_declarator
|
|||||||
;
|
;
|
||||||
|
|
||||||
direct_abstract_declarator
|
direct_abstract_declarator
|
||||||
: '[' constant_list ']' { $$ = new tile(nullptr, $1); }
|
: '[' primary_expression_list ']' { $$ = new tile(nullptr, $1); }
|
||||||
|
|
||||||
constant :
|
constant :
|
||||||
CONSTANT { $$ = new constant(atoi(yytext)); }
|
CONSTANT { $$ = new constant(atoi(yytext)); }
|
||||||
;
|
;
|
||||||
|
|
||||||
constant_list
|
|
||||||
: constant { $$ = new list<constant*>((constant*)$1); }
|
|
||||||
| constant_list ',' constant { $$ = append_ptr_list<constant>($1, $3); }
|
|
||||||
;
|
|
||||||
|
|
||||||
type_name
|
type_name
|
||||||
: declaration_specifiers { $$ = new type_name($1, nullptr); }
|
: declaration_specifiers { $$ = new type_name($1, nullptr); }
|
||||||
| declaration_specifiers abstract_declarator { $$ = new type_name($1, $2); }
|
| declaration_specifiers abstract_declarator { $$ = new type_name($1, $2); }
|
||||||
@@ -112,7 +111,7 @@ identifier
|
|||||||
;
|
;
|
||||||
|
|
||||||
builtin
|
builtin
|
||||||
: GET_GLOBAL_RANGE '[' constant ']' '(' constant ')' { $$ = new get_global_range($3, $6); }
|
: GET_GLOBAL_RANGE '[' primary_expression ']' '(' constant ')' { $$ = new get_global_range($3, $6); }
|
||||||
| DOT '(' expression ',' expression ',' expression ')' { $$ = new matmul_expression($3, $5, $7); }
|
| DOT '(' expression ',' expression ',' expression ')' { $$ = new matmul_expression($3, $5, $7); }
|
||||||
|
|
||||||
primary_expression
|
primary_expression
|
||||||
@@ -124,6 +123,11 @@ primary_expression
|
|||||||
| '(' expression ')' { $$ = $2; }
|
| '(' expression ')' { $$ = $2; }
|
||||||
;
|
;
|
||||||
|
|
||||||
|
primary_expression_list
|
||||||
|
: primary_expression { $$ = new list<expression*>((expression*)$1); }
|
||||||
|
| primary_expression_list ',' primary_expression { $$ = append_ptr_list<expression>($1, $3); }
|
||||||
|
;
|
||||||
|
|
||||||
slice
|
slice
|
||||||
: ':' { $$ = new slice(tdl::ast::ALL); }
|
: ':' { $$ = new slice(tdl::ast::ALL); }
|
||||||
| NEWAXIS { $$ = new slice(tdl::ast::NEWAXIS); }
|
| NEWAXIS { $$ = new slice(tdl::ast::NEWAXIS); }
|
||||||
@@ -312,7 +316,7 @@ jump_statement
|
|||||||
|
|
||||||
direct_declarator
|
direct_declarator
|
||||||
: identifier { $$ = $1; }
|
: identifier { $$ = $1; }
|
||||||
| identifier '[' constant_list ']' { $$ = new tile($1, $3); }
|
| identifier '[' primary_expression_list ']' { $$ = new tile($1, $3); }
|
||||||
| identifier '(' parameter_list ')' { $$ = new function($1, $3); }
|
| identifier '(' parameter_list ')' { $$ = new function($1, $3); }
|
||||||
| identifier '(' ')' { $$ = new function($1, nullptr); }
|
| identifier '(' ')' { $$ = new function($1, nullptr); }
|
||||||
;
|
;
|
||||||
@@ -330,7 +334,8 @@ parameter_declaration
|
|||||||
|
|
||||||
|
|
||||||
declaration_specifiers
|
declaration_specifiers
|
||||||
: type_specifier { $$ = new declaration_specifier(get_type_spec($1)); }
|
: type_specifier { $$ = new typed_declaration_specifier(get_type_spec($1)); }
|
||||||
|
| storage_class_specifier declaration_specifiers { $$ = new storage_declaration_specifier(get_storage_spec($1), $2); }
|
||||||
;
|
;
|
||||||
|
|
||||||
init_declarator_list
|
init_declarator_list
|
||||||
@@ -354,6 +359,13 @@ init_declarator
|
|||||||
| declarator '=' initialization_expression { $$ = new initializer($1, $3); }
|
| declarator '=' initialization_expression { $$ = new initializer($1, $3); }
|
||||||
;
|
;
|
||||||
|
|
||||||
|
storage_class_specifier
|
||||||
|
: TUNABLE { $$ = new token(TUNABLE_T); }
|
||||||
|
| KERNEL { $$ = new token(KERNEL_T); }
|
||||||
|
| READONLY { $$ = new token(READONLY_T); }
|
||||||
|
| WRITEONLY { $$ = new token(WRITEONLY_T); }
|
||||||
|
;
|
||||||
|
|
||||||
/* -------------------------- */
|
/* -------------------------- */
|
||||||
/* Translation Unit */
|
/* Translation Unit */
|
||||||
/* -------------------------- */
|
/* -------------------------- */
|
||||||
|
@@ -16,6 +16,10 @@ int comment();
|
|||||||
%}
|
%}
|
||||||
|
|
||||||
%%
|
%%
|
||||||
|
"tunable" { count(); return(TUNABLE); }
|
||||||
|
"kernel" { count(); return(KERNEL); }
|
||||||
|
"readonly" { count(); return(READONLY); }
|
||||||
|
"writeonly" { count(); return(WRITEONLY); }
|
||||||
"@" { count(); return(AT); }
|
"@" { count(); return(AT); }
|
||||||
"newaxis" { count(); return(NEWAXIS); }
|
"newaxis" { count(); return(NEWAXIS); }
|
||||||
"if" { count(); return(IF); }
|
"if" { count(); return(IF); }
|
||||||
|
@@ -32,7 +32,7 @@ private:
|
|||||||
void add_reference(ir::value *v, interval_vec_t &res);
|
void add_reference(ir::value *v, interval_vec_t &res);
|
||||||
void get_read_intervals(ir::instruction *i, interval_vec_t &res);
|
void get_read_intervals(ir::instruction *i, interval_vec_t &res);
|
||||||
void get_written_intervals(ir::instruction *i, interval_vec_t &res);
|
void get_written_intervals(ir::instruction *i, interval_vec_t &res);
|
||||||
void add(ir::basic_block *block, interval_vec_t ¬_synced, std::set<ir::instruction *> &insert_pts);
|
void add(ir::basic_block *block, interval_vec_t ¬_synced, ir::builder &builder);
|
||||||
|
|
||||||
public:
|
public:
|
||||||
barriers(allocation *alloc, buffer_info_pass *buffer_info): alloc_(alloc), buffer_info_(buffer_info) {}
|
barriers(allocation *alloc, buffer_info_pass *buffer_info): alloc_(alloc), buffer_info_(buffer_info) {}
|
||||||
|
@@ -100,6 +100,7 @@ class selection{
|
|||||||
private:
|
private:
|
||||||
// utils
|
// utils
|
||||||
llvm::Type *make_vector_ty(llvm::Type *ty, size_t vector_size);
|
llvm::Type *make_vector_ty(llvm::Type *ty, size_t vector_size);
|
||||||
|
std::vector<unsigned> extract_shapes(ir::value *v);
|
||||||
|
|
||||||
// LLVM conversions
|
// LLVM conversions
|
||||||
llvm::Type* llvm_type(ir::type *ty, llvm::LLVMContext &ctx);
|
llvm::Type* llvm_type(ir::type *ty, llvm::LLVMContext &ctx);
|
||||||
|
@@ -6,6 +6,7 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include "instructions.h"
|
#include "instructions.h"
|
||||||
#include "basic_block.h"
|
#include "basic_block.h"
|
||||||
|
#include "type.h"
|
||||||
|
|
||||||
namespace tdl{
|
namespace tdl{
|
||||||
namespace ir{
|
namespace ir{
|
||||||
@@ -110,11 +111,11 @@ public:
|
|||||||
value *create_load(value *arg, const std::string &name = "");
|
value *create_load(value *arg, const std::string &name = "");
|
||||||
value *create_store(value *ptr, value *val, const std::string &name = "");
|
value *create_store(value *ptr, value *val, const std::string &name = "");
|
||||||
// Tile instruction
|
// Tile instruction
|
||||||
value *create_splat(value *arg, const std::vector<unsigned> &shapes, const std::string &name = "");
|
value *create_splat(value *arg, const type::tile_shapes_t &shapes, const std::string &name = "");
|
||||||
value *create_reshape(value *arg, const std::vector<unsigned> &shapes, const std::string &name = "");
|
value *create_reshape(value *arg, const type::tile_shapes_t &shapes, const std::string &name = "");
|
||||||
value *create_broadcast(value *arg, const std::vector<unsigned> &shapes, const std::string &name = "");
|
value *create_broadcast(value *arg, const type::tile_shapes_t &shapes, const std::string &name = "");
|
||||||
// Built-in instruction
|
// Built-in instruction
|
||||||
value *create_get_global_range(unsigned axis, unsigned size, const std::string &name = "");
|
value *create_get_global_range(unsigned axis, type::tile_shapes_t::value_type size, const std::string &name = "");
|
||||||
value *create_matmul(value *A, value *B, value *C, const std::string &name = "");
|
value *create_matmul(value *A, value *B, value *C, const std::string &name = "");
|
||||||
// Intrinsics
|
// Intrinsics
|
||||||
value *create_copy_to_shared(value *arg, const std::string &name = "");
|
value *create_copy_to_shared(value *arg, const std::string &name = "");
|
||||||
|
@@ -2,6 +2,7 @@
|
|||||||
#define TDL_INCLUDE_IR_CONSTANT_H
|
#define TDL_INCLUDE_IR_CONSTANT_H
|
||||||
|
|
||||||
#include "value.h"
|
#include "value.h"
|
||||||
|
#include <cassert>
|
||||||
|
|
||||||
namespace tdl{
|
namespace tdl{
|
||||||
namespace ir{
|
namespace ir{
|
||||||
@@ -28,28 +29,43 @@ public:
|
|||||||
static undef_value* get(type* ty);
|
static undef_value* get(type* ty);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
/* Constant int */
|
/* Constant int */
|
||||||
class constant_int: public constant{
|
class constant_int: public constant{
|
||||||
|
protected:
|
||||||
constant_int(type *ty, uint64_t value);
|
constant_int(type *ty, uint64_t value);
|
||||||
|
|
||||||
public:
|
public:
|
||||||
uint64_t get_value() const { return value_; }
|
uint64_t get_value() const { return value_; }
|
||||||
static constant *get(type *ty, uint64_t value);
|
static constant_int *get(type *ty, uint64_t value);
|
||||||
|
|
||||||
|
protected:
|
||||||
|
uint64_t value_;
|
||||||
|
};
|
||||||
|
|
||||||
|
/* Metaparameter int */
|
||||||
|
class metaparameter: public constant_int{
|
||||||
|
metaparameter(type *ty, unsigned lo, unsigned hi);
|
||||||
|
|
||||||
|
public:
|
||||||
|
static metaparameter *create(context &ctx, type *ty, unsigned lo, unsigned hi);
|
||||||
|
void set_value(uint64_t value) { value_ = value; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
uint64_t value_;
|
unsigned lo_;
|
||||||
|
unsigned hi_;
|
||||||
};
|
};
|
||||||
|
|
||||||
/* constant range */
|
/* constant range */
|
||||||
class constant_range: public constant{
|
class constant_range: public constant{
|
||||||
constant_range(type *ty, uint64_t first, uint64_t last);
|
constant_range(type *ty, constant_int* first, constant_int* last);
|
||||||
|
|
||||||
public:
|
public:
|
||||||
static constant *get(constant *first, constant *last);
|
static constant *get(constant_int *first, constant_int *last);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
uint64_t first_;
|
constant_int* first_;
|
||||||
uint64_t last_;
|
constant_int* last_;
|
||||||
};
|
};
|
||||||
|
|
||||||
/* constant fp */
|
/* constant fp */
|
||||||
|
@@ -12,6 +12,7 @@ class context;
|
|||||||
class constant_int;
|
class constant_int;
|
||||||
class constant_fp;
|
class constant_fp;
|
||||||
class undef_value;
|
class undef_value;
|
||||||
|
class metaparameter;
|
||||||
|
|
||||||
/* Context impl */
|
/* Context impl */
|
||||||
class context_impl {
|
class context_impl {
|
||||||
@@ -26,13 +27,15 @@ public:
|
|||||||
integer_type int1_ty, int8_ty, int16_ty, int32_ty, int64_ty, int128_ty;
|
integer_type int1_ty, int8_ty, int16_ty, int32_ty, int64_ty, int128_ty;
|
||||||
// Pointer types
|
// Pointer types
|
||||||
std::map<std::pair<type*, unsigned>, pointer_type*> ptr_tys;
|
std::map<std::pair<type*, unsigned>, pointer_type*> ptr_tys;
|
||||||
std::map<std::pair<type*,std::vector<unsigned>>, tile_type*> tile_tys;
|
std::map<std::pair<type*, type::tile_shapes_t>, tile_type*> tile_tys;
|
||||||
// Int constants
|
// Int constants
|
||||||
std::map<uint64_t, constant_int*> int_constants_;
|
std::map<std::pair<type*, uint64_t>, constant_int*> int_constants_;
|
||||||
// Float constants
|
// Float constants
|
||||||
std::map<double, constant_fp*> fp_constants_;
|
std::map<double, constant_fp*> fp_constants_;
|
||||||
// undef values
|
// undef values
|
||||||
std::map<type*, undef_value*> uv_constants_;
|
std::map<type*, undef_value*> uv_constants_;
|
||||||
|
// Metaparameters
|
||||||
|
std::vector<metaparameter*> mp_constants_;
|
||||||
};
|
};
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@@ -3,6 +3,7 @@
|
|||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "value.h"
|
#include "value.h"
|
||||||
|
#include "ir/type.h"
|
||||||
#include "llvm/IR/Instructions.h"
|
#include "llvm/IR/Instructions.h"
|
||||||
|
|
||||||
namespace tdl{
|
namespace tdl{
|
||||||
@@ -358,7 +359,7 @@ public:
|
|||||||
|
|
||||||
class retile_inst: public unary_inst {
|
class retile_inst: public unary_inst {
|
||||||
protected:
|
protected:
|
||||||
retile_inst(value *arg, const std::vector<unsigned> &shape_suffix, const std::string &name, instruction *next);
|
retile_inst(value *arg, const type::tile_shapes_t &shapes, const std::string &name, instruction *next);
|
||||||
static std::string shape_suffix(ir::type* ty);
|
static std::string shape_suffix(ir::type* ty);
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -370,7 +371,7 @@ private:
|
|||||||
std::string repr_impl() const { return "reshape" + shape_suffix(get_type()); }
|
std::string repr_impl() const { return "reshape" + shape_suffix(get_type()); }
|
||||||
|
|
||||||
public:
|
public:
|
||||||
static instruction* create(value *arg, const std::vector<unsigned> &shape_suffix,
|
static instruction* create(value *arg, const type::tile_shapes_t &shape_suffix,
|
||||||
const std::string &name = "", instruction *next = nullptr);
|
const std::string &name = "", instruction *next = nullptr);
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -382,7 +383,7 @@ private:
|
|||||||
std::string repr_impl() const { return "splat" + shape_suffix(get_type()); }
|
std::string repr_impl() const { return "splat" + shape_suffix(get_type()); }
|
||||||
|
|
||||||
public:
|
public:
|
||||||
static instruction* create(value *arg, const std::vector<unsigned> &shape_suffix,
|
static instruction* create(value *arg, const type::tile_shapes_t &shape_suffix,
|
||||||
const std::string &name = "", instruction *next = nullptr);
|
const std::string &name = "", instruction *next = nullptr);
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -394,7 +395,7 @@ private:
|
|||||||
std::string repr_impl() const { return "broadcast" + shape_suffix(get_type()); }
|
std::string repr_impl() const { return "broadcast" + shape_suffix(get_type()); }
|
||||||
|
|
||||||
public:
|
public:
|
||||||
static instruction* create(value *arg, const std::vector<unsigned> &shape_suffix,
|
static instruction* create(value *arg, const type::tile_shapes_t &shape_suffix,
|
||||||
const std::string &name = "", instruction *next = nullptr);
|
const std::string &name = "", instruction *next = nullptr);
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -414,7 +415,7 @@ private:
|
|||||||
std::string repr_impl() const { return "get_global_range(" + std::to_string(axis_) + ")"; }
|
std::string repr_impl() const { return "get_global_range(" + std::to_string(axis_) + ")"; }
|
||||||
|
|
||||||
public:
|
public:
|
||||||
static instruction* create(context &ctx, unsigned axis, unsigned size,
|
static instruction* create(context &ctx, unsigned axis, type::tile_shapes_t::value_type size,
|
||||||
const std::string &name = "",
|
const std::string &name = "",
|
||||||
instruction *next = nullptr);
|
instruction *next = nullptr);
|
||||||
unsigned get_axis() const { return axis_; }
|
unsigned get_axis() const { return axis_; }
|
||||||
|
@@ -3,6 +3,7 @@
|
|||||||
|
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <set>
|
#include <set>
|
||||||
|
#include <stack>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include "builder.h"
|
#include "builder.h"
|
||||||
@@ -12,6 +13,7 @@ namespace tdl{
|
|||||||
namespace ast{
|
namespace ast{
|
||||||
|
|
||||||
class iteration_statement;
|
class iteration_statement;
|
||||||
|
class compound_statement;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -69,6 +71,10 @@ public:
|
|||||||
const functions_list_t &get_function_list() const { return functions_; }
|
const functions_list_t &get_function_list() const { return functions_; }
|
||||||
functions_list_t &get_function_list() { return functions_; }
|
functions_list_t &get_function_list() { return functions_; }
|
||||||
function *get_or_insert_function(const std::string &name, function_type *ty);
|
function *get_or_insert_function(const std::string &name, function_type *ty);
|
||||||
|
// Scope
|
||||||
|
void push_scope(const ast::compound_statement* scope) { scopes_.push(scope); }
|
||||||
|
void pop_scope() { scopes_.pop(); }
|
||||||
|
const ast::compound_statement* get_scope() { return scopes_.top(); }
|
||||||
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@@ -83,6 +89,7 @@ private:
|
|||||||
symbols_map_t symbols_;
|
symbols_map_t symbols_;
|
||||||
std::function<ir::value*()> continue_fn_;
|
std::function<ir::value*()> continue_fn_;
|
||||||
std::map<value*, value**> current_phi_;
|
std::map<value*, value**> current_phi_;
|
||||||
|
std::stack<const ast::compound_statement*> scopes_;
|
||||||
};
|
};
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@@ -3,6 +3,7 @@
|
|||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
#include <set>
|
||||||
|
|
||||||
namespace tdl{
|
namespace tdl{
|
||||||
namespace ir{
|
namespace ir{
|
||||||
@@ -10,9 +11,13 @@ namespace ir{
|
|||||||
class context;
|
class context;
|
||||||
class value;
|
class value;
|
||||||
class integer_type;
|
class integer_type;
|
||||||
|
class constant_int;
|
||||||
|
|
||||||
/* Type */
|
/* Type */
|
||||||
class type {
|
class type {
|
||||||
|
public:
|
||||||
|
typedef std::vector<constant_int*> tile_shapes_t;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
typedef std::vector<type*> contained_tys_vec_t;
|
typedef std::vector<type*> contained_tys_vec_t;
|
||||||
typedef contained_tys_vec_t::iterator ty_iterator;
|
typedef contained_tys_vec_t::iterator ty_iterator;
|
||||||
@@ -54,7 +59,7 @@ public:
|
|||||||
unsigned get_tile_bitwidth() const;
|
unsigned get_tile_bitwidth() const;
|
||||||
unsigned get_primitive_size_in_bits() const;
|
unsigned get_primitive_size_in_bits() const;
|
||||||
type *get_scalar_ty() const;
|
type *get_scalar_ty() const;
|
||||||
const std::vector<unsigned> &get_tile_shapes() const;
|
const tile_shapes_t& get_tile_shapes() const;
|
||||||
unsigned get_tile_num_elements() const;
|
unsigned get_tile_num_elements() const;
|
||||||
type *get_tile_element_ty() const;
|
type *get_tile_element_ty() const;
|
||||||
unsigned get_pointer_address_space() const;
|
unsigned get_pointer_address_space() const;
|
||||||
@@ -94,9 +99,25 @@ public:
|
|||||||
static integer_type *get_int64_ty(context &ctx);
|
static integer_type *get_int64_ty(context &ctx);
|
||||||
static integer_type *get_int128_ty(context &ctx);
|
static integer_type *get_int128_ty(context &ctx);
|
||||||
|
|
||||||
|
// Attributes
|
||||||
|
type* set_tunable() { is_tunable_ = true; return this; }
|
||||||
|
type* set_readonly() { is_readonly_ = true; return this; }
|
||||||
|
type* set_writeonly() { is_writeonly_ = true; return this; }
|
||||||
|
type* set_kernel() { is_kernel_ = true; return this; }
|
||||||
|
|
||||||
|
bool get_tunable() { return is_tunable_; }
|
||||||
|
bool get_readonly() { return is_readonly_; }
|
||||||
|
bool get_writeonly() { return is_writeonly_; }
|
||||||
|
bool get_kernel() { return is_kernel_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
context &ctx_;
|
context &ctx_;
|
||||||
id_t id_;
|
id_t id_;
|
||||||
|
// attributes
|
||||||
|
bool is_tunable_;
|
||||||
|
bool is_readonly_;
|
||||||
|
bool is_writeonly_;
|
||||||
|
bool is_kernel_;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
contained_tys_vec_t contained_tys_;
|
contained_tys_vec_t contained_tys_;
|
||||||
@@ -132,21 +153,24 @@ public:
|
|||||||
|
|
||||||
class tile_type: public composite_type {
|
class tile_type: public composite_type {
|
||||||
private:
|
private:
|
||||||
tile_type(type *ty, const std::vector<unsigned> &shapes);
|
tile_type(type *ty, const tile_shapes_t &shapes);
|
||||||
static bool is_valid_elt_ty(type *ty);
|
static bool is_valid_elt_ty(type *ty);
|
||||||
|
|
||||||
public:
|
public:
|
||||||
// accessors
|
// accessors
|
||||||
const std::vector<unsigned>& get_shapes() const { return shapes_; }
|
const tile_shapes_t& get_shapes() const { return shapes_; }
|
||||||
unsigned get_num_elements() const;
|
unsigned get_num_elements() const;
|
||||||
unsigned get_bitwidth() const;
|
unsigned get_bitwidth() const;
|
||||||
|
|
||||||
// factory methods
|
// factory methods
|
||||||
static tile_type* get(type *ty, const std::vector<unsigned> &shapes);
|
static tile_type* get(type *ty, const tile_shapes_t &shapes);
|
||||||
static tile_type* get_same_shapes(type *ty, type *ref);
|
static tile_type* get_same_shapes(type *ty, type *ref);
|
||||||
|
|
||||||
|
// shortcut to get a 1 element in the shape
|
||||||
|
static tile_shapes_t::value_type make_one(context &ctx);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::vector<unsigned> shapes_;
|
tile_shapes_t shapes_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class pointer_type: public type {
|
class pointer_type: public type {
|
||||||
|
@@ -100,6 +100,7 @@ void node::implicit_broadcast(ir::module *mod, ir::value *&lhs, ir::value *&rhs)
|
|||||||
ir::builder &builder = mod->get_builder();
|
ir::builder &builder = mod->get_builder();
|
||||||
ir::type *lhs_ty = lhs->get_type();
|
ir::type *lhs_ty = lhs->get_type();
|
||||||
ir::type *rhs_ty = rhs->get_type();
|
ir::type *rhs_ty = rhs->get_type();
|
||||||
|
ir::type::tile_shapes_t::value_type one = ir::tile_type::make_one(mod->get_context());
|
||||||
// Both are scalar
|
// Both are scalar
|
||||||
if(!lhs_ty->is_tile_ty() && !rhs_ty->is_tile_ty())
|
if(!lhs_ty->is_tile_ty() && !rhs_ty->is_tile_ty())
|
||||||
return;
|
return;
|
||||||
@@ -111,30 +112,30 @@ void node::implicit_broadcast(ir::module *mod, ir::value *&lhs, ir::value *&rhs)
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
// Both are arrays
|
// Both are arrays
|
||||||
std::vector<unsigned> lhs_shapes = lhs->get_type()->get_tile_shapes();
|
auto lhs_shapes = lhs->get_type()->get_tile_shapes();
|
||||||
std::vector<unsigned> rhs_shapes = rhs->get_type()->get_tile_shapes();
|
auto rhs_shapes = rhs->get_type()->get_tile_shapes();
|
||||||
if(lhs_shapes == rhs_shapes)
|
if(lhs_shapes == rhs_shapes)
|
||||||
return;
|
return;
|
||||||
int lhs_dim = lhs_shapes.size();
|
int lhs_dim = lhs_shapes.size();
|
||||||
int rhs_dim = rhs_shapes.size();
|
int rhs_dim = rhs_shapes.size();
|
||||||
std::vector<unsigned> &shortest = (lhs_dim < rhs_dim)?lhs_shapes:rhs_shapes;
|
auto &shortest = (lhs_dim < rhs_dim)?lhs_shapes:rhs_shapes;
|
||||||
std::vector<unsigned> &longest = (lhs_dim < rhs_dim)?rhs_shapes:lhs_shapes;
|
auto &longest = (lhs_dim < rhs_dim)?rhs_shapes:lhs_shapes;
|
||||||
size_t ndim = longest.size();
|
size_t ndim = longest.size();
|
||||||
int off = longest.size() - shortest.size();
|
int off = longest.size() - shortest.size();
|
||||||
for(int i = longest.size() - 1; i>= 0; i--){
|
for(int i = longest.size() - 1; i>= 0; i--){
|
||||||
if(shortest[off + i] != longest[i] && shortest[off + i] != 1 && longest[i] != 1)
|
if(shortest[off + i] != longest[i] && shortest[off + i] != one && longest[i] != one)
|
||||||
throw std::runtime_error("cannot broadcast");
|
throw std::runtime_error("cannot broadcast");
|
||||||
}
|
}
|
||||||
// Pad
|
// Pad
|
||||||
for(size_t i = 0; i < off; i++)
|
for(size_t i = 0; i < off; i++)
|
||||||
shortest.insert(shortest.begin(), 1);
|
shortest.insert(shortest.begin(), one);
|
||||||
ir::value *&target = (lhs_dim < rhs_dim)?lhs:rhs;
|
ir::value *&target = (lhs_dim < rhs_dim)?lhs:rhs;
|
||||||
if(off > 0)
|
if(off > 0)
|
||||||
target = builder.create_reshape(target, shortest);
|
target = builder.create_reshape(target, shortest);
|
||||||
// Broadcast
|
// Broadcast
|
||||||
std::vector<unsigned> shapes(ndim);
|
ir::type::tile_shapes_t shapes(ndim);
|
||||||
for(size_t i = 0; i < ndim; i++)
|
for(size_t i = 0; i < ndim; i++)
|
||||||
shapes[i] = std::max(shortest[i], longest[i]);
|
shapes[i] = shortest[i]==one?longest[i]:shortest[i];
|
||||||
if(shapes != lhs_shapes)
|
if(shapes != lhs_shapes)
|
||||||
lhs = builder.create_broadcast(lhs, shapes);
|
lhs = builder.create_broadcast(lhs, shapes);
|
||||||
if(shapes != rhs_shapes)
|
if(shapes != rhs_shapes)
|
||||||
@@ -148,14 +149,15 @@ inline bool is_terminator(ir::value* x) {
|
|||||||
|
|
||||||
/* Translation unit */
|
/* Translation unit */
|
||||||
ir::value* translation_unit::codegen(ir::module *mod) const{
|
ir::value* translation_unit::codegen(ir::module *mod) const{
|
||||||
decls_->codegen(mod);
|
mod->push_scope(nullptr);
|
||||||
|
decls_.codegen(mod);
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Declaration specifier */
|
/* Declaration specifier */
|
||||||
ir::type* declaration_specifier::type(ir::module *mod) const {
|
ir::type* typed_declaration_specifier::type(ir::module *mod) const {
|
||||||
ir::context &ctx = mod->get_context();
|
ir::context &ctx = mod->get_context();
|
||||||
switch (spec_) {
|
switch (ty_) {
|
||||||
case VOID_T: return ir::type::get_void_ty(ctx);
|
case VOID_T: return ir::type::get_void_ty(ctx);
|
||||||
case INT1_T: return ir::type::get_int1_ty(ctx);
|
case INT1_T: return ir::type::get_int1_ty(ctx);
|
||||||
case INT8_T: return ir::type::get_int8_ty(ctx);
|
case INT8_T: return ir::type::get_int8_ty(ctx);
|
||||||
@@ -164,7 +166,18 @@ ir::type* declaration_specifier::type(ir::module *mod) const {
|
|||||||
case INT64_T: return ir::type::get_int64_ty(ctx);
|
case INT64_T: return ir::type::get_int64_ty(ctx);
|
||||||
case FLOAT32_T: return ir::type::get_float_ty(ctx);
|
case FLOAT32_T: return ir::type::get_float_ty(ctx);
|
||||||
case FLOAT64_T: return ir::type::get_double_ty(ctx);
|
case FLOAT64_T: return ir::type::get_double_ty(ctx);
|
||||||
default: throw std::runtime_error("unreachable");
|
default: throw std::runtime_error("unreachable");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ir::type* storage_declaration_specifier::type(ir::module *mod) const {
|
||||||
|
ir::type* result = decl_spec_->type(mod);
|
||||||
|
switch(storage_spec_){
|
||||||
|
case TUNABLE_T: return result->set_tunable();
|
||||||
|
case KERNEL_T: return result->set_kernel();
|
||||||
|
case READONLY_T: return result->set_readonly();
|
||||||
|
case WRITEONLY_T: return result->set_writeonly();
|
||||||
|
default: throw std::runtime_error("unreachable");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -194,10 +207,10 @@ const std::string &identifier::name() const{
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Tile
|
// Tile
|
||||||
ir::type* tile::type_impl(ir::module*, ir::type *type) const{
|
ir::type* tile::type_impl(ir::module *mod, ir::type *type) const{
|
||||||
std::vector<unsigned> shapes;
|
ir::type::tile_shapes_t shapes;
|
||||||
for(constant *cst: shapes_->values())
|
for(constant *cst: shapes_->values())
|
||||||
shapes.push_back(cst->value());
|
shapes.push_back((ir::constant_int*)cst->codegen(mod));
|
||||||
return ir::tile_type::get(type, shapes);
|
return ir::tile_type::get(type, shapes);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -245,6 +258,7 @@ ir::value* function_definition::codegen(ir::module *mod) const{
|
|||||||
|
|
||||||
/* Statements */
|
/* Statements */
|
||||||
ir::value* compound_statement::codegen(ir::module* mod) const{
|
ir::value* compound_statement::codegen(ir::module* mod) const{
|
||||||
|
mod->push_scope(this);
|
||||||
if(decls_)
|
if(decls_)
|
||||||
decls_->codegen(mod);
|
decls_->codegen(mod);
|
||||||
if(statements_){
|
if(statements_){
|
||||||
@@ -254,6 +268,7 @@ ir::value* compound_statement::codegen(ir::module* mod) const{
|
|||||||
return current;
|
return current;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
mod->pop_scope();
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -337,7 +352,7 @@ ir::value* continue_statement::codegen(ir::module *mod) const{
|
|||||||
/* Declaration */
|
/* Declaration */
|
||||||
ir::value* declaration::codegen(ir::module* mod) const{
|
ir::value* declaration::codegen(ir::module* mod) const{
|
||||||
for(initializer *init: init_->values())
|
for(initializer *init: init_->values())
|
||||||
init->specifier(spec_);
|
init->set_specifier(spec_);
|
||||||
init_->codegen(mod);
|
init_->codegen(mod);
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
@@ -347,7 +362,7 @@ ir::type* initializer::type_impl(ir::module *mod, ir::type *type) const{
|
|||||||
return decl_->type(mod, type);
|
return decl_->type(mod, type);
|
||||||
}
|
}
|
||||||
|
|
||||||
void initializer::specifier(const declaration_specifier *spec) {
|
void initializer::set_specifier(const declaration_specifier *spec) {
|
||||||
spec_ = spec;
|
spec_ = spec;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -355,6 +370,11 @@ ir::value* initializer::codegen(ir::module * mod) const{
|
|||||||
ir::type *ty = decl_->type(mod, spec_->type(mod));
|
ir::type *ty = decl_->type(mod, spec_->type(mod));
|
||||||
std::string name = decl_->id()->name();
|
std::string name = decl_->id()->name();
|
||||||
ir::value *value = ir::undef_value::get(ty);
|
ir::value *value = ir::undef_value::get(ty);
|
||||||
|
if(ty->get_tunable()){
|
||||||
|
assert(expr_ == nullptr);
|
||||||
|
//TODO
|
||||||
|
value = ir::metaparameter::create(mod->get_context(), ty, 4, 8);
|
||||||
|
}
|
||||||
if(expr_){
|
if(expr_){
|
||||||
value = expr_->codegen(mod);
|
value = expr_->codegen(mod);
|
||||||
value = explicit_cast(mod->get_builder(), value, ty);
|
value = explicit_cast(mod->get_builder(), value, ty);
|
||||||
@@ -464,7 +484,7 @@ ir::value* binary_operator::codegen(ir::module *mod) const{
|
|||||||
// get_global_range
|
// get_global_range
|
||||||
ir::value* get_global_range::codegen(ir::module *mod) const {
|
ir::value* get_global_range::codegen(ir::module *mod) const {
|
||||||
ir::builder &builder = mod->get_builder();
|
ir::builder &builder = mod->get_builder();
|
||||||
return builder.create_get_global_range(axis_->value(), size_->value());
|
return builder.create_get_global_range(axis_->value(), (ir::constant_int*)size_->codegen(mod));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -487,11 +507,13 @@ ir::value* matmul_expression::codegen(ir::module *mod) const {
|
|||||||
ir::value* indexing_expression::codegen(ir::module *mod) const{
|
ir::value* indexing_expression::codegen(ir::module *mod) const{
|
||||||
ir::value *in = mod->get_value(id_->name());
|
ir::value *in = mod->get_value(id_->name());
|
||||||
const std::vector<slice*> &slices = slices_->values();
|
const std::vector<slice*> &slices = slices_->values();
|
||||||
std::vector<unsigned> in_shapes = in->get_type()->get_tile_shapes();
|
auto in_shapes = in->get_type()->get_tile_shapes();
|
||||||
std::vector<unsigned> out_shapes(slices.size());
|
ir::type::tile_shapes_t::value_type one = ir::tile_type::make_one(mod->get_context());
|
||||||
|
ir::type::tile_shapes_t out_shapes(slices.size());
|
||||||
|
// create shapes
|
||||||
size_t current = 0;
|
size_t current = 0;
|
||||||
for(size_t i = 0; i < out_shapes.size(); i++)
|
for(size_t i = 0; i < out_shapes.size(); i++)
|
||||||
out_shapes[i] = (slices[i]->type()==NEWAXIS)?1:in_shapes[current++];
|
out_shapes[i] = (slices[i]->type()==NEWAXIS)?one:in_shapes[current++];
|
||||||
return mod->get_builder().create_reshape(in, out_shapes);
|
return mod->get_builder().create_reshape(in, out_shapes);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -586,8 +608,8 @@ int constant::value() const{
|
|||||||
|
|
||||||
/* Constant range */
|
/* Constant range */
|
||||||
ir::value* constant_range::codegen(ir::module *mod) const{
|
ir::value* constant_range::codegen(ir::module *mod) const{
|
||||||
return ir::constant_range::get((ir::constant*)first_->codegen(mod),
|
return ir::constant_range::get((ir::constant_int*)first_->codegen(mod),
|
||||||
(ir::constant*)last_->codegen(mod));
|
(ir::constant_int*)last_->codegen(mod));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Named */
|
/* Named */
|
||||||
|
@@ -45,10 +45,15 @@ void barriers::get_written_intervals(ir::instruction *i, interval_vec_t &res){
|
|||||||
|
|
||||||
void barriers::insert_barrier(ir::instruction *instr, ir::builder &builder) {
|
void barriers::insert_barrier(ir::instruction *instr, ir::builder &builder) {
|
||||||
if(auto *phi = dynamic_cast<ir::phi_node*>(instr)) {
|
if(auto *phi = dynamic_cast<ir::phi_node*>(instr)) {
|
||||||
|
std::set<ir::value*> incoming;
|
||||||
for(unsigned n = 0; n < phi->get_num_incoming(); n++){
|
for(unsigned n = 0; n < phi->get_num_incoming(); n++){
|
||||||
ir::basic_block *block = phi->get_incoming_block(n);
|
ir::instruction *inc_val = dynamic_cast<ir::instruction*>(phi->get_incoming_value(n));
|
||||||
builder.set_insert_point(block->get_inst_list().back());
|
assert(inc_val);
|
||||||
builder.create_barrier();
|
if(incoming.insert(inc_val).second){
|
||||||
|
ir::basic_block *block = inc_val->get_parent();
|
||||||
|
builder.set_insert_point(block->get_inst_list().back());
|
||||||
|
builder.create_barrier();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
@@ -57,15 +62,15 @@ void barriers::insert_barrier(ir::instruction *instr, ir::builder &builder) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void barriers::add(ir::basic_block *block, interval_vec_t ¬_synced, std::set<ir::instruction*> &insert_pts) {
|
void barriers::add(ir::basic_block *block, interval_vec_t ¬_synced, ir::builder &builder) {
|
||||||
for(ir::instruction *i: block->get_inst_list()){
|
ir::basic_block::inst_list_t instructions = block->get_inst_list();
|
||||||
|
for(ir::instruction *i: instructions){
|
||||||
interval_vec_t read, written;
|
interval_vec_t read, written;
|
||||||
get_read_intervals(i, read);
|
get_read_intervals(i, read);
|
||||||
get_written_intervals(i, written);
|
get_written_intervals(i, written);
|
||||||
if(intersect(not_synced, read)
|
if(intersect(not_synced, read)) {
|
||||||
|| intersect(not_synced, written)) {
|
|
||||||
not_synced.clear();
|
not_synced.clear();
|
||||||
insert_pts.insert(i);
|
insert_barrier(i, builder);
|
||||||
}
|
}
|
||||||
std::copy(written.begin(), written.end(), std::back_inserter(not_synced));
|
std::copy(written.begin(), written.end(), std::back_inserter(not_synced));
|
||||||
}
|
}
|
||||||
@@ -76,12 +81,8 @@ void barriers::run(ir::module &mod) {
|
|||||||
for(ir::function *fn: mod.get_function_list()){
|
for(ir::function *fn: mod.get_function_list()){
|
||||||
// find barrier location
|
// find barrier location
|
||||||
interval_vec_t not_synced;
|
interval_vec_t not_synced;
|
||||||
std::set<ir::instruction*> insert_pts;
|
|
||||||
for(ir::basic_block *block: fn->blocks())
|
for(ir::basic_block *block: fn->blocks())
|
||||||
add(block, not_synced, insert_pts);
|
add(block, not_synced, builder);
|
||||||
// insert barrier
|
|
||||||
for(ir::instruction *i: insert_pts)
|
|
||||||
insert_barrier(i, builder);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -44,6 +44,7 @@ llvm::Type *distributed_tile::make_vector_ty(llvm::Type *ty, size_t vector_size)
|
|||||||
return VectorType::get(ty, vector_size);
|
return VectorType::get(ty, vector_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
distributed_tile::distributed_tile(Type *ty, const shapes_t &shapes, const axes_t &axes, llvm::IRBuilder<> &builder, bool vectorize)
|
distributed_tile::distributed_tile(Type *ty, const shapes_t &shapes, const axes_t &axes, llvm::IRBuilder<> &builder, bool vectorize)
|
||||||
: tile(make_vector_ty(ty, vectorize?axes[0].contiguous:1), shapes), axes_(axes), builder_(builder) {
|
: tile(make_vector_ty(ty, vectorize?axes[0].contiguous:1), shapes), axes_(axes), builder_(builder) {
|
||||||
vector_size_ = vectorize?ty_->getVectorNumElements():1;
|
vector_size_ = vectorize?ty_->getVectorNumElements():1;
|
||||||
@@ -149,6 +150,16 @@ Value* shared_tile::get_value(indices_t idx) {
|
|||||||
return builder_.CreateLoad(ptr);
|
return builder_.CreateLoad(ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* Utils */
|
||||||
|
std::vector<unsigned> selection::extract_shapes(ir::value *v) {
|
||||||
|
const auto& shapes = v->get_type()->get_tile_shapes();
|
||||||
|
std::vector<unsigned> result(shapes.size());
|
||||||
|
for(ir::constant_int* cst: shapes)
|
||||||
|
result.push_back(cst->get_value());
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
/* convert ir::type to Type */
|
/* convert ir::type to Type */
|
||||||
Type *selection::llvm_type(ir::type *ty, LLVMContext &ctx) {
|
Type *selection::llvm_type(ir::type *ty, LLVMContext &ctx) {
|
||||||
// function
|
// function
|
||||||
@@ -299,11 +310,12 @@ std::vector<Value*> delinearize(Value *trailing, std::vector<unsigned> &shapes,
|
|||||||
}
|
}
|
||||||
|
|
||||||
void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
|
void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
|
||||||
const auto& shapes = v->get_type()->get_tile_shapes();
|
const auto& shapes = extract_shapes(v);
|
||||||
size_t dim = shapes.size();
|
size_t dim = shapes.size();
|
||||||
std::vector<unsigned> contiguous(dim);
|
std::vector<unsigned> contiguous(dim);
|
||||||
std::vector<unsigned> warp_size(dim);
|
std::vector<unsigned> warp_size(dim);
|
||||||
std::vector<unsigned> n_warps(dim);
|
std::vector<unsigned> n_warps(dim);
|
||||||
|
std::cout << v->get_name() << " " << typeid(*v).name() << std::endl;
|
||||||
for(unsigned i = 0; i < shapes.size(); i++){
|
for(unsigned i = 0; i < shapes.size(); i++){
|
||||||
std::string str_i = std::to_string(i);
|
std::string str_i = std::to_string(i);
|
||||||
contiguous[i] = *params_->get_param(v, "p0.d" + str_i);
|
contiguous[i] = *params_->get_param(v, "p0.d" + str_i);
|
||||||
@@ -336,7 +348,7 @@ void selection::create_grids(std::vector<ir::value*> &grids,
|
|||||||
// get number of dimensions greater than 1
|
// get number of dimensions greater than 1
|
||||||
auto get_tile_gt1_dim = [&](ir::value *v){
|
auto get_tile_gt1_dim = [&](ir::value *v){
|
||||||
unsigned result = 0;
|
unsigned result = 0;
|
||||||
for(unsigned shape: v->get_type()->get_tile_shapes()) {
|
for(unsigned shape: extract_shapes(v)) {
|
||||||
result += (shape > 1)?shape:0;
|
result += (shape > 1)?shape:0;
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
@@ -353,7 +365,7 @@ void selection::create_grids(std::vector<ir::value*> &grids,
|
|||||||
for(ir::value *op: user->ops())
|
for(ir::value *op: user->ops())
|
||||||
bind_references(op);
|
bind_references(op);
|
||||||
// bind
|
// bind
|
||||||
const auto& shapes = v->get_type()->get_tile_shapes();
|
const auto& shapes = extract_shapes(v);
|
||||||
if(dynamic_cast<ir::copy_to_shared_inst*>(v) || buffer_info_->is_double(v))
|
if(dynamic_cast<ir::copy_to_shared_inst*>(v) || buffer_info_->is_double(v))
|
||||||
return;
|
return;
|
||||||
for(size_t d = 0; d < shapes.size(); d++){
|
for(size_t d = 0; d < shapes.size(); d++){
|
||||||
@@ -385,7 +397,7 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder,
|
|||||||
for(ir::value *op: user->ops())
|
for(ir::value *op: user->ops())
|
||||||
create_tile(op, builder, references, seen, sh_mem_ptr);
|
create_tile(op, builder, references, seen, sh_mem_ptr);
|
||||||
LLVMContext &ctx = builder.getContext();
|
LLVMContext &ctx = builder.getContext();
|
||||||
const auto& shapes = v->get_type()->get_tile_shapes();
|
const auto& shapes = extract_shapes(v);
|
||||||
Type* ty = llvm_type(v->get_type()->get_scalar_ty(), ctx);
|
Type* ty = llvm_type(v->get_type()->get_scalar_ty(), ctx);
|
||||||
// create shared tile
|
// create shared tile
|
||||||
if(dynamic_cast<ir::copy_to_shared_inst*>(v) || (buffer_info_->is_double(v))){
|
if(dynamic_cast<ir::copy_to_shared_inst*>(v) || (buffer_info_->is_double(v))){
|
||||||
@@ -429,7 +441,7 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder,
|
|||||||
}
|
}
|
||||||
// create distributed tile
|
// create distributed tile
|
||||||
else {
|
else {
|
||||||
const auto &shapes = v->get_type()->get_tile_shapes();
|
const auto &shapes = extract_shapes(v);
|
||||||
std::vector<distributed_axis> axes(shapes.size());
|
std::vector<distributed_axis> axes(shapes.size());
|
||||||
for(size_t d = 0; d < shapes.size(); d++){
|
for(size_t d = 0; d < shapes.size(); d++){
|
||||||
if(shapes[d] > 1){
|
if(shapes[d] > 1){
|
||||||
@@ -530,7 +542,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
|||||||
distributed_tile* result = (distributed_tile*)ti;
|
distributed_tile* result = (distributed_tile*)ti;
|
||||||
if(!ins->get_type()->is_tile_ty())
|
if(!ins->get_type()->is_tile_ty())
|
||||||
return;
|
return;
|
||||||
const auto& shapes = ins->get_type()->get_tile_shapes();
|
const auto& shapes = extract_shapes(ins);
|
||||||
// global_range
|
// global_range
|
||||||
if(auto *x = dynamic_cast<ir::get_global_range_inst*>(ins)) {
|
if(auto *x = dynamic_cast<ir::get_global_range_inst*>(ins)) {
|
||||||
static std::array<Intrinsic::ID, 3> ctaid = {
|
static std::array<Intrinsic::ID, 3> ctaid = {
|
||||||
@@ -568,7 +580,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
|||||||
// broadcast
|
// broadcast
|
||||||
else if(dynamic_cast<ir::broadcast_inst*>(ins)) {
|
else if(dynamic_cast<ir::broadcast_inst*>(ins)) {
|
||||||
ir::value* in = ins->get_operand(0);
|
ir::value* in = ins->get_operand(0);
|
||||||
const auto& in_shapes = in->get_type()->get_tile_shapes();
|
const auto& in_shapes = extract_shapes(in);
|
||||||
distributed_tile *in_tile = (distributed_tile*)tmap_.at(in);
|
distributed_tile *in_tile = (distributed_tile*)tmap_.at(in);
|
||||||
result->for_each([&](indices_t out_idx){
|
result->for_each([&](indices_t out_idx){
|
||||||
indices_t in_idx = out_idx;
|
indices_t in_idx = out_idx;
|
||||||
@@ -615,7 +627,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
|||||||
Function *f_mul_add = Intrinsic::getDeclaration(module, Intrinsic::fmuladd, {llvm_type(C->get_type()->get_scalar_ty(), ctx)});
|
Function *f_mul_add = Intrinsic::getDeclaration(module, Intrinsic::fmuladd, {llvm_type(C->get_type()->get_scalar_ty(), ctx)});
|
||||||
result->for_each([&](indices_t idx){
|
result->for_each([&](indices_t idx){
|
||||||
Value *res = tmap_.at(C)->get_value(idx);
|
Value *res = tmap_.at(C)->get_value(idx);
|
||||||
unsigned NK = A->get_type()->get_tile_shapes()[1];
|
unsigned NK = extract_shapes(A)[1];
|
||||||
for(unsigned K = 0; K < NK; ++K){
|
for(unsigned K = 0; K < NK; ++K){
|
||||||
indices_t a_idx = {idx[0], builder.getInt32(K)};
|
indices_t a_idx = {idx[0], builder.getInt32(K)};
|
||||||
indices_t b_idx = {idx[1], builder.getInt32(K)};
|
indices_t b_idx = {idx[1], builder.getInt32(K)};
|
||||||
|
@@ -3,6 +3,8 @@
|
|||||||
#include "ir/type.h"
|
#include "ir/type.h"
|
||||||
#include "ir/module.h"
|
#include "ir/module.h"
|
||||||
#include "ir/function.h"
|
#include "ir/function.h"
|
||||||
|
#include "ir/context_impl.h"
|
||||||
|
|
||||||
#include <cstdlib>
|
#include <cstdlib>
|
||||||
|
|
||||||
|
|
||||||
@@ -29,7 +31,8 @@ void tune::init_c_phi(ir::instruction *v) {
|
|||||||
|
|
||||||
void tune::init_c_graph(ir::instruction *v) {
|
void tune::init_c_graph(ir::instruction *v) {
|
||||||
// Reference shape
|
// Reference shape
|
||||||
std::vector<unsigned> shapes;
|
ir::type::tile_shapes_t::value_type one = ir::tile_type::make_one(v->get_parent()->get_context());
|
||||||
|
ir::type::tile_shapes_t shapes;
|
||||||
if(auto *store = dynamic_cast<ir::store_inst*>(v))
|
if(auto *store = dynamic_cast<ir::store_inst*>(v))
|
||||||
shapes = store->get_pointer_operand()->get_type()->get_tile_shapes();
|
shapes = store->get_pointer_operand()->get_type()->get_tile_shapes();
|
||||||
else
|
else
|
||||||
@@ -39,7 +42,7 @@ void tune::init_c_graph(ir::instruction *v) {
|
|||||||
ir::value *op = v->get_operand(0);
|
ir::value *op = v->get_operand(0);
|
||||||
unsigned current = 0;
|
unsigned current = 0;
|
||||||
for(unsigned i = 0; i < shapes.size(); i ++){
|
for(unsigned i = 0; i < shapes.size(); i ++){
|
||||||
if(shapes[i] == 1)
|
if(shapes[i] == one)
|
||||||
static_params_.insert({{v, i}, 1});
|
static_params_.insert({{v, i}, 1});
|
||||||
else
|
else
|
||||||
add_constraint({v, i}, {op, current++});
|
add_constraint({v, i}, {op, current++});
|
||||||
@@ -99,6 +102,7 @@ void tune::connected_components(node_t x, const std::vector<unsigned *> vals, st
|
|||||||
std::vector<unsigned*> tune::get_params(ir::module &mod) {
|
std::vector<unsigned*> tune::get_params(ir::module &mod) {
|
||||||
std::vector<unsigned *> result;
|
std::vector<unsigned *> result;
|
||||||
std::set<unsigned*> seen;
|
std::set<unsigned*> seen;
|
||||||
|
|
||||||
for(ir::function *fn: mod.get_function_list())
|
for(ir::function *fn: mod.get_function_list())
|
||||||
for(ir::basic_block *block: fn->blocks())
|
for(ir::basic_block *block: fn->blocks())
|
||||||
for(ir::instruction *i : block->get_inst_list())
|
for(ir::instruction *i : block->get_inst_list())
|
||||||
@@ -143,8 +147,9 @@ void tune::create_grids(std::vector<ir::instruction*> &grids,
|
|||||||
// get number of dimensions greater than 1
|
// get number of dimensions greater than 1
|
||||||
auto get_tile_gt1_dim = [&](ir::value *v){
|
auto get_tile_gt1_dim = [&](ir::value *v){
|
||||||
unsigned result = 0;
|
unsigned result = 0;
|
||||||
for(unsigned shape: v->get_type()->get_tile_shapes()) {
|
auto one = ir::tile_type::make_one(fn->get_fn_type()->get_context());
|
||||||
result += (shape > 1)?shape:0;
|
for(ir::constant_int *shape: v->get_type()->get_tile_shapes()) {
|
||||||
|
result += (shape != one);
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
};
|
};
|
||||||
@@ -194,8 +199,8 @@ for(ir::function *fn: mod.get_function_list()){
|
|||||||
unsigned *s1 = params_[i]["p1.d" + strk];
|
unsigned *s1 = params_[i]["p1.d" + strk];
|
||||||
unsigned *s2 = params_[i]["p2.d" + strk];
|
unsigned *s2 = params_[i]["p2.d" + strk];
|
||||||
unsigned multiple = (*s0)*(*s1)*(*s2);
|
unsigned multiple = (*s0)*(*s1)*(*s2);
|
||||||
if(shapes[k] % multiple != 0)
|
if(shapes[k]->get_value() % multiple != 0)
|
||||||
errors[i].push_back("for dim " + strk + ": shape (" + to_string(shapes[k]) + ")"
|
errors[i].push_back("for dim " + strk + ": shape (" + to_string(shapes[k]->get_value()) + ")"
|
||||||
" is not a multiple of layout (" + to_string(multiple) + ")");
|
" is not a multiple of layout (" + to_string(multiple) + ")");
|
||||||
}
|
}
|
||||||
// the number of thread per warp must be 32
|
// the number of thread per warp must be 32
|
||||||
|
@@ -244,15 +244,15 @@ value *builder::create_store(value *ptr, value *val, const std::string &name){
|
|||||||
// tile instructions
|
// tile instructions
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
value *builder::create_reshape(value *arg, const std::vector<unsigned> &shapes, const std::string &name) {
|
value *builder::create_reshape(value *arg, const type::tile_shapes_t &shapes, const std::string &name) {
|
||||||
return insert(reshape_inst::create(arg, shapes, name));
|
return insert(reshape_inst::create(arg, shapes, name));
|
||||||
}
|
}
|
||||||
|
|
||||||
value *builder::create_splat(value *arg, const std::vector<unsigned> &shapes, const std::string &name) {
|
value *builder::create_splat(value *arg, const type::tile_shapes_t &shapes, const std::string &name) {
|
||||||
return insert(splat_inst::create(arg, shapes, name));
|
return insert(splat_inst::create(arg, shapes, name));
|
||||||
}
|
}
|
||||||
|
|
||||||
value *builder::create_broadcast(value *arg, const std::vector<unsigned> &shapes, const std::string &name) {
|
value *builder::create_broadcast(value *arg, const type::tile_shapes_t &shapes, const std::string &name) {
|
||||||
return insert(broadcast_inst::create(arg, shapes, name));
|
return insert(broadcast_inst::create(arg, shapes, name));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -260,7 +260,7 @@ value *builder::create_broadcast(value *arg, const std::vector<unsigned> &shapes
|
|||||||
// built-in instructions
|
// built-in instructions
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
value *builder::create_get_global_range(unsigned axis, unsigned size, const std::string &name) {
|
value *builder::create_get_global_range(unsigned axis, type::tile_shapes_t::value_type size, const std::string &name) {
|
||||||
return insert(get_global_range_inst::create(ctx_, axis, size, name));
|
return insert(get_global_range_inst::create(ctx_, axis, size, name));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -48,24 +48,27 @@ constant *constant::get_all_ones_value(type *ty) {
|
|||||||
constant_int::constant_int(type *ty, uint64_t value)
|
constant_int::constant_int(type *ty, uint64_t value)
|
||||||
: constant(ty, 0), value_(value){ }
|
: constant(ty, 0), value_(value){ }
|
||||||
|
|
||||||
constant *constant_int::get(type *ty, uint64_t value) {
|
constant_int *constant_int::get(type *ty, uint64_t value) {
|
||||||
return new constant_int(ty, value);
|
context_impl *impl = ty->get_context().p_impl.get();
|
||||||
|
constant_int *& cst = impl->int_constants_[std::make_pair(ty, value)];
|
||||||
|
if(cst == nullptr)
|
||||||
|
cst = new constant_int(ty, value);
|
||||||
|
return cst;
|
||||||
}
|
}
|
||||||
|
|
||||||
// constant_range
|
// constant_range
|
||||||
// FIXME use something like APInt
|
// FIXME use something like APInt
|
||||||
|
|
||||||
constant_range::constant_range(type *ty, uint64_t first, uint64_t last)
|
constant_range::constant_range(type *ty, constant_int *first, constant_int *last)
|
||||||
: constant(ty, 0), first_(first), last_(last){ }
|
: constant(ty, 0), first_(first), last_(last){ }
|
||||||
|
|
||||||
constant *constant_range::get(constant *first, constant *last) {
|
constant *constant_range::get(constant_int *first, constant_int *last) {
|
||||||
assert(first->get_type()->is_integer_ty());
|
assert(first->get_type()->is_integer_ty());
|
||||||
assert(first->get_type() == last->get_type());
|
assert(first->get_type() == last->get_type());
|
||||||
unsigned vfirst = ((constant_int*)first)->get_value();
|
unsigned vfirst = ((constant_int*)first)->get_value();
|
||||||
unsigned vlast = ((constant_int*)last)->get_value();
|
assert(vfirst == 0);
|
||||||
assert(vlast > vfirst);
|
type *ty = tile_type::get(first->get_type(), {last});
|
||||||
type *ty = tile_type::get(first->get_type(), {vlast - vfirst});
|
return new constant_range(ty, first, last);
|
||||||
return new constant_range(ty, vfirst, vlast);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -94,6 +97,17 @@ constant *constant_fp::get(context &ctx, double v){
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// metaparameter
|
||||||
|
metaparameter::metaparameter(type *ty, unsigned lo, unsigned hi)
|
||||||
|
: constant_int(ty, 0), lo_(lo), hi_(hi){ }
|
||||||
|
|
||||||
|
metaparameter* metaparameter::create(context &ctx, type *ty, unsigned lo, unsigned hi) {
|
||||||
|
context_impl *impl = ctx.p_impl.get();
|
||||||
|
metaparameter *result = new metaparameter(ty, lo, hi);
|
||||||
|
impl->mp_constants_.push_back(result);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
// undef value
|
// undef value
|
||||||
undef_value::undef_value(type *ty)
|
undef_value::undef_value(type *ty)
|
||||||
: constant(ty, 0) { }
|
: constant(ty, 0) { }
|
||||||
|
@@ -409,7 +409,7 @@ std::string retile_inst::shape_suffix(ir::type* ty){
|
|||||||
std::string res = "[";
|
std::string res = "[";
|
||||||
const auto& shapes = ty->get_tile_shapes();
|
const auto& shapes = ty->get_tile_shapes();
|
||||||
for(unsigned i = 0; i < shapes.size(); i++){
|
for(unsigned i = 0; i < shapes.size(); i++){
|
||||||
res += std::to_string(ty->get_tile_shapes()[i]);
|
res += std::to_string(ty->get_tile_shapes()[i]->get_value());
|
||||||
if(i < shapes.size() - 1)
|
if(i < shapes.size() - 1)
|
||||||
res += ", ";
|
res += ", ";
|
||||||
}
|
}
|
||||||
@@ -417,13 +417,13 @@ std::string retile_inst::shape_suffix(ir::type* ty){
|
|||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
retile_inst::retile_inst(value *arg, const std::vector<unsigned> &shapes,
|
retile_inst::retile_inst(value *arg, const type::tile_shapes_t &shapes,
|
||||||
const std::string &name, instruction *next)
|
const std::string &name, instruction *next)
|
||||||
: unary_inst(tile_type::get(arg->get_type()->get_scalar_ty(), shapes), arg, name, next) { }
|
: unary_inst(tile_type::get(arg->get_type()->get_scalar_ty(), shapes), arg, name, next) { }
|
||||||
|
|
||||||
// reshape
|
// reshape
|
||||||
|
|
||||||
instruction* reshape_inst::create(value *arg, const std::vector<unsigned> &shapes,
|
instruction* reshape_inst::create(value *arg, const type::tile_shapes_t &shapes,
|
||||||
const std::string &name, instruction *next) {
|
const std::string &name, instruction *next) {
|
||||||
return new reshape_inst(arg, shapes, name, next);
|
return new reshape_inst(arg, shapes, name, next);
|
||||||
}
|
}
|
||||||
@@ -431,14 +431,14 @@ instruction* reshape_inst::create(value *arg, const std::vector<unsigned> &shape
|
|||||||
|
|
||||||
// splat
|
// splat
|
||||||
|
|
||||||
instruction* splat_inst::create(value *arg, const std::vector<unsigned> &shapes,
|
instruction* splat_inst::create(value *arg, const type::tile_shapes_t &shapes,
|
||||||
const std::string &name, instruction *next) {
|
const std::string &name, instruction *next) {
|
||||||
return new splat_inst(arg, shapes, name, next);
|
return new splat_inst(arg, shapes, name, next);
|
||||||
}
|
}
|
||||||
|
|
||||||
// broadcast
|
// broadcast
|
||||||
|
|
||||||
instruction* broadcast_inst::create(value *arg, const std::vector<unsigned> &shapes,
|
instruction* broadcast_inst::create(value *arg, const type::tile_shapes_t &shapes,
|
||||||
const std::string &name, instruction *next) {
|
const std::string &name, instruction *next) {
|
||||||
return new broadcast_inst(arg, shapes, name, next);
|
return new broadcast_inst(arg, shapes, name, next);
|
||||||
}
|
}
|
||||||
@@ -470,7 +470,7 @@ get_global_range_inst::get_global_range_inst(type *ty, unsigned axis,
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
instruction* get_global_range_inst::create(context &ctx, unsigned axis, unsigned size,
|
instruction* get_global_range_inst::create(context &ctx, unsigned axis, type::tile_shapes_t::value_type size,
|
||||||
const std::string &name, instruction *next) {
|
const std::string &name, instruction *next) {
|
||||||
type *int_ty = type::get_int32_ty(ctx);
|
type *int_ty = type::get_int32_ty(ctx);
|
||||||
type *tile_ty = tile_type::get(int_ty, {size});
|
type *tile_ty = tile_type::get(int_ty, {size});
|
||||||
|
@@ -3,6 +3,7 @@
|
|||||||
#include "ir/context.h"
|
#include "ir/context.h"
|
||||||
#include "ir/context_impl.h"
|
#include "ir/context_impl.h"
|
||||||
#include "ir/value.h"
|
#include "ir/value.h"
|
||||||
|
#include "ir/constant.h"
|
||||||
|
|
||||||
namespace tdl{
|
namespace tdl{
|
||||||
namespace ir{
|
namespace ir{
|
||||||
@@ -63,7 +64,7 @@ type * type::get_pointer_element_ty() const {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
const std::vector<unsigned> &type::get_tile_shapes() const {
|
const type::tile_shapes_t &type::get_tile_shapes() const {
|
||||||
assert(is_tile_ty());
|
assert(is_tile_ty());
|
||||||
return ((tile_type*)this)->get_shapes();
|
return ((tile_type*)this)->get_shapes();
|
||||||
}
|
}
|
||||||
@@ -148,7 +149,7 @@ bool composite_type::index_valid(value *idx) const{
|
|||||||
// tile_type class
|
// tile_type class
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
tile_type::tile_type(type *ty, const std::vector<unsigned> &shapes)
|
tile_type::tile_type(type *ty, const tile_shapes_t &shapes)
|
||||||
: composite_type(ty->get_context(), TileTyID), shapes_(shapes) {
|
: composite_type(ty->get_context(), TileTyID), shapes_(shapes) {
|
||||||
contained_tys_.push_back(ty);
|
contained_tys_.push_back(ty);
|
||||||
}
|
}
|
||||||
@@ -159,8 +160,8 @@ bool tile_type::is_valid_elt_ty(type *ty) {
|
|||||||
|
|
||||||
unsigned tile_type::get_num_elements() const {
|
unsigned tile_type::get_num_elements() const {
|
||||||
unsigned res = 1;
|
unsigned res = 1;
|
||||||
for(unsigned shape: shapes_)
|
for(auto shape: shapes_)
|
||||||
res *= shape;
|
res *= shape->get_value();
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -168,7 +169,7 @@ unsigned tile_type::get_bitwidth() const {
|
|||||||
return get_num_elements() * get_tile_element_ty()->get_primitive_size_in_bits();
|
return get_num_elements() * get_tile_element_ty()->get_primitive_size_in_bits();
|
||||||
}
|
}
|
||||||
|
|
||||||
tile_type* tile_type::get(type *elt_ty, const std::vector<unsigned> &shapes) {
|
tile_type* tile_type::get(type *elt_ty, const tile_shapes_t &shapes) {
|
||||||
assert(elt_ty && "Can't get a tile of <null> type!");
|
assert(elt_ty && "Can't get a tile of <null> type!");
|
||||||
assert(shapes.size() && "Can't create a tile with empty shapes!");
|
assert(shapes.size() && "Can't create a tile with empty shapes!");
|
||||||
assert(is_valid_elt_ty(elt_ty) && "Invalid type for pointer element!");
|
assert(is_valid_elt_ty(elt_ty) && "Invalid type for pointer element!");
|
||||||
@@ -185,6 +186,10 @@ tile_type* tile_type::get_same_shapes(type *ty, type *ref){
|
|||||||
return get(ty, ref->get_tile_shapes());
|
return get(ty, ref->get_tile_shapes());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type::tile_shapes_t::value_type tile_type::make_one(ir::context& ctx){
|
||||||
|
return constant_int::get(type::get_int32_ty(ctx), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// function_type class
|
// function_type class
|
||||||
|
Reference in New Issue
Block a user