diff --git a/include/triton/codegen/analysis/allocation.h b/include/triton/codegen/analysis/allocation.h index 858152150..49f378886 100644 --- a/include/triton/codegen/analysis/allocation.h +++ b/include/triton/codegen/analysis/allocation.h @@ -27,14 +27,14 @@ public: allocation(liveness *live) : liveness_(live) { } // accessors - bool has_offset(ir::value *x) const { return offsets_.find(x) != offsets_.end(); } - unsigned offset(ir::value *x) const { return offsets_.at(x); } + bool has_offset(const layout_t *x) const { return offsets_.find(x) != offsets_.end(); } + unsigned offset(const layout_t *x) const { return offsets_.at(x); } unsigned allocated_size() const { return allocated_size_; } // run void run(ir::module& mod); private: - std::map offsets_; + std::map offsets_; size_t allocated_size_; // dependences liveness *liveness_; diff --git a/include/triton/codegen/selection.h b/include/triton/codegen/selection.h index 82466ab93..ec00cdd20 100644 --- a/include/triton/codegen/selection.h +++ b/include/triton/codegen/selection.h @@ -147,12 +147,32 @@ private: }; class machine_layout_t { +public: virtual tile* create(ir::value *v) = 0; }; class machine_layout_shared_t: public machine_layout_t { public: - shared_tile* create(ir::value *v); + machine_layout_shared_t(Module *mod, Builder *builder, target *tgt, analysis::allocation* alloc, Value *&sh_mem_ptr, analysis::layout_t* layout, + std::map& vmap, + std::map& tmap); + + tile* create(ir::value *v); + + Module *mod_; + Builder *builder_; + target *tgt_; + analysis::allocation* alloc_; + Value *&sh_mem_ptr_; + analysis::layout_t* layout_; + std::map& vmap_; + std::map& tmap_; + + Value *offset_; + Value *ptr_; + Value *pre_ptr_; + Value *next_ptr_; + }; class machine_layout_distributed_t: public machine_layout_t { @@ -161,7 +181,7 @@ public: analysis::axes *a_axes, std::map& axes, analysis::layout_t* layout); - distributed_tile* create(ir::value *v); + tile* create(ir::value *v); Module *mod_; Builder *builder_; target *tgt_; @@ -309,7 +329,72 @@ private: unsigned num_warps_; }; +class finalizer: public ir::visitor, public analysis::layout_visitor { +private: + void for_each(ir::value *x, const std::function& fn); + Value* get_value(ir::value *x, const indices_t& idx); + void set_value(ir::value *x, const indices_t& idx, Value* v); +public: + finalizer(Builder *builder, std::map& vmap, std::map& tmap); + + void visit_phi_node(ir::phi_node*); + void visit_binary_operator(ir::binary_operator*) { } + void visit_getelementptr_inst(ir::getelementptr_inst*) { } + + void visit_icmp_inst(ir::icmp_inst*) { } + void visit_fcmp_inst(ir::fcmp_inst*) { } + void visit_cast_inst(ir::cast_inst*) { } + + void visit_return_inst(ir::return_inst*) { } + void visit_cond_branch_inst(ir::cond_branch_inst*) { } + void visit_uncond_branch_inst(ir::uncond_branch_inst*) { } + + + void visit_unmasked_load_inst(ir::unmasked_load_inst*) { } + void visit_masked_load_inst(ir::masked_load_inst*) { } + void visit_unmasked_store_inst(ir::unmasked_store_inst*) { } + void visit_masked_store_inst(ir::masked_store_inst*) { } + + void visit_reshape_inst(ir::reshape_inst*) { } + void visit_splat_inst(ir::splat_inst*) { } + void visit_broadcast_inst(ir::broadcast_inst*) { } + void visit_downcast_inst(ir::downcast_inst*) { } + + void visit_get_program_id_inst(ir::get_program_id_inst*) { } + void visit_get_num_program_inst(ir::get_num_program_inst*) { } + void visit_atomic_cas_inst(ir::atomic_cas_inst*) { } + void visit_atomic_exch_inst(ir::atomic_exch_inst*) { } + void visit_atomic_add_inst(ir::atomic_add_inst*) { } + void visit_dot_inst(ir::dot_inst*) { } + void visit_trans_inst(ir::trans_inst*) { } + void visit_sqrt_inst(ir::sqrt_inst*) { } + void visit_reduce_inst(ir::reduce_inst*) { } + void visit_select_inst(ir::select_inst*) { } + + void visit_copy_to_shared_inst(ir::copy_to_shared_inst*) { } + void visit_copy_from_shared_inst(ir::copy_from_shared_inst*) { } + void visit_barrier_inst(ir::barrier_inst*) { } + void visit_make_range_dyn(ir::make_range_dyn*) { } + void visit_make_range(ir::make_range*) { } + + void visit_make_range_sta(ir::make_range_sta*) { } + void visit_undef_value(ir::undef_value*) { } + void visit_constant_int(ir::constant_int*) { } + void visit_constant_fp(ir::constant_fp*) { } + void visit_alloc_const(ir::alloc_const*) { } + + void visit_function(ir::function*) { } + + void visit_layout_hmma_884(analysis::layout_hmma_884_t*) { } + void visit_layout_scanline(analysis::layout_scanline_t*) { } + void visit_layout_shared(analysis::layout_shared_t*); + +private: + Builder *builder_; + std::map& vmap_; + std::map& tmap_; +}; // Selection pass class selection{ @@ -318,12 +403,8 @@ class selection{ private: // LLVM conversions - Type* llvm_type(ir::type *ty, LLVMContext &ctx); Value* alloc_shared(Builder &builder, Module& dst); - // grid construction - void create_shared_tile(ir::value *v, Builder &builder, Value *sh_mem_ptr); - // lower scalar instruction void lower_value(ir::value *src, Builder &builder, generator* gen, std::set& seen); diff --git a/lib/codegen/analysis/allocation.cc b/lib/codegen/analysis/allocation.cc index 2474acded..3ea0a758d 100644 --- a/lib/codegen/analysis/allocation.cc +++ b/lib/codegen/analysis/allocation.cc @@ -99,16 +99,10 @@ void allocation::run(ir::module &mod) { unsigned Adj = 0; for(layout_t* y: interferences[x]) Adj = std::max(Adj, starts[y] + y->size); - // create offsets - for(ir::value *v: x->values){ - offsets_[v] = starts[x] + colors[x] * Adj; - } - if(x->double_buffer){ - auto info = *x->double_buffer; - offsets_[info.latch] = offsets_[info.first] + x->size / 2; - } + offsets_[x] = starts[x] + colors[x] * Adj; } + // Save maximum size of induced memory space allocated_size_ = 0; for(layout_t* x: V) diff --git a/lib/codegen/selection.cc b/lib/codegen/selection.cc index be97ac2e4..96a4632f5 100644 --- a/lib/codegen/selection.cc +++ b/lib/codegen/selection.cc @@ -305,43 +305,6 @@ llvm::CmpInst::Predicate llvm_pred(ir::cmp_pred_t pred) { throw std::runtime_error("unknown operator"); } -/* convert ir::type to Type */ -Type *selection::llvm_type(ir::type *ty, LLVMContext &ctx) { - // function - if(auto* tt = dynamic_cast(ty)){ - Type *return_ty = llvm_type(tt->get_return_ty(), ctx); - std::vector param_tys; - std::transform(tt->params_begin(), tt->params_end(), std::back_inserter(param_tys), - [this,&ctx](ir::type* t){ return llvm_type(t, ctx);}); - return FunctionType::get(return_ty, param_tys, false); - } - // pointer - if(ty->is_pointer_ty()){ - Type *elt_ty = llvm_type(ty->get_pointer_element_ty(), ctx); - unsigned addr_space = ty->get_pointer_address_space(); - return PointerType::get(elt_ty, addr_space); - } - // integer - if(ty->is_integer_ty()){ - unsigned bitwidth = ty->get_integer_bitwidth(); - return IntegerType::get(ctx, bitwidth); - } - // primitive types - switch(ty->get_type_id()){ - case ir::type::VoidTyID: return Type::getVoidTy(ctx); - case ir::type::HalfTyID: return Type::getHalfTy(ctx); - case ir::type::FloatTyID: return Type::getFloatTy(ctx); - case ir::type::DoubleTyID: return Type::getDoubleTy(ctx); - case ir::type::X86_FP80TyID: return Type::getX86_FP80Ty(ctx); - case ir::type::PPC_FP128TyID: return Type::getPPC_FP128Ty(ctx); - case ir::type::LabelTyID: return Type::getLabelTy(ctx); - case ir::type::MetadataTyID: return Type::getMetadataTy(ctx); - case ir::type::TokenTyID: return Type::getTokenTy(ctx); - default: break; - } - // unknown type - throw std::runtime_error("unknown conversion from ir::type to Type"); -} Type *type(ir::type *ty, LLVMContext &ctx) { // function @@ -407,45 +370,6 @@ inline int32_t ceil(int32_t num, int32_t div){ * ---- Init Tiles ---- * ------------------- */ -void selection::create_shared_tile(ir::value *v, IRBuilder<> &builder, Value *sh_mem_ptr) { - if(tmap_.find(v) != tmap_.end()) - return; - analysis::layout_shared_t *layout = (analysis::layout_shared_t*)layouts_->get(v); - auto order = layout->order; - auto shapes = layout->shapes; - shapes[order[0]] += layout->pad; - - Type* ty = llvm_type(v->get_type()->get_scalar_ty(), builder.getContext()); - // shared copy - PointerType *ptr_ty = ty->getPointerTo(sh_mem_ptr->getType()->getPointerAddressSpace()); - // double-buffered - if(layout->double_buffer) { - auto info = *layout->double_buffer; - ir::phi_node *phi = info.phi; - BasicBlock *parent = (BasicBlock*)vmap_[phi->get_parent()]; - if(parent->empty()) - builder.SetInsertPoint(parent); - else - builder.SetInsertPoint(&*parent->getFirstNonPHI()); - // create double-buffered pointer - PHINode *ptr = builder.CreatePHI(ptr_ty, 2); - PHINode *offset = builder.CreatePHI(builder.getInt32Ty(), 2); - // next pointer - Value *pre_ptr = builder.CreateGEP(sh_mem_ptr, builder.getInt32(alloc_->offset(v))); - pre_ptr = builder.CreateBitCast(pre_ptr, ptr->getType()); - Value *next_ptr = builder.CreateGEP(ptr, offset, "next_ptr"); - tmap_.insert({phi, new shared_tile(ty, shapes, order, ptr, builder, offset)}); - tmap_.insert({v, new shared_tile(ty, shapes, order, pre_ptr, builder)}); - tmap_.insert({info.latch, new shared_tile(ty, shapes, order, next_ptr, builder)}); - } - else { - size_t offset = alloc_->offset(v); - Value *ptr = builder.CreateGEP(sh_mem_ptr, builder.getInt32(offset)); - ptr = builder.CreateBitCast(ptr, ptr_ty); - tmap_.insert({v, new shared_tile(ty, shapes, order, ptr, builder)}); - } -} - bool is_trans(ir::value *v) { if(dynamic_cast(v)) { @@ -465,18 +389,11 @@ void selection::lower_value(ir::value *src, IRBuilder<> &builder, generator* gen if(!seen.insert(src).second) return; + if(src->get_type()->is_tile_ty()) + tmap_[src] = gen->get_machine_layout(layouts_->get(src))->create(src); + + BasicBlock *current = builder.GetInsertBlock(); - if(src->get_type()->is_tile_ty()){ - builder.SetInsertPoint(&*builder.GetInsertBlock()->getParent()->begin()); - auto *i = dynamic_cast(src); - if(i && layouts_->get(i)->type == analysis::SHARED) - create_shared_tile(i, builder, sh_mem_ptr_); - else - tmap_[src] = ((machine_layout_distributed_t*)gen->get_machine_layout(layouts_->get(src)))->create(src); - } - builder.SetInsertPoint(current); - - auto *inst = dynamic_cast(src); if(inst && !dynamic_cast(src)) for(ir::value *op: inst->ops()) @@ -541,6 +458,7 @@ void selection::run(ir::module &src, Module &dst) { // create tile generator gen(&dst_ctx, &dst, &dst_builder, a_axes_, axes_, vmap_, tmap_, tgt_, layouts_, alignment_, alloc_, sh_mem_ptr_, offset_a_i_, offset_a_k_, offset_b_j_, offset_b_k_, num_packs_0_, num_packs_1_, pack_size_0_, pack_size_1_, num_warps_ ); + finalizer fin(&dst_builder, vmap_, tmap_); for(ir::alloc_const *x: src.allocs()) x->accept(&gen); @@ -554,69 +472,25 @@ void selection::run(ir::module &src, Module &dst) { x.second->accept(&gen); // generate LLVM-IR code - std::map last_block; for(ir::basic_block *block: fn->blocks()) { BasicBlock *parent = (BasicBlock*)vmap_[block]; dst_builder.SetInsertPoint(parent); for(ir::instruction *i: block->get_inst_list()) lower_value(i, dst_builder, &gen, seen); - last_block[block] = dst_builder.GetInsertBlock(); + vmap_[block] = dst_builder.GetInsertBlock(); } // finalize double-buffering - for(const auto& x: layouts_->get_all()) { - if(x.second->double_buffer) { - auto info = *x.second->double_buffer; - ir::phi_node *phi = info.phi; - PHINode *ptr = (PHINode*)((shared_tile*)tmap_.at(phi))->get_pointer(); - PHINode *offset = (PHINode*)((shared_tile*)tmap_.at(phi))->get_offset(); - for(unsigned n = 0; n < phi->get_num_incoming(); n++){ - ir::basic_block* inc_block = phi->get_incoming_block(n); - ir::value* inc_val = phi->get_incoming_value(n); - BasicBlock *llvm_inc_block = last_block.at(inc_block); - shared_tile *inc_shared = (shared_tile*)tmap_.at(inc_val); - if(inc_val == info.latch){ - dst_builder.SetInsertPoint(llvm_inc_block->getTerminator()); - Value *next_offset = dst_builder.CreateNeg(offset); - offset->addIncoming(next_offset, llvm_inc_block); - } - else { - unsigned num_bytes = x.second->ty->get_primitive_size_in_bits() / 8; - offset->addIncoming(dst_builder.getInt32(x.second->size / (2*num_bytes)), llvm_inc_block); - } - ptr->addIncoming(inc_shared->get_pointer(), llvm_inc_block); - } - } - } + for(const auto& x: layouts_->get_all()) + x.second->accept(&fin); // finalize phi for(ir::basic_block *block: fn->blocks()) for(ir::instruction *inst: block->get_inst_list()) - if(auto *phi = dynamic_cast(inst)){ - if(tmap_.find(phi) == tmap_.end() || - !dynamic_cast(tmap_.at(phi))) { - for(unsigned n = 0; n < phi->get_num_incoming(); n++){ - ir::value *inc_val = phi->get_incoming_value(n); - ir::basic_block *inc_block = phi->get_incoming_block(n); - BasicBlock *llvm_inc_block = last_block.at(inc_block); - if(phi->get_type()->is_tile_ty()) { - distributed_tile *phi_tile = (distributed_tile*)tmap_.at(phi); - distributed_tile *inc_tile = (distributed_tile*)tmap_.at(inc_val); - phi_tile->for_each([&](indices_t idx){ - PHINode *llvm_phi = (PHINode*)phi_tile->get_value(idx); - Value *llvm_inc_val = inc_tile->get_value(idx); - llvm_phi->addIncoming(llvm_inc_val, llvm_inc_block); - }); - } - else { - PHINode *llvm_phi = (PHINode*)vmap_.at(phi); - Value *llvm_inc_val = vmap_.at(inc_val); - llvm_phi->addIncoming(llvm_inc_val, llvm_inc_block); - } - } - } - } + inst->accept(&fin); + } + } @@ -895,7 +769,7 @@ void generator::visit_atomic_cas_inst(ir::atomic_cas_inst* cas) { Value *pred = builder_->CreateICmpEQ(tid, builder_->getInt32(0)); BasicBlock *tid_0_bb = BasicBlock::Create(*ctx_, "tid_0", current->getParent()); BasicBlock *tid_0_done_bb = BasicBlock::Create(*ctx_, "tid_0_done", current->getParent()); - Value *ptr = builder_->CreateGEP(sh_mem_ptr_, builder_->getInt32(alloc_->offset(cas))); + Value *ptr = builder_->CreateGEP(sh_mem_ptr_, builder_->getInt32(alloc_->offset(layouts_->get(cas)))); ptr = builder_->CreateBitCast(ptr, PointerType::get(builder_->getInt32Ty(), ptr->getType()->getPointerAddressSpace())); tgt_->add_memfence(module, *builder_); tgt_->add_barrier(module, *builder_); @@ -1328,7 +1202,7 @@ void generator::visit_layout_scanline(analysis::layout_scanline_t* layout) { void generator::visit_layout_shared(analysis::layout_shared_t* layout) { - machine_layouts_[layout] = new machine_layout_shared_t(); + machine_layouts_[layout] = new machine_layout_shared_t(mod_, builder_, tgt_, alloc_, sh_mem_ptr_, layout, vmap_, tmap_); } void generator::for_each(ir::value *x, const std::function& fn) { @@ -1355,8 +1229,61 @@ void generator::set_value(ir::value *x, const indices_t& idx, Value* v) { -shared_tile* machine_layout_shared_t::create(ir::value *v) { +machine_layout_shared_t::machine_layout_shared_t(Module *mod, Builder *builder, target *tgt, analysis::allocation* alloc, + Value *&sh_mem_ptr, analysis::layout_t *layout, + std::map& vmap, + std::map& tmap) + : mod_(mod), builder_(builder), tgt_(tgt), alloc_(alloc), sh_mem_ptr_(sh_mem_ptr), layout_(layout), vmap_(vmap), tmap_(tmap) { + auto order = layout_->order; + auto shapes = layout_->shapes; + shapes[order[0]] += layout_->pad; + + Type* ty = type(layout_->ty, builder_->getContext()); + + PointerType *ptr_ty = ty->getPointerTo(sh_mem_ptr_->getType()->getPointerAddressSpace()); + // double-buffered + if(layout_->double_buffer) { + BasicBlock *current = builder_->GetInsertBlock(); + auto info = *layout_->double_buffer; + ir::phi_node *phi = info.phi; + BasicBlock *parent = (BasicBlock*)vmap_.at(phi->get_parent()); + if(parent->empty()) + builder_->SetInsertPoint(parent); + else + builder_->SetInsertPoint(&*parent->getFirstNonPHI()); + // create pointers + ptr_ = builder_->CreatePHI(ptr_ty, 2); + pre_ptr_ = builder_->CreateGEP(sh_mem_ptr_, builder_->getInt32(alloc_->offset(layout_))); + pre_ptr_ = builder_->CreateBitCast(pre_ptr_, ptr_->getType()); + offset_ = builder_->CreatePHI(builder_->getInt32Ty(), 2); + next_ptr_ = builder_->CreateGEP(ptr_, offset_, "next_ptr"); + builder_->SetInsertPoint(current); + } + else{ + size_t offset = alloc_->offset(layout_); + ptr_ = builder_->CreateGEP(sh_mem_ptr_, builder_->getInt32(offset)); + ptr_ = builder_->CreateBitCast(ptr_, ptr_ty); + } +} + + +tile* machine_layout_shared_t::create(ir::value *v) { + auto order = layout_->order; + auto shapes = layout_->shapes; + shapes[order[0]] += layout_->pad; + Type* ty = type(layout_->ty, builder_->getContext()); + // double-buffered + if(layout_->double_buffer) { + if(v == layout_->double_buffer->phi) + return new shared_tile(ty, shapes, order, ptr_, *builder_, offset_); + if(v == layout_->double_buffer->latch) + return new shared_tile(ty, shapes, order, next_ptr_, *builder_); + return new shared_tile(ty, shapes, order, pre_ptr_, *builder_); + } + else { + return new shared_tile(ty, shapes, order, ptr_, *builder_); + } } machine_layout_distributed_t::machine_layout_distributed_t(Module *mod, Builder *builder, target *tgt, Type *ty, @@ -1366,7 +1293,7 @@ machine_layout_distributed_t::machine_layout_distributed_t(Module *mod, Builder } -distributed_tile* machine_layout_distributed_t::create(ir::value *v) { +tile *machine_layout_distributed_t::create(ir::value *v) { Type *ty = type(v->get_type()->get_scalar_ty(), builder_->getContext()); const auto &shapes = v->get_type()->get_tile_shapes(); std::vector axes(shapes.size()); @@ -1540,6 +1467,74 @@ machine_layout_scanline_t::machine_layout_scanline_t(Module *mod, Builder *build } } +finalizer::finalizer(Builder *builder, std::map& vmap, std::map& tmap) + : builder_(builder), vmap_(vmap), tmap_(tmap) { + +} + +void finalizer::for_each(ir::value *x, const std::function& fn) { + if(!x->get_type()->is_tile_ty()) + return fn({}); + else { + if(auto *dt = dynamic_cast(tmap_.at(x))) + dt->for_each(fn); + } +} + +Value* finalizer::get_value(ir::value *x, const indices_t& idx) { + if(x->get_type()->is_tile_ty()) + return tmap_.at(x)->get_value(idx); + return vmap_.at(x); +} + +void finalizer::set_value(ir::value *x, const indices_t& idx, Value* v) { + if(x->get_type()->is_tile_ty()) + tmap_.at(x)->set_value(idx, v); + else + vmap_[x] = v; +} + +void finalizer::visit_phi_node(ir::phi_node* phi) { + auto it = tmap_.find(phi); + if(it != tmap_.end() && dynamic_cast(it->second)) + return; + for(unsigned n = 0; n < phi->get_num_incoming(); n++){ + ir::basic_block *inc_block = phi->get_incoming_block(n); + BasicBlock *llvm_inc_block = (BasicBlock*)vmap_.at(inc_block); + for_each(phi, [&](indices_t idx){ + PHINode *llvm_phi = (PHINode*)get_value(phi, idx); + Value *llvm_inc_val = get_value(phi->get_incoming_value(n), idx); + llvm_phi->addIncoming(llvm_inc_val, llvm_inc_block); + }); + } +} + + +void finalizer::visit_layout_shared(analysis::layout_shared_t* layout) { + if(layout->double_buffer) { + auto info = *layout->double_buffer; + ir::phi_node *phi = info.phi; + PHINode *ptr = (PHINode*)((shared_tile*)tmap_.at(phi))->get_pointer(); + PHINode *offset = (PHINode*)((shared_tile*)tmap_.at(phi))->get_offset(); + for(unsigned n = 0; n < phi->get_num_incoming(); n++){ + ir::basic_block* inc_block = phi->get_incoming_block(n); + ir::value* inc_val = phi->get_incoming_value(n); + BasicBlock *llvm_inc_block = (BasicBlock*)vmap_.at(inc_block); + shared_tile *inc_shared = (shared_tile*)tmap_.at(inc_val); + if(inc_val == info.latch){ + builder_->SetInsertPoint(llvm_inc_block->getTerminator()); + Value *next_offset = builder_->CreateNeg(offset); + offset->addIncoming(next_offset, llvm_inc_block); + } + else { + unsigned num_bytes = layout->ty->get_primitive_size_in_bits() / 8; + offset->addIncoming(builder_->getInt32(layout->size / (2*num_bytes)), llvm_inc_block); + } + ptr->addIncoming(inc_shared->get_pointer(), llvm_inc_block); + } + } +} + } diff --git a/lib/codegen/transform/membar.cc b/lib/codegen/transform/membar.cc index 9a8ad7fd2..8c2f3d909 100644 --- a/lib/codegen/transform/membar.cc +++ b/lib/codegen/transform/membar.cc @@ -37,8 +37,10 @@ void membar::add_reference(ir::value *v, interval_vec_t &res){ auto *i = dynamic_cast(v); if(!i) return; - if(alloc_->has_offset(v)){ - unsigned offset = alloc_->offset(v); + if(!i->get_type()->is_tile_ty()) + return; + if(alloc_->has_offset(layouts_->get(v))){ + unsigned offset = alloc_->offset(layouts_->get(v)); unsigned size = layouts_->get(v)->size; res.push_back(interval_t(offset, offset + size)); } diff --git a/lib/runtime/function.cc b/lib/runtime/function.cc index 37b14145f..b83ea8442 100644 --- a/lib/runtime/function.cc +++ b/lib/runtime/function.cc @@ -242,11 +242,6 @@ std::unique_ptr function::make_bin(ir::module &module, driver::c if(allocation.allocated_size() > context->device()->max_shared_memory()) return std::unique_ptr(); barriers.run(module); - dce.run(module); - align.run(module); - axes.run(module); - layouts.run(module); - liveness.run(module); // ir::print(module, std::cout); selection.run(module, *llvm); // return binary