[codegen] cleaning-up / formalizing shared-memory passes

This commit is contained in:
Philippe Tillet
2019-09-20 16:01:12 -04:00
parent e35be1ddcf
commit 43d88154bd
26 changed files with 229 additions and 117 deletions

View File

@@ -18,11 +18,11 @@ namespace analysis{
class tiles; class tiles;
class liveness; class liveness;
class meminfo; class cts;
class memalloc { class allocation {
public: public:
memalloc(liveness *live, meminfo *buffer_info, tiles *params) allocation(liveness *live, cts *buffer_info, tiles *params)
: liveness_(live), buffer_info_(buffer_info), tiles_(params){ } : liveness_(live), buffer_info_(buffer_info), tiles_(params){ }
// utilities // utilities
unsigned num_bytes(ir::value *x); unsigned num_bytes(ir::value *x);
@@ -39,7 +39,7 @@ private:
size_t allocated_size_; size_t allocated_size_;
// dependences // dependences
liveness *liveness_; liveness *liveness_;
meminfo *buffer_info_; cts *buffer_info_;
tiles *tiles_; tiles *tiles_;
}; };

View File

@@ -27,7 +27,6 @@ private:
void update_graph_store(ir::instruction *i); void update_graph_store(ir::instruction *i);
void update_graph_reduce(ir::instruction *i); void update_graph_reduce(ir::instruction *i);
void update_graph_reshape(ir::instruction *i); void update_graph_reshape(ir::instruction *i);
void update_graph_splat(ir::instruction *i);
void update_graph_trans(ir::instruction *i); void update_graph_trans(ir::instruction *i);
void update_graph_broadcast(ir::instruction *i); void update_graph_broadcast(ir::instruction *i);
void update_graph_dot(ir::instruction *i); void update_graph_dot(ir::instruction *i);

View File

@@ -16,7 +16,7 @@ namespace analysis{
typedef unsigned slot_index; typedef unsigned slot_index;
class meminfo; class cts;
struct segment { struct segment {
slot_index start; slot_index start;
@@ -44,7 +44,7 @@ public:
public: public:
// constructor // constructor
liveness(meminfo *info): info_(info){ } liveness(cts *info): info_(info){ }
// accessors // accessors
const intervals_map_t& intervals() const { return intervals_; } const intervals_map_t& intervals() const { return intervals_; }
segment get_interval(ir::value* v) const { return intervals_.at(v); } segment get_interval(ir::value* v) const { return intervals_.at(v); }
@@ -52,7 +52,7 @@ public:
void run(ir::module &mod); void run(ir::module &mod);
private: private:
meminfo *info_; cts *info_;
has_storage_map_t has_dedicated_storage_; has_storage_map_t has_dedicated_storage_;
indices_map_t indices_; indices_map_t indices_;
intervals_map_t intervals_; intervals_map_t intervals_;

View File

@@ -0,0 +1,78 @@
#ifndef _TRITON_CODEGEN_INSTRUCTIONS_H_
#define _TRITON_CODEGEN_INSTRUCTIONS_H_
#include "triton/ir/enums.h"
#include <map>
#include <vector>
namespace triton{
namespace codegen{
enum storage_info_t {
NONE,
ANY,
SHARED,
DISTRIBUTED,
REPLICATED
};
typedef std::pair<storage_info_t, std::vector<storage_info_t>> inst_storage_info_t;
static const std::map<ir::value_id_t, inst_storage_info_t> storage_info = {
// scalars
{ ir::INST_GET_PROGRAM_ID, {REPLICATED, {}}},
{ ir::INST_GET_NUM_PROGRAMS, {REPLICATED, {}}},
// scalar/array
{ ir::INST_PHI, {ANY, {ANY, ANY}}},
{ ir::INST_BINOP, {DISTRIBUTED, {DISTRIBUTED, DISTRIBUTED}}},
{ ir::INST_GETELEMENTPTR, {DISTRIBUTED, {DISTRIBUTED, DISTRIBUTED}}},
{ ir::INST_SELECT, {DISTRIBUTED, {DISTRIBUTED, DISTRIBUTED, DISTRIBUTED}}},
{ ir::INST_SQRT, {DISTRIBUTED, {DISTRIBUTED}}},
// cmp
{ ir::INST_ICMP, {DISTRIBUTED, {DISTRIBUTED, DISTRIBUTED}}},
{ ir::INST_FCMP, {DISTRIBUTED, {DISTRIBUTED, DISTRIBUTED}}},
// cast
{ ir::INST_CAST_TRUNC, {DISTRIBUTED, {DISTRIBUTED}}},
{ ir::INST_CAST_ZEXT, {DISTRIBUTED, {DISTRIBUTED}}},
{ ir::INST_CAST_SEXT, {DISTRIBUTED, {DISTRIBUTED}}},
{ ir::INST_CAST_FP_TRUNC, {DISTRIBUTED, {DISTRIBUTED}}},
{ ir::INST_CAST_FP_EXT, {DISTRIBUTED, {DISTRIBUTED}}},
{ ir::INST_CAST_UI_TO_FP, {DISTRIBUTED, {DISTRIBUTED}}},
{ ir::INST_CAST_SI_TO_FP, {DISTRIBUTED, {DISTRIBUTED}}},
{ ir::INST_CAST_FP_TO_UI, {DISTRIBUTED, {DISTRIBUTED}}},
{ ir::INST_CAST_FP_TO_SI, {DISTRIBUTED, {DISTRIBUTED}}},
{ ir::INST_CAST_PTR_TO_INT, {DISTRIBUTED, {DISTRIBUTED}}},
{ ir::INST_CAST_INT_TO_PTR, {DISTRIBUTED, {DISTRIBUTED}}},
{ ir::INST_CAST_BIT_CAST, {DISTRIBUTED, {DISTRIBUTED}}},
{ ir::INST_CAST_ADDR_SPACE_CAST, {DISTRIBUTED, {DISTRIBUTED}}},
// io
{ ir::INST_UNMASKED_LOAD, {DISTRIBUTED, {DISTRIBUTED}}},
{ ir::INST_MASKED_LOAD, {DISTRIBUTED, {DISTRIBUTED, DISTRIBUTED}}},
{ ir::INST_UNMASKED_STORE, {NONE , {DISTRIBUTED, DISTRIBUTED}}},
{ ir::INST_MASKED_STORE, {NONE , {DISTRIBUTED, DISTRIBUTED, DISTRIBUTED}}},
// retile
{ ir::INST_RESHAPE, {DISTRIBUTED, {DISTRIBUTED}}},
{ ir::INST_SPLAT, {DISTRIBUTED, {REPLICATED}}},
{ ir::INST_BROADCAST, {DISTRIBUTED, {REPLICATED}}},
{ ir::INST_DOWNCAST, {DISTRIBUTED, {REPLICATED}}},
// array arithmetic
{ ir::INST_TRANS, {SHARED, {DISTRIBUTED}}}, // TODO: not necessarily
{ ir::INST_REDUCE, {SHARED, {DISTRIBUTED}}},
{ ir::INST_DOT, {DISTRIBUTED, {SHARED, SHARED, DISTRIBUTED}}},
// terminator
{ ir::INST_RETURN, {NONE, {}}},
{ ir::INST_UNCOND_BRANCH, {NONE, {}}},
{ ir::INST_COND_BRANCH, {NONE, {REPLICATED}}},
// intrinsics
{ ir::INST_COPY_TO_SHARED, {SHARED, {DISTRIBUTED}}},
{ ir::INST_BARRIER, {NONE, {}}},
{ ir::INST_MAKE_RANGE_DYN, {DISTRIBUTED, {}}},
{ ir::INST_MAKE_RANGE_STA, {DISTRIBUTED, {}}},
{ ir::INST_MAKE_RANGE, {DISTRIBUTED, {}}}
};
}
}
#endif

View File

@@ -0,0 +1,30 @@
#ifndef _TRITON_CODEGEN_PASS_H_
#define _TRITON_CODEGEN_PASS_H_
#include <list>
namespace triton{
namespace ir{
class module;
}
namespace codegen{
class pass {
public:
virtual void run(ir::module& m);
};
class pass_manager {
public:
void add(pass* p);
void run(ir::module& m);
private:
std::list<pass*> passes;
};
}
}

View File

@@ -5,7 +5,7 @@
#include "triton/ir/module.h" #include "triton/ir/module.h"
#include "triton/ir/function.h" #include "triton/ir/function.h"
#include "triton/ir/type.h" #include "triton/ir/type.h"
#include "triton/codegen/analysis/meminfo.h" #include "triton/codegen/transform/cts.h"
namespace llvm{ namespace llvm{
@@ -45,8 +45,8 @@ namespace codegen{
namespace analysis{ namespace analysis{
class tiles; class tiles;
class align; class align;
class memalloc; class allocation;
class meminfo; class cts;
class axes; class axes;
class layout; class layout;
} }
@@ -201,7 +201,7 @@ private:
public: public:
selection(analysis::memalloc *alloc, analysis::tiles *tiles, analysis::meminfo *buffer_info, selection(analysis::allocation *alloc, analysis::tiles *tiles, analysis::cts *buffer_info,
analysis::align *alignment, analysis::axes *axes, analysis::layout *layouts, analysis::align *alignment, analysis::axes *axes, analysis::layout *layouts,
transform::coalesce* reorder, target *tgt, unsigned num_warps) transform::coalesce* reorder, target *tgt, unsigned num_warps)
: alloc_(alloc), tiles_(tiles), buffer_info_(buffer_info), : alloc_(alloc), tiles_(tiles), buffer_info_(buffer_info),
@@ -213,11 +213,11 @@ public:
private: private:
vmap_t vmap_; vmap_t vmap_;
tmap_t tmap_; tmap_t tmap_;
analysis::memalloc *alloc_; analysis::allocation *alloc_;
analysis::tiles *tiles_; analysis::tiles *tiles_;
analysis::axes *a_axes_; analysis::axes *a_axes_;
analysis::layout *layouts_; analysis::layout *layouts_;
analysis::meminfo *buffer_info_; analysis::cts *buffer_info_;
analysis::align *alignment_; analysis::align *alignment_;
transform::coalesce *reorder_; transform::coalesce *reorder_;
target *tgt_; target *tgt_;

View File

@@ -20,7 +20,7 @@ namespace codegen{
namespace analysis{ namespace analysis{
class align; class align;
class layout; class layout;
class meminfo; class cts;
} }
namespace transform{ namespace transform{
@@ -32,13 +32,13 @@ private:
ir::value* rematerialize(ir::value *v, ir::builder& builder, std::map<ir::value*, ir::value*>& seen); ir::value* rematerialize(ir::value *v, ir::builder& builder, std::map<ir::value*, ir::value*>& seen);
public: public:
coalesce(analysis::align* align, triton::codegen::analysis::layout *layouts, analysis::meminfo* mem); coalesce(analysis::align* align, triton::codegen::analysis::layout *layouts, analysis::cts* mem);
void run(ir::module &mod); void run(ir::module &mod);
private: private:
analysis::align* align_; analysis::align* align_;
analysis::layout* layout_; analysis::layout* layout_;
analysis::meminfo* mem_; analysis::cts* mem_;
}; };
} }

View File

@@ -16,7 +16,7 @@ namespace ir {
namespace codegen{ namespace codegen{
namespace analysis{ namespace analysis{
class meminfo { class cts {
public: public:
void run(ir::module &mod); void run(ir::module &mod);
// queries // queries
@@ -25,8 +25,6 @@ public:
bool is_shared(ir::value *x); bool is_shared(ir::value *x);
bool is_loop_latch(ir::phi_node *phi, ir::instruction *terminator); bool is_loop_latch(ir::phi_node *phi, ir::instruction *terminator);
ir::value *get_reference(ir::value *x); ir::value *get_reference(ir::value *x);
void replace(ir::value* before, ir::value *after);
void copy(ir::value* y, ir::value *x);
private: private:
std::set<ir::value*> shared_; std::set<ir::value*> shared_;

View File

@@ -15,8 +15,8 @@ namespace codegen{
namespace analysis{ namespace analysis{
class memalloc; class allocation;
class meminfo; class cts;
} }
@@ -38,12 +38,12 @@ private:
std::pair<interval_vec_t, interval_vec_t> transfer(ir::basic_block *block, const interval_vec_t &written_to, const interval_vec_t &read_from, std::set<ir::instruction *> &insert_loc); std::pair<interval_vec_t, interval_vec_t> transfer(ir::basic_block *block, const interval_vec_t &written_to, const interval_vec_t &read_from, std::set<ir::instruction *> &insert_loc);
public: public:
membar(analysis::memalloc *alloc, analysis::meminfo *buffer_info): alloc_(alloc), buffer_info_(buffer_info) {} membar(analysis::allocation *alloc, analysis::cts *buffer_info): alloc_(alloc), buffer_info_(buffer_info) {}
void run(ir::module &mod); void run(ir::module &mod);
private: private:
analysis::memalloc *alloc_; analysis::allocation *alloc_;
analysis::meminfo *buffer_info_; analysis::cts *buffer_info_;
}; };

View File

@@ -30,6 +30,7 @@ public:
// Setters // Setters
void set_insert_point(iterator instr); void set_insert_point(iterator instr);
void set_insert_point(instruction* i); void set_insert_point(instruction* i);
void set_insert_point_after(instruction* i);
void set_insert_point(basic_block* block); void set_insert_point(basic_block* block);
basic_block* get_insert_block() { return block_; } basic_block* get_insert_block() { return block_; }
iterator get_insert_point() { return insert_point_;} iterator get_insert_point() { return insert_point_;}

View File

@@ -12,14 +12,14 @@
#include "triton/codegen/selection.h" #include "triton/codegen/selection.h"
#include "triton/codegen/target.h" #include "triton/codegen/target.h"
#include "triton/codegen/analysis/tiles.h" #include "triton/codegen/analysis/tiles.h"
#include "triton/codegen/analysis/memalloc.h" #include "triton/codegen/analysis/allocation.h"
#include "triton/codegen/analysis/liveness.h" #include "triton/codegen/analysis/liveness.h"
#include "triton/codegen/analysis/meminfo.h"
#include "triton/codegen/analysis/align.h" #include "triton/codegen/analysis/align.h"
#include "triton/codegen/transform/dce.h" #include "triton/codegen/transform/dce.h"
#include "triton/codegen/transform/peephole.h" #include "triton/codegen/transform/peephole.h"
#include "triton/codegen/transform/membar.h" #include "triton/codegen/transform/membar.h"
#include "triton/codegen/transform/reassociate.h" #include "triton/codegen/transform/reassociate.h"
#include "triton/codegen/transform/cts.h"
#include "triton/lang/parser.h" #include "triton/lang/parser.h"
#include "triton/runtime/arg.h" #include "triton/runtime/arg.h"

View File

@@ -445,7 +445,6 @@ std::vector<unsigned> align::populate_starting_multiple_default(ir::value* v) {
return add_to_cache(v, {1}, starting_multiple_); return add_to_cache(v, {1}, starting_multiple_);
} }
std::vector<unsigned> align::populate_starting_multiple(ir::value *v){ std::vector<unsigned> align::populate_starting_multiple(ir::value *v){
if(starting_multiple_.find(v) != starting_multiple_.end()) if(starting_multiple_.find(v) != starting_multiple_.end())
return starting_multiple_.at(v); return starting_multiple_.at(v);

View File

@@ -1,7 +1,7 @@
#include <algorithm> #include <algorithm>
#include "triton/codegen/analysis/memalloc.h" #include "triton/codegen/analysis/allocation.h"
#include "triton/codegen/analysis/liveness.h" #include "triton/codegen/analysis/liveness.h"
#include "triton/codegen/analysis/meminfo.h" #include "triton/codegen/transform/cts.h"
#include "triton/codegen/analysis/tiles.h" #include "triton/codegen/analysis/tiles.h"
#include "triton/ir/basic_block.h" #include "triton/ir/basic_block.h"
#include "triton/ir/type.h" #include "triton/ir/type.h"
@@ -13,7 +13,7 @@ namespace triton{
namespace codegen{ namespace codegen{
namespace analysis{ namespace analysis{
unsigned memalloc::is_ld_padded(ir::value *x) { unsigned allocation::is_ld_padded(ir::value *x) {
if(auto *trans = dynamic_cast<ir::trans_inst*>(x)){ if(auto *trans = dynamic_cast<ir::trans_inst*>(x)){
if(trans->get_perm()[0]->get_value() != 0) if(trans->get_perm()[0]->get_value() != 0)
return 4; return 4;
@@ -45,7 +45,7 @@ unsigned memalloc::is_ld_padded(ir::value *x) {
return 0; return 0;
} }
unsigned memalloc::num_bytes(ir::value *x) { unsigned allocation::num_bytes(ir::value *x) {
if(auto *red = dynamic_cast<ir::reduce_inst*>(x)){ if(auto *red = dynamic_cast<ir::reduce_inst*>(x)){
unsigned num_bytes = x->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8; unsigned num_bytes = x->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
size_t axis = red->get_axis(); size_t axis = red->get_axis();
@@ -73,15 +73,15 @@ unsigned memalloc::num_bytes(ir::value *x) {
return num_bytes; return num_bytes;
} }
void memalloc::run(){
void allocation::run() {
using std::max; using std::max;
using std::min; using std::min;
typedef std::multimap<unsigned, segment> triples_map_type; typedef std::multimap<unsigned, segment> triples_map_type;
std::vector<ir::value *> I; std::vector<ir::value *> I;
for(auto x: liveness_->intervals()){ for(auto x: liveness_->intervals())
I.push_back(x.first); I.push_back(x.first);
}
std::vector<ir::value *> J = I; std::vector<ir::value *> J = I;
triples_map_type H; triples_map_type H;
@@ -137,7 +137,7 @@ void memalloc::run(){
for(ir::value *X: V) for(ir::value *X: V)
colors[X] = (X==V[0])?0:-1; colors[X] = (X==V[0])?0:-1;
// First-fit coloring // First-fit graph coloring
std::vector<bool> available(V.size()); std::vector<bool> available(V.size());
for(ir::value *x: V){ for(ir::value *x: V){
// Non-neighboring colors are available // Non-neighboring colors are available
@@ -158,6 +158,7 @@ void memalloc::run(){
for(ir::value *y: interferences[x]) for(ir::value *y: interferences[x])
Adj = std::max(Adj, starts[y] + num_bytes(y)); Adj = std::max(Adj, starts[y] + num_bytes(y));
offsets_[x] = starts[x] + colors[x] * Adj; offsets_[x] = starts[x] + colors[x] * Adj;
// std::cout << x->get_name() << " " << offsets_[x] << " " << num_bytes(x) << std::endl;
if(buffer_info_->is_double(x)){ if(buffer_info_->is_double(x)){
ir::phi_node *phi = (ir::phi_node*)x; ir::phi_node *phi = (ir::phi_node*)x;
for(unsigned i = 0; i < phi->get_num_incoming(); i++){ for(unsigned i = 0; i < phi->get_num_incoming(); i++){
@@ -167,6 +168,8 @@ void memalloc::run(){
} }
} }
// exit(EXIT_FAILURE);
// Save maximum size of induced memory space // Save maximum size of induced memory space
allocated_size_ = 0; allocated_size_ = 0;
for(auto &x: offsets_){ for(auto &x: offsets_){

View File

@@ -68,11 +68,6 @@ void axes::update_graph_reshape(ir::instruction *i) {
} }
} }
void axes::update_graph_splat(ir::instruction *) {
// argument is scalar so don't make any edge
return;
}
void axes::update_graph_trans(ir::instruction *i) { void axes::update_graph_trans(ir::instruction *i) {
auto *trans = static_cast<ir::trans_inst*>(i); auto *trans = static_cast<ir::trans_inst*>(i);
ir::value *op = trans->get_operand(0); ir::value *op = trans->get_operand(0);
@@ -129,13 +124,14 @@ void axes::update_graph_elementwise(ir::instruction *i) {
void axes::update_graph(ir::instruction *i) { void axes::update_graph(ir::instruction *i) {
switch (i->get_id()) { switch (i->get_id()) {
case ir::INST_REDUCE: return update_graph_reduce(i); case ir::INST_REDUCE: return update_graph_reduce(i);
case ir::INST_RESHAPE: return update_graph_reshape(i); case ir::INST_RESHAPE: return update_graph_reshape(i);
case ir::INST_SPLAT: return update_graph_splat(i); case ir::INST_SPLAT: return;
case ir::INST_TRANS: return update_graph_trans(i); case ir::INST_TRANS: return update_graph_trans(i);
case ir::INST_BROADCAST: return update_graph_broadcast(i); case ir::INST_BROADCAST: return update_graph_broadcast(i);
case ir::INST_DOT: return update_graph_dot(i); case ir::INST_DOT: return update_graph_dot(i);
default: return update_graph_elementwise(i); case ir::INST_COPY_TO_SHARED: return;
default: return update_graph_elementwise(i);
} }
return; return;
} }

View File

@@ -20,10 +20,9 @@ std::set<int> layout::axes_of(ir::value *value) {
rank = ty->get_tile_rank(); rank = ty->get_tile_rank();
// create result // create result
std::set<int> result; std::set<int> result;
for(size_t d = 0; d < rank; d++){ for(size_t d = 0; d < rank; d++)
if(axes_->has_id(value, d)) if(axes_->has_id(value, d))
result.insert(axes_->get_id(value, d)); result.insert(axes_->get_id(value, d));
}
return result; return result;
} }
@@ -54,6 +53,7 @@ const std::vector<ir::value*>& layout::values(unsigned id) const
size_t layout::get_num_groups() const size_t layout::get_num_groups() const
{ return values_.size(); } { return values_.size(); }
// connect two values
void layout::connect(ir::value *x, ir::value *y) { void layout::connect(ir::value *x, ir::value *y) {
if(x == y) if(x == y)
return; return;
@@ -75,6 +75,7 @@ void layout::connect(ir::value *x, ir::value *y) {
} }
} }
// make graph
void layout::make_graph(ir::instruction *i) { void layout::make_graph(ir::instruction *i) {
for(ir::value* opx: i->ops()) for(ir::value* opx: i->ops())
for(ir::value* opy: i->ops()){ for(ir::value* opy: i->ops()){

View File

@@ -1,6 +1,7 @@
#include <iostream> #include <iostream>
#include "triton/codegen/instructions.h"
#include "triton/codegen/analysis/liveness.h" #include "triton/codegen/analysis/liveness.h"
#include "triton/codegen/analysis/meminfo.h" #include "triton/codegen/transform/cts.h"
#include "triton/ir/basic_block.h" #include "triton/ir/basic_block.h"
#include "triton/ir/function.h" #include "triton/ir/function.h"
#include "triton/ir/module.h" #include "triton/ir/module.h"
@@ -25,6 +26,11 @@ void liveness::run(ir::module &mod) {
// Creates live intervals // Creates live intervals
for(auto i: indices_){ for(auto i: indices_){
ir::value *v = i.first; ir::value *v = i.first;
// ir::instruction* instr = dynamic_cast<ir::instruction*>(v);
// if(!instr)
// continue;
// if(storage_info.at(instr->get_id()).first != SHARED)
// continue;
if(!info_->is_shared(v) || info_->get_reference(v)) if(!info_->is_shared(v) || info_->get_reference(v))
continue; continue;
unsigned start = i.second; unsigned start = i.second;

View File

0
lib/codegen/pass.cc Normal file
View File

View File

@@ -4,7 +4,7 @@
#include "triton/codegen/analysis/layout.h" #include "triton/codegen/analysis/layout.h"
#include "triton/codegen/analysis/axes.h" #include "triton/codegen/analysis/axes.h"
#include "triton/codegen/analysis/tiles.h" #include "triton/codegen/analysis/tiles.h"
#include "triton/codegen/analysis/memalloc.h" #include "triton/codegen/analysis/allocation.h"
#include "triton/codegen/analysis/align.h" #include "triton/codegen/analysis/align.h"
#include "triton/codegen/transform/coalesce.h" #include "triton/codegen/transform/coalesce.h"
#include "triton/ir/context.h" #include "triton/ir/context.h"

View File

@@ -7,7 +7,7 @@
#include "triton/ir/instructions.h" #include "triton/ir/instructions.h"
#include "triton/ir/module.h" #include "triton/ir/module.h"
#include "triton/codegen/analysis/layout.h" #include "triton/codegen/analysis/layout.h"
#include "triton/codegen/analysis/meminfo.h" #include "triton/codegen/transform/cts.h"
#include "triton/codegen/analysis/align.h" #include "triton/codegen/analysis/align.h"
#include "triton/codegen/transform/coalesce.h" #include "triton/codegen/transform/coalesce.h"
@@ -15,7 +15,7 @@ namespace triton {
namespace codegen{ namespace codegen{
namespace transform{ namespace transform{
coalesce::coalesce(analysis::align* align, analysis::layout *layouts, analysis::meminfo *mem) coalesce::coalesce(analysis::align* align, analysis::layout *layouts, analysis::cts *mem)
: align_(align), layout_(layouts), mem_(mem) { } : align_(align), layout_(layouts), mem_(mem) { }
// Find all values that are used as pointer operands in LD/ST // Find all values that are used as pointer operands in LD/ST
@@ -102,9 +102,6 @@ void coalesce::run(ir::module &mod) {
r->replace_all_uses_with(cts); r->replace_all_uses_with(cts);
cts->replace_uses_of_with(cts, r); cts->replace_uses_of_with(cts, r);
} }
else{
}
} }
} }

View File

@@ -1,5 +1,7 @@
#include <algorithm> #include <algorithm>
#include "triton/codegen/analysis/meminfo.h" #include <iostream>
#include "triton/codegen/transform/cts.h"
#include "triton/codegen/instructions.h"
#include "triton/ir/module.h" #include "triton/ir/module.h"
#include "triton/ir/function.h" #include "triton/ir/function.h"
#include "triton/ir/basic_block.h" #include "triton/ir/basic_block.h"
@@ -12,7 +14,7 @@ namespace codegen{
namespace analysis{ namespace analysis{
// run pass on module // run pass on module
bool meminfo::is_loop_latch(ir::phi_node *phi, ir::instruction *terminator){ bool cts::is_loop_latch(ir::phi_node *phi, ir::instruction *terminator){
if(phi->get_parent() != terminator->get_parent()) if(phi->get_parent() != terminator->get_parent())
return false; return false;
if(auto *br = dynamic_cast<ir::cond_branch_inst*>(terminator)) if(auto *br = dynamic_cast<ir::cond_branch_inst*>(terminator))
@@ -24,24 +26,6 @@ bool meminfo::is_loop_latch(ir::phi_node *phi, ir::instruction *terminator){
throw std::runtime_error("unreachable"); throw std::runtime_error("unreachable");
} }
void meminfo::replace(ir::value* before, ir::value *after) {
shared_.erase(before);
shared_.insert(after);
if(refs_.find(before) != refs_.end()){
ir::value* v = refs_.at(before);
refs_.erase(before);
refs_.insert({after, v});
}
}
void meminfo::copy(ir::value* y, ir::value *x) {
if(shared_.find(x) != shared_.end())
shared_.insert(y);
if(refs_.find(x) != refs_.end())
refs_[y] = refs_[x];
if(double_.find(x) != double_.end())
double_.insert(y);
}
inline bool get_is_shared(ir::value* v) { inline bool get_is_shared(ir::value* v) {
@@ -62,40 +46,46 @@ inline bool get_is_shared(ir::value* v) {
return false; return false;
} }
void add_copy(ir::value *x, ir::builder &builder) { void add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder) {
if(auto phi = dynamic_cast<ir::phi_node*>(x)){ auto *i = dynamic_cast<ir::instruction*>(x);
// not an instruction
if(!i) {
builder.set_insert_point(parent);
ir::value *cts = builder.create_copy_to_shared(x);
parent->replace_uses_of_with(x, cts);
return;
}
// phi node
if(auto* phi = dynamic_cast<ir::phi_node*>(x)) {
for(unsigned i = 0; i < phi->get_num_incoming(); ++i) for(unsigned i = 0; i < phi->get_num_incoming(); ++i)
add_copy(phi->get_incoming_value(i), builder); add_copy(phi, phi->get_incoming_value(i), builder);
} return;
else {
if(get_is_shared(x))
return;
if(auto *i = dynamic_cast<ir::instruction*>(x)){
ir::basic_block* block = i->get_parent();
auto it = std::find(block->begin(), block->end(), i);
builder.set_insert_point(++it);
}
ir::instruction *rx = (ir::instruction*)builder.create_copy_to_shared(x);
x->replace_all_uses_with(rx);
rx->set_operand(0, x);
} }
ir::value_id_t id = i->get_id();
// already in shared memory
if(storage_info.at(id).first == SHARED)
return;
// copy
builder.set_insert_point_after(i);
ir::value *cts = builder.create_copy_to_shared(x);
parent->replace_uses_of_with(x, cts);
} }
void meminfo::run(ir::module &mod) { void cts::run(ir::module &mod) {
shared_.clear(); shared_.clear();
refs_.clear(); refs_.clear();
double_.clear(); double_.clear();
// Add shared copies // Add shared copies
ir::builder &builder = mod.get_builder();
for(ir::function *fn: mod.get_function_list()){ for(ir::function *fn: mod.get_function_list()){
ir::builder builder(mod.get_context());
for(ir::basic_block *block: fn->blocks()) for(ir::basic_block *block: fn->blocks())
for(ir::instruction *i: block->get_inst_list()){ for(ir::instruction *i: block->get_inst_list()){
if(dynamic_cast<ir::dot_inst*>(i)) auto storage = storage_info.at(i->get_id());
if(i->get_operand(1)->get_type()->get_tile_shapes()[1] != 1){ // copy to shared operands when necessary
add_copy(i->get_operand(0), builder); for(size_t k = 0; k < storage.second.size(); k++)
add_copy(i->get_operand(1), builder); if(storage.second[k] == SHARED)
} add_copy(i, i->get_operand(k), builder);
} }
} }
@@ -135,15 +125,15 @@ void meminfo::run(ir::module &mod) {
} }
// query double-buffered status // query double-buffered status
bool meminfo::is_double(ir::value *x) bool cts::is_double(ir::value *x)
{ return double_.find(x) != double_.end(); } { return double_.find(x) != double_.end(); }
// query shared status // query shared status
bool meminfo::is_shared(ir::value *x) bool cts::is_shared(ir::value *x)
{ return shared_.find(x) != shared_.end(); } { return shared_.find(x) != shared_.end(); }
// get reference if any // get reference if any
ir::value *meminfo::get_reference(ir::value *x) ir::value *cts::get_reference(ir::value *x)
{ return refs_[x]; } { return refs_[x]; }

View File

@@ -20,12 +20,22 @@ void dce::run(ir::module &mod) {
// iterate through blocks // iterate through blocks
for(ir::basic_block *block: rpo) for(ir::basic_block *block: rpo)
for(ir::instruction *i: block->get_inst_list()){ for(ir::instruction *i: block->get_inst_list()){
if(dynamic_cast<ir::io_inst*>(i) || dynamic_cast<ir::return_inst*>(i) switch(i->get_id()){
|| dynamic_cast<ir::branch_inst*>(i) || dynamic_cast<ir::cond_branch_inst*>(i) case ir::INST_RETURN:
|| dynamic_cast<ir::atomic_cas_inst*>(i) || dynamic_cast<ir::atomic_exch_inst*>(i) || dynamic_cast<ir::atomic_add_inst*>(i) case ir::INST_UNCOND_BRANCH:
|| dynamic_cast<ir::barrier_inst*>(i)){ case ir::INST_COND_BRANCH:
work_list.push_back(i); case ir::INST_UNMASKED_STORE:
marked.insert(i); case ir::INST_MASKED_STORE:
case ir::INST_ATOMIC_ADD:
case ir::INST_ATOMIC_CAS:
case ir::INST_ATOMIC_EXCH:
case ir::INST_BARRIER: {
work_list.push_back(i);
marked.insert(i);
break;
}
default:
break;
} }
} }
} }

View File

@@ -3,8 +3,8 @@
#include <algorithm> #include <algorithm>
#include "triton/codegen/transform/membar.h" #include "triton/codegen/transform/membar.h"
#include "triton/codegen/analysis/memalloc.h" #include "triton/codegen/analysis/allocation.h"
#include "triton/codegen/analysis/meminfo.h" #include "triton/codegen/transform/cts.h"
#include "triton/ir/module.h" #include "triton/ir/module.h"
#include "triton/ir/function.h" #include "triton/ir/function.h"
#include "triton/ir/basic_block.h" #include "triton/ir/basic_block.h"

View File

@@ -122,7 +122,6 @@ ir::value *reassociate::reassociate_idx(ir::value *old_value,
new_value = builder.create_add(rrhs, builder.create_add(lrhs, lhs), name, cst); new_value = builder.create_add(rrhs, builder.create_add(lrhs, lhs), name, cst);
} }
} }
// extract constant and non-constant // extract constant and non-constant
if(ir::instruction *bin_add = is_bin_add(new_value)){ if(ir::instruction *bin_add = is_bin_add(new_value)){
ir::value *new_lhs = bin_add->get_operand(0); ir::value *new_lhs = bin_add->get_operand(0);
@@ -136,12 +135,9 @@ ir::value *reassociate::reassociate_idx(ir::value *old_value,
noncst = new_lhs; noncst = new_lhs;
} }
} }
// clean-up if some re-ordering happened // clean-up if some re-ordering happened
if(old_value != new_value){ if(old_value != new_value)
old_value->replace_all_uses_with(new_value); old_value->replace_all_uses_with(new_value);
}
return new_value; return new_value;
} }

View File

@@ -27,6 +27,13 @@ void builder::set_insert_point(instruction* i){
} }
void builder::set_insert_point_after(instruction* i){
block_ = i->get_parent();
auto it = std::find(block_->begin(), block_->end(), i);
set_insert_point(++it);
}
void builder::set_insert_point(basic_block *block){ void builder::set_insert_point(basic_block *block){
block_ = block; block_ = block;
insert_point_ = block->end(); insert_point_ = block->end();

View File

@@ -199,14 +199,14 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
llvm::LLVMContext ctx; llvm::LLVMContext ctx;
std::unique_ptr<llvm::Module> llvm(new llvm::Module(module.get_name(), ctx)); std::unique_ptr<llvm::Module> llvm(new llvm::Module(module.get_name(), ctx));
// create passes // create passes
codegen::analysis::meminfo shmem_info; codegen::analysis::cts shmem_info;
codegen::analysis::align align; codegen::analysis::align align;
codegen::analysis::liveness shmem_liveness(&shmem_info); codegen::analysis::liveness shmem_liveness(&shmem_info);
codegen::analysis::axes axes; codegen::analysis::axes axes;
codegen::analysis::layout layouts(&axes); codegen::analysis::layout layouts(&axes);
codegen::transform::coalesce coalesce(&align, &layouts, &shmem_info); codegen::transform::coalesce coalesce(&align, &layouts, &shmem_info);
codegen::analysis::tiles tiles(opt.num_warps, &align, &axes, &layouts); codegen::analysis::tiles tiles(opt.num_warps, &align, &axes, &layouts);
codegen::analysis::memalloc shmem_allocation(&shmem_liveness, &shmem_info, &tiles); codegen::analysis::allocation shmem_allocation(&shmem_liveness, &shmem_info, &tiles);
codegen::transform::membar shmem_barriers(&shmem_allocation, &shmem_info); codegen::transform::membar shmem_barriers(&shmem_allocation, &shmem_info);
codegen::transform::dce dce; codegen::transform::dce dce;
codegen::transform::peephole peephole; codegen::transform::peephole peephole;
@@ -229,6 +229,7 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
peephole.run(module); peephole.run(module);
shmem_info.run(module); shmem_info.run(module);
shmem_liveness.run(module); shmem_liveness.run(module);
ir::print(module, std::cout);
shmem_allocation.run(); shmem_allocation.run();
if(shmem_allocation.allocated_size() > context->device()->max_shared_memory()) if(shmem_allocation.allocated_size() > context->device()->max_shared_memory())
return std::unique_ptr<driver::module>(); return std::unique_ptr<driver::module>();