From 9bc6df4fd15b5f77c6c25b489997719669c06c77 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Wed, 9 Oct 2019 15:05:44 -0400 Subject: [PATCH] [codegen] more cleaning --- include/triton/codegen/analysis/layout.h | 21 ++- include/triton/codegen/analysis/liveness.h | 4 +- include/triton/codegen/analysis/tiles.h | 16 +- include/triton/codegen/selection.h | 7 +- lib/codegen/analysis/layout.cc | 127 ++++++++++++- lib/codegen/analysis/liveness.cc | 13 +- lib/codegen/analysis/tiles.cc | 204 +++------------------ lib/codegen/selection.cc | 79 ++++---- lib/codegen/transform/coalesce.cc | 3 +- lib/runtime/function.cc | 4 +- 10 files changed, 226 insertions(+), 252 deletions(-) diff --git a/include/triton/codegen/analysis/layout.h b/include/triton/codegen/analysis/layout.h index f462211e8..a9d2d1a77 100644 --- a/include/triton/codegen/analysis/layout.h +++ b/include/triton/codegen/analysis/layout.h @@ -19,6 +19,20 @@ namespace codegen{ namespace analysis{ class axes; +class align; + +enum layout_type_t { + HMMA_884, + SCANLINE +}; + +struct layout_t { + layout_type_t type; + ir::value *i; + std::vector axes; + std::vector shapes; + std::vector order; +}; class layout { typedef ir::value* node_t; @@ -31,19 +45,24 @@ private: public: // constructor - layout(analysis::axes *axes); + layout(analysis::axes *axes, analysis::align *align); // accessors unsigned layout_of(ir::value *value) const; const std::vector& values_of(unsigned id) const; size_t num_layouts() const; + layout_t get(ir::value *v) const; + const std::map& get_all() const; + // execution void run(ir::module &mod); private: analysis::axes* axes_; + analysis::align* align_; tools::graph graph_; std::map groups_; std::map> values_; + std::map layouts_; }; } diff --git a/include/triton/codegen/analysis/liveness.h b/include/triton/codegen/analysis/liveness.h index f082e1cfa..57fc90b81 100644 --- a/include/triton/codegen/analysis/liveness.h +++ b/include/triton/codegen/analysis/liveness.h @@ -22,6 +22,7 @@ namespace analysis{ typedef unsigned slot_index; class tiles; +class layout; struct segment { slot_index start; @@ -72,7 +73,7 @@ private: public: - liveness(tiles *t): tiles_(t){ } + liveness(tiles *t, layout *l): tiles_(t), layouts_(l){ } // padding unsigned get_pad(ir::value *v) const { return pad_.at(v); } // buffer size @@ -92,6 +93,7 @@ public: private: // analysis tiles *tiles_; + layout *layouts_; // stuff has_storage_map_t has_dedicated_storage_; indices_map_t indices; diff --git a/include/triton/codegen/analysis/tiles.h b/include/triton/codegen/analysis/tiles.h index ca1eb0e90..fdc03cee1 100644 --- a/include/triton/codegen/analysis/tiles.h +++ b/include/triton/codegen/analysis/tiles.h @@ -5,6 +5,7 @@ #include #include #include +#include "triton/codegen/analysis/layout.h" namespace triton{ @@ -25,28 +26,22 @@ class axes; class layout; class align; -enum layout_t { - SCANLINE, - HMMA_C -}; class tiles { typedef std::map> param_map_t; private: - void init_hmma_tile(ir::value *i); - void init_scanline_tile(ir::value *i); + void init_hmma_tile(const layout_t& layout); + void init_scanline_tile(const layout_t& layout); bool is_trans(ir::value *i); public: tiles(size_t num_warps, analysis::align* align, analysis::axes* axes, analysis::layout* layout); void run(ir::module &mod); - layout_t hmma(ir::value *value); int mts(ir::value *value, unsigned ax); int nts(ir::value *value, unsigned ax); int fpw(ir::value *value, unsigned ax); int wpt(ir::value *value, unsigned ax); - std::vector order(ir::value *v); - const std::map& largest(); + private: // dependencies @@ -56,9 +51,6 @@ private: // number of warps size_t num_warps_; // tile properties - std::map largest_; - std::map> order_; - std::map hmma_; std::map fpw_; std::map wpt_; std::map mts_; diff --git a/include/triton/codegen/selection.h b/include/triton/codegen/selection.h index b505a6a29..bb03d3521 100644 --- a/include/triton/codegen/selection.h +++ b/include/triton/codegen/selection.h @@ -5,6 +5,7 @@ #include "triton/ir/module.h" #include "triton/ir/function.h" #include "triton/ir/type.h" +#include "triton/codegen/analysis/layout.h" #include "triton/codegen/transform/cts.h" @@ -171,9 +172,9 @@ private: void create_shared_tile(ir::value *v, Builder &builder, Value *sh_mem_ptr); void create_distributed_tile(ir::value *v, Builder &builder); void create_tile(ir::value *v, Builder &builder, std::set &seen, Value *sh_mem_ptr); - void init_strided_scan_axes(ir::value *i, Builder &builder, Value *u_thread_id, Value *u_warp_id); - void init_hmma_axes(ir::value *i, Builder &builder, Value *u_thread_id, Value *u_warp_id); - void init_axes(ir::value *i, Builder &builder, Value *u_thread_id, Value *u_warp_id); + void init_strided_scan_axes(const analysis::layout_t& layout, Builder &builder, Value *u_thread_id, Value *u_warp_id); + void init_hmma_axes(const analysis::layout_t& layout, Builder &builder, Value *u_thread_id, Value *u_warp_id); + void init_axes(const analysis::layout_t& layout, Builder &builder, Value *u_thread_id, Value *u_warp_id); void init_layouts(ir::function *fn, Builder &builder, Value *sh_mem_ptr); // lower scalar instruction diff --git a/lib/codegen/analysis/layout.cc b/lib/codegen/analysis/layout.cc index 2397df489..2a446f3ef 100644 --- a/lib/codegen/analysis/layout.cc +++ b/lib/codegen/analysis/layout.cc @@ -1,6 +1,8 @@ #include #include +#include #include "triton/codegen/analysis/axes.h" +#include "triton/codegen/analysis/align.h" #include "triton/codegen/analysis/layout.h" #include "triton/ir/function.h" #include "triton/ir/module.h" @@ -12,8 +14,8 @@ namespace analysis{ // constructor -layout::layout(analysis::axes *axes) - : axes_(axes) { } +layout::layout(analysis::axes *axes, analysis::align *align) + : axes_(axes), align_(align) { } // get group id unsigned layout::layout_of(ir::value *value) const @@ -56,6 +58,51 @@ void layout::make_graph(ir::instruction *i) { } } +// hmma +bool is_hmma_c(ir::value *v){ + bool result = false; + if(auto *x = dynamic_cast(v)){ + ir::value *a = x->get_operand(0); + ir::type *a_ty = a->get_type(); + ir::value *b = x->get_operand(1); + ir::type *b_ty = b->get_type(); + result = a_ty->get_scalar_ty()->is_half_ty() && + b_ty->get_scalar_ty()->is_half_ty(); + } + return result; +} + +layout_t layout::get(ir::value *v) const { + return layouts_.at(groups_.at(v)); +} + +const std::map& layout::get_all() const { + return layouts_; +} + +void extract_io_use(ir::value *v, std::set& result) { + for(ir::user* u: v->get_users()){ + auto i = dynamic_cast(u); + if(i && i->get_pointer_operand() == v) + result.insert(i); + } +} + + +inline bool is_trans(ir::value *v) { + if(dynamic_cast(v)) { + return true; + } + if(auto *phi = dynamic_cast(v)) { + bool result = true; + for(ir::value *op: phi->ops()) + result = result && is_trans(op); + return result; + } + return false; +} + + void layout::run(ir::module &mod) { // make graph graph_.clear(); @@ -64,6 +111,82 @@ void layout::run(ir::module &mod) { }); // connected components graph_.connected_components(&values_, &groups_); + // create layouts + for(const auto& x: values_) { + bool hmma_c = std::any_of(x.second.begin(), x.second.end(), &is_hmma_c); + layouts_[x.first].type = hmma_c ? HMMA_884 : SCANLINE; + + } + + + /* ---- TO CLEAN ---- */ + + size_t num_groups = num_layouts(); + // helpers + auto rank = [this](ir::value* v) { + int ret = 0; + for(int s: v->get_type()->get_tile_shapes()) + ret += s > 1; + return ret; + }; + + // find out which value is the largest in each group + for(const auto& x: values_) { + auto cmp = [&rank](ir::value* x, ir::value *y) { return rank(x) < rank(y); }; + ir::value *largest = *std::max_element(x.second.begin(), x.second.end(), cmp); + layouts_[x.first].axes = axes_->get(largest); + layouts_[x.first].i = largest; + layouts_[x.first].shapes = largest->get_type()->get_tile_shapes(); + } + + + // find out the layout ordering of a group + for(size_t i = 0; i < num_groups; i++){ + std::set io; + for(ir::value* v: values_of(i)) + extract_io_use(v, io); + auto cmp = [&rank](ir::io_inst* x, ir::io_inst *y) { + return rank(x->get_pointer_operand()) < rank(y->get_pointer_operand()); + }; + auto it = std::max_element(io.begin(), io.end(), cmp); + std::vector order(layouts_[i].axes.size()); + std::iota(order.begin(), order.end(), 0); + if(it != io.end()) { + auto max_contiguous = align_->contiguous((*it)->get_pointer_operand()); + std::sort(order.begin(), order.end(), [&](unsigned a, unsigned b) { + return max_contiguous[a] > max_contiguous[b]; } + ); + } + layouts_[i].order = order; + } + // matrix multiplication optimizations + for(size_t i = 0; i < num_groups; i++){ + std::vector dots; + for(ir::value* v: values_of(i)) + if(auto *x = dynamic_cast(v)) + dots.push_back(x); + for(ir::dot_inst* dot: dots){ + ir::value* a = dot->get_operand(0); + ir::value* b = dot->get_operand(1); + if(get(dot).type == HMMA_884){ + auto a_val = values_of(layout_of(a)); + auto b_val = values_of(layout_of(b)); + for(ir::value *v: a_val) + if(auto *cts = dynamic_cast(v)) + layouts_[layout_of(a)].order = layouts_[layout_of(cts->get_operand(0))].order; + for(ir::value *v: b_val) + if(auto *cts = dynamic_cast(v)) + layouts_[layout_of(b)].order = layouts_[layout_of(cts->get_operand(0))].order; + } + else{ + std::vector col = {0, 1}; + std::vector row = {1, 0}; + layouts_[layout_of(a)].order = is_trans(a) ? row : col; + layouts_[layout_of(b)].order = is_trans(b) ? col : row; + } + } + } + } } diff --git a/lib/codegen/analysis/liveness.cc b/lib/codegen/analysis/liveness.cc index 4ca9bf96f..35c801e8f 100644 --- a/lib/codegen/analysis/liveness.cc +++ b/lib/codegen/analysis/liveness.cc @@ -4,6 +4,7 @@ #include "triton/codegen/instructions.h" #include "triton/codegen/analysis/liveness.h" #include "triton/codegen/analysis/tiles.h" +#include "triton/codegen/analysis/layout.h" #include "triton/codegen/transform/cts.h" #include "triton/ir/basic_block.h" #include "triton/ir/function.h" @@ -89,8 +90,8 @@ bool liveness::do_pad(ir::value *x) { ir::value *b = dot->get_operand(1); size_t a_previous = pad_[a]; size_t b_previous = pad_[b]; - auto a_order = tiles_->order(a); - auto b_order = tiles_->order(b); + auto a_order = layouts_->get(a).order; + auto b_order = layouts_->get(b).order; bool a_row = is_trans(a) ^ (a_order[0] == 1); bool b_row = is_trans(b) ^ (b_order[0] == 1); auto a_shapes = a->get_type()->get_tile_shapes(); @@ -108,9 +109,9 @@ bool liveness::do_pad(ir::value *x) { } // padding for copy to shared if(auto* cts = dynamic_cast(x)) { - auto cts_order = tiles_->order(cts); + auto cts_order = layouts_->get(cts).order; ir::value *arg = cts->get_operand(0); - auto arg_order = tiles_->order(arg); + auto arg_order = layouts_->get(arg).order; size_t previous = pad_[cts]; if(cts_order != arg_order) pad_[cts] = std::max(pad_[cts], 4); @@ -144,7 +145,7 @@ unsigned liveness::num_bytes(ir::value *x) { for(auto x: shapes) num_elements *= x; size_t depth; - if(tiles_->hmma(x)) + if(layouts_->get(x).type == HMMA_884) depth = tiles_->wpt(op, axis); else depth = tiles_->mts(op, axis); @@ -153,7 +154,7 @@ unsigned liveness::num_bytes(ir::value *x) { unsigned num_bytes = x->get_type()->get_primitive_size_in_bits() / 8; unsigned pad = pad_.at(x); if(pad > 0){ - unsigned ld = x->get_type()->get_tile_shapes()[tiles_->order(x)[0]]; + unsigned ld = x->get_type()->get_tile_shapes()[layouts_->get(x).order[0]]; num_bytes += pad * num_bytes / ld; } if(has_double(x)) diff --git a/lib/codegen/analysis/tiles.cc b/lib/codegen/analysis/tiles.cc index dcec28a6b..6a16544e6 100644 --- a/lib/codegen/analysis/tiles.cc +++ b/lib/codegen/analysis/tiles.cc @@ -23,59 +23,7 @@ tiles::tiles(size_t num_warps, analysis::align *align, analysis::axes *axes, ana num_warps_(num_warps), align_(align), axes_(axes), layout_(layout) { } -bool is_hmma_c(ir::value *v){ - bool result = false; - if(auto *x = dynamic_cast(v)){ - ir::value *a = x->get_operand(0); - ir::type *a_ty = a->get_type(); - ir::value *b = x->get_operand(1); - ir::type *b_ty = b->get_type(); - result = a_ty->get_scalar_ty()->is_half_ty() && - b_ty->get_scalar_ty()->is_half_ty(); - } - return result; -} -bool is_hmma_a_col(ir::value* v) { - for(ir::user *u: v->get_users()) - if(is_hmma_c(u)){ - ir::dot_inst* dot = (ir::dot_inst*)u; - if((v == dot->get_operand(0))) - return true; - } -} - -bool is_hmma_a_row(ir::value* v) { - for(ir::user *u: v->get_users()) - if(is_hmma_c(u)){ - ir::dot_inst* dot = (ir::dot_inst*)u; - if((v == dot->get_operand(0))) - return true; - } -} - -bool is_hmma_b_col(ir::value* v) { - for(ir::user *u: v->get_users()) - if(is_hmma_c(u)){ - ir::dot_inst* dot = (ir::dot_inst*)u; - if((v == dot->get_operand(1))) - return true; - } -} - -bool is_hmma_b_row(ir::value* v) { - for(ir::user *u: v->get_users()) - if(is_hmma_c(u)){ - ir::dot_inst* dot = (ir::dot_inst*)u; - if((v == dot->get_operand(1))) - return true; - } -} - - -layout_t tiles::hmma(ir::value *value) { - return hmma_.at(layout_->layout_of(value)); -} int tiles::mts(ir::value *value, unsigned ax) { return mts_.at(axes_->get(value, ax)); @@ -93,24 +41,15 @@ int tiles::wpt(ir::value *value, unsigned ax) { return wpt_.at(axes_->get(value, ax)); } -std::vector tiles::order(ir::value *v) { - auto ret = order_[layout_->layout_of(v)]; - return ret; -} - -const std::map& tiles::largest() { - return largest_; -} - unsigned clamp(unsigned x, unsigned lo, unsigned hi) { return std::min(std::max(x, lo), hi); } -void tiles::init_hmma_tile(ir::value *i) { - auto ord = order(i); - auto shapes = i->get_type()->get_tile_shapes(); +void tiles::init_hmma_tile(const layout_t& layout) { + auto ord = layout.order; + auto shapes = layout.i->get_type()->get_tile_shapes(); unsigned shape_0 = shapes[ord[0]]; unsigned shape_1 = shapes[ord[1]]; /* fragments per warp */ @@ -127,7 +66,7 @@ void tiles::init_hmma_tile(ir::value *i) { }while(fpw_nm1 != fpw); // store parameters for(unsigned d = 0; d < shapes.size(); d++) - fpw_[axes_->get(i, d)] = fpw[d]; + fpw_[layout.axes[d]] = fpw[d]; /* warps per tile */ // try to make things as square as possible to maximize data re-use std::vector wpt = {1, 1, 1}; @@ -141,149 +80,48 @@ void tiles::init_hmma_tile(ir::value *i) { }while(wpt_nm1 != wpt); // store parameters for(unsigned d = 0; d < shapes.size(); d++) - wpt_[axes_->get(i, d)] = wpt[d]; + wpt_[layout.axes[d]] = wpt[d]; /* sanity check */ unsigned effective_num_warps = 1; for(size_t d = 0; d < shapes.size(); d++) - effective_num_warps *= wpt_[axes_->get(i, d)]; + effective_num_warps *= wpt_[layout.axes[d]]; if(num_warps_ != effective_num_warps) throw std::runtime_error("cannot create a kernel with this amount of warps"); } -void tiles::init_scanline_tile(ir::value *i) { - auto ord = order(i); - auto shapes = i->get_type()->get_tile_shapes(); - unsigned size = i->get_type()->get_tile_num_elements(); +void tiles::init_scanline_tile(const layout_t& layout) { + auto ord = layout.order; + auto shapes = layout.shapes; + unsigned size = std::accumulate(shapes.begin(), shapes.end(), 1, std::multiplies()); unsigned ld = ord[0]; unsigned num_threads = num_warps_*32; unsigned current = num_threads; - nts_[axes_->get(i, ld)] = clamp(size / num_threads, 1, 4); - mts_[axes_->get(i, ld)] = clamp(current, 1, shapes[ld] / nts_[axes_->get(i, ld)]); - current = current / mts_[axes_->get(i, ld)]; + nts_[layout.axes[ld]] = clamp(size / num_threads, 1, 4); + mts_[layout.axes[ld]] = clamp(current, 1, shapes[ld] / nts_[layout.axes[ld]]); + current = current / mts_[layout.axes[ld]]; for(size_t d = 1; d < shapes.size(); d++){ ld = ord[d]; - nts_[axes_->get(i, ld)] = 1; - mts_[axes_->get(i, ld)] = clamp(current, 1, shapes[ld]); - current = current / mts_[axes_->get(i, ld)]; + nts_[layout.axes[ld]] = 1; + mts_[layout.axes[ld]] = clamp(current, 1, shapes[ld]); + current = current / mts_[layout.axes[ld]]; } /* sanity check */ unsigned effective_num_threads = 1; for(size_t d = 0; d < shapes.size(); d++) - effective_num_threads *= mts_[axes_->get(i, d)]; + effective_num_threads *= mts_[layout.axes[d]]; // std::cout << num_threads << " " << effective_num_threads << std::endl; if(num_threads != effective_num_threads) throw std::runtime_error("cannot create a kernel with this amount of warps"); } -void extract_io_use(ir::value *v, std::set& result) { - for(ir::user* u: v->get_users()){ - auto i = dynamic_cast(u); - if(i && i->get_pointer_operand() == v) - result.insert(i); - } -} - - -bool tiles::is_trans(ir::value *v) { - if(dynamic_cast(v)) { - return true; - } - if(auto *phi = dynamic_cast(v)) { - bool result = true; - for(ir::value *op: phi->ops()) - result = result && is_trans(op); - return result; - } - return false; -} - - void tiles::run(ir::module &) { - hmma_.clear(); - largest_.clear(); - order_.clear(); - - size_t num_groups = layout_->num_layouts(); - // helpers - auto rank = [](ir::value* v) { - int ret = 0; - for(int s: v->get_type()->get_tile_shapes()) - ret += s > 1; - return ret; - }; - - // find out which groups require hmma layout - for(size_t i = 0; i < num_groups; i++) { - const auto& values = layout_->values_of(i); - bool hmma_c = std::any_of(values.begin(), values.end(), &is_hmma_c); - if(hmma_c) hmma_[i] = HMMA_C; - else hmma_[i] = SCANLINE; - } - - // find out which value is the largest in each group - for(size_t i = 0; i < num_groups; i++) { - const auto& values = layout_->values_of(i); - auto cmp = [&rank](ir::value* x, ir::value *y) { return rank(x) < rank(y); }; - largest_[i] = *std::max_element(values.begin(), values.end(), cmp); - } - - - // find out the layout ordering of a group - for(size_t i = 0; i < num_groups; i++){ - std::set io; - for(ir::value* v: layout_->values_of(i)) - extract_io_use(v, io); - auto cmp = [&rank](ir::io_inst* x, ir::io_inst *y) { - return rank(x->get_pointer_operand()) < rank(y->get_pointer_operand()); - }; - auto it = std::max_element(io.begin(), io.end(), cmp); - std::vector order(rank(largest_[i])); - std::iota(order.begin(), order.end(), 0); - if(it != io.end()) { - auto max_contiguous = align_->contiguous((*it)->get_pointer_operand()); - std::sort(order.begin(), order.end(), [&](unsigned a, unsigned b) { - return max_contiguous[a] > max_contiguous[b]; } - ); - } - order_[i] = order; - } - // matrix multiplication optimizations - for(size_t i = 0; i < num_groups; i++){ - std::vector dots; - for(ir::value* v: layout_->values_of(i)) - if(auto *x = dynamic_cast(v)) - dots.push_back(x); - for(ir::dot_inst* dot: dots){ - ir::value* a = dot->get_operand(0); - ir::value* b = dot->get_operand(1); - if(hmma_.at(layout_->layout_of(dot)) == HMMA_C){ - auto a_val = layout_->values_of(layout_->layout_of(a)); - auto b_val = layout_->values_of(layout_->layout_of(b)); - for(ir::value *v: a_val) - if(auto *cts = dynamic_cast(v)) - order_[layout_->layout_of(a)] = order_[layout_->layout_of(cts->get_operand(0))]; - for(ir::value *v: b_val) - if(auto *cts = dynamic_cast(v)) - order_[layout_->layout_of(b)] = order_[layout_->layout_of(cts->get_operand(0))]; - } - else{ - std::vector col = {0, 1}; - std::vector row = {1, 0}; - order_[layout_->layout_of(a)] = is_trans(a) ? row : col; - order_[layout_->layout_of(b)] = is_trans(b) ? col : row; - } - } - } // tiling parameters - for(auto x: largest_){ - ir::value *i = x.second; - if(!i->get_type()->is_tile_ty()) - continue; + for(auto x: layout_->get_all()){ /* HMMA parameters*/ - if(hmma_[x.first] == HMMA_C) - init_hmma_tile(i); + if(x.second.type == HMMA_884) + init_hmma_tile(x.second); else - init_scanline_tile(i); + init_scanline_tile(x.second); } } diff --git a/lib/codegen/selection.cc b/lib/codegen/selection.cc index b5692844f..1505bcbc6 100644 --- a/lib/codegen/selection.cc +++ b/lib/codegen/selection.cc @@ -577,37 +577,36 @@ inline int32_t ceil(int32_t num, int32_t div){ return (num + div - 1)/div; } -void selection::init_strided_scan_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) { - auto order = tiles_->order(v); - const auto& shapes = v->get_type()->get_tile_shapes(); +void selection::init_strided_scan_axes(const analysis::layout_t& layout, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) { + auto order = layout.order; + const auto& shapes = layout.shapes; size_t dim = shapes.size(); - std::vector contiguous(dim); - std::vector block_size(dim); + std::vector nts(dim); + std::vector mts(dim); for(unsigned i = 0; i < shapes.size(); i++){ - contiguous[i] = tiles_->nts(v, i); - block_size[i] = tiles_->mts(v, i); + nts[i] = tiles_->nts(layout.i, i); + mts[i] = tiles_->mts(layout.i, i); } Value* full_thread_id = builder.CreateAdd(builder.CreateMul(u_warp_id, builder.getInt32(32)), u_thread_id); - std::vector thread_id = delinearize(full_thread_id, order, block_size, builder); + std::vector thread_id = delinearize(full_thread_id, order, mts, builder); // Create axes for(unsigned k = 0; k < dim; k++) { std::string str_k = std::to_string(k); - Value *contiguous_k = builder.getInt32(contiguous[k]); + Value *contiguous_k = builder.getInt32(nts[k]); Value *scaled_thread_id = builder.CreateMul(thread_id[k], contiguous_k); - unsigned per_block = contiguous[k] * block_size[k]; - unsigned per_thread = contiguous[k] * shapes[k] / per_block; + unsigned per_block = nts[k] * mts[k]; + unsigned per_thread = nts[k] * shapes[k] / per_block; std::vector idx_list(per_thread); for(unsigned n = 0 ; n < per_thread; n++){ - unsigned offset = n / contiguous[k] * per_block + n % contiguous[k]; + unsigned offset = n / nts[k] * per_block + n % nts[k]; idx_list[n] = builder.CreateAdd(scaled_thread_id, builder.getInt32(offset), "idx_" + str_k + "_" + std::to_string(n)); } - axes_[a_axes_->get(v, k)] = distributed_axis{contiguous[k], idx_list, thread_id[k]}; + axes_[layout.axes[k]] = distributed_axis{nts[k], idx_list, thread_id[k]}; } } -void selection::init_hmma_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) { -// auto order = reorder_->get_order(v); - const auto& shapes = v->get_type()->get_tile_shapes(); +void selection::init_hmma_axes(const analysis::layout_t& layout, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) { + const auto& shapes = layout.shapes; if(shapes.size() > 3) throw std::runtime_error("unsupported"); bool is_batched = shapes.size() >= 3; @@ -619,13 +618,13 @@ void selection::init_hmma_axes(ir::value *v, IRBuilder<> &builder, Value *u_thre Value *_16 = builder.getInt32(16); // fragments per warp - unsigned fpw_0 = tiles_->fpw(v, 0); - unsigned fpw_1 = tiles_->fpw(v, 1); - unsigned fpw_2 = is_batched ? tiles_->fpw(v, 2) : 1; + unsigned fpw_0 = tiles_->fpw(layout.i, 0); + unsigned fpw_1 = tiles_->fpw(layout.i, 1); + unsigned fpw_2 = is_batched ? tiles_->fpw(layout.i, 2) : 1; // warps per tile - unsigned wpt_0 = tiles_->wpt(v, 0); - unsigned wpt_1 = tiles_->wpt(v, 1); - unsigned wpt_2 = is_batched ? tiles_->wpt(v, 2) : 1; + unsigned wpt_0 = tiles_->wpt(layout.i, 0); + unsigned wpt_1 = tiles_->wpt(layout.i, 1); + unsigned wpt_2 = is_batched ? tiles_->wpt(layout.i, 2) : 1; // hmma warp tile size unsigned hmma_wts_0 = fpw_0 * 8; unsigned hmma_wts_1 = fpw_1 * 8; @@ -706,18 +705,18 @@ void selection::init_hmma_axes(ir::value *v, IRBuilder<> &builder, Value *u_thre /* axes */ - axes_[a_axes_->get(v, 0)] = distributed_axis{1, idx_i, warp_id_0}; - axes_[a_axes_->get(v, 1)] = distributed_axis{1, idx_j, warp_id_1}; + axes_[layout.axes[0]] = distributed_axis{1, idx_i, warp_id_0}; + axes_[layout.axes[1]] = distributed_axis{1, idx_j, warp_id_1}; if(is_batched) - axes_[a_axes_->get(v, 2)] = distributed_axis{1, idx_z, warp_id_2}; + axes_[layout.axes[2]] = distributed_axis{1, idx_z, warp_id_2}; } -void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) { - if(tiles_->hmma(v) == analysis::HMMA_C) - init_hmma_axes(v, builder, u_thread_id, u_warp_id); +void selection::init_axes(const analysis::layout_t& layout, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) { + if(layout.type == analysis::HMMA_884) + init_hmma_axes(layout, builder, u_thread_id, u_warp_id); else - init_strided_scan_axes(v, builder, u_thread_id, u_warp_id); + init_strided_scan_axes(layout, builder, u_thread_id, u_warp_id); } /* ------------------- @@ -727,7 +726,7 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id void selection::create_shared_tile(ir::value *v, IRBuilder<> &builder, Value *sh_mem_ptr) { if(tmap_.find(v) != tmap_.end()) return; - auto order = tiles_->order(v); + auto order = layouts_->get(v).order; auto shapes = v->get_type()->get_tile_shapes(); unsigned pad = liveness_->get_pad(v); if(pad > 0) @@ -777,7 +776,7 @@ void selection::create_distributed_tile(ir::value *v, IRBuilder<> &builder) { axes[d].values = {builder.getInt32(0)}; } } - distributed_tile *T = new distributed_tile(ty, shapes, tiles_->order(v), axes, builder, false); + distributed_tile *T = new distributed_tile(ty, shapes, layouts_->get(v).order, axes, builder, false); bool is_inserted = tmap_.insert({v, T}).second; // constant range if(is_inserted && dynamic_cast(v)){ @@ -820,7 +819,7 @@ void selection::init_layouts(ir::function *fn, IRBuilder<> &builder, Value *sh_m Value *u_thread_warp_id = builder.CreateURem(u_thread_id, warp_size); Value *u_warp_id = builder.CreateUDiv(u_thread_id, warp_size); // create grid - for(auto x: tiles_->largest()) + for(auto x: layouts_->get_all()) init_axes(x.second, builder, u_thread_warp_id, u_warp_id); // create tile std::set seen; @@ -868,7 +867,7 @@ void selection::lower_masked_store(ir::masked_store_inst *x, LLVMContext &ctx, F void selection::lower_store(ir::store_inst *x, LLVMContext &ctx, Function *fn, IRBuilder<> &builder) { distributed_tile* ptrs = (distributed_tile*)tmap_.at(x->get_pointer_operand()); tile *scalars = tmap_.at(x->get_value_operand()); -// size_t ld = tiles_->order(x->get_pointer_operand())[0]; +// size_t ld = layouts_->order(x->get_pointer_operand())[0]; // unsigned vector_size = 2; // // vectorize pointers // std::map ptr_packets; @@ -1015,9 +1014,9 @@ void selection::lower_broadcast(ir::broadcast_inst *x, LLVMContext &ctx, Functio void selection::lower_copy_to_shared(ir::copy_to_shared_inst *x, LLVMContext &ctx, Function *fn, IRBuilder<> &builder) { unsigned vector_size = 1; - auto x_order = tiles_->order(x); + auto x_order = layouts_->get(x).order; ir::value *arg = x->get_operand(0); - auto arg_order = tiles_->order(arg); + auto arg_order = layouts_->get(arg).order; // tiles shared_tile* result = (shared_tile*)tmap_.at(x); distributed_tile* in = (distributed_tile*)tmap_.at(arg); @@ -1092,8 +1091,8 @@ void selection::lower_hmma_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *fn Value* u_thread_id = tgt_->get_local_id(builder.GetInsertBlock()->getModule(), builder, 0); - auto ord_a = tiles_->order(dot->get_operand(0)); - auto ord_b = tiles_->order(dot->get_operand(1)); + auto ord_a = layouts_->get(dot->get_operand(0)).order; + auto ord_b = layouts_->get(dot->get_operand(1)).order; bool is_a_trans = is_trans(dot->get_operand(0)); bool is_b_trans = is_trans(dot->get_operand(1)); @@ -1255,7 +1254,7 @@ void selection::lower_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *fn, IRB if(NK != 1) { shared_tile *TA = (shared_tile*)tmap_.at(A); shared_tile *TB = (shared_tile*)tmap_.at(B); - if(tiles_->hmma(dot) == analysis::HMMA_C) + if(layouts_->get(dot).type == analysis::HMMA_884) lower_hmma_dot(dot, ctx, fn, builder, TC, TA, TB, TD, NK); else lower_scanline_dot(dot, ctx, fn, builder, TC, TA, TB, TD, NK, c_ty, f_mul_add); @@ -1271,7 +1270,7 @@ void selection::lower_masked_load(ir::masked_load_inst *x, LLVMContext &ctx, Fun // find vector size distributed_tile* result = (distributed_tile*)tmap_.at(x); ir::value *ptr = x->get_pointer_operand(); - size_t ld = tiles_->order(ptr)[0]; + size_t ld = layouts_->get(ptr).order[0]; unsigned alignment = alignment_->get(ptr, ld); unsigned vector_size = std::min(result->axis(ld).contiguous, alignment); distributed_tile *pointers = (distributed_tile*)tmap_.at(ptr); @@ -1343,7 +1342,7 @@ void selection::lower_load(ir::load_inst *x, LLVMContext &ctx, Function *fn, IRB distributed_tile* result = (distributed_tile*)tmap_.at(x); // find vector size ir::value *ptr = x->get_pointer_operand(); - size_t ld = tiles_->order(ptr)[0]; + size_t ld = layouts_->get(ptr).order[0]; unsigned alignment = alignment_->get(ptr, ld); unsigned vector_size = std::min(result->axis(ld).contiguous, alignment); distributed_tile *pointers = (distributed_tile*)tmap_.at(ptr); diff --git a/lib/codegen/transform/coalesce.cc b/lib/codegen/transform/coalesce.cc index c5d356d31..8db94ed45 100644 --- a/lib/codegen/transform/coalesce.cc +++ b/lib/codegen/transform/coalesce.cc @@ -83,8 +83,7 @@ void coalesce::run(ir::module &mod) { if(axes.empty()) continue; for(auto it = ++axes.rbegin(); it != axes.rend(); it++) - remat.insert(remat.begin(), - it->second.begin(), it->second.end()); + remat.insert(remat.begin(), it->second.begin(), it->second.end()); } // rematerialize values for(ir::io_inst *r: remat) { diff --git a/lib/runtime/function.cc b/lib/runtime/function.cc index 6af64c105..806f003d1 100644 --- a/lib/runtime/function.cc +++ b/lib/runtime/function.cc @@ -202,9 +202,9 @@ std::unique_ptr function::make_bin(ir::module &module, driver::c // create passes codegen::analysis::align align; codegen::analysis::axes axes; - codegen::analysis::layout layouts(&axes); + codegen::analysis::layout layouts(&axes, &align); codegen::analysis::tiles tiles(opt.num_warps, &align, &axes, &layouts); - codegen::analysis::liveness liveness(&tiles); + codegen::analysis::liveness liveness(&tiles, &layouts); codegen::analysis::allocation allocation(&liveness, &tiles); codegen::transform::membar barriers(&liveness, &allocation); codegen::transform::dce dce;