From ae246218252a893bf62bc2c3e37f9ee6b9527a27 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Thu, 17 Oct 2019 00:36:46 -0400 Subject: [PATCH] more cleaning --- include/triton/codegen/analysis/layout.h | 2 + include/triton/codegen/selection.h | 118 ++----- include/triton/ir/basic_block.h | 4 + include/triton/ir/function.h | 2 + include/triton/ir/value.h | 4 +- include/triton/ir/visitor.h | 10 + lib/codegen/analysis/layout.cc | 5 + lib/codegen/selection.cc | 372 +++++++++-------------- lib/ir/function.cc | 4 + lib/ir/value.cc | 6 + 10 files changed, 205 insertions(+), 322 deletions(-) diff --git a/include/triton/codegen/analysis/layout.h b/include/triton/codegen/analysis/layout.h index 923e13411..70260542a 100644 --- a/include/triton/codegen/analysis/layout.h +++ b/include/triton/codegen/analysis/layout.h @@ -36,6 +36,7 @@ struct double_buffer_info_t { }; class layout_visitor; +class layout_t; class layout_hmma_884_t; class layout_scanline_t; class layout_shared_t; @@ -43,6 +44,7 @@ class layout_shared_t; class layout_visitor { public: + virtual void visit_layout(layout_t *); virtual void visit_layout_hmma_884(layout_hmma_884_t*) = 0; virtual void visit_layout_scanline(layout_scanline_t*) = 0; virtual void visit_layout_shared(layout_shared_t*) = 0; diff --git a/include/triton/codegen/selection.h b/include/triton/codegen/selection.h index ec00cdd20..dfdf48ca1 100644 --- a/include/triton/codegen/selection.h +++ b/include/triton/codegen/selection.h @@ -197,16 +197,13 @@ public: machine_layout_hmma_884_t(Module *mod, Builder *builder, target *tgt, Type *ty, analysis::axes *a_axes, std::map& axes, - Value *&offset_a_i, Value *&offset_a_k, Value *&offset_b_j, Value *&offset_b_k, - unsigned &pack_size_0, unsigned &pack_size_1, - unsigned &num_packs_0, unsigned &num_packs_1, analysis::layout_hmma_884_t* layout); - Value *&offset_a_i_, *&offset_a_k_; - Value *&offset_b_j_, *&offset_b_k_; - unsigned &pack_size_0_; - unsigned& pack_size_1_; - unsigned &num_packs_0_; - unsigned& num_packs_1_; + Value *offset_a_i_, *offset_a_k_; + Value *offset_b_j_, *offset_b_k_; + unsigned pack_size_0_; + unsigned pack_size_1_; + unsigned num_packs_0_; + unsigned num_packs_1_; }; class machine_layout_scanline_t: public machine_layout_distributed_t { @@ -219,15 +216,18 @@ public: class generator: public ir::visitor, public analysis::layout_visitor { private: - void visit_hmma_dot(ir::dot_inst*, distributed_tile *TC, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK); - void visit_scanline_dot(ir::dot_inst*, distributed_tile *TC, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK, Type *c_ty, Function *f_mul_add); - void visit_outer_dot(ir::dot_inst*, distributed_tile *TC, distributed_tile *TA, distributed_tile *TB, distributed_tile *TD, unsigned NK, - Type *c_ty, Function *f_mul_add); - void for_each(ir::value *x, const std::function& fn); Value* get_value(ir::value *x, const indices_t& idx); void set_value(ir::value *x, const indices_t& idx, Value* v); + void visit_hmma_dot(ir::dot_inst*, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK); + void visit_scanline_dot(ir::dot_inst*, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK, Type *c_ty, Function *f_mul_add); + void visit_outer_dot(ir::dot_inst*, distributed_tile *TA, distributed_tile *TB, distributed_tile *TD, unsigned NK, + Type *c_ty, Function *f_mul_add); + + void finalize_function(ir::function*); + void finalize_phi_node(ir::phi_node*); + public: generator(LLVMContext *ctx, Module *dst, @@ -241,18 +241,12 @@ public: analysis::align *alignment, analysis::allocation *alloc, Value *sh_mem_ptr, - Value *offset_a_i, Value *offset_a_k, - Value *offset_b_j, Value *offset_b_k, - unsigned num_packs_0, unsigned num_packs_1, - unsigned pack_size_0, unsigned pack_size_1, unsigned num_warps) : ctx_(ctx), mod_(dst), builder_(builder), a_axes_(a_axes), axes_(axes), vmap_(vmap), tmap_(tmap), tgt_(tgt), layouts_(layouts), alignment_(alignment), alloc_(alloc), sh_mem_ptr_(sh_mem_ptr), - offset_a_i_(offset_a_i), offset_a_k_(offset_a_k), offset_b_j_(offset_b_j), offset_b_k_(offset_b_k), - num_packs_0_(num_packs_0), num_packs_1_(num_packs_1), pack_size_0_(pack_size_0), pack_size_1_(pack_size_1), num_warps_(num_warps) { } - machine_layout_t *get_machine_layout(const analysis::layout_t *layout) { return machine_layouts_.at(layout); } + void visit_value(ir::value* v); void visit_phi_node(ir::phi_node*); void visit_binary_operator(ir::binary_operator*); @@ -301,6 +295,8 @@ public: void visit_alloc_const(ir::alloc_const*); void visit_function(ir::function*); + void visit_basic_block(ir::basic_block*); + void visit_argument(ir::argument*); void visit_layout_hmma_884(analysis::layout_hmma_884_t*); void visit_layout_scanline(analysis::layout_scanline_t*); @@ -308,7 +304,6 @@ public: private: LLVMContext *ctx_; - Function *fn_; Builder *builder_; Module *mod_; @@ -322,78 +317,9 @@ private: analysis::align *alignment_; analysis::allocation *alloc_; Value *sh_mem_ptr_; - Value *offset_a_i_, *offset_a_k_; - Value *offset_b_j_, *offset_b_k_; - unsigned num_packs_0_, num_packs_1_; - unsigned pack_size_0_, pack_size_1_; unsigned num_warps_; -}; -class finalizer: public ir::visitor, public analysis::layout_visitor { -private: - void for_each(ir::value *x, const std::function& fn); - Value* get_value(ir::value *x, const indices_t& idx); - void set_value(ir::value *x, const indices_t& idx, Value* v); - -public: - finalizer(Builder *builder, std::map& vmap, std::map& tmap); - - void visit_phi_node(ir::phi_node*); - void visit_binary_operator(ir::binary_operator*) { } - void visit_getelementptr_inst(ir::getelementptr_inst*) { } - - void visit_icmp_inst(ir::icmp_inst*) { } - void visit_fcmp_inst(ir::fcmp_inst*) { } - void visit_cast_inst(ir::cast_inst*) { } - - void visit_return_inst(ir::return_inst*) { } - void visit_cond_branch_inst(ir::cond_branch_inst*) { } - void visit_uncond_branch_inst(ir::uncond_branch_inst*) { } - - - void visit_unmasked_load_inst(ir::unmasked_load_inst*) { } - void visit_masked_load_inst(ir::masked_load_inst*) { } - void visit_unmasked_store_inst(ir::unmasked_store_inst*) { } - void visit_masked_store_inst(ir::masked_store_inst*) { } - - void visit_reshape_inst(ir::reshape_inst*) { } - void visit_splat_inst(ir::splat_inst*) { } - void visit_broadcast_inst(ir::broadcast_inst*) { } - void visit_downcast_inst(ir::downcast_inst*) { } - - void visit_get_program_id_inst(ir::get_program_id_inst*) { } - void visit_get_num_program_inst(ir::get_num_program_inst*) { } - void visit_atomic_cas_inst(ir::atomic_cas_inst*) { } - void visit_atomic_exch_inst(ir::atomic_exch_inst*) { } - void visit_atomic_add_inst(ir::atomic_add_inst*) { } - void visit_dot_inst(ir::dot_inst*) { } - void visit_trans_inst(ir::trans_inst*) { } - void visit_sqrt_inst(ir::sqrt_inst*) { } - void visit_reduce_inst(ir::reduce_inst*) { } - void visit_select_inst(ir::select_inst*) { } - - void visit_copy_to_shared_inst(ir::copy_to_shared_inst*) { } - void visit_copy_from_shared_inst(ir::copy_from_shared_inst*) { } - void visit_barrier_inst(ir::barrier_inst*) { } - void visit_make_range_dyn(ir::make_range_dyn*) { } - void visit_make_range(ir::make_range*) { } - - void visit_make_range_sta(ir::make_range_sta*) { } - void visit_undef_value(ir::undef_value*) { } - void visit_constant_int(ir::constant_int*) { } - void visit_constant_fp(ir::constant_fp*) { } - void visit_alloc_const(ir::alloc_const*) { } - - void visit_function(ir::function*) { } - - void visit_layout_hmma_884(analysis::layout_hmma_884_t*) { } - void visit_layout_scanline(analysis::layout_scanline_t*) { } - void visit_layout_shared(analysis::layout_shared_t*); - -private: - Builder *builder_; - std::map& vmap_; - std::map& tmap_; + std::set seen_; }; // Selection pass @@ -405,9 +331,6 @@ private: // LLVM conversions Value* alloc_shared(Builder &builder, Module& dst); - // lower scalar instruction - void lower_value(ir::value *src, Builder &builder, generator* gen, std::set& seen); - public: selection(analysis::liveness* liveness, analysis::allocation *alloc, analysis::align *alignment, analysis::axes *axes, @@ -428,11 +351,6 @@ private: analysis::align *alignment_; target *tgt_; std::map axes_; - Value *sh_mem_ptr_; - Value *offset_a_i_, *offset_a_k_; - Value *offset_b_j_, *offset_b_k_; - unsigned num_packs_0_, num_packs_1_; - unsigned pack_size_0_, pack_size_1_; unsigned num_warps_; }; diff --git a/include/triton/ir/basic_block.h b/include/triton/ir/basic_block.h index 4a60586f0..3d274815a 100644 --- a/include/triton/ir/basic_block.h +++ b/include/triton/ir/basic_block.h @@ -6,6 +6,7 @@ #include #include #include "value.h" +#include "visitor.h" namespace triton{ namespace ir{ @@ -66,6 +67,9 @@ public: // factory functions static basic_block* create(context &ctx, const std::string &name, function *parent); + // visitor + void accept(visitor *v) { v->visit_basic_block(this); } + private: context &ctx_; std::string name_; diff --git a/include/triton/ir/function.h b/include/triton/ir/function.h index 8cf11275a..d3ebe199b 100644 --- a/include/triton/ir/function.h +++ b/include/triton/ir/function.h @@ -26,6 +26,8 @@ public: function* get_parent() const; unsigned get_arg_no() const; + void accept(visitor *v); + private: function *parent_; unsigned arg_no_; diff --git a/include/triton/ir/value.h b/include/triton/ir/value.h index bf4a4aa9c..e192a54ef 100644 --- a/include/triton/ir/value.h +++ b/include/triton/ir/value.h @@ -33,6 +33,8 @@ public: void set_name(const std::string &name); const std::string &get_name() const { return name_; } type* get_type() const { return ty_; } + // visitor + virtual void accept(visitor *v) = 0; private: std::string name_; @@ -75,8 +77,6 @@ public: void replace_all_uses_with(value *target); void replace_uses_of_with(value *before, value *after); - // Visitor - virtual void accept(visitor *v) = 0; private: ops_t ops_; diff --git a/include/triton/ir/visitor.h b/include/triton/ir/visitor.h index e2310e94a..62e63e6c4 100644 --- a/include/triton/ir/visitor.h +++ b/include/triton/ir/visitor.h @@ -7,6 +7,8 @@ namespace triton{ namespace ir{ +class value; + class instruction; class phi_node; @@ -81,10 +83,18 @@ class alloc_const; class function; +class basic_block; + +class argument; + class visitor { public: virtual ~visitor() {} + virtual void visit_value(ir::value*); + + virtual void visit_basic_block(basic_block*) = 0; + virtual void visit_argument(argument*) = 0; virtual void visit_phi_node(phi_node*) = 0; virtual void visit_binary_operator(binary_operator*) = 0; virtual void visit_getelementptr_inst(getelementptr_inst*) = 0; diff --git a/lib/codegen/analysis/layout.cc b/lib/codegen/analysis/layout.cc index f435efef8..dc43f8ea6 100644 --- a/lib/codegen/analysis/layout.cc +++ b/lib/codegen/analysis/layout.cc @@ -124,6 +124,10 @@ inline bool is_trans(ir::value *v) { } +void layout_visitor::visit_layout(layout_t *layout) { + layout->accept(this); +} + layout_t::layout_t(layout_type_t _type, const std::vector &_axes, @@ -145,6 +149,7 @@ layout_t::layout_t(layout_type_t _type, } } + inline unsigned clamp(unsigned x, unsigned lo, unsigned hi) { return std::min(std::max(x, lo), hi); } diff --git a/lib/codegen/selection.cc b/lib/codegen/selection.cc index 96a4632f5..61e4d9bdd 100644 --- a/lib/codegen/selection.cc +++ b/lib/codegen/selection.cc @@ -385,31 +385,6 @@ bool is_trans(ir::value *v) { } -void selection::lower_value(ir::value *src, IRBuilder<> &builder, generator* gen, std::set& seen) { - if(!seen.insert(src).second) - return; - - if(src->get_type()->is_tile_ty()) - tmap_[src] = gen->get_machine_layout(layouts_->get(src))->create(src); - - - BasicBlock *current = builder.GetInsertBlock(); - auto *inst = dynamic_cast(src); - if(inst && !dynamic_cast(src)) - for(ir::value *op: inst->ops()) - lower_value(op, builder, gen, seen); - - builder.SetInsertPoint(current); - auto *phi = dynamic_cast(src); - if(phi && !current->empty() && current->getFirstNonPHI()) - builder.SetInsertPoint(&*current->getFirstNonPHI()); - - if(auto *usr = dynamic_cast(src)) - usr->accept(gen); - - if(phi && !current->empty() && current->getFirstNonPHI()) - builder.SetInsertPoint(current); -} /* ---------------------------- * ---- Generate LLVM code ---- @@ -445,57 +420,44 @@ Value* selection::alloc_shared(IRBuilder<> &builder, Module& dst) { void selection::run(ir::module &src, Module &dst) { vmap_.clear(); tmap_.clear(); - - LLVMContext &dst_ctx = dst.getContext(); - IRBuilder<> dst_builder(dst_ctx); - + LLVMContext &ctx = dst.getContext(); + IRBuilder<> builder(ctx); // allocate shared memory - sh_mem_ptr_ = alloc_shared(dst_builder, dst); - - // iterate over functions - std::set seen; - - // create tile - generator gen(&dst_ctx, &dst, &dst_builder, a_axes_, axes_, vmap_, tmap_, tgt_, layouts_, alignment_, alloc_, sh_mem_ptr_, - offset_a_i_, offset_a_k_, offset_b_j_, offset_b_k_, num_packs_0_, num_packs_1_, pack_size_0_, pack_size_1_, num_warps_ ); - finalizer fin(&dst_builder, vmap_, tmap_); - + Value *sh_mem_ptr = alloc_shared(builder, dst); + // visit + generator visitor(&ctx, &dst, &builder, a_axes_, axes_, vmap_, tmap_, tgt_, layouts_, alignment_, alloc_, sh_mem_ptr, num_warps_ ); for(ir::alloc_const *x: src.allocs()) - x->accept(&gen); - - for(ir::function *fn: src.get_function_list()) { - - fn->accept(&gen); - - // initialize layouts - for(auto x: layouts_->get_all()) - x.second->accept(&gen); - - // generate LLVM-IR code - for(ir::basic_block *block: fn->blocks()) { - BasicBlock *parent = (BasicBlock*)vmap_[block]; - dst_builder.SetInsertPoint(parent); - for(ir::instruction *i: block->get_inst_list()) - lower_value(i, dst_builder, &gen, seen); - vmap_[block] = dst_builder.GetInsertBlock(); - } - - // finalize double-buffering - for(const auto& x: layouts_->get_all()) - x.second->accept(&fin); - - // finalize phi - for(ir::basic_block *block: fn->blocks()) - for(ir::instruction *inst: block->get_inst_list()) - inst->accept(&fin); - - } - + visitor.visit_value(x); + for(ir::function *fn: src.get_function_list()) + visitor.visit_value(fn); } - +void generator::visit_value(ir::value* v) { + if(!seen_.insert(v).second) + return; + // create machine tile + if(v->get_type()->is_tile_ty()) + tmap_[v] = machine_layouts_.at(layouts_->get(v))->create(v); + // visit operands + BasicBlock *current = builder_->GetInsertBlock(); + auto *inst = dynamic_cast(v); + if(inst && !dynamic_cast(v)) + for(ir::value *op: inst->ops()) + visit_value(op); + // change insert point for phi node + builder_->SetInsertPoint(current); + auto *phi = dynamic_cast(v); + if(phi && !current->empty() && current->getFirstNonPHI()) + builder_->SetInsertPoint(&*current->getFirstNonPHI()); + // visit user + if(auto *usr = dynamic_cast(v)) + usr->accept(this); + // revert insert point + if(phi && !current->empty() && current->getFirstNonPHI()) + builder_->SetInsertPoint(current); +} void generator::visit_phi_node(ir::phi_node* phi) { Type *ty = type(phi->get_type()->get_scalar_ty(), *ctx_); @@ -574,19 +536,19 @@ void generator::visit_uncond_branch_inst(ir::uncond_branch_inst* br) { void generator::visit_unmasked_load_inst(ir::unmasked_load_inst* x) { - distributed_tile* result = (distributed_tile*)tmap_.at(x); // find vector size ir::value *ptr = x->get_pointer_operand(); size_t ld = layouts_->get(ptr)->order[0]; unsigned alignment = alignment_->get(ptr, ld); - unsigned vector_size = std::min(result->axis(ld).contiguous, alignment); - distributed_tile *pointers = (distributed_tile*)tmap_.at(ptr); + unsigned vector_size = std::min(axes_.at(a_axes_->get(x, ld)).contiguous, alignment); // vector loads std::map packets; - result->for_each([&](indices_t idx){ + for_each(x, [&](indices_t idx){ + distributed_tile* result = (distributed_tile*)tmap_.at(x); unsigned linear = result->get_linear_index(idx); unsigned id = linear / vector_size; if(linear % vector_size == 0) { + distributed_tile *pointers = (distributed_tile*)tmap_.at(ptr); Value *ptr = pointers->get_value(idx); ptr = builder_->CreateBitCast(ptr, PointerType::get(VectorType::get(result->get_ty(), vector_size), ptr->getType()->getPointerAddressSpace())); @@ -594,25 +556,26 @@ void generator::visit_unmasked_load_inst(ir::unmasked_load_inst* x) { } }); // extract result element - result->for_each([&](indices_t idx){ + for_each(x, [&](indices_t idx){ + distributed_tile* result = (distributed_tile*)tmap_.at(x); unsigned linear = result->get_linear_index(idx); unsigned id = linear / vector_size; - result->set_value(idx, builder_->CreateExtractElement(packets.at(id), linear % vector_size)); + set_value(x, idx, builder_->CreateExtractElement(packets.at(id), linear % vector_size)); }); } void generator::visit_masked_load_inst(ir::masked_load_inst* x) { // find vector size - distributed_tile* result = (distributed_tile*)tmap_.at(x); ir::value *ptr = x->get_pointer_operand(); size_t ld = layouts_->get(ptr)->order[0]; unsigned alignment = alignment_->get(ptr, ld); - unsigned vector_size = std::min(result->axis(ld).contiguous, alignment); + unsigned vector_size = std::min(axes_.at(a_axes_->get(x, ld)).contiguous, alignment); distributed_tile *pointers = (distributed_tile*)tmap_.at(ptr); distributed_tile *masks = (distributed_tile*)tmap_.at(x->get_mask_operand()); distributed_tile *false_values = (distributed_tile*)tmap_.at(x->get_false_value_operand()); std::map packets; - result->for_each([&](indices_t idx){ + for_each(x, [&](indices_t idx){ + distributed_tile* result = (distributed_tile*)tmap_.at(x); unsigned linear = result->get_linear_index(idx); unsigned id = linear / vector_size; if(linear % vector_size == 0) { @@ -664,7 +627,8 @@ void generator::visit_masked_load_inst(ir::masked_load_inst* x) { } }); // extract result element - result->for_each([&](indices_t idx){ + for_each(x, [&](indices_t idx){ + distributed_tile* result = (distributed_tile*)tmap_.at(x); unsigned linear = result->get_linear_index(idx); unsigned id = linear / vector_size; // Value *tmp = builder_->CreateExtractValue(packets.at(id), {(linear % vector_size) / 2}); @@ -714,13 +678,13 @@ void generator::visit_masked_store_inst(ir::masked_store_inst* st) { void generator::visit_reshape_inst(ir::reshape_inst* reshape) { - distributed_tile* result = (distributed_tile*)tmap_.at(reshape); - ir::value* in = reshape->get_operand(0); - distributed_tile *in_tile = (distributed_tile*)tmap_.at(in); for_each(reshape, [&](indices_t out_idx){ + distributed_tile* result = (distributed_tile*)tmap_.at(reshape); unsigned pos = result->get_linear_index(out_idx); + ir::value* in = reshape->get_operand(0); + distributed_tile *in_tile = (distributed_tile*)tmap_.at(in); indices_t in_idx = in_tile->get_ordered_indices(pos); - result->set_value(out_idx, in_tile->get_value(in_idx)); + set_value(reshape, out_idx, get_value(in, in_idx)); }); } @@ -732,17 +696,16 @@ void generator::visit_splat_inst(ir::splat_inst* splat) { } void generator::visit_broadcast_inst(ir::broadcast_inst* bcast) { - distributed_tile* result = (distributed_tile*)tmap_.at(bcast); ir::value* in = bcast->get_operand(0); const auto& in_shapes = in->get_type()->get_tile_shapes(); distributed_tile *in_tile = (distributed_tile*)tmap_.at(in); - result->for_each([&](indices_t out_idx){ + for_each(bcast, [&](indices_t out_idx){ indices_t in_idx = out_idx; for(size_t k = 0; k < in_idx.size(); k++){ if(in_shapes[k] == 1) in_idx[k] = builder_->getInt32(0); } - result->set_value(out_idx, in_tile->get_value(in_idx)); + set_value(bcast, out_idx, in_tile->get_value(in_idx)); }); } @@ -812,17 +775,17 @@ void generator::visit_atomic_add_inst(ir::atomic_add_inst*) { throw std::runtime_error("unsupported"); } -void generator::visit_hmma_dot(ir::dot_inst* dot, distributed_tile *TC, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK) { +void generator::visit_hmma_dot(ir::dot_inst* dot, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK) { const auto& shapes = dot->get_type()->get_tile_shapes(); - - TA->set_vector_size(4*pack_size_0_); - TB->set_vector_size(4*pack_size_1_); + machine_layout_hmma_884_t* hmma = (machine_layout_hmma_884_t*)machine_layouts_.at(layouts_->get(dot)); + TA->set_vector_size(4*hmma->pack_size_0_); + TB->set_vector_size(4*hmma->pack_size_1_); TA->set_return_mode(true); TB->set_return_mode(true); std::map, std::vector> fcs; - TC->for_each([&](indices_t idx){ + for_each(dot, [&](indices_t idx){ std::vector key(idx.size() - 2); std::copy(idx.begin() + 2, idx.end(), key.begin()); fcs[key].push_back(TD->get_value(idx)); @@ -833,10 +796,6 @@ void generator::visit_hmma_dot(ir::dot_inst* dot, distributed_tile *TC, shared_t Type *fp32_pack8_ty = StructType::get(*ctx_, {fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}); FunctionType *mma_ty = FunctionType::get(fp32_pack8_ty, {fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false); - Value *offset_a_i = offset_a_i_; - Value *offset_a_k = offset_a_k_; - Value *offset_b_j = offset_b_j_; - Value *offset_b_k = offset_b_k_; Value* u_thread_id = tgt_->get_local_id(builder_->GetInsertBlock()->getModule(), *builder_, 0); @@ -849,10 +808,15 @@ void generator::visit_hmma_dot(ir::dot_inst* dot, distributed_tile *TC, shared_t bool is_b_row = is_b_trans ^ (ord_b[ord_b.size() - 2] == 1); + Value *offset_a_i = hmma->offset_a_i_; + Value *offset_a_k = hmma->offset_a_k_; if(is_a_row){ offset_a_i = builder_->CreateAdd(offset_a_i, builder_->CreateURem(u_thread_id, builder_->getInt32(4))); offset_a_k = builder_->getInt32(0); } + + Value *offset_b_j = hmma->offset_b_j_; + Value *offset_b_k = hmma->offset_b_k_; if(!is_b_row){ offset_b_j = builder_->CreateAdd(offset_b_j, builder_->CreateURem(u_thread_id, builder_->getInt32(4))); offset_b_k = builder_->getInt32(0); @@ -881,33 +845,33 @@ void generator::visit_hmma_dot(ir::dot_inst* dot, distributed_tile *TC, shared_t for(auto& x: fcs){ std::vector& fc = x.second; - for(unsigned pack_i = 0; pack_i < num_packs_0_; pack_i++) - for(unsigned pack_j = 0; pack_j < num_packs_1_; pack_j++){ + for(unsigned pack_i = 0; pack_i < hmma->num_packs_0_; pack_i++) + for(unsigned pack_j = 0; pack_j < hmma->num_packs_1_; pack_j++){ for(unsigned K = 0; K < NK; K += 4){ Value *_K = builder_->getInt32(K); - Value *current_offset_a_i = builder_->CreateAdd(offset_a_i, builder_->getInt32(pack_i*stride_rep_i*pack_size_0_)); - Value *current_offset_b_i = builder_->CreateAdd(offset_b_j, builder_->getInt32(pack_j*stride_rep_j*pack_size_1_)); + Value *current_offset_a_i = builder_->CreateAdd(offset_a_i, builder_->getInt32(pack_i*stride_rep_i*hmma->pack_size_0_)); + Value *current_offset_b_i = builder_->CreateAdd(offset_b_j, builder_->getInt32(pack_j*stride_rep_j*hmma->pack_size_1_)); indices_t idx_a = {current_offset_a_i, builder_->CreateAdd(offset_a_k, _K)}; indices_t idx_b = {builder_->CreateAdd(offset_b_k, _K), current_offset_b_i}; idx_a.insert(idx_a.end(), x.first.begin(), x.first.end()); idx_b.insert(idx_b.end(), x.first.begin(), x.first.end()); Value *ha = TA->get_value(idx_a); Value *hb = TB->get_value(idx_b); - for(unsigned ii = 0; ii < pack_size_0_; ii++) - for(unsigned jj = 0; jj < pack_size_1_; jj++){ - Value *ha0 = builder_->CreateBitCast(builder_->CreateExtractElement(ha, builder_->getInt32(ii*pack_size_0_ + 0)), fp16x2_ty); - Value *ha1 = builder_->CreateBitCast(builder_->CreateExtractElement(ha, builder_->getInt32(ii*pack_size_0_ + 1)), fp16x2_ty); - Value *hb0 = builder_->CreateBitCast(builder_->CreateExtractElement(hb, builder_->getInt32(jj*pack_size_0_ + 0)), fp16x2_ty); - Value *hb1 = builder_->CreateBitCast(builder_->CreateExtractElement(hb, builder_->getInt32(jj*pack_size_0_ + 1)), fp16x2_ty); + for(unsigned ii = 0; ii < hmma->pack_size_0_; ii++) + for(unsigned jj = 0; jj < hmma->pack_size_1_; jj++){ + Value *ha0 = builder_->CreateBitCast(builder_->CreateExtractElement(ha, builder_->getInt32(ii*hmma->pack_size_0_ + 0)), fp16x2_ty); + Value *ha1 = builder_->CreateBitCast(builder_->CreateExtractElement(ha, builder_->getInt32(ii*hmma->pack_size_0_ + 1)), fp16x2_ty); + Value *hb0 = builder_->CreateBitCast(builder_->CreateExtractElement(hb, builder_->getInt32(jj*hmma->pack_size_0_ + 0)), fp16x2_ty); + Value *hb1 = builder_->CreateBitCast(builder_->CreateExtractElement(hb, builder_->getInt32(jj*hmma->pack_size_0_ + 1)), fp16x2_ty); std::vector idx = { - (pack_i*2*pack_size_0_ + ii*2 + 0) + (pack_j*4*pack_size_1_ + jj*4 + 0)*ld_fc, - (pack_i*2*pack_size_0_ + ii*2 + 0) + (pack_j*4*pack_size_1_ + jj*4 + 1)*ld_fc, - (pack_i*2*pack_size_0_ + ii*2 + 1) + (pack_j*4*pack_size_1_ + jj*4 + 0)*ld_fc, - (pack_i*2*pack_size_0_ + ii*2 + 1) + (pack_j*4*pack_size_1_ + jj*4 + 1)*ld_fc, - (pack_i*2*pack_size_0_ + ii*2 + 0) + (pack_j*4*pack_size_1_ + jj*4 + 2)*ld_fc, - (pack_i*2*pack_size_0_ + ii*2 + 0) + (pack_j*4*pack_size_1_ + jj*4 + 3)*ld_fc, - (pack_i*2*pack_size_0_ + ii*2 + 1) + (pack_j*4*pack_size_1_ + jj*4 + 2)*ld_fc, - (pack_i*2*pack_size_0_ + ii*2 + 1) + (pack_j*4*pack_size_1_ + jj*4 + 3)*ld_fc + (pack_i*2*hmma->pack_size_0_ + ii*2 + 0) + (pack_j*4*hmma->pack_size_1_ + jj*4 + 0)*ld_fc, + (pack_i*2*hmma->pack_size_0_ + ii*2 + 0) + (pack_j*4*hmma->pack_size_1_ + jj*4 + 1)*ld_fc, + (pack_i*2*hmma->pack_size_0_ + ii*2 + 1) + (pack_j*4*hmma->pack_size_1_ + jj*4 + 0)*ld_fc, + (pack_i*2*hmma->pack_size_0_ + ii*2 + 1) + (pack_j*4*hmma->pack_size_1_ + jj*4 + 1)*ld_fc, + (pack_i*2*hmma->pack_size_0_ + ii*2 + 0) + (pack_j*4*hmma->pack_size_1_ + jj*4 + 2)*ld_fc, + (pack_i*2*hmma->pack_size_0_ + ii*2 + 0) + (pack_j*4*hmma->pack_size_1_ + jj*4 + 3)*ld_fc, + (pack_i*2*hmma->pack_size_0_ + ii*2 + 1) + (pack_j*4*hmma->pack_size_1_ + jj*4 + 2)*ld_fc, + (pack_i*2*hmma->pack_size_0_ + ii*2 + 1) + (pack_j*4*hmma->pack_size_1_ + jj*4 + 3)*ld_fc }; Value *nc = builder_->CreateCall(mma_fn, {ha0, ha1, hb0, hb1, fc[idx[0]], fc[idx[1]], fc[idx[2]], fc[idx[3]], fc[idx[4]], fc[idx[5]], fc[idx[6]], fc[idx[7]]}); fc[idx[0]] = builder_->CreateExtractValue(nc, {0}); @@ -925,23 +889,23 @@ void generator::visit_hmma_dot(ir::dot_inst* dot, distributed_tile *TC, shared_t // write back unsigned i = 0; - TC->for_each([&](indices_t idx){ + for_each(dot, [&](indices_t idx){ std::vector key(idx.size() - 2); std::copy(idx.begin() + 2, idx.end(), key.begin()); if(i >= fcs.at(key).size()) i = 0; - TC->set_value(idx, fcs.at(key)[i++]); + set_value(dot, idx, fcs.at(key)[i++]); }); TA->set_return_mode(false); TB->set_return_mode(false); } -void generator::visit_scanline_dot(ir::dot_inst* dot, distributed_tile *TC, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK, +void generator::visit_scanline_dot(ir::dot_inst* dot, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK, Type *c_ty, Function *f_mul_add) { - TA->set_vector_size(TC->axis(0).contiguous); - TB->set_vector_size(TC->axis(1).contiguous); - TC->for_each([&](indices_t idx){ + TA->set_vector_size(axes_.at(a_axes_->get(dot, 0)).contiguous); + TB->set_vector_size(axes_.at(a_axes_->get(dot, 1)).contiguous); + for_each(dot, [&](indices_t idx){ Value *res = TD->get_value(idx); for(unsigned K = 0; K < NK; ++K){ // input indices @@ -961,13 +925,13 @@ void generator::visit_scanline_dot(ir::dot_inst* dot, distributed_tile *TC, shar b = builder_->CreateFPCast(b, c_ty); res = builder_->CreateCall(f_mul_add, {a, b, res}); } - TC->set_value(idx, res); + set_value(dot, idx, res); }); } -void generator::visit_outer_dot(ir::dot_inst*, distributed_tile *TC, distributed_tile *TA, distributed_tile *TB, distributed_tile *TD, unsigned NK, +void generator::visit_outer_dot(ir::dot_inst* dot, distributed_tile *TA, distributed_tile *TB, distributed_tile *TD, unsigned NK, Type *c_ty, Function *f_mul_add) { - TC->for_each([&](indices_t idx){ + for_each(dot, [&](indices_t idx){ Value *res = TD->get_value(idx); indices_t a_idx = {idx[0], builder_->getInt32(0)}; indices_t b_idx = {builder_->getInt32(0), idx[1]}; @@ -980,14 +944,13 @@ void generator::visit_outer_dot(ir::dot_inst*, distributed_tile *TC, distributed if(b->getType() != c_ty) b = builder_->CreateFPCast(b, c_ty); res = builder_->CreateCall(f_mul_add, {a, b, res}); - TC->set_value(idx, res); + set_value(dot, idx, res); }); } void generator::visit_dot_inst(ir::dot_inst* dot) { Function *fn = builder_->GetInsertBlock()->getParent(); - distributed_tile* TC = (distributed_tile*)tmap_.at(dot); Module *module = fn->getParent(); ir::value *A = dot->get_operand(0); ir::value *B = dot->get_operand(1); @@ -1004,14 +967,14 @@ void generator::visit_dot_inst(ir::dot_inst* dot) { shared_tile *TA = (shared_tile*)tmap_.at(A); shared_tile *TB = (shared_tile*)tmap_.at(B); if(layouts_->get(dot)->type == analysis::HMMA_884) - visit_hmma_dot(dot, TC, TA, TB, TD, NK); + visit_hmma_dot(dot, TA, TB, TD, NK); else - visit_scanline_dot(dot, TC, TA, TB, TD, NK, c_ty, f_mul_add); + visit_scanline_dot(dot, TA, TB, TD, NK, c_ty, f_mul_add); } else { distributed_tile *TA = (distributed_tile*)tmap_.at(A); distributed_tile *TB = (distributed_tile*)tmap_.at(B); - visit_outer_dot(dot, TC, TA, TB, TD, NK, c_ty, f_mul_add); + visit_outer_dot(dot, TA, TB, TD, NK, c_ty, f_mul_add); } } @@ -1052,15 +1015,14 @@ void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) { ir::value *arg = cts->get_operand(0); auto arg_order = layouts_->get(arg)->order; // tiles - shared_tile* result = (shared_tile*)tmap_.at(cts); - distributed_tile* in = (distributed_tile*)tmap_.at(arg); if(x_order == arg_order){ size_t ld = arg_order[0]; vector_size = layouts_->get(arg)->nts.at(ld); } std::map packets; - in->for_each([&](indices_t idx){ + for_each(arg, [&](indices_t idx){ + distributed_tile* in = (distributed_tile*)tmap_.at(arg); unsigned linear = in->get_linear_index(idx); unsigned id = linear / vector_size; Value *in_value = in->get_value(idx); @@ -1068,19 +1030,19 @@ void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) { packets[id] = UndefValue::get(VectorType::get(in_value->getType(), vector_size)); packets[id] = builder_->CreateInsertElement(packets.at(id), in_value, linear % vector_size); }); - in->for_each([&](indices_t idx){ + + for_each(arg, [&](indices_t idx){ + distributed_tile* in = (distributed_tile*)tmap_.at(arg); + shared_tile* result = (shared_tile*)tmap_.at(cts); unsigned linear = in->get_linear_index(idx); unsigned id = linear / vector_size; if(linear % vector_size == 0) result->set_value(idx, packets[id]); }); } - void generator::visit_copy_from_shared_inst(ir::copy_from_shared_inst* cfs) { - distributed_tile* result = (distributed_tile*)tmap_.at(cfs); - shared_tile* arg = (shared_tile*)tmap_.at(cfs->get_operand(0)); - result->for_each([&](indices_t idx){ - result->set_value(idx, arg->get_value(idx)); + for_each(cfs, [&](indices_t idx){ + set_value(cfs, idx, get_value(cfs->get_operand(0), idx)); }); } @@ -1090,33 +1052,30 @@ void generator::visit_barrier_inst(ir::barrier_inst*) { } void generator::visit_make_range_dyn(ir::make_range_dyn* x) { - distributed_tile* result = (distributed_tile*)tmap_.at(x); - result->for_each([&](indices_t idx){ + for_each(x, [&](indices_t idx){ assert(idx.size() == 1); BinaryOperator *bin_add = dyn_cast(idx[0]); assert(bin_add); Value *res = bin_add->getOperand(0); - result->set_value(idx, res); + set_value(x, idx, res); }); } void generator::visit_make_range_sta(ir::make_range_sta* x) { - distributed_tile *T = (distributed_tile *)tmap_.at(x); - T->for_each([&](indices_t idx){ + for_each(x, [&](indices_t idx){ assert(idx.size() == 1); BinaryOperator *bin_add = dyn_cast(idx[0]); assert(bin_add); Value *res = bin_add->getOperand(1); assert(isa(res)); - T->set_value(idx, res); + set_value(x, idx, res); }); } void generator::visit_make_range(ir::make_range* x) { - distributed_tile *T = (distributed_tile *)tmap_.at(x); - T->for_each([&](indices_t idx){ + for_each(x, [&](indices_t idx){ assert(idx.size() == 1); - T->set_value(idx, idx[0]); + set_value(x, idx, idx[0]); }); } @@ -1149,18 +1108,17 @@ void generator::visit_alloc_const(ir::alloc_const *alloc) { void generator::visit_function(ir::function* fn) { LLVMContext &ctx = builder_->getContext(); FunctionType *fn_ty = (FunctionType*)type(fn->get_fn_type(), *ctx_); - FunctionType *dst_fn_ty = fn_ty; if(!tgt_->is_gpu()){ - Type *dst_fn_ret_ty = fn_ty->getReturnType(); - std::vector dst_fn_args_ty; + Type *fn_ret_ty = fn_ty->getReturnType(); + std::vector fn_args_ty; for(unsigned i = 0; i < fn_ty->getNumParams(); i++) - dst_fn_args_ty.push_back(fn_ty->getParamType(i)); - dst_fn_args_ty.push_back(builder_->getInt32Ty()); - dst_fn_args_ty.push_back(builder_->getInt32Ty()); - dst_fn_args_ty.push_back(builder_->getInt32Ty()); - dst_fn_ty = FunctionType::get(dst_fn_ret_ty, dst_fn_args_ty, false); + fn_args_ty.push_back(fn_ty->getParamType(i)); + fn_args_ty.push_back(builder_->getInt32Ty()); + fn_args_ty.push_back(builder_->getInt32Ty()); + fn_args_ty.push_back(builder_->getInt32Ty()); + fn_ty = FunctionType::get(fn_ret_ty, fn_args_ty, false); } - Function *ret = Function::Create(dst_fn_ty, Function::ExternalLinkage, fn->get_name(), mod_); + Function *ret = Function::Create(fn_ty, Function::ExternalLinkage, fn->get_name(), mod_); // set attributes for(auto attr_pair: fn->attrs()){ unsigned id = attr_pair.first; @@ -1176,7 +1134,7 @@ void generator::visit_function(ir::function* fn) { ValueAsMetadata::get(builder_->getInt32(num_warps_*32)) }; mod_->getOrInsertNamedMetadata("nvvm.annotations")->addOperand(MDNode::get(ctx, md_args)); - // map parameters + // set arguments for(unsigned i = 0; i < fn->args().size(); i++) vmap_[fn->args()[i]] = &*(ret->arg_begin() + i); // create blocks @@ -1185,15 +1143,22 @@ void generator::visit_function(ir::function* fn) { vmap_[block] = dst_block; } builder_->SetInsertPoint((BasicBlock*)vmap_[fn->blocks()[0]]); - fn_ = ret; + // initialize layouts + for(auto x: layouts_->get_all()) + visit_layout(x.second); + // generate LLVM-IR code + for(ir::basic_block *block: fn->blocks()) + visit_basic_block(block); + // finalize + finalize_function(fn); } + + + + void generator::visit_layout_hmma_884(analysis::layout_hmma_884_t* layout) { - machine_layouts_[layout] = new machine_layout_hmma_884_t(mod_, builder_, tgt_, type(layout->ty->get_scalar_ty(), *ctx_), a_axes_, axes_, - offset_a_i_, offset_a_k_, offset_b_j_, offset_b_k_, - pack_size_0_, pack_size_1_, - num_packs_0_, num_packs_1_, - layout); + machine_layouts_[layout] = new machine_layout_hmma_884_t(mod_, builder_, tgt_, type(layout->ty->get_scalar_ty(), *ctx_), a_axes_, axes_, layout); } void generator::visit_layout_scanline(analysis::layout_scanline_t* layout) { @@ -1205,10 +1170,24 @@ void generator::visit_layout_shared(analysis::layout_shared_t* layout) { machine_layouts_[layout] = new machine_layout_shared_t(mod_, builder_, tgt_, alloc_, sh_mem_ptr_, layout, vmap_, tmap_); } +void generator::visit_basic_block(ir::basic_block * block) { + BasicBlock *parent = (BasicBlock*)vmap_[block]; + builder_->SetInsertPoint(parent); + for(ir::instruction *i: block->get_inst_list()) + visit_value(i); + vmap_[block] = builder_->GetInsertBlock(); +} + +void generator::visit_argument(ir::argument* arg) { + +} + void generator::for_each(ir::value *x, const std::function& fn) { if(!x->get_type()->is_tile_ty()) return fn({}); else { +// if(tmap_.find(x) == tmap_.end()) +// tmap_[x] = machine_layouts_.at(layouts_->get(x))->create(x); if(auto *dt = dynamic_cast(tmap_.at(x))) dt->for_each(fn); } @@ -1313,13 +1292,8 @@ tile *machine_layout_distributed_t::create(ir::value *v) { machine_layout_hmma_884_t::machine_layout_hmma_884_t(Module *mod, Builder *builder, target *tgt, Type *ty, analysis::axes *a_axes, std::map& axes, - Value *&offset_a_i, Value *&offset_a_k, Value *&offset_b_j, Value *&offset_b_k, - unsigned &pack_size_0, unsigned &pack_size_1, - unsigned &num_packs_0, unsigned &num_packs_1, analysis::layout_hmma_884_t* layout) - : machine_layout_distributed_t(mod, builder, tgt, ty, a_axes, axes, layout), - offset_a_i_(offset_a_i), offset_a_k_(offset_a_k), offset_b_j_(offset_b_j), offset_b_k_(offset_b_k), - pack_size_0_(pack_size_0), pack_size_1_(pack_size_1), num_packs_0_(num_packs_0), num_packs_1_(num_packs_1) { + : machine_layout_distributed_t(mod, builder, tgt, ty, a_axes, axes, layout) { Value *warp_size = builder_->getInt32(32); Value* u_thread_id_0 = tgt_->get_local_id(mod_, *builder_, 0); @@ -1467,34 +1441,18 @@ machine_layout_scanline_t::machine_layout_scanline_t(Module *mod, Builder *build } } -finalizer::finalizer(Builder *builder, std::map& vmap, std::map& tmap) - : builder_(builder), vmap_(vmap), tmap_(tmap) { - +void generator::finalize_function(ir::function* fn) { + // finalize double-buffering + for(const auto& x: layouts_->get_all()) + visit_layout(x.second); + // finalize phi + for(ir::basic_block *block: fn->blocks()) + for(ir::instruction *inst: block->get_inst_list()) + if(auto *phi = dynamic_cast(inst)) + finalize_phi_node(phi); } -void finalizer::for_each(ir::value *x, const std::function& fn) { - if(!x->get_type()->is_tile_ty()) - return fn({}); - else { - if(auto *dt = dynamic_cast(tmap_.at(x))) - dt->for_each(fn); - } -} - -Value* finalizer::get_value(ir::value *x, const indices_t& idx) { - if(x->get_type()->is_tile_ty()) - return tmap_.at(x)->get_value(idx); - return vmap_.at(x); -} - -void finalizer::set_value(ir::value *x, const indices_t& idx, Value* v) { - if(x->get_type()->is_tile_ty()) - tmap_.at(x)->set_value(idx, v); - else - vmap_[x] = v; -} - -void finalizer::visit_phi_node(ir::phi_node* phi) { +void generator::finalize_phi_node(ir::phi_node* phi) { auto it = tmap_.find(phi); if(it != tmap_.end() && dynamic_cast(it->second)) return; @@ -1510,32 +1468,6 @@ void finalizer::visit_phi_node(ir::phi_node* phi) { } -void finalizer::visit_layout_shared(analysis::layout_shared_t* layout) { - if(layout->double_buffer) { - auto info = *layout->double_buffer; - ir::phi_node *phi = info.phi; - PHINode *ptr = (PHINode*)((shared_tile*)tmap_.at(phi))->get_pointer(); - PHINode *offset = (PHINode*)((shared_tile*)tmap_.at(phi))->get_offset(); - for(unsigned n = 0; n < phi->get_num_incoming(); n++){ - ir::basic_block* inc_block = phi->get_incoming_block(n); - ir::value* inc_val = phi->get_incoming_value(n); - BasicBlock *llvm_inc_block = (BasicBlock*)vmap_.at(inc_block); - shared_tile *inc_shared = (shared_tile*)tmap_.at(inc_val); - if(inc_val == info.latch){ - builder_->SetInsertPoint(llvm_inc_block->getTerminator()); - Value *next_offset = builder_->CreateNeg(offset); - offset->addIncoming(next_offset, llvm_inc_block); - } - else { - unsigned num_bytes = layout->ty->get_primitive_size_in_bits() / 8; - offset->addIncoming(builder_->getInt32(layout->size / (2*num_bytes)), llvm_inc_block); - } - ptr->addIncoming(inc_shared->get_pointer(), llvm_inc_block); - } - } -} - - } } diff --git a/lib/ir/function.cc b/lib/ir/function.cc index c15440e9d..84d52df72 100644 --- a/lib/ir/function.cc +++ b/lib/ir/function.cc @@ -25,6 +25,10 @@ unsigned argument::get_arg_no() const { return arg_no_; } +void argument::accept(visitor *v) { + v->visit_argument(this); +} + /* function */ function::function(function_type *ty, linkage_types_t linkage, diff --git a/lib/ir/value.cc b/lib/ir/value.cc index 5dfb0460c..a43aaa05e 100644 --- a/lib/ir/value.cc +++ b/lib/ir/value.cc @@ -32,6 +32,10 @@ void value::replace_all_uses_with(value *target){ throw std::runtime_error("not implemented"); } +void visitor::visit_value(ir::value* v) { + v->accept(this); +} + //===----------------------------------------------------------------------===// // user class @@ -69,5 +73,7 @@ void user::replace_uses_of_with(value *before, value *after) { before->erase_use(this); } + + } }