This commit is contained in:
Philippe Tillet
2019-10-11 19:05:54 -04:00
parent 4efd0a3c6b
commit 323c90e431
20 changed files with 237 additions and 166 deletions

View File

@@ -30,6 +30,7 @@ private:
void update_graph_broadcast(ir::instruction *i);
void update_graph_dot(ir::instruction *i);
void update_graph_elementwise(ir::instruction *i);
void update_graph_no_edge(ir::instruction *i);
void update_graph(ir::instruction *i);
public:

View File

@@ -23,19 +23,24 @@ class align;
enum layout_type_t {
HMMA_884,
SCANLINE
SCANLINE,
SHARED
};
struct layout_t {
layout_t(layout_type_t _type,
const std::vector<int>& _axes,
const std::vector<unsigned> &_shapes,
const std::vector<ir::value *> &values,
const std::vector<ir::value *> &_values,
size_t _id,
analysis::align* align);
layout_type_t type;
std::vector<int> axes;
std::vector<unsigned> shapes;
std::vector<ir::value*> values;
std::vector<int> order;
size_t id;
size_t size;
std::vector<int> mts;
std::vector<int> nts;
std::vector<int> fpw;
@@ -46,7 +51,8 @@ struct layout_hmma_884_t: public layout_t {
layout_hmma_884_t(size_t num_warps,
const std::vector<int>& _axes,
const std::vector<unsigned>& _shapes,
const std::vector<ir::value *> &values,
const std::vector<ir::value *> &_values,
size_t _id,
analysis::align* align);
};
@@ -55,9 +61,20 @@ struct layout_scanline_t: public layout_t {
const std::vector<int>& _axes,
const std::vector<unsigned>& _shapes,
const std::vector<ir::value *> &values,
size_t _id,
analysis::align* align);
};
struct layout_shared_t: public layout_t {
layout_shared_t(const layout_t *arg,
const std::vector<int>& _axes,
const std::vector<unsigned>& _shapes,
const std::vector<ir::value *> &values,
size_t _id,
analysis::align* align);
};
class layout {
typedef ir::value* node_t;
typedef std::map <node_t, std::set<node_t>> graph_t;
@@ -70,6 +87,8 @@ private:
void init_hmma_tile(layout_t& layout);
void init_scanline_tile(layout_t &layout);
void create(size_t id, const std::vector<ir::value*>& values);
public:
// constructor
layout(analysis::axes *axes, analysis::align *align, size_t num_warps);

View File

@@ -23,6 +23,7 @@ typedef unsigned slot_index;
class tiles;
class layout;
class layout_t;
struct segment {
slot_index start;
@@ -42,18 +43,11 @@ struct double_buffer_info_t {
ir::phi_node* phi;
};
struct buffer_t {
size_t id;
size_t size;
bool operator<(buffer_t other) const {
return id < other.id;
}
};
class liveness {
private:
typedef std::map<ir::value*, slot_index> indices_map_t;
typedef std::map<buffer_t*, segment> intervals_map_t;
typedef std::map<layout_t*, segment> intervals_map_t;
typedef std::map<ir::value*, bool> has_storage_map_t;
typedef ir::value* node_t;
typedef std::map <node_t, std::set<node_t>> graph_t;
@@ -82,10 +76,7 @@ public:
unsigned num_bytes(ir::value *x);
// accessors
const intervals_map_t& intervals() const { return intervals_; }
segment get_interval(buffer_t* v) const { return intervals_.at(v); }
// buffers
buffer_t* get_buffer(ir::value *v) const { return groups_.at(v); }
std::vector<ir::value*> get_values(buffer_t* x) const { return values_.at(x); }
segment get_interval(layout_t* v) const { return intervals_.at(v); }
// double-buffering
bool has_double(ir::value *x) const { return double_.find(x) != double_.end(); }
double_buffer_info_t get_double(ir::value *x) const { return double_.at(x); }
@@ -101,10 +92,6 @@ private:
intervals_map_t intervals_;
std::map<ir::value*, double_buffer_info_t> double_;
std::map<ir::value*, size_t> pad_;
// buffers
tools::graph<node_t> graph_;
std::map<ir::value*, buffer_t*> groups_;
std::map<buffer_t*, std::vector<ir::value*>> values_;
};
}

View File

@@ -66,6 +66,7 @@ static const std::map<ir::value_id_t, inst_storage_info_t> storage_info = {
// intrinsics
{ ir::INST_COPY_TO_SHARED, {SHARED, {DISTRIBUTED}}},
{ ir::INST_COPY_FROM_SHARED, {DISTRIBUTED, {SHARED}}},
{ ir::INST_BARRIER, {NONE, {}}},
{ ir::INST_MAKE_RANGE_DYN, {DISTRIBUTED, {}}},
{ ir::INST_MAKE_RANGE_STA, {DISTRIBUTED, {}}},

View File

@@ -189,6 +189,7 @@ private:
void lower_splat(ir::splat_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
void lower_broadcast(ir::broadcast_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
void lower_copy_to_shared(ir::copy_to_shared_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
void lower_copy_from_shared(ir::copy_from_shared_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
void lower_trans(ir::trans_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
// matrix multiply
void lower_hmma_dot(ir::dot_inst *x, LLVMContext &ctx, Function *fn, Builder &builder,

View File

@@ -17,6 +17,7 @@ namespace analysis{
class allocation;
class liveness;
class layout;
class cts;
}
@@ -40,12 +41,13 @@ private:
std::set<ir::instruction *> &insert_loc, std::set<triton::ir::value *> &safe_war);
public:
membar(analysis::liveness *liveness, analysis::allocation *alloc):
liveness_(liveness), alloc_(alloc) {}
membar(analysis::liveness *liveness, analysis::layout *layouts, analysis::allocation *alloc):
liveness_(liveness), layouts_(layouts), alloc_(alloc) {}
void run(ir::module &mod);
private:
analysis::liveness *liveness_;
analysis::layout *layouts_;
analysis::allocation *alloc_;
};

View File

@@ -141,6 +141,7 @@ public:
value *create_select(value *pred, value *if_value, value *else_value, const std::string &name = "");
// Intrinsics
value *create_copy_to_shared(value *arg, const std::string &name = "");
value *create_copy_from_shared(value *arg, const std::string &name = "");
value *create_barrier(const std::string &name = "");
private:

View File

@@ -133,6 +133,7 @@ enum value_id_t: unsigned {
INST_DOT,
// intrinsics
INST_COPY_TO_SHARED,
INST_COPY_FROM_SHARED,
INST_BARRIER,
INST_MAKE_RANGE_DYN,
INST_MAKE_RANGE_STA,

View File

@@ -678,6 +678,17 @@ public:
_TRITON_DEFINE_CLONE(copy_to_shared_inst)
};
class copy_from_shared_inst: public unary_inst{
private:
using unary_inst::unary_inst;
std::string repr_impl() const { return "copy_from_shared"; }
public:
static copy_from_shared_inst* create(value *arg, const std::string &name = "",
instruction *next = nullptr);
_TRITON_DEFINE_CLONE(copy_from_shared_inst)
};
class barrier_inst: public instruction{
private:
barrier_inst(context &ctx, const std::string &name, instruction *next);

View File

@@ -1,5 +1,6 @@
#include <algorithm>
#include <climits>
#include "triton/codegen/analysis/layout.h"
#include "triton/codegen/analysis/allocation.h"
#include "triton/codegen/analysis/liveness.h"
#include "triton/codegen/transform/cts.h"
@@ -20,22 +21,22 @@ void allocation::run(ir::module &mod) {
using std::min;
typedef std::multimap<unsigned, segment> triples_map_type;
std::vector<buffer_t*> I;
std::vector<layout_t*> I;
for(auto x: liveness_->intervals())
I.push_back(x.first);
std::vector<buffer_t*> J = I;
std::vector<layout_t*> J = I;
triples_map_type H;
H.insert({0, segment{0, INT_MAX}});
std::vector<buffer_t*> V;
std::map<buffer_t*, unsigned> starts;
std::vector<layout_t*> V;
std::map<layout_t*, unsigned> starts;
while(!J.empty()){
auto h_it = H.begin();
unsigned w = h_it->first;
segment xh = h_it->second;
H.erase(h_it);
auto j_it = std::find_if(J.begin(), J.end(), [&](buffer_t* JJ){
auto j_it = std::find_if(J.begin(), J.end(), [&](layout_t* JJ){
segment xj = liveness_->get_interval(JJ);
bool res = xj.intersect(xh);
for(auto val: H)
@@ -57,9 +58,9 @@ void allocation::run(ir::module &mod) {
}
// Build interference graph
std::map<buffer_t*, std::set<buffer_t*>> interferences;
for(buffer_t* x: V)
for(buffer_t* y: V){
std::map<layout_t*, std::set<layout_t*>> interferences;
for(layout_t* x: V)
for(layout_t* y: V){
if(x->id == y->id)
continue;
unsigned X0 = starts[x], Y0 = starts[y];
@@ -73,17 +74,17 @@ void allocation::run(ir::module &mod) {
}
// Initialize colors
std::map<buffer_t*, int> colors;
for(buffer_t* X: V)
std::map<layout_t*, int> colors;
for(layout_t* X: V)
colors[X] = (X->id==V[0]->id)?0:-1;
// First-fit graph coloring
std::vector<bool> available(V.size());
for(buffer_t* x: V){
for(layout_t* x: V){
// Non-neighboring colors are available
std::fill(available.begin(), available.end(), true);
for(buffer_t* Y: interferences[x]){
for(layout_t* Y: interferences[x]){
int color = colors[Y];
if(color >= 0)
available[color] = false;
@@ -94,12 +95,12 @@ void allocation::run(ir::module &mod) {
}
// Finalize allocation
for(buffer_t* x: V){
for(layout_t* x: V){
unsigned Adj = 0;
for(buffer_t* y: interferences[x])
for(layout_t* y: interferences[x])
Adj = std::max<unsigned>(Adj, starts[y] + y->size);
// create offsets
for(ir::value *v: liveness_->get_values(x)){
for(ir::value *v: x->values){
offsets_[v] = starts[x] + colors[x] * Adj;
if(liveness_->has_double(v)){
auto info = liveness_->get_double(v);
@@ -110,7 +111,7 @@ void allocation::run(ir::module &mod) {
// Save maximum size of induced memory space
allocated_size_ = 0;
for(buffer_t* x: V)
for(layout_t* x: V)
allocated_size_ = std::max<size_t>(allocated_size_, starts[x] + x->size);
}

View File

@@ -105,17 +105,23 @@ void axes::update_graph_elementwise(ir::instruction *i) {
}
}
void axes::update_graph_no_edge(ir::instruction *i) {
auto rank = i->get_type()->get_tile_rank();
for(unsigned d = 0; d < rank; d++)
graph_.add_edge({i, d}, {i, d});
}
void axes::update_graph(ir::instruction *i) {
switch (i->get_id()) {
case ir::INST_REDUCE: return update_graph_reduce(i);
case ir::INST_RESHAPE: return update_graph_reshape(i);
case ir::INST_SPLAT: return;
case ir::INST_TRANS: return update_graph_trans(i);
case ir::INST_BROADCAST: return update_graph_broadcast(i);
case ir::INST_DOT: return update_graph_dot(i);
case ir::INST_COPY_TO_SHARED: return;
default: return update_graph_elementwise(i);
case ir::INST_REDUCE: return update_graph_reduce(i);
case ir::INST_RESHAPE: return update_graph_reshape(i);
case ir::INST_SPLAT: return update_graph_no_edge(i);;
case ir::INST_TRANS: return update_graph_trans(i);
case ir::INST_BROADCAST: return update_graph_broadcast(i);
case ir::INST_DOT: return update_graph_dot(i);
case ir::INST_COPY_TO_SHARED: return update_graph_no_edge(i);;
case ir::INST_COPY_FROM_SHARED: return update_graph_no_edge(i);
default: return update_graph_elementwise(i);
}
return;
}

View File

@@ -45,6 +45,8 @@ void layout::connect(ir::value *x, ir::value *y) {
std::set_intersection(sx_axes.begin(), sx_axes.end(),
sy_axes.begin(), sy_axes.end(),
std::inserter(common, common.begin()));
graph_.add_edge(x, x);
graph_.add_edge(y, y);
if(!common.empty())
graph_.add_edge(x, y);
}
@@ -89,6 +91,23 @@ void extract_io_use(ir::value *v, std::set<ir::value*>& result) {
}
}
void extract_dot_use(ir::value *v, ir::value*& result, size_t n) {
for(ir::user* u: v->get_users()){
auto i = dynamic_cast<ir::dot_inst*>(u);
if(i && i->get_operand(n) == v)
result = v;
}
}
void extract_hmma_dot_use(ir::value *v, ir::value*& result, size_t n) {
for(ir::user* u: v->get_users()){
auto i = dynamic_cast<ir::dot_inst*>(u);
if(i && is_hmma_c(i) && i->get_operand(n) == v)
result = v;
}
}
inline bool is_trans(ir::value *v) {
if(dynamic_cast<ir::trans_inst *>(v)) {
@@ -108,14 +127,14 @@ inline bool is_trans(ir::value *v) {
layout_t::layout_t(layout_type_t _type,
const std::vector<int> &_axes,
const std::vector<unsigned> &_shapes,
const std::vector<ir::value *> &values,
analysis::align* align): type(_type), axes(_axes), shapes(_shapes) {
const std::vector<ir::value *> &_values,
size_t _id,
analysis::align* align): type(_type), axes(_axes), shapes(_shapes), values(_values), id(_id) {
// io pointer
std::set<ir::value*> ptr;
for(ir::value* v: values)
extract_io_use(v, ptr);
size_t rank = axes.size();
std::vector<int> order(rank);
order.resize(axes.size());
std::iota(order.begin(), order.end(), 0);
for(ir::value *v: ptr){
auto max_contiguous = align->contiguous(v);
@@ -123,7 +142,6 @@ layout_t::layout_t(layout_type_t _type,
return max_contiguous[a] > max_contiguous[b];
});
}
this->order = order;
}
inline unsigned clamp(unsigned x, unsigned lo, unsigned hi) {
@@ -133,8 +151,8 @@ inline unsigned clamp(unsigned x, unsigned lo, unsigned hi) {
layout_hmma_884_t::layout_hmma_884_t(size_t num_warps,
const std::vector<int>& _axes,
const std::vector<unsigned>& _shapes,
const std::vector<ir::value *> &values,
analysis::align* align): layout_t(HMMA_884, _axes, _shapes, values, align) {
const std::vector<ir::value *> &values, size_t _id,
analysis::align* align): layout_t(HMMA_884, _axes, _shapes, values, _id, align) {
unsigned shape_0 = shapes[order[0]];
unsigned shape_1 = shapes[order[1]];
@@ -173,7 +191,8 @@ layout_scanline_t::layout_scanline_t(size_t num_warps,
const std::vector<int>& _axes,
const std::vector<unsigned>& _shapes,
const std::vector<ir::value *> &values,
analysis::align* align): layout_t(SCANLINE, _axes, _shapes, values, align){
size_t _id,
analysis::align* align): layout_t(SCANLINE, _axes, _shapes, values, _id, align){
unsigned size = std::accumulate(shapes.begin(), shapes.end(), 1, std::multiplies<int>());
unsigned num_threads = num_warps * 32;
nts.resize(shapes.size());
@@ -196,6 +215,58 @@ layout_scanline_t::layout_scanline_t(size_t num_warps,
throw std::runtime_error("cannot create a kernel with this amount of warps");
}
layout_shared_t::layout_shared_t(const layout_t *arg,
const std::vector<int>& _axes,
const std::vector<unsigned>& _shapes,
const std::vector<ir::value *> &values,
size_t _id,
analysis::align* align): layout_t(SHARED, _axes, _shapes, values, _id, align) {
if(arg->type == SCANLINE)
order = arg->order;
ir::value* dot_a = nullptr;
ir::value* dot_b = nullptr;
ir::value* hmma_dot_a = nullptr;
ir::value* hmma_dot_b = nullptr;
for(ir::value* v: values){
extract_dot_use(v, dot_a, 0);
extract_dot_use(v, dot_b, 1);
extract_hmma_dot_use(v, hmma_dot_a, 0);
extract_hmma_dot_use(v, hmma_dot_b, 1);
}
std::vector<int> col = {0, 1};
std::vector<int> row = {1, 0};
if(dot_a && !hmma_dot_a)
order = is_trans(dot_a) ? row : col;
if(dot_b && !hmma_dot_b)
order = is_trans(dot_b) ? col : row;
}
void layout::create(size_t id, const std::vector<ir::value*>& values) {
auto it_hmma_c = std::find_if(values.begin(), values.end(), &is_hmma_c);
auto cmp = [](ir::value* x, ir::value *y) {
return x->get_type()->get_tile_ranks1() <
y->get_type()->get_tile_ranks1();
};
ir::value *largest = *std::max_element(values.begin(), values.end(), cmp);
const auto& axes = axes_->get(largest);
const auto& shapes = largest->get_type()->get_tile_shapes();
auto it_cts = std::find_if(values.begin(), values.end(), [](ir::value* v) {
return dynamic_cast<ir::copy_to_shared_inst*>(v);
});
// type
if(it_hmma_c != values.end())
layouts_[id] = new layout_hmma_884_t(num_warps_, axes, shapes, values, id, align_);
else if(it_cts != values.end()){
ir::copy_to_shared_inst *cts = (ir::copy_to_shared_inst*)*it_cts;
ir::value *arg = cts->get_operand(0);
create(groups_.at(arg), values_.at(groups_.at(arg)));
layouts_[id] = new layout_shared_t(get(arg), axes, shapes, values, id, align_);
}
else
layouts_[id] = new layout_scanline_t(num_warps_, axes, shapes, values, id, align_);
}
void layout::run(ir::module &mod) {
// make graph
@@ -208,50 +279,8 @@ void layout::run(ir::module &mod) {
graph_.connected_components(&values_, &groups_);
// create layouts
for(const auto& x: values_) {
bool hmma_c = std::any_of(x.second.begin(), x.second.end(), &is_hmma_c);
auto cmp = [](ir::value* x, ir::value *y) {
return x->get_type()->get_tile_ranks1() <
y->get_type()->get_tile_ranks1();
};
ir::value *largest = *std::max_element(x.second.begin(), x.second.end(), cmp);
const auto& axes = axes_->get(largest);
const auto& shapes = largest->get_type()->get_tile_shapes();
// type
if(hmma_c)
layouts_[x.first] = new layout_hmma_884_t(num_warps_, axes, shapes, x.second, align_);
else
layouts_[x.first] = new layout_scanline_t(num_warps_, axes, shapes, x.second, align_);
}
// matrix multiplication optimizations
for(const auto& x: values_) {
std::vector<ir::dot_inst*> dots;
for(ir::value* v: x.second)
if(auto *x = dynamic_cast<ir::dot_inst*>(v))
dots.push_back(x);
for(ir::dot_inst* dot: dots){
ir::value* a = dot->get_operand(0);
ir::value* b = dot->get_operand(1);
if(get(dot)->type == HMMA_884){
auto a_val = values_of(layout_of(a));
auto b_val = values_of(layout_of(b));
for(ir::value *v: a_val)
if(auto *cts = dynamic_cast<ir::copy_to_shared_inst*>(v))
layouts_[layout_of(a)]->order = layouts_[layout_of(cts->get_operand(0))]->order;
for(ir::value *v: b_val)
if(auto *cts = dynamic_cast<ir::copy_to_shared_inst*>(v))
layouts_[layout_of(b)]->order = layouts_[layout_of(cts->get_operand(0))]->order;
}
else{
std::vector<int> col = {0, 1};
std::vector<int> row = {1, 0};
layouts_[layout_of(a)]->order = is_trans(a) ? row : col;
layouts_[layout_of(b)]->order = is_trans(b) ? col : row;
}
}
}
for(const auto& x: values_)
create(x.first, x.second);
}
}

View File

@@ -42,7 +42,7 @@ void liveness::extract_double_bufferable(ir::instruction *i) {
ir::value *value_1 = phi->get_incoming_value(1);
ir::instruction *i_0 = dynamic_cast<ir::instruction*>(value_0);
ir::instruction *i_1 = dynamic_cast<ir::instruction*>(value_1);
if(!i_0 || !i_1 || storage_info.at(i_0->get_id()).first != SHARED || storage_info.at(i_1->get_id()).first != SHARED)
if(!i_0 || !i_1 || storage_info.at(i_0->get_id()).first != codegen::SHARED || storage_info.at(i_1->get_id()).first != codegen::SHARED)
return;
if(is_latch_1)
double_[value_0] = double_buffer_info_t{value_1, phi};
@@ -50,21 +50,6 @@ void liveness::extract_double_bufferable(ir::instruction *i) {
double_[value_1] = double_buffer_info_t{value_0, phi};
}
void liveness::make_graph(ir::instruction *i) {
if(has_double(i)){
ir::value *latch = double_[i].latch;
graph_.add_edge(i, latch);
}
if(storage_info.at(i->get_id()).first == SHARED){
graph_.add_edge(i, i);
for(ir::value* op: i->ops()){
auto* iop = dynamic_cast<ir::instruction*>(op);
if(!iop || storage_info.at(iop->get_id()).first != SHARED)
continue;
graph_.add_edge(i, op);
}
}
}
// connected components
bool is_trans(ir::value *v) {
@@ -151,7 +136,6 @@ void liveness::run(ir::module &mod) {
indices.clear();
pad_.clear();
intervals_.clear();
graph_.clear();
// Create set of pair of values that can be double-buffered
ir::for_each_instruction(mod, [this](ir::instruction* i) {
@@ -167,22 +151,14 @@ void liveness::run(ir::module &mod) {
});
}while(has_changed);
// Create buffer dependency graph
ir::for_each_instruction(mod, [this](ir::instruction* i) {
this->make_graph(i);
});
// connected components
tools::graph<node_t>::cmap_t cmap;
tools::graph<node_t>::nmap_t nmap;
graph_.connected_components(&cmap, &nmap);
for(auto x: cmap) {
buffer_t* buffer = new buffer_t{x.first};
values_[buffer] = x.second;
for(ir::value *v: x.second){
buffer->size = std::max<int>(buffer->size, num_bytes(v));
groups_[v] = buffer;
}
for(auto &x: layouts_->get_all()) {
layout_t* layout = x.second;
if(layout->type != SHARED)
continue;
for(ir::value *v: layout->values)
layout->size = std::max<int>(layout->size, num_bytes(v));
}
// Assigns index to each instruction
@@ -195,22 +171,25 @@ void liveness::run(ir::module &mod) {
}
}
for(auto x: values_) {
for(auto &x: layouts_->get_all()) {
layout_t* layout = x.second;
if(layout->type != SHARED)
continue;
// users
std::set<ir::value*> values;
for(ir::value *v: x.second){
values.insert(v);
std::set<ir::value*> users;
for(ir::value *v: layout->values){
users.insert(v);
for(ir::user *u: v->get_users())
values.insert(u);
users.insert(u);
}
// compute intervals
unsigned start = INT32_MAX;
unsigned end = 0;
for(ir::value *u: values){
for(ir::value *u: users){
start = std::min(start, indices.at(u));
end = std::max(end, indices.at(u));
}
intervals_[x.first] = segment{start, end};
intervals_[layout] = segment{start, end};
}

View File

@@ -486,21 +486,7 @@ Instruction *selection::llvm_inst(ir::instruction *inst, std::function<Value*(ir
return (Instruction*)res;
}
if(ir::atomic_add_inst* ii = dynamic_cast<ir::atomic_add_inst*>(inst)){
// Value *ptr = value(ii->get_operand(0));
// Value *val = value(ii->get_operand(1));
// Value *atom_f_add = nullptr;
// if(val->getType()->isFloatTy())
// atom_f_add = Intrinsic::getDeclaration(builder.GetInsertBlock()->getModule(), Intrinsic::nvvm_atomic_load_add_f32, {ptr->getType()});
// else if(val->getType()->isHalfTy()){
// Type *fp16 = Type::getHalfTy(ctx);
// FunctionType *atom_ty = FunctionType::get(fp16, {fp16->getPointerTo(), fp16}, false);
// atom_f_add = InlineAsm::get(atom_ty, " atom.relaxed.global.gpu.add.noftz.f16 $0, [$1], $2;", "=h,l,h", true);
// }
// if(atom_f_add == nullptr)
throw std::runtime_error("unsupported");
// Value *res = builder.CreateCall(atom_f_add, {ptr, val});
// return (Instruction*)res;
}
if(ir::sqrt_inst* ii = dynamic_cast<ir::sqrt_inst*>(inst)){
Value *val = value(ii->get_operand(0));
@@ -711,7 +697,7 @@ void selection::init_hmma_axes(const analysis::layout_t& layout, IRBuilder<> &bu
void selection::init_axes(const analysis::layout_t& layout, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
if(layout.type == analysis::HMMA_884)
init_hmma_axes(layout, builder, u_thread_id, u_warp_id);
else
else if(layout.type == analysis::SCANLINE)
init_strided_scan_axes(layout, builder, u_thread_id, u_warp_id);
}
@@ -801,7 +787,7 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder,
for(ir::value *op: user->ops())
create_tile(op, builder, seen, sh_mem_ptr);
auto *i = dynamic_cast<ir::instruction*>(v);
if(i && storage_info.at(i->get_id()).first == SHARED && !dynamic_cast<ir::reduce_inst*>(v))
if(i && layouts_->get(i)->type == analysis::SHARED && !dynamic_cast<ir::reduce_inst*>(v))
create_shared_tile(i, builder, sh_mem_ptr);
else
create_distributed_tile(v, builder);
@@ -1018,7 +1004,7 @@ void selection::lower_copy_to_shared(ir::copy_to_shared_inst *x, LLVMContext &ct
distributed_tile* in = (distributed_tile*)tmap_.at(arg);
if(x_order == arg_order){
size_t ld = arg_order[0];
vector_size = std::min(layouts_->get(x)->nts.at(ld), layouts_->get(arg)->nts.at(ld));
vector_size = layouts_->get(arg)->nts.at(ld);
}
std::map<unsigned, Value*> packets;
@@ -1038,6 +1024,15 @@ void selection::lower_copy_to_shared(ir::copy_to_shared_inst *x, LLVMContext &ct
});
}
void selection::lower_copy_from_shared(ir::copy_from_shared_inst *x, LLVMContext &ctx, Function *fn, IRBuilder<> &builder) {
distributed_tile* result = (distributed_tile*)tmap_.at(x);
shared_tile* arg = (shared_tile*)tmap_.at(x->get_operand(0));
result->for_each([&](indices_t idx){
result->set_value(idx, arg->get_value(idx));
});
}
void selection::lower_trans(ir::trans_inst *x, LLVMContext &ctx, Function *fn, IRBuilder<> &builder) {
shared_tile* in = (shared_tile*)tmap_.at(x->get_operand(0));
shared_tile* out = new shared_tile(in->get_ty(), in->get_shapes(), in->get_order(), in->get_pointer(), builder, in->get_offset(), x->get_perm());
@@ -1399,6 +1394,8 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
lower_broadcast(x, ctx, fn, builder);
else if(auto *x = dynamic_cast<ir::copy_to_shared_inst*>(ins))
lower_copy_to_shared(x, ctx, fn, builder);
else if(auto *x = dynamic_cast<ir::copy_from_shared_inst*>(ins))
lower_copy_from_shared(x, ctx, fn, builder);
else if(auto* x = dynamic_cast<ir::trans_inst*>(ins))
lower_trans(x, ctx, fn, builder);
else if(auto x = dynamic_cast<ir::dot_inst*>(ins))
@@ -1554,7 +1551,7 @@ void selection::run(ir::module &src, Module &dst) {
}
else {
unsigned num_bytes = inst->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
offset->addIncoming(dst_builder.getInt32(liveness_->get_buffer(inst)->size / (2*num_bytes)), llvm_inc_block);
offset->addIncoming(dst_builder.getInt32(layouts_->get(inst)->size / (2*num_bytes)), llvm_inc_block);
}
ptr->addIncoming(inc_shared->get_pointer(), llvm_inc_block);
}

View File

@@ -12,30 +12,45 @@ namespace triton {
namespace codegen{
namespace transform{
inline bool is_shared(ir::value *v) {
auto *i = dynamic_cast<ir::instruction*>(v);
if(!i)
return false;
return storage_info.at(i->get_id()).first == codegen::SHARED;
}
// run pass on module
void add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder) {
void add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared) {
auto *i = dynamic_cast<ir::instruction*>(x);
// not an instruction
if(!i) {
builder.set_insert_point(parent);
ir::value *cts = builder.create_copy_to_shared(x);
parent->replace_uses_of_with(x, cts);
ir::value *copy;
if(to_shared)
copy = builder.create_copy_to_shared(x);
else
copy = builder.create_copy_from_shared(x);
parent->replace_uses_of_with(x, copy);
return;
}
// phi node
if(auto* phi = dynamic_cast<ir::phi_node*>(x)) {
for(unsigned i = 0; i < phi->get_num_incoming(); ++i)
add_copy(phi, phi->get_incoming_value(i), builder);
add_copy(phi, phi->get_incoming_value(i), builder, to_shared);
return;
}
ir::value_id_t id = i->get_id();
// already in shared memory
if(storage_info.at(id).first == SHARED)
if(to_shared && storage_info.at(id).first == SHARED)
return;
// copy
builder.set_insert_point_after(i);
ir::value *cts = builder.create_copy_to_shared(x);
parent->replace_uses_of_with(x, cts);
ir::value *copy;
if(to_shared)
copy = builder.create_copy_to_shared(x);
else
copy = builder.create_copy_from_shared(x);
parent->replace_uses_of_with(x, copy);
}
void cts::run(ir::module &mod) {
@@ -45,10 +60,16 @@ void cts::run(ir::module &mod) {
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *i: block->get_inst_list()){
auto storage = storage_info.at(i->get_id());
// copy to shared operands when necessary
// copy to shared operands
for(size_t k = 0; k < storage.second.size(); k++)
if(storage.second[k] == SHARED)
add_copy(i, i->get_operand(k), builder);
add_copy(i, i->get_operand(k), builder, true);
// copy from shared operands
for(size_t k = 0; k < storage.second.size(); k++)
if(storage.second[k] == DISTRIBUTED &&
is_shared(i->get_operand(k))){
add_copy(i, i->get_operand(k), builder, false);
}
}
}
}

View File

@@ -3,6 +3,7 @@
#include <algorithm>
#include "triton/codegen/analysis/liveness.h"
#include "triton/codegen/analysis/layout.h"
#include "triton/codegen/analysis/allocation.h"
#include "triton/codegen/instructions.h"
#include "triton/codegen/transform/membar.h"
@@ -38,7 +39,7 @@ void membar::add_reference(ir::value *v, interval_vec_t &res){
return;
if(alloc_->has_offset(v)){
unsigned offset = alloc_->offset(v);
unsigned size = liveness_->get_buffer(v)->size;
unsigned size = layouts_->get(v)->size;
res.push_back(interval_t(offset, offset + size));
}
}

View File

@@ -241,7 +241,7 @@ std::string cu_module::compile_llvm_module(std::unique_ptr<llvm::Module> module,
cu_module::cu_module(driver::context * context, std::unique_ptr<llvm::Module> ll_module): cu_module(context, compile_llvm_module(std::move(ll_module), context->device())) { }
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
// std::cout << source << std::endl;
std::cout << source << std::endl;
cu_context::context_switcher ctx(*context);
// JIT compile source-code
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};

View File

@@ -347,6 +347,10 @@ value *builder::create_copy_to_shared(value *arg, const std::string &name) {
return insert(copy_to_shared_inst::create(arg, name));
}
value *builder::create_copy_from_shared(value *arg, const std::string &name) {
return insert(copy_from_shared_inst::create(arg, name));
}
value *builder::create_barrier(const std::string &name) {
return insert(barrier_inst::create(ctx_, name));
}

View File

@@ -731,12 +731,20 @@ instruction* atomic_add_inst::create(value *ptr, value *val, const std::string &
//===----------------------------------------------------------------------===//
// intrinsic instructions
//===----------------------------------------------------------------------===//
// copy to shared
copy_to_shared_inst* copy_to_shared_inst::create(value *arg, const std::string &name,
instruction *next) {
return new copy_to_shared_inst(arg->get_type(), INST_COPY_TO_SHARED, arg, name, next);
}
// copy from shared
copy_from_shared_inst* copy_from_shared_inst::create(value *arg, const std::string &name,
instruction *next) {
return new copy_from_shared_inst(arg->get_type(), INST_COPY_FROM_SHARED, arg, name, next);
}
// barrier
barrier_inst::barrier_inst(context &ctx, const std::string &name,
instruction *next)

View File

@@ -211,7 +211,7 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
codegen::analysis::layout layouts(&axes, &align, opt.num_warps);
codegen::analysis::liveness liveness(&layouts);
codegen::analysis::allocation allocation(&liveness);
codegen::transform::membar barriers(&liveness, &allocation);
codegen::transform::membar barriers(&liveness, &layouts, &allocation);
codegen::transform::dce dce;
codegen::transform::peephole peephole;
codegen::transform::reassociate reassociate(&align);
@@ -230,11 +230,11 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
align.run(module);
dce.run(module);
reassociate.run(module);
// ir::print(module, std::cout);
dce.run(module);
cts.run(module);
align.run(module);
axes.run(module);
// ir::print(module, std::cout);
layouts.run(module);
liveness.run(module);
allocation.run(module);
@@ -245,6 +245,7 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
align.run(module);
axes.run(module);
layouts.run(module);
// ir::print(module, std::cout);
selection.run(module, *llvm);
// return binary
std::unique_ptr<driver::module> res(driver::module::create(context, std::move(llvm)));