diff --git a/include/triton/codegen/analysis/allocation.h b/include/triton/codegen/analysis/allocation.h index 49f378886..e49f5c591 100644 --- a/include/triton/codegen/analysis/allocation.h +++ b/include/triton/codegen/analysis/allocation.h @@ -27,14 +27,14 @@ public: allocation(liveness *live) : liveness_(live) { } // accessors - bool has_offset(const layout_t *x) const { return offsets_.find(x) != offsets_.end(); } - unsigned offset(const layout_t *x) const { return offsets_.at(x); } + bool has_offset(const data_layout *x) const { return offsets_.find(x) != offsets_.end(); } + unsigned offset(const data_layout *x) const { return offsets_.at(x); } unsigned allocated_size() const { return allocated_size_; } // run void run(ir::module& mod); private: - std::map offsets_; + std::map offsets_; size_t allocated_size_; // dependences liveness *liveness_; diff --git a/include/triton/codegen/analysis/layout.h b/include/triton/codegen/analysis/layout.h index 074bfb27c..13ddfafb4 100644 --- a/include/triton/codegen/analysis/layout.h +++ b/include/triton/codegen/analysis/layout.h @@ -22,11 +22,106 @@ namespace analysis{ class axes; class align; +class layout_visitor; +class data_layout; +class mma884_layout; +class scanline_layout; +class shared_layout; -enum layout_type_t { - HMMA_884, - SCANLINE, - SHARED + +class layout_visitor { +public: + virtual void visit_layout(data_layout *); + virtual void visit_layout_hmma_884(mma884_layout*) = 0; + virtual void visit_layout_scanline(scanline_layout*) = 0; + virtual void visit_layout_shared(shared_layout*) = 0; +}; + +class data_layout { +protected: + enum id_t { + HMMA_884, + SCANLINE, + SHARED + }; + + typedef std::vector axes_t; + typedef std::vector shape_t; + typedef std::vector order_t; + typedef std::vector values_t; + +private: + template + T* downcast(id_t id) { + if(id_ == id) + return static_cast(this); + return nullptr; + } + +public: + data_layout(id_t id, + const std::vector& axes, + const std::vector &shape, + const std::vector &values, + analysis::align* align); + // visitor + virtual void accept(layout_visitor* vst) = 0; + // downcast + mma884_layout* to_mma884() { return downcast(HMMA_884); } + scanline_layout* to_scanline() { return downcast(SCANLINE); } + shared_layout* to_shared() { return downcast(SHARED); } + // accessors + size_t get_rank() { return shape_.size(); } + const shape_t& get_shape() const { return shape_; } + const order_t& get_order() const { return order_; } + const values_t& get_values() const { return values_;} + int get_axis(size_t k) const { return axes_.at(k); } + const int get_order(size_t k) const { return order_.at(k); } + // find the position of given axis + size_t find_axis(int to_find) const; + + +private: + id_t id_; + axes_t axes_; + values_t values_; + +protected: + order_t order_; + shape_t shape_; +}; + +class mma884_layout: public data_layout { +public: + mma884_layout(size_t num_warps, + const std::vector& axes, + const std::vector& shapes, + const std::vector &values, + analysis::align* align); + void accept(layout_visitor* vst) { vst->visit_layout_hmma_884(this); } + // accessor + int fpw(size_t k) { return fpw_.at(k); } + int wpt(size_t k) { return wpt_.at(k); } + +private: + std::vector fpw_; + std::vector wpt_; +}; + +struct scanline_layout: public data_layout { + scanline_layout(size_t num_warps, + const std::vector& axes, + const std::vector& shape, + const std::vector &values, + analysis::align* align); + void accept(layout_visitor* vst) { vst->visit_layout_scanline(this); } + // accessor + int mts(size_t k) { return mts_.at(k); } + int nts(size_t k) { return nts_.at(k); } + +private: + std::vector mts_; + std::vector nts_; }; struct double_buffer_info_t { @@ -35,90 +130,33 @@ struct double_buffer_info_t { ir::phi_node* phi; }; -class layout_visitor; -class layout_t; -class layout_hmma_884_t; -class layout_scanline_t; -class layout_shared_t; +class shared_layout: public data_layout { +private: + static bool is_loop_latch(ir::phi_node *phi, ir::instruction *terminator); + static void extract_double_bufferable(ir::value *v, std::shared_ptr& res); - -class layout_visitor { public: - virtual void visit_layout(layout_t *); - virtual void visit_layout_hmma_884(layout_hmma_884_t*) = 0; - virtual void visit_layout_scanline(layout_scanline_t*) = 0; - virtual void visit_layout_shared(layout_shared_t*) = 0; -}; - -class layout_hmma_884_t; -class layout_scanline_t; -class layout_shared_t; - -struct layout_t { - layout_t(layout_type_t _type, - const std::vector& _axes, - const std::vector &_shapes, - const std::vector &_values, - ir::type *_ty, - analysis::align* align); - // visitor - virtual void accept(layout_visitor* vst) = 0; - // downcast - layout_hmma_884_t* to_hmma884(); - layout_scanline_t* to_scanline(); - layout_shared_t* to_shared(); - - - layout_type_t type; - std::vector axes; - std::vector shapes; - std::vector values; - std::vector order; - ir::type *ty; -}; - -struct layout_hmma_884_t: public layout_t { - layout_hmma_884_t(size_t num_warps, - const std::vector& _axes, - const std::vector& _shapes, - const std::vector &_values, - ir::type *_ty, - analysis::align* align); - void accept(layout_visitor* vst) { vst->visit_layout_hmma_884(this); } - - std::vector fpw; - std::vector wpt; -}; - -struct layout_scanline_t: public layout_t { - layout_scanline_t(size_t num_warps, - const std::vector& _axes, - const std::vector& _shapes, - const std::vector &values, - ir::type *_ty, - analysis::align* align); - void accept(layout_visitor* vst) { vst->visit_layout_scanline(this); } - - std::vector mts; - std::vector nts; -}; - -struct layout_shared_t: public layout_t { - layout_shared_t(const layout_t *arg, - const std::vector& _axes, - const std::vector& _shapes, - const std::vector &values, - ir::type *ty, - analysis::align* align); + shared_layout(const data_layout *arg, + const std::vector& axes, + const std::vector& shapes, + const std::vector &values_, + ir::type *ty, + analysis::align* align); void accept(layout_visitor* vst) { vst->visit_layout_shared(this); } + // accessors + size_t get_size() { return size_; } + ir::type* get_type() { return ty_; } + double_buffer_info_t* get_double_buffer() { return double_buffer_.get(); } - std::shared_ptr double_buffer; - size_t size; +private: + size_t size_; + ir::type *ty_; + std::shared_ptr double_buffer_; }; -class layout { +class layouts { typedef ir::value* node_t; typedef std::map > graph_t; @@ -127,23 +165,23 @@ private: void connect(ir::value *x, ir::value *y); void make_graph(ir::instruction *i); - void init_hmma_tile(layout_t& layout); - void init_scanline_tile(layout_t &layout); + void init_hmma_tile(data_layout& layouts); + void init_scanline_tile(data_layout &layouts); void create(size_t id, const std::vector& values); public: // constructor - layout(analysis::axes *axes, analysis::align *align, size_t num_warps); + layouts(analysis::axes *axes, analysis::align *align, size_t num_warps); // accessors - unsigned layout_of(ir::value *value) const; - const std::vector& values_of(unsigned id) const; - size_t num_layouts() const; - layout_t* get(size_t id); - layout_t* get(ir::value *v); - std::map &get_all(); - size_t tmp(ir::instruction* i); + unsigned layout_of(ir::value *value) const { return groups_.at(value); } + const std::vector& values_of(unsigned id) const { return values_.at(id); } + size_t num_layouts() const { return values_.size();} + data_layout* get(size_t id) { return layouts_.at(id); } + data_layout* get(ir::value *v) { return get(layout_of(v));} + std::map &get_all() { return layouts_; } + size_t tmp(ir::instruction* i) { return tmp_.at((ir::value*)i);} // execution void run(ir::module &mod); @@ -155,7 +193,7 @@ private: tools::graph graph_; std::map groups_; std::map> values_; - std::map layouts_; + std::map layouts_; std::map tmp_; }; diff --git a/include/triton/codegen/analysis/liveness.h b/include/triton/codegen/analysis/liveness.h index e0158dc8a..a95d62a06 100644 --- a/include/triton/codegen/analysis/liveness.h +++ b/include/triton/codegen/analysis/liveness.h @@ -23,8 +23,8 @@ namespace analysis{ typedef unsigned slot_index; class tiles; -class layout; -class layout_t; +class layouts; +class data_layout; struct segment { slot_index start; @@ -42,20 +42,20 @@ struct segment { class liveness { private: - typedef std::map intervals_map_t; + typedef std::map intervals_map_t; public: // constructor - liveness(layout *l): layouts_(l){ } + liveness(layouts *l): layouts_(l){ } // accessors const intervals_map_t& get() const { return intervals_; } - segment get(layout_shared_t* v) const { return intervals_.at(v); } + segment get(shared_layout* v) const { return intervals_.at(v); } // run void run(ir::module &mod); private: // analysis - layout *layouts_; + layouts *layouts_; intervals_map_t intervals_; }; diff --git a/include/triton/codegen/selection/generator.h b/include/triton/codegen/selection/generator.h index 1f18bc6e1..8b8c5bf64 100644 --- a/include/triton/codegen/selection/generator.h +++ b/include/triton/codegen/selection/generator.h @@ -35,7 +35,7 @@ class align; class allocation; class cts; class axes; -class layout; +class layouts; } // typedef typedef llvm::IRBuilder indices_t; // forward -class machine_layout_t; +class machine_data_layout; class tile; class shared_tile; class distributed_tile; @@ -74,13 +74,13 @@ private: void visit_outer_dot(ir::dot_inst*, distributed_tile *TA, distributed_tile *TB, distributed_tile *TD, unsigned NK, Type *c_ty, Function *f_mul_add); - void finalize_shared_layout(analysis::layout_shared_t*); + void finalize_shared_layout(analysis::shared_layout*); void finalize_function(ir::function*); void finalize_phi_node(ir::phi_node*); public: generator(analysis::axes *a_axes, - analysis::layout *layouts, + analysis::layouts *layouts, analysis::align *alignment, analysis::allocation *alloc, target *tgt, @@ -139,9 +139,9 @@ public: void visit_basic_block(ir::basic_block*); void visit_argument(ir::argument*); - void visit_layout_hmma_884(analysis::layout_hmma_884_t*); - void visit_layout_scanline(analysis::layout_scanline_t*); - void visit_layout_shared(analysis::layout_shared_t*); + void visit_layout_hmma_884(analysis::mma884_layout*); + void visit_layout_scanline(analysis::scanline_layout*); + void visit_layout_shared(analysis::shared_layout*); void visit(ir::module &, llvm::Module &); @@ -150,13 +150,13 @@ private: Builder* builder_; Module *mod_; - std::map machine_layouts_; + std::map machine_layouts_; analysis::axes *a_axes_; std::map axes_; std::map vmap_; std::map tmap_; target *tgt_; - analysis::layout *layouts_; + analysis::layouts *layouts_; analysis::align *alignment_; analysis::allocation *alloc_; Value *sh_mem_ptr_; diff --git a/include/triton/codegen/selection/machine_layout.h b/include/triton/codegen/selection/machine_layout.h index 5ea34f3f3..5458f15d3 100644 --- a/include/triton/codegen/selection/machine_layout.h +++ b/include/triton/codegen/selection/machine_layout.h @@ -36,7 +36,7 @@ class align; class allocation; class cts; class axes; -class layout; +class layouts; } typedef llvm::IRBuilder& vmap, std::map& tmap); @@ -83,7 +83,7 @@ public: target *tgt_; analysis::allocation* alloc_; Value *&sh_mem_ptr_; - analysis::layout_shared_t* layout_; + analysis::shared_layout* layout_; std::map& vmap_; std::map& tmap_; @@ -94,29 +94,28 @@ public: }; -class machine_layout_distributed_t: public machine_layout_t { +class machine_distributed_layout: public machine_data_layout { public: - machine_layout_distributed_t(Module *mod, Builder *builder, target *tgt, Type *ty, - analysis::axes *a_axes, std::map& axes, - analysis::layout_t* layout); + machine_distributed_layout(Module *mod, Builder *builder, target *tgt, + analysis::axes *a_axes, std::map& axes, + analysis::data_layout* layout); tile* create(ir::value *v); Module *mod_; Builder *builder_; target *tgt_; - Type *ty_; analysis::axes *a_axes_; std::map& axes_; - analysis::layout_t* layout_; + analysis::data_layout* layout_; }; -class machine_layout_hmma_884_t: public machine_layout_distributed_t { +class machine_mma884_layout: public machine_distributed_layout { public: - machine_layout_hmma_884_t(Module *mod, Builder *builder, - target *tgt, Type *ty, - analysis::axes *a_axes, std::map& axes, - analysis::layout_hmma_884_t* layout); + machine_mma884_layout(Module *mod, Builder *builder, + target *tgt, + analysis::axes *a_axes, std::map& axes, + analysis::mma884_layout* layout); Value *offset_a_i_, *offset_a_k_; Value *offset_b_j_, *offset_b_k_; unsigned pack_size_0_; @@ -125,12 +124,12 @@ public: unsigned num_packs_1_; }; -class machine_layout_scanline_t: public machine_layout_distributed_t { +class machine_scanline_layout: public machine_distributed_layout { public: - machine_layout_scanline_t(Module *mod, Builder *builder, - target *tgt, Type *ty, - analysis::axes *a_axes, std::map& axes, - analysis::layout_scanline_t* layout); + machine_scanline_layout(Module *mod, Builder *builder, + target *tgt, + analysis::axes *a_axes, std::map& axes, + analysis::scanline_layout* layout); }; } diff --git a/include/triton/codegen/selection/machine_value.h b/include/triton/codegen/selection/machine_value.h index aab1f023a..67c2ed394 100644 --- a/include/triton/codegen/selection/machine_value.h +++ b/include/triton/codegen/selection/machine_value.h @@ -47,11 +47,11 @@ class align; class allocation; class cts; class axes; -class layout; +class layouts; } class distributed_axis; -class machine_layout_t; +class machine_data_layout; class tile; class shared_tile; class distributed_tile; diff --git a/include/triton/codegen/transform/coalesce.h b/include/triton/codegen/transform/coalesce.h index e0ea0ea97..1b15306f1 100644 --- a/include/triton/codegen/transform/coalesce.h +++ b/include/triton/codegen/transform/coalesce.h @@ -19,7 +19,7 @@ namespace codegen{ namespace analysis{ class align; - class layout; + class layouts; class cts; } @@ -32,12 +32,12 @@ private: ir::value* rematerialize(ir::value *v, ir::builder& builder, std::map& seen); public: - coalesce(analysis::align* align, triton::codegen::analysis::layout *layouts); + coalesce(analysis::align* align, triton::codegen::analysis::layouts *layouts); void run(ir::module &mod); private: analysis::align* align_; - analysis::layout* layout_; + analysis::layouts* layout_; }; } diff --git a/include/triton/codegen/transform/membar.h b/include/triton/codegen/transform/membar.h index 820992da7..015f44f3d 100644 --- a/include/triton/codegen/transform/membar.h +++ b/include/triton/codegen/transform/membar.h @@ -17,7 +17,7 @@ namespace analysis{ class allocation; class liveness; -class layout; +class layouts; class cts; } @@ -41,13 +41,13 @@ private: std::set &insert_loc, std::set &safe_war); public: - membar(analysis::liveness *liveness, analysis::layout *layouts, analysis::allocation *alloc): + membar(analysis::liveness *liveness, analysis::layouts *layouts, analysis::allocation *alloc): liveness_(liveness), layouts_(layouts), alloc_(alloc) {} void run(ir::module &mod); private: analysis::liveness *liveness_; - analysis::layout *layouts_; + analysis::layouts *layouts_; analysis::allocation *alloc_; }; diff --git a/lib/codegen/analysis/allocation.cc b/lib/codegen/analysis/allocation.cc index 0cff27640..d0a66543a 100644 --- a/lib/codegen/analysis/allocation.cc +++ b/lib/codegen/analysis/allocation.cc @@ -15,22 +15,22 @@ void allocation::run(ir::module &mod) { using std::min; typedef std::multimap triples_map_type; - std::vector I; + std::vector I; for(auto x: liveness_->get()) I.push_back(x.first); - std::vector J = I; + std::vector J = I; triples_map_type H; H.insert({0, segment{0, INT_MAX}}); - std::vector V; - std::map starts; + std::vector V; + std::map starts; while(!J.empty()){ auto h_it = H.begin(); unsigned w = h_it->first; segment xh = h_it->second; H.erase(h_it); - auto j_it = std::find_if(J.begin(), J.end(), [&](layout_shared_t* JJ){ + auto j_it = std::find_if(J.begin(), J.end(), [&](shared_layout* JJ){ segment xj = liveness_->get(JJ); bool res = xj.intersect(xh); for(auto val: H) @@ -38,7 +38,7 @@ void allocation::run(ir::module &mod) { return res; }); if(j_it != J.end()){ - unsigned size = (*j_it)->size; + unsigned size = (*j_it)->get_size(); segment xj = liveness_->get(*j_it); starts[*j_it] = w; H.insert({w + size, segment{max(xh.start, xj.start), min(xh.end, xj.end)}}); @@ -52,14 +52,14 @@ void allocation::run(ir::module &mod) { } // Build interference graph - std::map> interferences; - for(layout_shared_t* x: V) - for(layout_shared_t* y: V){ + std::map> interferences; + for(shared_layout* x: V) + for(shared_layout* y: V){ if(x == y) continue; unsigned X0 = starts[x], Y0 = starts[y]; - unsigned NX = x->size; - unsigned NY = y->size; + unsigned NX = x->get_size(); + unsigned NY = y->get_size(); segment XS = {X0, X0 + NX}; segment YS = {Y0, Y0 + NY}; if(liveness_->get(x).intersect(liveness_->get(y)) @@ -68,17 +68,17 @@ void allocation::run(ir::module &mod) { } // Initialize colors - std::map colors; - for(layout_shared_t* X: V) + std::map colors; + for(shared_layout* X: V) colors[X] = (X==V[0])?0:-1; // First-fit graph coloring std::vector available(V.size()); - for(layout_shared_t* x: V){ + for(shared_layout* x: V){ // Non-neighboring colors are available std::fill(available.begin(), available.end(), true); - for(layout_shared_t* Y: interferences[x]){ + for(shared_layout* Y: interferences[x]){ int color = colors[Y]; if(color >= 0) available[color] = false; @@ -89,17 +89,17 @@ void allocation::run(ir::module &mod) { } // Finalize allocation - for(layout_shared_t* x: V){ + for(shared_layout* x: V){ unsigned Adj = 0; - for(layout_shared_t* y: interferences[x]) - Adj = std::max(Adj, starts[y] + y->size); + for(shared_layout* y: interferences[x]) + Adj = std::max(Adj, starts[y] + y->get_size()); offsets_[x] = starts[x] + colors[x] * Adj; } // Save maximum size of induced memory space allocated_size_ = 0; - for(layout_shared_t* x: V) - allocated_size_ = std::max(allocated_size_, starts[x] + x->size); + for(shared_layout* x: V) + allocated_size_ = std::max(allocated_size_, starts[x] + x->get_size()); } } diff --git a/lib/codegen/analysis/layout.cc b/lib/codegen/analysis/layout.cc index 2136d4162..8b4a3242a 100644 --- a/lib/codegen/analysis/layout.cc +++ b/lib/codegen/analysis/layout.cc @@ -12,57 +12,15 @@ namespace triton{ namespace codegen{ namespace analysis{ +/* -------------------------------- * + * Helper Functions * + * -------------------------------- */ -// constructor -layout::layout(analysis::axes *axes, analysis::align *align, size_t num_warps) - : axes_(axes), align_(align), num_warps_(num_warps) { } - -// get group id -unsigned layout::layout_of(ir::value *value) const -{ return groups_.at(value); } - -// get values -const std::vector& layout::values_of(unsigned id) const -{ return values_.at(id); } - -// get number of groups -size_t layout::num_layouts() const -{ return values_.size(); } - -// connect two values -void layout::connect(ir::value *x, ir::value *y) { - if(x == y) - return; - if(!x->get_type()->is_tile_ty()) - return; - if(!y->get_type()->is_tile_ty()) - return; - std::vector x_axes = axes_->get(x); - std::vector y_axes = axes_->get(y); - std::set sx_axes(x_axes.begin(), x_axes.end()); - std::set sy_axes(y_axes.begin(), y_axes.end()); - std::set common; - std::set_intersection(sx_axes.begin(), sx_axes.end(), - sy_axes.begin(), sy_axes.end(), - std::inserter(common, common.begin())); - graph_.add_edge(x, x); - graph_.add_edge(y, y); - if(!common.empty()) - graph_.add_edge(x, y); +inline unsigned clamp(unsigned x, unsigned lo, unsigned hi) { + return std::min(std::max(x, lo), hi); } -// make graph -void layout::make_graph(ir::instruction *i) { - for(ir::value* opx: i->ops()) - for(ir::value* opy: i->ops()){ - connect(i, opx); - connect(opx, opy); - } -} - - -// hmma -bool is_hmma_c(ir::value *v){ +inline bool is_hmma_c(ir::value *v){ bool result = false; if(auto *x = dynamic_cast(v)){ ir::value *a = x->get_operand(0); @@ -75,23 +33,7 @@ bool is_hmma_c(ir::value *v){ return result; } -layout_t* layout::get(size_t id) { - return layouts_.at(id); -} - -layout_t* layout::get(ir::value *v) { - return layouts_.at(groups_.at(v)); -} - -std::map& layout::get_all() { - return layouts_; -} - -size_t layout::tmp(ir::instruction* i) { - return tmp_.at(i); -} - -void extract_io_use(ir::value *v, std::set& result) { +inline void extract_io_use(ir::value *v, std::set& result) { for(ir::user* u: v->get_users()){ auto i = dynamic_cast(u); if(i && i->get_pointer_operand() == v) @@ -99,7 +41,7 @@ void extract_io_use(ir::value *v, std::set& result) { } } -void extract_dot_use(ir::value *v, ir::value*& result, size_t n) { +inline void extract_dot_use(ir::value *v, ir::value*& result, size_t n) { for(ir::user* u: v->get_users()){ auto i = dynamic_cast(u); if(i && i->get_operand(n) == v) @@ -107,7 +49,7 @@ void extract_dot_use(ir::value *v, ir::value*& result, size_t n) { } } -void extract_hmma_dot_use(ir::value *v, ir::value*& result, size_t n) { +inline void extract_hmma_dot_use(ir::value *v, ir::value*& result, size_t n) { for(ir::user* u: v->get_users()){ auto i = dynamic_cast(u); if(i && is_hmma_c(i) && i->get_operand(n) == v) @@ -116,7 +58,6 @@ void extract_hmma_dot_use(ir::value *v, ir::value*& result, size_t n) { } - inline bool is_trans(ir::value *v) { if(dynamic_cast(v)) { return true; @@ -131,104 +72,103 @@ inline bool is_trans(ir::value *v) { } -void layout_visitor::visit_layout(layout_t *layout) { +/* -------------------------------- * + * Layout Visitor * + * -------------------------------- */ + +void layout_visitor::visit_layout(data_layout *layout) { layout->accept(this); } -layout_t::layout_t(layout_type_t _type, - const std::vector &_axes, - const std::vector &_shapes, - const std::vector &_values, ir::type *_ty, - analysis::align* align): type(_type), axes(_axes), shapes(_shapes), values(_values), ty(_ty) { +/* -------------------------------- * + * Base Data Layout * + * -------------------------------- */ + +data_layout::data_layout(id_t id, + const std::vector &axes, + const std::vector &shape, + const std::vector &values, + analysis::align* align): id_(id), axes_(axes), shape_(shape), values_(values) { // io pointer std::set ptr; - for(ir::value* v: values) + for(ir::value* v: values_) extract_io_use(v, ptr); - order.resize(axes.size()); - std::iota(order.begin(), order.end(), 0); + order_.resize(axes_.size()); + std::iota(order_.begin(), order_.end(), 0); auto largest = std::max_element(ptr.begin(), ptr.end(), [&](ir::value *x, ir::value *y){ return x->get_type()->get_tile_rank() < y->get_type()->get_tile_rank(); }); if(*largest){ auto max_contiguous = align->contiguous(*largest); - std::sort(order.begin(), order.end(), [&](unsigned a, unsigned b) { + std::sort(order_.begin(), order_.end(), [&](unsigned a, unsigned b) { return max_contiguous[a] > max_contiguous[b]; }); } } -// downcast -layout_hmma_884_t* layout_t::to_hmma884() { - assert(type == HMMA_884); - return static_cast(this); +size_t data_layout::find_axis(int to_find) const { + auto it = std::find(axes_.begin(), axes_.end(), to_find); + return std::distance(axes_.begin(), it); } -layout_scanline_t* layout_t::to_scanline() { - assert(type == SCANLINE); - return static_cast(this); -} -layout_shared_t* layout_t::to_shared() { - assert(type == SHARED); - return static_cast(this); -} +/* -------------------------------- * + * MMA Layout * + * -------------------------------- */ -inline unsigned clamp(unsigned x, unsigned lo, unsigned hi) { - return std::min(std::max(x, lo), hi); -} - -layout_hmma_884_t::layout_hmma_884_t(size_t num_warps, - const std::vector& _axes, - const std::vector& _shapes, - const std::vector &values, ir::type *_ty, - analysis::align* align): layout_t(HMMA_884, _axes, _shapes, values, _ty, align) { - unsigned shape_0 = shapes[0]; - unsigned shape_1 = shapes[1]; +mma884_layout::mma884_layout(size_t num_warps, + const std::vector& axes, + const std::vector& shape, + const std::vector &values, + analysis::align* align): data_layout(HMMA_884, axes, shape, values, align) { /* fragments per warp */ // try to make things as square as possible to maximize data re-use - fpw = {1, 1, 1}; + fpw_ = {1, 1, 1}; std::vector fpw_nm1; - unsigned num_fragments = std::min((shape_0/8)*(shape_1/8), 4); + unsigned num_fragments = std::min((shape_[0]/8)*(shape_[1]/8), 4); do { - fpw_nm1 = fpw; - if(fpw[0]*fpw[1] < num_fragments) - fpw[0] = clamp(fpw[0]*2, 1, shape_0 / 8); - if(fpw[0]*fpw[1] < num_fragments) - fpw[1] = clamp(fpw[1]*2, 1, shape_1 / 8); - }while(fpw_nm1 != fpw); + fpw_nm1 = fpw_; + if(fpw_[0]*fpw_[1] < num_fragments) + fpw_[0] = clamp(fpw_[0]*2, 1, shape_[0] / 8); + if(fpw_[0]*fpw_[1] < num_fragments) + fpw_[1] = clamp(fpw_[1]*2, 1, shape_[1] / 8); + }while(fpw_nm1 != fpw_); + /* warps per tile */ // try to make things as square as possible to maximize data re-use - wpt = {1, 1, 1}; + wpt_ = {1, 1, 1}; std::vector wpt_nm1; do{ - wpt_nm1 = wpt; - if(wpt[0] * wpt[1] * wpt[2] < num_warps) - wpt[0] = clamp(wpt[0]*2, 1, shape_0 / (fpw[0]*8)); - if(wpt[0] * wpt[1] * wpt[2] < num_warps) - wpt[1] = clamp(wpt[1]*2, 1, shape_1 / (fpw[1]*8)); - }while(wpt_nm1 != wpt); + wpt_nm1 = wpt_; + if(wpt_[0] * wpt_[1] * wpt_[2] < num_warps) + wpt_[0] = clamp(wpt_[0]*2, 1, shape_[0] / (fpw_[0]*8)); + if(wpt_[0] * wpt_[1] * wpt_[2] < num_warps) + wpt_[1] = clamp(wpt_[1]*2, 1, shape_[1] / (fpw_[1]*8)); + }while(wpt_nm1 != wpt_); + /* sanity check */ unsigned effective_num_warps = 1; - for(size_t d = 0; d < shapes.size(); d++) - effective_num_warps *= wpt[d]; - + for(size_t d = 0; d < shape.size(); d++) + effective_num_warps *= wpt_[d]; if(num_warps != effective_num_warps) throw std::runtime_error("cannot create a kernel with this amount of warps"); } +/* -------------------------------- * + * Scanline Layout * + * -------------------------------- */ - -layout_scanline_t::layout_scanline_t(size_t num_warps, - const std::vector& _axes, - const std::vector& _shapes, - const std::vector &values, ir::type *_ty, - analysis::align* align): layout_t(SCANLINE, _axes, _shapes, values, _ty, align){ - unsigned size = std::accumulate(shapes.begin(), shapes.end(), 1, std::multiplies()); +scanline_layout::scanline_layout(size_t num_warps, + const std::vector& axes, + const std::vector& shape, + const std::vector &values, + analysis::align* align): data_layout(SCANLINE, axes, shape, values, align){ + unsigned size = std::accumulate(shape_.begin(), shape_.end(), 1, std::multiplies()); unsigned num_threads = num_warps * 32; - nts.resize(shapes.size()); - mts.resize(shapes.size()); + nts_.resize(shape_.size()); + mts_.resize(shape_.size()); bool is_dot = std::any_of(values.begin(), values.end(), [&](ir::value* v) { return dynamic_cast(v); }); @@ -238,34 +178,39 @@ layout_scanline_t::layout_scanline_t(size_t num_warps, if(auto *st = dynamic_cast(usr)) ptr = st->get_pointer_operand(); - unsigned i = order[0]; + unsigned i = order_[0]; int contiguous = 4; if(ptr) contiguous = std::min(align->contiguous(ptr)[i], 4); - nts[i] = clamp(size / num_threads, 1, std::min(contiguous, shapes[i])); - mts[i] = clamp(num_threads, 1, shapes[i] / nts[i]); - size /= shapes[i]; - num_threads /= mts[i]; + nts_[i] = clamp(size / num_threads, 1, std::min(contiguous, shape_[i])); + mts_[i] = clamp(num_threads, 1, shape_[i] / nts_[i]); + size /= shape_[i]; + num_threads /= mts_[i]; if(is_dot) - nts[order[1]] = clamp(size / num_threads, 1, std::min(4, shapes[order[1]])); - for(size_t d = 1; d < shapes.size(); d++){ - i = order[d]; + nts_[order_[1]] = clamp(size / num_threads, 1, std::min(4, shape_[order_[1]])); + for(size_t d = 1; d < shape_.size(); d++){ + i = order_[d]; if(d > 1 || !is_dot) - nts[i] = 1; - mts[i] = clamp(num_threads, 1, shapes[i] / nts[i]); - num_threads = num_threads / mts[i]; + nts_[i] = 1; + mts_[i] = clamp(num_threads, 1, shape_[i] / nts_[i]); + num_threads = num_threads / mts_[i]; } /* sanity check */ unsigned effective_num_threads = 1; - for(size_t d = 0; d < shapes.size(); d++) - effective_num_threads *= mts[d]; + for(size_t d = 0; d < shape_.size(); d++) + effective_num_threads *= mts_[d]; if(num_warps * 32 != effective_num_threads) throw std::runtime_error("cannot create a kernel with this amount of warps"); } -inline bool is_loop_latch(ir::phi_node *phi, ir::instruction *terminator){ + +/* -------------------------------- * + * Shared Layout * + * -------------------------------- */ + +bool shared_layout::is_loop_latch(ir::phi_node *phi, ir::instruction *terminator){ if(phi->get_parent() != terminator->get_parent()) return false; if(auto *br = dynamic_cast(terminator)) @@ -278,7 +223,7 @@ inline bool is_loop_latch(ir::phi_node *phi, ir::instruction *terminator){ } -void extract_double_bufferable(ir::value *v, std::shared_ptr& res) { +void shared_layout::extract_double_bufferable(ir::value *v, std::shared_ptr& res) { auto* phi = dynamic_cast(v); if(!phi || phi->get_num_incoming() != 2) return; @@ -303,22 +248,22 @@ void extract_double_bufferable(ir::value *v, std::shared_ptr& _axes, - const std::vector& _shapes, +shared_layout::shared_layout(const data_layout *arg, + const std::vector& axes, + const std::vector& shape, const std::vector &values, ir::type *ty, - analysis::align* align): layout_t(SHARED, _axes, _shapes, values, ty, align) { + analysis::align* align): data_layout(SHARED, axes, shape, values, align), ty_(ty) { - size = 0; + size_ = 0; // double-buffering for(ir::value *v: values) - extract_double_bufferable(v, double_buffer); + extract_double_bufferable(v, double_buffer_); // order - std::vector arg_order = arg ? arg->order : std::vector{0}; - order = arg_order; + std::vector arg_order = arg ? arg->get_order() : std::vector{0}; + order_ = arg_order; ir::value* dot_a = nullptr; ir::value* dot_b = nullptr; @@ -330,48 +275,84 @@ layout_shared_t::layout_shared_t(const layout_t *arg, extract_hmma_dot_use(v, hmma_dot_a, 0); extract_hmma_dot_use(v, hmma_dot_b, 1); } + + + // non-mma ordering std::vector col = {0, 1}; std::vector row = {1, 0}; - for(size_t s = 2; s < shapes.size(); s++){ + for(size_t s = 2; s < get_rank(); s++){ col.push_back(s); row.push_back(s); } - - bool is_nonhmma_dot_a = dot_a && !hmma_dot_a; bool is_nonhmma_dot_b = dot_b && !hmma_dot_b; if(is_nonhmma_dot_a) - order = is_trans(dot_a) ? row : col; + order_ = is_trans(dot_a) ? row : col; else if(is_nonhmma_dot_b) - order = is_trans(dot_b) ? col : row; -// else -// order = row; + order_ = is_trans(dot_b) ? col : row; + // padding size_t pad = 0; if(hmma_dot_a){ - bool row = is_trans(hmma_dot_a) ^ order[0] != 0; - pad = 24 - shapes[row ? 0 : 1] % 32; + bool row = is_trans(hmma_dot_a) ^ order_[0] != 0; + pad = 24 - shape_[row ? 0 : 1] % 32; } else if(hmma_dot_b){ - bool row = is_trans(hmma_dot_b) ^ order[0] != 0; - pad = 24 - shapes[row ? 1 : 0] % 32; + bool row = is_trans(hmma_dot_b) ^ order_[0] != 0; + pad = 24 - shape_[row ? 1 : 0] % 32; } - else if(order != arg_order) { + else if(order_ != arg_order) { pad = 4; } - shapes[order[0]] += pad; + shape_[order_[0]] += pad; // size - size = ty->get_primitive_size_in_bits() / 8; - for(auto s: shapes) - size *= s; - if(double_buffer) - size *= 2; + size_ = ty_->get_primitive_size_in_bits() / 8; + for(auto s: shape_) + size_ *= s; + if(double_buffer_) + size_ *= 2; } -// layout factory method -void layout::create(size_t id, const std::vector& values) { +/* -------------------------------- * + * ---- Layouts Inference Pass ---- * + * -------------------------------- */ + +layouts::layouts(analysis::axes *axes, analysis::align *align, size_t num_warps) + : axes_(axes), align_(align), num_warps_(num_warps) { } + + +void layouts::connect(ir::value *x, ir::value *y) { + if(x == y) + return; + if(!x->get_type()->is_tile_ty()) + return; + if(!y->get_type()->is_tile_ty()) + return; + std::vector x_axes = axes_->get(x); + std::vector y_axes = axes_->get(y); + std::set sx_axes(x_axes.begin(), x_axes.end()); + std::set sy_axes(y_axes.begin(), y_axes.end()); + std::set common; + std::set_intersection(sx_axes.begin(), sx_axes.end(), + sy_axes.begin(), sy_axes.end(), + std::inserter(common, common.begin())); + graph_.add_edge(x, x); + graph_.add_edge(y, y); + if(!common.empty()) + graph_.add_edge(x, y); +} + +void layouts::make_graph(ir::instruction *i) { + for(ir::value* opx: i->ops()) + for(ir::value* opy: i->ops()){ + connect(i, opx); + connect(opx, opy); + } +} + +void layouts::create(size_t id, const std::vector& values) { auto it_hmma_c = std::find_if(values.begin(), values.end(), &is_hmma_c); auto cmp = [](ir::value* x, ir::value *y) { return x->get_type()->get_tile_ranks1() < @@ -387,18 +368,18 @@ void layout::create(size_t id, const std::vector& values) { }); // type if(it_hmma_c != values.end()) - layouts_[id] = new layout_hmma_884_t(num_warps_, axes, shapes, values, largest->get_type()->get_scalar_ty(), align_); + layouts_[id] = new mma884_layout(num_warps_, axes, shapes, values, align_); else if(it_cts != values.end()){ ir::copy_to_shared_inst *cts = (ir::copy_to_shared_inst*)*it_cts; ir::value *arg = cts->get_operand(0); create(groups_.at(arg), values_.at(groups_.at(arg))); - layouts_[id] = new layout_shared_t(get(arg), axes, shapes, values, largest->get_type()->get_scalar_ty(), align_); + layouts_[id] = new shared_layout(get(arg), axes, shapes, values, largest->get_type()->get_scalar_ty(), align_); } else - layouts_[id] = new layout_scanline_t(num_warps_, axes, shapes, values, largest->get_type()->get_scalar_ty(), align_); + layouts_[id] = new scanline_layout(num_warps_, axes, shapes, values, align_); } -void layout::run(ir::module &mod) { +void layouts::run(ir::module &mod) { // make graph graph_.clear(); ir::for_each_instruction(mod, [this](ir::instruction* i) { @@ -422,35 +403,35 @@ void layout::run(ir::module &mod) { // shape auto shapes = arg->get_type()->get_tile_shapes(); unsigned shape_ax = shapes[axis]; - layout_scanline_t *layout = get(arg)->to_scanline(); - unsigned per_thread = layout->nts[axis]; + scanline_layout *layout = get(arg)->to_scanline(); + unsigned per_thread = layout->nts(axis); unsigned depth = shape_ax / per_thread; shapes[axis] = depth; // create layout - layouts_[id] = new layout_shared_t(layout, axes_->get(arg), shapes, {red}, red->get_type()->get_scalar_ty(), align_); + layouts_[id] = new shared_layout(layout, axes_->get(arg), shapes, {red}, red->get_type()->get_scalar_ty(), align_); tmp_[red] = id; } if(auto *recoalasce = dynamic_cast(i)){ ir::value *val = recoalasce->get_operand(0); - layout_t* in_layout = get(val); - layout_t* out_layout = get(i); - if(in_layout->type != HMMA_884) + mma884_layout* in_layout = get(val)->to_mma884(); + scanline_layout* out_layout = get(i)->to_scanline(); + if(!in_layout || !out_layout) return; id++; ir::type::tile_shapes_t in_shape = val->get_type()->get_tile_shapes(); ir::type::tile_shapes_t shape(in_shape.size()); - size_t ld = out_layout->order[0]; + size_t ld = out_layout->get_order(0); shape[ld] = in_shape[ld]; for(size_t k = 0; k < in_shape.size(); k++) if(k != ld) - shape[k] = 4*in_layout->to_hmma884()->fpw[k]*in_layout->to_hmma884()->wpt[k]; + shape[k] = 4*in_layout->to_mma884()->fpw(k)*in_layout->to_mma884()->wpt(k); // create layout - layouts_[id] = new layout_shared_t(out_layout, axes_->get(val), shape, {recoalasce}, val->get_type()->get_scalar_ty(), align_); + layouts_[id] = new shared_layout(out_layout, axes_->get(val), shape, {recoalasce}, val->get_type()->get_scalar_ty(), align_); tmp_[recoalasce] = id; } if(auto *atom = dynamic_cast(i)){ id++; - layouts_[id] = new layout_shared_t(nullptr, {}, {1}, {atom}, atom->get_type()->get_scalar_ty(), align_); + layouts_[id] = new shared_layout(nullptr, {}, {1}, {atom}, atom->get_type()->get_scalar_ty(), align_); tmp_[atom] = id; } }); diff --git a/lib/codegen/analysis/liveness.cc b/lib/codegen/analysis/liveness.cc index a4bb41f5e..224a93fc9 100644 --- a/lib/codegen/analysis/liveness.cc +++ b/lib/codegen/analysis/liveness.cc @@ -27,18 +27,18 @@ void liveness::run(ir::module &mod) { // create live intervals for(auto &x: layouts_->get_all()) { - if(x.second->type != SHARED) + shared_layout* layout = x.second->to_shared(); + if(!layout) continue; - layout_shared_t* layout = x.second->to_shared(); // users std::set users; - for(ir::value *v: layout->values){ + for(ir::value *v: layout->get_values()){ for(ir::user *u: v->get_users()) users.insert(u); } // compute intervals unsigned start = INT32_MAX; - for(ir::value *v: layout->values) + for(ir::value *v: layout->get_values()) if(indices.find(v) != indices.end()) start = std::min(start, indices.at(v)); unsigned end = 0; diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index 5cf964915..ae7d0e876 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -174,7 +174,7 @@ inline bool is_trans(ir::value *v) { generator::generator(analysis::axes *a_axes, - analysis::layout *layouts, + analysis::layouts *layouts, analysis::align *alignment, analysis::allocation *alloc, target *tgt, @@ -295,7 +295,7 @@ void generator::visit_unmasked_load_inst(ir::unmasked_load_inst* x) { } // find vector size ir::value *ptr = x->get_pointer_operand(); - size_t ld = layouts_->get(ptr)->order[0]; + size_t ld = layouts_->get(ptr)->get_order(0); unsigned alignment = std::max(alignment_->get(ptr, ld), 1); @@ -337,7 +337,7 @@ void generator::visit_unmasked_load_inst(ir::unmasked_load_inst* x) { void generator::visit_masked_load_inst(ir::masked_load_inst* x) { // find vector size ir::value *ptr = x->get_pointer_operand(); - size_t ld = layouts_->get(ptr)->order[0]; + size_t ld = layouts_->get(ptr)->get_order(0); unsigned alignment = alignment_->get(ptr, ld); distributed_tile *pointers = (distributed_tile*)tmap_.at(ptr); distributed_tile *masks = (distributed_tile*)tmap_.at(x->get_mask_operand()); @@ -603,7 +603,7 @@ void generator::visit_atomic_add_inst(ir::atomic_add_inst*) { void generator::visit_hmma_dot(ir::dot_inst* dot, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK) { const auto& shapes = dot->get_type()->get_tile_shapes(); - machine_layout_hmma_884_t* hmma = (machine_layout_hmma_884_t*)machine_layouts_.at(layouts_->get(dot)); + machine_mma884_layout* hmma = (machine_mma884_layout*)machine_layouts_.at(layouts_->get(dot)); TA->set_vector_size(4*hmma->pack_size_0_); TB->set_vector_size(4*hmma->pack_size_1_); TA->set_return_mode(true); @@ -625,8 +625,8 @@ void generator::visit_hmma_dot(ir::dot_inst* dot, shared_tile *TA, shared_tile * Value* u_thread_id = tgt_->get_local_id(builder_->GetInsertBlock()->getModule(), *builder_, 0); - auto ord_a = layouts_->get(dot->get_operand(0))->order; - auto ord_b = layouts_->get(dot->get_operand(1))->order; + auto ord_a = layouts_->get(dot->get_operand(0))->get_order(); + auto ord_b = layouts_->get(dot->get_operand(1))->get_order(); bool is_a_trans = is_trans(dot->get_operand(0)); bool is_b_trans = is_trans(dot->get_operand(1)); @@ -655,14 +655,14 @@ void generator::visit_hmma_dot(ir::dot_inst* dot, shared_tile *TA, shared_tile * "{$8, $9}, " "{$10, $11}, " "{$0, $1, $2, $3, $4, $5, $6, $7};", "=f,=f,=f,=f,=f,=f,=f,=f,r,r,r,r,0,1,2,3,4,5,6,7", false); - analysis::layout_hmma_884_t* layout = layouts_->get(dot)->to_hmma884(); + analysis::mma884_layout* layout = layouts_->get(dot)->to_mma884(); - unsigned fpw_0 = layout->fpw.at(0); - unsigned fpw_1 = layout->fpw.at(1); + unsigned fpw_0 = layout->fpw(0); + unsigned fpw_1 = layout->fpw(1); unsigned wts_0 = fpw_0 * 8; unsigned wts_1 = fpw_1 * 8; - unsigned wpt_0 = layout->wpt.at(0); - unsigned wpt_1 = layout->wpt.at(1); + unsigned wpt_0 = layout->wpt(0); + unsigned wpt_1 = layout->wpt(1); unsigned stride_rep_i = wpt_0 * wts_0; unsigned stride_rep_j = wpt_1 * wts_1; unsigned num_rep_i = shapes[0] / stride_rep_i; @@ -792,7 +792,7 @@ void generator::visit_dot_inst(ir::dot_inst* dot) { if(NK != 1) { shared_tile *TA = (shared_tile*)tmap_.at(A); shared_tile *TB = (shared_tile*)tmap_.at(B); - if(layouts_->get(dot)->type == analysis::HMMA_884) + if(layouts_->get(dot)->to_mma884()) visit_hmma_dot(dot, TA, TB, TD, NK); else visit_scanline_dot(dot, TA, TB, TD, NK, c_ty, f_mul_add); @@ -856,7 +856,7 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) { }); // reduce within blocks - machine_layout_t *slayout = machine_layouts_.at(layouts_->get(layouts_->tmp(x))); + machine_data_layout *slayout = machine_layouts_.at(layouts_->get(layouts_->tmp(x))); shared_tile *stile = (shared_tile*)slayout->create(x); unsigned depth = stile->get_shapes()[axis]; @@ -926,31 +926,31 @@ void generator::visit_recoalesce_inst(ir::recoalesce_inst* rc) { // pointer to temporary shared memory Type *ty = llvm_type(rc->get_type()->get_scalar_ty(), *ctx_); // layouts - analysis::layout_hmma_884_t* in_layout = layouts_->get(op)->to_hmma884(); - analysis::layout_scanline_t* out_layout = layouts_->get(rc)->to_scanline(); + analysis::mma884_layout* in_layout = layouts_->get(op)->to_mma884(); + analysis::scanline_layout* out_layout = layouts_->get(rc)->to_scanline(); // machine tiles distributed_tile *in_dt = (distributed_tile*)(tmap_.at(op)); distributed_tile *out_dt = (distributed_tile*)(tmap_.at(rc)); // WMMA configuration long wmma_pt[3] = { 2, 4, 1}; - long wmma[3] = { 8*in_layout->wpt[0]*in_layout->fpw[0], - 8*in_layout->wpt[1]*in_layout->fpw[1], + long wmma[3] = { 8*in_layout->wpt(0)*in_layout->fpw(0), + 8*in_layout->wpt(1)*in_layout->fpw(1), 1}; // Work per thread for input layout long in_pt[3] = { shape[0] / wmma[0], shape[1] / wmma[1], 1 }; // Work per thread for output layout - long out_pt[3] = { shape[0] / out_layout->mts[0], - shape[1] / out_layout->mts[1], + long out_pt[3] = { shape[0] / out_layout->mts(0), + shape[1] / out_layout->mts(1), 1 }; if(rank > 2){ - wmma[2] = in_layout->wpt[2]*in_layout->fpw[2]; + wmma[2] = in_layout->wpt(2)*in_layout->fpw(2); in_pt[2] = shape[2] / wmma[2]; - out_pt[2] = shape[2] / out_layout->mts[2]; + out_pt[2] = shape[2] / out_layout->mts(2); } // Orders - auto ord = out_layout->order; + auto ord = out_layout->get_order(); if(ord.size() < 3) ord.push_back(2); // pointer lanes @@ -1028,13 +1028,13 @@ void generator::visit_recoalesce_inst(ir::recoalesce_inst* rc) { void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) { unsigned vector_size = 1; ir::value *arg = cts->get_operand(0); - analysis::layout_shared_t* out_layout = layouts_->get(cts)->to_shared(); - analysis::layout_scanline_t* in_layout = layouts_->get(arg)->to_scanline(); - auto out_order = out_layout->order; - auto in_order = in_layout->order; + analysis::shared_layout* out_layout = layouts_->get(cts)->to_shared(); + analysis::scanline_layout* in_layout = layouts_->get(arg)->to_scanline(); + auto out_order = out_layout->get_order(); + auto in_order = in_layout->get_order(); // tiles if(out_order == in_order) - vector_size = in_layout->nts.at(in_order[0]); + vector_size = in_layout->nts(in_order[0]); std::map packets; for_each(arg, [&](indices_t idx){ @@ -1180,17 +1180,17 @@ void generator::visit_function(ir::function* fn) { -void generator::visit_layout_hmma_884(analysis::layout_hmma_884_t* layout) { - machine_layouts_[layout] = new machine_layout_hmma_884_t(mod_, &*builder_, tgt_, llvm_type(layout->ty->get_scalar_ty(), *ctx_), a_axes_, axes_, layout); +void generator::visit_layout_hmma_884(analysis::mma884_layout* layout) { + machine_layouts_[layout] = new machine_mma884_layout(mod_, &*builder_, tgt_, a_axes_, axes_, layout); } -void generator::visit_layout_scanline(analysis::layout_scanline_t* layout) { - machine_layouts_[layout] = new machine_layout_scanline_t(mod_, &*builder_, tgt_, llvm_type(layout->ty->get_scalar_ty(), *ctx_), a_axes_, axes_, layout); +void generator::visit_layout_scanline(analysis::scanline_layout* layout) { + machine_layouts_[layout] = new machine_scanline_layout(mod_, &*builder_, tgt_, a_axes_, axes_, layout); } -void generator::visit_layout_shared(analysis::layout_shared_t* layout) { +void generator::visit_layout_shared(analysis::shared_layout* layout) { - machine_layouts_[layout] = new machine_layout_shared_t(mod_, &*builder_, tgt_, alloc_, sh_mem_ptr_, layout, vmap_, tmap_); + machine_layouts_[layout] = new machine_shared_layout(mod_, &*builder_, tgt_, alloc_, sh_mem_ptr_, layout, vmap_, tmap_); } void generator::visit_basic_block(ir::basic_block * block) { @@ -1230,9 +1230,9 @@ void generator::set_value(ir::value *x, const indices_t& idx, Value* v) { } -void generator::finalize_shared_layout(analysis::layout_shared_t *shared) { - if(shared->double_buffer) { - auto info = *shared->double_buffer; +void generator::finalize_shared_layout(analysis::shared_layout *shared) { + if(shared->get_double_buffer()) { + auto info = *shared->get_double_buffer(); ir::phi_node *phi = info.phi; PHINode *ptr = (PHINode*)((shared_tile*)tmap_.at(phi))->get_pointer(); PHINode *offset = (PHINode*)((shared_tile*)tmap_.at(phi))->get_offset(); @@ -1247,8 +1247,8 @@ void generator::finalize_shared_layout(analysis::layout_shared_t *shared) { offset->addIncoming(next_offset, llvm_inc_block); } else { - unsigned num_bytes = shared->ty->get_primitive_size_in_bits() / 8; - offset->addIncoming(builder_->getInt32(shared->size / (2*num_bytes)), llvm_inc_block); + unsigned num_bytes = shared->get_type()->get_primitive_size_in_bits() / 8; + offset->addIncoming(builder_->getInt32(shared->get_size() / (2*num_bytes)), llvm_inc_block); } ptr->addIncoming(inc_shared->get_pointer(), llvm_inc_block); } @@ -1258,7 +1258,7 @@ void generator::finalize_shared_layout(analysis::layout_shared_t *shared) { void generator::finalize_function(ir::function *fn) { // finalize double-buffering for(const auto& x: layouts_->get_all()) - if(auto *shared = dynamic_cast(x.second)) + if(auto *shared = dynamic_cast(x.second)) finalize_shared_layout(shared); // finalize phi for(ir::basic_block *block: fn->blocks()) diff --git a/lib/codegen/selection/machine_layout.cc b/lib/codegen/selection/machine_layout.cc index d1ea9fa0f..e1fcc8fe6 100644 --- a/lib/codegen/selection/machine_layout.cc +++ b/lib/codegen/selection/machine_layout.cc @@ -71,18 +71,18 @@ inline int32_t ceil(int32_t num, int32_t div){ -machine_layout_shared_t::machine_layout_shared_t(Module *mod, Builder *builder, target *tgt, analysis::allocation* alloc, - Value *&sh_mem_ptr, analysis::layout_shared_t *layout, +machine_shared_layout::machine_shared_layout(Module *mod, Builder *builder, target *tgt, analysis::allocation* alloc, + Value *&sh_mem_ptr, analysis::shared_layout *layout, std::map& vmap, std::map& tmap) : mod_(mod), builder_(builder), tgt_(tgt), alloc_(alloc), sh_mem_ptr_(sh_mem_ptr), layout_(layout), vmap_(vmap), tmap_(tmap) { - Type* ty = llvm_type(layout_->ty, builder_->getContext()); + Type* ty = llvm_type(layout_->get_type(), builder_->getContext()); PointerType *ptr_ty = ty->getPointerTo(sh_mem_ptr_->getType()->getPointerAddressSpace()); // double-buffered - if(layout_->double_buffer) { + if(layout_->get_double_buffer()) { BasicBlock *current = builder_->GetInsertBlock(); - auto info = *layout_->double_buffer; + auto info = *layout_->get_double_buffer(); ir::phi_node *phi = info.phi; BasicBlock *parent = (BasicBlock*)vmap_.at((ir::value*)(phi->get_parent())); if(parent->empty()) @@ -105,31 +105,31 @@ machine_layout_shared_t::machine_layout_shared_t(Module *mod, Builder *builder, } -tile* machine_layout_shared_t::create(ir::value *v) { - auto order = layout_->order; - auto shapes = layout_->shapes; - Type* ty = llvm_type(layout_->ty, builder_->getContext()); - // double-buffered - if(layout_->double_buffer) { - if(v == layout_->double_buffer->phi) - return new shared_tile(ty, shapes, order, ptr_, *builder_, offset_); - if(v == layout_->double_buffer->latch) - return new shared_tile(ty, shapes, order, next_ptr_, *builder_); - return new shared_tile(ty, shapes, order, pre_ptr_, *builder_); - } - else { - return new shared_tile(ty, shapes, order, ptr_, *builder_); - } +tile* machine_shared_layout::create(ir::value *v) { + Type* ty = llvm_type(layout_->get_type(), builder_->getContext()); + auto double_buffer = layout_->get_double_buffer(); + // offset + Value *offset = nullptr; + if(double_buffer && v == double_buffer->phi) + offset = offset_; + // base pointer + Value *ptr = ptr_; + if(double_buffer && v == double_buffer->latch) + ptr = next_ptr_; + else if(double_buffer && v == double_buffer->first) + ptr = pre_ptr_; + // create tile + return new shared_tile(ty, layout_->get_shape(), layout_->get_order(), ptr, *builder_, offset); } -machine_layout_distributed_t::machine_layout_distributed_t(Module *mod, Builder *builder, target *tgt, Type *ty, +machine_distributed_layout::machine_distributed_layout(Module *mod, Builder *builder, target *tgt, analysis::axes *a_axes, std::map& axes, - analysis::layout_t *layout) - : mod_(mod), builder_(builder), tgt_(tgt), ty_(ty), a_axes_(a_axes), axes_(axes), layout_(layout) { + analysis::data_layout *layout) + : mod_(mod), builder_(builder), tgt_(tgt), a_axes_(a_axes), axes_(axes), layout_(layout) { } -tile *machine_layout_distributed_t::create(ir::value *v) { +tile *machine_distributed_layout::create(ir::value *v) { Type *ty = llvm_type(v->get_type()->get_scalar_ty(), builder_->getContext()); const auto &shapes = v->get_type()->get_tile_shapes(); size_t rank = shapes.size(); @@ -151,12 +151,10 @@ tile *machine_layout_distributed_t::create(ir::value *v) { auto cmp = [&](int x, int y) { unsigned axx = a_axes_->get(v, x); unsigned axy = a_axes_->get(v, y); - auto itx = std::find(layout_->axes.begin(), layout_->axes.end(), axx); - auto ity = std::find(layout_->axes.begin(), layout_->axes.end(), axy); - size_t posx = std::distance(layout_->axes.begin(), itx); - size_t posy = std::distance(layout_->axes.begin(), ity); + size_t posx = layout_->find_axis(axx); + size_t posy = layout_->find_axis(axy); if(posx < rank && posy < rank) - return layout_->order[posx] < layout_->order[posy]; + return layout_->get_order(posx) < layout_->get_order(posy); return false; }; std::sort(order.begin(), order.end(), cmp); @@ -164,22 +162,21 @@ tile *machine_layout_distributed_t::create(ir::value *v) { return new distributed_tile(ty, shapes, order, axes, *builder_); } -machine_layout_hmma_884_t::machine_layout_hmma_884_t(Module *mod, Builder *builder, - target *tgt, Type *ty, analysis::axes *a_axes, +machine_mma884_layout::machine_mma884_layout(Module *mod, Builder *builder, + target *tgt, analysis::axes *a_axes, std::map& axes, - analysis::layout_hmma_884_t* layout) - : machine_layout_distributed_t(mod, builder, tgt, ty, a_axes, axes, layout) { + analysis::mma884_layout* layout) + : machine_distributed_layout(mod, builder, tgt, a_axes, axes, layout) { Value *warp_size = builder_->getInt32(32); Value* u_thread_id_0 = tgt_->get_local_id(mod_, *builder_, 0); Value *u_thread_id = builder_->CreateURem(u_thread_id_0, warp_size); Value *u_warp_id = builder_->CreateUDiv(u_thread_id_0, warp_size); - const auto& shapes = layout->shapes; - if(shapes.size() > 3) + const auto& shape = layout->get_shape(); + if(shape.size() > 3) throw std::runtime_error("unsupported"); - - bool is_batched = shapes.size() >= 3; + bool is_batched = shape.size() >= 3; Value *_1 = builder_->getInt32(1); Value *_2 = builder_->getInt32(2); @@ -188,13 +185,13 @@ machine_layout_hmma_884_t::machine_layout_hmma_884_t(Module *mod, Builder *build Value *_16 = builder_->getInt32(16); // fragments per warp - unsigned fpw_0 = layout->fpw.at(0); - unsigned fpw_1 = layout->fpw.at(1); - unsigned fpw_2 = is_batched ? layout->fpw.at(2) : 1; + unsigned fpw_0 = layout->fpw(0); + unsigned fpw_1 = layout->fpw(1); + unsigned fpw_2 = is_batched ? layout->fpw(2) : 1; // warps per tile - unsigned wpt_0 = layout->wpt.at(0); - unsigned wpt_1 = layout->wpt.at(1); - unsigned wpt_2 = is_batched ? layout->wpt.at(2) : 1; + unsigned wpt_0 = layout->wpt(0); + unsigned wpt_1 = layout->wpt(1); + unsigned wpt_2 = is_batched ? layout->wpt(2) : 1; // hmma warp tile size unsigned hmma_wts_0 = fpw_0 * 8; unsigned hmma_wts_1 = fpw_1 * 8; @@ -204,9 +201,9 @@ machine_layout_hmma_884_t::machine_layout_hmma_884_t(Module *mod, Builder *build unsigned hmma_bts_1 = hmma_wts_1 * wpt_1; unsigned hmma_bts_2 = is_batched ? hmma_wts_2 * wpt_2 : 1; // number of repetition - unsigned num_rep_0 = shapes[0] / hmma_bts_0; - unsigned num_rep_1 = shapes[1] / hmma_bts_1; - unsigned num_rep_2 = is_batched ? shapes[2] / hmma_bts_2 : 1; + unsigned num_rep_0 = shape[0] / hmma_bts_0; + unsigned num_rep_1 = shape[1] / hmma_bts_1; + unsigned num_rep_2 = is_batched ? shape[2] / hmma_bts_2 : 1; // size of each pack (interleaving) pack_size_0_ = std::min(num_rep_0, 1); pack_size_1_ = std::min(num_rep_1, 1); @@ -275,44 +272,52 @@ machine_layout_hmma_884_t::machine_layout_hmma_884_t(Module *mod, Builder *build /* axes */ - axes_[layout->axes[0]] = distributed_axis{1, idx_i, warp_id_0}; - axes_[layout->axes[1]] = distributed_axis{1, idx_j, warp_id_1}; + axes_[layout->get_axis(0)] = distributed_axis{1, idx_i, warp_id_0}; + axes_[layout->get_axis(1)] = distributed_axis{1, idx_j, warp_id_1}; if(is_batched) - axes_[layout->axes[2]] = distributed_axis{1, idx_z, warp_id_2}; + axes_[layout->get_axis(2)] = distributed_axis{1, idx_z, warp_id_2}; } -machine_layout_scanline_t::machine_layout_scanline_t(Module *mod, Builder *builder, - target *tgt, Type *ty, +machine_scanline_layout::machine_scanline_layout(Module *mod, Builder *builder, + target *tgt, analysis::axes *a_axes, std::map &axes, - analysis::layout_scanline_t* layout) - : machine_layout_distributed_t(mod, builder, tgt, ty, a_axes, axes, layout) { + analysis::scanline_layout* layout) + : machine_distributed_layout(mod, builder, tgt, a_axes, axes, layout) { Value *warp_size = builder_->getInt32(32); Value* u_thread_id_0 = tgt_->get_local_id(mod_, *builder_, 0); Value *u_thread_id = builder_->CreateURem(u_thread_id_0, warp_size); Value *u_warp_id = builder_->CreateUDiv(u_thread_id_0, warp_size); - auto order = layout->order; - const auto& shapes = layout->shapes; - size_t dim = shapes.size(); - std::vector nts = layout->nts; - std::vector mts = layout->mts; + auto order = layout->get_order(); + const auto& shape = layout->get_shape(); Value* full_thread_id = builder_->CreateAdd(builder_->CreateMul(u_warp_id, builder_->getInt32(32)), u_thread_id); - std::vector thread_id = delinearize(full_thread_id, order, mts, *builder_); + // Delinearize + size_t dim = shape.size(); + std::vector thread_id(dim); + for(unsigned k = 0; k < dim - 1; k++){ + Constant *dim_k = builder_->getInt32(layout->mts(order[k])); + Value *rem = builder_->CreateURem(full_thread_id, dim_k); + full_thread_id = builder_->CreateUDiv(full_thread_id, dim_k); + thread_id[order[k]] = rem; + } + thread_id[order[dim - 1]] = full_thread_id; // Create axes for(unsigned k = 0; k < dim; k++) { + int nts = layout->nts(k); + int mts = layout->mts(k); std::string str_k = std::to_string(k); - Value *contiguous_k = builder_->getInt32(nts[k]); + Value *contiguous_k = builder_->getInt32(nts); Value *scaled_thread_id = builder_->CreateMul(thread_id[k], contiguous_k); - unsigned per_block = nts[k] * mts[k]; - unsigned per_thread = nts[k] * shapes[k] / per_block; + unsigned per_block = nts * mts; + unsigned per_thread = nts * shape[k] / per_block; std::vector idx_list(per_thread); for(unsigned n = 0 ; n < per_thread; n++){ - unsigned offset = n / nts[k] * per_block + n % nts[k]; + unsigned offset = n / nts * per_block + n % nts; idx_list[n] = builder_->CreateAdd(scaled_thread_id, builder_->getInt32(offset), "idx_" + str_k + "_" + std::to_string(n)); } - axes_[layout->axes[k]] = distributed_axis{nts[k], idx_list, thread_id[k]}; + axes_[layout->get_axis(k)] = distributed_axis{nts, idx_list, thread_id[k]}; } } diff --git a/lib/codegen/transform/coalesce.cc b/lib/codegen/transform/coalesce.cc index 78c03396f..14a295b00 100644 --- a/lib/codegen/transform/coalesce.cc +++ b/lib/codegen/transform/coalesce.cc @@ -12,7 +12,7 @@ namespace triton { namespace codegen{ namespace transform{ -coalesce::coalesce(analysis::align* align, analysis::layout *layouts) +coalesce::coalesce(analysis::align* align, analysis::layouts *layouts) : align_(align), layout_(layouts) { } // Find all values that are used as pointer operands in LD/ST @@ -64,8 +64,9 @@ ir::value* coalesce::rematerialize(ir::value *x, ir::builder &builder, void coalesce::run(ir::module &mod) { size_t num_groups = layout_->num_layouts(); + for(size_t id = 0; id < num_groups; id++) { - if(layout_->get(id)->type != analysis::HMMA_884) + if(!layout_->get(id)->to_mma884()) continue; // extract memory stores const auto& values = layout_->values_of(id); @@ -97,7 +98,6 @@ void coalesce::run(ir::module &mod) { } } - // find values to rematerialize std::vector remat; for(size_t id = 0; id < num_groups; id++) { @@ -109,7 +109,7 @@ void coalesce::run(ir::module &mod) { // extract leading axes std::map> axes; for(ir::io_inst *i: io){ - if(i->get_pointer_operand()->get_type()->get_tile_ranks1() == layout_->get(id)->axes.size()) + if(i->get_pointer_operand()->get_type()->get_tile_ranks1() == layout_->get(id)->get_rank()) extract_ld(i, axes); } // update list of values to rematerialize diff --git a/lib/codegen/transform/membar.cc b/lib/codegen/transform/membar.cc index 1d9aef055..0a9b0235b 100644 --- a/lib/codegen/transform/membar.cc +++ b/lib/codegen/transform/membar.cc @@ -35,10 +35,11 @@ void membar::add_reference(ir::value *v, interval_vec_t &res){ return; if(!i->get_type()->is_tile_ty()) return; - if(alloc_->has_offset(layouts_->get(v))){ - unsigned offset = alloc_->offset(layouts_->get(v)); - unsigned size = layouts_->get(v)->to_shared()->size; - res.push_back(interval_t(offset, offset + size)); + analysis::shared_layout* layout = layouts_->get(v)->to_shared(); + assert(layout); + if(alloc_->has_offset(layout)){ + unsigned offset = alloc_->offset(layout); + res.push_back(interval_t(offset, offset + layout->get_size())); } } @@ -119,13 +120,11 @@ void membar::run(ir::module &mod) { // without needing synchronization std::set safe_war; for(const auto& x: layouts_->get_all()){ - if(x.second->type != analysis::SHARED) + analysis::shared_layout* layout = x.second->to_shared(); + if(!layout || !layout->get_double_buffer()) continue; - analysis::layout_shared_t* layout = x.second->to_shared(); - if(!layout->double_buffer) - continue; - for(ir::value *v: layout->values) - if(v != layout->double_buffer->phi) + for(ir::value *v: layout->get_values()) + if(v != layout->get_double_buffer()->phi) safe_war.insert(v); } diff --git a/lib/runtime/function.cc b/lib/runtime/function.cc index fe1d77b66..bdd695298 100644 --- a/lib/runtime/function.cc +++ b/lib/runtime/function.cc @@ -220,7 +220,7 @@ std::unique_ptr function::make_bin(ir::module &module, driver::c codegen::analysis::align align; codegen::analysis::axes axes; codegen::transform::disassociate disassociate; - codegen::analysis::layout layouts(&axes, &align, opt.num_warps); + codegen::analysis::layouts layouts(&axes, &align, opt.num_warps); codegen::analysis::liveness liveness(&layouts); codegen::analysis::allocation allocation(&liveness); codegen::transform::membar barriers(&liveness, &layouts, &allocation); @@ -239,7 +239,6 @@ std::unique_ptr function::make_bin(ir::module &module, driver::c align.run(module); cts.run(module); axes.run(module); -// ir::print(module, std::cout); layouts.run(module); coalesce.run(module); dce.run(module); @@ -250,15 +249,14 @@ std::unique_ptr function::make_bin(ir::module &module, driver::c dce.run(module); align.run(module); axes.run(module); -// ir::print(module, std::cout); layouts.run(module); liveness.run(module); allocation.run(module); if(allocation.allocated_size() > context->device()->max_shared_memory()) return std::unique_ptr(); barriers.run(module); -// ir::print(module, std::cout); isel.visit(module, *llvm); + // return binary std::unique_ptr res(driver::module::create(context, std::move(llvm))); // done diff --git a/python/examples/einsum.py b/python/examples/einsum.py index a3fdba5e0..f61347d0c 100644 --- a/python/examples/einsum.py +++ b/python/examples/einsum.py @@ -79,7 +79,7 @@ for N, T, H, S, E in NTHSE: # 1D Dense convolution NCHKR = [ - # (1, 1152, 12602, 512, 3) + (1, 1152, 12602, 512, 3) ] for N, C, H, K, R in NCHKR: torch_fn = lambda a, b: torch.nn.functional.conv1d(a, b.permute(2, 0, 1)) @@ -92,10 +92,10 @@ for N, C, H, K, R in NCHKR: # 2D Dense convolution NCHWKRS = [ - #(8, 64, 128, 128, 768, 3, 3), - #(8, 128, 64, 64, 256, 3, 3), - #(8, 256, 32, 32, 512, 3, 3), - #(8, 512, 32, 32, 1024, 3, 3) + (8, 64, 128, 128, 768, 3, 3), + (8, 128, 64, 64, 256, 3, 3), + (8, 256, 32, 32, 512, 3, 3), + (8, 512, 32, 32, 1024, 3, 3) ] for N, C, H, W, K, R, S in NCHWKRS: torch_fn = lambda a, b: torch.nn.functional.conv2d(a, b.permute(3, 0, 1, 2)) @@ -108,10 +108,10 @@ for N, C, H, W, K, R, S in NCHWKRS: # 3D Dense Convolution NCDHWKTRS = [ - #(8, 32, 27, 100, 100, 64, 3, 3, 3), - #(8, 64, 23, 48, 48, 256, 3, 3, 3), - #(8, 256, 19, 22, 22, 640, 3, 3, 3), - #(8, 640, 15, 36, 36, 384, 3, 3, 3) + (8, 32, 27, 100, 100, 64, 3, 3, 3), + (8, 64, 23, 48, 48, 256, 3, 3, 3), + (8, 256, 19, 22, 22, 640, 3, 3, 3), + (8, 640, 15, 36, 36, 384, 3, 3, 3) ] for N, C, D, H, W, K, T, R, S in NCDHWKTRS: torch_fn = lambda a, b: torch.nn.functional.conv3d(a, b.permute(4, 0, 1, 2, 3)) @@ -168,7 +168,7 @@ for N, C, H, W, K, R, S in NCHWKRS: # Benchmark torch.set_num_threads(1) for a_shape, b_shape, c_shape, torch_fn, expr, arrays in configs: - dtype = torch.cuda.FloatTensor + dtype = torch.cuda.HalfTensor # initialize input tensors a = torch.rand(*a_shape).type(dtype).cuda() b = torch.rand(*b_shape).type(dtype).cuda()