ugh
This commit is contained in:
@@ -30,6 +30,7 @@ private:
|
||||
void update_graph_broadcast(ir::instruction *i);
|
||||
void update_graph_dot(ir::instruction *i);
|
||||
void update_graph_elementwise(ir::instruction *i);
|
||||
void update_graph_no_edge(ir::instruction *i);
|
||||
void update_graph(ir::instruction *i);
|
||||
|
||||
public:
|
||||
|
@@ -23,19 +23,24 @@ class align;
|
||||
|
||||
enum layout_type_t {
|
||||
HMMA_884,
|
||||
SCANLINE
|
||||
SCANLINE,
|
||||
SHARED
|
||||
};
|
||||
|
||||
struct layout_t {
|
||||
layout_t(layout_type_t _type,
|
||||
const std::vector<int>& _axes,
|
||||
const std::vector<unsigned> &_shapes,
|
||||
const std::vector<ir::value *> &values,
|
||||
const std::vector<ir::value *> &_values,
|
||||
size_t _id,
|
||||
analysis::align* align);
|
||||
layout_type_t type;
|
||||
std::vector<int> axes;
|
||||
std::vector<unsigned> shapes;
|
||||
std::vector<ir::value*> values;
|
||||
std::vector<int> order;
|
||||
size_t id;
|
||||
size_t size;
|
||||
std::vector<int> mts;
|
||||
std::vector<int> nts;
|
||||
std::vector<int> fpw;
|
||||
@@ -46,7 +51,8 @@ struct layout_hmma_884_t: public layout_t {
|
||||
layout_hmma_884_t(size_t num_warps,
|
||||
const std::vector<int>& _axes,
|
||||
const std::vector<unsigned>& _shapes,
|
||||
const std::vector<ir::value *> &values,
|
||||
const std::vector<ir::value *> &_values,
|
||||
size_t _id,
|
||||
analysis::align* align);
|
||||
};
|
||||
|
||||
@@ -55,9 +61,20 @@ struct layout_scanline_t: public layout_t {
|
||||
const std::vector<int>& _axes,
|
||||
const std::vector<unsigned>& _shapes,
|
||||
const std::vector<ir::value *> &values,
|
||||
size_t _id,
|
||||
analysis::align* align);
|
||||
};
|
||||
|
||||
struct layout_shared_t: public layout_t {
|
||||
layout_shared_t(const layout_t *arg,
|
||||
const std::vector<int>& _axes,
|
||||
const std::vector<unsigned>& _shapes,
|
||||
const std::vector<ir::value *> &values,
|
||||
size_t _id,
|
||||
analysis::align* align);
|
||||
};
|
||||
|
||||
|
||||
class layout {
|
||||
typedef ir::value* node_t;
|
||||
typedef std::map <node_t, std::set<node_t>> graph_t;
|
||||
@@ -70,6 +87,8 @@ private:
|
||||
void init_hmma_tile(layout_t& layout);
|
||||
void init_scanline_tile(layout_t &layout);
|
||||
|
||||
void create(size_t id, const std::vector<ir::value*>& values);
|
||||
|
||||
public:
|
||||
// constructor
|
||||
layout(analysis::axes *axes, analysis::align *align, size_t num_warps);
|
||||
|
@@ -23,6 +23,7 @@ typedef unsigned slot_index;
|
||||
|
||||
class tiles;
|
||||
class layout;
|
||||
class layout_t;
|
||||
|
||||
struct segment {
|
||||
slot_index start;
|
||||
@@ -42,18 +43,11 @@ struct double_buffer_info_t {
|
||||
ir::phi_node* phi;
|
||||
};
|
||||
|
||||
struct buffer_t {
|
||||
size_t id;
|
||||
size_t size;
|
||||
bool operator<(buffer_t other) const {
|
||||
return id < other.id;
|
||||
}
|
||||
};
|
||||
|
||||
class liveness {
|
||||
private:
|
||||
typedef std::map<ir::value*, slot_index> indices_map_t;
|
||||
typedef std::map<buffer_t*, segment> intervals_map_t;
|
||||
typedef std::map<layout_t*, segment> intervals_map_t;
|
||||
typedef std::map<ir::value*, bool> has_storage_map_t;
|
||||
typedef ir::value* node_t;
|
||||
typedef std::map <node_t, std::set<node_t>> graph_t;
|
||||
@@ -82,10 +76,7 @@ public:
|
||||
unsigned num_bytes(ir::value *x);
|
||||
// accessors
|
||||
const intervals_map_t& intervals() const { return intervals_; }
|
||||
segment get_interval(buffer_t* v) const { return intervals_.at(v); }
|
||||
// buffers
|
||||
buffer_t* get_buffer(ir::value *v) const { return groups_.at(v); }
|
||||
std::vector<ir::value*> get_values(buffer_t* x) const { return values_.at(x); }
|
||||
segment get_interval(layout_t* v) const { return intervals_.at(v); }
|
||||
// double-buffering
|
||||
bool has_double(ir::value *x) const { return double_.find(x) != double_.end(); }
|
||||
double_buffer_info_t get_double(ir::value *x) const { return double_.at(x); }
|
||||
@@ -101,10 +92,6 @@ private:
|
||||
intervals_map_t intervals_;
|
||||
std::map<ir::value*, double_buffer_info_t> double_;
|
||||
std::map<ir::value*, size_t> pad_;
|
||||
// buffers
|
||||
tools::graph<node_t> graph_;
|
||||
std::map<ir::value*, buffer_t*> groups_;
|
||||
std::map<buffer_t*, std::vector<ir::value*>> values_;
|
||||
};
|
||||
|
||||
}
|
||||
|
@@ -66,6 +66,7 @@ static const std::map<ir::value_id_t, inst_storage_info_t> storage_info = {
|
||||
|
||||
// intrinsics
|
||||
{ ir::INST_COPY_TO_SHARED, {SHARED, {DISTRIBUTED}}},
|
||||
{ ir::INST_COPY_FROM_SHARED, {DISTRIBUTED, {SHARED}}},
|
||||
{ ir::INST_BARRIER, {NONE, {}}},
|
||||
{ ir::INST_MAKE_RANGE_DYN, {DISTRIBUTED, {}}},
|
||||
{ ir::INST_MAKE_RANGE_STA, {DISTRIBUTED, {}}},
|
||||
|
@@ -189,6 +189,7 @@ private:
|
||||
void lower_splat(ir::splat_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
|
||||
void lower_broadcast(ir::broadcast_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
|
||||
void lower_copy_to_shared(ir::copy_to_shared_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
|
||||
void lower_copy_from_shared(ir::copy_from_shared_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
|
||||
void lower_trans(ir::trans_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
|
||||
// matrix multiply
|
||||
void lower_hmma_dot(ir::dot_inst *x, LLVMContext &ctx, Function *fn, Builder &builder,
|
||||
|
@@ -17,6 +17,7 @@ namespace analysis{
|
||||
|
||||
class allocation;
|
||||
class liveness;
|
||||
class layout;
|
||||
class cts;
|
||||
|
||||
}
|
||||
@@ -40,12 +41,13 @@ private:
|
||||
std::set<ir::instruction *> &insert_loc, std::set<triton::ir::value *> &safe_war);
|
||||
|
||||
public:
|
||||
membar(analysis::liveness *liveness, analysis::allocation *alloc):
|
||||
liveness_(liveness), alloc_(alloc) {}
|
||||
membar(analysis::liveness *liveness, analysis::layout *layouts, analysis::allocation *alloc):
|
||||
liveness_(liveness), layouts_(layouts), alloc_(alloc) {}
|
||||
void run(ir::module &mod);
|
||||
|
||||
private:
|
||||
analysis::liveness *liveness_;
|
||||
analysis::layout *layouts_;
|
||||
analysis::allocation *alloc_;
|
||||
};
|
||||
|
||||
|
@@ -141,6 +141,7 @@ public:
|
||||
value *create_select(value *pred, value *if_value, value *else_value, const std::string &name = "");
|
||||
// Intrinsics
|
||||
value *create_copy_to_shared(value *arg, const std::string &name = "");
|
||||
value *create_copy_from_shared(value *arg, const std::string &name = "");
|
||||
value *create_barrier(const std::string &name = "");
|
||||
|
||||
private:
|
||||
|
@@ -133,6 +133,7 @@ enum value_id_t: unsigned {
|
||||
INST_DOT,
|
||||
// intrinsics
|
||||
INST_COPY_TO_SHARED,
|
||||
INST_COPY_FROM_SHARED,
|
||||
INST_BARRIER,
|
||||
INST_MAKE_RANGE_DYN,
|
||||
INST_MAKE_RANGE_STA,
|
||||
|
@@ -678,6 +678,17 @@ public:
|
||||
_TRITON_DEFINE_CLONE(copy_to_shared_inst)
|
||||
};
|
||||
|
||||
class copy_from_shared_inst: public unary_inst{
|
||||
private:
|
||||
using unary_inst::unary_inst;
|
||||
std::string repr_impl() const { return "copy_from_shared"; }
|
||||
|
||||
public:
|
||||
static copy_from_shared_inst* create(value *arg, const std::string &name = "",
|
||||
instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(copy_from_shared_inst)
|
||||
};
|
||||
|
||||
class barrier_inst: public instruction{
|
||||
private:
|
||||
barrier_inst(context &ctx, const std::string &name, instruction *next);
|
||||
|
@@ -1,5 +1,6 @@
|
||||
#include <algorithm>
|
||||
#include <climits>
|
||||
#include "triton/codegen/analysis/layout.h"
|
||||
#include "triton/codegen/analysis/allocation.h"
|
||||
#include "triton/codegen/analysis/liveness.h"
|
||||
#include "triton/codegen/transform/cts.h"
|
||||
@@ -20,22 +21,22 @@ void allocation::run(ir::module &mod) {
|
||||
using std::min;
|
||||
typedef std::multimap<unsigned, segment> triples_map_type;
|
||||
|
||||
std::vector<buffer_t*> I;
|
||||
std::vector<layout_t*> I;
|
||||
for(auto x: liveness_->intervals())
|
||||
I.push_back(x.first);
|
||||
std::vector<buffer_t*> J = I;
|
||||
std::vector<layout_t*> J = I;
|
||||
|
||||
triples_map_type H;
|
||||
H.insert({0, segment{0, INT_MAX}});
|
||||
|
||||
std::vector<buffer_t*> V;
|
||||
std::map<buffer_t*, unsigned> starts;
|
||||
std::vector<layout_t*> V;
|
||||
std::map<layout_t*, unsigned> starts;
|
||||
while(!J.empty()){
|
||||
auto h_it = H.begin();
|
||||
unsigned w = h_it->first;
|
||||
segment xh = h_it->second;
|
||||
H.erase(h_it);
|
||||
auto j_it = std::find_if(J.begin(), J.end(), [&](buffer_t* JJ){
|
||||
auto j_it = std::find_if(J.begin(), J.end(), [&](layout_t* JJ){
|
||||
segment xj = liveness_->get_interval(JJ);
|
||||
bool res = xj.intersect(xh);
|
||||
for(auto val: H)
|
||||
@@ -57,9 +58,9 @@ void allocation::run(ir::module &mod) {
|
||||
}
|
||||
|
||||
// Build interference graph
|
||||
std::map<buffer_t*, std::set<buffer_t*>> interferences;
|
||||
for(buffer_t* x: V)
|
||||
for(buffer_t* y: V){
|
||||
std::map<layout_t*, std::set<layout_t*>> interferences;
|
||||
for(layout_t* x: V)
|
||||
for(layout_t* y: V){
|
||||
if(x->id == y->id)
|
||||
continue;
|
||||
unsigned X0 = starts[x], Y0 = starts[y];
|
||||
@@ -73,17 +74,17 @@ void allocation::run(ir::module &mod) {
|
||||
}
|
||||
|
||||
// Initialize colors
|
||||
std::map<buffer_t*, int> colors;
|
||||
for(buffer_t* X: V)
|
||||
std::map<layout_t*, int> colors;
|
||||
for(layout_t* X: V)
|
||||
colors[X] = (X->id==V[0]->id)?0:-1;
|
||||
|
||||
|
||||
// First-fit graph coloring
|
||||
std::vector<bool> available(V.size());
|
||||
for(buffer_t* x: V){
|
||||
for(layout_t* x: V){
|
||||
// Non-neighboring colors are available
|
||||
std::fill(available.begin(), available.end(), true);
|
||||
for(buffer_t* Y: interferences[x]){
|
||||
for(layout_t* Y: interferences[x]){
|
||||
int color = colors[Y];
|
||||
if(color >= 0)
|
||||
available[color] = false;
|
||||
@@ -94,12 +95,12 @@ void allocation::run(ir::module &mod) {
|
||||
}
|
||||
|
||||
// Finalize allocation
|
||||
for(buffer_t* x: V){
|
||||
for(layout_t* x: V){
|
||||
unsigned Adj = 0;
|
||||
for(buffer_t* y: interferences[x])
|
||||
for(layout_t* y: interferences[x])
|
||||
Adj = std::max<unsigned>(Adj, starts[y] + y->size);
|
||||
// create offsets
|
||||
for(ir::value *v: liveness_->get_values(x)){
|
||||
for(ir::value *v: x->values){
|
||||
offsets_[v] = starts[x] + colors[x] * Adj;
|
||||
if(liveness_->has_double(v)){
|
||||
auto info = liveness_->get_double(v);
|
||||
@@ -110,7 +111,7 @@ void allocation::run(ir::module &mod) {
|
||||
|
||||
// Save maximum size of induced memory space
|
||||
allocated_size_ = 0;
|
||||
for(buffer_t* x: V)
|
||||
for(layout_t* x: V)
|
||||
allocated_size_ = std::max<size_t>(allocated_size_, starts[x] + x->size);
|
||||
}
|
||||
|
||||
|
@@ -105,17 +105,23 @@ void axes::update_graph_elementwise(ir::instruction *i) {
|
||||
}
|
||||
}
|
||||
|
||||
void axes::update_graph_no_edge(ir::instruction *i) {
|
||||
auto rank = i->get_type()->get_tile_rank();
|
||||
for(unsigned d = 0; d < rank; d++)
|
||||
graph_.add_edge({i, d}, {i, d});
|
||||
}
|
||||
|
||||
void axes::update_graph(ir::instruction *i) {
|
||||
switch (i->get_id()) {
|
||||
case ir::INST_REDUCE: return update_graph_reduce(i);
|
||||
case ir::INST_RESHAPE: return update_graph_reshape(i);
|
||||
case ir::INST_SPLAT: return;
|
||||
case ir::INST_TRANS: return update_graph_trans(i);
|
||||
case ir::INST_BROADCAST: return update_graph_broadcast(i);
|
||||
case ir::INST_DOT: return update_graph_dot(i);
|
||||
case ir::INST_COPY_TO_SHARED: return;
|
||||
default: return update_graph_elementwise(i);
|
||||
case ir::INST_REDUCE: return update_graph_reduce(i);
|
||||
case ir::INST_RESHAPE: return update_graph_reshape(i);
|
||||
case ir::INST_SPLAT: return update_graph_no_edge(i);;
|
||||
case ir::INST_TRANS: return update_graph_trans(i);
|
||||
case ir::INST_BROADCAST: return update_graph_broadcast(i);
|
||||
case ir::INST_DOT: return update_graph_dot(i);
|
||||
case ir::INST_COPY_TO_SHARED: return update_graph_no_edge(i);;
|
||||
case ir::INST_COPY_FROM_SHARED: return update_graph_no_edge(i);
|
||||
default: return update_graph_elementwise(i);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
@@ -45,6 +45,8 @@ void layout::connect(ir::value *x, ir::value *y) {
|
||||
std::set_intersection(sx_axes.begin(), sx_axes.end(),
|
||||
sy_axes.begin(), sy_axes.end(),
|
||||
std::inserter(common, common.begin()));
|
||||
graph_.add_edge(x, x);
|
||||
graph_.add_edge(y, y);
|
||||
if(!common.empty())
|
||||
graph_.add_edge(x, y);
|
||||
}
|
||||
@@ -89,6 +91,23 @@ void extract_io_use(ir::value *v, std::set<ir::value*>& result) {
|
||||
}
|
||||
}
|
||||
|
||||
void extract_dot_use(ir::value *v, ir::value*& result, size_t n) {
|
||||
for(ir::user* u: v->get_users()){
|
||||
auto i = dynamic_cast<ir::dot_inst*>(u);
|
||||
if(i && i->get_operand(n) == v)
|
||||
result = v;
|
||||
}
|
||||
}
|
||||
|
||||
void extract_hmma_dot_use(ir::value *v, ir::value*& result, size_t n) {
|
||||
for(ir::user* u: v->get_users()){
|
||||
auto i = dynamic_cast<ir::dot_inst*>(u);
|
||||
if(i && is_hmma_c(i) && i->get_operand(n) == v)
|
||||
result = v;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline bool is_trans(ir::value *v) {
|
||||
if(dynamic_cast<ir::trans_inst *>(v)) {
|
||||
@@ -108,14 +127,14 @@ inline bool is_trans(ir::value *v) {
|
||||
layout_t::layout_t(layout_type_t _type,
|
||||
const std::vector<int> &_axes,
|
||||
const std::vector<unsigned> &_shapes,
|
||||
const std::vector<ir::value *> &values,
|
||||
analysis::align* align): type(_type), axes(_axes), shapes(_shapes) {
|
||||
const std::vector<ir::value *> &_values,
|
||||
size_t _id,
|
||||
analysis::align* align): type(_type), axes(_axes), shapes(_shapes), values(_values), id(_id) {
|
||||
// io pointer
|
||||
std::set<ir::value*> ptr;
|
||||
for(ir::value* v: values)
|
||||
extract_io_use(v, ptr);
|
||||
size_t rank = axes.size();
|
||||
std::vector<int> order(rank);
|
||||
order.resize(axes.size());
|
||||
std::iota(order.begin(), order.end(), 0);
|
||||
for(ir::value *v: ptr){
|
||||
auto max_contiguous = align->contiguous(v);
|
||||
@@ -123,7 +142,6 @@ layout_t::layout_t(layout_type_t _type,
|
||||
return max_contiguous[a] > max_contiguous[b];
|
||||
});
|
||||
}
|
||||
this->order = order;
|
||||
}
|
||||
|
||||
inline unsigned clamp(unsigned x, unsigned lo, unsigned hi) {
|
||||
@@ -133,8 +151,8 @@ inline unsigned clamp(unsigned x, unsigned lo, unsigned hi) {
|
||||
layout_hmma_884_t::layout_hmma_884_t(size_t num_warps,
|
||||
const std::vector<int>& _axes,
|
||||
const std::vector<unsigned>& _shapes,
|
||||
const std::vector<ir::value *> &values,
|
||||
analysis::align* align): layout_t(HMMA_884, _axes, _shapes, values, align) {
|
||||
const std::vector<ir::value *> &values, size_t _id,
|
||||
analysis::align* align): layout_t(HMMA_884, _axes, _shapes, values, _id, align) {
|
||||
|
||||
unsigned shape_0 = shapes[order[0]];
|
||||
unsigned shape_1 = shapes[order[1]];
|
||||
@@ -173,7 +191,8 @@ layout_scanline_t::layout_scanline_t(size_t num_warps,
|
||||
const std::vector<int>& _axes,
|
||||
const std::vector<unsigned>& _shapes,
|
||||
const std::vector<ir::value *> &values,
|
||||
analysis::align* align): layout_t(SCANLINE, _axes, _shapes, values, align){
|
||||
size_t _id,
|
||||
analysis::align* align): layout_t(SCANLINE, _axes, _shapes, values, _id, align){
|
||||
unsigned size = std::accumulate(shapes.begin(), shapes.end(), 1, std::multiplies<int>());
|
||||
unsigned num_threads = num_warps * 32;
|
||||
nts.resize(shapes.size());
|
||||
@@ -196,6 +215,58 @@ layout_scanline_t::layout_scanline_t(size_t num_warps,
|
||||
throw std::runtime_error("cannot create a kernel with this amount of warps");
|
||||
}
|
||||
|
||||
layout_shared_t::layout_shared_t(const layout_t *arg,
|
||||
const std::vector<int>& _axes,
|
||||
const std::vector<unsigned>& _shapes,
|
||||
const std::vector<ir::value *> &values,
|
||||
size_t _id,
|
||||
analysis::align* align): layout_t(SHARED, _axes, _shapes, values, _id, align) {
|
||||
|
||||
if(arg->type == SCANLINE)
|
||||
order = arg->order;
|
||||
|
||||
ir::value* dot_a = nullptr;
|
||||
ir::value* dot_b = nullptr;
|
||||
ir::value* hmma_dot_a = nullptr;
|
||||
ir::value* hmma_dot_b = nullptr;
|
||||
for(ir::value* v: values){
|
||||
extract_dot_use(v, dot_a, 0);
|
||||
extract_dot_use(v, dot_b, 1);
|
||||
extract_hmma_dot_use(v, hmma_dot_a, 0);
|
||||
extract_hmma_dot_use(v, hmma_dot_b, 1);
|
||||
}
|
||||
std::vector<int> col = {0, 1};
|
||||
std::vector<int> row = {1, 0};
|
||||
if(dot_a && !hmma_dot_a)
|
||||
order = is_trans(dot_a) ? row : col;
|
||||
if(dot_b && !hmma_dot_b)
|
||||
order = is_trans(dot_b) ? col : row;
|
||||
}
|
||||
|
||||
void layout::create(size_t id, const std::vector<ir::value*>& values) {
|
||||
auto it_hmma_c = std::find_if(values.begin(), values.end(), &is_hmma_c);
|
||||
auto cmp = [](ir::value* x, ir::value *y) {
|
||||
return x->get_type()->get_tile_ranks1() <
|
||||
y->get_type()->get_tile_ranks1();
|
||||
};
|
||||
ir::value *largest = *std::max_element(values.begin(), values.end(), cmp);
|
||||
const auto& axes = axes_->get(largest);
|
||||
const auto& shapes = largest->get_type()->get_tile_shapes();
|
||||
auto it_cts = std::find_if(values.begin(), values.end(), [](ir::value* v) {
|
||||
return dynamic_cast<ir::copy_to_shared_inst*>(v);
|
||||
});
|
||||
// type
|
||||
if(it_hmma_c != values.end())
|
||||
layouts_[id] = new layout_hmma_884_t(num_warps_, axes, shapes, values, id, align_);
|
||||
else if(it_cts != values.end()){
|
||||
ir::copy_to_shared_inst *cts = (ir::copy_to_shared_inst*)*it_cts;
|
||||
ir::value *arg = cts->get_operand(0);
|
||||
create(groups_.at(arg), values_.at(groups_.at(arg)));
|
||||
layouts_[id] = new layout_shared_t(get(arg), axes, shapes, values, id, align_);
|
||||
}
|
||||
else
|
||||
layouts_[id] = new layout_scanline_t(num_warps_, axes, shapes, values, id, align_);
|
||||
}
|
||||
|
||||
void layout::run(ir::module &mod) {
|
||||
// make graph
|
||||
@@ -208,50 +279,8 @@ void layout::run(ir::module &mod) {
|
||||
graph_.connected_components(&values_, &groups_);
|
||||
|
||||
// create layouts
|
||||
for(const auto& x: values_) {
|
||||
bool hmma_c = std::any_of(x.second.begin(), x.second.end(), &is_hmma_c);
|
||||
auto cmp = [](ir::value* x, ir::value *y) {
|
||||
return x->get_type()->get_tile_ranks1() <
|
||||
y->get_type()->get_tile_ranks1();
|
||||
};
|
||||
ir::value *largest = *std::max_element(x.second.begin(), x.second.end(), cmp);
|
||||
const auto& axes = axes_->get(largest);
|
||||
const auto& shapes = largest->get_type()->get_tile_shapes();
|
||||
// type
|
||||
if(hmma_c)
|
||||
layouts_[x.first] = new layout_hmma_884_t(num_warps_, axes, shapes, x.second, align_);
|
||||
else
|
||||
layouts_[x.first] = new layout_scanline_t(num_warps_, axes, shapes, x.second, align_);
|
||||
}
|
||||
|
||||
|
||||
// matrix multiplication optimizations
|
||||
for(const auto& x: values_) {
|
||||
std::vector<ir::dot_inst*> dots;
|
||||
for(ir::value* v: x.second)
|
||||
if(auto *x = dynamic_cast<ir::dot_inst*>(v))
|
||||
dots.push_back(x);
|
||||
for(ir::dot_inst* dot: dots){
|
||||
ir::value* a = dot->get_operand(0);
|
||||
ir::value* b = dot->get_operand(1);
|
||||
if(get(dot)->type == HMMA_884){
|
||||
auto a_val = values_of(layout_of(a));
|
||||
auto b_val = values_of(layout_of(b));
|
||||
for(ir::value *v: a_val)
|
||||
if(auto *cts = dynamic_cast<ir::copy_to_shared_inst*>(v))
|
||||
layouts_[layout_of(a)]->order = layouts_[layout_of(cts->get_operand(0))]->order;
|
||||
for(ir::value *v: b_val)
|
||||
if(auto *cts = dynamic_cast<ir::copy_to_shared_inst*>(v))
|
||||
layouts_[layout_of(b)]->order = layouts_[layout_of(cts->get_operand(0))]->order;
|
||||
}
|
||||
else{
|
||||
std::vector<int> col = {0, 1};
|
||||
std::vector<int> row = {1, 0};
|
||||
layouts_[layout_of(a)]->order = is_trans(a) ? row : col;
|
||||
layouts_[layout_of(b)]->order = is_trans(b) ? col : row;
|
||||
}
|
||||
}
|
||||
}
|
||||
for(const auto& x: values_)
|
||||
create(x.first, x.second);
|
||||
}
|
||||
|
||||
}
|
||||
|
@@ -42,7 +42,7 @@ void liveness::extract_double_bufferable(ir::instruction *i) {
|
||||
ir::value *value_1 = phi->get_incoming_value(1);
|
||||
ir::instruction *i_0 = dynamic_cast<ir::instruction*>(value_0);
|
||||
ir::instruction *i_1 = dynamic_cast<ir::instruction*>(value_1);
|
||||
if(!i_0 || !i_1 || storage_info.at(i_0->get_id()).first != SHARED || storage_info.at(i_1->get_id()).first != SHARED)
|
||||
if(!i_0 || !i_1 || storage_info.at(i_0->get_id()).first != codegen::SHARED || storage_info.at(i_1->get_id()).first != codegen::SHARED)
|
||||
return;
|
||||
if(is_latch_1)
|
||||
double_[value_0] = double_buffer_info_t{value_1, phi};
|
||||
@@ -50,21 +50,6 @@ void liveness::extract_double_bufferable(ir::instruction *i) {
|
||||
double_[value_1] = double_buffer_info_t{value_0, phi};
|
||||
}
|
||||
|
||||
void liveness::make_graph(ir::instruction *i) {
|
||||
if(has_double(i)){
|
||||
ir::value *latch = double_[i].latch;
|
||||
graph_.add_edge(i, latch);
|
||||
}
|
||||
if(storage_info.at(i->get_id()).first == SHARED){
|
||||
graph_.add_edge(i, i);
|
||||
for(ir::value* op: i->ops()){
|
||||
auto* iop = dynamic_cast<ir::instruction*>(op);
|
||||
if(!iop || storage_info.at(iop->get_id()).first != SHARED)
|
||||
continue;
|
||||
graph_.add_edge(i, op);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// connected components
|
||||
bool is_trans(ir::value *v) {
|
||||
@@ -151,7 +136,6 @@ void liveness::run(ir::module &mod) {
|
||||
indices.clear();
|
||||
pad_.clear();
|
||||
intervals_.clear();
|
||||
graph_.clear();
|
||||
|
||||
// Create set of pair of values that can be double-buffered
|
||||
ir::for_each_instruction(mod, [this](ir::instruction* i) {
|
||||
@@ -167,22 +151,14 @@ void liveness::run(ir::module &mod) {
|
||||
});
|
||||
}while(has_changed);
|
||||
|
||||
// Create buffer dependency graph
|
||||
ir::for_each_instruction(mod, [this](ir::instruction* i) {
|
||||
this->make_graph(i);
|
||||
});
|
||||
|
||||
// connected components
|
||||
tools::graph<node_t>::cmap_t cmap;
|
||||
tools::graph<node_t>::nmap_t nmap;
|
||||
graph_.connected_components(&cmap, &nmap);
|
||||
for(auto x: cmap) {
|
||||
buffer_t* buffer = new buffer_t{x.first};
|
||||
values_[buffer] = x.second;
|
||||
for(ir::value *v: x.second){
|
||||
buffer->size = std::max<int>(buffer->size, num_bytes(v));
|
||||
groups_[v] = buffer;
|
||||
}
|
||||
for(auto &x: layouts_->get_all()) {
|
||||
layout_t* layout = x.second;
|
||||
if(layout->type != SHARED)
|
||||
continue;
|
||||
for(ir::value *v: layout->values)
|
||||
layout->size = std::max<int>(layout->size, num_bytes(v));
|
||||
}
|
||||
|
||||
// Assigns index to each instruction
|
||||
@@ -195,22 +171,25 @@ void liveness::run(ir::module &mod) {
|
||||
}
|
||||
}
|
||||
|
||||
for(auto x: values_) {
|
||||
for(auto &x: layouts_->get_all()) {
|
||||
layout_t* layout = x.second;
|
||||
if(layout->type != SHARED)
|
||||
continue;
|
||||
// users
|
||||
std::set<ir::value*> values;
|
||||
for(ir::value *v: x.second){
|
||||
values.insert(v);
|
||||
std::set<ir::value*> users;
|
||||
for(ir::value *v: layout->values){
|
||||
users.insert(v);
|
||||
for(ir::user *u: v->get_users())
|
||||
values.insert(u);
|
||||
users.insert(u);
|
||||
}
|
||||
// compute intervals
|
||||
unsigned start = INT32_MAX;
|
||||
unsigned end = 0;
|
||||
for(ir::value *u: values){
|
||||
for(ir::value *u: users){
|
||||
start = std::min(start, indices.at(u));
|
||||
end = std::max(end, indices.at(u));
|
||||
}
|
||||
intervals_[x.first] = segment{start, end};
|
||||
intervals_[layout] = segment{start, end};
|
||||
}
|
||||
|
||||
|
||||
|
@@ -486,21 +486,7 @@ Instruction *selection::llvm_inst(ir::instruction *inst, std::function<Value*(ir
|
||||
return (Instruction*)res;
|
||||
}
|
||||
if(ir::atomic_add_inst* ii = dynamic_cast<ir::atomic_add_inst*>(inst)){
|
||||
// Value *ptr = value(ii->get_operand(0));
|
||||
// Value *val = value(ii->get_operand(1));
|
||||
// Value *atom_f_add = nullptr;
|
||||
// if(val->getType()->isFloatTy())
|
||||
// atom_f_add = Intrinsic::getDeclaration(builder.GetInsertBlock()->getModule(), Intrinsic::nvvm_atomic_load_add_f32, {ptr->getType()});
|
||||
// else if(val->getType()->isHalfTy()){
|
||||
// Type *fp16 = Type::getHalfTy(ctx);
|
||||
|
||||
// FunctionType *atom_ty = FunctionType::get(fp16, {fp16->getPointerTo(), fp16}, false);
|
||||
// atom_f_add = InlineAsm::get(atom_ty, " atom.relaxed.global.gpu.add.noftz.f16 $0, [$1], $2;", "=h,l,h", true);
|
||||
// }
|
||||
// if(atom_f_add == nullptr)
|
||||
throw std::runtime_error("unsupported");
|
||||
// Value *res = builder.CreateCall(atom_f_add, {ptr, val});
|
||||
// return (Instruction*)res;
|
||||
}
|
||||
if(ir::sqrt_inst* ii = dynamic_cast<ir::sqrt_inst*>(inst)){
|
||||
Value *val = value(ii->get_operand(0));
|
||||
@@ -711,7 +697,7 @@ void selection::init_hmma_axes(const analysis::layout_t& layout, IRBuilder<> &bu
|
||||
void selection::init_axes(const analysis::layout_t& layout, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
|
||||
if(layout.type == analysis::HMMA_884)
|
||||
init_hmma_axes(layout, builder, u_thread_id, u_warp_id);
|
||||
else
|
||||
else if(layout.type == analysis::SCANLINE)
|
||||
init_strided_scan_axes(layout, builder, u_thread_id, u_warp_id);
|
||||
}
|
||||
|
||||
@@ -801,7 +787,7 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder,
|
||||
for(ir::value *op: user->ops())
|
||||
create_tile(op, builder, seen, sh_mem_ptr);
|
||||
auto *i = dynamic_cast<ir::instruction*>(v);
|
||||
if(i && storage_info.at(i->get_id()).first == SHARED && !dynamic_cast<ir::reduce_inst*>(v))
|
||||
if(i && layouts_->get(i)->type == analysis::SHARED && !dynamic_cast<ir::reduce_inst*>(v))
|
||||
create_shared_tile(i, builder, sh_mem_ptr);
|
||||
else
|
||||
create_distributed_tile(v, builder);
|
||||
@@ -1018,7 +1004,7 @@ void selection::lower_copy_to_shared(ir::copy_to_shared_inst *x, LLVMContext &ct
|
||||
distributed_tile* in = (distributed_tile*)tmap_.at(arg);
|
||||
if(x_order == arg_order){
|
||||
size_t ld = arg_order[0];
|
||||
vector_size = std::min(layouts_->get(x)->nts.at(ld), layouts_->get(arg)->nts.at(ld));
|
||||
vector_size = layouts_->get(arg)->nts.at(ld);
|
||||
}
|
||||
|
||||
std::map<unsigned, Value*> packets;
|
||||
@@ -1038,6 +1024,15 @@ void selection::lower_copy_to_shared(ir::copy_to_shared_inst *x, LLVMContext &ct
|
||||
});
|
||||
}
|
||||
|
||||
void selection::lower_copy_from_shared(ir::copy_from_shared_inst *x, LLVMContext &ctx, Function *fn, IRBuilder<> &builder) {
|
||||
distributed_tile* result = (distributed_tile*)tmap_.at(x);
|
||||
shared_tile* arg = (shared_tile*)tmap_.at(x->get_operand(0));
|
||||
|
||||
result->for_each([&](indices_t idx){
|
||||
result->set_value(idx, arg->get_value(idx));
|
||||
});
|
||||
}
|
||||
|
||||
void selection::lower_trans(ir::trans_inst *x, LLVMContext &ctx, Function *fn, IRBuilder<> &builder) {
|
||||
shared_tile* in = (shared_tile*)tmap_.at(x->get_operand(0));
|
||||
shared_tile* out = new shared_tile(in->get_ty(), in->get_shapes(), in->get_order(), in->get_pointer(), builder, in->get_offset(), x->get_perm());
|
||||
@@ -1399,6 +1394,8 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
lower_broadcast(x, ctx, fn, builder);
|
||||
else if(auto *x = dynamic_cast<ir::copy_to_shared_inst*>(ins))
|
||||
lower_copy_to_shared(x, ctx, fn, builder);
|
||||
else if(auto *x = dynamic_cast<ir::copy_from_shared_inst*>(ins))
|
||||
lower_copy_from_shared(x, ctx, fn, builder);
|
||||
else if(auto* x = dynamic_cast<ir::trans_inst*>(ins))
|
||||
lower_trans(x, ctx, fn, builder);
|
||||
else if(auto x = dynamic_cast<ir::dot_inst*>(ins))
|
||||
@@ -1554,7 +1551,7 @@ void selection::run(ir::module &src, Module &dst) {
|
||||
}
|
||||
else {
|
||||
unsigned num_bytes = inst->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
|
||||
offset->addIncoming(dst_builder.getInt32(liveness_->get_buffer(inst)->size / (2*num_bytes)), llvm_inc_block);
|
||||
offset->addIncoming(dst_builder.getInt32(layouts_->get(inst)->size / (2*num_bytes)), llvm_inc_block);
|
||||
}
|
||||
ptr->addIncoming(inc_shared->get_pointer(), llvm_inc_block);
|
||||
}
|
||||
|
@@ -12,30 +12,45 @@ namespace triton {
|
||||
namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
inline bool is_shared(ir::value *v) {
|
||||
auto *i = dynamic_cast<ir::instruction*>(v);
|
||||
if(!i)
|
||||
return false;
|
||||
return storage_info.at(i->get_id()).first == codegen::SHARED;
|
||||
}
|
||||
|
||||
// run pass on module
|
||||
void add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder) {
|
||||
void add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared) {
|
||||
auto *i = dynamic_cast<ir::instruction*>(x);
|
||||
// not an instruction
|
||||
if(!i) {
|
||||
builder.set_insert_point(parent);
|
||||
ir::value *cts = builder.create_copy_to_shared(x);
|
||||
parent->replace_uses_of_with(x, cts);
|
||||
ir::value *copy;
|
||||
if(to_shared)
|
||||
copy = builder.create_copy_to_shared(x);
|
||||
else
|
||||
copy = builder.create_copy_from_shared(x);
|
||||
parent->replace_uses_of_with(x, copy);
|
||||
return;
|
||||
}
|
||||
// phi node
|
||||
if(auto* phi = dynamic_cast<ir::phi_node*>(x)) {
|
||||
for(unsigned i = 0; i < phi->get_num_incoming(); ++i)
|
||||
add_copy(phi, phi->get_incoming_value(i), builder);
|
||||
add_copy(phi, phi->get_incoming_value(i), builder, to_shared);
|
||||
return;
|
||||
}
|
||||
ir::value_id_t id = i->get_id();
|
||||
// already in shared memory
|
||||
if(storage_info.at(id).first == SHARED)
|
||||
if(to_shared && storage_info.at(id).first == SHARED)
|
||||
return;
|
||||
// copy
|
||||
builder.set_insert_point_after(i);
|
||||
ir::value *cts = builder.create_copy_to_shared(x);
|
||||
parent->replace_uses_of_with(x, cts);
|
||||
ir::value *copy;
|
||||
if(to_shared)
|
||||
copy = builder.create_copy_to_shared(x);
|
||||
else
|
||||
copy = builder.create_copy_from_shared(x);
|
||||
parent->replace_uses_of_with(x, copy);
|
||||
}
|
||||
|
||||
void cts::run(ir::module &mod) {
|
||||
@@ -45,10 +60,16 @@ void cts::run(ir::module &mod) {
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *i: block->get_inst_list()){
|
||||
auto storage = storage_info.at(i->get_id());
|
||||
// copy to shared operands when necessary
|
||||
// copy to shared operands
|
||||
for(size_t k = 0; k < storage.second.size(); k++)
|
||||
if(storage.second[k] == SHARED)
|
||||
add_copy(i, i->get_operand(k), builder);
|
||||
add_copy(i, i->get_operand(k), builder, true);
|
||||
// copy from shared operands
|
||||
for(size_t k = 0; k < storage.second.size(); k++)
|
||||
if(storage.second[k] == DISTRIBUTED &&
|
||||
is_shared(i->get_operand(k))){
|
||||
add_copy(i, i->get_operand(k), builder, false);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -3,6 +3,7 @@
|
||||
#include <algorithm>
|
||||
|
||||
#include "triton/codegen/analysis/liveness.h"
|
||||
#include "triton/codegen/analysis/layout.h"
|
||||
#include "triton/codegen/analysis/allocation.h"
|
||||
#include "triton/codegen/instructions.h"
|
||||
#include "triton/codegen/transform/membar.h"
|
||||
@@ -38,7 +39,7 @@ void membar::add_reference(ir::value *v, interval_vec_t &res){
|
||||
return;
|
||||
if(alloc_->has_offset(v)){
|
||||
unsigned offset = alloc_->offset(v);
|
||||
unsigned size = liveness_->get_buffer(v)->size;
|
||||
unsigned size = layouts_->get(v)->size;
|
||||
res.push_back(interval_t(offset, offset + size));
|
||||
}
|
||||
}
|
||||
|
@@ -241,7 +241,7 @@ std::string cu_module::compile_llvm_module(std::unique_ptr<llvm::Module> module,
|
||||
cu_module::cu_module(driver::context * context, std::unique_ptr<llvm::Module> ll_module): cu_module(context, compile_llvm_module(std::move(ll_module), context->device())) { }
|
||||
|
||||
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
|
||||
// std::cout << source << std::endl;
|
||||
std::cout << source << std::endl;
|
||||
cu_context::context_switcher ctx(*context);
|
||||
// JIT compile source-code
|
||||
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};
|
||||
|
@@ -347,6 +347,10 @@ value *builder::create_copy_to_shared(value *arg, const std::string &name) {
|
||||
return insert(copy_to_shared_inst::create(arg, name));
|
||||
}
|
||||
|
||||
value *builder::create_copy_from_shared(value *arg, const std::string &name) {
|
||||
return insert(copy_from_shared_inst::create(arg, name));
|
||||
}
|
||||
|
||||
value *builder::create_barrier(const std::string &name) {
|
||||
return insert(barrier_inst::create(ctx_, name));
|
||||
}
|
||||
|
@@ -731,12 +731,20 @@ instruction* atomic_add_inst::create(value *ptr, value *val, const std::string &
|
||||
//===----------------------------------------------------------------------===//
|
||||
// intrinsic instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// copy to shared
|
||||
copy_to_shared_inst* copy_to_shared_inst::create(value *arg, const std::string &name,
|
||||
instruction *next) {
|
||||
return new copy_to_shared_inst(arg->get_type(), INST_COPY_TO_SHARED, arg, name, next);
|
||||
}
|
||||
|
||||
// copy from shared
|
||||
copy_from_shared_inst* copy_from_shared_inst::create(value *arg, const std::string &name,
|
||||
instruction *next) {
|
||||
return new copy_from_shared_inst(arg->get_type(), INST_COPY_FROM_SHARED, arg, name, next);
|
||||
}
|
||||
|
||||
|
||||
// barrier
|
||||
barrier_inst::barrier_inst(context &ctx, const std::string &name,
|
||||
instruction *next)
|
||||
|
@@ -211,7 +211,7 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
||||
codegen::analysis::layout layouts(&axes, &align, opt.num_warps);
|
||||
codegen::analysis::liveness liveness(&layouts);
|
||||
codegen::analysis::allocation allocation(&liveness);
|
||||
codegen::transform::membar barriers(&liveness, &allocation);
|
||||
codegen::transform::membar barriers(&liveness, &layouts, &allocation);
|
||||
codegen::transform::dce dce;
|
||||
codegen::transform::peephole peephole;
|
||||
codegen::transform::reassociate reassociate(&align);
|
||||
@@ -230,11 +230,11 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
||||
align.run(module);
|
||||
dce.run(module);
|
||||
reassociate.run(module);
|
||||
// ir::print(module, std::cout);
|
||||
dce.run(module);
|
||||
cts.run(module);
|
||||
align.run(module);
|
||||
axes.run(module);
|
||||
// ir::print(module, std::cout);
|
||||
layouts.run(module);
|
||||
liveness.run(module);
|
||||
allocation.run(module);
|
||||
@@ -245,6 +245,7 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
||||
align.run(module);
|
||||
axes.run(module);
|
||||
layouts.run(module);
|
||||
// ir::print(module, std::cout);
|
||||
selection.run(module, *llvm);
|
||||
// return binary
|
||||
std::unique_ptr<driver::module> res(driver::module::create(context, std::move(llvm)));
|
||||
|
Reference in New Issue
Block a user