diff --git a/include/triton/codegen/selection.h b/include/triton/codegen/selection.h index dfdf48ca1..02deedff6 100644 --- a/include/triton/codegen/selection.h +++ b/include/triton/codegen/selection.h @@ -225,6 +225,7 @@ private: void visit_outer_dot(ir::dot_inst*, distributed_tile *TA, distributed_tile *TB, distributed_tile *TD, unsigned NK, Type *c_ty, Function *f_mul_add); + void finalize_shared_layout(analysis::layout_shared_t*); void finalize_function(ir::function*); void finalize_phi_node(ir::phi_node*); @@ -322,6 +323,7 @@ private: std::set seen_; }; + // Selection pass class selection{ typedef std::map vmap_t; diff --git a/lib/codegen/selection.cc b/lib/codegen/selection.cc index 61e4d9bdd..dfa12ba17 100644 --- a/lib/codegen/selection.cc +++ b/lib/codegen/selection.cc @@ -420,16 +420,20 @@ Value* selection::alloc_shared(IRBuilder<> &builder, Module& dst) { void selection::run(ir::module &src, Module &dst) { vmap_.clear(); tmap_.clear(); + LLVMContext &ctx = dst.getContext(); IRBuilder<> builder(ctx); + // allocate shared memory Value *sh_mem_ptr = alloc_shared(builder, dst); - // visit - generator visitor(&ctx, &dst, &builder, a_axes_, axes_, vmap_, tmap_, tgt_, layouts_, alignment_, alloc_, sh_mem_ptr, num_warps_ ); + + // create tile + generator gen(&ctx, &dst, &builder, a_axes_, axes_, vmap_, tmap_, tgt_, layouts_, alignment_, alloc_, sh_mem_ptr, num_warps_ ); + for(ir::alloc_const *x: src.allocs()) - visitor.visit_value(x); + gen.visit_value(x); for(ir::function *fn: src.get_function_list()) - visitor.visit_value(fn); + gen.visit_value(fn); } @@ -1441,10 +1445,36 @@ machine_layout_scanline_t::machine_layout_scanline_t(Module *mod, Builder *build } } -void generator::finalize_function(ir::function* fn) { +void generator::finalize_shared_layout(analysis::layout_shared_t *shared) { + if(shared->double_buffer) { + auto info = *shared->double_buffer; + ir::phi_node *phi = info.phi; + PHINode *ptr = (PHINode*)((shared_tile*)tmap_.at(phi))->get_pointer(); + PHINode *offset = (PHINode*)((shared_tile*)tmap_.at(phi))->get_offset(); + for(unsigned n = 0; n < phi->get_num_incoming(); n++){ + ir::basic_block* inc_block = phi->get_incoming_block(n); + ir::value* inc_val = phi->get_incoming_value(n); + BasicBlock *llvm_inc_block = (BasicBlock*)vmap_.at(inc_block); + shared_tile *inc_shared = (shared_tile*)tmap_.at(inc_val); + if(inc_val == info.latch){ + builder_->SetInsertPoint(llvm_inc_block->getTerminator()); + Value *next_offset = builder_->CreateNeg(offset); + offset->addIncoming(next_offset, llvm_inc_block); + } + else { + unsigned num_bytes = shared->ty->get_primitive_size_in_bits() / 8; + offset->addIncoming(builder_->getInt32(shared->size / (2*num_bytes)), llvm_inc_block); + } + ptr->addIncoming(inc_shared->get_pointer(), llvm_inc_block); + } + } +} + +void generator::finalize_function(ir::function *fn) { // finalize double-buffering for(const auto& x: layouts_->get_all()) - visit_layout(x.second); + if(auto *shared = dynamic_cast(x.second)) + finalize_shared_layout(shared); // finalize phi for(ir::basic_block *block: fn->blocks()) for(ir::instruction *inst: block->get_inst_list()) @@ -1452,7 +1482,7 @@ void generator::finalize_function(ir::function* fn) { finalize_phi_node(phi); } -void generator::finalize_phi_node(ir::phi_node* phi) { +void generator::finalize_phi_node(ir::phi_node *phi) { auto it = tmap_.find(phi); if(it != tmap_.end() && dynamic_cast(it->second)) return; @@ -1467,7 +1497,5 @@ void generator::finalize_phi_node(ir::phi_node* phi) { } } - - } }