[GENERAL] Cleaned polymorphic structure of layouts analysis pass

This commit is contained in:
Philippe Tillet
2020-01-20 15:15:32 -05:00
parent 382ca2c745
commit 78b98fb7cf
17 changed files with 500 additions and 480 deletions

View File

@@ -27,14 +27,14 @@ public:
allocation(liveness *live)
: liveness_(live) { }
// accessors
bool has_offset(const layout_t *x) const { return offsets_.find(x) != offsets_.end(); }
unsigned offset(const layout_t *x) const { return offsets_.at(x); }
bool has_offset(const data_layout *x) const { return offsets_.find(x) != offsets_.end(); }
unsigned offset(const data_layout *x) const { return offsets_.at(x); }
unsigned allocated_size() const { return allocated_size_; }
// run
void run(ir::module& mod);
private:
std::map<const layout_t*, unsigned> offsets_;
std::map<const data_layout*, unsigned> offsets_;
size_t allocated_size_;
// dependences
liveness *liveness_;

View File

@@ -22,103 +22,141 @@ namespace analysis{
class axes;
class align;
class layout_visitor;
class data_layout;
class mma884_layout;
class scanline_layout;
class shared_layout;
enum layout_type_t {
class layout_visitor {
public:
virtual void visit_layout(data_layout *);
virtual void visit_layout_hmma_884(mma884_layout*) = 0;
virtual void visit_layout_scanline(scanline_layout*) = 0;
virtual void visit_layout_shared(shared_layout*) = 0;
};
class data_layout {
protected:
enum id_t {
HMMA_884,
SCANLINE,
SHARED
};
typedef std::vector<int> axes_t;
typedef std::vector<unsigned> shape_t;
typedef std::vector<int> order_t;
typedef std::vector<ir::value*> values_t;
private:
template<typename T>
T* downcast(id_t id) {
if(id_ == id)
return static_cast<T*>(this);
return nullptr;
}
public:
data_layout(id_t id,
const std::vector<int>& axes,
const std::vector<unsigned> &shape,
const std::vector<ir::value *> &values,
analysis::align* align);
// visitor
virtual void accept(layout_visitor* vst) = 0;
// downcast
mma884_layout* to_mma884() { return downcast<mma884_layout>(HMMA_884); }
scanline_layout* to_scanline() { return downcast<scanline_layout>(SCANLINE); }
shared_layout* to_shared() { return downcast<shared_layout>(SHARED); }
// accessors
size_t get_rank() { return shape_.size(); }
const shape_t& get_shape() const { return shape_; }
const order_t& get_order() const { return order_; }
const values_t& get_values() const { return values_;}
int get_axis(size_t k) const { return axes_.at(k); }
const int get_order(size_t k) const { return order_.at(k); }
// find the position of given axis
size_t find_axis(int to_find) const;
private:
id_t id_;
axes_t axes_;
values_t values_;
protected:
order_t order_;
shape_t shape_;
};
class mma884_layout: public data_layout {
public:
mma884_layout(size_t num_warps,
const std::vector<int>& axes,
const std::vector<unsigned>& shapes,
const std::vector<ir::value *> &values,
analysis::align* align);
void accept(layout_visitor* vst) { vst->visit_layout_hmma_884(this); }
// accessor
int fpw(size_t k) { return fpw_.at(k); }
int wpt(size_t k) { return wpt_.at(k); }
private:
std::vector<int> fpw_;
std::vector<int> wpt_;
};
struct scanline_layout: public data_layout {
scanline_layout(size_t num_warps,
const std::vector<int>& axes,
const std::vector<unsigned>& shape,
const std::vector<ir::value *> &values,
analysis::align* align);
void accept(layout_visitor* vst) { vst->visit_layout_scanline(this); }
// accessor
int mts(size_t k) { return mts_.at(k); }
int nts(size_t k) { return nts_.at(k); }
private:
std::vector<int> mts_;
std::vector<int> nts_;
};
struct double_buffer_info_t {
ir::value* first;
ir::value* latch;
ir::phi_node* phi;
};
class layout_visitor;
class layout_t;
class layout_hmma_884_t;
class layout_scanline_t;
class layout_shared_t;
class shared_layout: public data_layout {
private:
static bool is_loop_latch(ir::phi_node *phi, ir::instruction *terminator);
static void extract_double_bufferable(ir::value *v, std::shared_ptr<double_buffer_info_t>& res);
class layout_visitor {
public:
virtual void visit_layout(layout_t *);
virtual void visit_layout_hmma_884(layout_hmma_884_t*) = 0;
virtual void visit_layout_scanline(layout_scanline_t*) = 0;
virtual void visit_layout_shared(layout_shared_t*) = 0;
};
class layout_hmma_884_t;
class layout_scanline_t;
class layout_shared_t;
struct layout_t {
layout_t(layout_type_t _type,
const std::vector<int>& _axes,
const std::vector<unsigned> &_shapes,
const std::vector<ir::value *> &_values,
ir::type *_ty,
analysis::align* align);
// visitor
virtual void accept(layout_visitor* vst) = 0;
// downcast
layout_hmma_884_t* to_hmma884();
layout_scanline_t* to_scanline();
layout_shared_t* to_shared();
layout_type_t type;
std::vector<int> axes;
std::vector<unsigned> shapes;
std::vector<ir::value*> values;
std::vector<int> order;
ir::type *ty;
};
struct layout_hmma_884_t: public layout_t {
layout_hmma_884_t(size_t num_warps,
const std::vector<int>& _axes,
const std::vector<unsigned>& _shapes,
const std::vector<ir::value *> &_values,
ir::type *_ty,
analysis::align* align);
void accept(layout_visitor* vst) { vst->visit_layout_hmma_884(this); }
std::vector<int> fpw;
std::vector<int> wpt;
};
struct layout_scanline_t: public layout_t {
layout_scanline_t(size_t num_warps,
const std::vector<int>& _axes,
const std::vector<unsigned>& _shapes,
const std::vector<ir::value *> &values,
ir::type *_ty,
analysis::align* align);
void accept(layout_visitor* vst) { vst->visit_layout_scanline(this); }
std::vector<int> mts;
std::vector<int> nts;
};
struct layout_shared_t: public layout_t {
layout_shared_t(const layout_t *arg,
const std::vector<int>& _axes,
const std::vector<unsigned>& _shapes,
const std::vector<ir::value *> &values,
shared_layout(const data_layout *arg,
const std::vector<int>& axes,
const std::vector<unsigned>& shapes,
const std::vector<ir::value *> &values_,
ir::type *ty,
analysis::align* align);
void accept(layout_visitor* vst) { vst->visit_layout_shared(this); }
// accessors
size_t get_size() { return size_; }
ir::type* get_type() { return ty_; }
double_buffer_info_t* get_double_buffer() { return double_buffer_.get(); }
std::shared_ptr<double_buffer_info_t> double_buffer;
size_t size;
private:
size_t size_;
ir::type *ty_;
std::shared_ptr<double_buffer_info_t> double_buffer_;
};
class layout {
class layouts {
typedef ir::value* node_t;
typedef std::map <node_t, std::set<node_t>> graph_t;
@@ -127,23 +165,23 @@ private:
void connect(ir::value *x, ir::value *y);
void make_graph(ir::instruction *i);
void init_hmma_tile(layout_t& layout);
void init_scanline_tile(layout_t &layout);
void init_hmma_tile(data_layout& layouts);
void init_scanline_tile(data_layout &layouts);
void create(size_t id, const std::vector<ir::value*>& values);
public:
// constructor
layout(analysis::axes *axes, analysis::align *align, size_t num_warps);
layouts(analysis::axes *axes, analysis::align *align, size_t num_warps);
// accessors
unsigned layout_of(ir::value *value) const;
const std::vector<ir::value*>& values_of(unsigned id) const;
size_t num_layouts() const;
layout_t* get(size_t id);
layout_t* get(ir::value *v);
std::map<size_t, layout_t*> &get_all();
size_t tmp(ir::instruction* i);
unsigned layout_of(ir::value *value) const { return groups_.at(value); }
const std::vector<ir::value*>& values_of(unsigned id) const { return values_.at(id); }
size_t num_layouts() const { return values_.size();}
data_layout* get(size_t id) { return layouts_.at(id); }
data_layout* get(ir::value *v) { return get(layout_of(v));}
std::map<size_t, data_layout*> &get_all() { return layouts_; }
size_t tmp(ir::instruction* i) { return tmp_.at((ir::value*)i);}
// execution
void run(ir::module &mod);
@@ -155,7 +193,7 @@ private:
tools::graph<ir::value*> graph_;
std::map<ir::value*, size_t> groups_;
std::map<size_t, std::vector<ir::value*>> values_;
std::map<size_t, layout_t*> layouts_;
std::map<size_t, data_layout*> layouts_;
std::map<ir::value*, size_t> tmp_;
};

View File

@@ -23,8 +23,8 @@ namespace analysis{
typedef unsigned slot_index;
class tiles;
class layout;
class layout_t;
class layouts;
class data_layout;
struct segment {
slot_index start;
@@ -42,20 +42,20 @@ struct segment {
class liveness {
private:
typedef std::map<layout_shared_t*, segment> intervals_map_t;
typedef std::map<shared_layout*, segment> intervals_map_t;
public:
// constructor
liveness(layout *l): layouts_(l){ }
liveness(layouts *l): layouts_(l){ }
// accessors
const intervals_map_t& get() const { return intervals_; }
segment get(layout_shared_t* v) const { return intervals_.at(v); }
segment get(shared_layout* v) const { return intervals_.at(v); }
// run
void run(ir::module &mod);
private:
// analysis
layout *layouts_;
layouts *layouts_;
intervals_map_t intervals_;
};

View File

@@ -35,7 +35,7 @@ class align;
class allocation;
class cts;
class axes;
class layout;
class layouts;
}
// typedef
typedef llvm::IRBuilder<llvm::ConstantFolder,
@@ -50,7 +50,7 @@ typedef llvm::ArrayType ArrayType;
typedef llvm::Function Function;
typedef std::vector<Value*> indices_t;
// forward
class machine_layout_t;
class machine_data_layout;
class tile;
class shared_tile;
class distributed_tile;
@@ -74,13 +74,13 @@ private:
void visit_outer_dot(ir::dot_inst*, distributed_tile *TA, distributed_tile *TB, distributed_tile *TD, unsigned NK,
Type *c_ty, Function *f_mul_add);
void finalize_shared_layout(analysis::layout_shared_t*);
void finalize_shared_layout(analysis::shared_layout*);
void finalize_function(ir::function*);
void finalize_phi_node(ir::phi_node*);
public:
generator(analysis::axes *a_axes,
analysis::layout *layouts,
analysis::layouts *layouts,
analysis::align *alignment,
analysis::allocation *alloc,
target *tgt,
@@ -139,9 +139,9 @@ public:
void visit_basic_block(ir::basic_block*);
void visit_argument(ir::argument*);
void visit_layout_hmma_884(analysis::layout_hmma_884_t*);
void visit_layout_scanline(analysis::layout_scanline_t*);
void visit_layout_shared(analysis::layout_shared_t*);
void visit_layout_hmma_884(analysis::mma884_layout*);
void visit_layout_scanline(analysis::scanline_layout*);
void visit_layout_shared(analysis::shared_layout*);
void visit(ir::module &, llvm::Module &);
@@ -150,13 +150,13 @@ private:
Builder* builder_;
Module *mod_;
std::map<const analysis::layout_t*, machine_layout_t*> machine_layouts_;
std::map<const analysis::data_layout*, machine_data_layout*> machine_layouts_;
analysis::axes *a_axes_;
std::map<unsigned, distributed_axis> axes_;
std::map<ir::value *, Value *> vmap_;
std::map<ir::value *, tile *> tmap_;
target *tgt_;
analysis::layout *layouts_;
analysis::layouts *layouts_;
analysis::align *alignment_;
analysis::allocation *alloc_;
Value *sh_mem_ptr_;

View File

@@ -36,7 +36,7 @@ class align;
class allocation;
class cts;
class axes;
class layout;
class layouts;
}
typedef llvm::IRBuilder<llvm::ConstantFolder,
@@ -51,7 +51,7 @@ typedef llvm::ArrayType ArrayType;
typedef llvm::Function Function;
class distributed_axis;
class machine_layout_t;
class machine_data_layout;
class tile;
class shared_tile;
class distributed_tile;
@@ -64,15 +64,15 @@ namespace triton{
namespace codegen{
class machine_layout_t {
class machine_data_layout {
public:
virtual tile* create(ir::value *v) = 0;
};
class machine_layout_shared_t: public machine_layout_t {
class machine_shared_layout: public machine_data_layout {
public:
machine_layout_shared_t(Module *mod, Builder *builder, target *tgt, analysis::allocation* alloc, Value *&sh_mem_ptr,
analysis::layout_shared_t* layout,
machine_shared_layout(Module *mod, Builder *builder, target *tgt, analysis::allocation* alloc, Value *&sh_mem_ptr,
analysis::shared_layout* layout,
std::map<ir::value *, Value *>& vmap,
std::map<ir::value *, tile *>& tmap);
@@ -83,7 +83,7 @@ public:
target *tgt_;
analysis::allocation* alloc_;
Value *&sh_mem_ptr_;
analysis::layout_shared_t* layout_;
analysis::shared_layout* layout_;
std::map<ir::value *, Value *>& vmap_;
std::map<ir::value *, tile *>& tmap_;
@@ -94,29 +94,28 @@ public:
};
class machine_layout_distributed_t: public machine_layout_t {
class machine_distributed_layout: public machine_data_layout {
public:
machine_layout_distributed_t(Module *mod, Builder *builder, target *tgt, Type *ty,
machine_distributed_layout(Module *mod, Builder *builder, target *tgt,
analysis::axes *a_axes, std::map<unsigned, distributed_axis>& axes,
analysis::layout_t* layout);
analysis::data_layout* layout);
tile* create(ir::value *v);
Module *mod_;
Builder *builder_;
target *tgt_;
Type *ty_;
analysis::axes *a_axes_;
std::map<unsigned, distributed_axis>& axes_;
analysis::layout_t* layout_;
analysis::data_layout* layout_;
};
class machine_layout_hmma_884_t: public machine_layout_distributed_t {
class machine_mma884_layout: public machine_distributed_layout {
public:
machine_layout_hmma_884_t(Module *mod, Builder *builder,
target *tgt, Type *ty,
machine_mma884_layout(Module *mod, Builder *builder,
target *tgt,
analysis::axes *a_axes, std::map<unsigned, distributed_axis>& axes,
analysis::layout_hmma_884_t* layout);
analysis::mma884_layout* layout);
Value *offset_a_i_, *offset_a_k_;
Value *offset_b_j_, *offset_b_k_;
unsigned pack_size_0_;
@@ -125,12 +124,12 @@ public:
unsigned num_packs_1_;
};
class machine_layout_scanline_t: public machine_layout_distributed_t {
class machine_scanline_layout: public machine_distributed_layout {
public:
machine_layout_scanline_t(Module *mod, Builder *builder,
target *tgt, Type *ty,
machine_scanline_layout(Module *mod, Builder *builder,
target *tgt,
analysis::axes *a_axes, std::map<unsigned, distributed_axis>& axes,
analysis::layout_scanline_t* layout);
analysis::scanline_layout* layout);
};
}

View File

@@ -47,11 +47,11 @@ class align;
class allocation;
class cts;
class axes;
class layout;
class layouts;
}
class distributed_axis;
class machine_layout_t;
class machine_data_layout;
class tile;
class shared_tile;
class distributed_tile;

View File

@@ -19,7 +19,7 @@ namespace codegen{
namespace analysis{
class align;
class layout;
class layouts;
class cts;
}
@@ -32,12 +32,12 @@ private:
ir::value* rematerialize(ir::value *v, ir::builder& builder, std::map<ir::value*, ir::value*>& seen);
public:
coalesce(analysis::align* align, triton::codegen::analysis::layout *layouts);
coalesce(analysis::align* align, triton::codegen::analysis::layouts *layouts);
void run(ir::module &mod);
private:
analysis::align* align_;
analysis::layout* layout_;
analysis::layouts* layout_;
};
}

View File

@@ -17,7 +17,7 @@ namespace analysis{
class allocation;
class liveness;
class layout;
class layouts;
class cts;
}
@@ -41,13 +41,13 @@ private:
std::set<ir::instruction *> &insert_loc, std::set<triton::ir::value *> &safe_war);
public:
membar(analysis::liveness *liveness, analysis::layout *layouts, analysis::allocation *alloc):
membar(analysis::liveness *liveness, analysis::layouts *layouts, analysis::allocation *alloc):
liveness_(liveness), layouts_(layouts), alloc_(alloc) {}
void run(ir::module &mod);
private:
analysis::liveness *liveness_;
analysis::layout *layouts_;
analysis::layouts *layouts_;
analysis::allocation *alloc_;
};

View File

@@ -15,22 +15,22 @@ void allocation::run(ir::module &mod) {
using std::min;
typedef std::multimap<unsigned, segment> triples_map_type;
std::vector<layout_shared_t*> I;
std::vector<shared_layout*> I;
for(auto x: liveness_->get())
I.push_back(x.first);
std::vector<layout_shared_t*> J = I;
std::vector<shared_layout*> J = I;
triples_map_type H;
H.insert({0, segment{0, INT_MAX}});
std::vector<layout_shared_t*> V;
std::map<layout_shared_t*, unsigned> starts;
std::vector<shared_layout*> V;
std::map<shared_layout*, unsigned> starts;
while(!J.empty()){
auto h_it = H.begin();
unsigned w = h_it->first;
segment xh = h_it->second;
H.erase(h_it);
auto j_it = std::find_if(J.begin(), J.end(), [&](layout_shared_t* JJ){
auto j_it = std::find_if(J.begin(), J.end(), [&](shared_layout* JJ){
segment xj = liveness_->get(JJ);
bool res = xj.intersect(xh);
for(auto val: H)
@@ -38,7 +38,7 @@ void allocation::run(ir::module &mod) {
return res;
});
if(j_it != J.end()){
unsigned size = (*j_it)->size;
unsigned size = (*j_it)->get_size();
segment xj = liveness_->get(*j_it);
starts[*j_it] = w;
H.insert({w + size, segment{max(xh.start, xj.start), min(xh.end, xj.end)}});
@@ -52,14 +52,14 @@ void allocation::run(ir::module &mod) {
}
// Build interference graph
std::map<layout_shared_t*, std::set<layout_shared_t*>> interferences;
for(layout_shared_t* x: V)
for(layout_shared_t* y: V){
std::map<shared_layout*, std::set<shared_layout*>> interferences;
for(shared_layout* x: V)
for(shared_layout* y: V){
if(x == y)
continue;
unsigned X0 = starts[x], Y0 = starts[y];
unsigned NX = x->size;
unsigned NY = y->size;
unsigned NX = x->get_size();
unsigned NY = y->get_size();
segment XS = {X0, X0 + NX};
segment YS = {Y0, Y0 + NY};
if(liveness_->get(x).intersect(liveness_->get(y))
@@ -68,17 +68,17 @@ void allocation::run(ir::module &mod) {
}
// Initialize colors
std::map<layout_shared_t*, int> colors;
for(layout_shared_t* X: V)
std::map<shared_layout*, int> colors;
for(shared_layout* X: V)
colors[X] = (X==V[0])?0:-1;
// First-fit graph coloring
std::vector<bool> available(V.size());
for(layout_shared_t* x: V){
for(shared_layout* x: V){
// Non-neighboring colors are available
std::fill(available.begin(), available.end(), true);
for(layout_shared_t* Y: interferences[x]){
for(shared_layout* Y: interferences[x]){
int color = colors[Y];
if(color >= 0)
available[color] = false;
@@ -89,17 +89,17 @@ void allocation::run(ir::module &mod) {
}
// Finalize allocation
for(layout_shared_t* x: V){
for(shared_layout* x: V){
unsigned Adj = 0;
for(layout_shared_t* y: interferences[x])
Adj = std::max<unsigned>(Adj, starts[y] + y->size);
for(shared_layout* y: interferences[x])
Adj = std::max<unsigned>(Adj, starts[y] + y->get_size());
offsets_[x] = starts[x] + colors[x] * Adj;
}
// Save maximum size of induced memory space
allocated_size_ = 0;
for(layout_shared_t* x: V)
allocated_size_ = std::max<size_t>(allocated_size_, starts[x] + x->size);
for(shared_layout* x: V)
allocated_size_ = std::max<size_t>(allocated_size_, starts[x] + x->get_size());
}
}

View File

@@ -12,57 +12,15 @@ namespace triton{
namespace codegen{
namespace analysis{
/* -------------------------------- *
* Helper Functions *
* -------------------------------- */
// constructor
layout::layout(analysis::axes *axes, analysis::align *align, size_t num_warps)
: axes_(axes), align_(align), num_warps_(num_warps) { }
// get group id
unsigned layout::layout_of(ir::value *value) const
{ return groups_.at(value); }
// get values
const std::vector<ir::value*>& layout::values_of(unsigned id) const
{ return values_.at(id); }
// get number of groups
size_t layout::num_layouts() const
{ return values_.size(); }
// connect two values
void layout::connect(ir::value *x, ir::value *y) {
if(x == y)
return;
if(!x->get_type()->is_tile_ty())
return;
if(!y->get_type()->is_tile_ty())
return;
std::vector<int> x_axes = axes_->get(x);
std::vector<int> y_axes = axes_->get(y);
std::set<int> sx_axes(x_axes.begin(), x_axes.end());
std::set<int> sy_axes(y_axes.begin(), y_axes.end());
std::set<int> common;
std::set_intersection(sx_axes.begin(), sx_axes.end(),
sy_axes.begin(), sy_axes.end(),
std::inserter(common, common.begin()));
graph_.add_edge(x, x);
graph_.add_edge(y, y);
if(!common.empty())
graph_.add_edge(x, y);
inline unsigned clamp(unsigned x, unsigned lo, unsigned hi) {
return std::min(std::max(x, lo), hi);
}
// make graph
void layout::make_graph(ir::instruction *i) {
for(ir::value* opx: i->ops())
for(ir::value* opy: i->ops()){
connect(i, opx);
connect(opx, opy);
}
}
// hmma
bool is_hmma_c(ir::value *v){
inline bool is_hmma_c(ir::value *v){
bool result = false;
if(auto *x = dynamic_cast<ir::dot_inst*>(v)){
ir::value *a = x->get_operand(0);
@@ -75,23 +33,7 @@ bool is_hmma_c(ir::value *v){
return result;
}
layout_t* layout::get(size_t id) {
return layouts_.at(id);
}
layout_t* layout::get(ir::value *v) {
return layouts_.at(groups_.at(v));
}
std::map<size_t, layout_t*>& layout::get_all() {
return layouts_;
}
size_t layout::tmp(ir::instruction* i) {
return tmp_.at(i);
}
void extract_io_use(ir::value *v, std::set<ir::value*>& result) {
inline void extract_io_use(ir::value *v, std::set<ir::value*>& result) {
for(ir::user* u: v->get_users()){
auto i = dynamic_cast<ir::io_inst*>(u);
if(i && i->get_pointer_operand() == v)
@@ -99,7 +41,7 @@ void extract_io_use(ir::value *v, std::set<ir::value*>& result) {
}
}
void extract_dot_use(ir::value *v, ir::value*& result, size_t n) {
inline void extract_dot_use(ir::value *v, ir::value*& result, size_t n) {
for(ir::user* u: v->get_users()){
auto i = dynamic_cast<ir::dot_inst*>(u);
if(i && i->get_operand(n) == v)
@@ -107,7 +49,7 @@ void extract_dot_use(ir::value *v, ir::value*& result, size_t n) {
}
}
void extract_hmma_dot_use(ir::value *v, ir::value*& result, size_t n) {
inline void extract_hmma_dot_use(ir::value *v, ir::value*& result, size_t n) {
for(ir::user* u: v->get_users()){
auto i = dynamic_cast<ir::dot_inst*>(u);
if(i && is_hmma_c(i) && i->get_operand(n) == v)
@@ -116,7 +58,6 @@ void extract_hmma_dot_use(ir::value *v, ir::value*& result, size_t n) {
}
inline bool is_trans(ir::value *v) {
if(dynamic_cast<ir::trans_inst *>(v)) {
return true;
@@ -131,104 +72,103 @@ inline bool is_trans(ir::value *v) {
}
void layout_visitor::visit_layout(layout_t *layout) {
/* -------------------------------- *
* Layout Visitor *
* -------------------------------- */
void layout_visitor::visit_layout(data_layout *layout) {
layout->accept(this);
}
layout_t::layout_t(layout_type_t _type,
const std::vector<int> &_axes,
const std::vector<unsigned> &_shapes,
const std::vector<ir::value *> &_values, ir::type *_ty,
analysis::align* align): type(_type), axes(_axes), shapes(_shapes), values(_values), ty(_ty) {
/* -------------------------------- *
* Base Data Layout *
* -------------------------------- */
data_layout::data_layout(id_t id,
const std::vector<int> &axes,
const std::vector<unsigned> &shape,
const std::vector<ir::value *> &values,
analysis::align* align): id_(id), axes_(axes), shape_(shape), values_(values) {
// io pointer
std::set<ir::value*> ptr;
for(ir::value* v: values)
for(ir::value* v: values_)
extract_io_use(v, ptr);
order.resize(axes.size());
std::iota(order.begin(), order.end(), 0);
order_.resize(axes_.size());
std::iota(order_.begin(), order_.end(), 0);
auto largest = std::max_element(ptr.begin(), ptr.end(), [&](ir::value *x, ir::value *y){
return x->get_type()->get_tile_rank() < y->get_type()->get_tile_rank();
});
if(*largest){
auto max_contiguous = align->contiguous(*largest);
std::sort(order.begin(), order.end(), [&](unsigned a, unsigned b) {
std::sort(order_.begin(), order_.end(), [&](unsigned a, unsigned b) {
return max_contiguous[a] > max_contiguous[b];
});
}
}
// downcast
layout_hmma_884_t* layout_t::to_hmma884() {
assert(type == HMMA_884);
return static_cast<layout_hmma_884_t*>(this);
size_t data_layout::find_axis(int to_find) const {
auto it = std::find(axes_.begin(), axes_.end(), to_find);
return std::distance(axes_.begin(), it);
}
layout_scanline_t* layout_t::to_scanline() {
assert(type == SCANLINE);
return static_cast<layout_scanline_t*>(this);
}
layout_shared_t* layout_t::to_shared() {
assert(type == SHARED);
return static_cast<layout_shared_t*>(this);
}
/* -------------------------------- *
* MMA Layout *
* -------------------------------- */
inline unsigned clamp(unsigned x, unsigned lo, unsigned hi) {
return std::min(std::max(x, lo), hi);
}
layout_hmma_884_t::layout_hmma_884_t(size_t num_warps,
const std::vector<int>& _axes,
const std::vector<unsigned>& _shapes,
const std::vector<ir::value *> &values, ir::type *_ty,
analysis::align* align): layout_t(HMMA_884, _axes, _shapes, values, _ty, align) {
unsigned shape_0 = shapes[0];
unsigned shape_1 = shapes[1];
mma884_layout::mma884_layout(size_t num_warps,
const std::vector<int>& axes,
const std::vector<unsigned>& shape,
const std::vector<ir::value *> &values,
analysis::align* align): data_layout(HMMA_884, axes, shape, values, align) {
/* fragments per warp */
// try to make things as square as possible to maximize data re-use
fpw = {1, 1, 1};
fpw_ = {1, 1, 1};
std::vector<int> fpw_nm1;
unsigned num_fragments = std::min<unsigned>((shape_0/8)*(shape_1/8), 4);
unsigned num_fragments = std::min<unsigned>((shape_[0]/8)*(shape_[1]/8), 4);
do {
fpw_nm1 = fpw;
if(fpw[0]*fpw[1] < num_fragments)
fpw[0] = clamp(fpw[0]*2, 1, shape_0 / 8);
if(fpw[0]*fpw[1] < num_fragments)
fpw[1] = clamp(fpw[1]*2, 1, shape_1 / 8);
}while(fpw_nm1 != fpw);
fpw_nm1 = fpw_;
if(fpw_[0]*fpw_[1] < num_fragments)
fpw_[0] = clamp(fpw_[0]*2, 1, shape_[0] / 8);
if(fpw_[0]*fpw_[1] < num_fragments)
fpw_[1] = clamp(fpw_[1]*2, 1, shape_[1] / 8);
}while(fpw_nm1 != fpw_);
/* warps per tile */
// try to make things as square as possible to maximize data re-use
wpt = {1, 1, 1};
wpt_ = {1, 1, 1};
std::vector<int> wpt_nm1;
do{
wpt_nm1 = wpt;
if(wpt[0] * wpt[1] * wpt[2] < num_warps)
wpt[0] = clamp(wpt[0]*2, 1, shape_0 / (fpw[0]*8));
if(wpt[0] * wpt[1] * wpt[2] < num_warps)
wpt[1] = clamp(wpt[1]*2, 1, shape_1 / (fpw[1]*8));
}while(wpt_nm1 != wpt);
wpt_nm1 = wpt_;
if(wpt_[0] * wpt_[1] * wpt_[2] < num_warps)
wpt_[0] = clamp(wpt_[0]*2, 1, shape_[0] / (fpw_[0]*8));
if(wpt_[0] * wpt_[1] * wpt_[2] < num_warps)
wpt_[1] = clamp(wpt_[1]*2, 1, shape_[1] / (fpw_[1]*8));
}while(wpt_nm1 != wpt_);
/* sanity check */
unsigned effective_num_warps = 1;
for(size_t d = 0; d < shapes.size(); d++)
effective_num_warps *= wpt[d];
for(size_t d = 0; d < shape.size(); d++)
effective_num_warps *= wpt_[d];
if(num_warps != effective_num_warps)
throw std::runtime_error("cannot create a kernel with this amount of warps");
}
/* -------------------------------- *
* Scanline Layout *
* -------------------------------- */
layout_scanline_t::layout_scanline_t(size_t num_warps,
const std::vector<int>& _axes,
const std::vector<unsigned>& _shapes,
const std::vector<ir::value *> &values, ir::type *_ty,
analysis::align* align): layout_t(SCANLINE, _axes, _shapes, values, _ty, align){
unsigned size = std::accumulate(shapes.begin(), shapes.end(), 1, std::multiplies<int>());
scanline_layout::scanline_layout(size_t num_warps,
const std::vector<int>& axes,
const std::vector<unsigned>& shape,
const std::vector<ir::value *> &values,
analysis::align* align): data_layout(SCANLINE, axes, shape, values, align){
unsigned size = std::accumulate(shape_.begin(), shape_.end(), 1, std::multiplies<int>());
unsigned num_threads = num_warps * 32;
nts.resize(shapes.size());
mts.resize(shapes.size());
nts_.resize(shape_.size());
mts_.resize(shape_.size());
bool is_dot = std::any_of(values.begin(), values.end(),
[&](ir::value* v) { return dynamic_cast<ir::dot_inst*>(v); });
@@ -238,34 +178,39 @@ layout_scanline_t::layout_scanline_t(size_t num_warps,
if(auto *st = dynamic_cast<ir::store_inst*>(usr))
ptr = st->get_pointer_operand();
unsigned i = order[0];
unsigned i = order_[0];
int contiguous = 4;
if(ptr)
contiguous = std::min<int>(align->contiguous(ptr)[i], 4);
nts[i] = clamp(size / num_threads, 1, std::min<int>(contiguous, shapes[i]));
mts[i] = clamp(num_threads, 1, shapes[i] / nts[i]);
size /= shapes[i];
num_threads /= mts[i];
nts_[i] = clamp(size / num_threads, 1, std::min<int>(contiguous, shape_[i]));
mts_[i] = clamp(num_threads, 1, shape_[i] / nts_[i]);
size /= shape_[i];
num_threads /= mts_[i];
if(is_dot)
nts[order[1]] = clamp(size / num_threads, 1, std::min<int>(4, shapes[order[1]]));
for(size_t d = 1; d < shapes.size(); d++){
i = order[d];
nts_[order_[1]] = clamp(size / num_threads, 1, std::min<int>(4, shape_[order_[1]]));
for(size_t d = 1; d < shape_.size(); d++){
i = order_[d];
if(d > 1 || !is_dot)
nts[i] = 1;
mts[i] = clamp(num_threads, 1, shapes[i] / nts[i]);
num_threads = num_threads / mts[i];
nts_[i] = 1;
mts_[i] = clamp(num_threads, 1, shape_[i] / nts_[i]);
num_threads = num_threads / mts_[i];
}
/* sanity check */
unsigned effective_num_threads = 1;
for(size_t d = 0; d < shapes.size(); d++)
effective_num_threads *= mts[d];
for(size_t d = 0; d < shape_.size(); d++)
effective_num_threads *= mts_[d];
if(num_warps * 32 != effective_num_threads)
throw std::runtime_error("cannot create a kernel with this amount of warps");
}
inline bool is_loop_latch(ir::phi_node *phi, ir::instruction *terminator){
/* -------------------------------- *
* Shared Layout *
* -------------------------------- */
bool shared_layout::is_loop_latch(ir::phi_node *phi, ir::instruction *terminator){
if(phi->get_parent() != terminator->get_parent())
return false;
if(auto *br = dynamic_cast<ir::cond_branch_inst*>(terminator))
@@ -278,7 +223,7 @@ inline bool is_loop_latch(ir::phi_node *phi, ir::instruction *terminator){
}
void extract_double_bufferable(ir::value *v, std::shared_ptr<double_buffer_info_t>& res) {
void shared_layout::extract_double_bufferable(ir::value *v, std::shared_ptr<double_buffer_info_t>& res) {
auto* phi = dynamic_cast<ir::phi_node*>(v);
if(!phi || phi->get_num_incoming() != 2)
return;
@@ -303,22 +248,22 @@ void extract_double_bufferable(ir::value *v, std::shared_ptr<double_buffer_info_
}
layout_shared_t::layout_shared_t(const layout_t *arg,
const std::vector<int>& _axes,
const std::vector<unsigned>& _shapes,
shared_layout::shared_layout(const data_layout *arg,
const std::vector<int>& axes,
const std::vector<unsigned>& shape,
const std::vector<ir::value *> &values,
ir::type *ty,
analysis::align* align): layout_t(SHARED, _axes, _shapes, values, ty, align) {
analysis::align* align): data_layout(SHARED, axes, shape, values, align), ty_(ty) {
size = 0;
size_ = 0;
// double-buffering
for(ir::value *v: values)
extract_double_bufferable(v, double_buffer);
extract_double_bufferable(v, double_buffer_);
// order
std::vector<int> arg_order = arg ? arg->order : std::vector<int>{0};
order = arg_order;
std::vector<int> arg_order = arg ? arg->get_order() : std::vector<int>{0};
order_ = arg_order;
ir::value* dot_a = nullptr;
ir::value* dot_b = nullptr;
@@ -330,48 +275,84 @@ layout_shared_t::layout_shared_t(const layout_t *arg,
extract_hmma_dot_use(v, hmma_dot_a, 0);
extract_hmma_dot_use(v, hmma_dot_b, 1);
}
// non-mma ordering
std::vector<int> col = {0, 1};
std::vector<int> row = {1, 0};
for(size_t s = 2; s < shapes.size(); s++){
for(size_t s = 2; s < get_rank(); s++){
col.push_back(s);
row.push_back(s);
}
bool is_nonhmma_dot_a = dot_a && !hmma_dot_a;
bool is_nonhmma_dot_b = dot_b && !hmma_dot_b;
if(is_nonhmma_dot_a)
order = is_trans(dot_a) ? row : col;
order_ = is_trans(dot_a) ? row : col;
else if(is_nonhmma_dot_b)
order = is_trans(dot_b) ? col : row;
// else
// order = row;
order_ = is_trans(dot_b) ? col : row;
// padding
size_t pad = 0;
if(hmma_dot_a){
bool row = is_trans(hmma_dot_a) ^ order[0] != 0;
pad = 24 - shapes[row ? 0 : 1] % 32;
bool row = is_trans(hmma_dot_a) ^ order_[0] != 0;
pad = 24 - shape_[row ? 0 : 1] % 32;
}
else if(hmma_dot_b){
bool row = is_trans(hmma_dot_b) ^ order[0] != 0;
pad = 24 - shapes[row ? 1 : 0] % 32;
bool row = is_trans(hmma_dot_b) ^ order_[0] != 0;
pad = 24 - shape_[row ? 1 : 0] % 32;
}
else if(order != arg_order) {
else if(order_ != arg_order) {
pad = 4;
}
shapes[order[0]] += pad;
shape_[order_[0]] += pad;
// size
size = ty->get_primitive_size_in_bits() / 8;
for(auto s: shapes)
size *= s;
if(double_buffer)
size *= 2;
size_ = ty_->get_primitive_size_in_bits() / 8;
for(auto s: shape_)
size_ *= s;
if(double_buffer_)
size_ *= 2;
}
// layout factory method
void layout::create(size_t id, const std::vector<ir::value*>& values) {
/* -------------------------------- *
* ---- Layouts Inference Pass ---- *
* -------------------------------- */
layouts::layouts(analysis::axes *axes, analysis::align *align, size_t num_warps)
: axes_(axes), align_(align), num_warps_(num_warps) { }
void layouts::connect(ir::value *x, ir::value *y) {
if(x == y)
return;
if(!x->get_type()->is_tile_ty())
return;
if(!y->get_type()->is_tile_ty())
return;
std::vector<int> x_axes = axes_->get(x);
std::vector<int> y_axes = axes_->get(y);
std::set<int> sx_axes(x_axes.begin(), x_axes.end());
std::set<int> sy_axes(y_axes.begin(), y_axes.end());
std::set<int> common;
std::set_intersection(sx_axes.begin(), sx_axes.end(),
sy_axes.begin(), sy_axes.end(),
std::inserter(common, common.begin()));
graph_.add_edge(x, x);
graph_.add_edge(y, y);
if(!common.empty())
graph_.add_edge(x, y);
}
void layouts::make_graph(ir::instruction *i) {
for(ir::value* opx: i->ops())
for(ir::value* opy: i->ops()){
connect(i, opx);
connect(opx, opy);
}
}
void layouts::create(size_t id, const std::vector<ir::value*>& values) {
auto it_hmma_c = std::find_if(values.begin(), values.end(), &is_hmma_c);
auto cmp = [](ir::value* x, ir::value *y) {
return x->get_type()->get_tile_ranks1() <
@@ -387,18 +368,18 @@ void layout::create(size_t id, const std::vector<ir::value*>& values) {
});
// type
if(it_hmma_c != values.end())
layouts_[id] = new layout_hmma_884_t(num_warps_, axes, shapes, values, largest->get_type()->get_scalar_ty(), align_);
layouts_[id] = new mma884_layout(num_warps_, axes, shapes, values, align_);
else if(it_cts != values.end()){
ir::copy_to_shared_inst *cts = (ir::copy_to_shared_inst*)*it_cts;
ir::value *arg = cts->get_operand(0);
create(groups_.at(arg), values_.at(groups_.at(arg)));
layouts_[id] = new layout_shared_t(get(arg), axes, shapes, values, largest->get_type()->get_scalar_ty(), align_);
layouts_[id] = new shared_layout(get(arg), axes, shapes, values, largest->get_type()->get_scalar_ty(), align_);
}
else
layouts_[id] = new layout_scanline_t(num_warps_, axes, shapes, values, largest->get_type()->get_scalar_ty(), align_);
layouts_[id] = new scanline_layout(num_warps_, axes, shapes, values, align_);
}
void layout::run(ir::module &mod) {
void layouts::run(ir::module &mod) {
// make graph
graph_.clear();
ir::for_each_instruction(mod, [this](ir::instruction* i) {
@@ -422,35 +403,35 @@ void layout::run(ir::module &mod) {
// shape
auto shapes = arg->get_type()->get_tile_shapes();
unsigned shape_ax = shapes[axis];
layout_scanline_t *layout = get(arg)->to_scanline();
unsigned per_thread = layout->nts[axis];
scanline_layout *layout = get(arg)->to_scanline();
unsigned per_thread = layout->nts(axis);
unsigned depth = shape_ax / per_thread;
shapes[axis] = depth;
// create layout
layouts_[id] = new layout_shared_t(layout, axes_->get(arg), shapes, {red}, red->get_type()->get_scalar_ty(), align_);
layouts_[id] = new shared_layout(layout, axes_->get(arg), shapes, {red}, red->get_type()->get_scalar_ty(), align_);
tmp_[red] = id;
}
if(auto *recoalasce = dynamic_cast<ir::recoalesce_inst*>(i)){
ir::value *val = recoalasce->get_operand(0);
layout_t* in_layout = get(val);
layout_t* out_layout = get(i);
if(in_layout->type != HMMA_884)
mma884_layout* in_layout = get(val)->to_mma884();
scanline_layout* out_layout = get(i)->to_scanline();
if(!in_layout || !out_layout)
return;
id++;
ir::type::tile_shapes_t in_shape = val->get_type()->get_tile_shapes();
ir::type::tile_shapes_t shape(in_shape.size());
size_t ld = out_layout->order[0];
size_t ld = out_layout->get_order(0);
shape[ld] = in_shape[ld];
for(size_t k = 0; k < in_shape.size(); k++)
if(k != ld)
shape[k] = 4*in_layout->to_hmma884()->fpw[k]*in_layout->to_hmma884()->wpt[k];
shape[k] = 4*in_layout->to_mma884()->fpw(k)*in_layout->to_mma884()->wpt(k);
// create layout
layouts_[id] = new layout_shared_t(out_layout, axes_->get(val), shape, {recoalasce}, val->get_type()->get_scalar_ty(), align_);
layouts_[id] = new shared_layout(out_layout, axes_->get(val), shape, {recoalasce}, val->get_type()->get_scalar_ty(), align_);
tmp_[recoalasce] = id;
}
if(auto *atom = dynamic_cast<ir::atomic_cas_inst*>(i)){
id++;
layouts_[id] = new layout_shared_t(nullptr, {}, {1}, {atom}, atom->get_type()->get_scalar_ty(), align_);
layouts_[id] = new shared_layout(nullptr, {}, {1}, {atom}, atom->get_type()->get_scalar_ty(), align_);
tmp_[atom] = id;
}
});

View File

@@ -27,18 +27,18 @@ void liveness::run(ir::module &mod) {
// create live intervals
for(auto &x: layouts_->get_all()) {
if(x.second->type != SHARED)
shared_layout* layout = x.second->to_shared();
if(!layout)
continue;
layout_shared_t* layout = x.second->to_shared();
// users
std::set<ir::user*> users;
for(ir::value *v: layout->values){
for(ir::value *v: layout->get_values()){
for(ir::user *u: v->get_users())
users.insert(u);
}
// compute intervals
unsigned start = INT32_MAX;
for(ir::value *v: layout->values)
for(ir::value *v: layout->get_values())
if(indices.find(v) != indices.end())
start = std::min(start, indices.at(v));
unsigned end = 0;

View File

@@ -174,7 +174,7 @@ inline bool is_trans(ir::value *v) {
generator::generator(analysis::axes *a_axes,
analysis::layout *layouts,
analysis::layouts *layouts,
analysis::align *alignment,
analysis::allocation *alloc,
target *tgt,
@@ -295,7 +295,7 @@ void generator::visit_unmasked_load_inst(ir::unmasked_load_inst* x) {
}
// find vector size
ir::value *ptr = x->get_pointer_operand();
size_t ld = layouts_->get(ptr)->order[0];
size_t ld = layouts_->get(ptr)->get_order(0);
unsigned alignment = std::max<int>(alignment_->get(ptr, ld), 1);
@@ -337,7 +337,7 @@ void generator::visit_unmasked_load_inst(ir::unmasked_load_inst* x) {
void generator::visit_masked_load_inst(ir::masked_load_inst* x) {
// find vector size
ir::value *ptr = x->get_pointer_operand();
size_t ld = layouts_->get(ptr)->order[0];
size_t ld = layouts_->get(ptr)->get_order(0);
unsigned alignment = alignment_->get(ptr, ld);
distributed_tile *pointers = (distributed_tile*)tmap_.at(ptr);
distributed_tile *masks = (distributed_tile*)tmap_.at(x->get_mask_operand());
@@ -603,7 +603,7 @@ void generator::visit_atomic_add_inst(ir::atomic_add_inst*) {
void generator::visit_hmma_dot(ir::dot_inst* dot, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK) {
const auto& shapes = dot->get_type()->get_tile_shapes();
machine_layout_hmma_884_t* hmma = (machine_layout_hmma_884_t*)machine_layouts_.at(layouts_->get(dot));
machine_mma884_layout* hmma = (machine_mma884_layout*)machine_layouts_.at(layouts_->get(dot));
TA->set_vector_size(4*hmma->pack_size_0_);
TB->set_vector_size(4*hmma->pack_size_1_);
TA->set_return_mode(true);
@@ -625,8 +625,8 @@ void generator::visit_hmma_dot(ir::dot_inst* dot, shared_tile *TA, shared_tile *
Value* u_thread_id = tgt_->get_local_id(builder_->GetInsertBlock()->getModule(), *builder_, 0);
auto ord_a = layouts_->get(dot->get_operand(0))->order;
auto ord_b = layouts_->get(dot->get_operand(1))->order;
auto ord_a = layouts_->get(dot->get_operand(0))->get_order();
auto ord_b = layouts_->get(dot->get_operand(1))->get_order();
bool is_a_trans = is_trans(dot->get_operand(0));
bool is_b_trans = is_trans(dot->get_operand(1));
@@ -655,14 +655,14 @@ void generator::visit_hmma_dot(ir::dot_inst* dot, shared_tile *TA, shared_tile *
"{$8, $9}, "
"{$10, $11}, "
"{$0, $1, $2, $3, $4, $5, $6, $7};", "=f,=f,=f,=f,=f,=f,=f,=f,r,r,r,r,0,1,2,3,4,5,6,7", false);
analysis::layout_hmma_884_t* layout = layouts_->get(dot)->to_hmma884();
analysis::mma884_layout* layout = layouts_->get(dot)->to_mma884();
unsigned fpw_0 = layout->fpw.at(0);
unsigned fpw_1 = layout->fpw.at(1);
unsigned fpw_0 = layout->fpw(0);
unsigned fpw_1 = layout->fpw(1);
unsigned wts_0 = fpw_0 * 8;
unsigned wts_1 = fpw_1 * 8;
unsigned wpt_0 = layout->wpt.at(0);
unsigned wpt_1 = layout->wpt.at(1);
unsigned wpt_0 = layout->wpt(0);
unsigned wpt_1 = layout->wpt(1);
unsigned stride_rep_i = wpt_0 * wts_0;
unsigned stride_rep_j = wpt_1 * wts_1;
unsigned num_rep_i = shapes[0] / stride_rep_i;
@@ -792,7 +792,7 @@ void generator::visit_dot_inst(ir::dot_inst* dot) {
if(NK != 1) {
shared_tile *TA = (shared_tile*)tmap_.at(A);
shared_tile *TB = (shared_tile*)tmap_.at(B);
if(layouts_->get(dot)->type == analysis::HMMA_884)
if(layouts_->get(dot)->to_mma884())
visit_hmma_dot(dot, TA, TB, TD, NK);
else
visit_scanline_dot(dot, TA, TB, TD, NK, c_ty, f_mul_add);
@@ -856,7 +856,7 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) {
});
// reduce within blocks
machine_layout_t *slayout = machine_layouts_.at(layouts_->get(layouts_->tmp(x)));
machine_data_layout *slayout = machine_layouts_.at(layouts_->get(layouts_->tmp(x)));
shared_tile *stile = (shared_tile*)slayout->create(x);
unsigned depth = stile->get_shapes()[axis];
@@ -926,31 +926,31 @@ void generator::visit_recoalesce_inst(ir::recoalesce_inst* rc) {
// pointer to temporary shared memory
Type *ty = llvm_type(rc->get_type()->get_scalar_ty(), *ctx_);
// layouts
analysis::layout_hmma_884_t* in_layout = layouts_->get(op)->to_hmma884();
analysis::layout_scanline_t* out_layout = layouts_->get(rc)->to_scanline();
analysis::mma884_layout* in_layout = layouts_->get(op)->to_mma884();
analysis::scanline_layout* out_layout = layouts_->get(rc)->to_scanline();
// machine tiles
distributed_tile *in_dt = (distributed_tile*)(tmap_.at(op));
distributed_tile *out_dt = (distributed_tile*)(tmap_.at(rc));
// WMMA configuration
long wmma_pt[3] = { 2, 4, 1};
long wmma[3] = { 8*in_layout->wpt[0]*in_layout->fpw[0],
8*in_layout->wpt[1]*in_layout->fpw[1],
long wmma[3] = { 8*in_layout->wpt(0)*in_layout->fpw(0),
8*in_layout->wpt(1)*in_layout->fpw(1),
1};
// Work per thread for input layout
long in_pt[3] = { shape[0] / wmma[0],
shape[1] / wmma[1],
1 };
// Work per thread for output layout
long out_pt[3] = { shape[0] / out_layout->mts[0],
shape[1] / out_layout->mts[1],
long out_pt[3] = { shape[0] / out_layout->mts(0),
shape[1] / out_layout->mts(1),
1 };
if(rank > 2){
wmma[2] = in_layout->wpt[2]*in_layout->fpw[2];
wmma[2] = in_layout->wpt(2)*in_layout->fpw(2);
in_pt[2] = shape[2] / wmma[2];
out_pt[2] = shape[2] / out_layout->mts[2];
out_pt[2] = shape[2] / out_layout->mts(2);
}
// Orders
auto ord = out_layout->order;
auto ord = out_layout->get_order();
if(ord.size() < 3)
ord.push_back(2);
// pointer lanes
@@ -1028,13 +1028,13 @@ void generator::visit_recoalesce_inst(ir::recoalesce_inst* rc) {
void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) {
unsigned vector_size = 1;
ir::value *arg = cts->get_operand(0);
analysis::layout_shared_t* out_layout = layouts_->get(cts)->to_shared();
analysis::layout_scanline_t* in_layout = layouts_->get(arg)->to_scanline();
auto out_order = out_layout->order;
auto in_order = in_layout->order;
analysis::shared_layout* out_layout = layouts_->get(cts)->to_shared();
analysis::scanline_layout* in_layout = layouts_->get(arg)->to_scanline();
auto out_order = out_layout->get_order();
auto in_order = in_layout->get_order();
// tiles
if(out_order == in_order)
vector_size = in_layout->nts.at(in_order[0]);
vector_size = in_layout->nts(in_order[0]);
std::map<unsigned, Value*> packets;
for_each(arg, [&](indices_t idx){
@@ -1180,17 +1180,17 @@ void generator::visit_function(ir::function* fn) {
void generator::visit_layout_hmma_884(analysis::layout_hmma_884_t* layout) {
machine_layouts_[layout] = new machine_layout_hmma_884_t(mod_, &*builder_, tgt_, llvm_type(layout->ty->get_scalar_ty(), *ctx_), a_axes_, axes_, layout);
void generator::visit_layout_hmma_884(analysis::mma884_layout* layout) {
machine_layouts_[layout] = new machine_mma884_layout(mod_, &*builder_, tgt_, a_axes_, axes_, layout);
}
void generator::visit_layout_scanline(analysis::layout_scanline_t* layout) {
machine_layouts_[layout] = new machine_layout_scanline_t(mod_, &*builder_, tgt_, llvm_type(layout->ty->get_scalar_ty(), *ctx_), a_axes_, axes_, layout);
void generator::visit_layout_scanline(analysis::scanline_layout* layout) {
machine_layouts_[layout] = new machine_scanline_layout(mod_, &*builder_, tgt_, a_axes_, axes_, layout);
}
void generator::visit_layout_shared(analysis::layout_shared_t* layout) {
void generator::visit_layout_shared(analysis::shared_layout* layout) {
machine_layouts_[layout] = new machine_layout_shared_t(mod_, &*builder_, tgt_, alloc_, sh_mem_ptr_, layout, vmap_, tmap_);
machine_layouts_[layout] = new machine_shared_layout(mod_, &*builder_, tgt_, alloc_, sh_mem_ptr_, layout, vmap_, tmap_);
}
void generator::visit_basic_block(ir::basic_block * block) {
@@ -1230,9 +1230,9 @@ void generator::set_value(ir::value *x, const indices_t& idx, Value* v) {
}
void generator::finalize_shared_layout(analysis::layout_shared_t *shared) {
if(shared->double_buffer) {
auto info = *shared->double_buffer;
void generator::finalize_shared_layout(analysis::shared_layout *shared) {
if(shared->get_double_buffer()) {
auto info = *shared->get_double_buffer();
ir::phi_node *phi = info.phi;
PHINode *ptr = (PHINode*)((shared_tile*)tmap_.at(phi))->get_pointer();
PHINode *offset = (PHINode*)((shared_tile*)tmap_.at(phi))->get_offset();
@@ -1247,8 +1247,8 @@ void generator::finalize_shared_layout(analysis::layout_shared_t *shared) {
offset->addIncoming(next_offset, llvm_inc_block);
}
else {
unsigned num_bytes = shared->ty->get_primitive_size_in_bits() / 8;
offset->addIncoming(builder_->getInt32(shared->size / (2*num_bytes)), llvm_inc_block);
unsigned num_bytes = shared->get_type()->get_primitive_size_in_bits() / 8;
offset->addIncoming(builder_->getInt32(shared->get_size() / (2*num_bytes)), llvm_inc_block);
}
ptr->addIncoming(inc_shared->get_pointer(), llvm_inc_block);
}
@@ -1258,7 +1258,7 @@ void generator::finalize_shared_layout(analysis::layout_shared_t *shared) {
void generator::finalize_function(ir::function *fn) {
// finalize double-buffering
for(const auto& x: layouts_->get_all())
if(auto *shared = dynamic_cast<analysis::layout_shared_t*>(x.second))
if(auto *shared = dynamic_cast<analysis::shared_layout*>(x.second))
finalize_shared_layout(shared);
// finalize phi
for(ir::basic_block *block: fn->blocks())

View File

@@ -71,18 +71,18 @@ inline int32_t ceil(int32_t num, int32_t div){
machine_layout_shared_t::machine_layout_shared_t(Module *mod, Builder *builder, target *tgt, analysis::allocation* alloc,
Value *&sh_mem_ptr, analysis::layout_shared_t *layout,
machine_shared_layout::machine_shared_layout(Module *mod, Builder *builder, target *tgt, analysis::allocation* alloc,
Value *&sh_mem_ptr, analysis::shared_layout *layout,
std::map<ir::value *, Value *>& vmap,
std::map<ir::value *, tile *>& tmap)
: mod_(mod), builder_(builder), tgt_(tgt), alloc_(alloc), sh_mem_ptr_(sh_mem_ptr), layout_(layout), vmap_(vmap), tmap_(tmap) {
Type* ty = llvm_type(layout_->ty, builder_->getContext());
Type* ty = llvm_type(layout_->get_type(), builder_->getContext());
PointerType *ptr_ty = ty->getPointerTo(sh_mem_ptr_->getType()->getPointerAddressSpace());
// double-buffered
if(layout_->double_buffer) {
if(layout_->get_double_buffer()) {
BasicBlock *current = builder_->GetInsertBlock();
auto info = *layout_->double_buffer;
auto info = *layout_->get_double_buffer();
ir::phi_node *phi = info.phi;
BasicBlock *parent = (BasicBlock*)vmap_.at((ir::value*)(phi->get_parent()));
if(parent->empty())
@@ -105,31 +105,31 @@ machine_layout_shared_t::machine_layout_shared_t(Module *mod, Builder *builder,
}
tile* machine_layout_shared_t::create(ir::value *v) {
auto order = layout_->order;
auto shapes = layout_->shapes;
Type* ty = llvm_type(layout_->ty, builder_->getContext());
// double-buffered
if(layout_->double_buffer) {
if(v == layout_->double_buffer->phi)
return new shared_tile(ty, shapes, order, ptr_, *builder_, offset_);
if(v == layout_->double_buffer->latch)
return new shared_tile(ty, shapes, order, next_ptr_, *builder_);
return new shared_tile(ty, shapes, order, pre_ptr_, *builder_);
}
else {
return new shared_tile(ty, shapes, order, ptr_, *builder_);
}
tile* machine_shared_layout::create(ir::value *v) {
Type* ty = llvm_type(layout_->get_type(), builder_->getContext());
auto double_buffer = layout_->get_double_buffer();
// offset
Value *offset = nullptr;
if(double_buffer && v == double_buffer->phi)
offset = offset_;
// base pointer
Value *ptr = ptr_;
if(double_buffer && v == double_buffer->latch)
ptr = next_ptr_;
else if(double_buffer && v == double_buffer->first)
ptr = pre_ptr_;
// create tile
return new shared_tile(ty, layout_->get_shape(), layout_->get_order(), ptr, *builder_, offset);
}
machine_layout_distributed_t::machine_layout_distributed_t(Module *mod, Builder *builder, target *tgt, Type *ty,
machine_distributed_layout::machine_distributed_layout(Module *mod, Builder *builder, target *tgt,
analysis::axes *a_axes, std::map<unsigned, distributed_axis>& axes,
analysis::layout_t *layout)
: mod_(mod), builder_(builder), tgt_(tgt), ty_(ty), a_axes_(a_axes), axes_(axes), layout_(layout) {
analysis::data_layout *layout)
: mod_(mod), builder_(builder), tgt_(tgt), a_axes_(a_axes), axes_(axes), layout_(layout) {
}
tile *machine_layout_distributed_t::create(ir::value *v) {
tile *machine_distributed_layout::create(ir::value *v) {
Type *ty = llvm_type(v->get_type()->get_scalar_ty(), builder_->getContext());
const auto &shapes = v->get_type()->get_tile_shapes();
size_t rank = shapes.size();
@@ -151,12 +151,10 @@ tile *machine_layout_distributed_t::create(ir::value *v) {
auto cmp = [&](int x, int y) {
unsigned axx = a_axes_->get(v, x);
unsigned axy = a_axes_->get(v, y);
auto itx = std::find(layout_->axes.begin(), layout_->axes.end(), axx);
auto ity = std::find(layout_->axes.begin(), layout_->axes.end(), axy);
size_t posx = std::distance(layout_->axes.begin(), itx);
size_t posy = std::distance(layout_->axes.begin(), ity);
size_t posx = layout_->find_axis(axx);
size_t posy = layout_->find_axis(axy);
if(posx < rank && posy < rank)
return layout_->order[posx] < layout_->order[posy];
return layout_->get_order(posx) < layout_->get_order(posy);
return false;
};
std::sort(order.begin(), order.end(), cmp);
@@ -164,22 +162,21 @@ tile *machine_layout_distributed_t::create(ir::value *v) {
return new distributed_tile(ty, shapes, order, axes, *builder_);
}
machine_layout_hmma_884_t::machine_layout_hmma_884_t(Module *mod, Builder *builder,
target *tgt, Type *ty, analysis::axes *a_axes,
machine_mma884_layout::machine_mma884_layout(Module *mod, Builder *builder,
target *tgt, analysis::axes *a_axes,
std::map<unsigned, distributed_axis>& axes,
analysis::layout_hmma_884_t* layout)
: machine_layout_distributed_t(mod, builder, tgt, ty, a_axes, axes, layout) {
analysis::mma884_layout* layout)
: machine_distributed_layout(mod, builder, tgt, a_axes, axes, layout) {
Value *warp_size = builder_->getInt32(32);
Value* u_thread_id_0 = tgt_->get_local_id(mod_, *builder_, 0);
Value *u_thread_id = builder_->CreateURem(u_thread_id_0, warp_size);
Value *u_warp_id = builder_->CreateUDiv(u_thread_id_0, warp_size);
const auto& shapes = layout->shapes;
if(shapes.size() > 3)
const auto& shape = layout->get_shape();
if(shape.size() > 3)
throw std::runtime_error("unsupported");
bool is_batched = shapes.size() >= 3;
bool is_batched = shape.size() >= 3;
Value *_1 = builder_->getInt32(1);
Value *_2 = builder_->getInt32(2);
@@ -188,13 +185,13 @@ machine_layout_hmma_884_t::machine_layout_hmma_884_t(Module *mod, Builder *build
Value *_16 = builder_->getInt32(16);
// fragments per warp
unsigned fpw_0 = layout->fpw.at(0);
unsigned fpw_1 = layout->fpw.at(1);
unsigned fpw_2 = is_batched ? layout->fpw.at(2) : 1;
unsigned fpw_0 = layout->fpw(0);
unsigned fpw_1 = layout->fpw(1);
unsigned fpw_2 = is_batched ? layout->fpw(2) : 1;
// warps per tile
unsigned wpt_0 = layout->wpt.at(0);
unsigned wpt_1 = layout->wpt.at(1);
unsigned wpt_2 = is_batched ? layout->wpt.at(2) : 1;
unsigned wpt_0 = layout->wpt(0);
unsigned wpt_1 = layout->wpt(1);
unsigned wpt_2 = is_batched ? layout->wpt(2) : 1;
// hmma warp tile size
unsigned hmma_wts_0 = fpw_0 * 8;
unsigned hmma_wts_1 = fpw_1 * 8;
@@ -204,9 +201,9 @@ machine_layout_hmma_884_t::machine_layout_hmma_884_t(Module *mod, Builder *build
unsigned hmma_bts_1 = hmma_wts_1 * wpt_1;
unsigned hmma_bts_2 = is_batched ? hmma_wts_2 * wpt_2 : 1;
// number of repetition
unsigned num_rep_0 = shapes[0] / hmma_bts_0;
unsigned num_rep_1 = shapes[1] / hmma_bts_1;
unsigned num_rep_2 = is_batched ? shapes[2] / hmma_bts_2 : 1;
unsigned num_rep_0 = shape[0] / hmma_bts_0;
unsigned num_rep_1 = shape[1] / hmma_bts_1;
unsigned num_rep_2 = is_batched ? shape[2] / hmma_bts_2 : 1;
// size of each pack (interleaving)
pack_size_0_ = std::min<unsigned>(num_rep_0, 1);
pack_size_1_ = std::min<unsigned>(num_rep_1, 1);
@@ -275,44 +272,52 @@ machine_layout_hmma_884_t::machine_layout_hmma_884_t(Module *mod, Builder *build
/* axes */
axes_[layout->axes[0]] = distributed_axis{1, idx_i, warp_id_0};
axes_[layout->axes[1]] = distributed_axis{1, idx_j, warp_id_1};
axes_[layout->get_axis(0)] = distributed_axis{1, idx_i, warp_id_0};
axes_[layout->get_axis(1)] = distributed_axis{1, idx_j, warp_id_1};
if(is_batched)
axes_[layout->axes[2]] = distributed_axis{1, idx_z, warp_id_2};
axes_[layout->get_axis(2)] = distributed_axis{1, idx_z, warp_id_2};
}
machine_layout_scanline_t::machine_layout_scanline_t(Module *mod, Builder *builder,
target *tgt, Type *ty,
machine_scanline_layout::machine_scanline_layout(Module *mod, Builder *builder,
target *tgt,
analysis::axes *a_axes, std::map<unsigned, distributed_axis> &axes,
analysis::layout_scanline_t* layout)
: machine_layout_distributed_t(mod, builder, tgt, ty, a_axes, axes, layout) {
analysis::scanline_layout* layout)
: machine_distributed_layout(mod, builder, tgt, a_axes, axes, layout) {
Value *warp_size = builder_->getInt32(32);
Value* u_thread_id_0 = tgt_->get_local_id(mod_, *builder_, 0);
Value *u_thread_id = builder_->CreateURem(u_thread_id_0, warp_size);
Value *u_warp_id = builder_->CreateUDiv(u_thread_id_0, warp_size);
auto order = layout->order;
const auto& shapes = layout->shapes;
size_t dim = shapes.size();
std::vector<int> nts = layout->nts;
std::vector<int> mts = layout->mts;
auto order = layout->get_order();
const auto& shape = layout->get_shape();
Value* full_thread_id = builder_->CreateAdd(builder_->CreateMul(u_warp_id, builder_->getInt32(32)), u_thread_id);
std::vector<Value*> thread_id = delinearize(full_thread_id, order, mts, *builder_);
// Delinearize
size_t dim = shape.size();
std::vector<Value*> thread_id(dim);
for(unsigned k = 0; k < dim - 1; k++){
Constant *dim_k = builder_->getInt32(layout->mts(order[k]));
Value *rem = builder_->CreateURem(full_thread_id, dim_k);
full_thread_id = builder_->CreateUDiv(full_thread_id, dim_k);
thread_id[order[k]] = rem;
}
thread_id[order[dim - 1]] = full_thread_id;
// Create axes
for(unsigned k = 0; k < dim; k++) {
int nts = layout->nts(k);
int mts = layout->mts(k);
std::string str_k = std::to_string(k);
Value *contiguous_k = builder_->getInt32(nts[k]);
Value *contiguous_k = builder_->getInt32(nts);
Value *scaled_thread_id = builder_->CreateMul(thread_id[k], contiguous_k);
unsigned per_block = nts[k] * mts[k];
unsigned per_thread = nts[k] * shapes[k] / per_block;
unsigned per_block = nts * mts;
unsigned per_thread = nts * shape[k] / per_block;
std::vector<Value*> idx_list(per_thread);
for(unsigned n = 0 ; n < per_thread; n++){
unsigned offset = n / nts[k] * per_block + n % nts[k];
unsigned offset = n / nts * per_block + n % nts;
idx_list[n] = builder_->CreateAdd(scaled_thread_id, builder_->getInt32(offset), "idx_" + str_k + "_" + std::to_string(n));
}
axes_[layout->axes[k]] = distributed_axis{nts[k], idx_list, thread_id[k]};
axes_[layout->get_axis(k)] = distributed_axis{nts, idx_list, thread_id[k]};
}
}

View File

@@ -12,7 +12,7 @@ namespace triton {
namespace codegen{
namespace transform{
coalesce::coalesce(analysis::align* align, analysis::layout *layouts)
coalesce::coalesce(analysis::align* align, analysis::layouts *layouts)
: align_(align), layout_(layouts) { }
// Find all values that are used as pointer operands in LD/ST
@@ -64,8 +64,9 @@ ir::value* coalesce::rematerialize(ir::value *x, ir::builder &builder,
void coalesce::run(ir::module &mod) {
size_t num_groups = layout_->num_layouts();
for(size_t id = 0; id < num_groups; id++) {
if(layout_->get(id)->type != analysis::HMMA_884)
if(!layout_->get(id)->to_mma884())
continue;
// extract memory stores
const auto& values = layout_->values_of(id);
@@ -97,7 +98,6 @@ void coalesce::run(ir::module &mod) {
}
}
// find values to rematerialize
std::vector<ir::io_inst*> remat;
for(size_t id = 0; id < num_groups; id++) {
@@ -109,7 +109,7 @@ void coalesce::run(ir::module &mod) {
// extract leading axes
std::map<int, std::vector<ir::io_inst*>> axes;
for(ir::io_inst *i: io){
if(i->get_pointer_operand()->get_type()->get_tile_ranks1() == layout_->get(id)->axes.size())
if(i->get_pointer_operand()->get_type()->get_tile_ranks1() == layout_->get(id)->get_rank())
extract_ld(i, axes);
}
// update list of values to rematerialize

View File

@@ -35,10 +35,11 @@ void membar::add_reference(ir::value *v, interval_vec_t &res){
return;
if(!i->get_type()->is_tile_ty())
return;
if(alloc_->has_offset(layouts_->get(v))){
unsigned offset = alloc_->offset(layouts_->get(v));
unsigned size = layouts_->get(v)->to_shared()->size;
res.push_back(interval_t(offset, offset + size));
analysis::shared_layout* layout = layouts_->get(v)->to_shared();
assert(layout);
if(alloc_->has_offset(layout)){
unsigned offset = alloc_->offset(layout);
res.push_back(interval_t(offset, offset + layout->get_size()));
}
}
@@ -119,13 +120,11 @@ void membar::run(ir::module &mod) {
// without needing synchronization
std::set<ir::value*> safe_war;
for(const auto& x: layouts_->get_all()){
if(x.second->type != analysis::SHARED)
analysis::shared_layout* layout = x.second->to_shared();
if(!layout || !layout->get_double_buffer())
continue;
analysis::layout_shared_t* layout = x.second->to_shared();
if(!layout->double_buffer)
continue;
for(ir::value *v: layout->values)
if(v != layout->double_buffer->phi)
for(ir::value *v: layout->get_values())
if(v != layout->get_double_buffer()->phi)
safe_war.insert(v);
}

View File

@@ -220,7 +220,7 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
codegen::analysis::align align;
codegen::analysis::axes axes;
codegen::transform::disassociate disassociate;
codegen::analysis::layout layouts(&axes, &align, opt.num_warps);
codegen::analysis::layouts layouts(&axes, &align, opt.num_warps);
codegen::analysis::liveness liveness(&layouts);
codegen::analysis::allocation allocation(&liveness);
codegen::transform::membar barriers(&liveness, &layouts, &allocation);
@@ -239,7 +239,6 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
align.run(module);
cts.run(module);
axes.run(module);
// ir::print(module, std::cout);
layouts.run(module);
coalesce.run(module);
dce.run(module);
@@ -250,15 +249,14 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
dce.run(module);
align.run(module);
axes.run(module);
// ir::print(module, std::cout);
layouts.run(module);
liveness.run(module);
allocation.run(module);
if(allocation.allocated_size() > context->device()->max_shared_memory())
return std::unique_ptr<driver::module>();
barriers.run(module);
// ir::print(module, std::cout);
isel.visit(module, *llvm);
// return binary
std::unique_ptr<driver::module> res(driver::module::create(context, std::move(llvm)));
// done

View File

@@ -79,7 +79,7 @@ for N, T, H, S, E in NTHSE:
# 1D Dense convolution
NCHKR = [
# (1, 1152, 12602, 512, 3)
(1, 1152, 12602, 512, 3)
]
for N, C, H, K, R in NCHKR:
torch_fn = lambda a, b: torch.nn.functional.conv1d(a, b.permute(2, 0, 1))
@@ -92,10 +92,10 @@ for N, C, H, K, R in NCHKR:
# 2D Dense convolution
NCHWKRS = [
#(8, 64, 128, 128, 768, 3, 3),
#(8, 128, 64, 64, 256, 3, 3),
#(8, 256, 32, 32, 512, 3, 3),
#(8, 512, 32, 32, 1024, 3, 3)
(8, 64, 128, 128, 768, 3, 3),
(8, 128, 64, 64, 256, 3, 3),
(8, 256, 32, 32, 512, 3, 3),
(8, 512, 32, 32, 1024, 3, 3)
]
for N, C, H, W, K, R, S in NCHWKRS:
torch_fn = lambda a, b: torch.nn.functional.conv2d(a, b.permute(3, 0, 1, 2))
@@ -108,10 +108,10 @@ for N, C, H, W, K, R, S in NCHWKRS:
# 3D Dense Convolution
NCDHWKTRS = [
#(8, 32, 27, 100, 100, 64, 3, 3, 3),
#(8, 64, 23, 48, 48, 256, 3, 3, 3),
#(8, 256, 19, 22, 22, 640, 3, 3, 3),
#(8, 640, 15, 36, 36, 384, 3, 3, 3)
(8, 32, 27, 100, 100, 64, 3, 3, 3),
(8, 64, 23, 48, 48, 256, 3, 3, 3),
(8, 256, 19, 22, 22, 640, 3, 3, 3),
(8, 640, 15, 36, 36, 384, 3, 3, 3)
]
for N, C, D, H, W, K, T, R, S in NCDHWKTRS:
torch_fn = lambda a, b: torch.nn.functional.conv3d(a, b.permute(4, 0, 1, 2, 3))
@@ -168,7 +168,7 @@ for N, C, H, W, K, R, S in NCHWKRS:
# Benchmark
torch.set_num_threads(1)
for a_shape, b_shape, c_shape, torch_fn, expr, arrays in configs:
dtype = torch.cuda.FloatTensor
dtype = torch.cuda.HalfTensor
# initialize input tensors
a = torch.rand(*a_shape).type(dtype).cuda()
b = torch.rand(*b_shape).type(dtype).cuda()