From a3f76b6eb1515f20eb3a51abe7b53b450e359208 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Wed, 9 Oct 2019 18:17:48 -0400 Subject: [PATCH] [codegen] more cleaning --- include/triton/codegen/analysis/allocation.h | 5 +- include/triton/codegen/analysis/layout.h | 15 ++- include/triton/codegen/analysis/liveness.h | 3 +- include/triton/codegen/analysis/tiles.h | 65 ---------- include/triton/codegen/selection.h | 5 +- include/triton/runtime/function.h | 9 -- include/triton/tools/bench.hpp | 2 +- lib/codegen/analysis/allocation.cc | 1 - lib/codegen/analysis/layout.cc | 116 ++++++++++++++--- lib/codegen/analysis/liveness.cc | 5 +- lib/codegen/analysis/tiles.cc | 130 ------------------- lib/codegen/selection.cc | 29 ++--- lib/runtime/function.cc | 27 ++-- 13 files changed, 142 insertions(+), 270 deletions(-) delete mode 100644 include/triton/codegen/analysis/tiles.h delete mode 100644 lib/codegen/analysis/tiles.cc diff --git a/include/triton/codegen/analysis/allocation.h b/include/triton/codegen/analysis/allocation.h index b23f11964..858152150 100644 --- a/include/triton/codegen/analysis/allocation.h +++ b/include/triton/codegen/analysis/allocation.h @@ -24,8 +24,8 @@ class cts; class allocation { public: - allocation(liveness *live, tiles *params) - : liveness_(live), tiles_(params){ } + allocation(liveness *live) + : liveness_(live) { } // accessors bool has_offset(ir::value *x) const { return offsets_.find(x) != offsets_.end(); } unsigned offset(ir::value *x) const { return offsets_.at(x); } @@ -38,7 +38,6 @@ private: size_t allocated_size_; // dependences liveness *liveness_; - tiles *tiles_; }; } diff --git a/include/triton/codegen/analysis/layout.h b/include/triton/codegen/analysis/layout.h index a9d2d1a77..bec751659 100644 --- a/include/triton/codegen/analysis/layout.h +++ b/include/triton/codegen/analysis/layout.h @@ -28,10 +28,13 @@ enum layout_type_t { struct layout_t { layout_type_t type; - ir::value *i; std::vector axes; std::vector shapes; std::vector order; + std::map mts; + std::map nts; + std::map fpw; + std::map wpt; }; class layout { @@ -43,15 +46,18 @@ private: void connect(ir::value *x, ir::value *y); void make_graph(ir::instruction *i); + void init_hmma_tile(layout_t& layout); + void init_scanline_tile(layout_t &layout); + public: // constructor - layout(analysis::axes *axes, analysis::align *align); + layout(analysis::axes *axes, analysis::align *align, size_t num_warps); // accessors unsigned layout_of(ir::value *value) const; const std::vector& values_of(unsigned id) const; size_t num_layouts() const; - layout_t get(ir::value *v) const; - const std::map& get_all() const; + const layout_t& get(ir::value *v) const; + std::map &get_all(); // execution void run(ir::module &mod); @@ -59,6 +65,7 @@ public: private: analysis::axes* axes_; analysis::align* align_; + size_t num_warps_; tools::graph graph_; std::map groups_; std::map> values_; diff --git a/include/triton/codegen/analysis/liveness.h b/include/triton/codegen/analysis/liveness.h index 57fc90b81..b23463f06 100644 --- a/include/triton/codegen/analysis/liveness.h +++ b/include/triton/codegen/analysis/liveness.h @@ -73,7 +73,7 @@ private: public: - liveness(tiles *t, layout *l): tiles_(t), layouts_(l){ } + liveness(layout *l): layouts_(l){ } // padding unsigned get_pad(ir::value *v) const { return pad_.at(v); } // buffer size @@ -92,7 +92,6 @@ public: private: // analysis - tiles *tiles_; layout *layouts_; // stuff has_storage_map_t has_dedicated_storage_; diff --git a/include/triton/codegen/analysis/tiles.h b/include/triton/codegen/analysis/tiles.h deleted file mode 100644 index fdc03cee1..000000000 --- a/include/triton/codegen/analysis/tiles.h +++ /dev/null @@ -1,65 +0,0 @@ -#ifndef _TRITON_CODEGEN_ANALYSIS_TILES_H_ -#define _TRITON_CODEGEN_ANALYSIS_TILES_H_ - -#include -#include -#include -#include -#include "triton/codegen/analysis/layout.h" - -namespace triton{ - -namespace ir{ - class value; - class module; - class instruction; - class function; - class metaparameter; - class constant_int; -} - -namespace codegen{ - -namespace analysis{ - -class axes; -class layout; -class align; - - -class tiles { - typedef std::map> param_map_t; -private: - void init_hmma_tile(const layout_t& layout); - void init_scanline_tile(const layout_t& layout); - bool is_trans(ir::value *i); - -public: - tiles(size_t num_warps, analysis::align* align, analysis::axes* axes, analysis::layout* layout); - void run(ir::module &mod); - int mts(ir::value *value, unsigned ax); - int nts(ir::value *value, unsigned ax); - int fpw(ir::value *value, unsigned ax); - int wpt(ir::value *value, unsigned ax); - - -private: - // dependencies - analysis::align* align_; - analysis::layout* layout_; - analysis::axes* axes_; - // number of warps - size_t num_warps_; - // tile properties - std::map fpw_; - std::map wpt_; - std::map mts_; - std::map nts_; -}; - - -} -} -} - -#endif diff --git a/include/triton/codegen/selection.h b/include/triton/codegen/selection.h index bb03d3521..b20bc6d51 100644 --- a/include/triton/codegen/selection.h +++ b/include/triton/codegen/selection.h @@ -210,10 +210,10 @@ private: public: - selection(analysis::liveness* liveness, analysis::allocation *alloc, analysis::tiles *tiles, + selection(analysis::liveness* liveness, analysis::allocation *alloc, analysis::align *alignment, analysis::axes *axes, analysis::layout *layouts, target *tgt, unsigned num_warps) - : liveness_(liveness), alloc_(alloc), tiles_(tiles), + : liveness_(liveness), alloc_(alloc), alignment_(alignment), a_axes_(axes), layouts_(layouts), tgt_(tgt), num_warps_(num_warps){ } @@ -224,7 +224,6 @@ private: tmap_t tmap_; analysis::liveness *liveness_; analysis::allocation *alloc_; - analysis::tiles *tiles_; analysis::axes *a_axes_; analysis::layout *layouts_; analysis::align *alignment_; diff --git a/include/triton/runtime/function.h b/include/triton/runtime/function.h index 88de3825c..c12f9c6ca 100644 --- a/include/triton/runtime/function.h +++ b/include/triton/runtime/function.h @@ -11,15 +11,6 @@ // codegen #include "triton/codegen/selection.h" #include "triton/codegen/target.h" -#include "triton/codegen/analysis/tiles.h" -#include "triton/codegen/analysis/allocation.h" -#include "triton/codegen/analysis/liveness.h" -#include "triton/codegen/analysis/align.h" -#include "triton/codegen/transform/dce.h" -#include "triton/codegen/transform/peephole.h" -#include "triton/codegen/transform/membar.h" -#include "triton/codegen/transform/reassociate.h" -#include "triton/codegen/transform/cts.h" #include "triton/lang/parser.h" #include "triton/runtime/arg.h" diff --git a/include/triton/tools/bench.hpp b/include/triton/tools/bench.hpp index 48a4ab972..554b3bcc3 100644 --- a/include/triton/tools/bench.hpp +++ b/include/triton/tools/bench.hpp @@ -38,7 +38,7 @@ inline double bench(std::function const & op, driver::stream * stream) double total_time = 0; op(); stream->synchronize(); - while(total_time*1e-9 < 1e-2){ + while(total_time*1e-9 < 1e-3){ float norm = 1; // normalize clock if possible to reduce noise in auto-tuning if(auto cu_device = dynamic_cast(stream->context()->device())) diff --git a/lib/codegen/analysis/allocation.cc b/lib/codegen/analysis/allocation.cc index 91ca0868f..0fde814f3 100644 --- a/lib/codegen/analysis/allocation.cc +++ b/lib/codegen/analysis/allocation.cc @@ -3,7 +3,6 @@ #include "triton/codegen/analysis/allocation.h" #include "triton/codegen/analysis/liveness.h" #include "triton/codegen/transform/cts.h" -#include "triton/codegen/analysis/tiles.h" #include "triton/ir/basic_block.h" #include "triton/ir/type.h" #include "triton/ir/value.h" diff --git a/lib/codegen/analysis/layout.cc b/lib/codegen/analysis/layout.cc index 2a446f3ef..4b52f9e3b 100644 --- a/lib/codegen/analysis/layout.cc +++ b/lib/codegen/analysis/layout.cc @@ -14,8 +14,8 @@ namespace analysis{ // constructor -layout::layout(analysis::axes *axes, analysis::align *align) - : axes_(axes), align_(align) { } +layout::layout(analysis::axes *axes, analysis::align *align, size_t num_warps) + : axes_(axes), align_(align), num_warps_(num_warps) { } // get group id unsigned layout::layout_of(ir::value *value) const @@ -72,19 +72,19 @@ bool is_hmma_c(ir::value *v){ return result; } -layout_t layout::get(ir::value *v) const { +const layout_t &layout::get(ir::value *v) const { return layouts_.at(groups_.at(v)); } -const std::map& layout::get_all() const { +std::map& layout::get_all() { return layouts_; } -void extract_io_use(ir::value *v, std::set& result) { +void extract_io_use(ir::value *v, std::set& result) { for(ir::user* u: v->get_users()){ auto i = dynamic_cast(u); if(i && i->get_pointer_operand() == v) - result.insert(i); + result.insert(v); } } @@ -102,6 +102,75 @@ inline bool is_trans(ir::value *v) { return false; } +inline unsigned clamp(unsigned x, unsigned lo, unsigned hi) { + return std::min(std::max(x, lo), hi); +} + +void layout::init_hmma_tile(layout_t& layout) { + auto ord = layout.order; + auto shapes = layout.shapes; + unsigned shape_0 = shapes[ord[0]]; + unsigned shape_1 = shapes[ord[1]]; + /* fragments per warp */ + // try to make things as square as possible to maximize data re-use + std::vector fpw = {1, 1, 1}; + std::vector fpw_nm1; + unsigned num_fragments = std::min((shape_0/8)*(shape_1/8), 4); + do { + fpw_nm1 = fpw; + if(fpw[0]*fpw[1] < num_fragments) + fpw[0] = clamp(fpw[0]*2, 1, shape_0 / 8); + if(fpw[0]*fpw[1] < num_fragments) + fpw[1] = clamp(fpw[1]*2, 1, shape_1 / 8); + }while(fpw_nm1 != fpw); + // store parameters + for(unsigned d = 0; d < shapes.size(); d++) + layout.fpw[d] = fpw[d]; + /* warps per tile */ + // try to make things as square as possible to maximize data re-use + std::vector wpt = {1, 1, 1}; + std::vector wpt_nm1; + do{ + wpt_nm1 = wpt; + if(wpt[0] * wpt[1] * wpt[2] < num_warps_) + wpt[0] = clamp(wpt[0]*2, 1, shape_0 / (fpw[0]*8)); + if(wpt[0] * wpt[1] * wpt[2] < num_warps_) + wpt[1] = clamp(wpt[1]*2, 1, shape_1 / (fpw[1]*8)); + }while(wpt_nm1 != wpt); + // store parameters + for(unsigned d = 0; d < shapes.size(); d++) + layout.wpt[d] = wpt[d]; + /* sanity check */ + unsigned effective_num_warps = 1; + for(size_t d = 0; d < shapes.size(); d++) + effective_num_warps *= layout.wpt[d]; + if(num_warps_ != effective_num_warps) + throw std::runtime_error("cannot create a kernel with this amount of warps"); +} + +void layout::init_scanline_tile(layout_t& layout) { + auto ord = layout.order; + auto shapes = layout.shapes; + unsigned size = std::accumulate(shapes.begin(), shapes.end(), 1, std::multiplies()); + unsigned ld = ord[0]; + unsigned num_threads = num_warps_*32; + unsigned current = num_threads; + layout.nts[ld] = clamp(size / num_threads, 1, 4); + layout.mts[ld] = clamp(current, 1, shapes[ld] / layout.nts[ld]); + current = current / layout.mts[ld]; + for(size_t d = 1; d < shapes.size(); d++){ + ld = ord[d]; + layout.nts[ld] = 1; + layout.mts[ld] = clamp(current, 1, shapes[ld]); + current = current / layout.mts[ld]; + } + /* sanity check */ + unsigned effective_num_threads = 1; + for(size_t d = 0; d < shapes.size(); d++) + effective_num_threads *= layout.mts[d]; + if(num_threads != effective_num_threads) + throw std::runtime_error("cannot create a kernel with this amount of warps"); +} void layout::run(ir::module &mod) { // make graph @@ -114,8 +183,8 @@ void layout::run(ir::module &mod) { // create layouts for(const auto& x: values_) { bool hmma_c = std::any_of(x.second.begin(), x.second.end(), &is_hmma_c); + // type layouts_[x.first].type = hmma_c ? HMMA_884 : SCANLINE; - } @@ -130,35 +199,32 @@ void layout::run(ir::module &mod) { return ret; }; - // find out which value is the largest in each group + // find out axes for each layout for(const auto& x: values_) { auto cmp = [&rank](ir::value* x, ir::value *y) { return rank(x) < rank(y); }; ir::value *largest = *std::max_element(x.second.begin(), x.second.end(), cmp); layouts_[x.first].axes = axes_->get(largest); - layouts_[x.first].i = largest; layouts_[x.first].shapes = largest->get_type()->get_tile_shapes(); } // find out the layout ordering of a group - for(size_t i = 0; i < num_groups; i++){ - std::set io; - for(ir::value* v: values_of(i)) - extract_io_use(v, io); - auto cmp = [&rank](ir::io_inst* x, ir::io_inst *y) { - return rank(x->get_pointer_operand()) < rank(y->get_pointer_operand()); - }; - auto it = std::max_element(io.begin(), io.end(), cmp); - std::vector order(layouts_[i].axes.size()); + for(const auto& x: values_) { + std::set ptr; + for(ir::value* v: x.second) + extract_io_use(v, ptr); + size_t rank = layouts_[x.first].axes.size(); + std::vector order(rank); std::iota(order.begin(), order.end(), 0); - if(it != io.end()) { - auto max_contiguous = align_->contiguous((*it)->get_pointer_operand()); + for(ir::value *v: ptr){ + auto max_contiguous = align_->contiguous(v); std::sort(order.begin(), order.end(), [&](unsigned a, unsigned b) { return max_contiguous[a] > max_contiguous[b]; } ); } - layouts_[i].order = order; + layouts_[x.first].order = order; } + // matrix multiplication optimizations for(size_t i = 0; i < num_groups; i++){ std::vector dots; @@ -187,6 +253,14 @@ void layout::run(ir::module &mod) { } } + // tiling parameters + for(auto& x: layouts_){ + /* HMMA parameters*/ + if(x.second.type == HMMA_884) + init_hmma_tile(x.second); + else + init_scanline_tile(x.second); + } } } diff --git a/lib/codegen/analysis/liveness.cc b/lib/codegen/analysis/liveness.cc index 35c801e8f..d85271553 100644 --- a/lib/codegen/analysis/liveness.cc +++ b/lib/codegen/analysis/liveness.cc @@ -3,7 +3,6 @@ #include #include "triton/codegen/instructions.h" #include "triton/codegen/analysis/liveness.h" -#include "triton/codegen/analysis/tiles.h" #include "triton/codegen/analysis/layout.h" #include "triton/codegen/transform/cts.h" #include "triton/ir/basic_block.h" @@ -146,9 +145,9 @@ unsigned liveness::num_bytes(ir::value *x) { num_elements *= x; size_t depth; if(layouts_->get(x).type == HMMA_884) - depth = tiles_->wpt(op, axis); + depth = layouts_->get(op).wpt.at(axis); else - depth = tiles_->mts(op, axis); + depth = layouts_->get(op).mts.at(axis); return num_elements * num_bytes * depth; } unsigned num_bytes = x->get_type()->get_primitive_size_in_bits() / 8; diff --git a/lib/codegen/analysis/tiles.cc b/lib/codegen/analysis/tiles.cc deleted file mode 100644 index 6a16544e6..000000000 --- a/lib/codegen/analysis/tiles.cc +++ /dev/null @@ -1,130 +0,0 @@ -#include -#include -#include -#include "triton/codegen/analysis/align.h" -#include "triton/codegen/analysis/axes.h" -#include "triton/codegen/analysis/tiles.h" -#include "triton/codegen/analysis/layout.h" -#include "triton/ir/instructions.h" -#include "triton/ir/type.h" -#include "triton/ir/module.h" -#include "triton/ir/function.h" -#include "triton/ir/context_impl.h" -#include "triton/ir/constant.h" -#include "triton/driver/device.h" - - - -namespace triton{ -namespace codegen{ -namespace analysis{ - -tiles::tiles(size_t num_warps, analysis::align *align, analysis::axes *axes, analysis::layout *layout): - num_warps_(num_warps), align_(align), axes_(axes), layout_(layout) -{ } - - - -int tiles::mts(ir::value *value, unsigned ax) { - return mts_.at(axes_->get(value, ax)); -} - -int tiles::nts(ir::value *value, unsigned ax) { - return nts_.at(axes_->get(value, ax)); -} - -int tiles::fpw(ir::value *value, unsigned ax) { - return fpw_.at(axes_->get(value, ax)); -} - -int tiles::wpt(ir::value *value, unsigned ax) { - return wpt_.at(axes_->get(value, ax)); -} - - -unsigned clamp(unsigned x, unsigned lo, unsigned hi) { - return std::min(std::max(x, lo), hi); -} - - -void tiles::init_hmma_tile(const layout_t& layout) { - auto ord = layout.order; - auto shapes = layout.i->get_type()->get_tile_shapes(); - unsigned shape_0 = shapes[ord[0]]; - unsigned shape_1 = shapes[ord[1]]; - /* fragments per warp */ - // try to make things as square as possible to maximize data re-use - std::vector fpw = {1, 1, 1}; - std::vector fpw_nm1; - unsigned num_fragments = std::min((shape_0/8)*(shape_1/8), 4); - do { - fpw_nm1 = fpw; - if(fpw[0]*fpw[1] < num_fragments) - fpw[0] = clamp(fpw[0]*2, 1, shape_0 / 8); - if(fpw[0]*fpw[1] < num_fragments) - fpw[1] = clamp(fpw[1]*2, 1, shape_1 / 8); - }while(fpw_nm1 != fpw); - // store parameters - for(unsigned d = 0; d < shapes.size(); d++) - fpw_[layout.axes[d]] = fpw[d]; - /* warps per tile */ - // try to make things as square as possible to maximize data re-use - std::vector wpt = {1, 1, 1}; - std::vector wpt_nm1; - do{ - wpt_nm1 = wpt; - if(wpt[0] * wpt[1] * wpt[2] < num_warps_) - wpt[0] = clamp(wpt[0]*2, 1, shape_0 / (fpw[0]*8)); - if(wpt[0] * wpt[1] * wpt[2] < num_warps_) - wpt[1] = clamp(wpt[1]*2, 1, shape_1 / (fpw[1]*8)); - }while(wpt_nm1 != wpt); - // store parameters - for(unsigned d = 0; d < shapes.size(); d++) - wpt_[layout.axes[d]] = wpt[d]; - /* sanity check */ - unsigned effective_num_warps = 1; - for(size_t d = 0; d < shapes.size(); d++) - effective_num_warps *= wpt_[layout.axes[d]]; - if(num_warps_ != effective_num_warps) - throw std::runtime_error("cannot create a kernel with this amount of warps"); -} - -void tiles::init_scanline_tile(const layout_t& layout) { - auto ord = layout.order; - auto shapes = layout.shapes; - unsigned size = std::accumulate(shapes.begin(), shapes.end(), 1, std::multiplies()); - unsigned ld = ord[0]; - unsigned num_threads = num_warps_*32; - unsigned current = num_threads; - nts_[layout.axes[ld]] = clamp(size / num_threads, 1, 4); - mts_[layout.axes[ld]] = clamp(current, 1, shapes[ld] / nts_[layout.axes[ld]]); - current = current / mts_[layout.axes[ld]]; - for(size_t d = 1; d < shapes.size(); d++){ - ld = ord[d]; - nts_[layout.axes[ld]] = 1; - mts_[layout.axes[ld]] = clamp(current, 1, shapes[ld]); - current = current / mts_[layout.axes[ld]]; - } - /* sanity check */ - unsigned effective_num_threads = 1; - for(size_t d = 0; d < shapes.size(); d++) - effective_num_threads *= mts_[layout.axes[d]]; -// std::cout << num_threads << " " << effective_num_threads << std::endl; - if(num_threads != effective_num_threads) - throw std::runtime_error("cannot create a kernel with this amount of warps"); -} - -void tiles::run(ir::module &) { - // tiling parameters - for(auto x: layout_->get_all()){ - /* HMMA parameters*/ - if(x.second.type == HMMA_884) - init_hmma_tile(x.second); - else - init_scanline_tile(x.second); - } -} - -} -} -} diff --git a/lib/codegen/selection.cc b/lib/codegen/selection.cc index 1505bcbc6..ee5b55f08 100644 --- a/lib/codegen/selection.cc +++ b/lib/codegen/selection.cc @@ -4,7 +4,6 @@ #include "triton/codegen/analysis/liveness.h" #include "triton/codegen/analysis/layout.h" #include "triton/codegen/analysis/axes.h" -#include "triton/codegen/analysis/tiles.h" #include "triton/codegen/analysis/allocation.h" #include "triton/codegen/analysis/align.h" #include "triton/codegen/transform/coalesce.h" @@ -584,8 +583,8 @@ void selection::init_strided_scan_axes(const analysis::layout_t& layout, IRBuild std::vector nts(dim); std::vector mts(dim); for(unsigned i = 0; i < shapes.size(); i++){ - nts[i] = tiles_->nts(layout.i, i); - mts[i] = tiles_->mts(layout.i, i); + nts[i] = layout.nts.at(i); + mts[i] = layout.mts.at(i); } Value* full_thread_id = builder.CreateAdd(builder.CreateMul(u_warp_id, builder.getInt32(32)), u_thread_id); std::vector thread_id = delinearize(full_thread_id, order, mts, builder); @@ -618,13 +617,13 @@ void selection::init_hmma_axes(const analysis::layout_t& layout, IRBuilder<> &bu Value *_16 = builder.getInt32(16); // fragments per warp - unsigned fpw_0 = tiles_->fpw(layout.i, 0); - unsigned fpw_1 = tiles_->fpw(layout.i, 1); - unsigned fpw_2 = is_batched ? tiles_->fpw(layout.i, 2) : 1; + unsigned fpw_0 = layout.fpw.at(0); + unsigned fpw_1 = layout.fpw.at(1); + unsigned fpw_2 = is_batched ? layout.fpw.at(2) : 1; // warps per tile - unsigned wpt_0 = tiles_->wpt(layout.i, 0); - unsigned wpt_1 = tiles_->wpt(layout.i, 1); - unsigned wpt_2 = is_batched ? tiles_->wpt(layout.i, 2) : 1; + unsigned wpt_0 = layout.wpt.at(0); + unsigned wpt_1 = layout.wpt.at(1); + unsigned wpt_2 = is_batched ? layout.wpt.at(2) : 1; // hmma warp tile size unsigned hmma_wts_0 = fpw_0 * 8; unsigned hmma_wts_1 = fpw_1 * 8; @@ -933,7 +932,7 @@ void selection::lower_reduce(ir::reduce_inst *x, LLVMContext &ctx, Function *fn, tgt_->add_barrier(module, builder); builder.CreateStore(result, write_ptr); // build result - unsigned depth = tiles_->wpt(op, axis); + unsigned depth = layouts_->get(op).wpt.at(axis); for(unsigned i = depth/2; i > 0; i >>= 1){ // current indices indices_t current(write_idx.size(), builder.getInt32(0)); @@ -1022,7 +1021,7 @@ void selection::lower_copy_to_shared(ir::copy_to_shared_inst *x, LLVMContext &ct distributed_tile* in = (distributed_tile*)tmap_.at(arg); if(x_order == arg_order){ size_t ld = arg_order[0]; - vector_size = std::min(tiles_->nts(x, ld),tiles_->nts(arg, ld)); + vector_size = std::min(layouts_->get(x).nts.at(ld), layouts_->get(arg).nts.at(ld)); } std::map packets; @@ -1118,12 +1117,12 @@ void selection::lower_hmma_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *fn "{$10, $11}, " "{$0, $1, $2, $3, $4, $5, $6, $7};", "=f,=f,=f,=f,=f,=f,=f,=f,r,r,r,r,0,1,2,3,4,5,6,7", false); - unsigned fpw_0 = tiles_->fpw(dot, 0); - unsigned fpw_1 = tiles_->fpw(dot, 1); + unsigned fpw_0 = layouts_->get(dot).fpw.at(0); + unsigned fpw_1 = layouts_->get(dot).fpw.at(1); unsigned wts_0 = fpw_0 * 8; unsigned wts_1 = fpw_1 * 8; - unsigned wpt_0 = tiles_->wpt(dot, 0); - unsigned wpt_1 = tiles_->wpt(dot, 1); + unsigned wpt_0 = layouts_->get(dot).wpt.at(0); + unsigned wpt_1 = layouts_->get(dot).wpt.at(1); unsigned stride_rep_i = wpt_0 * wts_0; unsigned stride_rep_j = wpt_1 * wts_1; unsigned num_rep_i = shapes[0] / stride_rep_i; diff --git a/lib/runtime/function.cc b/lib/runtime/function.cc index 806f003d1..9578a3acb 100644 --- a/lib/runtime/function.cc +++ b/lib/runtime/function.cc @@ -4,11 +4,17 @@ #include #include #include "triton/codegen/analysis/axes.h" -#include "triton/codegen/analysis/layout.h" -#include "triton/codegen/analysis/tiles.h" +#include "triton/codegen/analysis/allocation.h" +#include "triton/codegen/analysis/liveness.h" +#include "triton/codegen/analysis/align.h" +#include "triton/codegen/transform/coalesce.h" +#include "triton/codegen/transform/dce.h" +#include "triton/codegen/transform/peephole.h" +#include "triton/codegen/transform/membar.h" +#include "triton/codegen/transform/reassociate.h" +#include "triton/codegen/transform/cts.h" #include "triton/codegen/selection.h" #include "triton/runtime/function.h" -#include "triton/codegen/transform/coalesce.h" #include "triton/lang/cpp.h" #include "triton/lang/parser.h" #include "triton/lang/code_gen.h" @@ -202,17 +208,16 @@ std::unique_ptr function::make_bin(ir::module &module, driver::c // create passes codegen::analysis::align align; codegen::analysis::axes axes; - codegen::analysis::layout layouts(&axes, &align); - codegen::analysis::tiles tiles(opt.num_warps, &align, &axes, &layouts); - codegen::analysis::liveness liveness(&tiles, &layouts); - codegen::analysis::allocation allocation(&liveness, &tiles); + codegen::analysis::layout layouts(&axes, &align, opt.num_warps); + codegen::analysis::liveness liveness(&layouts); + codegen::analysis::allocation allocation(&liveness); codegen::transform::membar barriers(&liveness, &allocation); codegen::transform::dce dce; codegen::transform::peephole peephole; codegen::transform::reassociate reassociate(&align); codegen::transform::coalesce coalesce(&align, &layouts); codegen::transform::cts cts; - codegen::selection selection(&liveness, &allocation, &tiles, &align, &axes, &layouts, target.get(), opt.num_warps); + codegen::selection selection(&liveness, &allocation, &align, &axes, &layouts, target.get(), opt.num_warps); // run passes peephole.run(module); dce.run(module); @@ -226,24 +231,20 @@ std::unique_ptr function::make_bin(ir::module &module, driver::c dce.run(module); reassociate.run(module); // ir::print(module, std::cout); -// exit(EXIT_FAILURE); dce.run(module); cts.run(module); align.run(module); axes.run(module); layouts.run(module); - tiles.run(module); liveness.run(module); allocation.run(module); if(allocation.allocated_size() > context->device()->max_shared_memory()) return std::unique_ptr(); barriers.run(module); dce.run(module); + align.run(module); axes.run(module); layouts.run(module); -// ir::print(module, std::cout); - align.run(module); - tiles.run(module); selection.run(module, *llvm); // return binary std::unique_ptr res(driver::module::create(context, std::move(llvm)));