[codegen] cleaned up shared memory and double-buffering logic
This commit is contained in:
@@ -10,6 +10,7 @@ namespace triton{
|
||||
namespace ir{
|
||||
class value;
|
||||
class function;
|
||||
class module;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
@@ -22,8 +23,8 @@ class cts;
|
||||
|
||||
class allocation {
|
||||
public:
|
||||
allocation(liveness *live, cts *buffer_info, tiles *params)
|
||||
: liveness_(live), buffer_info_(buffer_info), tiles_(params){ }
|
||||
allocation(liveness *live, tiles *params)
|
||||
: liveness_(live), tiles_(params){ }
|
||||
// utilities
|
||||
unsigned num_bytes(ir::value *x);
|
||||
unsigned is_ld_padded(ir::value* x);
|
||||
@@ -31,7 +32,7 @@ public:
|
||||
unsigned offset(ir::value *x) const { return offsets_.at(x); }
|
||||
unsigned allocated_size() const { return allocated_size_; }
|
||||
// run
|
||||
void run();
|
||||
void run(ir::module& mod);
|
||||
|
||||
private:
|
||||
std::map<ir::value*, unsigned> offsets_;
|
||||
@@ -39,7 +40,6 @@ private:
|
||||
size_t allocated_size_;
|
||||
// dependences
|
||||
liveness *liveness_;
|
||||
cts *buffer_info_;
|
||||
tiles *tiles_;
|
||||
};
|
||||
|
||||
|
@@ -7,6 +7,7 @@ namespace triton{
|
||||
|
||||
namespace ir{
|
||||
class value;
|
||||
class phi_node;
|
||||
class function;
|
||||
class module;
|
||||
}
|
||||
@@ -31,6 +32,11 @@ struct segment {
|
||||
}
|
||||
};
|
||||
|
||||
struct double_buffer_info_t {
|
||||
ir::value* latch;
|
||||
ir::phi_node* phi;
|
||||
};
|
||||
|
||||
class liveness {
|
||||
private:
|
||||
typedef std::map<ir::value*, slot_index> indices_map_t;
|
||||
@@ -43,19 +49,20 @@ public:
|
||||
using const_iterator = intervals_map_t::const_iterator;
|
||||
|
||||
public:
|
||||
// constructor
|
||||
liveness(cts *info): info_(info){ }
|
||||
// accessors
|
||||
const intervals_map_t& intervals() const { return intervals_; }
|
||||
segment get_interval(ir::value* v) const { return intervals_.at(v); }
|
||||
// double-buffering
|
||||
bool has_double(ir::value *x) const { return double_.find(x) != double_.end(); }
|
||||
double_buffer_info_t get_double(ir::value *x) const { return double_.at(x); }
|
||||
// run
|
||||
void run(ir::module &mod);
|
||||
|
||||
private:
|
||||
cts *info_;
|
||||
has_storage_map_t has_dedicated_storage_;
|
||||
indices_map_t indices_;
|
||||
intervals_map_t intervals_;
|
||||
std::map<ir::value*, double_buffer_info_t> double_;
|
||||
};
|
||||
|
||||
}
|
||||
|
@@ -43,6 +43,7 @@ namespace triton{
|
||||
namespace codegen{
|
||||
|
||||
namespace analysis{
|
||||
class liveness;
|
||||
class tiles;
|
||||
class align;
|
||||
class allocation;
|
||||
@@ -201,10 +202,10 @@ private:
|
||||
|
||||
|
||||
public:
|
||||
selection(analysis::allocation *alloc, analysis::tiles *tiles, analysis::cts *buffer_info,
|
||||
analysis::align *alignment, analysis::axes *axes, analysis::layout *layouts,
|
||||
transform::coalesce* reorder, target *tgt, unsigned num_warps)
|
||||
: alloc_(alloc), tiles_(tiles), buffer_info_(buffer_info),
|
||||
selection(analysis::liveness* liveness, analysis::allocation *alloc, analysis::tiles *tiles,
|
||||
analysis::align *alignment, analysis::axes *axes,
|
||||
analysis::layout *layouts, transform::coalesce* reorder, target *tgt, unsigned num_warps)
|
||||
: liveness_(liveness), alloc_(alloc), tiles_(tiles),
|
||||
alignment_(alignment), a_axes_(axes), layouts_(layouts),
|
||||
reorder_(reorder), tgt_(tgt), num_warps_(num_warps){ }
|
||||
|
||||
@@ -213,11 +214,11 @@ public:
|
||||
private:
|
||||
vmap_t vmap_;
|
||||
tmap_t tmap_;
|
||||
analysis::liveness *liveness_;
|
||||
analysis::allocation *alloc_;
|
||||
analysis::tiles *tiles_;
|
||||
analysis::axes *a_axes_;
|
||||
analysis::layout *layouts_;
|
||||
analysis::cts *buffer_info_;
|
||||
analysis::align *alignment_;
|
||||
transform::coalesce *reorder_;
|
||||
target *tgt_;
|
||||
|
@@ -32,13 +32,12 @@ private:
|
||||
ir::value* rematerialize(ir::value *v, ir::builder& builder, std::map<ir::value*, ir::value*>& seen);
|
||||
|
||||
public:
|
||||
coalesce(analysis::align* align, triton::codegen::analysis::layout *layouts, analysis::cts* mem);
|
||||
coalesce(analysis::align* align, triton::codegen::analysis::layout *layouts);
|
||||
void run(ir::module &mod);
|
||||
|
||||
private:
|
||||
analysis::align* align_;
|
||||
analysis::layout* layout_;
|
||||
analysis::cts* mem_;
|
||||
};
|
||||
|
||||
}
|
||||
|
@@ -19,17 +19,6 @@ namespace analysis{
|
||||
class cts {
|
||||
public:
|
||||
void run(ir::module &mod);
|
||||
// queries
|
||||
bool is_double(ir::value *x);
|
||||
void add_shared(ir::value *v);
|
||||
bool is_shared(ir::value *x);
|
||||
bool is_loop_latch(ir::phi_node *phi, ir::instruction *terminator);
|
||||
ir::value *get_reference(ir::value *x);
|
||||
|
||||
private:
|
||||
std::set<ir::value*> shared_;
|
||||
std::set<ir::value*> double_;
|
||||
std::map<ir::value*, ir::value*> refs_;
|
||||
};
|
||||
|
||||
}
|
||||
|
@@ -16,6 +16,7 @@ namespace codegen{
|
||||
namespace analysis{
|
||||
|
||||
class allocation;
|
||||
class liveness;
|
||||
class cts;
|
||||
|
||||
}
|
||||
@@ -35,15 +36,17 @@ private:
|
||||
void add_reference(ir::value *v, interval_vec_t &res);
|
||||
void get_read_intervals(ir::instruction *i, interval_vec_t &res);
|
||||
void get_written_intervals(ir::instruction *i, interval_vec_t &res);
|
||||
std::pair<interval_vec_t, interval_vec_t> transfer(ir::basic_block *block, const interval_vec_t &written_to, const interval_vec_t &read_from, std::set<ir::instruction *> &insert_loc);
|
||||
std::pair<interval_vec_t, interval_vec_t> transfer(ir::basic_block *block, const interval_vec_t &written_to, const interval_vec_t &read_from,
|
||||
std::set<ir::instruction *> &insert_loc, std::set<triton::ir::value *> &safe_war);
|
||||
|
||||
public:
|
||||
membar(analysis::allocation *alloc, analysis::cts *buffer_info): alloc_(alloc), buffer_info_(buffer_info) {}
|
||||
membar(analysis::liveness *liveness, analysis::allocation *alloc):
|
||||
liveness_(liveness), alloc_(alloc) {}
|
||||
void run(ir::module &mod);
|
||||
|
||||
private:
|
||||
analysis::liveness *liveness_;
|
||||
analysis::allocation *alloc_;
|
||||
analysis::cts *buffer_info_;
|
||||
};
|
||||
|
||||
|
||||
|
@@ -1,4 +1,5 @@
|
||||
#include <algorithm>
|
||||
#include <climits>
|
||||
#include "triton/codegen/analysis/allocation.h"
|
||||
#include "triton/codegen/analysis/liveness.h"
|
||||
#include "triton/codegen/transform/cts.h"
|
||||
@@ -8,6 +9,7 @@
|
||||
#include "triton/ir/value.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/utils.h"
|
||||
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
@@ -68,13 +70,12 @@ unsigned allocation::num_bytes(ir::value *x) {
|
||||
unsigned ld = x->get_type()->get_tile_shapes()[0];
|
||||
num_bytes += pad * num_bytes / ld;
|
||||
}
|
||||
if(buffer_info_->is_double(x))
|
||||
if(liveness_->has_double(x))
|
||||
num_bytes *= 2;
|
||||
return num_bytes;
|
||||
}
|
||||
|
||||
|
||||
void allocation::run() {
|
||||
void allocation::run(ir::module &mod) {
|
||||
using std::max;
|
||||
using std::min;
|
||||
typedef std::multimap<unsigned, segment> triples_map_type;
|
||||
@@ -85,7 +86,7 @@ void allocation::run() {
|
||||
std::vector<ir::value *> J = I;
|
||||
|
||||
triples_map_type H;
|
||||
H.insert({0, segment{0, 1024}});
|
||||
H.insert({0, segment{0, INT_MAX}});
|
||||
|
||||
std::vector<ir::value *> V;
|
||||
std::map<ir::value *, unsigned> starts;
|
||||
@@ -115,7 +116,6 @@ void allocation::run() {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Build interference graph
|
||||
std::map<ir::value*, std::set<ir::value *>> interferences;
|
||||
for(ir::value *x: V)
|
||||
@@ -137,6 +137,7 @@ void allocation::run() {
|
||||
for(ir::value *X: V)
|
||||
colors[X] = (X==V[0])?0:-1;
|
||||
|
||||
|
||||
// First-fit graph coloring
|
||||
std::vector<bool> available(V.size());
|
||||
for(ir::value *x: V){
|
||||
@@ -158,17 +159,11 @@ void allocation::run() {
|
||||
for(ir::value *y: interferences[x])
|
||||
Adj = std::max(Adj, starts[y] + num_bytes(y));
|
||||
offsets_[x] = starts[x] + colors[x] * Adj;
|
||||
// std::cout << x->get_name() << " " << offsets_[x] << " " << num_bytes(x) << std::endl;
|
||||
if(buffer_info_->is_double(x)){
|
||||
ir::phi_node *phi = (ir::phi_node*)x;
|
||||
for(unsigned i = 0; i < phi->get_num_incoming(); i++){
|
||||
ir::value *inc_val = phi->get_incoming_value(i);
|
||||
offsets_[inc_val] = offsets_[phi];
|
||||
if(liveness_->has_double(x)){
|
||||
auto info = liveness_->get_double(x);
|
||||
offsets_[info.latch] = offsets_[x] + num_bytes(x) / 2;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// exit(EXIT_FAILURE);
|
||||
|
||||
// Save maximum size of induced memory space
|
||||
allocated_size_ = 0;
|
||||
|
@@ -7,13 +7,59 @@
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/value.h"
|
||||
#include "triton/ir/utils.h"
|
||||
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
|
||||
inline bool is_loop_latch(ir::phi_node *phi, ir::instruction *terminator){
|
||||
if(phi->get_parent() != terminator->get_parent())
|
||||
return false;
|
||||
if(auto *br = dynamic_cast<ir::cond_branch_inst*>(terminator))
|
||||
return br->get_true_dest() == phi->get_parent()
|
||||
|| br->get_false_dest() == phi->get_parent();
|
||||
else if(dynamic_cast<ir::uncond_branch_inst*>(terminator))
|
||||
return false;
|
||||
else
|
||||
throw std::runtime_error("unreachable");
|
||||
}
|
||||
|
||||
inline void extract_double_bufferable(ir::instruction *i, std::map<ir::value*, double_buffer_info_t>& result) {
|
||||
auto* phi = dynamic_cast<ir::phi_node*>(i);
|
||||
if(!phi || phi->get_num_incoming() != 2)
|
||||
return;
|
||||
ir::basic_block *block_0 = phi->get_incoming_block(0);
|
||||
ir::basic_block *block_1 = phi->get_incoming_block(1);
|
||||
ir::instruction *terminator_0 = block_0->get_inst_list().back();
|
||||
ir::instruction *terminator_1 = block_1->get_inst_list().back();
|
||||
bool is_latch_0 = is_loop_latch(phi, terminator_0);
|
||||
bool is_latch_1 = is_loop_latch(phi, terminator_1);
|
||||
ir::value *value_0 = phi->get_incoming_value(0);
|
||||
ir::value *value_1 = phi->get_incoming_value(1);
|
||||
ir::instruction *i_0 = dynamic_cast<ir::instruction*>(value_0);
|
||||
ir::instruction *i_1 = dynamic_cast<ir::instruction*>(value_1);
|
||||
if(!i_0 || !i_1 || storage_info.at(i_0->get_id()).first != SHARED || storage_info.at(i_1->get_id()).first != SHARED)
|
||||
return;
|
||||
if(is_latch_1)
|
||||
result[value_0] = double_buffer_info_t{value_1, phi};
|
||||
if(is_latch_0)
|
||||
result[value_1] = double_buffer_info_t{value_0, phi};
|
||||
}
|
||||
|
||||
|
||||
// Entry point
|
||||
void liveness::run(ir::module &mod) {
|
||||
double_.clear();
|
||||
indices_.clear();
|
||||
intervals_.clear();
|
||||
|
||||
// set of pair of values that can be double-buffered
|
||||
ir::for_each_instruction(mod, [this](ir::instruction* i) {
|
||||
extract_double_bufferable(i, this->double_);
|
||||
});
|
||||
|
||||
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
// Assigns index to each instruction
|
||||
slot_index index = 0;
|
||||
@@ -26,12 +72,10 @@ void liveness::run(ir::module &mod) {
|
||||
// Creates live intervals
|
||||
for(auto i: indices_){
|
||||
ir::value *v = i.first;
|
||||
// ir::instruction* instr = dynamic_cast<ir::instruction*>(v);
|
||||
// if(!instr)
|
||||
// continue;
|
||||
// if(storage_info.at(instr->get_id()).first != SHARED)
|
||||
// continue;
|
||||
if(!info_->is_shared(v) || info_->get_reference(v))
|
||||
ir::instruction* instr = dynamic_cast<ir::instruction*>(v);
|
||||
if(!instr)
|
||||
continue;
|
||||
if(storage_info.at(instr->get_id()).first != SHARED)
|
||||
continue;
|
||||
unsigned start = i.second;
|
||||
unsigned end = start;
|
||||
@@ -41,6 +85,21 @@ void liveness::run(ir::module &mod) {
|
||||
}
|
||||
intervals_[v] = segment{start, end};
|
||||
}
|
||||
// Double-Buffering
|
||||
// Arrays are live throughout the end of the loop
|
||||
auto it = intervals_.begin();
|
||||
while(it != intervals_.end()) {
|
||||
ir::value *x = it->first;
|
||||
auto dit = double_.find(x);
|
||||
if(dit != double_.end()) {
|
||||
ir::value *y = dit->second.latch;
|
||||
unsigned start = intervals_[x].start;
|
||||
unsigned end = intervals_[y].end;
|
||||
intervals_[x] = segment{start, end};
|
||||
intervals_.erase(y);
|
||||
}
|
||||
it++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -1,12 +1,14 @@
|
||||
#include <numeric>
|
||||
#include "triton/codegen/selection.h"
|
||||
#include "triton/codegen/target.h"
|
||||
#include "triton/codegen/analysis/liveness.h"
|
||||
#include "triton/codegen/analysis/layout.h"
|
||||
#include "triton/codegen/analysis/axes.h"
|
||||
#include "triton/codegen/analysis/tiles.h"
|
||||
#include "triton/codegen/analysis/allocation.h"
|
||||
#include "triton/codegen/analysis/align.h"
|
||||
#include "triton/codegen/transform/coalesce.h"
|
||||
#include "triton/codegen/instructions.h"
|
||||
#include "triton/ir/context.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
@@ -746,43 +748,32 @@ void selection::create_shared_tile(ir::value *v, IRBuilder<> &builder, Value *sh
|
||||
Type* ty = llvm_type(v->get_type()->get_scalar_ty(), builder.getContext());
|
||||
// shared copy
|
||||
PointerType *ptr_ty = ty->getPointerTo(sh_mem_ptr->getType()->getPointerAddressSpace());
|
||||
// phi-node (double-buffering)
|
||||
if(auto *phi = dynamic_cast<ir::phi_node*>(v)) {
|
||||
// double-buffered
|
||||
if(liveness_->has_double(v)) {
|
||||
auto info = liveness_->get_double(v);
|
||||
ir::phi_node *phi = info.phi;
|
||||
BasicBlock *parent = (BasicBlock*)vmap_[phi->get_parent()];
|
||||
unsigned id_pre = 0, id_loop = 1;
|
||||
if(phi->get_incoming_block(0) == phi->get_parent())
|
||||
std::swap(id_pre, id_loop);
|
||||
if(parent->empty())
|
||||
builder.SetInsertPoint(parent);
|
||||
else
|
||||
builder.SetInsertPoint(&*parent->getFirstInsertionPt());
|
||||
// create double-buffered pointer
|
||||
PHINode *ptr = builder.CreatePHI(ptr_ty, 2);
|
||||
PHINode *offset = builder.CreatePHI(builder.getInt32Ty(), 2);
|
||||
// next pointer
|
||||
Value *pre_ptr = builder.CreateGEP(sh_mem_ptr, builder.getInt32(alloc_->offset(phi)));
|
||||
Value *pre_ptr = builder.CreateGEP(sh_mem_ptr, builder.getInt32(alloc_->offset(v)));
|
||||
pre_ptr = builder.CreateBitCast(pre_ptr, ptr->getType());
|
||||
Value *next_ptr = builder.CreateGEP(ptr, offset, "next_ptr");
|
||||
tmap_.insert({phi, new shared_tile(ty, shapes, ptr, builder, offset)});
|
||||
for(unsigned i = 0; i < phi->get_num_incoming(); i++) {
|
||||
ir::basic_block* inc_block = phi->get_incoming_block(i);
|
||||
ir::value* inc_value = phi->get_incoming_value(i);
|
||||
ir::instruction* terminator = inc_block->get_inst_list().back();
|
||||
bool is_loop_latch = buffer_info_->is_loop_latch(phi, terminator);
|
||||
tmap_.insert({inc_value, new shared_tile(ty, shapes, is_loop_latch?next_ptr:pre_ptr, builder)});
|
||||
}
|
||||
tmap_.insert({v, new shared_tile(ty, shapes, pre_ptr, builder)});
|
||||
tmap_.insert({info.latch, new shared_tile(ty, shapes, next_ptr, builder)});
|
||||
}
|
||||
else {
|
||||
bool has_phi_user = false;
|
||||
for(ir::user *usr: v->get_users())
|
||||
if(dynamic_cast<ir::phi_node*>(usr))
|
||||
has_phi_user = true;
|
||||
if(!has_phi_user){
|
||||
size_t offset = alloc_->offset(v);
|
||||
Value *ptr = builder.CreateGEP(sh_mem_ptr, builder.getInt32(offset));
|
||||
ptr = builder.CreateBitCast(ptr, ptr_ty);
|
||||
tmap_.insert({v, new shared_tile(ty, shapes, ptr, builder)});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void selection::create_distributed_tile(ir::value *v, IRBuilder<> &builder) {
|
||||
@@ -827,8 +818,9 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder,
|
||||
if(auto *user = dynamic_cast<ir::user*>(v))
|
||||
for(ir::value *op: user->ops())
|
||||
create_tile(op, builder, seen, sh_mem_ptr);
|
||||
if(buffer_info_->is_shared(v) && !dynamic_cast<ir::reduce_inst*>(v))
|
||||
create_shared_tile(v, builder, sh_mem_ptr);
|
||||
auto *i = dynamic_cast<ir::instruction*>(v);
|
||||
if(i && storage_info.at(i->get_id()).first == SHARED && !dynamic_cast<ir::reduce_inst*>(v))
|
||||
create_shared_tile(i, builder, sh_mem_ptr);
|
||||
else
|
||||
create_distributed_tile(v, builder);
|
||||
}
|
||||
@@ -1427,7 +1419,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
lower_masked_load(x, ctx, fn, builder);
|
||||
else if(auto *x = dynamic_cast<ir::load_inst*>(ins))
|
||||
lower_load(x, ctx, fn, builder);
|
||||
else if(!buffer_info_->is_shared(ins))
|
||||
else if(!dynamic_cast<shared_tile*>(tmap_.at(ins)))
|
||||
lower_elementwise(ins, ctx, fn, builder);
|
||||
}
|
||||
|
||||
@@ -1556,21 +1548,19 @@ void selection::run(ir::module &src, Module &dst) {
|
||||
}
|
||||
}
|
||||
|
||||
// add phi operands
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *inst: block->get_inst_list())
|
||||
if(auto *phi = dynamic_cast<ir::phi_node*>(inst)){
|
||||
if(buffer_info_->is_double(phi)) {
|
||||
for(ir::instruction *inst: block->get_inst_list()) {
|
||||
if(liveness_->has_double(inst)) {
|
||||
auto info = liveness_->get_double(inst);
|
||||
ir::phi_node *phi = info.phi;
|
||||
PHINode *ptr = (PHINode*)((shared_tile*)tmap_.at(phi))->get_pointer();
|
||||
PHINode *offset = (PHINode*)((shared_tile*)tmap_.at(phi))->get_offset();
|
||||
for(unsigned n = 0; n < phi->get_num_incoming(); n++){
|
||||
ir::basic_block* inc_block = phi->get_incoming_block(n);
|
||||
ir::value* inc_val = phi->get_incoming_value(n);
|
||||
ir::instruction* terminator = inc_block->get_inst_list().back();
|
||||
BasicBlock *llvm_inc_block = last_block.at(inc_block);
|
||||
shared_tile *inc_shared = (shared_tile*)tmap_.at(inc_val);
|
||||
bool is_loop_latch = buffer_info_->is_loop_latch(phi, terminator);
|
||||
if(is_loop_latch){
|
||||
if(inc_val == info.latch){
|
||||
dst_builder.SetInsertPoint(llvm_inc_block->getTerminator());
|
||||
Value *next_offset = dst_builder.CreateNeg(offset);
|
||||
offset->addIncoming(next_offset, llvm_inc_block);
|
||||
@@ -1582,7 +1572,14 @@ void selection::run(ir::module &src, Module &dst) {
|
||||
ptr->addIncoming(inc_shared->get_pointer(), llvm_inc_block);
|
||||
}
|
||||
}
|
||||
else {
|
||||
}
|
||||
|
||||
// add phi operands
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *inst: block->get_inst_list())
|
||||
if(auto *phi = dynamic_cast<ir::phi_node*>(inst)){
|
||||
if(tmap_.find(phi) == tmap_.end() ||
|
||||
!dynamic_cast<shared_tile*>(tmap_.at(phi))) {
|
||||
for(unsigned n = 0; n < phi->get_num_incoming(); n++){
|
||||
ir::value *inc_val = phi->get_incoming_value(n);
|
||||
ir::basic_block *inc_block = phi->get_incoming_block(n);
|
||||
|
@@ -15,8 +15,8 @@ namespace triton {
|
||||
namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
coalesce::coalesce(analysis::align* align, analysis::layout *layouts, analysis::cts *mem)
|
||||
: align_(align), layout_(layouts), mem_(mem) { }
|
||||
coalesce::coalesce(analysis::align* align, analysis::layout *layouts)
|
||||
: align_(align), layout_(layouts) { }
|
||||
|
||||
// Find all values that are used as pointer operands in LD/ST
|
||||
void coalesce::extract_io_use(ir::value *v, std::set<ir::io_inst*>& result) {
|
||||
|
@@ -14,38 +14,6 @@ namespace codegen{
|
||||
namespace analysis{
|
||||
|
||||
// run pass on module
|
||||
bool cts::is_loop_latch(ir::phi_node *phi, ir::instruction *terminator){
|
||||
if(phi->get_parent() != terminator->get_parent())
|
||||
return false;
|
||||
if(auto *br = dynamic_cast<ir::cond_branch_inst*>(terminator))
|
||||
return br->get_true_dest() == phi->get_parent()
|
||||
|| br->get_false_dest() == phi->get_parent();
|
||||
else if(dynamic_cast<ir::uncond_branch_inst*>(terminator))
|
||||
return false;
|
||||
else
|
||||
throw std::runtime_error("unreachable");
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline bool get_is_shared(ir::value* v) {
|
||||
if(dynamic_cast<ir::atomic_cas_inst*>(v))
|
||||
return true;
|
||||
if(dynamic_cast<ir::trans_inst*>(v))
|
||||
return true;
|
||||
if(dynamic_cast<ir::copy_to_shared_inst*>(v))
|
||||
return true;
|
||||
if(dynamic_cast<ir::reduce_inst*>(v))
|
||||
return true;
|
||||
if(auto *x = dynamic_cast<ir::phi_node*>(v)){
|
||||
bool res = true;
|
||||
for(unsigned inc = 0; inc < x->get_num_incoming(); inc++)
|
||||
res = res && get_is_shared(x->get_incoming_value(inc));
|
||||
return res;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder) {
|
||||
auto *i = dynamic_cast<ir::instruction*>(x);
|
||||
// not an instruction
|
||||
@@ -72,10 +40,6 @@ void add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder) {
|
||||
}
|
||||
|
||||
void cts::run(ir::module &mod) {
|
||||
shared_.clear();
|
||||
refs_.clear();
|
||||
double_.clear();
|
||||
|
||||
// Add shared copies
|
||||
ir::builder &builder = mod.get_builder();
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
@@ -88,55 +52,8 @@ void cts::run(ir::module &mod) {
|
||||
add_copy(i, i->get_operand(k), builder);
|
||||
}
|
||||
}
|
||||
|
||||
// Find which buffers are shared
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *i: block->get_inst_list())
|
||||
if(get_is_shared(i))
|
||||
shared_.insert(i);
|
||||
|
||||
// double-buffering
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *i: block->get_inst_list()) {
|
||||
if(!i->get_type()->is_tile_ty())
|
||||
continue;
|
||||
// handle phi
|
||||
if(auto *phi = dynamic_cast<ir::phi_node*>(i))
|
||||
if(is_shared(phi)){
|
||||
// determine if the value is in shared memory
|
||||
bool is_double = false;
|
||||
for(unsigned n = 0; n < phi->get_num_incoming(); n++){
|
||||
ir::basic_block *inc_block = phi->get_incoming_block(n);
|
||||
ir::instruction *terminator = inc_block->get_inst_list().back();
|
||||
is_double = is_double || is_loop_latch(phi, terminator);
|
||||
}
|
||||
// add to double-buffered
|
||||
if(is_double)
|
||||
double_.insert(phi);
|
||||
// set references of input
|
||||
for(unsigned n = 0; n < phi->get_num_incoming(); n++){
|
||||
ir::value *inc_val = phi->get_incoming_value(n);
|
||||
refs_[inc_val] = phi;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// query double-buffered status
|
||||
bool cts::is_double(ir::value *x)
|
||||
{ return double_.find(x) != double_.end(); }
|
||||
|
||||
// query shared status
|
||||
bool cts::is_shared(ir::value *x)
|
||||
{ return shared_.find(x) != shared_.end(); }
|
||||
|
||||
// get reference if any
|
||||
ir::value *cts::get_reference(ir::value *x)
|
||||
{ return refs_[x]; }
|
||||
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
@@ -2,8 +2,10 @@
|
||||
#include <set>
|
||||
#include <algorithm>
|
||||
|
||||
#include "triton/codegen/transform/membar.h"
|
||||
#include "triton/codegen/analysis/liveness.h"
|
||||
#include "triton/codegen/analysis/allocation.h"
|
||||
#include "triton/codegen/instructions.h"
|
||||
#include "triton/codegen/transform/membar.h"
|
||||
#include "triton/codegen/transform/cts.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
@@ -31,7 +33,10 @@ bool membar::intersect(const interval_vec_t &X, const interval_vec_t &Y) {
|
||||
}
|
||||
|
||||
void membar::add_reference(ir::value *v, interval_vec_t &res){
|
||||
if(buffer_info_->is_shared(v) && !dynamic_cast<ir::phi_node*>(v)){
|
||||
auto *i = dynamic_cast<ir::instruction*>(v);
|
||||
if(!i)
|
||||
return;
|
||||
if(storage_info.at(i->get_id()).first == SHARED){
|
||||
unsigned offset = alloc_->offset(v);
|
||||
unsigned num_bytes = alloc_->num_bytes(v);
|
||||
res.push_back(interval_t(offset, offset + num_bytes));
|
||||
@@ -79,10 +84,12 @@ std::pair<membar::interval_vec_t,
|
||||
membar::interval_vec_t> membar::transfer(ir::basic_block *block,
|
||||
const interval_vec_t &written_to,
|
||||
const interval_vec_t &read_from,
|
||||
std::set<ir::instruction*>& insert_loc) {
|
||||
std::set<ir::instruction*>& insert_loc,
|
||||
std::set<ir::value*>& safe_war) {
|
||||
ir::basic_block::inst_list_t instructions = block->get_inst_list();
|
||||
interval_vec_t new_written_to = written_to;
|
||||
interval_vec_t new_read_from = read_from;
|
||||
|
||||
for(ir::instruction *i: instructions){
|
||||
interval_vec_t read, written;
|
||||
get_read_intervals(i, read);
|
||||
@@ -90,9 +97,9 @@ std::pair<membar::interval_vec_t,
|
||||
bool read_after_write = intersect(new_written_to, read);
|
||||
bool write_after_read = intersect(new_read_from, written);
|
||||
// double buffering: write and phi-node read won't intersect
|
||||
if(buffer_info_->is_shared(i) &&
|
||||
buffer_info_->is_double(buffer_info_->get_reference(i)))
|
||||
if(safe_war.find(i) != safe_war.end())
|
||||
write_after_read = false;
|
||||
// record hazards
|
||||
if(read_after_write || write_after_read) {
|
||||
insert_loc.insert(i);
|
||||
new_written_to.clear();
|
||||
@@ -106,6 +113,18 @@ std::pair<membar::interval_vec_t,
|
||||
|
||||
void membar::run(ir::module &mod) {
|
||||
ir::builder &builder = mod.get_builder();
|
||||
// extract phi-node associates with double-buffered
|
||||
// shared-memory copies. These can be read from and written to
|
||||
// without needing synchronization
|
||||
std::set<ir::value*> safe_war;
|
||||
ir::for_each_instruction(mod, [&](ir::instruction* i){
|
||||
if(liveness_->has_double(i)){
|
||||
auto info = liveness_->get_double(i);
|
||||
safe_war.insert(i);
|
||||
safe_war.insert(info.latch);
|
||||
}
|
||||
});
|
||||
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
|
||||
std::map<ir::basic_block*, interval_vec_t> written_to;
|
||||
@@ -125,7 +144,7 @@ void membar::run(ir::module &mod) {
|
||||
for(ir::basic_block* pred: block->get_predecessors())
|
||||
pred_read_from.push_back(read_from[pred]);
|
||||
// apply transfer function
|
||||
auto result = transfer(block, join(pred_written_to), join(pred_read_from), insert_locs);
|
||||
auto result = transfer(block, join(pred_written_to), join(pred_read_from), insert_locs, safe_war);
|
||||
written_to[block] = result.first;
|
||||
read_from[block] = result.second;
|
||||
}
|
||||
|
@@ -241,7 +241,7 @@ std::string cu_module::compile_llvm_module(std::unique_ptr<llvm::Module> module,
|
||||
cu_module::cu_module(driver::context * context, std::unique_ptr<llvm::Module> ll_module): cu_module(context, compile_llvm_module(std::move(ll_module), context->device())) { }
|
||||
|
||||
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
|
||||
// std::cout << source << std::endl;
|
||||
std::cout << source << std::endl;
|
||||
cu_context::context_switcher ctx_switch(*context);
|
||||
// JIT compile source-code
|
||||
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};
|
||||
|
@@ -199,24 +199,24 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
||||
llvm::LLVMContext ctx;
|
||||
std::unique_ptr<llvm::Module> llvm(new llvm::Module(module.get_name(), ctx));
|
||||
// create passes
|
||||
codegen::analysis::cts shmem_info;
|
||||
codegen::analysis::cts cts;
|
||||
codegen::analysis::align align;
|
||||
codegen::analysis::liveness shmem_liveness(&shmem_info);
|
||||
codegen::analysis::liveness shmem_liveness;
|
||||
codegen::analysis::axes axes;
|
||||
codegen::analysis::layout layouts(&axes);
|
||||
codegen::transform::coalesce coalesce(&align, &layouts, &shmem_info);
|
||||
codegen::transform::coalesce coalesce(&align, &layouts);
|
||||
codegen::analysis::tiles tiles(opt.num_warps, &align, &axes, &layouts);
|
||||
codegen::analysis::allocation shmem_allocation(&shmem_liveness, &shmem_info, &tiles);
|
||||
codegen::transform::membar shmem_barriers(&shmem_allocation, &shmem_info);
|
||||
codegen::analysis::allocation shmem_allocation(&shmem_liveness, &tiles);
|
||||
codegen::transform::membar shmem_barriers(&shmem_liveness, &shmem_allocation);
|
||||
codegen::transform::dce dce;
|
||||
codegen::transform::peephole peephole;
|
||||
codegen::transform::reassociate reassociate(&align);
|
||||
codegen::selection selection(&shmem_allocation, &tiles, &shmem_info, &align, &axes, &layouts, &coalesce, target.get(), opt.num_warps);
|
||||
codegen::selection selection(&shmem_liveness, &shmem_allocation, &tiles, &align, &axes, &layouts, &coalesce, target.get(), opt.num_warps);
|
||||
// run passes
|
||||
peephole.run(module);
|
||||
dce.run(module);
|
||||
align.run(module);
|
||||
shmem_info.run(module);
|
||||
cts.run(module);
|
||||
axes.run(module);
|
||||
layouts.run(module);
|
||||
coalesce.run(module);
|
||||
@@ -227,10 +227,10 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
||||
reassociate.run(module);
|
||||
dce.run(module);
|
||||
peephole.run(module);
|
||||
shmem_info.run(module);
|
||||
dce.run(module);
|
||||
cts.run(module);
|
||||
shmem_liveness.run(module);
|
||||
ir::print(module, std::cout);
|
||||
shmem_allocation.run();
|
||||
shmem_allocation.run(module);
|
||||
if(shmem_allocation.allocated_size() > context->device()->max_shared_memory())
|
||||
return std::unique_ptr<driver::module>();
|
||||
shmem_barriers.run(module);
|
||||
|
@@ -45,10 +45,10 @@ std::vector<double> do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, i
|
||||
opt.defines.push_back({"TYPE", {ty}});
|
||||
opt.defines.push_back({"AT", {AT?"1":"0"}});
|
||||
opt.defines.push_back({"BT", {BT?"1":"0"}});
|
||||
opt.defines.push_back({"TM", {"64", "128"}});
|
||||
opt.defines.push_back({"TN", {"64", "128"}});
|
||||
opt.defines.push_back({"TM", {"128"}});
|
||||
opt.defines.push_back({"TN", {"128"}});
|
||||
opt.defines.push_back({"TK", {"8"}});
|
||||
opt.num_warps = {2, 4, 8};
|
||||
opt.num_warps = {8};
|
||||
// create function
|
||||
rt::function function(src::dot, opt);
|
||||
// benchmark available libraries
|
||||
|
Reference in New Issue
Block a user