[GENERAL] Cleaned polymorphic structure of layouts analysis pass

This commit is contained in:
Philippe Tillet
2020-01-20 15:15:32 -05:00
parent 382ca2c745
commit 78b98fb7cf
17 changed files with 500 additions and 480 deletions

View File

@@ -27,14 +27,14 @@ public:
allocation(liveness *live) allocation(liveness *live)
: liveness_(live) { } : liveness_(live) { }
// accessors // accessors
bool has_offset(const layout_t *x) const { return offsets_.find(x) != offsets_.end(); } bool has_offset(const data_layout *x) const { return offsets_.find(x) != offsets_.end(); }
unsigned offset(const layout_t *x) const { return offsets_.at(x); } unsigned offset(const data_layout *x) const { return offsets_.at(x); }
unsigned allocated_size() const { return allocated_size_; } unsigned allocated_size() const { return allocated_size_; }
// run // run
void run(ir::module& mod); void run(ir::module& mod);
private: private:
std::map<const layout_t*, unsigned> offsets_; std::map<const data_layout*, unsigned> offsets_;
size_t allocated_size_; size_t allocated_size_;
// dependences // dependences
liveness *liveness_; liveness *liveness_;

View File

@@ -22,11 +22,106 @@ namespace analysis{
class axes; class axes;
class align; class align;
class layout_visitor;
class data_layout;
class mma884_layout;
class scanline_layout;
class shared_layout;
enum layout_type_t {
HMMA_884, class layout_visitor {
SCANLINE, public:
SHARED virtual void visit_layout(data_layout *);
virtual void visit_layout_hmma_884(mma884_layout*) = 0;
virtual void visit_layout_scanline(scanline_layout*) = 0;
virtual void visit_layout_shared(shared_layout*) = 0;
};
class data_layout {
protected:
enum id_t {
HMMA_884,
SCANLINE,
SHARED
};
typedef std::vector<int> axes_t;
typedef std::vector<unsigned> shape_t;
typedef std::vector<int> order_t;
typedef std::vector<ir::value*> values_t;
private:
template<typename T>
T* downcast(id_t id) {
if(id_ == id)
return static_cast<T*>(this);
return nullptr;
}
public:
data_layout(id_t id,
const std::vector<int>& axes,
const std::vector<unsigned> &shape,
const std::vector<ir::value *> &values,
analysis::align* align);
// visitor
virtual void accept(layout_visitor* vst) = 0;
// downcast
mma884_layout* to_mma884() { return downcast<mma884_layout>(HMMA_884); }
scanline_layout* to_scanline() { return downcast<scanline_layout>(SCANLINE); }
shared_layout* to_shared() { return downcast<shared_layout>(SHARED); }
// accessors
size_t get_rank() { return shape_.size(); }
const shape_t& get_shape() const { return shape_; }
const order_t& get_order() const { return order_; }
const values_t& get_values() const { return values_;}
int get_axis(size_t k) const { return axes_.at(k); }
const int get_order(size_t k) const { return order_.at(k); }
// find the position of given axis
size_t find_axis(int to_find) const;
private:
id_t id_;
axes_t axes_;
values_t values_;
protected:
order_t order_;
shape_t shape_;
};
class mma884_layout: public data_layout {
public:
mma884_layout(size_t num_warps,
const std::vector<int>& axes,
const std::vector<unsigned>& shapes,
const std::vector<ir::value *> &values,
analysis::align* align);
void accept(layout_visitor* vst) { vst->visit_layout_hmma_884(this); }
// accessor
int fpw(size_t k) { return fpw_.at(k); }
int wpt(size_t k) { return wpt_.at(k); }
private:
std::vector<int> fpw_;
std::vector<int> wpt_;
};
struct scanline_layout: public data_layout {
scanline_layout(size_t num_warps,
const std::vector<int>& axes,
const std::vector<unsigned>& shape,
const std::vector<ir::value *> &values,
analysis::align* align);
void accept(layout_visitor* vst) { vst->visit_layout_scanline(this); }
// accessor
int mts(size_t k) { return mts_.at(k); }
int nts(size_t k) { return nts_.at(k); }
private:
std::vector<int> mts_;
std::vector<int> nts_;
}; };
struct double_buffer_info_t { struct double_buffer_info_t {
@@ -35,90 +130,33 @@ struct double_buffer_info_t {
ir::phi_node* phi; ir::phi_node* phi;
}; };
class layout_visitor; class shared_layout: public data_layout {
class layout_t; private:
class layout_hmma_884_t; static bool is_loop_latch(ir::phi_node *phi, ir::instruction *terminator);
class layout_scanline_t; static void extract_double_bufferable(ir::value *v, std::shared_ptr<double_buffer_info_t>& res);
class layout_shared_t;
class layout_visitor {
public: public:
virtual void visit_layout(layout_t *); shared_layout(const data_layout *arg,
virtual void visit_layout_hmma_884(layout_hmma_884_t*) = 0; const std::vector<int>& axes,
virtual void visit_layout_scanline(layout_scanline_t*) = 0; const std::vector<unsigned>& shapes,
virtual void visit_layout_shared(layout_shared_t*) = 0; const std::vector<ir::value *> &values_,
}; ir::type *ty,
analysis::align* align);
class layout_hmma_884_t;
class layout_scanline_t;
class layout_shared_t;
struct layout_t {
layout_t(layout_type_t _type,
const std::vector<int>& _axes,
const std::vector<unsigned> &_shapes,
const std::vector<ir::value *> &_values,
ir::type *_ty,
analysis::align* align);
// visitor
virtual void accept(layout_visitor* vst) = 0;
// downcast
layout_hmma_884_t* to_hmma884();
layout_scanline_t* to_scanline();
layout_shared_t* to_shared();
layout_type_t type;
std::vector<int> axes;
std::vector<unsigned> shapes;
std::vector<ir::value*> values;
std::vector<int> order;
ir::type *ty;
};
struct layout_hmma_884_t: public layout_t {
layout_hmma_884_t(size_t num_warps,
const std::vector<int>& _axes,
const std::vector<unsigned>& _shapes,
const std::vector<ir::value *> &_values,
ir::type *_ty,
analysis::align* align);
void accept(layout_visitor* vst) { vst->visit_layout_hmma_884(this); }
std::vector<int> fpw;
std::vector<int> wpt;
};
struct layout_scanline_t: public layout_t {
layout_scanline_t(size_t num_warps,
const std::vector<int>& _axes,
const std::vector<unsigned>& _shapes,
const std::vector<ir::value *> &values,
ir::type *_ty,
analysis::align* align);
void accept(layout_visitor* vst) { vst->visit_layout_scanline(this); }
std::vector<int> mts;
std::vector<int> nts;
};
struct layout_shared_t: public layout_t {
layout_shared_t(const layout_t *arg,
const std::vector<int>& _axes,
const std::vector<unsigned>& _shapes,
const std::vector<ir::value *> &values,
ir::type *ty,
analysis::align* align);
void accept(layout_visitor* vst) { vst->visit_layout_shared(this); } void accept(layout_visitor* vst) { vst->visit_layout_shared(this); }
// accessors
size_t get_size() { return size_; }
ir::type* get_type() { return ty_; }
double_buffer_info_t* get_double_buffer() { return double_buffer_.get(); }
std::shared_ptr<double_buffer_info_t> double_buffer; private:
size_t size; size_t size_;
ir::type *ty_;
std::shared_ptr<double_buffer_info_t> double_buffer_;
}; };
class layout { class layouts {
typedef ir::value* node_t; typedef ir::value* node_t;
typedef std::map <node_t, std::set<node_t>> graph_t; typedef std::map <node_t, std::set<node_t>> graph_t;
@@ -127,23 +165,23 @@ private:
void connect(ir::value *x, ir::value *y); void connect(ir::value *x, ir::value *y);
void make_graph(ir::instruction *i); void make_graph(ir::instruction *i);
void init_hmma_tile(layout_t& layout); void init_hmma_tile(data_layout& layouts);
void init_scanline_tile(layout_t &layout); void init_scanline_tile(data_layout &layouts);
void create(size_t id, const std::vector<ir::value*>& values); void create(size_t id, const std::vector<ir::value*>& values);
public: public:
// constructor // constructor
layout(analysis::axes *axes, analysis::align *align, size_t num_warps); layouts(analysis::axes *axes, analysis::align *align, size_t num_warps);
// accessors // accessors
unsigned layout_of(ir::value *value) const; unsigned layout_of(ir::value *value) const { return groups_.at(value); }
const std::vector<ir::value*>& values_of(unsigned id) const; const std::vector<ir::value*>& values_of(unsigned id) const { return values_.at(id); }
size_t num_layouts() const; size_t num_layouts() const { return values_.size();}
layout_t* get(size_t id); data_layout* get(size_t id) { return layouts_.at(id); }
layout_t* get(ir::value *v); data_layout* get(ir::value *v) { return get(layout_of(v));}
std::map<size_t, layout_t*> &get_all(); std::map<size_t, data_layout*> &get_all() { return layouts_; }
size_t tmp(ir::instruction* i); size_t tmp(ir::instruction* i) { return tmp_.at((ir::value*)i);}
// execution // execution
void run(ir::module &mod); void run(ir::module &mod);
@@ -155,7 +193,7 @@ private:
tools::graph<ir::value*> graph_; tools::graph<ir::value*> graph_;
std::map<ir::value*, size_t> groups_; std::map<ir::value*, size_t> groups_;
std::map<size_t, std::vector<ir::value*>> values_; std::map<size_t, std::vector<ir::value*>> values_;
std::map<size_t, layout_t*> layouts_; std::map<size_t, data_layout*> layouts_;
std::map<ir::value*, size_t> tmp_; std::map<ir::value*, size_t> tmp_;
}; };

View File

@@ -23,8 +23,8 @@ namespace analysis{
typedef unsigned slot_index; typedef unsigned slot_index;
class tiles; class tiles;
class layout; class layouts;
class layout_t; class data_layout;
struct segment { struct segment {
slot_index start; slot_index start;
@@ -42,20 +42,20 @@ struct segment {
class liveness { class liveness {
private: private:
typedef std::map<layout_shared_t*, segment> intervals_map_t; typedef std::map<shared_layout*, segment> intervals_map_t;
public: public:
// constructor // constructor
liveness(layout *l): layouts_(l){ } liveness(layouts *l): layouts_(l){ }
// accessors // accessors
const intervals_map_t& get() const { return intervals_; } const intervals_map_t& get() const { return intervals_; }
segment get(layout_shared_t* v) const { return intervals_.at(v); } segment get(shared_layout* v) const { return intervals_.at(v); }
// run // run
void run(ir::module &mod); void run(ir::module &mod);
private: private:
// analysis // analysis
layout *layouts_; layouts *layouts_;
intervals_map_t intervals_; intervals_map_t intervals_;
}; };

View File

@@ -35,7 +35,7 @@ class align;
class allocation; class allocation;
class cts; class cts;
class axes; class axes;
class layout; class layouts;
} }
// typedef // typedef
typedef llvm::IRBuilder<llvm::ConstantFolder, typedef llvm::IRBuilder<llvm::ConstantFolder,
@@ -50,7 +50,7 @@ typedef llvm::ArrayType ArrayType;
typedef llvm::Function Function; typedef llvm::Function Function;
typedef std::vector<Value*> indices_t; typedef std::vector<Value*> indices_t;
// forward // forward
class machine_layout_t; class machine_data_layout;
class tile; class tile;
class shared_tile; class shared_tile;
class distributed_tile; class distributed_tile;
@@ -74,13 +74,13 @@ private:
void visit_outer_dot(ir::dot_inst*, distributed_tile *TA, distributed_tile *TB, distributed_tile *TD, unsigned NK, void visit_outer_dot(ir::dot_inst*, distributed_tile *TA, distributed_tile *TB, distributed_tile *TD, unsigned NK,
Type *c_ty, Function *f_mul_add); Type *c_ty, Function *f_mul_add);
void finalize_shared_layout(analysis::layout_shared_t*); void finalize_shared_layout(analysis::shared_layout*);
void finalize_function(ir::function*); void finalize_function(ir::function*);
void finalize_phi_node(ir::phi_node*); void finalize_phi_node(ir::phi_node*);
public: public:
generator(analysis::axes *a_axes, generator(analysis::axes *a_axes,
analysis::layout *layouts, analysis::layouts *layouts,
analysis::align *alignment, analysis::align *alignment,
analysis::allocation *alloc, analysis::allocation *alloc,
target *tgt, target *tgt,
@@ -139,9 +139,9 @@ public:
void visit_basic_block(ir::basic_block*); void visit_basic_block(ir::basic_block*);
void visit_argument(ir::argument*); void visit_argument(ir::argument*);
void visit_layout_hmma_884(analysis::layout_hmma_884_t*); void visit_layout_hmma_884(analysis::mma884_layout*);
void visit_layout_scanline(analysis::layout_scanline_t*); void visit_layout_scanline(analysis::scanline_layout*);
void visit_layout_shared(analysis::layout_shared_t*); void visit_layout_shared(analysis::shared_layout*);
void visit(ir::module &, llvm::Module &); void visit(ir::module &, llvm::Module &);
@@ -150,13 +150,13 @@ private:
Builder* builder_; Builder* builder_;
Module *mod_; Module *mod_;
std::map<const analysis::layout_t*, machine_layout_t*> machine_layouts_; std::map<const analysis::data_layout*, machine_data_layout*> machine_layouts_;
analysis::axes *a_axes_; analysis::axes *a_axes_;
std::map<unsigned, distributed_axis> axes_; std::map<unsigned, distributed_axis> axes_;
std::map<ir::value *, Value *> vmap_; std::map<ir::value *, Value *> vmap_;
std::map<ir::value *, tile *> tmap_; std::map<ir::value *, tile *> tmap_;
target *tgt_; target *tgt_;
analysis::layout *layouts_; analysis::layouts *layouts_;
analysis::align *alignment_; analysis::align *alignment_;
analysis::allocation *alloc_; analysis::allocation *alloc_;
Value *sh_mem_ptr_; Value *sh_mem_ptr_;

View File

@@ -36,7 +36,7 @@ class align;
class allocation; class allocation;
class cts; class cts;
class axes; class axes;
class layout; class layouts;
} }
typedef llvm::IRBuilder<llvm::ConstantFolder, typedef llvm::IRBuilder<llvm::ConstantFolder,
@@ -51,7 +51,7 @@ typedef llvm::ArrayType ArrayType;
typedef llvm::Function Function; typedef llvm::Function Function;
class distributed_axis; class distributed_axis;
class machine_layout_t; class machine_data_layout;
class tile; class tile;
class shared_tile; class shared_tile;
class distributed_tile; class distributed_tile;
@@ -64,15 +64,15 @@ namespace triton{
namespace codegen{ namespace codegen{
class machine_layout_t { class machine_data_layout {
public: public:
virtual tile* create(ir::value *v) = 0; virtual tile* create(ir::value *v) = 0;
}; };
class machine_layout_shared_t: public machine_layout_t { class machine_shared_layout: public machine_data_layout {
public: public:
machine_layout_shared_t(Module *mod, Builder *builder, target *tgt, analysis::allocation* alloc, Value *&sh_mem_ptr, machine_shared_layout(Module *mod, Builder *builder, target *tgt, analysis::allocation* alloc, Value *&sh_mem_ptr,
analysis::layout_shared_t* layout, analysis::shared_layout* layout,
std::map<ir::value *, Value *>& vmap, std::map<ir::value *, Value *>& vmap,
std::map<ir::value *, tile *>& tmap); std::map<ir::value *, tile *>& tmap);
@@ -83,7 +83,7 @@ public:
target *tgt_; target *tgt_;
analysis::allocation* alloc_; analysis::allocation* alloc_;
Value *&sh_mem_ptr_; Value *&sh_mem_ptr_;
analysis::layout_shared_t* layout_; analysis::shared_layout* layout_;
std::map<ir::value *, Value *>& vmap_; std::map<ir::value *, Value *>& vmap_;
std::map<ir::value *, tile *>& tmap_; std::map<ir::value *, tile *>& tmap_;
@@ -94,29 +94,28 @@ public:
}; };
class machine_layout_distributed_t: public machine_layout_t { class machine_distributed_layout: public machine_data_layout {
public: public:
machine_layout_distributed_t(Module *mod, Builder *builder, target *tgt, Type *ty, machine_distributed_layout(Module *mod, Builder *builder, target *tgt,
analysis::axes *a_axes, std::map<unsigned, distributed_axis>& axes, analysis::axes *a_axes, std::map<unsigned, distributed_axis>& axes,
analysis::layout_t* layout); analysis::data_layout* layout);
tile* create(ir::value *v); tile* create(ir::value *v);
Module *mod_; Module *mod_;
Builder *builder_; Builder *builder_;
target *tgt_; target *tgt_;
Type *ty_;
analysis::axes *a_axes_; analysis::axes *a_axes_;
std::map<unsigned, distributed_axis>& axes_; std::map<unsigned, distributed_axis>& axes_;
analysis::layout_t* layout_; analysis::data_layout* layout_;
}; };
class machine_layout_hmma_884_t: public machine_layout_distributed_t { class machine_mma884_layout: public machine_distributed_layout {
public: public:
machine_layout_hmma_884_t(Module *mod, Builder *builder, machine_mma884_layout(Module *mod, Builder *builder,
target *tgt, Type *ty, target *tgt,
analysis::axes *a_axes, std::map<unsigned, distributed_axis>& axes, analysis::axes *a_axes, std::map<unsigned, distributed_axis>& axes,
analysis::layout_hmma_884_t* layout); analysis::mma884_layout* layout);
Value *offset_a_i_, *offset_a_k_; Value *offset_a_i_, *offset_a_k_;
Value *offset_b_j_, *offset_b_k_; Value *offset_b_j_, *offset_b_k_;
unsigned pack_size_0_; unsigned pack_size_0_;
@@ -125,12 +124,12 @@ public:
unsigned num_packs_1_; unsigned num_packs_1_;
}; };
class machine_layout_scanline_t: public machine_layout_distributed_t { class machine_scanline_layout: public machine_distributed_layout {
public: public:
machine_layout_scanline_t(Module *mod, Builder *builder, machine_scanline_layout(Module *mod, Builder *builder,
target *tgt, Type *ty, target *tgt,
analysis::axes *a_axes, std::map<unsigned, distributed_axis>& axes, analysis::axes *a_axes, std::map<unsigned, distributed_axis>& axes,
analysis::layout_scanline_t* layout); analysis::scanline_layout* layout);
}; };
} }

View File

@@ -47,11 +47,11 @@ class align;
class allocation; class allocation;
class cts; class cts;
class axes; class axes;
class layout; class layouts;
} }
class distributed_axis; class distributed_axis;
class machine_layout_t; class machine_data_layout;
class tile; class tile;
class shared_tile; class shared_tile;
class distributed_tile; class distributed_tile;

View File

@@ -19,7 +19,7 @@ namespace codegen{
namespace analysis{ namespace analysis{
class align; class align;
class layout; class layouts;
class cts; class cts;
} }
@@ -32,12 +32,12 @@ private:
ir::value* rematerialize(ir::value *v, ir::builder& builder, std::map<ir::value*, ir::value*>& seen); ir::value* rematerialize(ir::value *v, ir::builder& builder, std::map<ir::value*, ir::value*>& seen);
public: public:
coalesce(analysis::align* align, triton::codegen::analysis::layout *layouts); coalesce(analysis::align* align, triton::codegen::analysis::layouts *layouts);
void run(ir::module &mod); void run(ir::module &mod);
private: private:
analysis::align* align_; analysis::align* align_;
analysis::layout* layout_; analysis::layouts* layout_;
}; };
} }

View File

@@ -17,7 +17,7 @@ namespace analysis{
class allocation; class allocation;
class liveness; class liveness;
class layout; class layouts;
class cts; class cts;
} }
@@ -41,13 +41,13 @@ private:
std::set<ir::instruction *> &insert_loc, std::set<triton::ir::value *> &safe_war); std::set<ir::instruction *> &insert_loc, std::set<triton::ir::value *> &safe_war);
public: public:
membar(analysis::liveness *liveness, analysis::layout *layouts, analysis::allocation *alloc): membar(analysis::liveness *liveness, analysis::layouts *layouts, analysis::allocation *alloc):
liveness_(liveness), layouts_(layouts), alloc_(alloc) {} liveness_(liveness), layouts_(layouts), alloc_(alloc) {}
void run(ir::module &mod); void run(ir::module &mod);
private: private:
analysis::liveness *liveness_; analysis::liveness *liveness_;
analysis::layout *layouts_; analysis::layouts *layouts_;
analysis::allocation *alloc_; analysis::allocation *alloc_;
}; };

View File

@@ -15,22 +15,22 @@ void allocation::run(ir::module &mod) {
using std::min; using std::min;
typedef std::multimap<unsigned, segment> triples_map_type; typedef std::multimap<unsigned, segment> triples_map_type;
std::vector<layout_shared_t*> I; std::vector<shared_layout*> I;
for(auto x: liveness_->get()) for(auto x: liveness_->get())
I.push_back(x.first); I.push_back(x.first);
std::vector<layout_shared_t*> J = I; std::vector<shared_layout*> J = I;
triples_map_type H; triples_map_type H;
H.insert({0, segment{0, INT_MAX}}); H.insert({0, segment{0, INT_MAX}});
std::vector<layout_shared_t*> V; std::vector<shared_layout*> V;
std::map<layout_shared_t*, unsigned> starts; std::map<shared_layout*, unsigned> starts;
while(!J.empty()){ while(!J.empty()){
auto h_it = H.begin(); auto h_it = H.begin();
unsigned w = h_it->first; unsigned w = h_it->first;
segment xh = h_it->second; segment xh = h_it->second;
H.erase(h_it); H.erase(h_it);
auto j_it = std::find_if(J.begin(), J.end(), [&](layout_shared_t* JJ){ auto j_it = std::find_if(J.begin(), J.end(), [&](shared_layout* JJ){
segment xj = liveness_->get(JJ); segment xj = liveness_->get(JJ);
bool res = xj.intersect(xh); bool res = xj.intersect(xh);
for(auto val: H) for(auto val: H)
@@ -38,7 +38,7 @@ void allocation::run(ir::module &mod) {
return res; return res;
}); });
if(j_it != J.end()){ if(j_it != J.end()){
unsigned size = (*j_it)->size; unsigned size = (*j_it)->get_size();
segment xj = liveness_->get(*j_it); segment xj = liveness_->get(*j_it);
starts[*j_it] = w; starts[*j_it] = w;
H.insert({w + size, segment{max(xh.start, xj.start), min(xh.end, xj.end)}}); H.insert({w + size, segment{max(xh.start, xj.start), min(xh.end, xj.end)}});
@@ -52,14 +52,14 @@ void allocation::run(ir::module &mod) {
} }
// Build interference graph // Build interference graph
std::map<layout_shared_t*, std::set<layout_shared_t*>> interferences; std::map<shared_layout*, std::set<shared_layout*>> interferences;
for(layout_shared_t* x: V) for(shared_layout* x: V)
for(layout_shared_t* y: V){ for(shared_layout* y: V){
if(x == y) if(x == y)
continue; continue;
unsigned X0 = starts[x], Y0 = starts[y]; unsigned X0 = starts[x], Y0 = starts[y];
unsigned NX = x->size; unsigned NX = x->get_size();
unsigned NY = y->size; unsigned NY = y->get_size();
segment XS = {X0, X0 + NX}; segment XS = {X0, X0 + NX};
segment YS = {Y0, Y0 + NY}; segment YS = {Y0, Y0 + NY};
if(liveness_->get(x).intersect(liveness_->get(y)) if(liveness_->get(x).intersect(liveness_->get(y))
@@ -68,17 +68,17 @@ void allocation::run(ir::module &mod) {
} }
// Initialize colors // Initialize colors
std::map<layout_shared_t*, int> colors; std::map<shared_layout*, int> colors;
for(layout_shared_t* X: V) for(shared_layout* X: V)
colors[X] = (X==V[0])?0:-1; colors[X] = (X==V[0])?0:-1;
// First-fit graph coloring // First-fit graph coloring
std::vector<bool> available(V.size()); std::vector<bool> available(V.size());
for(layout_shared_t* x: V){ for(shared_layout* x: V){
// Non-neighboring colors are available // Non-neighboring colors are available
std::fill(available.begin(), available.end(), true); std::fill(available.begin(), available.end(), true);
for(layout_shared_t* Y: interferences[x]){ for(shared_layout* Y: interferences[x]){
int color = colors[Y]; int color = colors[Y];
if(color >= 0) if(color >= 0)
available[color] = false; available[color] = false;
@@ -89,17 +89,17 @@ void allocation::run(ir::module &mod) {
} }
// Finalize allocation // Finalize allocation
for(layout_shared_t* x: V){ for(shared_layout* x: V){
unsigned Adj = 0; unsigned Adj = 0;
for(layout_shared_t* y: interferences[x]) for(shared_layout* y: interferences[x])
Adj = std::max<unsigned>(Adj, starts[y] + y->size); Adj = std::max<unsigned>(Adj, starts[y] + y->get_size());
offsets_[x] = starts[x] + colors[x] * Adj; offsets_[x] = starts[x] + colors[x] * Adj;
} }
// Save maximum size of induced memory space // Save maximum size of induced memory space
allocated_size_ = 0; allocated_size_ = 0;
for(layout_shared_t* x: V) for(shared_layout* x: V)
allocated_size_ = std::max<size_t>(allocated_size_, starts[x] + x->size); allocated_size_ = std::max<size_t>(allocated_size_, starts[x] + x->get_size());
} }
} }

View File

@@ -12,57 +12,15 @@ namespace triton{
namespace codegen{ namespace codegen{
namespace analysis{ namespace analysis{
/* -------------------------------- *
* Helper Functions *
* -------------------------------- */
// constructor inline unsigned clamp(unsigned x, unsigned lo, unsigned hi) {
layout::layout(analysis::axes *axes, analysis::align *align, size_t num_warps) return std::min(std::max(x, lo), hi);
: axes_(axes), align_(align), num_warps_(num_warps) { }
// get group id
unsigned layout::layout_of(ir::value *value) const
{ return groups_.at(value); }
// get values
const std::vector<ir::value*>& layout::values_of(unsigned id) const
{ return values_.at(id); }
// get number of groups
size_t layout::num_layouts() const
{ return values_.size(); }
// connect two values
void layout::connect(ir::value *x, ir::value *y) {
if(x == y)
return;
if(!x->get_type()->is_tile_ty())
return;
if(!y->get_type()->is_tile_ty())
return;
std::vector<int> x_axes = axes_->get(x);
std::vector<int> y_axes = axes_->get(y);
std::set<int> sx_axes(x_axes.begin(), x_axes.end());
std::set<int> sy_axes(y_axes.begin(), y_axes.end());
std::set<int> common;
std::set_intersection(sx_axes.begin(), sx_axes.end(),
sy_axes.begin(), sy_axes.end(),
std::inserter(common, common.begin()));
graph_.add_edge(x, x);
graph_.add_edge(y, y);
if(!common.empty())
graph_.add_edge(x, y);
} }
// make graph inline bool is_hmma_c(ir::value *v){
void layout::make_graph(ir::instruction *i) {
for(ir::value* opx: i->ops())
for(ir::value* opy: i->ops()){
connect(i, opx);
connect(opx, opy);
}
}
// hmma
bool is_hmma_c(ir::value *v){
bool result = false; bool result = false;
if(auto *x = dynamic_cast<ir::dot_inst*>(v)){ if(auto *x = dynamic_cast<ir::dot_inst*>(v)){
ir::value *a = x->get_operand(0); ir::value *a = x->get_operand(0);
@@ -75,23 +33,7 @@ bool is_hmma_c(ir::value *v){
return result; return result;
} }
layout_t* layout::get(size_t id) { inline void extract_io_use(ir::value *v, std::set<ir::value*>& result) {
return layouts_.at(id);
}
layout_t* layout::get(ir::value *v) {
return layouts_.at(groups_.at(v));
}
std::map<size_t, layout_t*>& layout::get_all() {
return layouts_;
}
size_t layout::tmp(ir::instruction* i) {
return tmp_.at(i);
}
void extract_io_use(ir::value *v, std::set<ir::value*>& result) {
for(ir::user* u: v->get_users()){ for(ir::user* u: v->get_users()){
auto i = dynamic_cast<ir::io_inst*>(u); auto i = dynamic_cast<ir::io_inst*>(u);
if(i && i->get_pointer_operand() == v) if(i && i->get_pointer_operand() == v)
@@ -99,7 +41,7 @@ void extract_io_use(ir::value *v, std::set<ir::value*>& result) {
} }
} }
void extract_dot_use(ir::value *v, ir::value*& result, size_t n) { inline void extract_dot_use(ir::value *v, ir::value*& result, size_t n) {
for(ir::user* u: v->get_users()){ for(ir::user* u: v->get_users()){
auto i = dynamic_cast<ir::dot_inst*>(u); auto i = dynamic_cast<ir::dot_inst*>(u);
if(i && i->get_operand(n) == v) if(i && i->get_operand(n) == v)
@@ -107,7 +49,7 @@ void extract_dot_use(ir::value *v, ir::value*& result, size_t n) {
} }
} }
void extract_hmma_dot_use(ir::value *v, ir::value*& result, size_t n) { inline void extract_hmma_dot_use(ir::value *v, ir::value*& result, size_t n) {
for(ir::user* u: v->get_users()){ for(ir::user* u: v->get_users()){
auto i = dynamic_cast<ir::dot_inst*>(u); auto i = dynamic_cast<ir::dot_inst*>(u);
if(i && is_hmma_c(i) && i->get_operand(n) == v) if(i && is_hmma_c(i) && i->get_operand(n) == v)
@@ -116,7 +58,6 @@ void extract_hmma_dot_use(ir::value *v, ir::value*& result, size_t n) {
} }
inline bool is_trans(ir::value *v) { inline bool is_trans(ir::value *v) {
if(dynamic_cast<ir::trans_inst *>(v)) { if(dynamic_cast<ir::trans_inst *>(v)) {
return true; return true;
@@ -131,104 +72,103 @@ inline bool is_trans(ir::value *v) {
} }
void layout_visitor::visit_layout(layout_t *layout) { /* -------------------------------- *
* Layout Visitor *
* -------------------------------- */
void layout_visitor::visit_layout(data_layout *layout) {
layout->accept(this); layout->accept(this);
} }
layout_t::layout_t(layout_type_t _type, /* -------------------------------- *
const std::vector<int> &_axes, * Base Data Layout *
const std::vector<unsigned> &_shapes, * -------------------------------- */
const std::vector<ir::value *> &_values, ir::type *_ty,
analysis::align* align): type(_type), axes(_axes), shapes(_shapes), values(_values), ty(_ty) { data_layout::data_layout(id_t id,
const std::vector<int> &axes,
const std::vector<unsigned> &shape,
const std::vector<ir::value *> &values,
analysis::align* align): id_(id), axes_(axes), shape_(shape), values_(values) {
// io pointer // io pointer
std::set<ir::value*> ptr; std::set<ir::value*> ptr;
for(ir::value* v: values) for(ir::value* v: values_)
extract_io_use(v, ptr); extract_io_use(v, ptr);
order.resize(axes.size()); order_.resize(axes_.size());
std::iota(order.begin(), order.end(), 0); std::iota(order_.begin(), order_.end(), 0);
auto largest = std::max_element(ptr.begin(), ptr.end(), [&](ir::value *x, ir::value *y){ auto largest = std::max_element(ptr.begin(), ptr.end(), [&](ir::value *x, ir::value *y){
return x->get_type()->get_tile_rank() < y->get_type()->get_tile_rank(); return x->get_type()->get_tile_rank() < y->get_type()->get_tile_rank();
}); });
if(*largest){ if(*largest){
auto max_contiguous = align->contiguous(*largest); auto max_contiguous = align->contiguous(*largest);
std::sort(order.begin(), order.end(), [&](unsigned a, unsigned b) { std::sort(order_.begin(), order_.end(), [&](unsigned a, unsigned b) {
return max_contiguous[a] > max_contiguous[b]; return max_contiguous[a] > max_contiguous[b];
}); });
} }
} }
// downcast size_t data_layout::find_axis(int to_find) const {
layout_hmma_884_t* layout_t::to_hmma884() { auto it = std::find(axes_.begin(), axes_.end(), to_find);
assert(type == HMMA_884); return std::distance(axes_.begin(), it);
return static_cast<layout_hmma_884_t*>(this);
} }
layout_scanline_t* layout_t::to_scanline() {
assert(type == SCANLINE);
return static_cast<layout_scanline_t*>(this);
}
layout_shared_t* layout_t::to_shared() { /* -------------------------------- *
assert(type == SHARED); * MMA Layout *
return static_cast<layout_shared_t*>(this); * -------------------------------- */
}
inline unsigned clamp(unsigned x, unsigned lo, unsigned hi) { mma884_layout::mma884_layout(size_t num_warps,
return std::min(std::max(x, lo), hi); const std::vector<int>& axes,
} const std::vector<unsigned>& shape,
const std::vector<ir::value *> &values,
layout_hmma_884_t::layout_hmma_884_t(size_t num_warps, analysis::align* align): data_layout(HMMA_884, axes, shape, values, align) {
const std::vector<int>& _axes,
const std::vector<unsigned>& _shapes,
const std::vector<ir::value *> &values, ir::type *_ty,
analysis::align* align): layout_t(HMMA_884, _axes, _shapes, values, _ty, align) {
unsigned shape_0 = shapes[0];
unsigned shape_1 = shapes[1];
/* fragments per warp */ /* fragments per warp */
// try to make things as square as possible to maximize data re-use // try to make things as square as possible to maximize data re-use
fpw = {1, 1, 1}; fpw_ = {1, 1, 1};
std::vector<int> fpw_nm1; std::vector<int> fpw_nm1;
unsigned num_fragments = std::min<unsigned>((shape_0/8)*(shape_1/8), 4); unsigned num_fragments = std::min<unsigned>((shape_[0]/8)*(shape_[1]/8), 4);
do { do {
fpw_nm1 = fpw; fpw_nm1 = fpw_;
if(fpw[0]*fpw[1] < num_fragments) if(fpw_[0]*fpw_[1] < num_fragments)
fpw[0] = clamp(fpw[0]*2, 1, shape_0 / 8); fpw_[0] = clamp(fpw_[0]*2, 1, shape_[0] / 8);
if(fpw[0]*fpw[1] < num_fragments) if(fpw_[0]*fpw_[1] < num_fragments)
fpw[1] = clamp(fpw[1]*2, 1, shape_1 / 8); fpw_[1] = clamp(fpw_[1]*2, 1, shape_[1] / 8);
}while(fpw_nm1 != fpw); }while(fpw_nm1 != fpw_);
/* warps per tile */ /* warps per tile */
// try to make things as square as possible to maximize data re-use // try to make things as square as possible to maximize data re-use
wpt = {1, 1, 1}; wpt_ = {1, 1, 1};
std::vector<int> wpt_nm1; std::vector<int> wpt_nm1;
do{ do{
wpt_nm1 = wpt; wpt_nm1 = wpt_;
if(wpt[0] * wpt[1] * wpt[2] < num_warps) if(wpt_[0] * wpt_[1] * wpt_[2] < num_warps)
wpt[0] = clamp(wpt[0]*2, 1, shape_0 / (fpw[0]*8)); wpt_[0] = clamp(wpt_[0]*2, 1, shape_[0] / (fpw_[0]*8));
if(wpt[0] * wpt[1] * wpt[2] < num_warps) if(wpt_[0] * wpt_[1] * wpt_[2] < num_warps)
wpt[1] = clamp(wpt[1]*2, 1, shape_1 / (fpw[1]*8)); wpt_[1] = clamp(wpt_[1]*2, 1, shape_[1] / (fpw_[1]*8));
}while(wpt_nm1 != wpt); }while(wpt_nm1 != wpt_);
/* sanity check */ /* sanity check */
unsigned effective_num_warps = 1; unsigned effective_num_warps = 1;
for(size_t d = 0; d < shapes.size(); d++) for(size_t d = 0; d < shape.size(); d++)
effective_num_warps *= wpt[d]; effective_num_warps *= wpt_[d];
if(num_warps != effective_num_warps) if(num_warps != effective_num_warps)
throw std::runtime_error("cannot create a kernel with this amount of warps"); throw std::runtime_error("cannot create a kernel with this amount of warps");
} }
/* -------------------------------- *
* Scanline Layout *
* -------------------------------- */
scanline_layout::scanline_layout(size_t num_warps,
layout_scanline_t::layout_scanline_t(size_t num_warps, const std::vector<int>& axes,
const std::vector<int>& _axes, const std::vector<unsigned>& shape,
const std::vector<unsigned>& _shapes, const std::vector<ir::value *> &values,
const std::vector<ir::value *> &values, ir::type *_ty, analysis::align* align): data_layout(SCANLINE, axes, shape, values, align){
analysis::align* align): layout_t(SCANLINE, _axes, _shapes, values, _ty, align){ unsigned size = std::accumulate(shape_.begin(), shape_.end(), 1, std::multiplies<int>());
unsigned size = std::accumulate(shapes.begin(), shapes.end(), 1, std::multiplies<int>());
unsigned num_threads = num_warps * 32; unsigned num_threads = num_warps * 32;
nts.resize(shapes.size()); nts_.resize(shape_.size());
mts.resize(shapes.size()); mts_.resize(shape_.size());
bool is_dot = std::any_of(values.begin(), values.end(), bool is_dot = std::any_of(values.begin(), values.end(),
[&](ir::value* v) { return dynamic_cast<ir::dot_inst*>(v); }); [&](ir::value* v) { return dynamic_cast<ir::dot_inst*>(v); });
@@ -238,34 +178,39 @@ layout_scanline_t::layout_scanline_t(size_t num_warps,
if(auto *st = dynamic_cast<ir::store_inst*>(usr)) if(auto *st = dynamic_cast<ir::store_inst*>(usr))
ptr = st->get_pointer_operand(); ptr = st->get_pointer_operand();
unsigned i = order[0]; unsigned i = order_[0];
int contiguous = 4; int contiguous = 4;
if(ptr) if(ptr)
contiguous = std::min<int>(align->contiguous(ptr)[i], 4); contiguous = std::min<int>(align->contiguous(ptr)[i], 4);
nts[i] = clamp(size / num_threads, 1, std::min<int>(contiguous, shapes[i])); nts_[i] = clamp(size / num_threads, 1, std::min<int>(contiguous, shape_[i]));
mts[i] = clamp(num_threads, 1, shapes[i] / nts[i]); mts_[i] = clamp(num_threads, 1, shape_[i] / nts_[i]);
size /= shapes[i]; size /= shape_[i];
num_threads /= mts[i]; num_threads /= mts_[i];
if(is_dot) if(is_dot)
nts[order[1]] = clamp(size / num_threads, 1, std::min<int>(4, shapes[order[1]])); nts_[order_[1]] = clamp(size / num_threads, 1, std::min<int>(4, shape_[order_[1]]));
for(size_t d = 1; d < shapes.size(); d++){ for(size_t d = 1; d < shape_.size(); d++){
i = order[d]; i = order_[d];
if(d > 1 || !is_dot) if(d > 1 || !is_dot)
nts[i] = 1; nts_[i] = 1;
mts[i] = clamp(num_threads, 1, shapes[i] / nts[i]); mts_[i] = clamp(num_threads, 1, shape_[i] / nts_[i]);
num_threads = num_threads / mts[i]; num_threads = num_threads / mts_[i];
} }
/* sanity check */ /* sanity check */
unsigned effective_num_threads = 1; unsigned effective_num_threads = 1;
for(size_t d = 0; d < shapes.size(); d++) for(size_t d = 0; d < shape_.size(); d++)
effective_num_threads *= mts[d]; effective_num_threads *= mts_[d];
if(num_warps * 32 != effective_num_threads) if(num_warps * 32 != effective_num_threads)
throw std::runtime_error("cannot create a kernel with this amount of warps"); throw std::runtime_error("cannot create a kernel with this amount of warps");
} }
inline bool is_loop_latch(ir::phi_node *phi, ir::instruction *terminator){
/* -------------------------------- *
* Shared Layout *
* -------------------------------- */
bool shared_layout::is_loop_latch(ir::phi_node *phi, ir::instruction *terminator){
if(phi->get_parent() != terminator->get_parent()) if(phi->get_parent() != terminator->get_parent())
return false; return false;
if(auto *br = dynamic_cast<ir::cond_branch_inst*>(terminator)) if(auto *br = dynamic_cast<ir::cond_branch_inst*>(terminator))
@@ -278,7 +223,7 @@ inline bool is_loop_latch(ir::phi_node *phi, ir::instruction *terminator){
} }
void extract_double_bufferable(ir::value *v, std::shared_ptr<double_buffer_info_t>& res) { void shared_layout::extract_double_bufferable(ir::value *v, std::shared_ptr<double_buffer_info_t>& res) {
auto* phi = dynamic_cast<ir::phi_node*>(v); auto* phi = dynamic_cast<ir::phi_node*>(v);
if(!phi || phi->get_num_incoming() != 2) if(!phi || phi->get_num_incoming() != 2)
return; return;
@@ -303,22 +248,22 @@ void extract_double_bufferable(ir::value *v, std::shared_ptr<double_buffer_info_
} }
layout_shared_t::layout_shared_t(const layout_t *arg, shared_layout::shared_layout(const data_layout *arg,
const std::vector<int>& _axes, const std::vector<int>& axes,
const std::vector<unsigned>& _shapes, const std::vector<unsigned>& shape,
const std::vector<ir::value *> &values, const std::vector<ir::value *> &values,
ir::type *ty, ir::type *ty,
analysis::align* align): layout_t(SHARED, _axes, _shapes, values, ty, align) { analysis::align* align): data_layout(SHARED, axes, shape, values, align), ty_(ty) {
size = 0; size_ = 0;
// double-buffering // double-buffering
for(ir::value *v: values) for(ir::value *v: values)
extract_double_bufferable(v, double_buffer); extract_double_bufferable(v, double_buffer_);
// order // order
std::vector<int> arg_order = arg ? arg->order : std::vector<int>{0}; std::vector<int> arg_order = arg ? arg->get_order() : std::vector<int>{0};
order = arg_order; order_ = arg_order;
ir::value* dot_a = nullptr; ir::value* dot_a = nullptr;
ir::value* dot_b = nullptr; ir::value* dot_b = nullptr;
@@ -330,48 +275,84 @@ layout_shared_t::layout_shared_t(const layout_t *arg,
extract_hmma_dot_use(v, hmma_dot_a, 0); extract_hmma_dot_use(v, hmma_dot_a, 0);
extract_hmma_dot_use(v, hmma_dot_b, 1); extract_hmma_dot_use(v, hmma_dot_b, 1);
} }
// non-mma ordering
std::vector<int> col = {0, 1}; std::vector<int> col = {0, 1};
std::vector<int> row = {1, 0}; std::vector<int> row = {1, 0};
for(size_t s = 2; s < shapes.size(); s++){ for(size_t s = 2; s < get_rank(); s++){
col.push_back(s); col.push_back(s);
row.push_back(s); row.push_back(s);
} }
bool is_nonhmma_dot_a = dot_a && !hmma_dot_a; bool is_nonhmma_dot_a = dot_a && !hmma_dot_a;
bool is_nonhmma_dot_b = dot_b && !hmma_dot_b; bool is_nonhmma_dot_b = dot_b && !hmma_dot_b;
if(is_nonhmma_dot_a) if(is_nonhmma_dot_a)
order = is_trans(dot_a) ? row : col; order_ = is_trans(dot_a) ? row : col;
else if(is_nonhmma_dot_b) else if(is_nonhmma_dot_b)
order = is_trans(dot_b) ? col : row; order_ = is_trans(dot_b) ? col : row;
// else
// order = row;
// padding // padding
size_t pad = 0; size_t pad = 0;
if(hmma_dot_a){ if(hmma_dot_a){
bool row = is_trans(hmma_dot_a) ^ order[0] != 0; bool row = is_trans(hmma_dot_a) ^ order_[0] != 0;
pad = 24 - shapes[row ? 0 : 1] % 32; pad = 24 - shape_[row ? 0 : 1] % 32;
} }
else if(hmma_dot_b){ else if(hmma_dot_b){
bool row = is_trans(hmma_dot_b) ^ order[0] != 0; bool row = is_trans(hmma_dot_b) ^ order_[0] != 0;
pad = 24 - shapes[row ? 1 : 0] % 32; pad = 24 - shape_[row ? 1 : 0] % 32;
} }
else if(order != arg_order) { else if(order_ != arg_order) {
pad = 4; pad = 4;
} }
shapes[order[0]] += pad; shape_[order_[0]] += pad;
// size // size
size = ty->get_primitive_size_in_bits() / 8; size_ = ty_->get_primitive_size_in_bits() / 8;
for(auto s: shapes) for(auto s: shape_)
size *= s; size_ *= s;
if(double_buffer) if(double_buffer_)
size *= 2; size_ *= 2;
} }
// layout factory method /* -------------------------------- *
void layout::create(size_t id, const std::vector<ir::value*>& values) { * ---- Layouts Inference Pass ---- *
* -------------------------------- */
layouts::layouts(analysis::axes *axes, analysis::align *align, size_t num_warps)
: axes_(axes), align_(align), num_warps_(num_warps) { }
void layouts::connect(ir::value *x, ir::value *y) {
if(x == y)
return;
if(!x->get_type()->is_tile_ty())
return;
if(!y->get_type()->is_tile_ty())
return;
std::vector<int> x_axes = axes_->get(x);
std::vector<int> y_axes = axes_->get(y);
std::set<int> sx_axes(x_axes.begin(), x_axes.end());
std::set<int> sy_axes(y_axes.begin(), y_axes.end());
std::set<int> common;
std::set_intersection(sx_axes.begin(), sx_axes.end(),
sy_axes.begin(), sy_axes.end(),
std::inserter(common, common.begin()));
graph_.add_edge(x, x);
graph_.add_edge(y, y);
if(!common.empty())
graph_.add_edge(x, y);
}
void layouts::make_graph(ir::instruction *i) {
for(ir::value* opx: i->ops())
for(ir::value* opy: i->ops()){
connect(i, opx);
connect(opx, opy);
}
}
void layouts::create(size_t id, const std::vector<ir::value*>& values) {
auto it_hmma_c = std::find_if(values.begin(), values.end(), &is_hmma_c); auto it_hmma_c = std::find_if(values.begin(), values.end(), &is_hmma_c);
auto cmp = [](ir::value* x, ir::value *y) { auto cmp = [](ir::value* x, ir::value *y) {
return x->get_type()->get_tile_ranks1() < return x->get_type()->get_tile_ranks1() <
@@ -387,18 +368,18 @@ void layout::create(size_t id, const std::vector<ir::value*>& values) {
}); });
// type // type
if(it_hmma_c != values.end()) if(it_hmma_c != values.end())
layouts_[id] = new layout_hmma_884_t(num_warps_, axes, shapes, values, largest->get_type()->get_scalar_ty(), align_); layouts_[id] = new mma884_layout(num_warps_, axes, shapes, values, align_);
else if(it_cts != values.end()){ else if(it_cts != values.end()){
ir::copy_to_shared_inst *cts = (ir::copy_to_shared_inst*)*it_cts; ir::copy_to_shared_inst *cts = (ir::copy_to_shared_inst*)*it_cts;
ir::value *arg = cts->get_operand(0); ir::value *arg = cts->get_operand(0);
create(groups_.at(arg), values_.at(groups_.at(arg))); create(groups_.at(arg), values_.at(groups_.at(arg)));
layouts_[id] = new layout_shared_t(get(arg), axes, shapes, values, largest->get_type()->get_scalar_ty(), align_); layouts_[id] = new shared_layout(get(arg), axes, shapes, values, largest->get_type()->get_scalar_ty(), align_);
} }
else else
layouts_[id] = new layout_scanline_t(num_warps_, axes, shapes, values, largest->get_type()->get_scalar_ty(), align_); layouts_[id] = new scanline_layout(num_warps_, axes, shapes, values, align_);
} }
void layout::run(ir::module &mod) { void layouts::run(ir::module &mod) {
// make graph // make graph
graph_.clear(); graph_.clear();
ir::for_each_instruction(mod, [this](ir::instruction* i) { ir::for_each_instruction(mod, [this](ir::instruction* i) {
@@ -422,35 +403,35 @@ void layout::run(ir::module &mod) {
// shape // shape
auto shapes = arg->get_type()->get_tile_shapes(); auto shapes = arg->get_type()->get_tile_shapes();
unsigned shape_ax = shapes[axis]; unsigned shape_ax = shapes[axis];
layout_scanline_t *layout = get(arg)->to_scanline(); scanline_layout *layout = get(arg)->to_scanline();
unsigned per_thread = layout->nts[axis]; unsigned per_thread = layout->nts(axis);
unsigned depth = shape_ax / per_thread; unsigned depth = shape_ax / per_thread;
shapes[axis] = depth; shapes[axis] = depth;
// create layout // create layout
layouts_[id] = new layout_shared_t(layout, axes_->get(arg), shapes, {red}, red->get_type()->get_scalar_ty(), align_); layouts_[id] = new shared_layout(layout, axes_->get(arg), shapes, {red}, red->get_type()->get_scalar_ty(), align_);
tmp_[red] = id; tmp_[red] = id;
} }
if(auto *recoalasce = dynamic_cast<ir::recoalesce_inst*>(i)){ if(auto *recoalasce = dynamic_cast<ir::recoalesce_inst*>(i)){
ir::value *val = recoalasce->get_operand(0); ir::value *val = recoalasce->get_operand(0);
layout_t* in_layout = get(val); mma884_layout* in_layout = get(val)->to_mma884();
layout_t* out_layout = get(i); scanline_layout* out_layout = get(i)->to_scanline();
if(in_layout->type != HMMA_884) if(!in_layout || !out_layout)
return; return;
id++; id++;
ir::type::tile_shapes_t in_shape = val->get_type()->get_tile_shapes(); ir::type::tile_shapes_t in_shape = val->get_type()->get_tile_shapes();
ir::type::tile_shapes_t shape(in_shape.size()); ir::type::tile_shapes_t shape(in_shape.size());
size_t ld = out_layout->order[0]; size_t ld = out_layout->get_order(0);
shape[ld] = in_shape[ld]; shape[ld] = in_shape[ld];
for(size_t k = 0; k < in_shape.size(); k++) for(size_t k = 0; k < in_shape.size(); k++)
if(k != ld) if(k != ld)
shape[k] = 4*in_layout->to_hmma884()->fpw[k]*in_layout->to_hmma884()->wpt[k]; shape[k] = 4*in_layout->to_mma884()->fpw(k)*in_layout->to_mma884()->wpt(k);
// create layout // create layout
layouts_[id] = new layout_shared_t(out_layout, axes_->get(val), shape, {recoalasce}, val->get_type()->get_scalar_ty(), align_); layouts_[id] = new shared_layout(out_layout, axes_->get(val), shape, {recoalasce}, val->get_type()->get_scalar_ty(), align_);
tmp_[recoalasce] = id; tmp_[recoalasce] = id;
} }
if(auto *atom = dynamic_cast<ir::atomic_cas_inst*>(i)){ if(auto *atom = dynamic_cast<ir::atomic_cas_inst*>(i)){
id++; id++;
layouts_[id] = new layout_shared_t(nullptr, {}, {1}, {atom}, atom->get_type()->get_scalar_ty(), align_); layouts_[id] = new shared_layout(nullptr, {}, {1}, {atom}, atom->get_type()->get_scalar_ty(), align_);
tmp_[atom] = id; tmp_[atom] = id;
} }
}); });

View File

@@ -27,18 +27,18 @@ void liveness::run(ir::module &mod) {
// create live intervals // create live intervals
for(auto &x: layouts_->get_all()) { for(auto &x: layouts_->get_all()) {
if(x.second->type != SHARED) shared_layout* layout = x.second->to_shared();
if(!layout)
continue; continue;
layout_shared_t* layout = x.second->to_shared();
// users // users
std::set<ir::user*> users; std::set<ir::user*> users;
for(ir::value *v: layout->values){ for(ir::value *v: layout->get_values()){
for(ir::user *u: v->get_users()) for(ir::user *u: v->get_users())
users.insert(u); users.insert(u);
} }
// compute intervals // compute intervals
unsigned start = INT32_MAX; unsigned start = INT32_MAX;
for(ir::value *v: layout->values) for(ir::value *v: layout->get_values())
if(indices.find(v) != indices.end()) if(indices.find(v) != indices.end())
start = std::min(start, indices.at(v)); start = std::min(start, indices.at(v));
unsigned end = 0; unsigned end = 0;

View File

@@ -174,7 +174,7 @@ inline bool is_trans(ir::value *v) {
generator::generator(analysis::axes *a_axes, generator::generator(analysis::axes *a_axes,
analysis::layout *layouts, analysis::layouts *layouts,
analysis::align *alignment, analysis::align *alignment,
analysis::allocation *alloc, analysis::allocation *alloc,
target *tgt, target *tgt,
@@ -295,7 +295,7 @@ void generator::visit_unmasked_load_inst(ir::unmasked_load_inst* x) {
} }
// find vector size // find vector size
ir::value *ptr = x->get_pointer_operand(); ir::value *ptr = x->get_pointer_operand();
size_t ld = layouts_->get(ptr)->order[0]; size_t ld = layouts_->get(ptr)->get_order(0);
unsigned alignment = std::max<int>(alignment_->get(ptr, ld), 1); unsigned alignment = std::max<int>(alignment_->get(ptr, ld), 1);
@@ -337,7 +337,7 @@ void generator::visit_unmasked_load_inst(ir::unmasked_load_inst* x) {
void generator::visit_masked_load_inst(ir::masked_load_inst* x) { void generator::visit_masked_load_inst(ir::masked_load_inst* x) {
// find vector size // find vector size
ir::value *ptr = x->get_pointer_operand(); ir::value *ptr = x->get_pointer_operand();
size_t ld = layouts_->get(ptr)->order[0]; size_t ld = layouts_->get(ptr)->get_order(0);
unsigned alignment = alignment_->get(ptr, ld); unsigned alignment = alignment_->get(ptr, ld);
distributed_tile *pointers = (distributed_tile*)tmap_.at(ptr); distributed_tile *pointers = (distributed_tile*)tmap_.at(ptr);
distributed_tile *masks = (distributed_tile*)tmap_.at(x->get_mask_operand()); distributed_tile *masks = (distributed_tile*)tmap_.at(x->get_mask_operand());
@@ -603,7 +603,7 @@ void generator::visit_atomic_add_inst(ir::atomic_add_inst*) {
void generator::visit_hmma_dot(ir::dot_inst* dot, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK) { void generator::visit_hmma_dot(ir::dot_inst* dot, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK) {
const auto& shapes = dot->get_type()->get_tile_shapes(); const auto& shapes = dot->get_type()->get_tile_shapes();
machine_layout_hmma_884_t* hmma = (machine_layout_hmma_884_t*)machine_layouts_.at(layouts_->get(dot)); machine_mma884_layout* hmma = (machine_mma884_layout*)machine_layouts_.at(layouts_->get(dot));
TA->set_vector_size(4*hmma->pack_size_0_); TA->set_vector_size(4*hmma->pack_size_0_);
TB->set_vector_size(4*hmma->pack_size_1_); TB->set_vector_size(4*hmma->pack_size_1_);
TA->set_return_mode(true); TA->set_return_mode(true);
@@ -625,8 +625,8 @@ void generator::visit_hmma_dot(ir::dot_inst* dot, shared_tile *TA, shared_tile *
Value* u_thread_id = tgt_->get_local_id(builder_->GetInsertBlock()->getModule(), *builder_, 0); Value* u_thread_id = tgt_->get_local_id(builder_->GetInsertBlock()->getModule(), *builder_, 0);
auto ord_a = layouts_->get(dot->get_operand(0))->order; auto ord_a = layouts_->get(dot->get_operand(0))->get_order();
auto ord_b = layouts_->get(dot->get_operand(1))->order; auto ord_b = layouts_->get(dot->get_operand(1))->get_order();
bool is_a_trans = is_trans(dot->get_operand(0)); bool is_a_trans = is_trans(dot->get_operand(0));
bool is_b_trans = is_trans(dot->get_operand(1)); bool is_b_trans = is_trans(dot->get_operand(1));
@@ -655,14 +655,14 @@ void generator::visit_hmma_dot(ir::dot_inst* dot, shared_tile *TA, shared_tile *
"{$8, $9}, " "{$8, $9}, "
"{$10, $11}, " "{$10, $11}, "
"{$0, $1, $2, $3, $4, $5, $6, $7};", "=f,=f,=f,=f,=f,=f,=f,=f,r,r,r,r,0,1,2,3,4,5,6,7", false); "{$0, $1, $2, $3, $4, $5, $6, $7};", "=f,=f,=f,=f,=f,=f,=f,=f,r,r,r,r,0,1,2,3,4,5,6,7", false);
analysis::layout_hmma_884_t* layout = layouts_->get(dot)->to_hmma884(); analysis::mma884_layout* layout = layouts_->get(dot)->to_mma884();
unsigned fpw_0 = layout->fpw.at(0); unsigned fpw_0 = layout->fpw(0);
unsigned fpw_1 = layout->fpw.at(1); unsigned fpw_1 = layout->fpw(1);
unsigned wts_0 = fpw_0 * 8; unsigned wts_0 = fpw_0 * 8;
unsigned wts_1 = fpw_1 * 8; unsigned wts_1 = fpw_1 * 8;
unsigned wpt_0 = layout->wpt.at(0); unsigned wpt_0 = layout->wpt(0);
unsigned wpt_1 = layout->wpt.at(1); unsigned wpt_1 = layout->wpt(1);
unsigned stride_rep_i = wpt_0 * wts_0; unsigned stride_rep_i = wpt_0 * wts_0;
unsigned stride_rep_j = wpt_1 * wts_1; unsigned stride_rep_j = wpt_1 * wts_1;
unsigned num_rep_i = shapes[0] / stride_rep_i; unsigned num_rep_i = shapes[0] / stride_rep_i;
@@ -792,7 +792,7 @@ void generator::visit_dot_inst(ir::dot_inst* dot) {
if(NK != 1) { if(NK != 1) {
shared_tile *TA = (shared_tile*)tmap_.at(A); shared_tile *TA = (shared_tile*)tmap_.at(A);
shared_tile *TB = (shared_tile*)tmap_.at(B); shared_tile *TB = (shared_tile*)tmap_.at(B);
if(layouts_->get(dot)->type == analysis::HMMA_884) if(layouts_->get(dot)->to_mma884())
visit_hmma_dot(dot, TA, TB, TD, NK); visit_hmma_dot(dot, TA, TB, TD, NK);
else else
visit_scanline_dot(dot, TA, TB, TD, NK, c_ty, f_mul_add); visit_scanline_dot(dot, TA, TB, TD, NK, c_ty, f_mul_add);
@@ -856,7 +856,7 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) {
}); });
// reduce within blocks // reduce within blocks
machine_layout_t *slayout = machine_layouts_.at(layouts_->get(layouts_->tmp(x))); machine_data_layout *slayout = machine_layouts_.at(layouts_->get(layouts_->tmp(x)));
shared_tile *stile = (shared_tile*)slayout->create(x); shared_tile *stile = (shared_tile*)slayout->create(x);
unsigned depth = stile->get_shapes()[axis]; unsigned depth = stile->get_shapes()[axis];
@@ -926,31 +926,31 @@ void generator::visit_recoalesce_inst(ir::recoalesce_inst* rc) {
// pointer to temporary shared memory // pointer to temporary shared memory
Type *ty = llvm_type(rc->get_type()->get_scalar_ty(), *ctx_); Type *ty = llvm_type(rc->get_type()->get_scalar_ty(), *ctx_);
// layouts // layouts
analysis::layout_hmma_884_t* in_layout = layouts_->get(op)->to_hmma884(); analysis::mma884_layout* in_layout = layouts_->get(op)->to_mma884();
analysis::layout_scanline_t* out_layout = layouts_->get(rc)->to_scanline(); analysis::scanline_layout* out_layout = layouts_->get(rc)->to_scanline();
// machine tiles // machine tiles
distributed_tile *in_dt = (distributed_tile*)(tmap_.at(op)); distributed_tile *in_dt = (distributed_tile*)(tmap_.at(op));
distributed_tile *out_dt = (distributed_tile*)(tmap_.at(rc)); distributed_tile *out_dt = (distributed_tile*)(tmap_.at(rc));
// WMMA configuration // WMMA configuration
long wmma_pt[3] = { 2, 4, 1}; long wmma_pt[3] = { 2, 4, 1};
long wmma[3] = { 8*in_layout->wpt[0]*in_layout->fpw[0], long wmma[3] = { 8*in_layout->wpt(0)*in_layout->fpw(0),
8*in_layout->wpt[1]*in_layout->fpw[1], 8*in_layout->wpt(1)*in_layout->fpw(1),
1}; 1};
// Work per thread for input layout // Work per thread for input layout
long in_pt[3] = { shape[0] / wmma[0], long in_pt[3] = { shape[0] / wmma[0],
shape[1] / wmma[1], shape[1] / wmma[1],
1 }; 1 };
// Work per thread for output layout // Work per thread for output layout
long out_pt[3] = { shape[0] / out_layout->mts[0], long out_pt[3] = { shape[0] / out_layout->mts(0),
shape[1] / out_layout->mts[1], shape[1] / out_layout->mts(1),
1 }; 1 };
if(rank > 2){ if(rank > 2){
wmma[2] = in_layout->wpt[2]*in_layout->fpw[2]; wmma[2] = in_layout->wpt(2)*in_layout->fpw(2);
in_pt[2] = shape[2] / wmma[2]; in_pt[2] = shape[2] / wmma[2];
out_pt[2] = shape[2] / out_layout->mts[2]; out_pt[2] = shape[2] / out_layout->mts(2);
} }
// Orders // Orders
auto ord = out_layout->order; auto ord = out_layout->get_order();
if(ord.size() < 3) if(ord.size() < 3)
ord.push_back(2); ord.push_back(2);
// pointer lanes // pointer lanes
@@ -1028,13 +1028,13 @@ void generator::visit_recoalesce_inst(ir::recoalesce_inst* rc) {
void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) { void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) {
unsigned vector_size = 1; unsigned vector_size = 1;
ir::value *arg = cts->get_operand(0); ir::value *arg = cts->get_operand(0);
analysis::layout_shared_t* out_layout = layouts_->get(cts)->to_shared(); analysis::shared_layout* out_layout = layouts_->get(cts)->to_shared();
analysis::layout_scanline_t* in_layout = layouts_->get(arg)->to_scanline(); analysis::scanline_layout* in_layout = layouts_->get(arg)->to_scanline();
auto out_order = out_layout->order; auto out_order = out_layout->get_order();
auto in_order = in_layout->order; auto in_order = in_layout->get_order();
// tiles // tiles
if(out_order == in_order) if(out_order == in_order)
vector_size = in_layout->nts.at(in_order[0]); vector_size = in_layout->nts(in_order[0]);
std::map<unsigned, Value*> packets; std::map<unsigned, Value*> packets;
for_each(arg, [&](indices_t idx){ for_each(arg, [&](indices_t idx){
@@ -1180,17 +1180,17 @@ void generator::visit_function(ir::function* fn) {
void generator::visit_layout_hmma_884(analysis::layout_hmma_884_t* layout) { void generator::visit_layout_hmma_884(analysis::mma884_layout* layout) {
machine_layouts_[layout] = new machine_layout_hmma_884_t(mod_, &*builder_, tgt_, llvm_type(layout->ty->get_scalar_ty(), *ctx_), a_axes_, axes_, layout); machine_layouts_[layout] = new machine_mma884_layout(mod_, &*builder_, tgt_, a_axes_, axes_, layout);
} }
void generator::visit_layout_scanline(analysis::layout_scanline_t* layout) { void generator::visit_layout_scanline(analysis::scanline_layout* layout) {
machine_layouts_[layout] = new machine_layout_scanline_t(mod_, &*builder_, tgt_, llvm_type(layout->ty->get_scalar_ty(), *ctx_), a_axes_, axes_, layout); machine_layouts_[layout] = new machine_scanline_layout(mod_, &*builder_, tgt_, a_axes_, axes_, layout);
} }
void generator::visit_layout_shared(analysis::layout_shared_t* layout) { void generator::visit_layout_shared(analysis::shared_layout* layout) {
machine_layouts_[layout] = new machine_layout_shared_t(mod_, &*builder_, tgt_, alloc_, sh_mem_ptr_, layout, vmap_, tmap_); machine_layouts_[layout] = new machine_shared_layout(mod_, &*builder_, tgt_, alloc_, sh_mem_ptr_, layout, vmap_, tmap_);
} }
void generator::visit_basic_block(ir::basic_block * block) { void generator::visit_basic_block(ir::basic_block * block) {
@@ -1230,9 +1230,9 @@ void generator::set_value(ir::value *x, const indices_t& idx, Value* v) {
} }
void generator::finalize_shared_layout(analysis::layout_shared_t *shared) { void generator::finalize_shared_layout(analysis::shared_layout *shared) {
if(shared->double_buffer) { if(shared->get_double_buffer()) {
auto info = *shared->double_buffer; auto info = *shared->get_double_buffer();
ir::phi_node *phi = info.phi; ir::phi_node *phi = info.phi;
PHINode *ptr = (PHINode*)((shared_tile*)tmap_.at(phi))->get_pointer(); PHINode *ptr = (PHINode*)((shared_tile*)tmap_.at(phi))->get_pointer();
PHINode *offset = (PHINode*)((shared_tile*)tmap_.at(phi))->get_offset(); PHINode *offset = (PHINode*)((shared_tile*)tmap_.at(phi))->get_offset();
@@ -1247,8 +1247,8 @@ void generator::finalize_shared_layout(analysis::layout_shared_t *shared) {
offset->addIncoming(next_offset, llvm_inc_block); offset->addIncoming(next_offset, llvm_inc_block);
} }
else { else {
unsigned num_bytes = shared->ty->get_primitive_size_in_bits() / 8; unsigned num_bytes = shared->get_type()->get_primitive_size_in_bits() / 8;
offset->addIncoming(builder_->getInt32(shared->size / (2*num_bytes)), llvm_inc_block); offset->addIncoming(builder_->getInt32(shared->get_size() / (2*num_bytes)), llvm_inc_block);
} }
ptr->addIncoming(inc_shared->get_pointer(), llvm_inc_block); ptr->addIncoming(inc_shared->get_pointer(), llvm_inc_block);
} }
@@ -1258,7 +1258,7 @@ void generator::finalize_shared_layout(analysis::layout_shared_t *shared) {
void generator::finalize_function(ir::function *fn) { void generator::finalize_function(ir::function *fn) {
// finalize double-buffering // finalize double-buffering
for(const auto& x: layouts_->get_all()) for(const auto& x: layouts_->get_all())
if(auto *shared = dynamic_cast<analysis::layout_shared_t*>(x.second)) if(auto *shared = dynamic_cast<analysis::shared_layout*>(x.second))
finalize_shared_layout(shared); finalize_shared_layout(shared);
// finalize phi // finalize phi
for(ir::basic_block *block: fn->blocks()) for(ir::basic_block *block: fn->blocks())

View File

@@ -71,18 +71,18 @@ inline int32_t ceil(int32_t num, int32_t div){
machine_layout_shared_t::machine_layout_shared_t(Module *mod, Builder *builder, target *tgt, analysis::allocation* alloc, machine_shared_layout::machine_shared_layout(Module *mod, Builder *builder, target *tgt, analysis::allocation* alloc,
Value *&sh_mem_ptr, analysis::layout_shared_t *layout, Value *&sh_mem_ptr, analysis::shared_layout *layout,
std::map<ir::value *, Value *>& vmap, std::map<ir::value *, Value *>& vmap,
std::map<ir::value *, tile *>& tmap) std::map<ir::value *, tile *>& tmap)
: mod_(mod), builder_(builder), tgt_(tgt), alloc_(alloc), sh_mem_ptr_(sh_mem_ptr), layout_(layout), vmap_(vmap), tmap_(tmap) { : mod_(mod), builder_(builder), tgt_(tgt), alloc_(alloc), sh_mem_ptr_(sh_mem_ptr), layout_(layout), vmap_(vmap), tmap_(tmap) {
Type* ty = llvm_type(layout_->ty, builder_->getContext()); Type* ty = llvm_type(layout_->get_type(), builder_->getContext());
PointerType *ptr_ty = ty->getPointerTo(sh_mem_ptr_->getType()->getPointerAddressSpace()); PointerType *ptr_ty = ty->getPointerTo(sh_mem_ptr_->getType()->getPointerAddressSpace());
// double-buffered // double-buffered
if(layout_->double_buffer) { if(layout_->get_double_buffer()) {
BasicBlock *current = builder_->GetInsertBlock(); BasicBlock *current = builder_->GetInsertBlock();
auto info = *layout_->double_buffer; auto info = *layout_->get_double_buffer();
ir::phi_node *phi = info.phi; ir::phi_node *phi = info.phi;
BasicBlock *parent = (BasicBlock*)vmap_.at((ir::value*)(phi->get_parent())); BasicBlock *parent = (BasicBlock*)vmap_.at((ir::value*)(phi->get_parent()));
if(parent->empty()) if(parent->empty())
@@ -105,31 +105,31 @@ machine_layout_shared_t::machine_layout_shared_t(Module *mod, Builder *builder,
} }
tile* machine_layout_shared_t::create(ir::value *v) { tile* machine_shared_layout::create(ir::value *v) {
auto order = layout_->order; Type* ty = llvm_type(layout_->get_type(), builder_->getContext());
auto shapes = layout_->shapes; auto double_buffer = layout_->get_double_buffer();
Type* ty = llvm_type(layout_->ty, builder_->getContext()); // offset
// double-buffered Value *offset = nullptr;
if(layout_->double_buffer) { if(double_buffer && v == double_buffer->phi)
if(v == layout_->double_buffer->phi) offset = offset_;
return new shared_tile(ty, shapes, order, ptr_, *builder_, offset_); // base pointer
if(v == layout_->double_buffer->latch) Value *ptr = ptr_;
return new shared_tile(ty, shapes, order, next_ptr_, *builder_); if(double_buffer && v == double_buffer->latch)
return new shared_tile(ty, shapes, order, pre_ptr_, *builder_); ptr = next_ptr_;
} else if(double_buffer && v == double_buffer->first)
else { ptr = pre_ptr_;
return new shared_tile(ty, shapes, order, ptr_, *builder_); // create tile
} return new shared_tile(ty, layout_->get_shape(), layout_->get_order(), ptr, *builder_, offset);
} }
machine_layout_distributed_t::machine_layout_distributed_t(Module *mod, Builder *builder, target *tgt, Type *ty, machine_distributed_layout::machine_distributed_layout(Module *mod, Builder *builder, target *tgt,
analysis::axes *a_axes, std::map<unsigned, distributed_axis>& axes, analysis::axes *a_axes, std::map<unsigned, distributed_axis>& axes,
analysis::layout_t *layout) analysis::data_layout *layout)
: mod_(mod), builder_(builder), tgt_(tgt), ty_(ty), a_axes_(a_axes), axes_(axes), layout_(layout) { : mod_(mod), builder_(builder), tgt_(tgt), a_axes_(a_axes), axes_(axes), layout_(layout) {
} }
tile *machine_layout_distributed_t::create(ir::value *v) { tile *machine_distributed_layout::create(ir::value *v) {
Type *ty = llvm_type(v->get_type()->get_scalar_ty(), builder_->getContext()); Type *ty = llvm_type(v->get_type()->get_scalar_ty(), builder_->getContext());
const auto &shapes = v->get_type()->get_tile_shapes(); const auto &shapes = v->get_type()->get_tile_shapes();
size_t rank = shapes.size(); size_t rank = shapes.size();
@@ -151,12 +151,10 @@ tile *machine_layout_distributed_t::create(ir::value *v) {
auto cmp = [&](int x, int y) { auto cmp = [&](int x, int y) {
unsigned axx = a_axes_->get(v, x); unsigned axx = a_axes_->get(v, x);
unsigned axy = a_axes_->get(v, y); unsigned axy = a_axes_->get(v, y);
auto itx = std::find(layout_->axes.begin(), layout_->axes.end(), axx); size_t posx = layout_->find_axis(axx);
auto ity = std::find(layout_->axes.begin(), layout_->axes.end(), axy); size_t posy = layout_->find_axis(axy);
size_t posx = std::distance(layout_->axes.begin(), itx);
size_t posy = std::distance(layout_->axes.begin(), ity);
if(posx < rank && posy < rank) if(posx < rank && posy < rank)
return layout_->order[posx] < layout_->order[posy]; return layout_->get_order(posx) < layout_->get_order(posy);
return false; return false;
}; };
std::sort(order.begin(), order.end(), cmp); std::sort(order.begin(), order.end(), cmp);
@@ -164,22 +162,21 @@ tile *machine_layout_distributed_t::create(ir::value *v) {
return new distributed_tile(ty, shapes, order, axes, *builder_); return new distributed_tile(ty, shapes, order, axes, *builder_);
} }
machine_layout_hmma_884_t::machine_layout_hmma_884_t(Module *mod, Builder *builder, machine_mma884_layout::machine_mma884_layout(Module *mod, Builder *builder,
target *tgt, Type *ty, analysis::axes *a_axes, target *tgt, analysis::axes *a_axes,
std::map<unsigned, distributed_axis>& axes, std::map<unsigned, distributed_axis>& axes,
analysis::layout_hmma_884_t* layout) analysis::mma884_layout* layout)
: machine_layout_distributed_t(mod, builder, tgt, ty, a_axes, axes, layout) { : machine_distributed_layout(mod, builder, tgt, a_axes, axes, layout) {
Value *warp_size = builder_->getInt32(32); Value *warp_size = builder_->getInt32(32);
Value* u_thread_id_0 = tgt_->get_local_id(mod_, *builder_, 0); Value* u_thread_id_0 = tgt_->get_local_id(mod_, *builder_, 0);
Value *u_thread_id = builder_->CreateURem(u_thread_id_0, warp_size); Value *u_thread_id = builder_->CreateURem(u_thread_id_0, warp_size);
Value *u_warp_id = builder_->CreateUDiv(u_thread_id_0, warp_size); Value *u_warp_id = builder_->CreateUDiv(u_thread_id_0, warp_size);
const auto& shapes = layout->shapes; const auto& shape = layout->get_shape();
if(shapes.size() > 3) if(shape.size() > 3)
throw std::runtime_error("unsupported"); throw std::runtime_error("unsupported");
bool is_batched = shape.size() >= 3;
bool is_batched = shapes.size() >= 3;
Value *_1 = builder_->getInt32(1); Value *_1 = builder_->getInt32(1);
Value *_2 = builder_->getInt32(2); Value *_2 = builder_->getInt32(2);
@@ -188,13 +185,13 @@ machine_layout_hmma_884_t::machine_layout_hmma_884_t(Module *mod, Builder *build
Value *_16 = builder_->getInt32(16); Value *_16 = builder_->getInt32(16);
// fragments per warp // fragments per warp
unsigned fpw_0 = layout->fpw.at(0); unsigned fpw_0 = layout->fpw(0);
unsigned fpw_1 = layout->fpw.at(1); unsigned fpw_1 = layout->fpw(1);
unsigned fpw_2 = is_batched ? layout->fpw.at(2) : 1; unsigned fpw_2 = is_batched ? layout->fpw(2) : 1;
// warps per tile // warps per tile
unsigned wpt_0 = layout->wpt.at(0); unsigned wpt_0 = layout->wpt(0);
unsigned wpt_1 = layout->wpt.at(1); unsigned wpt_1 = layout->wpt(1);
unsigned wpt_2 = is_batched ? layout->wpt.at(2) : 1; unsigned wpt_2 = is_batched ? layout->wpt(2) : 1;
// hmma warp tile size // hmma warp tile size
unsigned hmma_wts_0 = fpw_0 * 8; unsigned hmma_wts_0 = fpw_0 * 8;
unsigned hmma_wts_1 = fpw_1 * 8; unsigned hmma_wts_1 = fpw_1 * 8;
@@ -204,9 +201,9 @@ machine_layout_hmma_884_t::machine_layout_hmma_884_t(Module *mod, Builder *build
unsigned hmma_bts_1 = hmma_wts_1 * wpt_1; unsigned hmma_bts_1 = hmma_wts_1 * wpt_1;
unsigned hmma_bts_2 = is_batched ? hmma_wts_2 * wpt_2 : 1; unsigned hmma_bts_2 = is_batched ? hmma_wts_2 * wpt_2 : 1;
// number of repetition // number of repetition
unsigned num_rep_0 = shapes[0] / hmma_bts_0; unsigned num_rep_0 = shape[0] / hmma_bts_0;
unsigned num_rep_1 = shapes[1] / hmma_bts_1; unsigned num_rep_1 = shape[1] / hmma_bts_1;
unsigned num_rep_2 = is_batched ? shapes[2] / hmma_bts_2 : 1; unsigned num_rep_2 = is_batched ? shape[2] / hmma_bts_2 : 1;
// size of each pack (interleaving) // size of each pack (interleaving)
pack_size_0_ = std::min<unsigned>(num_rep_0, 1); pack_size_0_ = std::min<unsigned>(num_rep_0, 1);
pack_size_1_ = std::min<unsigned>(num_rep_1, 1); pack_size_1_ = std::min<unsigned>(num_rep_1, 1);
@@ -275,44 +272,52 @@ machine_layout_hmma_884_t::machine_layout_hmma_884_t(Module *mod, Builder *build
/* axes */ /* axes */
axes_[layout->axes[0]] = distributed_axis{1, idx_i, warp_id_0}; axes_[layout->get_axis(0)] = distributed_axis{1, idx_i, warp_id_0};
axes_[layout->axes[1]] = distributed_axis{1, idx_j, warp_id_1}; axes_[layout->get_axis(1)] = distributed_axis{1, idx_j, warp_id_1};
if(is_batched) if(is_batched)
axes_[layout->axes[2]] = distributed_axis{1, idx_z, warp_id_2}; axes_[layout->get_axis(2)] = distributed_axis{1, idx_z, warp_id_2};
} }
machine_layout_scanline_t::machine_layout_scanline_t(Module *mod, Builder *builder, machine_scanline_layout::machine_scanline_layout(Module *mod, Builder *builder,
target *tgt, Type *ty, target *tgt,
analysis::axes *a_axes, std::map<unsigned, distributed_axis> &axes, analysis::axes *a_axes, std::map<unsigned, distributed_axis> &axes,
analysis::layout_scanline_t* layout) analysis::scanline_layout* layout)
: machine_layout_distributed_t(mod, builder, tgt, ty, a_axes, axes, layout) { : machine_distributed_layout(mod, builder, tgt, a_axes, axes, layout) {
Value *warp_size = builder_->getInt32(32); Value *warp_size = builder_->getInt32(32);
Value* u_thread_id_0 = tgt_->get_local_id(mod_, *builder_, 0); Value* u_thread_id_0 = tgt_->get_local_id(mod_, *builder_, 0);
Value *u_thread_id = builder_->CreateURem(u_thread_id_0, warp_size); Value *u_thread_id = builder_->CreateURem(u_thread_id_0, warp_size);
Value *u_warp_id = builder_->CreateUDiv(u_thread_id_0, warp_size); Value *u_warp_id = builder_->CreateUDiv(u_thread_id_0, warp_size);
auto order = layout->order; auto order = layout->get_order();
const auto& shapes = layout->shapes; const auto& shape = layout->get_shape();
size_t dim = shapes.size();
std::vector<int> nts = layout->nts;
std::vector<int> mts = layout->mts;
Value* full_thread_id = builder_->CreateAdd(builder_->CreateMul(u_warp_id, builder_->getInt32(32)), u_thread_id); Value* full_thread_id = builder_->CreateAdd(builder_->CreateMul(u_warp_id, builder_->getInt32(32)), u_thread_id);
std::vector<Value*> thread_id = delinearize(full_thread_id, order, mts, *builder_); // Delinearize
size_t dim = shape.size();
std::vector<Value*> thread_id(dim);
for(unsigned k = 0; k < dim - 1; k++){
Constant *dim_k = builder_->getInt32(layout->mts(order[k]));
Value *rem = builder_->CreateURem(full_thread_id, dim_k);
full_thread_id = builder_->CreateUDiv(full_thread_id, dim_k);
thread_id[order[k]] = rem;
}
thread_id[order[dim - 1]] = full_thread_id;
// Create axes // Create axes
for(unsigned k = 0; k < dim; k++) { for(unsigned k = 0; k < dim; k++) {
int nts = layout->nts(k);
int mts = layout->mts(k);
std::string str_k = std::to_string(k); std::string str_k = std::to_string(k);
Value *contiguous_k = builder_->getInt32(nts[k]); Value *contiguous_k = builder_->getInt32(nts);
Value *scaled_thread_id = builder_->CreateMul(thread_id[k], contiguous_k); Value *scaled_thread_id = builder_->CreateMul(thread_id[k], contiguous_k);
unsigned per_block = nts[k] * mts[k]; unsigned per_block = nts * mts;
unsigned per_thread = nts[k] * shapes[k] / per_block; unsigned per_thread = nts * shape[k] / per_block;
std::vector<Value*> idx_list(per_thread); std::vector<Value*> idx_list(per_thread);
for(unsigned n = 0 ; n < per_thread; n++){ for(unsigned n = 0 ; n < per_thread; n++){
unsigned offset = n / nts[k] * per_block + n % nts[k]; unsigned offset = n / nts * per_block + n % nts;
idx_list[n] = builder_->CreateAdd(scaled_thread_id, builder_->getInt32(offset), "idx_" + str_k + "_" + std::to_string(n)); idx_list[n] = builder_->CreateAdd(scaled_thread_id, builder_->getInt32(offset), "idx_" + str_k + "_" + std::to_string(n));
} }
axes_[layout->axes[k]] = distributed_axis{nts[k], idx_list, thread_id[k]}; axes_[layout->get_axis(k)] = distributed_axis{nts, idx_list, thread_id[k]};
} }
} }

View File

@@ -12,7 +12,7 @@ namespace triton {
namespace codegen{ namespace codegen{
namespace transform{ namespace transform{
coalesce::coalesce(analysis::align* align, analysis::layout *layouts) coalesce::coalesce(analysis::align* align, analysis::layouts *layouts)
: align_(align), layout_(layouts) { } : align_(align), layout_(layouts) { }
// Find all values that are used as pointer operands in LD/ST // Find all values that are used as pointer operands in LD/ST
@@ -64,8 +64,9 @@ ir::value* coalesce::rematerialize(ir::value *x, ir::builder &builder,
void coalesce::run(ir::module &mod) { void coalesce::run(ir::module &mod) {
size_t num_groups = layout_->num_layouts(); size_t num_groups = layout_->num_layouts();
for(size_t id = 0; id < num_groups; id++) { for(size_t id = 0; id < num_groups; id++) {
if(layout_->get(id)->type != analysis::HMMA_884) if(!layout_->get(id)->to_mma884())
continue; continue;
// extract memory stores // extract memory stores
const auto& values = layout_->values_of(id); const auto& values = layout_->values_of(id);
@@ -97,7 +98,6 @@ void coalesce::run(ir::module &mod) {
} }
} }
// find values to rematerialize // find values to rematerialize
std::vector<ir::io_inst*> remat; std::vector<ir::io_inst*> remat;
for(size_t id = 0; id < num_groups; id++) { for(size_t id = 0; id < num_groups; id++) {
@@ -109,7 +109,7 @@ void coalesce::run(ir::module &mod) {
// extract leading axes // extract leading axes
std::map<int, std::vector<ir::io_inst*>> axes; std::map<int, std::vector<ir::io_inst*>> axes;
for(ir::io_inst *i: io){ for(ir::io_inst *i: io){
if(i->get_pointer_operand()->get_type()->get_tile_ranks1() == layout_->get(id)->axes.size()) if(i->get_pointer_operand()->get_type()->get_tile_ranks1() == layout_->get(id)->get_rank())
extract_ld(i, axes); extract_ld(i, axes);
} }
// update list of values to rematerialize // update list of values to rematerialize

View File

@@ -35,10 +35,11 @@ void membar::add_reference(ir::value *v, interval_vec_t &res){
return; return;
if(!i->get_type()->is_tile_ty()) if(!i->get_type()->is_tile_ty())
return; return;
if(alloc_->has_offset(layouts_->get(v))){ analysis::shared_layout* layout = layouts_->get(v)->to_shared();
unsigned offset = alloc_->offset(layouts_->get(v)); assert(layout);
unsigned size = layouts_->get(v)->to_shared()->size; if(alloc_->has_offset(layout)){
res.push_back(interval_t(offset, offset + size)); unsigned offset = alloc_->offset(layout);
res.push_back(interval_t(offset, offset + layout->get_size()));
} }
} }
@@ -119,13 +120,11 @@ void membar::run(ir::module &mod) {
// without needing synchronization // without needing synchronization
std::set<ir::value*> safe_war; std::set<ir::value*> safe_war;
for(const auto& x: layouts_->get_all()){ for(const auto& x: layouts_->get_all()){
if(x.second->type != analysis::SHARED) analysis::shared_layout* layout = x.second->to_shared();
if(!layout || !layout->get_double_buffer())
continue; continue;
analysis::layout_shared_t* layout = x.second->to_shared(); for(ir::value *v: layout->get_values())
if(!layout->double_buffer) if(v != layout->get_double_buffer()->phi)
continue;
for(ir::value *v: layout->values)
if(v != layout->double_buffer->phi)
safe_war.insert(v); safe_war.insert(v);
} }

View File

@@ -220,7 +220,7 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
codegen::analysis::align align; codegen::analysis::align align;
codegen::analysis::axes axes; codegen::analysis::axes axes;
codegen::transform::disassociate disassociate; codegen::transform::disassociate disassociate;
codegen::analysis::layout layouts(&axes, &align, opt.num_warps); codegen::analysis::layouts layouts(&axes, &align, opt.num_warps);
codegen::analysis::liveness liveness(&layouts); codegen::analysis::liveness liveness(&layouts);
codegen::analysis::allocation allocation(&liveness); codegen::analysis::allocation allocation(&liveness);
codegen::transform::membar barriers(&liveness, &layouts, &allocation); codegen::transform::membar barriers(&liveness, &layouts, &allocation);
@@ -239,7 +239,6 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
align.run(module); align.run(module);
cts.run(module); cts.run(module);
axes.run(module); axes.run(module);
// ir::print(module, std::cout);
layouts.run(module); layouts.run(module);
coalesce.run(module); coalesce.run(module);
dce.run(module); dce.run(module);
@@ -250,15 +249,14 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
dce.run(module); dce.run(module);
align.run(module); align.run(module);
axes.run(module); axes.run(module);
// ir::print(module, std::cout);
layouts.run(module); layouts.run(module);
liveness.run(module); liveness.run(module);
allocation.run(module); allocation.run(module);
if(allocation.allocated_size() > context->device()->max_shared_memory()) if(allocation.allocated_size() > context->device()->max_shared_memory())
return std::unique_ptr<driver::module>(); return std::unique_ptr<driver::module>();
barriers.run(module); barriers.run(module);
// ir::print(module, std::cout);
isel.visit(module, *llvm); isel.visit(module, *llvm);
// return binary // return binary
std::unique_ptr<driver::module> res(driver::module::create(context, std::move(llvm))); std::unique_ptr<driver::module> res(driver::module::create(context, std::move(llvm)));
// done // done

View File

@@ -79,7 +79,7 @@ for N, T, H, S, E in NTHSE:
# 1D Dense convolution # 1D Dense convolution
NCHKR = [ NCHKR = [
# (1, 1152, 12602, 512, 3) (1, 1152, 12602, 512, 3)
] ]
for N, C, H, K, R in NCHKR: for N, C, H, K, R in NCHKR:
torch_fn = lambda a, b: torch.nn.functional.conv1d(a, b.permute(2, 0, 1)) torch_fn = lambda a, b: torch.nn.functional.conv1d(a, b.permute(2, 0, 1))
@@ -92,10 +92,10 @@ for N, C, H, K, R in NCHKR:
# 2D Dense convolution # 2D Dense convolution
NCHWKRS = [ NCHWKRS = [
#(8, 64, 128, 128, 768, 3, 3), (8, 64, 128, 128, 768, 3, 3),
#(8, 128, 64, 64, 256, 3, 3), (8, 128, 64, 64, 256, 3, 3),
#(8, 256, 32, 32, 512, 3, 3), (8, 256, 32, 32, 512, 3, 3),
#(8, 512, 32, 32, 1024, 3, 3) (8, 512, 32, 32, 1024, 3, 3)
] ]
for N, C, H, W, K, R, S in NCHWKRS: for N, C, H, W, K, R, S in NCHWKRS:
torch_fn = lambda a, b: torch.nn.functional.conv2d(a, b.permute(3, 0, 1, 2)) torch_fn = lambda a, b: torch.nn.functional.conv2d(a, b.permute(3, 0, 1, 2))
@@ -108,10 +108,10 @@ for N, C, H, W, K, R, S in NCHWKRS:
# 3D Dense Convolution # 3D Dense Convolution
NCDHWKTRS = [ NCDHWKTRS = [
#(8, 32, 27, 100, 100, 64, 3, 3, 3), (8, 32, 27, 100, 100, 64, 3, 3, 3),
#(8, 64, 23, 48, 48, 256, 3, 3, 3), (8, 64, 23, 48, 48, 256, 3, 3, 3),
#(8, 256, 19, 22, 22, 640, 3, 3, 3), (8, 256, 19, 22, 22, 640, 3, 3, 3),
#(8, 640, 15, 36, 36, 384, 3, 3, 3) (8, 640, 15, 36, 36, 384, 3, 3, 3)
] ]
for N, C, D, H, W, K, T, R, S in NCDHWKTRS: for N, C, D, H, W, K, T, R, S in NCDHWKTRS:
torch_fn = lambda a, b: torch.nn.functional.conv3d(a, b.permute(4, 0, 1, 2, 3)) torch_fn = lambda a, b: torch.nn.functional.conv3d(a, b.permute(4, 0, 1, 2, 3))
@@ -168,7 +168,7 @@ for N, C, H, W, K, R, S in NCHWKRS:
# Benchmark # Benchmark
torch.set_num_threads(1) torch.set_num_threads(1)
for a_shape, b_shape, c_shape, torch_fn, expr, arrays in configs: for a_shape, b_shape, c_shape, torch_fn, expr, arrays in configs:
dtype = torch.cuda.FloatTensor dtype = torch.cuda.HalfTensor
# initialize input tensors # initialize input tensors
a = torch.rand(*a_shape).type(dtype).cuda() a = torch.rand(*a_shape).type(dtype).cuda()
b = torch.rand(*b_shape).type(dtype).cuda() b = torch.rand(*b_shape).type(dtype).cuda()