[code generation] basic metaparameter support
This commit is contained in:
4
TODO
4
TODO
@@ -1,5 +1,7 @@
|
||||
[Frontend]
|
||||
- SCOPES
|
||||
|
||||
[Intermediate Representation]
|
||||
- proper naming scheme
|
||||
- symbols table
|
||||
- name conflicts on globals?
|
||||
- separate header for typedef (e.g., type::tile_shapes_t) to reduce compilation time
|
||||
|
@@ -38,29 +38,31 @@ extern translation_unit *ast_root;
|
||||
|
||||
const char src[] =
|
||||
"\
|
||||
const tunable int32 TM;\
|
||||
const tunable int32 TN;\
|
||||
void test(fp32 *a, fp32 *b, fp32 *c, int32 M, int32 N, int32 K, int32 bound){\
|
||||
int32 rxa[16] = get_global_range[16](0);\
|
||||
int32 ryb[16] = get_global_range[16](1);\
|
||||
int32 rxa[TM] = get_global_range[TM](0);\
|
||||
int32 ryb[TN] = get_global_range[TN](1);\
|
||||
int32 rka[8] = 0 ... 8;\
|
||||
int32 rkb[8] = 0 ... 8;\
|
||||
int32 rxc[16] = get_global_range[16](0);\
|
||||
int32 ryc[16] = get_global_range[16](1);\
|
||||
fp32 C[16, 16] = 0;\
|
||||
int32 rxc[TM] = get_global_range[TM](0);\
|
||||
int32 ryc[TN] = get_global_range[TN](1);\
|
||||
fp32 C[TM, TN] = 0;\
|
||||
int32 k;\
|
||||
fp32* pa[16, 8] = a + rxa[:, newaxis] + rka[newaxis, :]*M;\
|
||||
fp32* pb[16, 8] = b + ryb[:, newaxis] + rkb[newaxis, :]*K;\
|
||||
fp32* pc[16, 16] = c + rxc[:, newaxis] + ryc[newaxis, :]*M;\
|
||||
fp32 a[16, 8] = *pa;\
|
||||
fp32 b[16, 8] = *pb;\
|
||||
int1 checkc0[16] = rxc < M;\
|
||||
int1 checkc1[16] = ryc < N;\
|
||||
int1 checkc[16, 16] = checkc0[:, newaxis] && checkc1[newaxis, :];\
|
||||
fp32* pa[TM, 8] = a + rxa[:, newaxis] + rka[newaxis, :]*M;\
|
||||
fp32* pb[TN, 8] = b + ryb[:, newaxis] + rkb[newaxis, :]*K;\
|
||||
fp32* pc[TM, TN] = c + rxc[:, newaxis] + ryc[newaxis, :]*M;\
|
||||
fp32 a[TM, 8] = *pa;\
|
||||
fp32 b[TN, 8] = *pb;\
|
||||
int1 checkc0[TM] = rxc < M;\
|
||||
int1 checkc1[TN] = ryc < N;\
|
||||
int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];\
|
||||
for(k = K; k > 0; k = k - 8){\
|
||||
int1 checka[16, 8] = (k > 8);\
|
||||
int1 checkb[16, 8] = (k > 8);\
|
||||
int1 checka0[16];\
|
||||
int1 checka[TM, 8] = (k > 8);\
|
||||
int1 checkb[TN, 8] = (k > 8);\
|
||||
int1 checka0[TM];\
|
||||
int1 checka1[8];\
|
||||
int1 checkb0[16];\
|
||||
int1 checkb0[TN];\
|
||||
int1 checkb1[8];\
|
||||
C = dot(a, b, C);\
|
||||
pa = pa + 8*M;\
|
||||
@@ -183,8 +185,8 @@ int main() {
|
||||
llvm::LLVMContext llvm_context;
|
||||
llvm::Module llvm_module("test", llvm_context);
|
||||
|
||||
// context.p_impl->mp_constants_[0]->set_value(16);
|
||||
// context.p_impl->mp_constants_[1]->set_value(16);
|
||||
context.p_impl->mp_constants_[0]->set_value(16);
|
||||
context.p_impl->mp_constants_[1]->set_value(16);
|
||||
// context.p_impl->mp_constants_[2]->set_value(8);
|
||||
|
||||
// create passes
|
||||
|
@@ -57,6 +57,7 @@ enum TYPE_T{
|
||||
};
|
||||
|
||||
enum STORAGE_SPEC_T{
|
||||
CONST_T,
|
||||
TUNABLE_T,
|
||||
KERNEL_T,
|
||||
READONLY_T, WRITEONLY_T,
|
||||
@@ -399,12 +400,14 @@ class declaration_specifier: public node{
|
||||
public:
|
||||
using node::node;
|
||||
virtual ir::type* type(ir::module *mod) const = 0;
|
||||
virtual std::vector<STORAGE_SPEC_T> storage() const = 0;
|
||||
};
|
||||
|
||||
class typed_declaration_specifier: public declaration_specifier {
|
||||
public:
|
||||
typed_declaration_specifier(TYPE_T ty): ty_(ty){ }
|
||||
ir::type* type(ir::module *mod) const;
|
||||
std::vector<STORAGE_SPEC_T> storage() const;
|
||||
|
||||
private:
|
||||
const TYPE_T ty_;
|
||||
@@ -415,6 +418,7 @@ public:
|
||||
storage_declaration_specifier(STORAGE_SPEC_T storage_spec, node *decl_spec)
|
||||
: storage_spec_(storage_spec), decl_spec_((declaration_specifier*)decl_spec) {}
|
||||
ir::type* type(ir::module *mod) const;
|
||||
std::vector<STORAGE_SPEC_T> storage() const;
|
||||
|
||||
private:
|
||||
const STORAGE_SPEC_T storage_spec_;
|
||||
@@ -429,6 +433,7 @@ public:
|
||||
decl_((declarator*)decl) { }
|
||||
|
||||
ir::type* type(ir::module *mod) const;
|
||||
std::vector<STORAGE_SPEC_T> storage() const;
|
||||
const identifier* id() const;
|
||||
|
||||
public:
|
||||
@@ -485,10 +490,10 @@ private:
|
||||
|
||||
public:
|
||||
tile(node *id, node *shapes)
|
||||
: declarator(id), shapes_((list<constant*>*)(shapes)) { }
|
||||
: declarator(id), shapes_((list<expression*>*)(shapes)) { }
|
||||
|
||||
public:
|
||||
const list<constant*>* shapes_;
|
||||
const list<expression*>* shapes_;
|
||||
};
|
||||
|
||||
class function: public declarator{
|
||||
|
@@ -46,7 +46,7 @@ STORAGE_SPEC_T get_storage_spec(node *op) { return ((token*)op)->storage_spec;}
|
||||
%}
|
||||
|
||||
%token IDENTIFIER CONSTANT STRING_LITERAL
|
||||
%token TUNABLE KERNEL READONLY WRITEONLY
|
||||
%token TUNABLE KERNEL READONLY WRITEONLY CONST
|
||||
%token PTR_OP INC_OP DEC_OP LEFT_OP RIGHT_OP LE_OP GE_OP EQ_OP NE_OP
|
||||
%token AND_OP OR_OP MUL_ASSIGN DIV_ASSIGN MOD_ASSIGN ADD_ASSIGN
|
||||
%token SUB_ASSIGN LEFT_ASSIGN RIGHT_ASSIGN AND_ASSIGN
|
||||
@@ -360,7 +360,8 @@ init_declarator
|
||||
;
|
||||
|
||||
storage_class_specifier
|
||||
: TUNABLE { $$ = new token(TUNABLE_T); }
|
||||
: CONST { $$ = new token(CONST_T); }
|
||||
| TUNABLE { $$ = new token(TUNABLE_T); }
|
||||
| KERNEL { $$ = new token(KERNEL_T); }
|
||||
| READONLY { $$ = new token(READONLY_T); }
|
||||
| WRITEONLY { $$ = new token(WRITEONLY_T); }
|
||||
|
@@ -16,6 +16,7 @@ int comment();
|
||||
%}
|
||||
|
||||
%%
|
||||
"const" { count(); return(CONST); }
|
||||
"tunable" { count(); return(TUNABLE); }
|
||||
"kernel" { count(); return(KERNEL); }
|
||||
"readonly" { count(); return(READONLY); }
|
||||
|
@@ -58,6 +58,7 @@ public:
|
||||
void set_value(const std::string& name, value* x);
|
||||
void set_type(const std::string& name, basic_block* block, type* x);
|
||||
void set_type(const std::string& name, type* x);
|
||||
void set_const(const std::string& name);
|
||||
void set_continue_fn(std::function<ir::value*()> fn);
|
||||
// Getters
|
||||
value *get_value(const std::string& name, basic_block* block);
|
||||
@@ -83,6 +84,7 @@ private:
|
||||
builder builder_;
|
||||
std::map<val_key_t, value*> values_;
|
||||
std::map<val_key_t, type*> types_;
|
||||
std::set<std::string> const_;
|
||||
std::set<basic_block*> sealed_blocks_;
|
||||
std::map<basic_block*, std::map<std::string, phi_node*>> incomplete_phis_;
|
||||
functions_list_t functions_;
|
||||
|
@@ -99,25 +99,10 @@ public:
|
||||
static integer_type *get_int64_ty(context &ctx);
|
||||
static integer_type *get_int128_ty(context &ctx);
|
||||
|
||||
// Attributes
|
||||
type* set_tunable() { is_tunable_ = true; return this; }
|
||||
type* set_readonly() { is_readonly_ = true; return this; }
|
||||
type* set_writeonly() { is_writeonly_ = true; return this; }
|
||||
type* set_kernel() { is_kernel_ = true; return this; }
|
||||
|
||||
bool get_tunable() { return is_tunable_; }
|
||||
bool get_readonly() { return is_readonly_; }
|
||||
bool get_writeonly() { return is_writeonly_; }
|
||||
bool get_kernel() { return is_kernel_; }
|
||||
|
||||
private:
|
||||
context &ctx_;
|
||||
id_t id_;
|
||||
// attributes
|
||||
bool is_tunable_;
|
||||
bool is_readonly_;
|
||||
bool is_writeonly_;
|
||||
bool is_kernel_;
|
||||
|
||||
protected:
|
||||
contained_tys_vec_t contained_tys_;
|
||||
|
@@ -170,22 +170,31 @@ ir::type* typed_declaration_specifier::type(ir::module *mod) const {
|
||||
}
|
||||
}
|
||||
|
||||
ir::type* storage_declaration_specifier::type(ir::module *mod) const {
|
||||
ir::type* result = decl_spec_->type(mod);
|
||||
switch(storage_spec_){
|
||||
case TUNABLE_T: return result->set_tunable();
|
||||
case KERNEL_T: return result->set_kernel();
|
||||
case READONLY_T: return result->set_readonly();
|
||||
case WRITEONLY_T: return result->set_writeonly();
|
||||
default: throw std::runtime_error("unreachable");
|
||||
}
|
||||
std::vector<STORAGE_SPEC_T> typed_declaration_specifier::storage() const {
|
||||
return {};
|
||||
}
|
||||
|
||||
|
||||
ir::type* storage_declaration_specifier::type(ir::module *mod) const {
|
||||
return decl_spec_->type(mod);
|
||||
}
|
||||
|
||||
std::vector<STORAGE_SPEC_T> storage_declaration_specifier::storage() const {
|
||||
auto result = decl_spec_->storage();
|
||||
result.push_back(storage_spec_);
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
/* Parameter */
|
||||
ir::type* parameter::type(ir::module *mod) const {
|
||||
return decl_->type(mod, spec_->type(mod));
|
||||
}
|
||||
|
||||
std::vector<STORAGE_SPEC_T> parameter::storage() const {
|
||||
return spec_->storage();
|
||||
}
|
||||
|
||||
const identifier *parameter::id() const {
|
||||
return decl_->id();
|
||||
}
|
||||
@@ -209,8 +218,11 @@ const std::string &identifier::name() const{
|
||||
// Tile
|
||||
ir::type* tile::type_impl(ir::module *mod, ir::type *type) const{
|
||||
ir::type::tile_shapes_t shapes;
|
||||
for(constant *cst: shapes_->values())
|
||||
shapes.push_back((ir::constant_int*)cst->codegen(mod));
|
||||
for(expression *expr: shapes_->values()){
|
||||
ir::constant_int *shape = dynamic_cast<ir::constant_int*>(expr->codegen(mod));
|
||||
assert(shape);
|
||||
shapes.push_back(shape);
|
||||
}
|
||||
return ir::tile_type::get(type, shapes);
|
||||
}
|
||||
|
||||
@@ -368,11 +380,12 @@ void initializer::set_specifier(const declaration_specifier *spec) {
|
||||
|
||||
ir::value* initializer::codegen(ir::module * mod) const{
|
||||
ir::type *ty = decl_->type(mod, spec_->type(mod));
|
||||
std::vector<STORAGE_SPEC_T> storage = spec_->storage();
|
||||
std::string name = decl_->id()->name();
|
||||
ir::value *value = ir::undef_value::get(ty);
|
||||
if(ty->get_tunable()){
|
||||
if(std::find(storage.begin(), storage.end(), TUNABLE_T) != storage.end()){
|
||||
assert(expr_ == nullptr);
|
||||
//TODO
|
||||
//TODO: implement ranges
|
||||
value = ir::metaparameter::create(mod->get_context(), ty, 4, 8);
|
||||
}
|
||||
if(expr_){
|
||||
@@ -383,6 +396,8 @@ ir::value* initializer::codegen(ir::module * mod) const{
|
||||
value->set_name(name);
|
||||
mod->set_value(name, value);
|
||||
mod->set_type(name, ty);
|
||||
if(std::find(storage.begin(), storage.end(), CONST_T) != storage.end())
|
||||
mod->set_const(name);
|
||||
return value;
|
||||
}
|
||||
|
||||
|
@@ -44,7 +44,6 @@ llvm::Type *distributed_tile::make_vector_ty(llvm::Type *ty, size_t vector_size)
|
||||
return VectorType::get(ty, vector_size);
|
||||
}
|
||||
|
||||
|
||||
distributed_tile::distributed_tile(Type *ty, const shapes_t &shapes, const axes_t &axes, llvm::IRBuilder<> &builder, bool vectorize)
|
||||
: tile(make_vector_ty(ty, vectorize?axes[0].contiguous:1), shapes), axes_(axes), builder_(builder) {
|
||||
vector_size_ = vectorize?ty_->getVectorNumElements():1;
|
||||
@@ -150,16 +149,6 @@ Value* shared_tile::get_value(indices_t idx) {
|
||||
return builder_.CreateLoad(ptr);
|
||||
}
|
||||
|
||||
/* Utils */
|
||||
std::vector<unsigned> selection::extract_shapes(ir::value *v) {
|
||||
const auto& shapes = v->get_type()->get_tile_shapes();
|
||||
std::vector<unsigned> result(shapes.size());
|
||||
for(ir::constant_int* cst: shapes)
|
||||
result.push_back(cst->get_value());
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
/* convert ir::type to Type */
|
||||
Type *selection::llvm_type(ir::type *ty, LLVMContext &ctx) {
|
||||
// function
|
||||
@@ -310,12 +299,11 @@ std::vector<Value*> delinearize(Value *trailing, std::vector<unsigned> &shapes,
|
||||
}
|
||||
|
||||
void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
|
||||
const auto& shapes = extract_shapes(v);
|
||||
const auto& shapes = v->get_type()->get_tile_shapes();
|
||||
size_t dim = shapes.size();
|
||||
std::vector<unsigned> contiguous(dim);
|
||||
std::vector<unsigned> warp_size(dim);
|
||||
std::vector<unsigned> n_warps(dim);
|
||||
std::cout << v->get_name() << " " << typeid(*v).name() << std::endl;
|
||||
for(unsigned i = 0; i < shapes.size(); i++){
|
||||
std::string str_i = std::to_string(i);
|
||||
contiguous[i] = *params_->get_param(v, "p0.d" + str_i);
|
||||
@@ -332,7 +320,7 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id
|
||||
Value *thread_id = builder.CreateAdd(thread_id_in_warp[k], builder.CreateMul(warp_id[k], warp_size_k));
|
||||
thread_id = builder.CreateMul(thread_id, contiguous_k);
|
||||
unsigned per_block = contiguous[k] * warp_size[k] * n_warps[k];
|
||||
unsigned per_thread = contiguous[k] * shapes[k] / per_block;
|
||||
unsigned per_thread = contiguous[k] * shapes[k]->get_value() / per_block;
|
||||
std::vector<Value*> idx_list(per_thread);
|
||||
for(unsigned n = 0 ; n < per_thread; n++){
|
||||
unsigned offset = n / contiguous[k] * per_block + n % contiguous[k];
|
||||
@@ -348,8 +336,8 @@ void selection::create_grids(std::vector<ir::value*> &grids,
|
||||
// get number of dimensions greater than 1
|
||||
auto get_tile_gt1_dim = [&](ir::value *v){
|
||||
unsigned result = 0;
|
||||
for(unsigned shape: extract_shapes(v)) {
|
||||
result += (shape > 1)?shape:0;
|
||||
for(ir::constant_int* shape: v->get_type()->get_tile_shapes()) {
|
||||
result += (shape->get_value() > 1)?shape->get_value():0;
|
||||
}
|
||||
return result;
|
||||
};
|
||||
@@ -365,11 +353,11 @@ void selection::create_grids(std::vector<ir::value*> &grids,
|
||||
for(ir::value *op: user->ops())
|
||||
bind_references(op);
|
||||
// bind
|
||||
const auto& shapes = extract_shapes(v);
|
||||
const auto& shapes = v->get_type()->get_tile_shapes();
|
||||
if(dynamic_cast<ir::copy_to_shared_inst*>(v) || buffer_info_->is_double(v))
|
||||
return;
|
||||
for(size_t d = 0; d < shapes.size(); d++){
|
||||
if(shapes[d] == 1)
|
||||
if(shapes[d]->get_value() == 1)
|
||||
continue;
|
||||
unsigned *x = params_->get_param(v, "p0.d" + std::to_string(d));
|
||||
ir::value *&r = references[x];
|
||||
@@ -397,7 +385,10 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder,
|
||||
for(ir::value *op: user->ops())
|
||||
create_tile(op, builder, references, seen, sh_mem_ptr);
|
||||
LLVMContext &ctx = builder.getContext();
|
||||
const auto& shapes = extract_shapes(v);
|
||||
const auto& shapes = v->get_type()->get_tile_shapes();
|
||||
std::vector<unsigned> shapes2;
|
||||
for(ir::constant_int* shape: shapes)
|
||||
shapes2.push_back(shape->get_value());
|
||||
Type* ty = llvm_type(v->get_type()->get_scalar_ty(), ctx);
|
||||
// create shared tile
|
||||
if(dynamic_cast<ir::copy_to_shared_inst*>(v) || (buffer_info_->is_double(v))){
|
||||
@@ -408,7 +399,7 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder,
|
||||
size_t offset = alloc_->get_offset(v);
|
||||
Value *ptr = builder.CreateGEP(sh_mem_ptr, builder.getInt32(offset));
|
||||
ptr = builder.CreateBitCast(ptr, ptr_ty);
|
||||
tmap_.insert({v, new shared_tile(ty, shapes, ptr, builder)});
|
||||
tmap_.insert({v, new shared_tile(ty, shapes2, ptr, builder)});
|
||||
}
|
||||
}
|
||||
// phi-node (double-buffering)
|
||||
@@ -427,13 +418,13 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder,
|
||||
Value *pre_ptr = builder.CreateGEP(sh_mem_ptr, builder.getInt32(alloc_->get_offset(phi)));
|
||||
pre_ptr = builder.CreateBitCast(pre_ptr, ptr->getType());
|
||||
Value *next_ptr = builder.CreateGEP(ptr, offset);
|
||||
tmap_.insert({phi, new shared_tile(ty, shapes, ptr, builder, offset)});
|
||||
tmap_.insert({phi, new shared_tile(ty, shapes2, ptr, builder, offset)});
|
||||
for(unsigned i = 0; i < phi->get_num_incoming(); i++) {
|
||||
ir::basic_block* inc_block = phi->get_incoming_block(i);
|
||||
ir::value* inc_value = phi->get_incoming_value(i);
|
||||
ir::value* terminator = inc_block->get_inst_list().back();
|
||||
bool is_loop_latch = buffer_info_->is_loop_latch(phi, terminator);
|
||||
tmap_.insert({inc_value, new shared_tile(ty, shapes, is_loop_latch?next_ptr:pre_ptr, builder)});
|
||||
tmap_.insert({inc_value, new shared_tile(ty, shapes2, is_loop_latch?next_ptr:pre_ptr, builder)});
|
||||
}
|
||||
}
|
||||
else
|
||||
@@ -441,10 +432,10 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder,
|
||||
}
|
||||
// create distributed tile
|
||||
else {
|
||||
const auto &shapes = extract_shapes(v);
|
||||
const auto &shapes = v->get_type()->get_tile_shapes();
|
||||
std::vector<distributed_axis> axes(shapes.size());
|
||||
for(size_t d = 0; d < shapes.size(); d++){
|
||||
if(shapes[d] > 1){
|
||||
if(shapes[d]->get_value() > 1){
|
||||
unsigned *x = params_->get_param(v, "p0.d" + std::to_string(d));
|
||||
axes[d] = axes_.at(x);
|
||||
}
|
||||
@@ -454,7 +445,7 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder,
|
||||
}
|
||||
}
|
||||
bool vectorize = dynamic_cast<ir::vectorize_inst*>(v);
|
||||
distributed_tile *T = new distributed_tile(ty, shapes, axes, builder, vectorize);
|
||||
distributed_tile *T = new distributed_tile(ty, shapes2, axes, builder, vectorize);
|
||||
tmap_.insert({v, T});
|
||||
// constant range
|
||||
if(dynamic_cast<ir::constant*>(v)){
|
||||
@@ -542,7 +533,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
distributed_tile* result = (distributed_tile*)ti;
|
||||
if(!ins->get_type()->is_tile_ty())
|
||||
return;
|
||||
const auto& shapes = extract_shapes(ins);
|
||||
const auto& shapes = ins->get_type()->get_tile_shapes();
|
||||
// global_range
|
||||
if(auto *x = dynamic_cast<ir::get_global_range_inst*>(ins)) {
|
||||
static std::array<Intrinsic::ID, 3> ctaid = {
|
||||
@@ -552,7 +543,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
};
|
||||
Function *get_group_id = Intrinsic::getDeclaration(module, ctaid[x->get_axis()]);
|
||||
Value *group_id = builder.CreateCall(get_group_id, {});
|
||||
Value *offset = builder.CreateMul(builder.getInt32(shapes[0]), group_id);
|
||||
Value *offset = builder.CreateMul(builder.getInt32(shapes[0]->get_value()), group_id);
|
||||
result->for_each([&](indices_t idx){
|
||||
BinaryOperator *bin = static_cast<BinaryOperator*>(idx[0]);
|
||||
result->set_value(idx, insert_masked(idx, [&]{ return builder.CreateAdd(bin, offset); }));
|
||||
@@ -565,7 +556,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
result->for_each([&](indices_t out_idx){
|
||||
indices_t in_idx;
|
||||
for(size_t k = 0; k < shapes.size(); k++){
|
||||
if(shapes[k] > 1)
|
||||
if(shapes[k]->get_value() > 1)
|
||||
in_idx.push_back(out_idx[k]);
|
||||
}
|
||||
result->set_value(out_idx, in_tile->get_value(in_idx));
|
||||
@@ -580,12 +571,12 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
// broadcast
|
||||
else if(dynamic_cast<ir::broadcast_inst*>(ins)) {
|
||||
ir::value* in = ins->get_operand(0);
|
||||
const auto& in_shapes = extract_shapes(in);
|
||||
const auto& in_shapes = in->get_type()->get_tile_shapes();
|
||||
distributed_tile *in_tile = (distributed_tile*)tmap_.at(in);
|
||||
result->for_each([&](indices_t out_idx){
|
||||
indices_t in_idx = out_idx;
|
||||
for(size_t k = 0; k < in_idx.size(); k++){
|
||||
if(in_shapes[k] == 1)
|
||||
if(in_shapes[k]->get_value() == 1)
|
||||
in_idx[k] = builder.getInt32(0);
|
||||
}
|
||||
result->set_value(out_idx, in_tile->get_value(in_idx));
|
||||
@@ -627,7 +618,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
Function *f_mul_add = Intrinsic::getDeclaration(module, Intrinsic::fmuladd, {llvm_type(C->get_type()->get_scalar_ty(), ctx)});
|
||||
result->for_each([&](indices_t idx){
|
||||
Value *res = tmap_.at(C)->get_value(idx);
|
||||
unsigned NK = extract_shapes(A)[1];
|
||||
unsigned NK = A->get_type()->get_tile_shapes()[1]->get_value();
|
||||
for(unsigned K = 0; K < NK; ++K){
|
||||
indices_t a_idx = {idx[0], builder.getInt32(K)};
|
||||
indices_t b_idx = {idx[1], builder.getInt32(K)};
|
||||
|
@@ -10,7 +10,7 @@ namespace tdl{
|
||||
namespace ir{
|
||||
|
||||
builder::builder(context &ctx):
|
||||
ctx_(ctx){}
|
||||
ctx_(ctx), block_(nullptr), insert_point_(nullptr) {}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// utilities
|
||||
|
@@ -37,6 +37,10 @@ void module::set_type(const std::string& name, ir::type *type){
|
||||
return set_type(name, builder_.get_insert_block(), type);
|
||||
}
|
||||
|
||||
void module::set_const(const std::string& name){
|
||||
const_.insert(name);
|
||||
}
|
||||
|
||||
void module::set_continue_fn(std::function<ir::value*()> fn) {
|
||||
continue_fn_ = fn;
|
||||
}
|
||||
@@ -91,10 +95,12 @@ ir::value *module::add_phi_operands(const std::string& name, ir::phi_node *&phi)
|
||||
|
||||
ir::value *module::get_value_recursive(const std::string& name, ir::basic_block *block) {
|
||||
ir::value *result;
|
||||
bool is_const = const_.find(name) != const_.end();
|
||||
auto &preds = block->get_predecessors();
|
||||
ir::type *ty = get_type(name, block);
|
||||
if(block)
|
||||
if(sealed_blocks_.find(block) == sealed_blocks_.end()){
|
||||
incomplete_phis_[block][name] = make_phi(get_type(name, block), 1, block);
|
||||
if(!is_const && sealed_blocks_.find(block) == sealed_blocks_.end()){
|
||||
incomplete_phis_[block][name] = make_phi(ty, 1, block);
|
||||
result = (ir::value*)incomplete_phis_[block][name];
|
||||
}
|
||||
else if(preds.size() <= 1){
|
||||
@@ -102,7 +108,7 @@ ir::value *module::get_value_recursive(const std::string& name, ir::basic_block
|
||||
result = get_value(name, has_pred?preds.front():nullptr);
|
||||
}
|
||||
else{
|
||||
result = make_phi(get_type(name, block), 1, block);
|
||||
result = make_phi(ty, 1, block);
|
||||
set_value(name, block, result);
|
||||
result = add_phi_operands(name, (ir::phi_node*&)result);
|
||||
}
|
||||
|
Reference in New Issue
Block a user