diff --git a/include/triton/codegen/analysis/layout.h b/include/triton/codegen/analysis/layout.h index 096c45ea3..923e13411 100644 --- a/include/triton/codegen/analysis/layout.h +++ b/include/triton/codegen/analysis/layout.h @@ -53,6 +53,7 @@ struct layout_t { const std::vector& _axes, const std::vector &_shapes, const std::vector &_values, + ir::type *_ty, size_t _id, analysis::align* align); @@ -79,6 +80,7 @@ struct layout_hmma_884_t: public layout_t { const std::vector& _axes, const std::vector& _shapes, const std::vector &_values, + ir::type *_ty, size_t _id, analysis::align* align); void accept(layout_visitor* vst) { vst->visit_layout_hmma_884(this); } @@ -89,6 +91,7 @@ struct layout_scanline_t: public layout_t { const std::vector& _axes, const std::vector& _shapes, const std::vector &values, + ir::type *_ty, size_t _id, analysis::align* align); void accept(layout_visitor* vst) { vst->visit_layout_scanline(this); } diff --git a/include/triton/codegen/selection.h b/include/triton/codegen/selection.h index 06ec94222..82466ab93 100644 --- a/include/triton/codegen/selection.h +++ b/include/triton/codegen/selection.h @@ -147,44 +147,54 @@ private: }; class machine_layout_t { - + virtual tile* create(ir::value *v) = 0; }; class machine_layout_shared_t: public machine_layout_t { - +public: + shared_tile* create(ir::value *v); }; -class machine_layout_hmma_884_t: public machine_layout_t { +class machine_layout_distributed_t: public machine_layout_t { +public: + machine_layout_distributed_t(Module *mod, Builder *builder, target *tgt, Type *ty, + analysis::axes *a_axes, std::map& axes, + analysis::layout_t* layout); + + distributed_tile* create(ir::value *v); + Module *mod_; + Builder *builder_; + target *tgt_; + Type *ty_; + analysis::axes *a_axes_; + std::map& axes_; + analysis::layout_t* layout_; +}; + + +class machine_layout_hmma_884_t: public machine_layout_distributed_t { public: machine_layout_hmma_884_t(Module *mod, Builder *builder, - target *tgt, std::map& axes, + target *tgt, Type *ty, + analysis::axes *a_axes, std::map& axes, Value *&offset_a_i, Value *&offset_a_k, Value *&offset_b_j, Value *&offset_b_k, unsigned &pack_size_0, unsigned &pack_size_1, unsigned &num_packs_0, unsigned &num_packs_1, analysis::layout_hmma_884_t* layout); - Module *mod_; - Builder *builder_; - target *tgt_; - std::map& axes_; Value *&offset_a_i_, *&offset_a_k_; Value *&offset_b_j_, *&offset_b_k_; unsigned &pack_size_0_; unsigned& pack_size_1_; unsigned &num_packs_0_; unsigned& num_packs_1_; - analysis::layout_hmma_884_t* layout_; }; -class machine_layout_scanline_t: public machine_layout_t { +class machine_layout_scanline_t: public machine_layout_distributed_t { public: machine_layout_scanline_t(Module *mod, Builder *builder, - target *tgt, std::map& axes, + target *tgt, Type *ty, + analysis::axes *a_axes, std::map& axes, analysis::layout_scanline_t* layout); - Module *mod_; - Builder *builder_; - target *tgt_; - std::map& axes_; - analysis::layout_scanline_t* layout_; }; class generator: public ir::visitor, public analysis::layout_visitor { @@ -194,7 +204,6 @@ private: void visit_outer_dot(ir::dot_inst*, distributed_tile *TC, distributed_tile *TA, distributed_tile *TB, distributed_tile *TD, unsigned NK, Type *c_ty, Function *f_mul_add); - Type *type(ir::type *ty); void for_each(ir::value *x, const std::function& fn); Value* get_value(ir::value *x, const indices_t& idx); void set_value(ir::value *x, const indices_t& idx, Value* v); @@ -203,6 +212,7 @@ public: generator(LLVMContext *ctx, Module *dst, Builder *builder, + analysis::axes *a_axes, std::map& axes, std::map& vmap, std::map& tmap, @@ -216,12 +226,13 @@ public: unsigned num_packs_0, unsigned num_packs_1, unsigned pack_size_0, unsigned pack_size_1, unsigned num_warps) - : ctx_(ctx), mod_(dst), builder_(builder), axes_(axes), vmap_(vmap), tmap_(tmap), tgt_(tgt), + : ctx_(ctx), mod_(dst), builder_(builder), a_axes_(a_axes), axes_(axes), vmap_(vmap), tmap_(tmap), tgt_(tgt), layouts_(layouts), alignment_(alignment), alloc_(alloc), sh_mem_ptr_(sh_mem_ptr), offset_a_i_(offset_a_i), offset_a_k_(offset_a_k), offset_b_j_(offset_b_j), offset_b_k_(offset_b_k), num_packs_0_(num_packs_0), num_packs_1_(num_packs_1), pack_size_0_(pack_size_0), pack_size_1_(pack_size_1), num_warps_(num_warps) { } + machine_layout_t *get_machine_layout(const analysis::layout_t *layout) { return machine_layouts_.at(layout); } void visit_phi_node(ir::phi_node*); void visit_binary_operator(ir::binary_operator*); @@ -281,7 +292,8 @@ private: Builder *builder_; Module *mod_; - std::map machine_layouts_; + std::map machine_layouts_; + analysis::axes *a_axes_; std::map& axes_; std::map& vmap_; std::map& tmap_; @@ -311,7 +323,6 @@ private: // grid construction void create_shared_tile(ir::value *v, Builder &builder, Value *sh_mem_ptr); - void create_distributed_tile(ir::value *v, Builder &builder); // lower scalar instruction void lower_value(ir::value *src, Builder &builder, generator* gen, std::set& seen); diff --git a/lib/codegen/analysis/layout.cc b/lib/codegen/analysis/layout.cc index cfd3b3c47..f435efef8 100644 --- a/lib/codegen/analysis/layout.cc +++ b/lib/codegen/analysis/layout.cc @@ -128,9 +128,9 @@ inline bool is_trans(ir::value *v) { layout_t::layout_t(layout_type_t _type, const std::vector &_axes, const std::vector &_shapes, - const std::vector &_values, + const std::vector &_values, ir::type *_ty, size_t _id, - analysis::align* align): type(_type), axes(_axes), shapes(_shapes), values(_values), id(_id) { + analysis::align* align): type(_type), axes(_axes), shapes(_shapes), values(_values), id(_id), ty(_ty) { // io pointer std::set ptr; for(ir::value* v: values) @@ -152,8 +152,8 @@ inline unsigned clamp(unsigned x, unsigned lo, unsigned hi) { layout_hmma_884_t::layout_hmma_884_t(size_t num_warps, const std::vector& _axes, const std::vector& _shapes, - const std::vector &values, size_t _id, - analysis::align* align): layout_t(HMMA_884, _axes, _shapes, values, _id, align) { + const std::vector &values, ir::type *_ty, size_t _id, + analysis::align* align): layout_t(HMMA_884, _axes, _shapes, values, _ty, _id, align) { unsigned shape_0 = shapes[order[0]]; unsigned shape_1 = shapes[order[1]]; @@ -194,9 +194,9 @@ layout_hmma_884_t::layout_hmma_884_t(size_t num_warps, layout_scanline_t::layout_scanline_t(size_t num_warps, const std::vector& _axes, const std::vector& _shapes, - const std::vector &values, + const std::vector &values, ir::type *_ty, size_t _id, - analysis::align* align): layout_t(SCANLINE, _axes, _shapes, values, _id, align){ + analysis::align* align): layout_t(SCANLINE, _axes, _shapes, values, _ty, _id, align){ unsigned size = std::accumulate(shapes.begin(), shapes.end(), 1, std::multiplies()); unsigned num_threads = num_warps * 32; nts.resize(shapes.size()); @@ -263,9 +263,8 @@ layout_shared_t::layout_shared_t(const layout_t *arg, const std::vector &values, ir::type *ty, size_t _id, - analysis::align* align): layout_t(SHARED, _axes, _shapes, values, _id, align) { + analysis::align* align): layout_t(SHARED, _axes, _shapes, values, ty, _id, align) { - this->ty = ty; size = 0; // double-buffering @@ -333,7 +332,7 @@ void layout::create(size_t id, const std::vector& values) { }); // type if(it_hmma_c != values.end()) - layouts_[id] = new layout_hmma_884_t(num_warps_, axes, shapes, values, id, align_); + layouts_[id] = new layout_hmma_884_t(num_warps_, axes, shapes, values, largest->get_type()->get_scalar_ty(), id, align_); else if(it_cts != values.end()){ ir::copy_to_shared_inst *cts = (ir::copy_to_shared_inst*)*it_cts; ir::value *arg = cts->get_operand(0); @@ -341,7 +340,7 @@ void layout::create(size_t id, const std::vector& values) { layouts_[id] = new layout_shared_t(get(arg), axes, shapes, values, largest->get_type()->get_scalar_ty(), id, align_); } else - layouts_[id] = new layout_scanline_t(num_warps_, axes, shapes, values, id, align_); + layouts_[id] = new layout_scanline_t(num_warps_, axes, shapes, values, largest->get_type()->get_scalar_ty(), id, align_); } void layout::run(ir::module &mod) { diff --git a/lib/codegen/selection.cc b/lib/codegen/selection.cc index 9321d49fa..be97ac2e4 100644 --- a/lib/codegen/selection.cc +++ b/lib/codegen/selection.cc @@ -343,6 +343,42 @@ Type *selection::llvm_type(ir::type *ty, LLVMContext &ctx) { throw std::runtime_error("unknown conversion from ir::type to Type"); } +Type *type(ir::type *ty, LLVMContext &ctx) { + // function + if(auto* tt = dynamic_cast(ty)){ + Type *return_ty = type(tt->get_return_ty(), ctx); + std::vector param_tys; + std::transform(tt->params_begin(), tt->params_end(), std::back_inserter(param_tys), + [&ctx](ir::type* t){ return type(t, ctx);}); + return FunctionType::get(return_ty, param_tys, false); + } + // pointer + if(ty->is_pointer_ty()){ + Type *elt_ty = type(ty->get_pointer_element_ty(), ctx); + unsigned addr_space = ty->get_pointer_address_space(); + return PointerType::get(elt_ty, addr_space); + } + // integer + if(ty->is_integer_ty()){ + unsigned bitwidth = ty->get_integer_bitwidth(); + return IntegerType::get(ctx, bitwidth); + } + // primitive types + switch(ty->get_type_id()){ + case ir::type::VoidTyID: return Type::getVoidTy(ctx); + case ir::type::HalfTyID: return Type::getHalfTy(ctx); + case ir::type::FloatTyID: return Type::getFloatTy(ctx); + case ir::type::DoubleTyID: return Type::getDoubleTy(ctx); + case ir::type::X86_FP80TyID: return Type::getX86_FP80Ty(ctx); + case ir::type::PPC_FP128TyID: return Type::getPPC_FP128Ty(ctx); + case ir::type::LabelTyID: return Type::getLabelTy(ctx); + case ir::type::MetadataTyID: return Type::getMetadataTy(ctx); + case ir::type::TokenTyID: return Type::getTokenTy(ctx); + default: break; + } + // unknown type + throw std::runtime_error("unknown conversion from ir::type to Type"); +} /* ------------------- @@ -410,24 +446,6 @@ void selection::create_shared_tile(ir::value *v, IRBuilder<> &builder, Value *sh } } -void selection::create_distributed_tile(ir::value *v, IRBuilder<> &builder) { - Type* ty = llvm_type(v->get_type()->get_scalar_ty(), builder.getContext()); - const auto &shapes = v->get_type()->get_tile_shapes(); - std::vector axes(shapes.size()); - for(size_t d = 0; d < shapes.size(); d++){ - if(shapes[d] > 1){ - unsigned x = a_axes_->get(v, d); - axes[d] = axes_.at(x); - } - else{ - axes[d].contiguous = 1; - axes[d].values = {builder.getInt32(0)}; - } - } - distributed_tile *T = new distributed_tile(ty, shapes, layouts_->get(v)->order, axes, builder, false); - tmap_.insert({v, T}); -} - bool is_trans(ir::value *v) { if(dynamic_cast(v)) { @@ -454,7 +472,7 @@ void selection::lower_value(ir::value *src, IRBuilder<> &builder, generator* gen if(i && layouts_->get(i)->type == analysis::SHARED) create_shared_tile(i, builder, sh_mem_ptr_); else - create_distributed_tile(src, builder); + tmap_[src] = ((machine_layout_distributed_t*)gen->get_machine_layout(layouts_->get(src)))->create(src); } builder.SetInsertPoint(current); @@ -521,7 +539,7 @@ void selection::run(ir::module &src, Module &dst) { std::set seen; // create tile - generator gen(&dst_ctx, &dst, &dst_builder, axes_, vmap_, tmap_, tgt_, layouts_, alignment_, alloc_, sh_mem_ptr_, + generator gen(&dst_ctx, &dst, &dst_builder, a_axes_, axes_, vmap_, tmap_, tgt_, layouts_, alignment_, alloc_, sh_mem_ptr_, offset_a_i_, offset_a_k_, offset_b_j_, offset_b_k_, num_packs_0_, num_packs_1_, pack_size_0_, pack_size_1_, num_warps_ ); for(ir::alloc_const *x: src.allocs()) @@ -606,7 +624,7 @@ void selection::run(ir::module &src, Module &dst) { void generator::visit_phi_node(ir::phi_node* phi) { - Type *ty = type(phi->get_type()->get_scalar_ty()); + Type *ty = type(phi->get_type()->get_scalar_ty(), *ctx_); unsigned num_ops = phi->get_num_operands(); for_each(phi, [&](indices_t idx){ set_value(phi, idx, builder_->Insert(PHINode::Create(ty, num_ops))); @@ -628,7 +646,7 @@ void generator::visit_getelementptr_inst(ir::getelementptr_inst* gep) { std::vector idx_vals; std::transform(gep->idx_begin(), gep->idx_end(), std::back_inserter(idx_vals), [&](ir::value* x){ return get_value(x, idx);}); - Type *source_ty = type(gep->get_source_elt_ty()->get_scalar_ty()); + Type *source_ty = type(gep->get_source_elt_ty()->get_scalar_ty(), *ctx_); Value *ret = builder_->Insert(GetElementPtrInst::CreateInBounds(source_ty, ptr, idx_vals)); set_value(gep, idx, ret); }); @@ -657,7 +675,7 @@ void generator::visit_fcmp_inst(ir::fcmp_inst* fcmp) { void generator::visit_cast_inst(ir::cast_inst* cast) { for_each(cast, [&](indices_t idx){ Value *arg = get_value(cast->get_operand(0), idx); - Type *dst_ty = type(cast->get_type()->get_scalar_ty()); + Type *dst_ty = type(cast->get_type()->get_scalar_ty(), *ctx_); Value *ret = builder_->Insert(CastInst::Create(llvm_op(cast->get_op()), arg, dst_ty)); set_value(cast, idx, ret); }); @@ -1102,7 +1120,7 @@ void generator::visit_dot_inst(ir::dot_inst* dot) { ir::value *D = dot->get_operand(2); distributed_tile *TD = (distributed_tile*)tmap_.at(D); - Type *c_ty = type(D->get_type()->get_scalar_ty()); + Type *c_ty = type(D->get_type()->get_scalar_ty(), *ctx_); Function *f_mul_add = Intrinsic::getDeclaration(module, Intrinsic::fmuladd, {c_ty}); auto A_shapes = A->get_type()->get_tile_shapes(); size_t red_axis = 1; @@ -1228,60 +1246,25 @@ void generator::visit_make_range(ir::make_range* x) { }); } -Type *generator::type(ir::type *ty) { - // function - if(auto* tt = dynamic_cast(ty)){ - Type *return_ty = type(tt->get_return_ty()); - std::vector param_tys; - std::transform(tt->params_begin(), tt->params_end(), std::back_inserter(param_tys), - [this](ir::type* t){ return type(t);}); - return FunctionType::get(return_ty, param_tys, false); - } - // pointer - if(ty->is_pointer_ty()){ - Type *elt_ty = type(ty->get_pointer_element_ty()); - unsigned addr_space = ty->get_pointer_address_space(); - return PointerType::get(elt_ty, addr_space); - } - // integer - if(ty->is_integer_ty()){ - unsigned bitwidth = ty->get_integer_bitwidth(); - return IntegerType::get(*ctx_, bitwidth); - } - // primitive types - switch(ty->get_type_id()){ - case ir::type::VoidTyID: return Type::getVoidTy(*ctx_); - case ir::type::HalfTyID: return Type::getHalfTy(*ctx_); - case ir::type::FloatTyID: return Type::getFloatTy(*ctx_); - case ir::type::DoubleTyID: return Type::getDoubleTy(*ctx_); - case ir::type::X86_FP80TyID: return Type::getX86_FP80Ty(*ctx_); - case ir::type::PPC_FP128TyID: return Type::getPPC_FP128Ty(*ctx_); - case ir::type::LabelTyID: return Type::getLabelTy(*ctx_); - case ir::type::MetadataTyID: return Type::getMetadataTy(*ctx_); - case ir::type::TokenTyID: return Type::getTokenTy(*ctx_); - default: break; - } - // unknown type - throw std::runtime_error("unknown conversion from ir::type to Type"); -} + void generator::visit_undef_value(ir::undef_value *ud) { - vmap_[ud] = llvm::UndefValue::get(type(ud->get_type())); + vmap_[ud] = llvm::UndefValue::get(type(ud->get_type(), *ctx_)); } void generator::visit_constant_int(ir::constant_int *cst){ - Type *ty = type(cst->get_type()->get_scalar_ty()); + Type *ty = type(cst->get_type()->get_scalar_ty(), *ctx_); vmap_[cst] = ConstantInt::get(ty, cst->get_value()); } void generator::visit_constant_fp(ir::constant_fp *cst){ - Type *ty = type(cst->get_type()->get_scalar_ty()); + Type *ty = type(cst->get_type()->get_scalar_ty(), *ctx_); vmap_[cst] = ConstantFP::get(ty, cst->get_value()); } void generator::visit_alloc_const(ir::alloc_const *alloc) { unsigned size = ((ir::constant_int*)alloc->get_operand(0))->get_value(); - Type *element_ty = type(alloc->get_type()->get_pointer_element_ty()); + Type *element_ty = type(alloc->get_type()->get_pointer_element_ty(), *ctx_); Type *array_ty = llvm::ArrayType::get(element_ty, size); Value *array = new llvm::GlobalVariable(*mod_, array_ty, false, llvm::GlobalVariable::ExternalLinkage, nullptr, alloc->get_name(), nullptr, llvm::GlobalVariable::NotThreadLocal, 4); @@ -1291,7 +1274,7 @@ void generator::visit_alloc_const(ir::alloc_const *alloc) { void generator::visit_function(ir::function* fn) { LLVMContext &ctx = builder_->getContext(); - FunctionType *fn_ty = (FunctionType*)type(fn->get_fn_type()); + FunctionType *fn_ty = (FunctionType*)type(fn->get_fn_type(), *ctx_); FunctionType *dst_fn_ty = fn_ty; if(!tgt_->is_gpu()){ Type *dst_fn_ret_ty = fn_ty->getReturnType(); @@ -1331,16 +1314,86 @@ void generator::visit_function(ir::function* fn) { fn_ = ret; } +void generator::visit_layout_hmma_884(analysis::layout_hmma_884_t* layout) { + machine_layouts_[layout] = new machine_layout_hmma_884_t(mod_, builder_, tgt_, type(layout->ty->get_scalar_ty(), *ctx_), a_axes_, axes_, + offset_a_i_, offset_a_k_, offset_b_j_, offset_b_k_, + pack_size_0_, pack_size_1_, + num_packs_0_, num_packs_1_, + layout); +} + +void generator::visit_layout_scanline(analysis::layout_scanline_t* layout) { + machine_layouts_[layout] = new machine_layout_scanline_t(mod_, builder_, tgt_, type(layout->ty->get_scalar_ty(), *ctx_), a_axes_, axes_, layout); +} + +void generator::visit_layout_shared(analysis::layout_shared_t* layout) { + + machine_layouts_[layout] = new machine_layout_shared_t(); +} + +void generator::for_each(ir::value *x, const std::function& fn) { + if(!x->get_type()->is_tile_ty()) + return fn({}); + else { + if(auto *dt = dynamic_cast(tmap_.at(x))) + dt->for_each(fn); + } +} + +Value* generator::get_value(ir::value *x, const indices_t& idx) { + if(x->get_type()->is_tile_ty()) + return tmap_.at(x)->get_value(idx); + return vmap_.at(x); +} + +void generator::set_value(ir::value *x, const indices_t& idx, Value* v) { + if(x->get_type()->is_tile_ty()) + tmap_.at(x)->set_value(idx, v); + else + vmap_[x] = v; +} + + + +shared_tile* machine_layout_shared_t::create(ir::value *v) { + +} + +machine_layout_distributed_t::machine_layout_distributed_t(Module *mod, Builder *builder, target *tgt, Type *ty, + analysis::axes *a_axes, std::map& axes, + analysis::layout_t *layout) + : mod_(mod), builder_(builder), tgt_(tgt), ty_(ty), a_axes_(a_axes), axes_(axes), layout_(layout) { + +} + +distributed_tile* machine_layout_distributed_t::create(ir::value *v) { + Type *ty = type(v->get_type()->get_scalar_ty(), builder_->getContext()); + const auto &shapes = v->get_type()->get_tile_shapes(); + std::vector axes(shapes.size()); + for(size_t d = 0; d < shapes.size(); d++){ + if(shapes[d] > 1){ + unsigned x = a_axes_->get(v, d); + axes[d] = axes_.at(x); + } + else{ + axes[d].contiguous = 1; + axes[d].values = {builder_->getInt32(0)}; + } + } + return new distributed_tile(ty, shapes, layout_->order, axes, *builder_, false); +} + machine_layout_hmma_884_t::machine_layout_hmma_884_t(Module *mod, Builder *builder, - target *tgt, std::map& axes, + target *tgt, Type *ty, analysis::axes *a_axes, + std::map& axes, Value *&offset_a_i, Value *&offset_a_k, Value *&offset_b_j, Value *&offset_b_k, unsigned &pack_size_0, unsigned &pack_size_1, unsigned &num_packs_0, unsigned &num_packs_1, analysis::layout_hmma_884_t* layout) - : mod_(mod), builder_(builder), tgt_(tgt), axes_(axes), + : machine_layout_distributed_t(mod, builder, tgt, ty, a_axes, axes, layout), offset_a_i_(offset_a_i), offset_a_k_(offset_a_k), offset_b_j_(offset_b_j), offset_b_k_(offset_b_k), - pack_size_0_(pack_size_0), pack_size_1_(pack_size_1), num_packs_0_(num_packs_0), num_packs_1_(num_packs_1), - layout_(layout) { + pack_size_0_(pack_size_0), pack_size_1_(pack_size_1), num_packs_0_(num_packs_0), num_packs_1_(num_packs_1) { + Value *warp_size = builder_->getInt32(32); Value* u_thread_id_0 = tgt_->get_local_id(mod_, *builder_, 0); Value *u_thread_id = builder_->CreateURem(u_thread_id_0, warp_size); @@ -1454,10 +1507,11 @@ machine_layout_hmma_884_t::machine_layout_hmma_884_t(Module *mod, Builder *build machine_layout_scanline_t::machine_layout_scanline_t(Module *mod, Builder *builder, - target *tgt, std::map &axes, + target *tgt, Type *ty, + analysis::axes *a_axes, std::map &axes, analysis::layout_scanline_t* layout) - : mod_(mod), builder_(builder), tgt_(tgt), axes_(axes), layout_(layout) -{ + : machine_layout_distributed_t(mod, builder, tgt, ty, a_axes, axes, layout) { + Value *warp_size = builder_->getInt32(32); Value* u_thread_id_0 = tgt_->get_local_id(mod_, *builder_, 0); Value *u_thread_id = builder_->CreateURem(u_thread_id_0, warp_size); @@ -1486,44 +1540,6 @@ machine_layout_scanline_t::machine_layout_scanline_t(Module *mod, Builder *build } } -void generator::visit_layout_hmma_884(analysis::layout_hmma_884_t* layout) { - machine_layouts_[layout] = new machine_layout_hmma_884_t(mod_, builder_, tgt_, axes_, offset_a_i_, offset_a_k_, offset_b_j_, offset_b_k_, - pack_size_0_, pack_size_1_, - num_packs_0_, num_packs_1_, - layout); -} - -void generator::visit_layout_scanline(analysis::layout_scanline_t* layout) { - machine_layouts_[layout] = new machine_layout_scanline_t(mod_, builder_, tgt_, axes_, layout); -} - -void generator::visit_layout_shared(analysis::layout_shared_t* layout) { - - machine_layouts_[layout] = new machine_layout_shared_t(); -} - -void generator::for_each(ir::value *x, const std::function& fn) { - if(!x->get_type()->is_tile_ty()) - return fn({}); - else { - if(auto *dt = dynamic_cast(tmap_.at(x))) - dt->for_each(fn); - } -} - -Value* generator::get_value(ir::value *x, const indices_t& idx) { - if(x->get_type()->is_tile_ty()) - return tmap_.at(x)->get_value(idx); - return vmap_.at(x); -} - -void generator::set_value(ir::value *x, const indices_t& idx, Value* v) { - if(x->get_type()->is_tile_ty()) - tmap_.at(x)->set_value(idx, v); - else - vmap_[x] = v; -} - }