diff --git a/include/triton/codegen/analysis/memalloc.h b/include/triton/codegen/analysis/allocation.h similarity index 87% rename from include/triton/codegen/analysis/memalloc.h rename to include/triton/codegen/analysis/allocation.h index f50d00b22..a43e93031 100644 --- a/include/triton/codegen/analysis/memalloc.h +++ b/include/triton/codegen/analysis/allocation.h @@ -18,11 +18,11 @@ namespace analysis{ class tiles; class liveness; -class meminfo; +class cts; -class memalloc { +class allocation { public: - memalloc(liveness *live, meminfo *buffer_info, tiles *params) + allocation(liveness *live, cts *buffer_info, tiles *params) : liveness_(live), buffer_info_(buffer_info), tiles_(params){ } // utilities unsigned num_bytes(ir::value *x); @@ -39,7 +39,7 @@ private: size_t allocated_size_; // dependences liveness *liveness_; - meminfo *buffer_info_; + cts *buffer_info_; tiles *tiles_; }; diff --git a/include/triton/codegen/analysis/axes.h b/include/triton/codegen/analysis/axes.h index d22fa5fa8..625d414c6 100644 --- a/include/triton/codegen/analysis/axes.h +++ b/include/triton/codegen/analysis/axes.h @@ -27,7 +27,6 @@ private: void update_graph_store(ir::instruction *i); void update_graph_reduce(ir::instruction *i); void update_graph_reshape(ir::instruction *i); - void update_graph_splat(ir::instruction *i); void update_graph_trans(ir::instruction *i); void update_graph_broadcast(ir::instruction *i); void update_graph_dot(ir::instruction *i); diff --git a/include/triton/codegen/analysis/liveness.h b/include/triton/codegen/analysis/liveness.h index 4b863ff55..df951161c 100644 --- a/include/triton/codegen/analysis/liveness.h +++ b/include/triton/codegen/analysis/liveness.h @@ -16,7 +16,7 @@ namespace analysis{ typedef unsigned slot_index; -class meminfo; +class cts; struct segment { slot_index start; @@ -44,7 +44,7 @@ public: public: // constructor - liveness(meminfo *info): info_(info){ } + liveness(cts *info): info_(info){ } // accessors const intervals_map_t& intervals() const { return intervals_; } segment get_interval(ir::value* v) const { return intervals_.at(v); } @@ -52,7 +52,7 @@ public: void run(ir::module &mod); private: - meminfo *info_; + cts *info_; has_storage_map_t has_dedicated_storage_; indices_map_t indices_; intervals_map_t intervals_; diff --git a/include/triton/codegen/instructions.h b/include/triton/codegen/instructions.h new file mode 100644 index 000000000..cecd716e0 --- /dev/null +++ b/include/triton/codegen/instructions.h @@ -0,0 +1,78 @@ +#ifndef _TRITON_CODEGEN_INSTRUCTIONS_H_ +#define _TRITON_CODEGEN_INSTRUCTIONS_H_ + +#include "triton/ir/enums.h" +#include +#include + +namespace triton{ +namespace codegen{ + + +enum storage_info_t { + NONE, + ANY, + SHARED, + DISTRIBUTED, + REPLICATED +}; + +typedef std::pair> inst_storage_info_t; +static const std::map storage_info = { + // scalars + { ir::INST_GET_PROGRAM_ID, {REPLICATED, {}}}, + { ir::INST_GET_NUM_PROGRAMS, {REPLICATED, {}}}, + // scalar/array + { ir::INST_PHI, {ANY, {ANY, ANY}}}, + { ir::INST_BINOP, {DISTRIBUTED, {DISTRIBUTED, DISTRIBUTED}}}, + { ir::INST_GETELEMENTPTR, {DISTRIBUTED, {DISTRIBUTED, DISTRIBUTED}}}, + { ir::INST_SELECT, {DISTRIBUTED, {DISTRIBUTED, DISTRIBUTED, DISTRIBUTED}}}, + { ir::INST_SQRT, {DISTRIBUTED, {DISTRIBUTED}}}, + // cmp + { ir::INST_ICMP, {DISTRIBUTED, {DISTRIBUTED, DISTRIBUTED}}}, + { ir::INST_FCMP, {DISTRIBUTED, {DISTRIBUTED, DISTRIBUTED}}}, + // cast + { ir::INST_CAST_TRUNC, {DISTRIBUTED, {DISTRIBUTED}}}, + { ir::INST_CAST_ZEXT, {DISTRIBUTED, {DISTRIBUTED}}}, + { ir::INST_CAST_SEXT, {DISTRIBUTED, {DISTRIBUTED}}}, + { ir::INST_CAST_FP_TRUNC, {DISTRIBUTED, {DISTRIBUTED}}}, + { ir::INST_CAST_FP_EXT, {DISTRIBUTED, {DISTRIBUTED}}}, + { ir::INST_CAST_UI_TO_FP, {DISTRIBUTED, {DISTRIBUTED}}}, + { ir::INST_CAST_SI_TO_FP, {DISTRIBUTED, {DISTRIBUTED}}}, + { ir::INST_CAST_FP_TO_UI, {DISTRIBUTED, {DISTRIBUTED}}}, + { ir::INST_CAST_FP_TO_SI, {DISTRIBUTED, {DISTRIBUTED}}}, + { ir::INST_CAST_PTR_TO_INT, {DISTRIBUTED, {DISTRIBUTED}}}, + { ir::INST_CAST_INT_TO_PTR, {DISTRIBUTED, {DISTRIBUTED}}}, + { ir::INST_CAST_BIT_CAST, {DISTRIBUTED, {DISTRIBUTED}}}, + { ir::INST_CAST_ADDR_SPACE_CAST, {DISTRIBUTED, {DISTRIBUTED}}}, + // io + { ir::INST_UNMASKED_LOAD, {DISTRIBUTED, {DISTRIBUTED}}}, + { ir::INST_MASKED_LOAD, {DISTRIBUTED, {DISTRIBUTED, DISTRIBUTED}}}, + { ir::INST_UNMASKED_STORE, {NONE , {DISTRIBUTED, DISTRIBUTED}}}, + { ir::INST_MASKED_STORE, {NONE , {DISTRIBUTED, DISTRIBUTED, DISTRIBUTED}}}, + // retile + { ir::INST_RESHAPE, {DISTRIBUTED, {DISTRIBUTED}}}, + { ir::INST_SPLAT, {DISTRIBUTED, {REPLICATED}}}, + { ir::INST_BROADCAST, {DISTRIBUTED, {REPLICATED}}}, + { ir::INST_DOWNCAST, {DISTRIBUTED, {REPLICATED}}}, + // array arithmetic + { ir::INST_TRANS, {SHARED, {DISTRIBUTED}}}, // TODO: not necessarily + { ir::INST_REDUCE, {SHARED, {DISTRIBUTED}}}, + { ir::INST_DOT, {DISTRIBUTED, {SHARED, SHARED, DISTRIBUTED}}}, + // terminator + { ir::INST_RETURN, {NONE, {}}}, + { ir::INST_UNCOND_BRANCH, {NONE, {}}}, + { ir::INST_COND_BRANCH, {NONE, {REPLICATED}}}, + + // intrinsics + { ir::INST_COPY_TO_SHARED, {SHARED, {DISTRIBUTED}}}, + { ir::INST_BARRIER, {NONE, {}}}, + { ir::INST_MAKE_RANGE_DYN, {DISTRIBUTED, {}}}, + { ir::INST_MAKE_RANGE_STA, {DISTRIBUTED, {}}}, + { ir::INST_MAKE_RANGE, {DISTRIBUTED, {}}} +}; + +} +} + +#endif diff --git a/include/triton/codegen/pass.h b/include/triton/codegen/pass.h new file mode 100644 index 000000000..129c02bc6 --- /dev/null +++ b/include/triton/codegen/pass.h @@ -0,0 +1,30 @@ +#ifndef _TRITON_CODEGEN_PASS_H_ +#define _TRITON_CODEGEN_PASS_H_ + +#include + +namespace triton{ + +namespace ir{ + class module; +} + +namespace codegen{ + +class pass { +public: + virtual void run(ir::module& m); +}; + + +class pass_manager { +public: + void add(pass* p); + void run(ir::module& m); + +private: + std::list passes; +}; + +} +} diff --git a/include/triton/codegen/selection.h b/include/triton/codegen/selection.h index 74a617af9..21bd83ee1 100644 --- a/include/triton/codegen/selection.h +++ b/include/triton/codegen/selection.h @@ -5,7 +5,7 @@ #include "triton/ir/module.h" #include "triton/ir/function.h" #include "triton/ir/type.h" -#include "triton/codegen/analysis/meminfo.h" +#include "triton/codegen/transform/cts.h" namespace llvm{ @@ -45,8 +45,8 @@ namespace codegen{ namespace analysis{ class tiles; class align; -class memalloc; -class meminfo; +class allocation; +class cts; class axes; class layout; } @@ -201,7 +201,7 @@ private: public: - selection(analysis::memalloc *alloc, analysis::tiles *tiles, analysis::meminfo *buffer_info, + selection(analysis::allocation *alloc, analysis::tiles *tiles, analysis::cts *buffer_info, analysis::align *alignment, analysis::axes *axes, analysis::layout *layouts, transform::coalesce* reorder, target *tgt, unsigned num_warps) : alloc_(alloc), tiles_(tiles), buffer_info_(buffer_info), @@ -213,11 +213,11 @@ public: private: vmap_t vmap_; tmap_t tmap_; - analysis::memalloc *alloc_; + analysis::allocation *alloc_; analysis::tiles *tiles_; analysis::axes *a_axes_; analysis::layout *layouts_; - analysis::meminfo *buffer_info_; + analysis::cts *buffer_info_; analysis::align *alignment_; transform::coalesce *reorder_; target *tgt_; diff --git a/include/triton/codegen/transform/coalesce.h b/include/triton/codegen/transform/coalesce.h index 3d418fdb5..680f1ccb2 100644 --- a/include/triton/codegen/transform/coalesce.h +++ b/include/triton/codegen/transform/coalesce.h @@ -20,7 +20,7 @@ namespace codegen{ namespace analysis{ class align; class layout; - class meminfo; + class cts; } namespace transform{ @@ -32,13 +32,13 @@ private: ir::value* rematerialize(ir::value *v, ir::builder& builder, std::map& seen); public: - coalesce(analysis::align* align, triton::codegen::analysis::layout *layouts, analysis::meminfo* mem); + coalesce(analysis::align* align, triton::codegen::analysis::layout *layouts, analysis::cts* mem); void run(ir::module &mod); private: analysis::align* align_; analysis::layout* layout_; - analysis::meminfo* mem_; + analysis::cts* mem_; }; } diff --git a/include/triton/codegen/analysis/meminfo.h b/include/triton/codegen/transform/cts.h similarity index 85% rename from include/triton/codegen/analysis/meminfo.h rename to include/triton/codegen/transform/cts.h index f4ad290a6..7b7237f7e 100644 --- a/include/triton/codegen/analysis/meminfo.h +++ b/include/triton/codegen/transform/cts.h @@ -16,7 +16,7 @@ namespace ir { namespace codegen{ namespace analysis{ -class meminfo { +class cts { public: void run(ir::module &mod); // queries @@ -25,8 +25,6 @@ public: bool is_shared(ir::value *x); bool is_loop_latch(ir::phi_node *phi, ir::instruction *terminator); ir::value *get_reference(ir::value *x); - void replace(ir::value* before, ir::value *after); - void copy(ir::value* y, ir::value *x); private: std::set shared_; diff --git a/include/triton/codegen/transform/membar.h b/include/triton/codegen/transform/membar.h index 8991ac57d..b4aebc2ce 100644 --- a/include/triton/codegen/transform/membar.h +++ b/include/triton/codegen/transform/membar.h @@ -15,8 +15,8 @@ namespace codegen{ namespace analysis{ -class memalloc; -class meminfo; +class allocation; +class cts; } @@ -38,12 +38,12 @@ private: std::pair transfer(ir::basic_block *block, const interval_vec_t &written_to, const interval_vec_t &read_from, std::set &insert_loc); public: - membar(analysis::memalloc *alloc, analysis::meminfo *buffer_info): alloc_(alloc), buffer_info_(buffer_info) {} + membar(analysis::allocation *alloc, analysis::cts *buffer_info): alloc_(alloc), buffer_info_(buffer_info) {} void run(ir::module &mod); private: - analysis::memalloc *alloc_; - analysis::meminfo *buffer_info_; + analysis::allocation *alloc_; + analysis::cts *buffer_info_; }; diff --git a/include/triton/ir/builder.h b/include/triton/ir/builder.h index d5707265a..5cf107be3 100644 --- a/include/triton/ir/builder.h +++ b/include/triton/ir/builder.h @@ -30,6 +30,7 @@ public: // Setters void set_insert_point(iterator instr); void set_insert_point(instruction* i); + void set_insert_point_after(instruction* i); void set_insert_point(basic_block* block); basic_block* get_insert_block() { return block_; } iterator get_insert_point() { return insert_point_;} diff --git a/include/triton/runtime/function.h b/include/triton/runtime/function.h index 0eaa9a33d..88de3825c 100644 --- a/include/triton/runtime/function.h +++ b/include/triton/runtime/function.h @@ -12,14 +12,14 @@ #include "triton/codegen/selection.h" #include "triton/codegen/target.h" #include "triton/codegen/analysis/tiles.h" -#include "triton/codegen/analysis/memalloc.h" +#include "triton/codegen/analysis/allocation.h" #include "triton/codegen/analysis/liveness.h" -#include "triton/codegen/analysis/meminfo.h" #include "triton/codegen/analysis/align.h" #include "triton/codegen/transform/dce.h" #include "triton/codegen/transform/peephole.h" #include "triton/codegen/transform/membar.h" #include "triton/codegen/transform/reassociate.h" +#include "triton/codegen/transform/cts.h" #include "triton/lang/parser.h" #include "triton/runtime/arg.h" diff --git a/lib/codegen/analysis/align.cc b/lib/codegen/analysis/align.cc index f84e8d692..ef57e7a4f 100644 --- a/lib/codegen/analysis/align.cc +++ b/lib/codegen/analysis/align.cc @@ -445,7 +445,6 @@ std::vector align::populate_starting_multiple_default(ir::value* v) { return add_to_cache(v, {1}, starting_multiple_); } - std::vector align::populate_starting_multiple(ir::value *v){ if(starting_multiple_.find(v) != starting_multiple_.end()) return starting_multiple_.at(v); diff --git a/lib/codegen/analysis/memalloc.cc b/lib/codegen/analysis/allocation.cc similarity index 92% rename from lib/codegen/analysis/memalloc.cc rename to lib/codegen/analysis/allocation.cc index 7f80824e3..b05b55a4d 100644 --- a/lib/codegen/analysis/memalloc.cc +++ b/lib/codegen/analysis/allocation.cc @@ -1,7 +1,7 @@ #include -#include "triton/codegen/analysis/memalloc.h" +#include "triton/codegen/analysis/allocation.h" #include "triton/codegen/analysis/liveness.h" -#include "triton/codegen/analysis/meminfo.h" +#include "triton/codegen/transform/cts.h" #include "triton/codegen/analysis/tiles.h" #include "triton/ir/basic_block.h" #include "triton/ir/type.h" @@ -13,7 +13,7 @@ namespace triton{ namespace codegen{ namespace analysis{ -unsigned memalloc::is_ld_padded(ir::value *x) { +unsigned allocation::is_ld_padded(ir::value *x) { if(auto *trans = dynamic_cast(x)){ if(trans->get_perm()[0]->get_value() != 0) return 4; @@ -45,7 +45,7 @@ unsigned memalloc::is_ld_padded(ir::value *x) { return 0; } -unsigned memalloc::num_bytes(ir::value *x) { +unsigned allocation::num_bytes(ir::value *x) { if(auto *red = dynamic_cast(x)){ unsigned num_bytes = x->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8; size_t axis = red->get_axis(); @@ -73,15 +73,15 @@ unsigned memalloc::num_bytes(ir::value *x) { return num_bytes; } -void memalloc::run(){ + +void allocation::run() { using std::max; using std::min; typedef std::multimap triples_map_type; std::vector I; - for(auto x: liveness_->intervals()){ + for(auto x: liveness_->intervals()) I.push_back(x.first); - } std::vector J = I; triples_map_type H; @@ -137,7 +137,7 @@ void memalloc::run(){ for(ir::value *X: V) colors[X] = (X==V[0])?0:-1; - // First-fit coloring + // First-fit graph coloring std::vector available(V.size()); for(ir::value *x: V){ // Non-neighboring colors are available @@ -158,6 +158,7 @@ void memalloc::run(){ for(ir::value *y: interferences[x]) Adj = std::max(Adj, starts[y] + num_bytes(y)); offsets_[x] = starts[x] + colors[x] * Adj; +// std::cout << x->get_name() << " " << offsets_[x] << " " << num_bytes(x) << std::endl; if(buffer_info_->is_double(x)){ ir::phi_node *phi = (ir::phi_node*)x; for(unsigned i = 0; i < phi->get_num_incoming(); i++){ @@ -167,6 +168,8 @@ void memalloc::run(){ } } +// exit(EXIT_FAILURE); + // Save maximum size of induced memory space allocated_size_ = 0; for(auto &x: offsets_){ diff --git a/lib/codegen/analysis/axes.cc b/lib/codegen/analysis/axes.cc index 790c8a36b..2c152f439 100644 --- a/lib/codegen/analysis/axes.cc +++ b/lib/codegen/analysis/axes.cc @@ -68,11 +68,6 @@ void axes::update_graph_reshape(ir::instruction *i) { } } -void axes::update_graph_splat(ir::instruction *) { - // argument is scalar so don't make any edge - return; -} - void axes::update_graph_trans(ir::instruction *i) { auto *trans = static_cast(i); ir::value *op = trans->get_operand(0); @@ -129,13 +124,14 @@ void axes::update_graph_elementwise(ir::instruction *i) { void axes::update_graph(ir::instruction *i) { switch (i->get_id()) { - case ir::INST_REDUCE: return update_graph_reduce(i); - case ir::INST_RESHAPE: return update_graph_reshape(i); - case ir::INST_SPLAT: return update_graph_splat(i); - case ir::INST_TRANS: return update_graph_trans(i); - case ir::INST_BROADCAST: return update_graph_broadcast(i); - case ir::INST_DOT: return update_graph_dot(i); - default: return update_graph_elementwise(i); + case ir::INST_REDUCE: return update_graph_reduce(i); + case ir::INST_RESHAPE: return update_graph_reshape(i); + case ir::INST_SPLAT: return; + case ir::INST_TRANS: return update_graph_trans(i); + case ir::INST_BROADCAST: return update_graph_broadcast(i); + case ir::INST_DOT: return update_graph_dot(i); + case ir::INST_COPY_TO_SHARED: return; + default: return update_graph_elementwise(i); } return; } diff --git a/lib/codegen/analysis/layout.cc b/lib/codegen/analysis/layout.cc index 77b25e0bb..40c8449ea 100644 --- a/lib/codegen/analysis/layout.cc +++ b/lib/codegen/analysis/layout.cc @@ -20,10 +20,9 @@ std::set layout::axes_of(ir::value *value) { rank = ty->get_tile_rank(); // create result std::set result; - for(size_t d = 0; d < rank; d++){ + for(size_t d = 0; d < rank; d++) if(axes_->has_id(value, d)) result.insert(axes_->get_id(value, d)); - } return result; } @@ -54,6 +53,7 @@ const std::vector& layout::values(unsigned id) const size_t layout::get_num_groups() const { return values_.size(); } +// connect two values void layout::connect(ir::value *x, ir::value *y) { if(x == y) return; @@ -75,6 +75,7 @@ void layout::connect(ir::value *x, ir::value *y) { } } +// make graph void layout::make_graph(ir::instruction *i) { for(ir::value* opx: i->ops()) for(ir::value* opy: i->ops()){ diff --git a/lib/codegen/analysis/liveness.cc b/lib/codegen/analysis/liveness.cc index 088691263..05d29032b 100644 --- a/lib/codegen/analysis/liveness.cc +++ b/lib/codegen/analysis/liveness.cc @@ -1,6 +1,7 @@ #include +#include "triton/codegen/instructions.h" #include "triton/codegen/analysis/liveness.h" -#include "triton/codegen/analysis/meminfo.h" +#include "triton/codegen/transform/cts.h" #include "triton/ir/basic_block.h" #include "triton/ir/function.h" #include "triton/ir/module.h" @@ -25,6 +26,11 @@ void liveness::run(ir::module &mod) { // Creates live intervals for(auto i: indices_){ ir::value *v = i.first; +// ir::instruction* instr = dynamic_cast(v); +// if(!instr) +// continue; +// if(storage_info.at(instr->get_id()).first != SHARED) +// continue; if(!info_->is_shared(v) || info_->get_reference(v)) continue; unsigned start = i.second; diff --git a/lib/codegen/instructions.cc b/lib/codegen/instructions.cc new file mode 100644 index 000000000..e69de29bb diff --git a/lib/codegen/pass.cc b/lib/codegen/pass.cc new file mode 100644 index 000000000..e69de29bb diff --git a/lib/codegen/selection.cc b/lib/codegen/selection.cc index d89a4e1c5..169283e7f 100644 --- a/lib/codegen/selection.cc +++ b/lib/codegen/selection.cc @@ -4,7 +4,7 @@ #include "triton/codegen/analysis/layout.h" #include "triton/codegen/analysis/axes.h" #include "triton/codegen/analysis/tiles.h" -#include "triton/codegen/analysis/memalloc.h" +#include "triton/codegen/analysis/allocation.h" #include "triton/codegen/analysis/align.h" #include "triton/codegen/transform/coalesce.h" #include "triton/ir/context.h" diff --git a/lib/codegen/transform/coalesce.cc b/lib/codegen/transform/coalesce.cc index 455f2fb5d..117bd35df 100644 --- a/lib/codegen/transform/coalesce.cc +++ b/lib/codegen/transform/coalesce.cc @@ -7,7 +7,7 @@ #include "triton/ir/instructions.h" #include "triton/ir/module.h" #include "triton/codegen/analysis/layout.h" -#include "triton/codegen/analysis/meminfo.h" +#include "triton/codegen/transform/cts.h" #include "triton/codegen/analysis/align.h" #include "triton/codegen/transform/coalesce.h" @@ -15,7 +15,7 @@ namespace triton { namespace codegen{ namespace transform{ -coalesce::coalesce(analysis::align* align, analysis::layout *layouts, analysis::meminfo *mem) +coalesce::coalesce(analysis::align* align, analysis::layout *layouts, analysis::cts *mem) : align_(align), layout_(layouts), mem_(mem) { } // Find all values that are used as pointer operands in LD/ST @@ -102,9 +102,6 @@ void coalesce::run(ir::module &mod) { r->replace_all_uses_with(cts); cts->replace_uses_of_with(cts, r); } - else{ - - } } } diff --git a/lib/codegen/analysis/meminfo.cc b/lib/codegen/transform/cts.cc similarity index 64% rename from lib/codegen/analysis/meminfo.cc rename to lib/codegen/transform/cts.cc index be55d6ac7..5a7e16a2d 100644 --- a/lib/codegen/analysis/meminfo.cc +++ b/lib/codegen/transform/cts.cc @@ -1,5 +1,7 @@ #include -#include "triton/codegen/analysis/meminfo.h" +#include +#include "triton/codegen/transform/cts.h" +#include "triton/codegen/instructions.h" #include "triton/ir/module.h" #include "triton/ir/function.h" #include "triton/ir/basic_block.h" @@ -12,7 +14,7 @@ namespace codegen{ namespace analysis{ // run pass on module -bool meminfo::is_loop_latch(ir::phi_node *phi, ir::instruction *terminator){ +bool cts::is_loop_latch(ir::phi_node *phi, ir::instruction *terminator){ if(phi->get_parent() != terminator->get_parent()) return false; if(auto *br = dynamic_cast(terminator)) @@ -24,24 +26,6 @@ bool meminfo::is_loop_latch(ir::phi_node *phi, ir::instruction *terminator){ throw std::runtime_error("unreachable"); } -void meminfo::replace(ir::value* before, ir::value *after) { - shared_.erase(before); - shared_.insert(after); - if(refs_.find(before) != refs_.end()){ - ir::value* v = refs_.at(before); - refs_.erase(before); - refs_.insert({after, v}); - } -} - -void meminfo::copy(ir::value* y, ir::value *x) { - if(shared_.find(x) != shared_.end()) - shared_.insert(y); - if(refs_.find(x) != refs_.end()) - refs_[y] = refs_[x]; - if(double_.find(x) != double_.end()) - double_.insert(y); -} inline bool get_is_shared(ir::value* v) { @@ -62,40 +46,46 @@ inline bool get_is_shared(ir::value* v) { return false; } -void add_copy(ir::value *x, ir::builder &builder) { - if(auto phi = dynamic_cast(x)){ +void add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder) { + auto *i = dynamic_cast(x); + // not an instruction + if(!i) { + builder.set_insert_point(parent); + ir::value *cts = builder.create_copy_to_shared(x); + parent->replace_uses_of_with(x, cts); + return; + } + // phi node + if(auto* phi = dynamic_cast(x)) { for(unsigned i = 0; i < phi->get_num_incoming(); ++i) - add_copy(phi->get_incoming_value(i), builder); - } - else { - if(get_is_shared(x)) - return; - if(auto *i = dynamic_cast(x)){ - ir::basic_block* block = i->get_parent(); - auto it = std::find(block->begin(), block->end(), i); - builder.set_insert_point(++it); - } - ir::instruction *rx = (ir::instruction*)builder.create_copy_to_shared(x); - x->replace_all_uses_with(rx); - rx->set_operand(0, x); + add_copy(phi, phi->get_incoming_value(i), builder); + return; } + ir::value_id_t id = i->get_id(); + // already in shared memory + if(storage_info.at(id).first == SHARED) + return; + // copy + builder.set_insert_point_after(i); + ir::value *cts = builder.create_copy_to_shared(x); + parent->replace_uses_of_with(x, cts); } -void meminfo::run(ir::module &mod) { +void cts::run(ir::module &mod) { shared_.clear(); refs_.clear(); double_.clear(); // Add shared copies + ir::builder &builder = mod.get_builder(); for(ir::function *fn: mod.get_function_list()){ - ir::builder builder(mod.get_context()); for(ir::basic_block *block: fn->blocks()) for(ir::instruction *i: block->get_inst_list()){ - if(dynamic_cast(i)) - if(i->get_operand(1)->get_type()->get_tile_shapes()[1] != 1){ - add_copy(i->get_operand(0), builder); - add_copy(i->get_operand(1), builder); - } + auto storage = storage_info.at(i->get_id()); + // copy to shared operands when necessary + for(size_t k = 0; k < storage.second.size(); k++) + if(storage.second[k] == SHARED) + add_copy(i, i->get_operand(k), builder); } } @@ -135,15 +125,15 @@ void meminfo::run(ir::module &mod) { } // query double-buffered status -bool meminfo::is_double(ir::value *x) +bool cts::is_double(ir::value *x) { return double_.find(x) != double_.end(); } // query shared status -bool meminfo::is_shared(ir::value *x) +bool cts::is_shared(ir::value *x) { return shared_.find(x) != shared_.end(); } // get reference if any -ir::value *meminfo::get_reference(ir::value *x) +ir::value *cts::get_reference(ir::value *x) { return refs_[x]; } diff --git a/lib/codegen/transform/dce.cc b/lib/codegen/transform/dce.cc index 18406b4ab..4497f2fde 100644 --- a/lib/codegen/transform/dce.cc +++ b/lib/codegen/transform/dce.cc @@ -20,12 +20,22 @@ void dce::run(ir::module &mod) { // iterate through blocks for(ir::basic_block *block: rpo) for(ir::instruction *i: block->get_inst_list()){ - if(dynamic_cast(i) || dynamic_cast(i) - || dynamic_cast(i) || dynamic_cast(i) - || dynamic_cast(i) || dynamic_cast(i) || dynamic_cast(i) - || dynamic_cast(i)){ - work_list.push_back(i); - marked.insert(i); + switch(i->get_id()){ + case ir::INST_RETURN: + case ir::INST_UNCOND_BRANCH: + case ir::INST_COND_BRANCH: + case ir::INST_UNMASKED_STORE: + case ir::INST_MASKED_STORE: + case ir::INST_ATOMIC_ADD: + case ir::INST_ATOMIC_CAS: + case ir::INST_ATOMIC_EXCH: + case ir::INST_BARRIER: { + work_list.push_back(i); + marked.insert(i); + break; + } + default: + break; } } } diff --git a/lib/codegen/transform/membar.cc b/lib/codegen/transform/membar.cc index b8b029d9a..e77e9c71a 100644 --- a/lib/codegen/transform/membar.cc +++ b/lib/codegen/transform/membar.cc @@ -3,8 +3,8 @@ #include #include "triton/codegen/transform/membar.h" -#include "triton/codegen/analysis/memalloc.h" -#include "triton/codegen/analysis/meminfo.h" +#include "triton/codegen/analysis/allocation.h" +#include "triton/codegen/transform/cts.h" #include "triton/ir/module.h" #include "triton/ir/function.h" #include "triton/ir/basic_block.h" diff --git a/lib/codegen/transform/reassociate.cc b/lib/codegen/transform/reassociate.cc index 38e8c79ed..c2b9d2d4b 100644 --- a/lib/codegen/transform/reassociate.cc +++ b/lib/codegen/transform/reassociate.cc @@ -122,7 +122,6 @@ ir::value *reassociate::reassociate_idx(ir::value *old_value, new_value = builder.create_add(rrhs, builder.create_add(lrhs, lhs), name, cst); } } - // extract constant and non-constant if(ir::instruction *bin_add = is_bin_add(new_value)){ ir::value *new_lhs = bin_add->get_operand(0); @@ -136,12 +135,9 @@ ir::value *reassociate::reassociate_idx(ir::value *old_value, noncst = new_lhs; } } - // clean-up if some re-ordering happened - if(old_value != new_value){ + if(old_value != new_value) old_value->replace_all_uses_with(new_value); - } - return new_value; } diff --git a/lib/ir/builder.cc b/lib/ir/builder.cc index 00450b547..db2080a4d 100644 --- a/lib/ir/builder.cc +++ b/lib/ir/builder.cc @@ -27,6 +27,13 @@ void builder::set_insert_point(instruction* i){ } +void builder::set_insert_point_after(instruction* i){ + block_ = i->get_parent(); + auto it = std::find(block_->begin(), block_->end(), i); + set_insert_point(++it); +} + + void builder::set_insert_point(basic_block *block){ block_ = block; insert_point_ = block->end(); diff --git a/lib/runtime/function.cc b/lib/runtime/function.cc index 04977966d..501d62f54 100644 --- a/lib/runtime/function.cc +++ b/lib/runtime/function.cc @@ -199,14 +199,14 @@ std::unique_ptr function::make_bin(ir::module &module, driver::c llvm::LLVMContext ctx; std::unique_ptr llvm(new llvm::Module(module.get_name(), ctx)); // create passes - codegen::analysis::meminfo shmem_info; + codegen::analysis::cts shmem_info; codegen::analysis::align align; codegen::analysis::liveness shmem_liveness(&shmem_info); codegen::analysis::axes axes; codegen::analysis::layout layouts(&axes); codegen::transform::coalesce coalesce(&align, &layouts, &shmem_info); codegen::analysis::tiles tiles(opt.num_warps, &align, &axes, &layouts); - codegen::analysis::memalloc shmem_allocation(&shmem_liveness, &shmem_info, &tiles); + codegen::analysis::allocation shmem_allocation(&shmem_liveness, &shmem_info, &tiles); codegen::transform::membar shmem_barriers(&shmem_allocation, &shmem_info); codegen::transform::dce dce; codegen::transform::peephole peephole; @@ -229,6 +229,7 @@ std::unique_ptr function::make_bin(ir::module &module, driver::c peephole.run(module); shmem_info.run(module); shmem_liveness.run(module); + ir::print(module, std::cout); shmem_allocation.run(); if(shmem_allocation.allocated_size() > context->device()->max_shared_memory()) return std::unique_ptr();