From ee387ff567586f65ff2731dedd9dd93e9fc89e51 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Sun, 13 Oct 2019 14:43:17 -0400 Subject: [PATCH] more cleaning --- include/triton/codegen/analysis/layout.h | 19 + include/triton/codegen/selection.h | 35 +- include/triton/ir/constant.h | 30 +- include/triton/ir/context_impl.h | 3 - include/triton/ir/function.h | 3 + include/triton/ir/instructions.h | 3 +- include/triton/ir/value.h | 4 + include/triton/ir/visitor.h | 26 +- include/triton/runtime/function.h | 1 - lib/codegen/selection.cc | 432 +++++++++++------------ lib/ir/constant.cc | 21 -- 11 files changed, 277 insertions(+), 300 deletions(-) diff --git a/include/triton/codegen/analysis/layout.h b/include/triton/codegen/analysis/layout.h index 01b65e8d2..096c45ea3 100644 --- a/include/triton/codegen/analysis/layout.h +++ b/include/triton/codegen/analysis/layout.h @@ -35,6 +35,18 @@ struct double_buffer_info_t { ir::phi_node* phi; }; +class layout_visitor; +class layout_hmma_884_t; +class layout_scanline_t; +class layout_shared_t; + + +class layout_visitor { +public: + virtual void visit_layout_hmma_884(layout_hmma_884_t*) = 0; + virtual void visit_layout_scanline(layout_scanline_t*) = 0; + virtual void visit_layout_shared(layout_shared_t*) = 0; +}; struct layout_t { layout_t(layout_type_t _type, @@ -43,6 +55,9 @@ struct layout_t { const std::vector &_values, size_t _id, analysis::align* align); + + virtual void accept(layout_visitor* vst) = 0; + layout_type_t type; std::vector axes; std::vector shapes; @@ -66,6 +81,7 @@ struct layout_hmma_884_t: public layout_t { const std::vector &_values, size_t _id, analysis::align* align); + void accept(layout_visitor* vst) { vst->visit_layout_hmma_884(this); } }; struct layout_scanline_t: public layout_t { @@ -75,6 +91,7 @@ struct layout_scanline_t: public layout_t { const std::vector &values, size_t _id, analysis::align* align); + void accept(layout_visitor* vst) { vst->visit_layout_scanline(this); } }; struct layout_shared_t: public layout_t { @@ -85,9 +102,11 @@ struct layout_shared_t: public layout_t { ir::type *ty, size_t _id, analysis::align* align); + void accept(layout_visitor* vst) { vst->visit_layout_shared(this); } }; + class layout { typedef ir::value* node_t; typedef std::map > graph_t; diff --git a/include/triton/codegen/selection.h b/include/triton/codegen/selection.h index 279c7475e..8d42d9dfe 100644 --- a/include/triton/codegen/selection.h +++ b/include/triton/codegen/selection.h @@ -147,7 +147,7 @@ private: }; -class generator: public ir::visitor { +class generator: public ir::visitor, public analysis::layout_visitor { private: void visit_hmma_dot(ir::dot_inst*, distributed_tile *TC, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK); void visit_scanline_dot(ir::dot_inst*, distributed_tile *TC, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK, Type *c_ty, Function *f_mul_add); @@ -163,7 +163,9 @@ public: generator(LLVMContext *ctx, Function *fn, + Module *dst, Builder *builder, + std::map& axes, std::map& vmap, std::map& tmap, target *tgt, @@ -176,7 +178,7 @@ public: unsigned num_packs_0, unsigned num_packs_1, unsigned pack_size_0, unsigned pack_size_1, unsigned num_warps) - : ctx_(ctx), fn_(fn), builder_(builder), vmap_(vmap), tmap_(tmap), tgt_(tgt), + : ctx_(ctx), fn_(fn), mod_(dst), builder_(builder), axes_(axes), vmap_(vmap), tmap_(tmap), tgt_(tgt), layouts_(layouts), alignment_(alignment), alloc_(alloc), sh_mem_ptr_(sh_mem_ptr), offset_a_i_(offset_a_i), offset_a_k_(offset_a_k), offset_b_j_(offset_b_j), offset_b_k_(offset_b_k), num_packs_0_(num_packs_0), num_packs_1_(num_packs_1), pack_size_0_(pack_size_0), pack_size_1_(pack_size_1), @@ -221,14 +223,27 @@ public: void visit_copy_from_shared_inst(ir::copy_from_shared_inst*); void visit_barrier_inst(ir::barrier_inst*); void visit_make_range_dyn(ir::make_range_dyn*); - void visit_make_range_sta(ir::make_range_sta*); void visit_make_range(ir::make_range*); + void visit_make_range_sta(ir::make_range_sta*); + void visit_undef_value(ir::undef_value*); + void visit_constant_int(ir::constant_int*); + void visit_constant_fp(ir::constant_fp*); + void visit_alloc_const(ir::alloc_const*); + + void visit_function(ir::function*); + + void visit_layout_hmma_884(analysis::layout_hmma_884_t*); + void visit_layout_scanline(analysis::layout_scanline_t*); + void visit_layout_shared(analysis::layout_shared_t*); + private: LLVMContext *ctx_; Function *fn_; Builder *builder_; + Module *mod_; + std::map& axes_; std::map& vmap_; std::map& tmap_; target *tgt_; @@ -249,29 +264,15 @@ class selection{ typedef std::map tmap_t; private: - // utils - Type *make_vector_ty(Type *ty, size_t vector_size); - std::vector extract_shapes(ir::value *v); - // LLVM conversions Type* llvm_type(ir::type *ty, LLVMContext &ctx); - Constant* llvm_constant(ir::constant *cst, LLVMContext &ctx); Value* llvm_alloc_const(ir::alloc_const *v, Module *module, Builder &builder); - ArrayType* llvm_linearized_tile_type(ir::type *ty, LLVMContext &ctx); Function* llvm_fn(ir::function *fn, Builder& builder, Module &dst); Value* alloc_shared(Builder &builder, Module& dst); // grid construction - void create_grids(std::vector &grids, - std::map &references, - ir::function *fn); void create_shared_tile(ir::value *v, Builder &builder, Value *sh_mem_ptr); void create_distributed_tile(ir::value *v, Builder &builder); - void create_tile(ir::value *v, Builder &builder, std::set &seen, Value *sh_mem_ptr); - void init_strided_scan_axes(const analysis::layout_t& layout, Builder &builder, Value *u_thread_id, Value *u_warp_id); - void init_hmma_axes(const analysis::layout_t& layout, Builder &builder, Value *u_thread_id, Value *u_warp_id); - void init_axes(const analysis::layout_t& layout, Builder &builder, Value *u_thread_id, Value *u_warp_id); - void init_layouts(ir::function *fn, Builder &builder, Value *sh_mem_ptr); // lower scalar instruction void lower_value(ir::value *src, Builder &builder, generator* gen, std::set& seen); diff --git a/include/triton/ir/constant.h b/include/triton/ir/constant.h index 0127acae6..671d5e5f0 100644 --- a/include/triton/ir/constant.h +++ b/include/triton/ir/constant.h @@ -6,6 +6,7 @@ #include "enums.h" #include "value.h" #include +#include "visitor.h" namespace triton{ namespace ir{ @@ -32,6 +33,7 @@ private: public: static undef_value* get(type* ty); std::string repr() const { return "undef"; } + void accept(visitor* vst) { vst->visit_undef_value(this); } }; @@ -44,31 +46,13 @@ public: virtual uint64_t get_value() const { return value_; } static constant_int *get(type *ty, uint64_t value); std::string repr() const { return std::to_string(value_); } + void accept(visitor* vst) { vst->visit_constant_int(this); } protected: uint64_t value_; }; -/* Metaparameter (int) */ -class metaparameter: public constant_int { -private: - metaparameter(type *ty, const std::vector& space); - -public: - static metaparameter *create(context &ctx, type *ty, unsigned lo, unsigned hi); - static metaparameter *create(context &ctx, type *ty, const std::vector& space); - void set_value(uint64_t value) { has_value_ = true; value_ = value; } - bool has_value() { return has_value_; } - const std::vector& get_space() { return space_; } - void set_space(const std::vector &space) { space_ = space; } - uint64_t get_value() const { assert(has_value_); return value_; } - std::string repr() const { return has_value_? std::to_string(value_) : "?" ;} -private: - std::vector space_; - bool has_value_; -}; - -/* constant fp */ +/* Constant fp */ class constant_fp: public constant{ constant_fp(type *ty, double value); @@ -79,13 +63,14 @@ public: static constant* get(context &ctx, double v); static constant* get(type *ty, double v); std::string repr() const { return std::to_string(value_); } + void accept(visitor* vst) { vst->visit_constant_fp(this); } private: double value_; }; -/* global value */ +/* Global Value */ class global_value: public constant { public: enum linkage_types_t { @@ -109,7 +94,6 @@ public: linkage_types_t linkage, const std::string &name, unsigned addr_space = 0); std::string repr() const { return get_name(); } - }; /* global variable */ @@ -118,6 +102,8 @@ public: alloc_const(type *ty, constant_int *size, const std::string &name = ""); std::string repr() const { return get_name(); } + void accept(visitor* vst) { vst->visit_alloc_const(this); } + }; diff --git a/include/triton/ir/context_impl.h b/include/triton/ir/context_impl.h index 5995de0d4..a016d1add 100644 --- a/include/triton/ir/context_impl.h +++ b/include/triton/ir/context_impl.h @@ -14,7 +14,6 @@ class constant; class constant_int; class constant_fp; class undef_value; -class metaparameter; /* Context impl */ class context_impl { @@ -36,8 +35,6 @@ public: std::map, constant_fp*> fp_constants_; // undef values std::map uv_constants_; - // Metaparameters - std::vector mp_constants_; }; } diff --git a/include/triton/ir/function.h b/include/triton/ir/function.h index 74af3abe2..8cf11275a 100644 --- a/include/triton/ir/function.h +++ b/include/triton/ir/function.h @@ -112,6 +112,9 @@ public: const attr_map_t &attrs() { return attrs_; } std::set get_attributes(argument* arg) { return attrs_[arg->get_arg_no() + 1]; } + // visitor + void accept(visitor *v) { v->visit_function(this); } + private: module *parent_; bool init_; diff --git a/include/triton/ir/instructions.h b/include/triton/ir/instructions.h index bbc75c63c..4409d1ccb 100644 --- a/include/triton/ir/instructions.h +++ b/include/triton/ir/instructions.h @@ -71,8 +71,6 @@ public: } // instruction id value_id_t get_id() const { return id_; } - // visit - virtual void accept(visitor *v) = 0; private: basic_block *parent_; @@ -759,6 +757,7 @@ public: static make_range_sta *get(make_range* range); make_range* get_range() const; std::string repr() const { return "nv_static_program_idx"; } + _TRITON_DEFINE_ACCEPT(make_range_sta) private: make_range *range_; diff --git a/include/triton/ir/value.h b/include/triton/ir/value.h index 0c2727a38..bf4a4aa9c 100644 --- a/include/triton/ir/value.h +++ b/include/triton/ir/value.h @@ -13,6 +13,7 @@ namespace ir{ class type; class use; class user; +class visitor; //===----------------------------------------------------------------------===// // value class @@ -74,6 +75,9 @@ public: void replace_all_uses_with(value *target); void replace_uses_of_with(value *before, value *after); + // Visitor + virtual void accept(visitor *v) = 0; + private: ops_t ops_; unsigned num_ops_; diff --git a/include/triton/ir/visitor.h b/include/triton/ir/visitor.h index ffe8d734c..e2310e94a 100644 --- a/include/triton/ir/visitor.h +++ b/include/triton/ir/visitor.h @@ -61,10 +61,25 @@ class copy_to_shared_inst; class copy_from_shared_inst; class barrier_inst; class make_range_dyn; -class make_range_sta; class make_range; +class make_range_sta; +class undef_value; +class constant_int; +class constant_fp; +class global_value; +class global_object; +class alloc_const; +class constant_fp; +class undef_value; +class constant_int; +class constant_fp; +class global_value; +class global_object; +class alloc_const; + +class function; class visitor { public: @@ -108,8 +123,15 @@ public: virtual void visit_copy_from_shared_inst(copy_from_shared_inst*) = 0; virtual void visit_barrier_inst(barrier_inst*) = 0; virtual void visit_make_range_dyn(make_range_dyn*) = 0; - virtual void visit_make_range_sta(make_range_sta*) = 0; virtual void visit_make_range(make_range*) = 0; + + virtual void visit_function(function*) = 0; + + virtual void visit_make_range_sta(make_range_sta*) = 0; + virtual void visit_undef_value(undef_value*) = 0; + virtual void visit_constant_int(constant_int*) = 0; + virtual void visit_constant_fp(constant_fp*) = 0; + virtual void visit_alloc_const(alloc_const*) = 0; }; } diff --git a/include/triton/runtime/function.h b/include/triton/runtime/function.h index c12f9c6ca..e312cfded 100644 --- a/include/triton/runtime/function.h +++ b/include/triton/runtime/function.h @@ -43,7 +43,6 @@ namespace ir { class module; class function; class context; -class metaparameter; } namespace runtime{ diff --git a/lib/codegen/selection.cc b/lib/codegen/selection.cc index 9d95e0a41..9a2f1f569 100644 --- a/lib/codegen/selection.cc +++ b/lib/codegen/selection.cc @@ -343,16 +343,6 @@ Type *selection::llvm_type(ir::type *ty, LLVMContext &ctx) { throw std::runtime_error("unknown conversion from ir::type to Type"); } -/* convert ir::constant to Constant */ -Constant *selection::llvm_constant(ir::constant *cst, LLVMContext &ctx) { - Type *dst_ty = llvm_type(cst->get_type()->get_scalar_ty(), ctx); - if(auto* cc = dynamic_cast(cst)) - return ConstantInt::get(dst_ty, cc->get_value()); - if(auto* cc = dynamic_cast(cst)) - return ConstantFP::get(dst_ty, cc->get_value()); - // unknown constant - throw std::runtime_error("unknown conversion from ir::constant to Constant"); -} /* convert ir::alloc_const to llvm::GlobalVariable */ Value* selection::llvm_alloc_const(ir::alloc_const *v, Module *module, IRBuilder<> &builder) { @@ -387,145 +377,6 @@ inline int32_t ceil(int32_t num, int32_t div){ return (num + div - 1)/div; } -void selection::init_strided_scan_axes(const analysis::layout_t& layout, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) { - auto order = layout.order; - const auto& shapes = layout.shapes; - size_t dim = shapes.size(); - std::vector nts = layout.nts; - std::vector mts = layout.mts; - Value* full_thread_id = builder.CreateAdd(builder.CreateMul(u_warp_id, builder.getInt32(32)), u_thread_id); - std::vector thread_id = delinearize(full_thread_id, order, mts, builder); - // Create axes - for(unsigned k = 0; k < dim; k++) { - std::string str_k = std::to_string(k); - Value *contiguous_k = builder.getInt32(nts[k]); - Value *scaled_thread_id = builder.CreateMul(thread_id[k], contiguous_k); - unsigned per_block = nts[k] * mts[k]; - unsigned per_thread = nts[k] * shapes[k] / per_block; - std::vector idx_list(per_thread); - for(unsigned n = 0 ; n < per_thread; n++){ - unsigned offset = n / nts[k] * per_block + n % nts[k]; - idx_list[n] = builder.CreateAdd(scaled_thread_id, builder.getInt32(offset), "idx_" + str_k + "_" + std::to_string(n)); - } - axes_[layout.axes[k]] = distributed_axis{nts[k], idx_list, thread_id[k]}; - } -} - -void selection::init_hmma_axes(const analysis::layout_t& layout, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) { - const auto& shapes = layout.shapes; - if(shapes.size() > 3) - throw std::runtime_error("unsupported"); - - bool is_batched = shapes.size() >= 3; - - Value *_1 = builder.getInt32(1); - Value *_2 = builder.getInt32(2); - Value *_3 = builder.getInt32(3); - Value *_4 = builder.getInt32(4); - Value *_16 = builder.getInt32(16); - - // fragments per warp - unsigned fpw_0 = layout.fpw.at(0); - unsigned fpw_1 = layout.fpw.at(1); - unsigned fpw_2 = is_batched ? layout.fpw.at(2) : 1; - // warps per tile - unsigned wpt_0 = layout.wpt.at(0); - unsigned wpt_1 = layout.wpt.at(1); - unsigned wpt_2 = is_batched ? layout.wpt.at(2) : 1; - // hmma warp tile size - unsigned hmma_wts_0 = fpw_0 * 8; - unsigned hmma_wts_1 = fpw_1 * 8; - unsigned hmma_wts_2 = is_batched ? fpw_2 : 1; - // hmma block tile size - unsigned hmma_bts_0 = hmma_wts_0 * wpt_0; - unsigned hmma_bts_1 = hmma_wts_1 * wpt_1; - unsigned hmma_bts_2 = is_batched ? hmma_wts_2 * wpt_2 : 1; - // number of repetition - unsigned num_rep_0 = shapes[0] / hmma_bts_0; - unsigned num_rep_1 = shapes[1] / hmma_bts_1; - unsigned num_rep_2 = is_batched ? shapes[2] / hmma_bts_2 : 1; - // size of each pack (interleaving) - pack_size_0_ = std::min(num_rep_0, 1); - pack_size_1_ = std::min(num_rep_1, 1); - // number of packs (interleaving) - num_packs_0_ = num_rep_0 / pack_size_0_; - num_packs_1_ = num_rep_1 / pack_size_1_; - - /* intra warp offset */ - // offset of quad in pair - Value *in_pair_off_a = builder.CreateMul(builder.CreateUDiv(builder.CreateAnd(u_thread_id, _16), builder.getInt32(4)), - builder.getInt32(fpw_0 * pack_size_0_)); - Value *in_pair_off_b = builder.CreateMul(builder.CreateUDiv(builder.CreateAnd(u_thread_id, _16), builder.getInt32(4)), - builder.getInt32(fpw_1 * pack_size_1_)); - - // Quad pair id - Value *pair_a_id = builder.CreateUDiv(builder.CreateURem(u_thread_id, _16), _4); - Value *pair_b_id = builder.CreateUDiv(builder.CreateURem(u_thread_id, _16), _4); - pair_a_id = builder.CreateURem(pair_a_id, builder.getInt32(fpw_0)); - pair_b_id = builder.CreateUDiv(pair_b_id, builder.getInt32(fpw_0)); - pair_b_id = builder.CreateURem(pair_b_id, builder.getInt32(fpw_1)); - // Quad pair offset - Value *pair_a_off = builder.CreateMul(pair_a_id, builder.getInt32(4 * pack_size_0_)); - Value *pair_b_off = builder.CreateMul(pair_b_id, builder.getInt32(4 * pack_size_1_)); - - /* inter warp offset */ - Value *warp_id_0 = builder.CreateURem(u_warp_id, builder.getInt32(wpt_0)); - Value *warp_id_12 = builder.CreateUDiv(u_warp_id, builder.getInt32(wpt_0)); - Value *warp_id_1 = builder.CreateURem(warp_id_12, builder.getInt32(wpt_1)); - Value *warp_id_2 = builder.CreateUDiv(warp_id_12, builder.getInt32(wpt_1)); - Value *warp_offset_i = builder.CreateMul(warp_id_0, builder.getInt32(hmma_wts_0 * pack_size_0_)); - Value *warp_offset_j = builder.CreateMul(warp_id_1, builder.getInt32(hmma_wts_1 * pack_size_1_)); - - /* offsets */ - // a offset - offset_a_i_ = builder.CreateAdd(warp_offset_i, builder.CreateAdd(pair_a_off, in_pair_off_a)); - offset_a_k_ = builder.CreateAnd(u_thread_id, _3); - // b offsets - offset_b_j_ = builder.CreateAdd(warp_offset_j, builder.CreateAdd(pair_b_off, in_pair_off_b)); - offset_b_k_ = builder.CreateAnd(u_thread_id, _3); - - // c offsets - Value *offset_c_i = builder.CreateAdd(builder.CreateAnd(u_thread_id, _1), offset_a_i_); - Value *offset_c_j = builder.CreateAdd(builder.CreateAnd(u_thread_id, _2), - builder.CreateAdd(warp_offset_j, pair_b_off)); - - /* indices */ - // i indices - std::vector idx_i; - for(unsigned pack = 0; pack < num_packs_0_; pack++) - for(unsigned ii = 0; ii < pack_size_0_; ii++) - for(unsigned i = 0; i < 2; i++){ - idx_i.push_back(builder.CreateAdd(offset_c_i, builder.getInt32(pack*hmma_bts_0*pack_size_0_ + ii*4 + i*2))); - } - // j indices - std::vector idx_j; - for(unsigned pack = 0; pack < num_packs_1_; pack++) - for(unsigned jj = 0; jj < pack_size_1_; jj++) - for(unsigned j = 0; j < 2; j++){ - idx_j.push_back(builder.CreateAdd(offset_c_j, builder.getInt32(pack*hmma_bts_1*pack_size_1_ + jj*4 + j*4*fpw_1*pack_size_1_))); - idx_j.push_back(builder.CreateAdd(offset_c_j, builder.getInt32(pack*hmma_bts_1*pack_size_1_ + jj*4 + j*4*fpw_1*pack_size_1_ + 1))); - } - // z indices - std::vector idx_z; - for(unsigned pack = 0; pack < num_rep_2; pack++) - idx_z.push_back(builder.CreateAdd(warp_id_2, builder.getInt32(pack*hmma_bts_2))); - - - /* axes */ - axes_[layout.axes[0]] = distributed_axis{1, idx_i, warp_id_0}; - axes_[layout.axes[1]] = distributed_axis{1, idx_j, warp_id_1}; - if(is_batched) - axes_[layout.axes[2]] = distributed_axis{1, idx_z, warp_id_2}; -} - - -void selection::init_axes(const analysis::layout_t& layout, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) { - if(layout.type == analysis::HMMA_884) - init_hmma_axes(layout, builder, u_thread_id, u_warp_id); - else if(layout.type == analysis::SCANLINE) - init_strided_scan_axes(layout, builder, u_thread_id, u_warp_id); -} - /* ------------------- * ---- Init Tiles ---- * ------------------- */ @@ -549,7 +400,7 @@ void selection::create_shared_tile(ir::value *v, IRBuilder<> &builder, Value *sh if(parent->empty()) builder.SetInsertPoint(parent); else - builder.SetInsertPoint(&*parent->getFirstInsertionPt()); + builder.SetInsertPoint(&*parent->getFirstNonPHI()); // create double-buffered pointer PHINode *ptr = builder.CreatePHI(ptr_ty, 2); PHINode *offset = builder.CreatePHI(builder.getInt32Ty(), 2); @@ -587,41 +438,6 @@ void selection::create_distributed_tile(ir::value *v, IRBuilder<> &builder) { tmap_.insert({v, T}); } -void selection::create_tile(ir::value *v, IRBuilder<> &builder, - std::set &seen, Value *sh_mem_ptr) { - if(!v->get_type()->is_tile_ty() || !seen.insert(v).second) - return; - if(auto *user = dynamic_cast(v)) - for(ir::value *op: user->ops()) - create_tile(op, builder, seen, sh_mem_ptr); - auto *i = dynamic_cast(v); - if(i && layouts_->get(i)->type == analysis::SHARED && !dynamic_cast(v)) - create_shared_tile(i, builder, sh_mem_ptr); - else - create_distributed_tile(v, builder); -} - -void selection::init_layouts(ir::function *fn, IRBuilder<> &builder, Value *sh_mem_ptr){ - // fetch linear ID - Module *mod = builder.GetInsertBlock()->getParent()->getParent(); - Value *warp_size = builder.getInt32(32); - Value* u_thread_id = tgt_->get_local_id(mod, builder, 0); - Value *u_thread_warp_id = builder.CreateURem(u_thread_id, warp_size); - Value *u_warp_id = builder.CreateUDiv(u_thread_id, warp_size); - // create grid - for(auto x: layouts_->get_all()) - init_axes(*x.second, builder, u_thread_warp_id, u_warp_id); - // create tile - std::set seen; - for(ir::basic_block *block: fn->blocks()) - for(ir::instruction *i: block->get_inst_list()){ - if(!i->get_type()->is_tile_ty()) - continue; - create_tile(i, builder, seen, sh_mem_ptr); - } -} - - bool is_trans(ir::value *v) { if(dynamic_cast(v)) { @@ -641,51 +457,34 @@ void selection::lower_value(ir::value *src, IRBuilder<> &builder, generator* gen if(!seen.insert(src).second) return; + BasicBlock *current = builder.GetInsertBlock(); + if(src->get_type()->is_tile_ty()){ + builder.SetInsertPoint(&*builder.GetInsertBlock()->getParent()->begin()); + auto *i = dynamic_cast(src); + if(i && layouts_->get(i)->type == analysis::SHARED && !dynamic_cast(src)){ + create_shared_tile(i, builder, sh_mem_ptr_); + } + else + create_distributed_tile(src, builder); + } + builder.SetInsertPoint(current); + + auto *inst = dynamic_cast(src); if(inst && !dynamic_cast(src)) for(ir::value *op: inst->ops()) lower_value(op, builder, gen, seen); - BasicBlock *current = builder.GetInsertBlock(); + builder.SetInsertPoint(current); auto *phi = dynamic_cast(src); - bool phi_inserted = phi && !current->empty(); - if(phi_inserted && current->getFirstNonPHI()) + if(phi && !current->empty() && current->getFirstNonPHI()) builder.SetInsertPoint(&*current->getFirstNonPHI()); + if(auto *usr = dynamic_cast(src)) + usr->accept(gen); - if(dynamic_cast(src)){ - distributed_tile *T = (distributed_tile *)tmap_.at(src); - T->for_each([&](indices_t idx){ - assert(idx.size() == 1); - T->set_value(idx, idx[0]); - }); - } - else if(dynamic_cast(src)){ - distributed_tile *T = (distributed_tile *)tmap_.at(src); - T->for_each([&](indices_t idx){ - assert(idx.size() == 1); - BinaryOperator *bin_add = dyn_cast(idx[0]); - assert(bin_add); - Value *res = bin_add->getOperand(1); - assert(isa(res)); - T->set_value(idx, res); - }); - } - else if(auto *cst = dynamic_cast(src)){ - vmap_[cst] = llvm_constant(cst, builder.getContext()); - } - else if(inst){ - inst->accept(gen); - } - - if(phi_inserted && current->getFirstNonPHI()) + if(phi && !current->empty() && current->getFirstNonPHI()) builder.SetInsertPoint(current); - -// if(dynamic_cast(src)) -// for(ir::value *op: inst->ops()) -// lower_value(op, builder, seen); - - } /* ---------------------------- @@ -702,12 +501,6 @@ inline llvm::Attribute llvm_attr(llvm::LLVMContext& ctx, ir::attribute attr) { } } -ArrayType* selection::llvm_linearized_tile_type(ir::type *ty, LLVMContext &ctx) { - unsigned size = 1; - for(auto shape: ty->get_tile_shapes()) - size *= shape; - return ArrayType::get(llvm_type(ty->get_scalar_ty(), ctx), size); -} Function* selection::llvm_fn(ir::function *fn, IRBuilder<>& builder, Module& dst) { LLVMContext &ctx = builder.getContext(); @@ -777,6 +570,9 @@ void selection::run(ir::module &src, Module &dst) { for(ir::alloc_const *x: src.allocs()) vmap_[x] = llvm_alloc_const(x, &dst, dst_builder); + // allocate shared memory + sh_mem_ptr_ = alloc_shared(dst_builder, dst); + // iterate over functions std::set seen; @@ -785,14 +581,13 @@ void selection::run(ir::module &src, Module &dst) { // create LLVM function Function *ffn = llvm_fn(fn, dst_builder, dst); - // allocate shared memory - sh_mem_ptr_ = alloc_shared(dst_builder, dst); + // create tile + generator gen(&dst_ctx, ffn, &dst, &dst_builder, axes_, vmap_, tmap_, tgt_, layouts_, alignment_, alloc_, sh_mem_ptr_, + offset_a_i_, offset_a_k_, offset_b_j_, offset_b_k_, num_packs_0_, num_packs_1_, pack_size_0_, pack_size_1_, num_warps_ ); // initialize layouts - init_layouts(fn, dst_builder, sh_mem_ptr_); - - generator gen(&dst_ctx, ffn, &dst_builder, vmap_, tmap_, tgt_, layouts_, alignment_, alloc_, sh_mem_ptr_, - offset_a_i_, offset_a_k_, offset_b_j_, offset_b_k_, num_packs_0_, num_packs_1_, pack_size_0_, pack_size_1_, num_warps_ ); + for(auto x: layouts_->get_all()) + x.second->accept(&gen); // generate LLVM-IR code std::map last_block; @@ -1536,6 +1331,179 @@ Type *generator::type(ir::type *ty) { throw std::runtime_error("unknown conversion from ir::type to Type"); } +void generator::visit_undef_value(ir::undef_value *ud) { + vmap_[ud] = llvm::UndefValue::get(type(ud->get_type())); +} + +void generator::visit_constant_int(ir::constant_int *cst){ + Type *ty = type(cst->get_type()->get_scalar_ty()); + vmap_[cst] = ConstantInt::get(ty, cst->get_value()); +} + +void generator::visit_constant_fp(ir::constant_fp *cst){ + Type *ty = type(cst->get_type()->get_scalar_ty()); + vmap_[cst] = ConstantFP::get(ty, cst->get_value()); +} + +void generator::visit_alloc_const(ir::alloc_const *alloc) { + unsigned size = ((ir::constant_int*)alloc->get_operand(0))->get_value(); + Type *element_ty = type(alloc->get_type()->get_pointer_element_ty()); + Type *array_ty = llvm::ArrayType::get(element_ty, size); + Value *array = new llvm::GlobalVariable(*mod_, array_ty, false, llvm::GlobalVariable::ExternalLinkage, + nullptr, alloc->get_name(), nullptr, llvm::GlobalVariable::NotThreadLocal, 4); + vmap_[alloc] = builder_->CreateBitCast(array, element_ty->getPointerTo(4)); +} + + +void generator::visit_function(ir::function*) { + +} + +void generator::visit_layout_hmma_884(analysis::layout_hmma_884_t* layout) { + Value *warp_size = builder_->getInt32(32); + Value* u_thread_id_0 = tgt_->get_local_id(mod_, *builder_, 0); + Value *u_thread_id = builder_->CreateURem(u_thread_id_0, warp_size); + Value *u_warp_id = builder_->CreateUDiv(u_thread_id_0, warp_size); + + const auto& shapes = layout->shapes; + if(shapes.size() > 3) + throw std::runtime_error("unsupported"); + + bool is_batched = shapes.size() >= 3; + + Value *_1 = builder_->getInt32(1); + Value *_2 = builder_->getInt32(2); + Value *_3 = builder_->getInt32(3); + Value *_4 = builder_->getInt32(4); + Value *_16 = builder_->getInt32(16); + + // fragments per warp + unsigned fpw_0 = layout->fpw.at(0); + unsigned fpw_1 = layout->fpw.at(1); + unsigned fpw_2 = is_batched ? layout->fpw.at(2) : 1; + // warps per tile + unsigned wpt_0 = layout->wpt.at(0); + unsigned wpt_1 = layout->wpt.at(1); + unsigned wpt_2 = is_batched ? layout->wpt.at(2) : 1; + // hmma warp tile size + unsigned hmma_wts_0 = fpw_0 * 8; + unsigned hmma_wts_1 = fpw_1 * 8; + unsigned hmma_wts_2 = is_batched ? fpw_2 : 1; + // hmma block tile size + unsigned hmma_bts_0 = hmma_wts_0 * wpt_0; + unsigned hmma_bts_1 = hmma_wts_1 * wpt_1; + unsigned hmma_bts_2 = is_batched ? hmma_wts_2 * wpt_2 : 1; + // number of repetition + unsigned num_rep_0 = shapes[0] / hmma_bts_0; + unsigned num_rep_1 = shapes[1] / hmma_bts_1; + unsigned num_rep_2 = is_batched ? shapes[2] / hmma_bts_2 : 1; + // size of each pack (interleaving) + pack_size_0_ = std::min(num_rep_0, 1); + pack_size_1_ = std::min(num_rep_1, 1); + // number of packs (interleaving) + num_packs_0_ = num_rep_0 / pack_size_0_; + num_packs_1_ = num_rep_1 / pack_size_1_; + + /* intra warp offset */ + // offset of quad in pair + Value *in_pair_off_a = builder_->CreateMul(builder_->CreateUDiv(builder_->CreateAnd(u_thread_id, _16), builder_->getInt32(4)), + builder_->getInt32(fpw_0 * pack_size_0_)); + Value *in_pair_off_b = builder_->CreateMul(builder_->CreateUDiv(builder_->CreateAnd(u_thread_id, _16), builder_->getInt32(4)), + builder_->getInt32(fpw_1 * pack_size_1_)); + + // Quad pair id + Value *pair_a_id = builder_->CreateUDiv(builder_->CreateURem(u_thread_id, _16), _4); + Value *pair_b_id = builder_->CreateUDiv(builder_->CreateURem(u_thread_id, _16), _4); + pair_a_id = builder_->CreateURem(pair_a_id, builder_->getInt32(fpw_0)); + pair_b_id = builder_->CreateUDiv(pair_b_id, builder_->getInt32(fpw_0)); + pair_b_id = builder_->CreateURem(pair_b_id, builder_->getInt32(fpw_1)); + // Quad pair offset + Value *pair_a_off = builder_->CreateMul(pair_a_id, builder_->getInt32(4 * pack_size_0_)); + Value *pair_b_off = builder_->CreateMul(pair_b_id, builder_->getInt32(4 * pack_size_1_)); + + /* inter warp offset */ + Value *warp_id_0 = builder_->CreateURem(u_warp_id, builder_->getInt32(wpt_0)); + Value *warp_id_12 = builder_->CreateUDiv(u_warp_id, builder_->getInt32(wpt_0)); + Value *warp_id_1 = builder_->CreateURem(warp_id_12, builder_->getInt32(wpt_1)); + Value *warp_id_2 = builder_->CreateUDiv(warp_id_12, builder_->getInt32(wpt_1)); + Value *warp_offset_i = builder_->CreateMul(warp_id_0, builder_->getInt32(hmma_wts_0 * pack_size_0_)); + Value *warp_offset_j = builder_->CreateMul(warp_id_1, builder_->getInt32(hmma_wts_1 * pack_size_1_)); + + /* offsets */ + // a offset + offset_a_i_ = builder_->CreateAdd(warp_offset_i, builder_->CreateAdd(pair_a_off, in_pair_off_a)); + offset_a_k_ = builder_->CreateAnd(u_thread_id, _3); + // b offsets + offset_b_j_ = builder_->CreateAdd(warp_offset_j, builder_->CreateAdd(pair_b_off, in_pair_off_b)); + offset_b_k_ = builder_->CreateAnd(u_thread_id, _3); + + // c offsets + Value *offset_c_i = builder_->CreateAdd(builder_->CreateAnd(u_thread_id, _1), offset_a_i_); + Value *offset_c_j = builder_->CreateAdd(builder_->CreateAnd(u_thread_id, _2), + builder_->CreateAdd(warp_offset_j, pair_b_off)); + + /* indices */ + // i indices + std::vector idx_i; + for(unsigned pack = 0; pack < num_packs_0_; pack++) + for(unsigned ii = 0; ii < pack_size_0_; ii++) + for(unsigned i = 0; i < 2; i++){ + idx_i.push_back(builder_->CreateAdd(offset_c_i, builder_->getInt32(pack*hmma_bts_0*pack_size_0_ + ii*4 + i*2))); + } + // j indices + std::vector idx_j; + for(unsigned pack = 0; pack < num_packs_1_; pack++) + for(unsigned jj = 0; jj < pack_size_1_; jj++) + for(unsigned j = 0; j < 2; j++){ + idx_j.push_back(builder_->CreateAdd(offset_c_j, builder_->getInt32(pack*hmma_bts_1*pack_size_1_ + jj*4 + j*4*fpw_1*pack_size_1_))); + idx_j.push_back(builder_->CreateAdd(offset_c_j, builder_->getInt32(pack*hmma_bts_1*pack_size_1_ + jj*4 + j*4*fpw_1*pack_size_1_ + 1))); + } + // z indices + std::vector idx_z; + for(unsigned pack = 0; pack < num_rep_2; pack++) + idx_z.push_back(builder_->CreateAdd(warp_id_2, builder_->getInt32(pack*hmma_bts_2))); + + + /* axes */ + axes_[layout->axes[0]] = distributed_axis{1, idx_i, warp_id_0}; + axes_[layout->axes[1]] = distributed_axis{1, idx_j, warp_id_1}; + if(is_batched) + axes_[layout->axes[2]] = distributed_axis{1, idx_z, warp_id_2}; +} + +void generator::visit_layout_scanline(analysis::layout_scanline_t* layout) { + Value *warp_size = builder_->getInt32(32); + Value* u_thread_id_0 = tgt_->get_local_id(mod_, *builder_, 0); + Value *u_thread_id = builder_->CreateURem(u_thread_id_0, warp_size); + Value *u_warp_id = builder_->CreateUDiv(u_thread_id_0, warp_size); + + auto order = layout->order; + const auto& shapes = layout->shapes; + size_t dim = shapes.size(); + std::vector nts = layout->nts; + std::vector mts = layout->mts; + Value* full_thread_id = builder_->CreateAdd(builder_->CreateMul(u_warp_id, builder_->getInt32(32)), u_thread_id); + std::vector thread_id = delinearize(full_thread_id, order, mts, *builder_); + // Create axes + for(unsigned k = 0; k < dim; k++) { + std::string str_k = std::to_string(k); + Value *contiguous_k = builder_->getInt32(nts[k]); + Value *scaled_thread_id = builder_->CreateMul(thread_id[k], contiguous_k); + unsigned per_block = nts[k] * mts[k]; + unsigned per_thread = nts[k] * shapes[k] / per_block; + std::vector idx_list(per_thread); + for(unsigned n = 0 ; n < per_thread; n++){ + unsigned offset = n / nts[k] * per_block + n % nts[k]; + idx_list[n] = builder_->CreateAdd(scaled_thread_id, builder_->getInt32(offset), "idx_" + str_k + "_" + std::to_string(n)); + } + axes_[layout->axes[k]] = distributed_axis{nts[k], idx_list, thread_id[k]}; + } +} + +void generator::visit_layout_shared(analysis::layout_shared_t*) { + +} + void generator::for_each(ir::value *x, const std::function& fn) { if(!x->get_type()->is_tile_ty()) return fn({}); diff --git a/lib/ir/constant.cc b/lib/ir/constant.cc index 0eff5261e..8a3f1a343 100644 --- a/lib/ir/constant.cc +++ b/lib/ir/constant.cc @@ -76,27 +76,6 @@ constant *constant_fp::get(type *ty, double v){ return result; } -// metaparameter -metaparameter::metaparameter(type *ty, const std::vector &space) - : constant_int(ty, 0), space_(space), has_value_(false){ } - -metaparameter* metaparameter::create(context &ctx, type *ty, unsigned lo, unsigned hi) { - context_impl *impl = ctx.p_impl.get(); - std::vector space; - for(unsigned i = lo; i <= hi; i *= 2) - space.push_back(i); - metaparameter *result = new metaparameter(ty, space); - impl->mp_constants_.push_back(result); - return result; -} - -metaparameter* metaparameter::create(context &ctx, type *ty, const std::vector &space) { - context_impl *impl = ctx.p_impl.get(); - metaparameter *result = new metaparameter(ty, space); - impl->mp_constants_.push_back(result); - return result; -} - // undef value undef_value::undef_value(type *ty)