[codegen] cleaned up shared memory and double-buffering logic

This commit is contained in:
Philippe Tillet
2019-09-21 22:21:40 -04:00
parent 43d88154bd
commit 001973630e
15 changed files with 173 additions and 187 deletions

View File

@@ -10,6 +10,7 @@ namespace triton{
namespace ir{ namespace ir{
class value; class value;
class function; class function;
class module;
} }
namespace codegen{ namespace codegen{
@@ -22,8 +23,8 @@ class cts;
class allocation { class allocation {
public: public:
allocation(liveness *live, cts *buffer_info, tiles *params) allocation(liveness *live, tiles *params)
: liveness_(live), buffer_info_(buffer_info), tiles_(params){ } : liveness_(live), tiles_(params){ }
// utilities // utilities
unsigned num_bytes(ir::value *x); unsigned num_bytes(ir::value *x);
unsigned is_ld_padded(ir::value* x); unsigned is_ld_padded(ir::value* x);
@@ -31,7 +32,7 @@ public:
unsigned offset(ir::value *x) const { return offsets_.at(x); } unsigned offset(ir::value *x) const { return offsets_.at(x); }
unsigned allocated_size() const { return allocated_size_; } unsigned allocated_size() const { return allocated_size_; }
// run // run
void run(); void run(ir::module& mod);
private: private:
std::map<ir::value*, unsigned> offsets_; std::map<ir::value*, unsigned> offsets_;
@@ -39,7 +40,6 @@ private:
size_t allocated_size_; size_t allocated_size_;
// dependences // dependences
liveness *liveness_; liveness *liveness_;
cts *buffer_info_;
tiles *tiles_; tiles *tiles_;
}; };

View File

@@ -7,6 +7,7 @@ namespace triton{
namespace ir{ namespace ir{
class value; class value;
class phi_node;
class function; class function;
class module; class module;
} }
@@ -31,6 +32,11 @@ struct segment {
} }
}; };
struct double_buffer_info_t {
ir::value* latch;
ir::phi_node* phi;
};
class liveness { class liveness {
private: private:
typedef std::map<ir::value*, slot_index> indices_map_t; typedef std::map<ir::value*, slot_index> indices_map_t;
@@ -43,19 +49,20 @@ public:
using const_iterator = intervals_map_t::const_iterator; using const_iterator = intervals_map_t::const_iterator;
public: public:
// constructor
liveness(cts *info): info_(info){ }
// accessors // accessors
const intervals_map_t& intervals() const { return intervals_; } const intervals_map_t& intervals() const { return intervals_; }
segment get_interval(ir::value* v) const { return intervals_.at(v); } segment get_interval(ir::value* v) const { return intervals_.at(v); }
// double-buffering
bool has_double(ir::value *x) const { return double_.find(x) != double_.end(); }
double_buffer_info_t get_double(ir::value *x) const { return double_.at(x); }
// run // run
void run(ir::module &mod); void run(ir::module &mod);
private: private:
cts *info_;
has_storage_map_t has_dedicated_storage_; has_storage_map_t has_dedicated_storage_;
indices_map_t indices_; indices_map_t indices_;
intervals_map_t intervals_; intervals_map_t intervals_;
std::map<ir::value*, double_buffer_info_t> double_;
}; };
} }

View File

@@ -43,6 +43,7 @@ namespace triton{
namespace codegen{ namespace codegen{
namespace analysis{ namespace analysis{
class liveness;
class tiles; class tiles;
class align; class align;
class allocation; class allocation;
@@ -201,10 +202,10 @@ private:
public: public:
selection(analysis::allocation *alloc, analysis::tiles *tiles, analysis::cts *buffer_info, selection(analysis::liveness* liveness, analysis::allocation *alloc, analysis::tiles *tiles,
analysis::align *alignment, analysis::axes *axes, analysis::layout *layouts, analysis::align *alignment, analysis::axes *axes,
transform::coalesce* reorder, target *tgt, unsigned num_warps) analysis::layout *layouts, transform::coalesce* reorder, target *tgt, unsigned num_warps)
: alloc_(alloc), tiles_(tiles), buffer_info_(buffer_info), : liveness_(liveness), alloc_(alloc), tiles_(tiles),
alignment_(alignment), a_axes_(axes), layouts_(layouts), alignment_(alignment), a_axes_(axes), layouts_(layouts),
reorder_(reorder), tgt_(tgt), num_warps_(num_warps){ } reorder_(reorder), tgt_(tgt), num_warps_(num_warps){ }
@@ -213,11 +214,11 @@ public:
private: private:
vmap_t vmap_; vmap_t vmap_;
tmap_t tmap_; tmap_t tmap_;
analysis::liveness *liveness_;
analysis::allocation *alloc_; analysis::allocation *alloc_;
analysis::tiles *tiles_; analysis::tiles *tiles_;
analysis::axes *a_axes_; analysis::axes *a_axes_;
analysis::layout *layouts_; analysis::layout *layouts_;
analysis::cts *buffer_info_;
analysis::align *alignment_; analysis::align *alignment_;
transform::coalesce *reorder_; transform::coalesce *reorder_;
target *tgt_; target *tgt_;

View File

@@ -32,13 +32,12 @@ private:
ir::value* rematerialize(ir::value *v, ir::builder& builder, std::map<ir::value*, ir::value*>& seen); ir::value* rematerialize(ir::value *v, ir::builder& builder, std::map<ir::value*, ir::value*>& seen);
public: public:
coalesce(analysis::align* align, triton::codegen::analysis::layout *layouts, analysis::cts* mem); coalesce(analysis::align* align, triton::codegen::analysis::layout *layouts);
void run(ir::module &mod); void run(ir::module &mod);
private: private:
analysis::align* align_; analysis::align* align_;
analysis::layout* layout_; analysis::layout* layout_;
analysis::cts* mem_;
}; };
} }

View File

@@ -19,17 +19,6 @@ namespace analysis{
class cts { class cts {
public: public:
void run(ir::module &mod); void run(ir::module &mod);
// queries
bool is_double(ir::value *x);
void add_shared(ir::value *v);
bool is_shared(ir::value *x);
bool is_loop_latch(ir::phi_node *phi, ir::instruction *terminator);
ir::value *get_reference(ir::value *x);
private:
std::set<ir::value*> shared_;
std::set<ir::value*> double_;
std::map<ir::value*, ir::value*> refs_;
}; };
} }

View File

@@ -16,6 +16,7 @@ namespace codegen{
namespace analysis{ namespace analysis{
class allocation; class allocation;
class liveness;
class cts; class cts;
} }
@@ -35,15 +36,17 @@ private:
void add_reference(ir::value *v, interval_vec_t &res); void add_reference(ir::value *v, interval_vec_t &res);
void get_read_intervals(ir::instruction *i, interval_vec_t &res); void get_read_intervals(ir::instruction *i, interval_vec_t &res);
void get_written_intervals(ir::instruction *i, interval_vec_t &res); void get_written_intervals(ir::instruction *i, interval_vec_t &res);
std::pair<interval_vec_t, interval_vec_t> transfer(ir::basic_block *block, const interval_vec_t &written_to, const interval_vec_t &read_from, std::set<ir::instruction *> &insert_loc); std::pair<interval_vec_t, interval_vec_t> transfer(ir::basic_block *block, const interval_vec_t &written_to, const interval_vec_t &read_from,
std::set<ir::instruction *> &insert_loc, std::set<triton::ir::value *> &safe_war);
public: public:
membar(analysis::allocation *alloc, analysis::cts *buffer_info): alloc_(alloc), buffer_info_(buffer_info) {} membar(analysis::liveness *liveness, analysis::allocation *alloc):
liveness_(liveness), alloc_(alloc) {}
void run(ir::module &mod); void run(ir::module &mod);
private: private:
analysis::liveness *liveness_;
analysis::allocation *alloc_; analysis::allocation *alloc_;
analysis::cts *buffer_info_;
}; };

View File

@@ -1,4 +1,5 @@
#include <algorithm> #include <algorithm>
#include <climits>
#include "triton/codegen/analysis/allocation.h" #include "triton/codegen/analysis/allocation.h"
#include "triton/codegen/analysis/liveness.h" #include "triton/codegen/analysis/liveness.h"
#include "triton/codegen/transform/cts.h" #include "triton/codegen/transform/cts.h"
@@ -8,6 +9,7 @@
#include "triton/ir/value.h" #include "triton/ir/value.h"
#include "triton/ir/function.h" #include "triton/ir/function.h"
#include "triton/ir/instructions.h" #include "triton/ir/instructions.h"
#include "triton/ir/utils.h"
namespace triton{ namespace triton{
namespace codegen{ namespace codegen{
@@ -68,13 +70,12 @@ unsigned allocation::num_bytes(ir::value *x) {
unsigned ld = x->get_type()->get_tile_shapes()[0]; unsigned ld = x->get_type()->get_tile_shapes()[0];
num_bytes += pad * num_bytes / ld; num_bytes += pad * num_bytes / ld;
} }
if(buffer_info_->is_double(x)) if(liveness_->has_double(x))
num_bytes *= 2; num_bytes *= 2;
return num_bytes; return num_bytes;
} }
void allocation::run(ir::module &mod) {
void allocation::run() {
using std::max; using std::max;
using std::min; using std::min;
typedef std::multimap<unsigned, segment> triples_map_type; typedef std::multimap<unsigned, segment> triples_map_type;
@@ -85,7 +86,7 @@ void allocation::run() {
std::vector<ir::value *> J = I; std::vector<ir::value *> J = I;
triples_map_type H; triples_map_type H;
H.insert({0, segment{0, 1024}}); H.insert({0, segment{0, INT_MAX}});
std::vector<ir::value *> V; std::vector<ir::value *> V;
std::map<ir::value *, unsigned> starts; std::map<ir::value *, unsigned> starts;
@@ -115,7 +116,6 @@ void allocation::run() {
} }
} }
// Build interference graph // Build interference graph
std::map<ir::value*, std::set<ir::value *>> interferences; std::map<ir::value*, std::set<ir::value *>> interferences;
for(ir::value *x: V) for(ir::value *x: V)
@@ -137,6 +137,7 @@ void allocation::run() {
for(ir::value *X: V) for(ir::value *X: V)
colors[X] = (X==V[0])?0:-1; colors[X] = (X==V[0])?0:-1;
// First-fit graph coloring // First-fit graph coloring
std::vector<bool> available(V.size()); std::vector<bool> available(V.size());
for(ir::value *x: V){ for(ir::value *x: V){
@@ -158,17 +159,11 @@ void allocation::run() {
for(ir::value *y: interferences[x]) for(ir::value *y: interferences[x])
Adj = std::max(Adj, starts[y] + num_bytes(y)); Adj = std::max(Adj, starts[y] + num_bytes(y));
offsets_[x] = starts[x] + colors[x] * Adj; offsets_[x] = starts[x] + colors[x] * Adj;
// std::cout << x->get_name() << " " << offsets_[x] << " " << num_bytes(x) << std::endl; if(liveness_->has_double(x)){
if(buffer_info_->is_double(x)){ auto info = liveness_->get_double(x);
ir::phi_node *phi = (ir::phi_node*)x; offsets_[info.latch] = offsets_[x] + num_bytes(x) / 2;
for(unsigned i = 0; i < phi->get_num_incoming(); i++){
ir::value *inc_val = phi->get_incoming_value(i);
offsets_[inc_val] = offsets_[phi];
} }
} }
}
// exit(EXIT_FAILURE);
// Save maximum size of induced memory space // Save maximum size of induced memory space
allocated_size_ = 0; allocated_size_ = 0;

View File

@@ -7,13 +7,59 @@
#include "triton/ir/module.h" #include "triton/ir/module.h"
#include "triton/ir/instructions.h" #include "triton/ir/instructions.h"
#include "triton/ir/value.h" #include "triton/ir/value.h"
#include "triton/ir/utils.h"
namespace triton{ namespace triton{
namespace codegen{ namespace codegen{
namespace analysis{ namespace analysis{
inline bool is_loop_latch(ir::phi_node *phi, ir::instruction *terminator){
if(phi->get_parent() != terminator->get_parent())
return false;
if(auto *br = dynamic_cast<ir::cond_branch_inst*>(terminator))
return br->get_true_dest() == phi->get_parent()
|| br->get_false_dest() == phi->get_parent();
else if(dynamic_cast<ir::uncond_branch_inst*>(terminator))
return false;
else
throw std::runtime_error("unreachable");
}
inline void extract_double_bufferable(ir::instruction *i, std::map<ir::value*, double_buffer_info_t>& result) {
auto* phi = dynamic_cast<ir::phi_node*>(i);
if(!phi || phi->get_num_incoming() != 2)
return;
ir::basic_block *block_0 = phi->get_incoming_block(0);
ir::basic_block *block_1 = phi->get_incoming_block(1);
ir::instruction *terminator_0 = block_0->get_inst_list().back();
ir::instruction *terminator_1 = block_1->get_inst_list().back();
bool is_latch_0 = is_loop_latch(phi, terminator_0);
bool is_latch_1 = is_loop_latch(phi, terminator_1);
ir::value *value_0 = phi->get_incoming_value(0);
ir::value *value_1 = phi->get_incoming_value(1);
ir::instruction *i_0 = dynamic_cast<ir::instruction*>(value_0);
ir::instruction *i_1 = dynamic_cast<ir::instruction*>(value_1);
if(!i_0 || !i_1 || storage_info.at(i_0->get_id()).first != SHARED || storage_info.at(i_1->get_id()).first != SHARED)
return;
if(is_latch_1)
result[value_0] = double_buffer_info_t{value_1, phi};
if(is_latch_0)
result[value_1] = double_buffer_info_t{value_0, phi};
}
// Entry point // Entry point
void liveness::run(ir::module &mod) { void liveness::run(ir::module &mod) {
double_.clear();
indices_.clear();
intervals_.clear();
// set of pair of values that can be double-buffered
ir::for_each_instruction(mod, [this](ir::instruction* i) {
extract_double_bufferable(i, this->double_);
});
for(ir::function *fn: mod.get_function_list()){ for(ir::function *fn: mod.get_function_list()){
// Assigns index to each instruction // Assigns index to each instruction
slot_index index = 0; slot_index index = 0;
@@ -26,12 +72,10 @@ void liveness::run(ir::module &mod) {
// Creates live intervals // Creates live intervals
for(auto i: indices_){ for(auto i: indices_){
ir::value *v = i.first; ir::value *v = i.first;
// ir::instruction* instr = dynamic_cast<ir::instruction*>(v); ir::instruction* instr = dynamic_cast<ir::instruction*>(v);
// if(!instr) if(!instr)
// continue; continue;
// if(storage_info.at(instr->get_id()).first != SHARED) if(storage_info.at(instr->get_id()).first != SHARED)
// continue;
if(!info_->is_shared(v) || info_->get_reference(v))
continue; continue;
unsigned start = i.second; unsigned start = i.second;
unsigned end = start; unsigned end = start;
@@ -41,6 +85,21 @@ void liveness::run(ir::module &mod) {
} }
intervals_[v] = segment{start, end}; intervals_[v] = segment{start, end};
} }
// Double-Buffering
// Arrays are live throughout the end of the loop
auto it = intervals_.begin();
while(it != intervals_.end()) {
ir::value *x = it->first;
auto dit = double_.find(x);
if(dit != double_.end()) {
ir::value *y = dit->second.latch;
unsigned start = intervals_[x].start;
unsigned end = intervals_[y].end;
intervals_[x] = segment{start, end};
intervals_.erase(y);
}
it++;
}
} }
} }

View File

@@ -1,12 +1,14 @@
#include <numeric> #include <numeric>
#include "triton/codegen/selection.h" #include "triton/codegen/selection.h"
#include "triton/codegen/target.h" #include "triton/codegen/target.h"
#include "triton/codegen/analysis/liveness.h"
#include "triton/codegen/analysis/layout.h" #include "triton/codegen/analysis/layout.h"
#include "triton/codegen/analysis/axes.h" #include "triton/codegen/analysis/axes.h"
#include "triton/codegen/analysis/tiles.h" #include "triton/codegen/analysis/tiles.h"
#include "triton/codegen/analysis/allocation.h" #include "triton/codegen/analysis/allocation.h"
#include "triton/codegen/analysis/align.h" #include "triton/codegen/analysis/align.h"
#include "triton/codegen/transform/coalesce.h" #include "triton/codegen/transform/coalesce.h"
#include "triton/codegen/instructions.h"
#include "triton/ir/context.h" #include "triton/ir/context.h"
#include "triton/ir/module.h" #include "triton/ir/module.h"
#include "triton/ir/function.h" #include "triton/ir/function.h"
@@ -746,43 +748,32 @@ void selection::create_shared_tile(ir::value *v, IRBuilder<> &builder, Value *sh
Type* ty = llvm_type(v->get_type()->get_scalar_ty(), builder.getContext()); Type* ty = llvm_type(v->get_type()->get_scalar_ty(), builder.getContext());
// shared copy // shared copy
PointerType *ptr_ty = ty->getPointerTo(sh_mem_ptr->getType()->getPointerAddressSpace()); PointerType *ptr_ty = ty->getPointerTo(sh_mem_ptr->getType()->getPointerAddressSpace());
// phi-node (double-buffering) // double-buffered
if(auto *phi = dynamic_cast<ir::phi_node*>(v)) { if(liveness_->has_double(v)) {
auto info = liveness_->get_double(v);
ir::phi_node *phi = info.phi;
BasicBlock *parent = (BasicBlock*)vmap_[phi->get_parent()]; BasicBlock *parent = (BasicBlock*)vmap_[phi->get_parent()];
unsigned id_pre = 0, id_loop = 1;
if(phi->get_incoming_block(0) == phi->get_parent())
std::swap(id_pre, id_loop);
if(parent->empty()) if(parent->empty())
builder.SetInsertPoint(parent); builder.SetInsertPoint(parent);
else else
builder.SetInsertPoint(&*parent->getFirstInsertionPt()); builder.SetInsertPoint(&*parent->getFirstInsertionPt());
// create double-buffered pointer
PHINode *ptr = builder.CreatePHI(ptr_ty, 2); PHINode *ptr = builder.CreatePHI(ptr_ty, 2);
PHINode *offset = builder.CreatePHI(builder.getInt32Ty(), 2); PHINode *offset = builder.CreatePHI(builder.getInt32Ty(), 2);
// next pointer // next pointer
Value *pre_ptr = builder.CreateGEP(sh_mem_ptr, builder.getInt32(alloc_->offset(phi))); Value *pre_ptr = builder.CreateGEP(sh_mem_ptr, builder.getInt32(alloc_->offset(v)));
pre_ptr = builder.CreateBitCast(pre_ptr, ptr->getType()); pre_ptr = builder.CreateBitCast(pre_ptr, ptr->getType());
Value *next_ptr = builder.CreateGEP(ptr, offset, "next_ptr"); Value *next_ptr = builder.CreateGEP(ptr, offset, "next_ptr");
tmap_.insert({phi, new shared_tile(ty, shapes, ptr, builder, offset)}); tmap_.insert({phi, new shared_tile(ty, shapes, ptr, builder, offset)});
for(unsigned i = 0; i < phi->get_num_incoming(); i++) { tmap_.insert({v, new shared_tile(ty, shapes, pre_ptr, builder)});
ir::basic_block* inc_block = phi->get_incoming_block(i); tmap_.insert({info.latch, new shared_tile(ty, shapes, next_ptr, builder)});
ir::value* inc_value = phi->get_incoming_value(i);
ir::instruction* terminator = inc_block->get_inst_list().back();
bool is_loop_latch = buffer_info_->is_loop_latch(phi, terminator);
tmap_.insert({inc_value, new shared_tile(ty, shapes, is_loop_latch?next_ptr:pre_ptr, builder)});
}
} }
else { else {
bool has_phi_user = false;
for(ir::user *usr: v->get_users())
if(dynamic_cast<ir::phi_node*>(usr))
has_phi_user = true;
if(!has_phi_user){
size_t offset = alloc_->offset(v); size_t offset = alloc_->offset(v);
Value *ptr = builder.CreateGEP(sh_mem_ptr, builder.getInt32(offset)); Value *ptr = builder.CreateGEP(sh_mem_ptr, builder.getInt32(offset));
ptr = builder.CreateBitCast(ptr, ptr_ty); ptr = builder.CreateBitCast(ptr, ptr_ty);
tmap_.insert({v, new shared_tile(ty, shapes, ptr, builder)}); tmap_.insert({v, new shared_tile(ty, shapes, ptr, builder)});
} }
}
} }
void selection::create_distributed_tile(ir::value *v, IRBuilder<> &builder) { void selection::create_distributed_tile(ir::value *v, IRBuilder<> &builder) {
@@ -827,8 +818,9 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder,
if(auto *user = dynamic_cast<ir::user*>(v)) if(auto *user = dynamic_cast<ir::user*>(v))
for(ir::value *op: user->ops()) for(ir::value *op: user->ops())
create_tile(op, builder, seen, sh_mem_ptr); create_tile(op, builder, seen, sh_mem_ptr);
if(buffer_info_->is_shared(v) && !dynamic_cast<ir::reduce_inst*>(v)) auto *i = dynamic_cast<ir::instruction*>(v);
create_shared_tile(v, builder, sh_mem_ptr); if(i && storage_info.at(i->get_id()).first == SHARED && !dynamic_cast<ir::reduce_inst*>(v))
create_shared_tile(i, builder, sh_mem_ptr);
else else
create_distributed_tile(v, builder); create_distributed_tile(v, builder);
} }
@@ -1427,7 +1419,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
lower_masked_load(x, ctx, fn, builder); lower_masked_load(x, ctx, fn, builder);
else if(auto *x = dynamic_cast<ir::load_inst*>(ins)) else if(auto *x = dynamic_cast<ir::load_inst*>(ins))
lower_load(x, ctx, fn, builder); lower_load(x, ctx, fn, builder);
else if(!buffer_info_->is_shared(ins)) else if(!dynamic_cast<shared_tile*>(tmap_.at(ins)))
lower_elementwise(ins, ctx, fn, builder); lower_elementwise(ins, ctx, fn, builder);
} }
@@ -1556,21 +1548,19 @@ void selection::run(ir::module &src, Module &dst) {
} }
} }
// add phi operands
for(ir::basic_block *block: fn->blocks()) for(ir::basic_block *block: fn->blocks())
for(ir::instruction *inst: block->get_inst_list()) for(ir::instruction *inst: block->get_inst_list()) {
if(auto *phi = dynamic_cast<ir::phi_node*>(inst)){ if(liveness_->has_double(inst)) {
if(buffer_info_->is_double(phi)) { auto info = liveness_->get_double(inst);
ir::phi_node *phi = info.phi;
PHINode *ptr = (PHINode*)((shared_tile*)tmap_.at(phi))->get_pointer(); PHINode *ptr = (PHINode*)((shared_tile*)tmap_.at(phi))->get_pointer();
PHINode *offset = (PHINode*)((shared_tile*)tmap_.at(phi))->get_offset(); PHINode *offset = (PHINode*)((shared_tile*)tmap_.at(phi))->get_offset();
for(unsigned n = 0; n < phi->get_num_incoming(); n++){ for(unsigned n = 0; n < phi->get_num_incoming(); n++){
ir::basic_block* inc_block = phi->get_incoming_block(n); ir::basic_block* inc_block = phi->get_incoming_block(n);
ir::value* inc_val = phi->get_incoming_value(n); ir::value* inc_val = phi->get_incoming_value(n);
ir::instruction* terminator = inc_block->get_inst_list().back();
BasicBlock *llvm_inc_block = last_block.at(inc_block); BasicBlock *llvm_inc_block = last_block.at(inc_block);
shared_tile *inc_shared = (shared_tile*)tmap_.at(inc_val); shared_tile *inc_shared = (shared_tile*)tmap_.at(inc_val);
bool is_loop_latch = buffer_info_->is_loop_latch(phi, terminator); if(inc_val == info.latch){
if(is_loop_latch){
dst_builder.SetInsertPoint(llvm_inc_block->getTerminator()); dst_builder.SetInsertPoint(llvm_inc_block->getTerminator());
Value *next_offset = dst_builder.CreateNeg(offset); Value *next_offset = dst_builder.CreateNeg(offset);
offset->addIncoming(next_offset, llvm_inc_block); offset->addIncoming(next_offset, llvm_inc_block);
@@ -1582,7 +1572,14 @@ void selection::run(ir::module &src, Module &dst) {
ptr->addIncoming(inc_shared->get_pointer(), llvm_inc_block); ptr->addIncoming(inc_shared->get_pointer(), llvm_inc_block);
} }
} }
else { }
// add phi operands
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *inst: block->get_inst_list())
if(auto *phi = dynamic_cast<ir::phi_node*>(inst)){
if(tmap_.find(phi) == tmap_.end() ||
!dynamic_cast<shared_tile*>(tmap_.at(phi))) {
for(unsigned n = 0; n < phi->get_num_incoming(); n++){ for(unsigned n = 0; n < phi->get_num_incoming(); n++){
ir::value *inc_val = phi->get_incoming_value(n); ir::value *inc_val = phi->get_incoming_value(n);
ir::basic_block *inc_block = phi->get_incoming_block(n); ir::basic_block *inc_block = phi->get_incoming_block(n);

View File

@@ -15,8 +15,8 @@ namespace triton {
namespace codegen{ namespace codegen{
namespace transform{ namespace transform{
coalesce::coalesce(analysis::align* align, analysis::layout *layouts, analysis::cts *mem) coalesce::coalesce(analysis::align* align, analysis::layout *layouts)
: align_(align), layout_(layouts), mem_(mem) { } : align_(align), layout_(layouts) { }
// Find all values that are used as pointer operands in LD/ST // Find all values that are used as pointer operands in LD/ST
void coalesce::extract_io_use(ir::value *v, std::set<ir::io_inst*>& result) { void coalesce::extract_io_use(ir::value *v, std::set<ir::io_inst*>& result) {

View File

@@ -14,38 +14,6 @@ namespace codegen{
namespace analysis{ namespace analysis{
// run pass on module // run pass on module
bool cts::is_loop_latch(ir::phi_node *phi, ir::instruction *terminator){
if(phi->get_parent() != terminator->get_parent())
return false;
if(auto *br = dynamic_cast<ir::cond_branch_inst*>(terminator))
return br->get_true_dest() == phi->get_parent()
|| br->get_false_dest() == phi->get_parent();
else if(dynamic_cast<ir::uncond_branch_inst*>(terminator))
return false;
else
throw std::runtime_error("unreachable");
}
inline bool get_is_shared(ir::value* v) {
if(dynamic_cast<ir::atomic_cas_inst*>(v))
return true;
if(dynamic_cast<ir::trans_inst*>(v))
return true;
if(dynamic_cast<ir::copy_to_shared_inst*>(v))
return true;
if(dynamic_cast<ir::reduce_inst*>(v))
return true;
if(auto *x = dynamic_cast<ir::phi_node*>(v)){
bool res = true;
for(unsigned inc = 0; inc < x->get_num_incoming(); inc++)
res = res && get_is_shared(x->get_incoming_value(inc));
return res;
}
return false;
}
void add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder) { void add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder) {
auto *i = dynamic_cast<ir::instruction*>(x); auto *i = dynamic_cast<ir::instruction*>(x);
// not an instruction // not an instruction
@@ -72,10 +40,6 @@ void add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder) {
} }
void cts::run(ir::module &mod) { void cts::run(ir::module &mod) {
shared_.clear();
refs_.clear();
double_.clear();
// Add shared copies // Add shared copies
ir::builder &builder = mod.get_builder(); ir::builder &builder = mod.get_builder();
for(ir::function *fn: mod.get_function_list()){ for(ir::function *fn: mod.get_function_list()){
@@ -88,55 +52,8 @@ void cts::run(ir::module &mod) {
add_copy(i, i->get_operand(k), builder); add_copy(i, i->get_operand(k), builder);
} }
} }
// Find which buffers are shared
for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *i: block->get_inst_list())
if(get_is_shared(i))
shared_.insert(i);
// double-buffering
for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *i: block->get_inst_list()) {
if(!i->get_type()->is_tile_ty())
continue;
// handle phi
if(auto *phi = dynamic_cast<ir::phi_node*>(i))
if(is_shared(phi)){
// determine if the value is in shared memory
bool is_double = false;
for(unsigned n = 0; n < phi->get_num_incoming(); n++){
ir::basic_block *inc_block = phi->get_incoming_block(n);
ir::instruction *terminator = inc_block->get_inst_list().back();
is_double = is_double || is_loop_latch(phi, terminator);
}
// add to double-buffered
if(is_double)
double_.insert(phi);
// set references of input
for(unsigned n = 0; n < phi->get_num_incoming(); n++){
ir::value *inc_val = phi->get_incoming_value(n);
refs_[inc_val] = phi;
}
}
}
} }
// query double-buffered status
bool cts::is_double(ir::value *x)
{ return double_.find(x) != double_.end(); }
// query shared status
bool cts::is_shared(ir::value *x)
{ return shared_.find(x) != shared_.end(); }
// get reference if any
ir::value *cts::get_reference(ir::value *x)
{ return refs_[x]; }
} }
} }

View File

@@ -2,8 +2,10 @@
#include <set> #include <set>
#include <algorithm> #include <algorithm>
#include "triton/codegen/transform/membar.h" #include "triton/codegen/analysis/liveness.h"
#include "triton/codegen/analysis/allocation.h" #include "triton/codegen/analysis/allocation.h"
#include "triton/codegen/instructions.h"
#include "triton/codegen/transform/membar.h"
#include "triton/codegen/transform/cts.h" #include "triton/codegen/transform/cts.h"
#include "triton/ir/module.h" #include "triton/ir/module.h"
#include "triton/ir/function.h" #include "triton/ir/function.h"
@@ -31,7 +33,10 @@ bool membar::intersect(const interval_vec_t &X, const interval_vec_t &Y) {
} }
void membar::add_reference(ir::value *v, interval_vec_t &res){ void membar::add_reference(ir::value *v, interval_vec_t &res){
if(buffer_info_->is_shared(v) && !dynamic_cast<ir::phi_node*>(v)){ auto *i = dynamic_cast<ir::instruction*>(v);
if(!i)
return;
if(storage_info.at(i->get_id()).first == SHARED){
unsigned offset = alloc_->offset(v); unsigned offset = alloc_->offset(v);
unsigned num_bytes = alloc_->num_bytes(v); unsigned num_bytes = alloc_->num_bytes(v);
res.push_back(interval_t(offset, offset + num_bytes)); res.push_back(interval_t(offset, offset + num_bytes));
@@ -79,10 +84,12 @@ std::pair<membar::interval_vec_t,
membar::interval_vec_t> membar::transfer(ir::basic_block *block, membar::interval_vec_t> membar::transfer(ir::basic_block *block,
const interval_vec_t &written_to, const interval_vec_t &written_to,
const interval_vec_t &read_from, const interval_vec_t &read_from,
std::set<ir::instruction*>& insert_loc) { std::set<ir::instruction*>& insert_loc,
std::set<ir::value*>& safe_war) {
ir::basic_block::inst_list_t instructions = block->get_inst_list(); ir::basic_block::inst_list_t instructions = block->get_inst_list();
interval_vec_t new_written_to = written_to; interval_vec_t new_written_to = written_to;
interval_vec_t new_read_from = read_from; interval_vec_t new_read_from = read_from;
for(ir::instruction *i: instructions){ for(ir::instruction *i: instructions){
interval_vec_t read, written; interval_vec_t read, written;
get_read_intervals(i, read); get_read_intervals(i, read);
@@ -90,9 +97,9 @@ std::pair<membar::interval_vec_t,
bool read_after_write = intersect(new_written_to, read); bool read_after_write = intersect(new_written_to, read);
bool write_after_read = intersect(new_read_from, written); bool write_after_read = intersect(new_read_from, written);
// double buffering: write and phi-node read won't intersect // double buffering: write and phi-node read won't intersect
if(buffer_info_->is_shared(i) && if(safe_war.find(i) != safe_war.end())
buffer_info_->is_double(buffer_info_->get_reference(i)))
write_after_read = false; write_after_read = false;
// record hazards
if(read_after_write || write_after_read) { if(read_after_write || write_after_read) {
insert_loc.insert(i); insert_loc.insert(i);
new_written_to.clear(); new_written_to.clear();
@@ -106,6 +113,18 @@ std::pair<membar::interval_vec_t,
void membar::run(ir::module &mod) { void membar::run(ir::module &mod) {
ir::builder &builder = mod.get_builder(); ir::builder &builder = mod.get_builder();
// extract phi-node associates with double-buffered
// shared-memory copies. These can be read from and written to
// without needing synchronization
std::set<ir::value*> safe_war;
ir::for_each_instruction(mod, [&](ir::instruction* i){
if(liveness_->has_double(i)){
auto info = liveness_->get_double(i);
safe_war.insert(i);
safe_war.insert(info.latch);
}
});
for(ir::function *fn: mod.get_function_list()){ for(ir::function *fn: mod.get_function_list()){
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn); std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
std::map<ir::basic_block*, interval_vec_t> written_to; std::map<ir::basic_block*, interval_vec_t> written_to;
@@ -125,7 +144,7 @@ void membar::run(ir::module &mod) {
for(ir::basic_block* pred: block->get_predecessors()) for(ir::basic_block* pred: block->get_predecessors())
pred_read_from.push_back(read_from[pred]); pred_read_from.push_back(read_from[pred]);
// apply transfer function // apply transfer function
auto result = transfer(block, join(pred_written_to), join(pred_read_from), insert_locs); auto result = transfer(block, join(pred_written_to), join(pred_read_from), insert_locs, safe_war);
written_to[block] = result.first; written_to[block] = result.first;
read_from[block] = result.second; read_from[block] = result.second;
} }

View File

@@ -241,7 +241,7 @@ std::string cu_module::compile_llvm_module(std::unique_ptr<llvm::Module> module,
cu_module::cu_module(driver::context * context, std::unique_ptr<llvm::Module> ll_module): cu_module(context, compile_llvm_module(std::move(ll_module), context->device())) { } cu_module::cu_module(driver::context * context, std::unique_ptr<llvm::Module> ll_module): cu_module(context, compile_llvm_module(std::move(ll_module), context->device())) { }
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){ cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
// std::cout << source << std::endl; std::cout << source << std::endl;
cu_context::context_switcher ctx_switch(*context); cu_context::context_switcher ctx_switch(*context);
// JIT compile source-code // JIT compile source-code
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER}; CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};

View File

@@ -199,24 +199,24 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
llvm::LLVMContext ctx; llvm::LLVMContext ctx;
std::unique_ptr<llvm::Module> llvm(new llvm::Module(module.get_name(), ctx)); std::unique_ptr<llvm::Module> llvm(new llvm::Module(module.get_name(), ctx));
// create passes // create passes
codegen::analysis::cts shmem_info; codegen::analysis::cts cts;
codegen::analysis::align align; codegen::analysis::align align;
codegen::analysis::liveness shmem_liveness(&shmem_info); codegen::analysis::liveness shmem_liveness;
codegen::analysis::axes axes; codegen::analysis::axes axes;
codegen::analysis::layout layouts(&axes); codegen::analysis::layout layouts(&axes);
codegen::transform::coalesce coalesce(&align, &layouts, &shmem_info); codegen::transform::coalesce coalesce(&align, &layouts);
codegen::analysis::tiles tiles(opt.num_warps, &align, &axes, &layouts); codegen::analysis::tiles tiles(opt.num_warps, &align, &axes, &layouts);
codegen::analysis::allocation shmem_allocation(&shmem_liveness, &shmem_info, &tiles); codegen::analysis::allocation shmem_allocation(&shmem_liveness, &tiles);
codegen::transform::membar shmem_barriers(&shmem_allocation, &shmem_info); codegen::transform::membar shmem_barriers(&shmem_liveness, &shmem_allocation);
codegen::transform::dce dce; codegen::transform::dce dce;
codegen::transform::peephole peephole; codegen::transform::peephole peephole;
codegen::transform::reassociate reassociate(&align); codegen::transform::reassociate reassociate(&align);
codegen::selection selection(&shmem_allocation, &tiles, &shmem_info, &align, &axes, &layouts, &coalesce, target.get(), opt.num_warps); codegen::selection selection(&shmem_liveness, &shmem_allocation, &tiles, &align, &axes, &layouts, &coalesce, target.get(), opt.num_warps);
// run passes // run passes
peephole.run(module); peephole.run(module);
dce.run(module); dce.run(module);
align.run(module); align.run(module);
shmem_info.run(module); cts.run(module);
axes.run(module); axes.run(module);
layouts.run(module); layouts.run(module);
coalesce.run(module); coalesce.run(module);
@@ -227,10 +227,10 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
reassociate.run(module); reassociate.run(module);
dce.run(module); dce.run(module);
peephole.run(module); peephole.run(module);
shmem_info.run(module); dce.run(module);
cts.run(module);
shmem_liveness.run(module); shmem_liveness.run(module);
ir::print(module, std::cout); shmem_allocation.run(module);
shmem_allocation.run();
if(shmem_allocation.allocated_size() > context->device()->max_shared_memory()) if(shmem_allocation.allocated_size() > context->device()->max_shared_memory())
return std::unique_ptr<driver::module>(); return std::unique_ptr<driver::module>();
shmem_barriers.run(module); shmem_barriers.run(module);

View File

@@ -45,10 +45,10 @@ std::vector<double> do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, i
opt.defines.push_back({"TYPE", {ty}}); opt.defines.push_back({"TYPE", {ty}});
opt.defines.push_back({"AT", {AT?"1":"0"}}); opt.defines.push_back({"AT", {AT?"1":"0"}});
opt.defines.push_back({"BT", {BT?"1":"0"}}); opt.defines.push_back({"BT", {BT?"1":"0"}});
opt.defines.push_back({"TM", {"64", "128"}}); opt.defines.push_back({"TM", {"128"}});
opt.defines.push_back({"TN", {"64", "128"}}); opt.defines.push_back({"TN", {"128"}});
opt.defines.push_back({"TK", {"8"}}); opt.defines.push_back({"TK", {"8"}});
opt.num_warps = {2, 4, 8}; opt.num_warps = {8};
// create function // create function
rt::function function(src::dot, opt); rt::function function(src::dot, opt);
// benchmark available libraries // benchmark available libraries