[GENERAL] Cleaned polymorphic structure of layouts analysis pass
This commit is contained in:
@@ -27,14 +27,14 @@ public:
|
|||||||
allocation(liveness *live)
|
allocation(liveness *live)
|
||||||
: liveness_(live) { }
|
: liveness_(live) { }
|
||||||
// accessors
|
// accessors
|
||||||
bool has_offset(const layout_t *x) const { return offsets_.find(x) != offsets_.end(); }
|
bool has_offset(const data_layout *x) const { return offsets_.find(x) != offsets_.end(); }
|
||||||
unsigned offset(const layout_t *x) const { return offsets_.at(x); }
|
unsigned offset(const data_layout *x) const { return offsets_.at(x); }
|
||||||
unsigned allocated_size() const { return allocated_size_; }
|
unsigned allocated_size() const { return allocated_size_; }
|
||||||
// run
|
// run
|
||||||
void run(ir::module& mod);
|
void run(ir::module& mod);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::map<const layout_t*, unsigned> offsets_;
|
std::map<const data_layout*, unsigned> offsets_;
|
||||||
size_t allocated_size_;
|
size_t allocated_size_;
|
||||||
// dependences
|
// dependences
|
||||||
liveness *liveness_;
|
liveness *liveness_;
|
||||||
|
@@ -22,11 +22,106 @@ namespace analysis{
|
|||||||
|
|
||||||
class axes;
|
class axes;
|
||||||
class align;
|
class align;
|
||||||
|
class layout_visitor;
|
||||||
|
class data_layout;
|
||||||
|
class mma884_layout;
|
||||||
|
class scanline_layout;
|
||||||
|
class shared_layout;
|
||||||
|
|
||||||
enum layout_type_t {
|
|
||||||
HMMA_884,
|
class layout_visitor {
|
||||||
SCANLINE,
|
public:
|
||||||
SHARED
|
virtual void visit_layout(data_layout *);
|
||||||
|
virtual void visit_layout_hmma_884(mma884_layout*) = 0;
|
||||||
|
virtual void visit_layout_scanline(scanline_layout*) = 0;
|
||||||
|
virtual void visit_layout_shared(shared_layout*) = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
class data_layout {
|
||||||
|
protected:
|
||||||
|
enum id_t {
|
||||||
|
HMMA_884,
|
||||||
|
SCANLINE,
|
||||||
|
SHARED
|
||||||
|
};
|
||||||
|
|
||||||
|
typedef std::vector<int> axes_t;
|
||||||
|
typedef std::vector<unsigned> shape_t;
|
||||||
|
typedef std::vector<int> order_t;
|
||||||
|
typedef std::vector<ir::value*> values_t;
|
||||||
|
|
||||||
|
private:
|
||||||
|
template<typename T>
|
||||||
|
T* downcast(id_t id) {
|
||||||
|
if(id_ == id)
|
||||||
|
return static_cast<T*>(this);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
public:
|
||||||
|
data_layout(id_t id,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
const std::vector<unsigned> &shape,
|
||||||
|
const std::vector<ir::value *> &values,
|
||||||
|
analysis::align* align);
|
||||||
|
// visitor
|
||||||
|
virtual void accept(layout_visitor* vst) = 0;
|
||||||
|
// downcast
|
||||||
|
mma884_layout* to_mma884() { return downcast<mma884_layout>(HMMA_884); }
|
||||||
|
scanline_layout* to_scanline() { return downcast<scanline_layout>(SCANLINE); }
|
||||||
|
shared_layout* to_shared() { return downcast<shared_layout>(SHARED); }
|
||||||
|
// accessors
|
||||||
|
size_t get_rank() { return shape_.size(); }
|
||||||
|
const shape_t& get_shape() const { return shape_; }
|
||||||
|
const order_t& get_order() const { return order_; }
|
||||||
|
const values_t& get_values() const { return values_;}
|
||||||
|
int get_axis(size_t k) const { return axes_.at(k); }
|
||||||
|
const int get_order(size_t k) const { return order_.at(k); }
|
||||||
|
// find the position of given axis
|
||||||
|
size_t find_axis(int to_find) const;
|
||||||
|
|
||||||
|
|
||||||
|
private:
|
||||||
|
id_t id_;
|
||||||
|
axes_t axes_;
|
||||||
|
values_t values_;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
order_t order_;
|
||||||
|
shape_t shape_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class mma884_layout: public data_layout {
|
||||||
|
public:
|
||||||
|
mma884_layout(size_t num_warps,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
const std::vector<unsigned>& shapes,
|
||||||
|
const std::vector<ir::value *> &values,
|
||||||
|
analysis::align* align);
|
||||||
|
void accept(layout_visitor* vst) { vst->visit_layout_hmma_884(this); }
|
||||||
|
// accessor
|
||||||
|
int fpw(size_t k) { return fpw_.at(k); }
|
||||||
|
int wpt(size_t k) { return wpt_.at(k); }
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::vector<int> fpw_;
|
||||||
|
std::vector<int> wpt_;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct scanline_layout: public data_layout {
|
||||||
|
scanline_layout(size_t num_warps,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
const std::vector<unsigned>& shape,
|
||||||
|
const std::vector<ir::value *> &values,
|
||||||
|
analysis::align* align);
|
||||||
|
void accept(layout_visitor* vst) { vst->visit_layout_scanline(this); }
|
||||||
|
// accessor
|
||||||
|
int mts(size_t k) { return mts_.at(k); }
|
||||||
|
int nts(size_t k) { return nts_.at(k); }
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::vector<int> mts_;
|
||||||
|
std::vector<int> nts_;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct double_buffer_info_t {
|
struct double_buffer_info_t {
|
||||||
@@ -35,90 +130,33 @@ struct double_buffer_info_t {
|
|||||||
ir::phi_node* phi;
|
ir::phi_node* phi;
|
||||||
};
|
};
|
||||||
|
|
||||||
class layout_visitor;
|
class shared_layout: public data_layout {
|
||||||
class layout_t;
|
private:
|
||||||
class layout_hmma_884_t;
|
static bool is_loop_latch(ir::phi_node *phi, ir::instruction *terminator);
|
||||||
class layout_scanline_t;
|
static void extract_double_bufferable(ir::value *v, std::shared_ptr<double_buffer_info_t>& res);
|
||||||
class layout_shared_t;
|
|
||||||
|
|
||||||
|
|
||||||
class layout_visitor {
|
|
||||||
public:
|
public:
|
||||||
virtual void visit_layout(layout_t *);
|
shared_layout(const data_layout *arg,
|
||||||
virtual void visit_layout_hmma_884(layout_hmma_884_t*) = 0;
|
const std::vector<int>& axes,
|
||||||
virtual void visit_layout_scanline(layout_scanline_t*) = 0;
|
const std::vector<unsigned>& shapes,
|
||||||
virtual void visit_layout_shared(layout_shared_t*) = 0;
|
const std::vector<ir::value *> &values_,
|
||||||
};
|
ir::type *ty,
|
||||||
|
analysis::align* align);
|
||||||
class layout_hmma_884_t;
|
|
||||||
class layout_scanline_t;
|
|
||||||
class layout_shared_t;
|
|
||||||
|
|
||||||
struct layout_t {
|
|
||||||
layout_t(layout_type_t _type,
|
|
||||||
const std::vector<int>& _axes,
|
|
||||||
const std::vector<unsigned> &_shapes,
|
|
||||||
const std::vector<ir::value *> &_values,
|
|
||||||
ir::type *_ty,
|
|
||||||
analysis::align* align);
|
|
||||||
// visitor
|
|
||||||
virtual void accept(layout_visitor* vst) = 0;
|
|
||||||
// downcast
|
|
||||||
layout_hmma_884_t* to_hmma884();
|
|
||||||
layout_scanline_t* to_scanline();
|
|
||||||
layout_shared_t* to_shared();
|
|
||||||
|
|
||||||
|
|
||||||
layout_type_t type;
|
|
||||||
std::vector<int> axes;
|
|
||||||
std::vector<unsigned> shapes;
|
|
||||||
std::vector<ir::value*> values;
|
|
||||||
std::vector<int> order;
|
|
||||||
ir::type *ty;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct layout_hmma_884_t: public layout_t {
|
|
||||||
layout_hmma_884_t(size_t num_warps,
|
|
||||||
const std::vector<int>& _axes,
|
|
||||||
const std::vector<unsigned>& _shapes,
|
|
||||||
const std::vector<ir::value *> &_values,
|
|
||||||
ir::type *_ty,
|
|
||||||
analysis::align* align);
|
|
||||||
void accept(layout_visitor* vst) { vst->visit_layout_hmma_884(this); }
|
|
||||||
|
|
||||||
std::vector<int> fpw;
|
|
||||||
std::vector<int> wpt;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct layout_scanline_t: public layout_t {
|
|
||||||
layout_scanline_t(size_t num_warps,
|
|
||||||
const std::vector<int>& _axes,
|
|
||||||
const std::vector<unsigned>& _shapes,
|
|
||||||
const std::vector<ir::value *> &values,
|
|
||||||
ir::type *_ty,
|
|
||||||
analysis::align* align);
|
|
||||||
void accept(layout_visitor* vst) { vst->visit_layout_scanline(this); }
|
|
||||||
|
|
||||||
std::vector<int> mts;
|
|
||||||
std::vector<int> nts;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct layout_shared_t: public layout_t {
|
|
||||||
layout_shared_t(const layout_t *arg,
|
|
||||||
const std::vector<int>& _axes,
|
|
||||||
const std::vector<unsigned>& _shapes,
|
|
||||||
const std::vector<ir::value *> &values,
|
|
||||||
ir::type *ty,
|
|
||||||
analysis::align* align);
|
|
||||||
void accept(layout_visitor* vst) { vst->visit_layout_shared(this); }
|
void accept(layout_visitor* vst) { vst->visit_layout_shared(this); }
|
||||||
|
// accessors
|
||||||
|
size_t get_size() { return size_; }
|
||||||
|
ir::type* get_type() { return ty_; }
|
||||||
|
double_buffer_info_t* get_double_buffer() { return double_buffer_.get(); }
|
||||||
|
|
||||||
std::shared_ptr<double_buffer_info_t> double_buffer;
|
private:
|
||||||
size_t size;
|
size_t size_;
|
||||||
|
ir::type *ty_;
|
||||||
|
std::shared_ptr<double_buffer_info_t> double_buffer_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class layout {
|
class layouts {
|
||||||
typedef ir::value* node_t;
|
typedef ir::value* node_t;
|
||||||
typedef std::map <node_t, std::set<node_t>> graph_t;
|
typedef std::map <node_t, std::set<node_t>> graph_t;
|
||||||
|
|
||||||
@@ -127,23 +165,23 @@ private:
|
|||||||
void connect(ir::value *x, ir::value *y);
|
void connect(ir::value *x, ir::value *y);
|
||||||
void make_graph(ir::instruction *i);
|
void make_graph(ir::instruction *i);
|
||||||
|
|
||||||
void init_hmma_tile(layout_t& layout);
|
void init_hmma_tile(data_layout& layouts);
|
||||||
void init_scanline_tile(layout_t &layout);
|
void init_scanline_tile(data_layout &layouts);
|
||||||
|
|
||||||
void create(size_t id, const std::vector<ir::value*>& values);
|
void create(size_t id, const std::vector<ir::value*>& values);
|
||||||
|
|
||||||
public:
|
public:
|
||||||
// constructor
|
// constructor
|
||||||
layout(analysis::axes *axes, analysis::align *align, size_t num_warps);
|
layouts(analysis::axes *axes, analysis::align *align, size_t num_warps);
|
||||||
|
|
||||||
// accessors
|
// accessors
|
||||||
unsigned layout_of(ir::value *value) const;
|
unsigned layout_of(ir::value *value) const { return groups_.at(value); }
|
||||||
const std::vector<ir::value*>& values_of(unsigned id) const;
|
const std::vector<ir::value*>& values_of(unsigned id) const { return values_.at(id); }
|
||||||
size_t num_layouts() const;
|
size_t num_layouts() const { return values_.size();}
|
||||||
layout_t* get(size_t id);
|
data_layout* get(size_t id) { return layouts_.at(id); }
|
||||||
layout_t* get(ir::value *v);
|
data_layout* get(ir::value *v) { return get(layout_of(v));}
|
||||||
std::map<size_t, layout_t*> &get_all();
|
std::map<size_t, data_layout*> &get_all() { return layouts_; }
|
||||||
size_t tmp(ir::instruction* i);
|
size_t tmp(ir::instruction* i) { return tmp_.at((ir::value*)i);}
|
||||||
|
|
||||||
// execution
|
// execution
|
||||||
void run(ir::module &mod);
|
void run(ir::module &mod);
|
||||||
@@ -155,7 +193,7 @@ private:
|
|||||||
tools::graph<ir::value*> graph_;
|
tools::graph<ir::value*> graph_;
|
||||||
std::map<ir::value*, size_t> groups_;
|
std::map<ir::value*, size_t> groups_;
|
||||||
std::map<size_t, std::vector<ir::value*>> values_;
|
std::map<size_t, std::vector<ir::value*>> values_;
|
||||||
std::map<size_t, layout_t*> layouts_;
|
std::map<size_t, data_layout*> layouts_;
|
||||||
std::map<ir::value*, size_t> tmp_;
|
std::map<ir::value*, size_t> tmp_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@@ -23,8 +23,8 @@ namespace analysis{
|
|||||||
typedef unsigned slot_index;
|
typedef unsigned slot_index;
|
||||||
|
|
||||||
class tiles;
|
class tiles;
|
||||||
class layout;
|
class layouts;
|
||||||
class layout_t;
|
class data_layout;
|
||||||
|
|
||||||
struct segment {
|
struct segment {
|
||||||
slot_index start;
|
slot_index start;
|
||||||
@@ -42,20 +42,20 @@ struct segment {
|
|||||||
|
|
||||||
class liveness {
|
class liveness {
|
||||||
private:
|
private:
|
||||||
typedef std::map<layout_shared_t*, segment> intervals_map_t;
|
typedef std::map<shared_layout*, segment> intervals_map_t;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
// constructor
|
// constructor
|
||||||
liveness(layout *l): layouts_(l){ }
|
liveness(layouts *l): layouts_(l){ }
|
||||||
// accessors
|
// accessors
|
||||||
const intervals_map_t& get() const { return intervals_; }
|
const intervals_map_t& get() const { return intervals_; }
|
||||||
segment get(layout_shared_t* v) const { return intervals_.at(v); }
|
segment get(shared_layout* v) const { return intervals_.at(v); }
|
||||||
// run
|
// run
|
||||||
void run(ir::module &mod);
|
void run(ir::module &mod);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// analysis
|
// analysis
|
||||||
layout *layouts_;
|
layouts *layouts_;
|
||||||
intervals_map_t intervals_;
|
intervals_map_t intervals_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@@ -35,7 +35,7 @@ class align;
|
|||||||
class allocation;
|
class allocation;
|
||||||
class cts;
|
class cts;
|
||||||
class axes;
|
class axes;
|
||||||
class layout;
|
class layouts;
|
||||||
}
|
}
|
||||||
// typedef
|
// typedef
|
||||||
typedef llvm::IRBuilder<llvm::ConstantFolder,
|
typedef llvm::IRBuilder<llvm::ConstantFolder,
|
||||||
@@ -50,7 +50,7 @@ typedef llvm::ArrayType ArrayType;
|
|||||||
typedef llvm::Function Function;
|
typedef llvm::Function Function;
|
||||||
typedef std::vector<Value*> indices_t;
|
typedef std::vector<Value*> indices_t;
|
||||||
// forward
|
// forward
|
||||||
class machine_layout_t;
|
class machine_data_layout;
|
||||||
class tile;
|
class tile;
|
||||||
class shared_tile;
|
class shared_tile;
|
||||||
class distributed_tile;
|
class distributed_tile;
|
||||||
@@ -74,13 +74,13 @@ private:
|
|||||||
void visit_outer_dot(ir::dot_inst*, distributed_tile *TA, distributed_tile *TB, distributed_tile *TD, unsigned NK,
|
void visit_outer_dot(ir::dot_inst*, distributed_tile *TA, distributed_tile *TB, distributed_tile *TD, unsigned NK,
|
||||||
Type *c_ty, Function *f_mul_add);
|
Type *c_ty, Function *f_mul_add);
|
||||||
|
|
||||||
void finalize_shared_layout(analysis::layout_shared_t*);
|
void finalize_shared_layout(analysis::shared_layout*);
|
||||||
void finalize_function(ir::function*);
|
void finalize_function(ir::function*);
|
||||||
void finalize_phi_node(ir::phi_node*);
|
void finalize_phi_node(ir::phi_node*);
|
||||||
|
|
||||||
public:
|
public:
|
||||||
generator(analysis::axes *a_axes,
|
generator(analysis::axes *a_axes,
|
||||||
analysis::layout *layouts,
|
analysis::layouts *layouts,
|
||||||
analysis::align *alignment,
|
analysis::align *alignment,
|
||||||
analysis::allocation *alloc,
|
analysis::allocation *alloc,
|
||||||
target *tgt,
|
target *tgt,
|
||||||
@@ -139,9 +139,9 @@ public:
|
|||||||
void visit_basic_block(ir::basic_block*);
|
void visit_basic_block(ir::basic_block*);
|
||||||
void visit_argument(ir::argument*);
|
void visit_argument(ir::argument*);
|
||||||
|
|
||||||
void visit_layout_hmma_884(analysis::layout_hmma_884_t*);
|
void visit_layout_hmma_884(analysis::mma884_layout*);
|
||||||
void visit_layout_scanline(analysis::layout_scanline_t*);
|
void visit_layout_scanline(analysis::scanline_layout*);
|
||||||
void visit_layout_shared(analysis::layout_shared_t*);
|
void visit_layout_shared(analysis::shared_layout*);
|
||||||
|
|
||||||
void visit(ir::module &, llvm::Module &);
|
void visit(ir::module &, llvm::Module &);
|
||||||
|
|
||||||
@@ -150,13 +150,13 @@ private:
|
|||||||
Builder* builder_;
|
Builder* builder_;
|
||||||
Module *mod_;
|
Module *mod_;
|
||||||
|
|
||||||
std::map<const analysis::layout_t*, machine_layout_t*> machine_layouts_;
|
std::map<const analysis::data_layout*, machine_data_layout*> machine_layouts_;
|
||||||
analysis::axes *a_axes_;
|
analysis::axes *a_axes_;
|
||||||
std::map<unsigned, distributed_axis> axes_;
|
std::map<unsigned, distributed_axis> axes_;
|
||||||
std::map<ir::value *, Value *> vmap_;
|
std::map<ir::value *, Value *> vmap_;
|
||||||
std::map<ir::value *, tile *> tmap_;
|
std::map<ir::value *, tile *> tmap_;
|
||||||
target *tgt_;
|
target *tgt_;
|
||||||
analysis::layout *layouts_;
|
analysis::layouts *layouts_;
|
||||||
analysis::align *alignment_;
|
analysis::align *alignment_;
|
||||||
analysis::allocation *alloc_;
|
analysis::allocation *alloc_;
|
||||||
Value *sh_mem_ptr_;
|
Value *sh_mem_ptr_;
|
||||||
|
@@ -36,7 +36,7 @@ class align;
|
|||||||
class allocation;
|
class allocation;
|
||||||
class cts;
|
class cts;
|
||||||
class axes;
|
class axes;
|
||||||
class layout;
|
class layouts;
|
||||||
}
|
}
|
||||||
|
|
||||||
typedef llvm::IRBuilder<llvm::ConstantFolder,
|
typedef llvm::IRBuilder<llvm::ConstantFolder,
|
||||||
@@ -51,7 +51,7 @@ typedef llvm::ArrayType ArrayType;
|
|||||||
typedef llvm::Function Function;
|
typedef llvm::Function Function;
|
||||||
|
|
||||||
class distributed_axis;
|
class distributed_axis;
|
||||||
class machine_layout_t;
|
class machine_data_layout;
|
||||||
class tile;
|
class tile;
|
||||||
class shared_tile;
|
class shared_tile;
|
||||||
class distributed_tile;
|
class distributed_tile;
|
||||||
@@ -64,15 +64,15 @@ namespace triton{
|
|||||||
namespace codegen{
|
namespace codegen{
|
||||||
|
|
||||||
|
|
||||||
class machine_layout_t {
|
class machine_data_layout {
|
||||||
public:
|
public:
|
||||||
virtual tile* create(ir::value *v) = 0;
|
virtual tile* create(ir::value *v) = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
class machine_layout_shared_t: public machine_layout_t {
|
class machine_shared_layout: public machine_data_layout {
|
||||||
public:
|
public:
|
||||||
machine_layout_shared_t(Module *mod, Builder *builder, target *tgt, analysis::allocation* alloc, Value *&sh_mem_ptr,
|
machine_shared_layout(Module *mod, Builder *builder, target *tgt, analysis::allocation* alloc, Value *&sh_mem_ptr,
|
||||||
analysis::layout_shared_t* layout,
|
analysis::shared_layout* layout,
|
||||||
std::map<ir::value *, Value *>& vmap,
|
std::map<ir::value *, Value *>& vmap,
|
||||||
std::map<ir::value *, tile *>& tmap);
|
std::map<ir::value *, tile *>& tmap);
|
||||||
|
|
||||||
@@ -83,7 +83,7 @@ public:
|
|||||||
target *tgt_;
|
target *tgt_;
|
||||||
analysis::allocation* alloc_;
|
analysis::allocation* alloc_;
|
||||||
Value *&sh_mem_ptr_;
|
Value *&sh_mem_ptr_;
|
||||||
analysis::layout_shared_t* layout_;
|
analysis::shared_layout* layout_;
|
||||||
std::map<ir::value *, Value *>& vmap_;
|
std::map<ir::value *, Value *>& vmap_;
|
||||||
std::map<ir::value *, tile *>& tmap_;
|
std::map<ir::value *, tile *>& tmap_;
|
||||||
|
|
||||||
@@ -94,29 +94,28 @@ public:
|
|||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
class machine_layout_distributed_t: public machine_layout_t {
|
class machine_distributed_layout: public machine_data_layout {
|
||||||
public:
|
public:
|
||||||
machine_layout_distributed_t(Module *mod, Builder *builder, target *tgt, Type *ty,
|
machine_distributed_layout(Module *mod, Builder *builder, target *tgt,
|
||||||
analysis::axes *a_axes, std::map<unsigned, distributed_axis>& axes,
|
analysis::axes *a_axes, std::map<unsigned, distributed_axis>& axes,
|
||||||
analysis::layout_t* layout);
|
analysis::data_layout* layout);
|
||||||
|
|
||||||
tile* create(ir::value *v);
|
tile* create(ir::value *v);
|
||||||
Module *mod_;
|
Module *mod_;
|
||||||
Builder *builder_;
|
Builder *builder_;
|
||||||
target *tgt_;
|
target *tgt_;
|
||||||
Type *ty_;
|
|
||||||
analysis::axes *a_axes_;
|
analysis::axes *a_axes_;
|
||||||
std::map<unsigned, distributed_axis>& axes_;
|
std::map<unsigned, distributed_axis>& axes_;
|
||||||
analysis::layout_t* layout_;
|
analysis::data_layout* layout_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
class machine_layout_hmma_884_t: public machine_layout_distributed_t {
|
class machine_mma884_layout: public machine_distributed_layout {
|
||||||
public:
|
public:
|
||||||
machine_layout_hmma_884_t(Module *mod, Builder *builder,
|
machine_mma884_layout(Module *mod, Builder *builder,
|
||||||
target *tgt, Type *ty,
|
target *tgt,
|
||||||
analysis::axes *a_axes, std::map<unsigned, distributed_axis>& axes,
|
analysis::axes *a_axes, std::map<unsigned, distributed_axis>& axes,
|
||||||
analysis::layout_hmma_884_t* layout);
|
analysis::mma884_layout* layout);
|
||||||
Value *offset_a_i_, *offset_a_k_;
|
Value *offset_a_i_, *offset_a_k_;
|
||||||
Value *offset_b_j_, *offset_b_k_;
|
Value *offset_b_j_, *offset_b_k_;
|
||||||
unsigned pack_size_0_;
|
unsigned pack_size_0_;
|
||||||
@@ -125,12 +124,12 @@ public:
|
|||||||
unsigned num_packs_1_;
|
unsigned num_packs_1_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class machine_layout_scanline_t: public machine_layout_distributed_t {
|
class machine_scanline_layout: public machine_distributed_layout {
|
||||||
public:
|
public:
|
||||||
machine_layout_scanline_t(Module *mod, Builder *builder,
|
machine_scanline_layout(Module *mod, Builder *builder,
|
||||||
target *tgt, Type *ty,
|
target *tgt,
|
||||||
analysis::axes *a_axes, std::map<unsigned, distributed_axis>& axes,
|
analysis::axes *a_axes, std::map<unsigned, distributed_axis>& axes,
|
||||||
analysis::layout_scanline_t* layout);
|
analysis::scanline_layout* layout);
|
||||||
};
|
};
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@@ -47,11 +47,11 @@ class align;
|
|||||||
class allocation;
|
class allocation;
|
||||||
class cts;
|
class cts;
|
||||||
class axes;
|
class axes;
|
||||||
class layout;
|
class layouts;
|
||||||
}
|
}
|
||||||
|
|
||||||
class distributed_axis;
|
class distributed_axis;
|
||||||
class machine_layout_t;
|
class machine_data_layout;
|
||||||
class tile;
|
class tile;
|
||||||
class shared_tile;
|
class shared_tile;
|
||||||
class distributed_tile;
|
class distributed_tile;
|
||||||
|
@@ -19,7 +19,7 @@ namespace codegen{
|
|||||||
|
|
||||||
namespace analysis{
|
namespace analysis{
|
||||||
class align;
|
class align;
|
||||||
class layout;
|
class layouts;
|
||||||
class cts;
|
class cts;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -32,12 +32,12 @@ private:
|
|||||||
ir::value* rematerialize(ir::value *v, ir::builder& builder, std::map<ir::value*, ir::value*>& seen);
|
ir::value* rematerialize(ir::value *v, ir::builder& builder, std::map<ir::value*, ir::value*>& seen);
|
||||||
|
|
||||||
public:
|
public:
|
||||||
coalesce(analysis::align* align, triton::codegen::analysis::layout *layouts);
|
coalesce(analysis::align* align, triton::codegen::analysis::layouts *layouts);
|
||||||
void run(ir::module &mod);
|
void run(ir::module &mod);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
analysis::align* align_;
|
analysis::align* align_;
|
||||||
analysis::layout* layout_;
|
analysis::layouts* layout_;
|
||||||
};
|
};
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@@ -17,7 +17,7 @@ namespace analysis{
|
|||||||
|
|
||||||
class allocation;
|
class allocation;
|
||||||
class liveness;
|
class liveness;
|
||||||
class layout;
|
class layouts;
|
||||||
class cts;
|
class cts;
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -41,13 +41,13 @@ private:
|
|||||||
std::set<ir::instruction *> &insert_loc, std::set<triton::ir::value *> &safe_war);
|
std::set<ir::instruction *> &insert_loc, std::set<triton::ir::value *> &safe_war);
|
||||||
|
|
||||||
public:
|
public:
|
||||||
membar(analysis::liveness *liveness, analysis::layout *layouts, analysis::allocation *alloc):
|
membar(analysis::liveness *liveness, analysis::layouts *layouts, analysis::allocation *alloc):
|
||||||
liveness_(liveness), layouts_(layouts), alloc_(alloc) {}
|
liveness_(liveness), layouts_(layouts), alloc_(alloc) {}
|
||||||
void run(ir::module &mod);
|
void run(ir::module &mod);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
analysis::liveness *liveness_;
|
analysis::liveness *liveness_;
|
||||||
analysis::layout *layouts_;
|
analysis::layouts *layouts_;
|
||||||
analysis::allocation *alloc_;
|
analysis::allocation *alloc_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@@ -15,22 +15,22 @@ void allocation::run(ir::module &mod) {
|
|||||||
using std::min;
|
using std::min;
|
||||||
typedef std::multimap<unsigned, segment> triples_map_type;
|
typedef std::multimap<unsigned, segment> triples_map_type;
|
||||||
|
|
||||||
std::vector<layout_shared_t*> I;
|
std::vector<shared_layout*> I;
|
||||||
for(auto x: liveness_->get())
|
for(auto x: liveness_->get())
|
||||||
I.push_back(x.first);
|
I.push_back(x.first);
|
||||||
std::vector<layout_shared_t*> J = I;
|
std::vector<shared_layout*> J = I;
|
||||||
|
|
||||||
triples_map_type H;
|
triples_map_type H;
|
||||||
H.insert({0, segment{0, INT_MAX}});
|
H.insert({0, segment{0, INT_MAX}});
|
||||||
|
|
||||||
std::vector<layout_shared_t*> V;
|
std::vector<shared_layout*> V;
|
||||||
std::map<layout_shared_t*, unsigned> starts;
|
std::map<shared_layout*, unsigned> starts;
|
||||||
while(!J.empty()){
|
while(!J.empty()){
|
||||||
auto h_it = H.begin();
|
auto h_it = H.begin();
|
||||||
unsigned w = h_it->first;
|
unsigned w = h_it->first;
|
||||||
segment xh = h_it->second;
|
segment xh = h_it->second;
|
||||||
H.erase(h_it);
|
H.erase(h_it);
|
||||||
auto j_it = std::find_if(J.begin(), J.end(), [&](layout_shared_t* JJ){
|
auto j_it = std::find_if(J.begin(), J.end(), [&](shared_layout* JJ){
|
||||||
segment xj = liveness_->get(JJ);
|
segment xj = liveness_->get(JJ);
|
||||||
bool res = xj.intersect(xh);
|
bool res = xj.intersect(xh);
|
||||||
for(auto val: H)
|
for(auto val: H)
|
||||||
@@ -38,7 +38,7 @@ void allocation::run(ir::module &mod) {
|
|||||||
return res;
|
return res;
|
||||||
});
|
});
|
||||||
if(j_it != J.end()){
|
if(j_it != J.end()){
|
||||||
unsigned size = (*j_it)->size;
|
unsigned size = (*j_it)->get_size();
|
||||||
segment xj = liveness_->get(*j_it);
|
segment xj = liveness_->get(*j_it);
|
||||||
starts[*j_it] = w;
|
starts[*j_it] = w;
|
||||||
H.insert({w + size, segment{max(xh.start, xj.start), min(xh.end, xj.end)}});
|
H.insert({w + size, segment{max(xh.start, xj.start), min(xh.end, xj.end)}});
|
||||||
@@ -52,14 +52,14 @@ void allocation::run(ir::module &mod) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Build interference graph
|
// Build interference graph
|
||||||
std::map<layout_shared_t*, std::set<layout_shared_t*>> interferences;
|
std::map<shared_layout*, std::set<shared_layout*>> interferences;
|
||||||
for(layout_shared_t* x: V)
|
for(shared_layout* x: V)
|
||||||
for(layout_shared_t* y: V){
|
for(shared_layout* y: V){
|
||||||
if(x == y)
|
if(x == y)
|
||||||
continue;
|
continue;
|
||||||
unsigned X0 = starts[x], Y0 = starts[y];
|
unsigned X0 = starts[x], Y0 = starts[y];
|
||||||
unsigned NX = x->size;
|
unsigned NX = x->get_size();
|
||||||
unsigned NY = y->size;
|
unsigned NY = y->get_size();
|
||||||
segment XS = {X0, X0 + NX};
|
segment XS = {X0, X0 + NX};
|
||||||
segment YS = {Y0, Y0 + NY};
|
segment YS = {Y0, Y0 + NY};
|
||||||
if(liveness_->get(x).intersect(liveness_->get(y))
|
if(liveness_->get(x).intersect(liveness_->get(y))
|
||||||
@@ -68,17 +68,17 @@ void allocation::run(ir::module &mod) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Initialize colors
|
// Initialize colors
|
||||||
std::map<layout_shared_t*, int> colors;
|
std::map<shared_layout*, int> colors;
|
||||||
for(layout_shared_t* X: V)
|
for(shared_layout* X: V)
|
||||||
colors[X] = (X==V[0])?0:-1;
|
colors[X] = (X==V[0])?0:-1;
|
||||||
|
|
||||||
|
|
||||||
// First-fit graph coloring
|
// First-fit graph coloring
|
||||||
std::vector<bool> available(V.size());
|
std::vector<bool> available(V.size());
|
||||||
for(layout_shared_t* x: V){
|
for(shared_layout* x: V){
|
||||||
// Non-neighboring colors are available
|
// Non-neighboring colors are available
|
||||||
std::fill(available.begin(), available.end(), true);
|
std::fill(available.begin(), available.end(), true);
|
||||||
for(layout_shared_t* Y: interferences[x]){
|
for(shared_layout* Y: interferences[x]){
|
||||||
int color = colors[Y];
|
int color = colors[Y];
|
||||||
if(color >= 0)
|
if(color >= 0)
|
||||||
available[color] = false;
|
available[color] = false;
|
||||||
@@ -89,17 +89,17 @@ void allocation::run(ir::module &mod) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Finalize allocation
|
// Finalize allocation
|
||||||
for(layout_shared_t* x: V){
|
for(shared_layout* x: V){
|
||||||
unsigned Adj = 0;
|
unsigned Adj = 0;
|
||||||
for(layout_shared_t* y: interferences[x])
|
for(shared_layout* y: interferences[x])
|
||||||
Adj = std::max<unsigned>(Adj, starts[y] + y->size);
|
Adj = std::max<unsigned>(Adj, starts[y] + y->get_size());
|
||||||
offsets_[x] = starts[x] + colors[x] * Adj;
|
offsets_[x] = starts[x] + colors[x] * Adj;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Save maximum size of induced memory space
|
// Save maximum size of induced memory space
|
||||||
allocated_size_ = 0;
|
allocated_size_ = 0;
|
||||||
for(layout_shared_t* x: V)
|
for(shared_layout* x: V)
|
||||||
allocated_size_ = std::max<size_t>(allocated_size_, starts[x] + x->size);
|
allocated_size_ = std::max<size_t>(allocated_size_, starts[x] + x->get_size());
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@@ -12,57 +12,15 @@ namespace triton{
|
|||||||
namespace codegen{
|
namespace codegen{
|
||||||
namespace analysis{
|
namespace analysis{
|
||||||
|
|
||||||
|
/* -------------------------------- *
|
||||||
|
* Helper Functions *
|
||||||
|
* -------------------------------- */
|
||||||
|
|
||||||
// constructor
|
inline unsigned clamp(unsigned x, unsigned lo, unsigned hi) {
|
||||||
layout::layout(analysis::axes *axes, analysis::align *align, size_t num_warps)
|
return std::min(std::max(x, lo), hi);
|
||||||
: axes_(axes), align_(align), num_warps_(num_warps) { }
|
|
||||||
|
|
||||||
// get group id
|
|
||||||
unsigned layout::layout_of(ir::value *value) const
|
|
||||||
{ return groups_.at(value); }
|
|
||||||
|
|
||||||
// get values
|
|
||||||
const std::vector<ir::value*>& layout::values_of(unsigned id) const
|
|
||||||
{ return values_.at(id); }
|
|
||||||
|
|
||||||
// get number of groups
|
|
||||||
size_t layout::num_layouts() const
|
|
||||||
{ return values_.size(); }
|
|
||||||
|
|
||||||
// connect two values
|
|
||||||
void layout::connect(ir::value *x, ir::value *y) {
|
|
||||||
if(x == y)
|
|
||||||
return;
|
|
||||||
if(!x->get_type()->is_tile_ty())
|
|
||||||
return;
|
|
||||||
if(!y->get_type()->is_tile_ty())
|
|
||||||
return;
|
|
||||||
std::vector<int> x_axes = axes_->get(x);
|
|
||||||
std::vector<int> y_axes = axes_->get(y);
|
|
||||||
std::set<int> sx_axes(x_axes.begin(), x_axes.end());
|
|
||||||
std::set<int> sy_axes(y_axes.begin(), y_axes.end());
|
|
||||||
std::set<int> common;
|
|
||||||
std::set_intersection(sx_axes.begin(), sx_axes.end(),
|
|
||||||
sy_axes.begin(), sy_axes.end(),
|
|
||||||
std::inserter(common, common.begin()));
|
|
||||||
graph_.add_edge(x, x);
|
|
||||||
graph_.add_edge(y, y);
|
|
||||||
if(!common.empty())
|
|
||||||
graph_.add_edge(x, y);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// make graph
|
inline bool is_hmma_c(ir::value *v){
|
||||||
void layout::make_graph(ir::instruction *i) {
|
|
||||||
for(ir::value* opx: i->ops())
|
|
||||||
for(ir::value* opy: i->ops()){
|
|
||||||
connect(i, opx);
|
|
||||||
connect(opx, opy);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
// hmma
|
|
||||||
bool is_hmma_c(ir::value *v){
|
|
||||||
bool result = false;
|
bool result = false;
|
||||||
if(auto *x = dynamic_cast<ir::dot_inst*>(v)){
|
if(auto *x = dynamic_cast<ir::dot_inst*>(v)){
|
||||||
ir::value *a = x->get_operand(0);
|
ir::value *a = x->get_operand(0);
|
||||||
@@ -75,23 +33,7 @@ bool is_hmma_c(ir::value *v){
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
layout_t* layout::get(size_t id) {
|
inline void extract_io_use(ir::value *v, std::set<ir::value*>& result) {
|
||||||
return layouts_.at(id);
|
|
||||||
}
|
|
||||||
|
|
||||||
layout_t* layout::get(ir::value *v) {
|
|
||||||
return layouts_.at(groups_.at(v));
|
|
||||||
}
|
|
||||||
|
|
||||||
std::map<size_t, layout_t*>& layout::get_all() {
|
|
||||||
return layouts_;
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t layout::tmp(ir::instruction* i) {
|
|
||||||
return tmp_.at(i);
|
|
||||||
}
|
|
||||||
|
|
||||||
void extract_io_use(ir::value *v, std::set<ir::value*>& result) {
|
|
||||||
for(ir::user* u: v->get_users()){
|
for(ir::user* u: v->get_users()){
|
||||||
auto i = dynamic_cast<ir::io_inst*>(u);
|
auto i = dynamic_cast<ir::io_inst*>(u);
|
||||||
if(i && i->get_pointer_operand() == v)
|
if(i && i->get_pointer_operand() == v)
|
||||||
@@ -99,7 +41,7 @@ void extract_io_use(ir::value *v, std::set<ir::value*>& result) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void extract_dot_use(ir::value *v, ir::value*& result, size_t n) {
|
inline void extract_dot_use(ir::value *v, ir::value*& result, size_t n) {
|
||||||
for(ir::user* u: v->get_users()){
|
for(ir::user* u: v->get_users()){
|
||||||
auto i = dynamic_cast<ir::dot_inst*>(u);
|
auto i = dynamic_cast<ir::dot_inst*>(u);
|
||||||
if(i && i->get_operand(n) == v)
|
if(i && i->get_operand(n) == v)
|
||||||
@@ -107,7 +49,7 @@ void extract_dot_use(ir::value *v, ir::value*& result, size_t n) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void extract_hmma_dot_use(ir::value *v, ir::value*& result, size_t n) {
|
inline void extract_hmma_dot_use(ir::value *v, ir::value*& result, size_t n) {
|
||||||
for(ir::user* u: v->get_users()){
|
for(ir::user* u: v->get_users()){
|
||||||
auto i = dynamic_cast<ir::dot_inst*>(u);
|
auto i = dynamic_cast<ir::dot_inst*>(u);
|
||||||
if(i && is_hmma_c(i) && i->get_operand(n) == v)
|
if(i && is_hmma_c(i) && i->get_operand(n) == v)
|
||||||
@@ -116,7 +58,6 @@ void extract_hmma_dot_use(ir::value *v, ir::value*& result, size_t n) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
inline bool is_trans(ir::value *v) {
|
inline bool is_trans(ir::value *v) {
|
||||||
if(dynamic_cast<ir::trans_inst *>(v)) {
|
if(dynamic_cast<ir::trans_inst *>(v)) {
|
||||||
return true;
|
return true;
|
||||||
@@ -131,104 +72,103 @@ inline bool is_trans(ir::value *v) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
void layout_visitor::visit_layout(layout_t *layout) {
|
/* -------------------------------- *
|
||||||
|
* Layout Visitor *
|
||||||
|
* -------------------------------- */
|
||||||
|
|
||||||
|
void layout_visitor::visit_layout(data_layout *layout) {
|
||||||
layout->accept(this);
|
layout->accept(this);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
layout_t::layout_t(layout_type_t _type,
|
/* -------------------------------- *
|
||||||
const std::vector<int> &_axes,
|
* Base Data Layout *
|
||||||
const std::vector<unsigned> &_shapes,
|
* -------------------------------- */
|
||||||
const std::vector<ir::value *> &_values, ir::type *_ty,
|
|
||||||
analysis::align* align): type(_type), axes(_axes), shapes(_shapes), values(_values), ty(_ty) {
|
data_layout::data_layout(id_t id,
|
||||||
|
const std::vector<int> &axes,
|
||||||
|
const std::vector<unsigned> &shape,
|
||||||
|
const std::vector<ir::value *> &values,
|
||||||
|
analysis::align* align): id_(id), axes_(axes), shape_(shape), values_(values) {
|
||||||
// io pointer
|
// io pointer
|
||||||
std::set<ir::value*> ptr;
|
std::set<ir::value*> ptr;
|
||||||
for(ir::value* v: values)
|
for(ir::value* v: values_)
|
||||||
extract_io_use(v, ptr);
|
extract_io_use(v, ptr);
|
||||||
order.resize(axes.size());
|
order_.resize(axes_.size());
|
||||||
std::iota(order.begin(), order.end(), 0);
|
std::iota(order_.begin(), order_.end(), 0);
|
||||||
auto largest = std::max_element(ptr.begin(), ptr.end(), [&](ir::value *x, ir::value *y){
|
auto largest = std::max_element(ptr.begin(), ptr.end(), [&](ir::value *x, ir::value *y){
|
||||||
return x->get_type()->get_tile_rank() < y->get_type()->get_tile_rank();
|
return x->get_type()->get_tile_rank() < y->get_type()->get_tile_rank();
|
||||||
});
|
});
|
||||||
if(*largest){
|
if(*largest){
|
||||||
auto max_contiguous = align->contiguous(*largest);
|
auto max_contiguous = align->contiguous(*largest);
|
||||||
std::sort(order.begin(), order.end(), [&](unsigned a, unsigned b) {
|
std::sort(order_.begin(), order_.end(), [&](unsigned a, unsigned b) {
|
||||||
return max_contiguous[a] > max_contiguous[b];
|
return max_contiguous[a] > max_contiguous[b];
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// downcast
|
size_t data_layout::find_axis(int to_find) const {
|
||||||
layout_hmma_884_t* layout_t::to_hmma884() {
|
auto it = std::find(axes_.begin(), axes_.end(), to_find);
|
||||||
assert(type == HMMA_884);
|
return std::distance(axes_.begin(), it);
|
||||||
return static_cast<layout_hmma_884_t*>(this);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
layout_scanline_t* layout_t::to_scanline() {
|
|
||||||
assert(type == SCANLINE);
|
|
||||||
return static_cast<layout_scanline_t*>(this);
|
|
||||||
}
|
|
||||||
|
|
||||||
layout_shared_t* layout_t::to_shared() {
|
/* -------------------------------- *
|
||||||
assert(type == SHARED);
|
* MMA Layout *
|
||||||
return static_cast<layout_shared_t*>(this);
|
* -------------------------------- */
|
||||||
}
|
|
||||||
|
|
||||||
inline unsigned clamp(unsigned x, unsigned lo, unsigned hi) {
|
mma884_layout::mma884_layout(size_t num_warps,
|
||||||
return std::min(std::max(x, lo), hi);
|
const std::vector<int>& axes,
|
||||||
}
|
const std::vector<unsigned>& shape,
|
||||||
|
const std::vector<ir::value *> &values,
|
||||||
layout_hmma_884_t::layout_hmma_884_t(size_t num_warps,
|
analysis::align* align): data_layout(HMMA_884, axes, shape, values, align) {
|
||||||
const std::vector<int>& _axes,
|
|
||||||
const std::vector<unsigned>& _shapes,
|
|
||||||
const std::vector<ir::value *> &values, ir::type *_ty,
|
|
||||||
analysis::align* align): layout_t(HMMA_884, _axes, _shapes, values, _ty, align) {
|
|
||||||
unsigned shape_0 = shapes[0];
|
|
||||||
unsigned shape_1 = shapes[1];
|
|
||||||
/* fragments per warp */
|
/* fragments per warp */
|
||||||
// try to make things as square as possible to maximize data re-use
|
// try to make things as square as possible to maximize data re-use
|
||||||
fpw = {1, 1, 1};
|
fpw_ = {1, 1, 1};
|
||||||
std::vector<int> fpw_nm1;
|
std::vector<int> fpw_nm1;
|
||||||
unsigned num_fragments = std::min<unsigned>((shape_0/8)*(shape_1/8), 4);
|
unsigned num_fragments = std::min<unsigned>((shape_[0]/8)*(shape_[1]/8), 4);
|
||||||
do {
|
do {
|
||||||
fpw_nm1 = fpw;
|
fpw_nm1 = fpw_;
|
||||||
if(fpw[0]*fpw[1] < num_fragments)
|
if(fpw_[0]*fpw_[1] < num_fragments)
|
||||||
fpw[0] = clamp(fpw[0]*2, 1, shape_0 / 8);
|
fpw_[0] = clamp(fpw_[0]*2, 1, shape_[0] / 8);
|
||||||
if(fpw[0]*fpw[1] < num_fragments)
|
if(fpw_[0]*fpw_[1] < num_fragments)
|
||||||
fpw[1] = clamp(fpw[1]*2, 1, shape_1 / 8);
|
fpw_[1] = clamp(fpw_[1]*2, 1, shape_[1] / 8);
|
||||||
}while(fpw_nm1 != fpw);
|
}while(fpw_nm1 != fpw_);
|
||||||
|
|
||||||
/* warps per tile */
|
/* warps per tile */
|
||||||
// try to make things as square as possible to maximize data re-use
|
// try to make things as square as possible to maximize data re-use
|
||||||
wpt = {1, 1, 1};
|
wpt_ = {1, 1, 1};
|
||||||
std::vector<int> wpt_nm1;
|
std::vector<int> wpt_nm1;
|
||||||
do{
|
do{
|
||||||
wpt_nm1 = wpt;
|
wpt_nm1 = wpt_;
|
||||||
if(wpt[0] * wpt[1] * wpt[2] < num_warps)
|
if(wpt_[0] * wpt_[1] * wpt_[2] < num_warps)
|
||||||
wpt[0] = clamp(wpt[0]*2, 1, shape_0 / (fpw[0]*8));
|
wpt_[0] = clamp(wpt_[0]*2, 1, shape_[0] / (fpw_[0]*8));
|
||||||
if(wpt[0] * wpt[1] * wpt[2] < num_warps)
|
if(wpt_[0] * wpt_[1] * wpt_[2] < num_warps)
|
||||||
wpt[1] = clamp(wpt[1]*2, 1, shape_1 / (fpw[1]*8));
|
wpt_[1] = clamp(wpt_[1]*2, 1, shape_[1] / (fpw_[1]*8));
|
||||||
}while(wpt_nm1 != wpt);
|
}while(wpt_nm1 != wpt_);
|
||||||
|
|
||||||
/* sanity check */
|
/* sanity check */
|
||||||
unsigned effective_num_warps = 1;
|
unsigned effective_num_warps = 1;
|
||||||
for(size_t d = 0; d < shapes.size(); d++)
|
for(size_t d = 0; d < shape.size(); d++)
|
||||||
effective_num_warps *= wpt[d];
|
effective_num_warps *= wpt_[d];
|
||||||
|
|
||||||
if(num_warps != effective_num_warps)
|
if(num_warps != effective_num_warps)
|
||||||
throw std::runtime_error("cannot create a kernel with this amount of warps");
|
throw std::runtime_error("cannot create a kernel with this amount of warps");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/* -------------------------------- *
|
||||||
|
* Scanline Layout *
|
||||||
|
* -------------------------------- */
|
||||||
|
|
||||||
|
scanline_layout::scanline_layout(size_t num_warps,
|
||||||
layout_scanline_t::layout_scanline_t(size_t num_warps,
|
const std::vector<int>& axes,
|
||||||
const std::vector<int>& _axes,
|
const std::vector<unsigned>& shape,
|
||||||
const std::vector<unsigned>& _shapes,
|
const std::vector<ir::value *> &values,
|
||||||
const std::vector<ir::value *> &values, ir::type *_ty,
|
analysis::align* align): data_layout(SCANLINE, axes, shape, values, align){
|
||||||
analysis::align* align): layout_t(SCANLINE, _axes, _shapes, values, _ty, align){
|
unsigned size = std::accumulate(shape_.begin(), shape_.end(), 1, std::multiplies<int>());
|
||||||
unsigned size = std::accumulate(shapes.begin(), shapes.end(), 1, std::multiplies<int>());
|
|
||||||
unsigned num_threads = num_warps * 32;
|
unsigned num_threads = num_warps * 32;
|
||||||
nts.resize(shapes.size());
|
nts_.resize(shape_.size());
|
||||||
mts.resize(shapes.size());
|
mts_.resize(shape_.size());
|
||||||
bool is_dot = std::any_of(values.begin(), values.end(),
|
bool is_dot = std::any_of(values.begin(), values.end(),
|
||||||
[&](ir::value* v) { return dynamic_cast<ir::dot_inst*>(v); });
|
[&](ir::value* v) { return dynamic_cast<ir::dot_inst*>(v); });
|
||||||
|
|
||||||
@@ -238,34 +178,39 @@ layout_scanline_t::layout_scanline_t(size_t num_warps,
|
|||||||
if(auto *st = dynamic_cast<ir::store_inst*>(usr))
|
if(auto *st = dynamic_cast<ir::store_inst*>(usr))
|
||||||
ptr = st->get_pointer_operand();
|
ptr = st->get_pointer_operand();
|
||||||
|
|
||||||
unsigned i = order[0];
|
unsigned i = order_[0];
|
||||||
int contiguous = 4;
|
int contiguous = 4;
|
||||||
if(ptr)
|
if(ptr)
|
||||||
contiguous = std::min<int>(align->contiguous(ptr)[i], 4);
|
contiguous = std::min<int>(align->contiguous(ptr)[i], 4);
|
||||||
|
|
||||||
nts[i] = clamp(size / num_threads, 1, std::min<int>(contiguous, shapes[i]));
|
nts_[i] = clamp(size / num_threads, 1, std::min<int>(contiguous, shape_[i]));
|
||||||
mts[i] = clamp(num_threads, 1, shapes[i] / nts[i]);
|
mts_[i] = clamp(num_threads, 1, shape_[i] / nts_[i]);
|
||||||
size /= shapes[i];
|
size /= shape_[i];
|
||||||
num_threads /= mts[i];
|
num_threads /= mts_[i];
|
||||||
if(is_dot)
|
if(is_dot)
|
||||||
nts[order[1]] = clamp(size / num_threads, 1, std::min<int>(4, shapes[order[1]]));
|
nts_[order_[1]] = clamp(size / num_threads, 1, std::min<int>(4, shape_[order_[1]]));
|
||||||
for(size_t d = 1; d < shapes.size(); d++){
|
for(size_t d = 1; d < shape_.size(); d++){
|
||||||
i = order[d];
|
i = order_[d];
|
||||||
if(d > 1 || !is_dot)
|
if(d > 1 || !is_dot)
|
||||||
nts[i] = 1;
|
nts_[i] = 1;
|
||||||
mts[i] = clamp(num_threads, 1, shapes[i] / nts[i]);
|
mts_[i] = clamp(num_threads, 1, shape_[i] / nts_[i]);
|
||||||
num_threads = num_threads / mts[i];
|
num_threads = num_threads / mts_[i];
|
||||||
}
|
}
|
||||||
/* sanity check */
|
/* sanity check */
|
||||||
unsigned effective_num_threads = 1;
|
unsigned effective_num_threads = 1;
|
||||||
for(size_t d = 0; d < shapes.size(); d++)
|
for(size_t d = 0; d < shape_.size(); d++)
|
||||||
effective_num_threads *= mts[d];
|
effective_num_threads *= mts_[d];
|
||||||
|
|
||||||
if(num_warps * 32 != effective_num_threads)
|
if(num_warps * 32 != effective_num_threads)
|
||||||
throw std::runtime_error("cannot create a kernel with this amount of warps");
|
throw std::runtime_error("cannot create a kernel with this amount of warps");
|
||||||
}
|
}
|
||||||
|
|
||||||
inline bool is_loop_latch(ir::phi_node *phi, ir::instruction *terminator){
|
|
||||||
|
/* -------------------------------- *
|
||||||
|
* Shared Layout *
|
||||||
|
* -------------------------------- */
|
||||||
|
|
||||||
|
bool shared_layout::is_loop_latch(ir::phi_node *phi, ir::instruction *terminator){
|
||||||
if(phi->get_parent() != terminator->get_parent())
|
if(phi->get_parent() != terminator->get_parent())
|
||||||
return false;
|
return false;
|
||||||
if(auto *br = dynamic_cast<ir::cond_branch_inst*>(terminator))
|
if(auto *br = dynamic_cast<ir::cond_branch_inst*>(terminator))
|
||||||
@@ -278,7 +223,7 @@ inline bool is_loop_latch(ir::phi_node *phi, ir::instruction *terminator){
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
void extract_double_bufferable(ir::value *v, std::shared_ptr<double_buffer_info_t>& res) {
|
void shared_layout::extract_double_bufferable(ir::value *v, std::shared_ptr<double_buffer_info_t>& res) {
|
||||||
auto* phi = dynamic_cast<ir::phi_node*>(v);
|
auto* phi = dynamic_cast<ir::phi_node*>(v);
|
||||||
if(!phi || phi->get_num_incoming() != 2)
|
if(!phi || phi->get_num_incoming() != 2)
|
||||||
return;
|
return;
|
||||||
@@ -303,22 +248,22 @@ void extract_double_bufferable(ir::value *v, std::shared_ptr<double_buffer_info_
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
layout_shared_t::layout_shared_t(const layout_t *arg,
|
shared_layout::shared_layout(const data_layout *arg,
|
||||||
const std::vector<int>& _axes,
|
const std::vector<int>& axes,
|
||||||
const std::vector<unsigned>& _shapes,
|
const std::vector<unsigned>& shape,
|
||||||
const std::vector<ir::value *> &values,
|
const std::vector<ir::value *> &values,
|
||||||
ir::type *ty,
|
ir::type *ty,
|
||||||
analysis::align* align): layout_t(SHARED, _axes, _shapes, values, ty, align) {
|
analysis::align* align): data_layout(SHARED, axes, shape, values, align), ty_(ty) {
|
||||||
|
|
||||||
size = 0;
|
size_ = 0;
|
||||||
|
|
||||||
// double-buffering
|
// double-buffering
|
||||||
for(ir::value *v: values)
|
for(ir::value *v: values)
|
||||||
extract_double_bufferable(v, double_buffer);
|
extract_double_bufferable(v, double_buffer_);
|
||||||
|
|
||||||
// order
|
// order
|
||||||
std::vector<int> arg_order = arg ? arg->order : std::vector<int>{0};
|
std::vector<int> arg_order = arg ? arg->get_order() : std::vector<int>{0};
|
||||||
order = arg_order;
|
order_ = arg_order;
|
||||||
|
|
||||||
ir::value* dot_a = nullptr;
|
ir::value* dot_a = nullptr;
|
||||||
ir::value* dot_b = nullptr;
|
ir::value* dot_b = nullptr;
|
||||||
@@ -330,48 +275,84 @@ layout_shared_t::layout_shared_t(const layout_t *arg,
|
|||||||
extract_hmma_dot_use(v, hmma_dot_a, 0);
|
extract_hmma_dot_use(v, hmma_dot_a, 0);
|
||||||
extract_hmma_dot_use(v, hmma_dot_b, 1);
|
extract_hmma_dot_use(v, hmma_dot_b, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// non-mma ordering
|
||||||
std::vector<int> col = {0, 1};
|
std::vector<int> col = {0, 1};
|
||||||
std::vector<int> row = {1, 0};
|
std::vector<int> row = {1, 0};
|
||||||
for(size_t s = 2; s < shapes.size(); s++){
|
for(size_t s = 2; s < get_rank(); s++){
|
||||||
col.push_back(s);
|
col.push_back(s);
|
||||||
row.push_back(s);
|
row.push_back(s);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
bool is_nonhmma_dot_a = dot_a && !hmma_dot_a;
|
bool is_nonhmma_dot_a = dot_a && !hmma_dot_a;
|
||||||
bool is_nonhmma_dot_b = dot_b && !hmma_dot_b;
|
bool is_nonhmma_dot_b = dot_b && !hmma_dot_b;
|
||||||
if(is_nonhmma_dot_a)
|
if(is_nonhmma_dot_a)
|
||||||
order = is_trans(dot_a) ? row : col;
|
order_ = is_trans(dot_a) ? row : col;
|
||||||
else if(is_nonhmma_dot_b)
|
else if(is_nonhmma_dot_b)
|
||||||
order = is_trans(dot_b) ? col : row;
|
order_ = is_trans(dot_b) ? col : row;
|
||||||
// else
|
|
||||||
// order = row;
|
|
||||||
// padding
|
// padding
|
||||||
size_t pad = 0;
|
size_t pad = 0;
|
||||||
if(hmma_dot_a){
|
if(hmma_dot_a){
|
||||||
bool row = is_trans(hmma_dot_a) ^ order[0] != 0;
|
bool row = is_trans(hmma_dot_a) ^ order_[0] != 0;
|
||||||
pad = 24 - shapes[row ? 0 : 1] % 32;
|
pad = 24 - shape_[row ? 0 : 1] % 32;
|
||||||
}
|
}
|
||||||
else if(hmma_dot_b){
|
else if(hmma_dot_b){
|
||||||
bool row = is_trans(hmma_dot_b) ^ order[0] != 0;
|
bool row = is_trans(hmma_dot_b) ^ order_[0] != 0;
|
||||||
pad = 24 - shapes[row ? 1 : 0] % 32;
|
pad = 24 - shape_[row ? 1 : 0] % 32;
|
||||||
}
|
}
|
||||||
else if(order != arg_order) {
|
else if(order_ != arg_order) {
|
||||||
pad = 4;
|
pad = 4;
|
||||||
}
|
}
|
||||||
shapes[order[0]] += pad;
|
shape_[order_[0]] += pad;
|
||||||
|
|
||||||
// size
|
// size
|
||||||
size = ty->get_primitive_size_in_bits() / 8;
|
size_ = ty_->get_primitive_size_in_bits() / 8;
|
||||||
for(auto s: shapes)
|
for(auto s: shape_)
|
||||||
size *= s;
|
size_ *= s;
|
||||||
if(double_buffer)
|
if(double_buffer_)
|
||||||
size *= 2;
|
size_ *= 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// layout factory method
|
/* -------------------------------- *
|
||||||
void layout::create(size_t id, const std::vector<ir::value*>& values) {
|
* ---- Layouts Inference Pass ---- *
|
||||||
|
* -------------------------------- */
|
||||||
|
|
||||||
|
layouts::layouts(analysis::axes *axes, analysis::align *align, size_t num_warps)
|
||||||
|
: axes_(axes), align_(align), num_warps_(num_warps) { }
|
||||||
|
|
||||||
|
|
||||||
|
void layouts::connect(ir::value *x, ir::value *y) {
|
||||||
|
if(x == y)
|
||||||
|
return;
|
||||||
|
if(!x->get_type()->is_tile_ty())
|
||||||
|
return;
|
||||||
|
if(!y->get_type()->is_tile_ty())
|
||||||
|
return;
|
||||||
|
std::vector<int> x_axes = axes_->get(x);
|
||||||
|
std::vector<int> y_axes = axes_->get(y);
|
||||||
|
std::set<int> sx_axes(x_axes.begin(), x_axes.end());
|
||||||
|
std::set<int> sy_axes(y_axes.begin(), y_axes.end());
|
||||||
|
std::set<int> common;
|
||||||
|
std::set_intersection(sx_axes.begin(), sx_axes.end(),
|
||||||
|
sy_axes.begin(), sy_axes.end(),
|
||||||
|
std::inserter(common, common.begin()));
|
||||||
|
graph_.add_edge(x, x);
|
||||||
|
graph_.add_edge(y, y);
|
||||||
|
if(!common.empty())
|
||||||
|
graph_.add_edge(x, y);
|
||||||
|
}
|
||||||
|
|
||||||
|
void layouts::make_graph(ir::instruction *i) {
|
||||||
|
for(ir::value* opx: i->ops())
|
||||||
|
for(ir::value* opy: i->ops()){
|
||||||
|
connect(i, opx);
|
||||||
|
connect(opx, opy);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void layouts::create(size_t id, const std::vector<ir::value*>& values) {
|
||||||
auto it_hmma_c = std::find_if(values.begin(), values.end(), &is_hmma_c);
|
auto it_hmma_c = std::find_if(values.begin(), values.end(), &is_hmma_c);
|
||||||
auto cmp = [](ir::value* x, ir::value *y) {
|
auto cmp = [](ir::value* x, ir::value *y) {
|
||||||
return x->get_type()->get_tile_ranks1() <
|
return x->get_type()->get_tile_ranks1() <
|
||||||
@@ -387,18 +368,18 @@ void layout::create(size_t id, const std::vector<ir::value*>& values) {
|
|||||||
});
|
});
|
||||||
// type
|
// type
|
||||||
if(it_hmma_c != values.end())
|
if(it_hmma_c != values.end())
|
||||||
layouts_[id] = new layout_hmma_884_t(num_warps_, axes, shapes, values, largest->get_type()->get_scalar_ty(), align_);
|
layouts_[id] = new mma884_layout(num_warps_, axes, shapes, values, align_);
|
||||||
else if(it_cts != values.end()){
|
else if(it_cts != values.end()){
|
||||||
ir::copy_to_shared_inst *cts = (ir::copy_to_shared_inst*)*it_cts;
|
ir::copy_to_shared_inst *cts = (ir::copy_to_shared_inst*)*it_cts;
|
||||||
ir::value *arg = cts->get_operand(0);
|
ir::value *arg = cts->get_operand(0);
|
||||||
create(groups_.at(arg), values_.at(groups_.at(arg)));
|
create(groups_.at(arg), values_.at(groups_.at(arg)));
|
||||||
layouts_[id] = new layout_shared_t(get(arg), axes, shapes, values, largest->get_type()->get_scalar_ty(), align_);
|
layouts_[id] = new shared_layout(get(arg), axes, shapes, values, largest->get_type()->get_scalar_ty(), align_);
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
layouts_[id] = new layout_scanline_t(num_warps_, axes, shapes, values, largest->get_type()->get_scalar_ty(), align_);
|
layouts_[id] = new scanline_layout(num_warps_, axes, shapes, values, align_);
|
||||||
}
|
}
|
||||||
|
|
||||||
void layout::run(ir::module &mod) {
|
void layouts::run(ir::module &mod) {
|
||||||
// make graph
|
// make graph
|
||||||
graph_.clear();
|
graph_.clear();
|
||||||
ir::for_each_instruction(mod, [this](ir::instruction* i) {
|
ir::for_each_instruction(mod, [this](ir::instruction* i) {
|
||||||
@@ -422,35 +403,35 @@ void layout::run(ir::module &mod) {
|
|||||||
// shape
|
// shape
|
||||||
auto shapes = arg->get_type()->get_tile_shapes();
|
auto shapes = arg->get_type()->get_tile_shapes();
|
||||||
unsigned shape_ax = shapes[axis];
|
unsigned shape_ax = shapes[axis];
|
||||||
layout_scanline_t *layout = get(arg)->to_scanline();
|
scanline_layout *layout = get(arg)->to_scanline();
|
||||||
unsigned per_thread = layout->nts[axis];
|
unsigned per_thread = layout->nts(axis);
|
||||||
unsigned depth = shape_ax / per_thread;
|
unsigned depth = shape_ax / per_thread;
|
||||||
shapes[axis] = depth;
|
shapes[axis] = depth;
|
||||||
// create layout
|
// create layout
|
||||||
layouts_[id] = new layout_shared_t(layout, axes_->get(arg), shapes, {red}, red->get_type()->get_scalar_ty(), align_);
|
layouts_[id] = new shared_layout(layout, axes_->get(arg), shapes, {red}, red->get_type()->get_scalar_ty(), align_);
|
||||||
tmp_[red] = id;
|
tmp_[red] = id;
|
||||||
}
|
}
|
||||||
if(auto *recoalasce = dynamic_cast<ir::recoalesce_inst*>(i)){
|
if(auto *recoalasce = dynamic_cast<ir::recoalesce_inst*>(i)){
|
||||||
ir::value *val = recoalasce->get_operand(0);
|
ir::value *val = recoalasce->get_operand(0);
|
||||||
layout_t* in_layout = get(val);
|
mma884_layout* in_layout = get(val)->to_mma884();
|
||||||
layout_t* out_layout = get(i);
|
scanline_layout* out_layout = get(i)->to_scanline();
|
||||||
if(in_layout->type != HMMA_884)
|
if(!in_layout || !out_layout)
|
||||||
return;
|
return;
|
||||||
id++;
|
id++;
|
||||||
ir::type::tile_shapes_t in_shape = val->get_type()->get_tile_shapes();
|
ir::type::tile_shapes_t in_shape = val->get_type()->get_tile_shapes();
|
||||||
ir::type::tile_shapes_t shape(in_shape.size());
|
ir::type::tile_shapes_t shape(in_shape.size());
|
||||||
size_t ld = out_layout->order[0];
|
size_t ld = out_layout->get_order(0);
|
||||||
shape[ld] = in_shape[ld];
|
shape[ld] = in_shape[ld];
|
||||||
for(size_t k = 0; k < in_shape.size(); k++)
|
for(size_t k = 0; k < in_shape.size(); k++)
|
||||||
if(k != ld)
|
if(k != ld)
|
||||||
shape[k] = 4*in_layout->to_hmma884()->fpw[k]*in_layout->to_hmma884()->wpt[k];
|
shape[k] = 4*in_layout->to_mma884()->fpw(k)*in_layout->to_mma884()->wpt(k);
|
||||||
// create layout
|
// create layout
|
||||||
layouts_[id] = new layout_shared_t(out_layout, axes_->get(val), shape, {recoalasce}, val->get_type()->get_scalar_ty(), align_);
|
layouts_[id] = new shared_layout(out_layout, axes_->get(val), shape, {recoalasce}, val->get_type()->get_scalar_ty(), align_);
|
||||||
tmp_[recoalasce] = id;
|
tmp_[recoalasce] = id;
|
||||||
}
|
}
|
||||||
if(auto *atom = dynamic_cast<ir::atomic_cas_inst*>(i)){
|
if(auto *atom = dynamic_cast<ir::atomic_cas_inst*>(i)){
|
||||||
id++;
|
id++;
|
||||||
layouts_[id] = new layout_shared_t(nullptr, {}, {1}, {atom}, atom->get_type()->get_scalar_ty(), align_);
|
layouts_[id] = new shared_layout(nullptr, {}, {1}, {atom}, atom->get_type()->get_scalar_ty(), align_);
|
||||||
tmp_[atom] = id;
|
tmp_[atom] = id;
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
@@ -27,18 +27,18 @@ void liveness::run(ir::module &mod) {
|
|||||||
|
|
||||||
// create live intervals
|
// create live intervals
|
||||||
for(auto &x: layouts_->get_all()) {
|
for(auto &x: layouts_->get_all()) {
|
||||||
if(x.second->type != SHARED)
|
shared_layout* layout = x.second->to_shared();
|
||||||
|
if(!layout)
|
||||||
continue;
|
continue;
|
||||||
layout_shared_t* layout = x.second->to_shared();
|
|
||||||
// users
|
// users
|
||||||
std::set<ir::user*> users;
|
std::set<ir::user*> users;
|
||||||
for(ir::value *v: layout->values){
|
for(ir::value *v: layout->get_values()){
|
||||||
for(ir::user *u: v->get_users())
|
for(ir::user *u: v->get_users())
|
||||||
users.insert(u);
|
users.insert(u);
|
||||||
}
|
}
|
||||||
// compute intervals
|
// compute intervals
|
||||||
unsigned start = INT32_MAX;
|
unsigned start = INT32_MAX;
|
||||||
for(ir::value *v: layout->values)
|
for(ir::value *v: layout->get_values())
|
||||||
if(indices.find(v) != indices.end())
|
if(indices.find(v) != indices.end())
|
||||||
start = std::min(start, indices.at(v));
|
start = std::min(start, indices.at(v));
|
||||||
unsigned end = 0;
|
unsigned end = 0;
|
||||||
|
@@ -174,7 +174,7 @@ inline bool is_trans(ir::value *v) {
|
|||||||
|
|
||||||
|
|
||||||
generator::generator(analysis::axes *a_axes,
|
generator::generator(analysis::axes *a_axes,
|
||||||
analysis::layout *layouts,
|
analysis::layouts *layouts,
|
||||||
analysis::align *alignment,
|
analysis::align *alignment,
|
||||||
analysis::allocation *alloc,
|
analysis::allocation *alloc,
|
||||||
target *tgt,
|
target *tgt,
|
||||||
@@ -295,7 +295,7 @@ void generator::visit_unmasked_load_inst(ir::unmasked_load_inst* x) {
|
|||||||
}
|
}
|
||||||
// find vector size
|
// find vector size
|
||||||
ir::value *ptr = x->get_pointer_operand();
|
ir::value *ptr = x->get_pointer_operand();
|
||||||
size_t ld = layouts_->get(ptr)->order[0];
|
size_t ld = layouts_->get(ptr)->get_order(0);
|
||||||
unsigned alignment = std::max<int>(alignment_->get(ptr, ld), 1);
|
unsigned alignment = std::max<int>(alignment_->get(ptr, ld), 1);
|
||||||
|
|
||||||
|
|
||||||
@@ -337,7 +337,7 @@ void generator::visit_unmasked_load_inst(ir::unmasked_load_inst* x) {
|
|||||||
void generator::visit_masked_load_inst(ir::masked_load_inst* x) {
|
void generator::visit_masked_load_inst(ir::masked_load_inst* x) {
|
||||||
// find vector size
|
// find vector size
|
||||||
ir::value *ptr = x->get_pointer_operand();
|
ir::value *ptr = x->get_pointer_operand();
|
||||||
size_t ld = layouts_->get(ptr)->order[0];
|
size_t ld = layouts_->get(ptr)->get_order(0);
|
||||||
unsigned alignment = alignment_->get(ptr, ld);
|
unsigned alignment = alignment_->get(ptr, ld);
|
||||||
distributed_tile *pointers = (distributed_tile*)tmap_.at(ptr);
|
distributed_tile *pointers = (distributed_tile*)tmap_.at(ptr);
|
||||||
distributed_tile *masks = (distributed_tile*)tmap_.at(x->get_mask_operand());
|
distributed_tile *masks = (distributed_tile*)tmap_.at(x->get_mask_operand());
|
||||||
@@ -603,7 +603,7 @@ void generator::visit_atomic_add_inst(ir::atomic_add_inst*) {
|
|||||||
|
|
||||||
void generator::visit_hmma_dot(ir::dot_inst* dot, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK) {
|
void generator::visit_hmma_dot(ir::dot_inst* dot, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK) {
|
||||||
const auto& shapes = dot->get_type()->get_tile_shapes();
|
const auto& shapes = dot->get_type()->get_tile_shapes();
|
||||||
machine_layout_hmma_884_t* hmma = (machine_layout_hmma_884_t*)machine_layouts_.at(layouts_->get(dot));
|
machine_mma884_layout* hmma = (machine_mma884_layout*)machine_layouts_.at(layouts_->get(dot));
|
||||||
TA->set_vector_size(4*hmma->pack_size_0_);
|
TA->set_vector_size(4*hmma->pack_size_0_);
|
||||||
TB->set_vector_size(4*hmma->pack_size_1_);
|
TB->set_vector_size(4*hmma->pack_size_1_);
|
||||||
TA->set_return_mode(true);
|
TA->set_return_mode(true);
|
||||||
@@ -625,8 +625,8 @@ void generator::visit_hmma_dot(ir::dot_inst* dot, shared_tile *TA, shared_tile *
|
|||||||
|
|
||||||
Value* u_thread_id = tgt_->get_local_id(builder_->GetInsertBlock()->getModule(), *builder_, 0);
|
Value* u_thread_id = tgt_->get_local_id(builder_->GetInsertBlock()->getModule(), *builder_, 0);
|
||||||
|
|
||||||
auto ord_a = layouts_->get(dot->get_operand(0))->order;
|
auto ord_a = layouts_->get(dot->get_operand(0))->get_order();
|
||||||
auto ord_b = layouts_->get(dot->get_operand(1))->order;
|
auto ord_b = layouts_->get(dot->get_operand(1))->get_order();
|
||||||
|
|
||||||
bool is_a_trans = is_trans(dot->get_operand(0));
|
bool is_a_trans = is_trans(dot->get_operand(0));
|
||||||
bool is_b_trans = is_trans(dot->get_operand(1));
|
bool is_b_trans = is_trans(dot->get_operand(1));
|
||||||
@@ -655,14 +655,14 @@ void generator::visit_hmma_dot(ir::dot_inst* dot, shared_tile *TA, shared_tile *
|
|||||||
"{$8, $9}, "
|
"{$8, $9}, "
|
||||||
"{$10, $11}, "
|
"{$10, $11}, "
|
||||||
"{$0, $1, $2, $3, $4, $5, $6, $7};", "=f,=f,=f,=f,=f,=f,=f,=f,r,r,r,r,0,1,2,3,4,5,6,7", false);
|
"{$0, $1, $2, $3, $4, $5, $6, $7};", "=f,=f,=f,=f,=f,=f,=f,=f,r,r,r,r,0,1,2,3,4,5,6,7", false);
|
||||||
analysis::layout_hmma_884_t* layout = layouts_->get(dot)->to_hmma884();
|
analysis::mma884_layout* layout = layouts_->get(dot)->to_mma884();
|
||||||
|
|
||||||
unsigned fpw_0 = layout->fpw.at(0);
|
unsigned fpw_0 = layout->fpw(0);
|
||||||
unsigned fpw_1 = layout->fpw.at(1);
|
unsigned fpw_1 = layout->fpw(1);
|
||||||
unsigned wts_0 = fpw_0 * 8;
|
unsigned wts_0 = fpw_0 * 8;
|
||||||
unsigned wts_1 = fpw_1 * 8;
|
unsigned wts_1 = fpw_1 * 8;
|
||||||
unsigned wpt_0 = layout->wpt.at(0);
|
unsigned wpt_0 = layout->wpt(0);
|
||||||
unsigned wpt_1 = layout->wpt.at(1);
|
unsigned wpt_1 = layout->wpt(1);
|
||||||
unsigned stride_rep_i = wpt_0 * wts_0;
|
unsigned stride_rep_i = wpt_0 * wts_0;
|
||||||
unsigned stride_rep_j = wpt_1 * wts_1;
|
unsigned stride_rep_j = wpt_1 * wts_1;
|
||||||
unsigned num_rep_i = shapes[0] / stride_rep_i;
|
unsigned num_rep_i = shapes[0] / stride_rep_i;
|
||||||
@@ -792,7 +792,7 @@ void generator::visit_dot_inst(ir::dot_inst* dot) {
|
|||||||
if(NK != 1) {
|
if(NK != 1) {
|
||||||
shared_tile *TA = (shared_tile*)tmap_.at(A);
|
shared_tile *TA = (shared_tile*)tmap_.at(A);
|
||||||
shared_tile *TB = (shared_tile*)tmap_.at(B);
|
shared_tile *TB = (shared_tile*)tmap_.at(B);
|
||||||
if(layouts_->get(dot)->type == analysis::HMMA_884)
|
if(layouts_->get(dot)->to_mma884())
|
||||||
visit_hmma_dot(dot, TA, TB, TD, NK);
|
visit_hmma_dot(dot, TA, TB, TD, NK);
|
||||||
else
|
else
|
||||||
visit_scanline_dot(dot, TA, TB, TD, NK, c_ty, f_mul_add);
|
visit_scanline_dot(dot, TA, TB, TD, NK, c_ty, f_mul_add);
|
||||||
@@ -856,7 +856,7 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) {
|
|||||||
});
|
});
|
||||||
|
|
||||||
// reduce within blocks
|
// reduce within blocks
|
||||||
machine_layout_t *slayout = machine_layouts_.at(layouts_->get(layouts_->tmp(x)));
|
machine_data_layout *slayout = machine_layouts_.at(layouts_->get(layouts_->tmp(x)));
|
||||||
shared_tile *stile = (shared_tile*)slayout->create(x);
|
shared_tile *stile = (shared_tile*)slayout->create(x);
|
||||||
unsigned depth = stile->get_shapes()[axis];
|
unsigned depth = stile->get_shapes()[axis];
|
||||||
|
|
||||||
@@ -926,31 +926,31 @@ void generator::visit_recoalesce_inst(ir::recoalesce_inst* rc) {
|
|||||||
// pointer to temporary shared memory
|
// pointer to temporary shared memory
|
||||||
Type *ty = llvm_type(rc->get_type()->get_scalar_ty(), *ctx_);
|
Type *ty = llvm_type(rc->get_type()->get_scalar_ty(), *ctx_);
|
||||||
// layouts
|
// layouts
|
||||||
analysis::layout_hmma_884_t* in_layout = layouts_->get(op)->to_hmma884();
|
analysis::mma884_layout* in_layout = layouts_->get(op)->to_mma884();
|
||||||
analysis::layout_scanline_t* out_layout = layouts_->get(rc)->to_scanline();
|
analysis::scanline_layout* out_layout = layouts_->get(rc)->to_scanline();
|
||||||
// machine tiles
|
// machine tiles
|
||||||
distributed_tile *in_dt = (distributed_tile*)(tmap_.at(op));
|
distributed_tile *in_dt = (distributed_tile*)(tmap_.at(op));
|
||||||
distributed_tile *out_dt = (distributed_tile*)(tmap_.at(rc));
|
distributed_tile *out_dt = (distributed_tile*)(tmap_.at(rc));
|
||||||
// WMMA configuration
|
// WMMA configuration
|
||||||
long wmma_pt[3] = { 2, 4, 1};
|
long wmma_pt[3] = { 2, 4, 1};
|
||||||
long wmma[3] = { 8*in_layout->wpt[0]*in_layout->fpw[0],
|
long wmma[3] = { 8*in_layout->wpt(0)*in_layout->fpw(0),
|
||||||
8*in_layout->wpt[1]*in_layout->fpw[1],
|
8*in_layout->wpt(1)*in_layout->fpw(1),
|
||||||
1};
|
1};
|
||||||
// Work per thread for input layout
|
// Work per thread for input layout
|
||||||
long in_pt[3] = { shape[0] / wmma[0],
|
long in_pt[3] = { shape[0] / wmma[0],
|
||||||
shape[1] / wmma[1],
|
shape[1] / wmma[1],
|
||||||
1 };
|
1 };
|
||||||
// Work per thread for output layout
|
// Work per thread for output layout
|
||||||
long out_pt[3] = { shape[0] / out_layout->mts[0],
|
long out_pt[3] = { shape[0] / out_layout->mts(0),
|
||||||
shape[1] / out_layout->mts[1],
|
shape[1] / out_layout->mts(1),
|
||||||
1 };
|
1 };
|
||||||
if(rank > 2){
|
if(rank > 2){
|
||||||
wmma[2] = in_layout->wpt[2]*in_layout->fpw[2];
|
wmma[2] = in_layout->wpt(2)*in_layout->fpw(2);
|
||||||
in_pt[2] = shape[2] / wmma[2];
|
in_pt[2] = shape[2] / wmma[2];
|
||||||
out_pt[2] = shape[2] / out_layout->mts[2];
|
out_pt[2] = shape[2] / out_layout->mts(2);
|
||||||
}
|
}
|
||||||
// Orders
|
// Orders
|
||||||
auto ord = out_layout->order;
|
auto ord = out_layout->get_order();
|
||||||
if(ord.size() < 3)
|
if(ord.size() < 3)
|
||||||
ord.push_back(2);
|
ord.push_back(2);
|
||||||
// pointer lanes
|
// pointer lanes
|
||||||
@@ -1028,13 +1028,13 @@ void generator::visit_recoalesce_inst(ir::recoalesce_inst* rc) {
|
|||||||
void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) {
|
void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) {
|
||||||
unsigned vector_size = 1;
|
unsigned vector_size = 1;
|
||||||
ir::value *arg = cts->get_operand(0);
|
ir::value *arg = cts->get_operand(0);
|
||||||
analysis::layout_shared_t* out_layout = layouts_->get(cts)->to_shared();
|
analysis::shared_layout* out_layout = layouts_->get(cts)->to_shared();
|
||||||
analysis::layout_scanline_t* in_layout = layouts_->get(arg)->to_scanline();
|
analysis::scanline_layout* in_layout = layouts_->get(arg)->to_scanline();
|
||||||
auto out_order = out_layout->order;
|
auto out_order = out_layout->get_order();
|
||||||
auto in_order = in_layout->order;
|
auto in_order = in_layout->get_order();
|
||||||
// tiles
|
// tiles
|
||||||
if(out_order == in_order)
|
if(out_order == in_order)
|
||||||
vector_size = in_layout->nts.at(in_order[0]);
|
vector_size = in_layout->nts(in_order[0]);
|
||||||
|
|
||||||
std::map<unsigned, Value*> packets;
|
std::map<unsigned, Value*> packets;
|
||||||
for_each(arg, [&](indices_t idx){
|
for_each(arg, [&](indices_t idx){
|
||||||
@@ -1180,17 +1180,17 @@ void generator::visit_function(ir::function* fn) {
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
void generator::visit_layout_hmma_884(analysis::layout_hmma_884_t* layout) {
|
void generator::visit_layout_hmma_884(analysis::mma884_layout* layout) {
|
||||||
machine_layouts_[layout] = new machine_layout_hmma_884_t(mod_, &*builder_, tgt_, llvm_type(layout->ty->get_scalar_ty(), *ctx_), a_axes_, axes_, layout);
|
machine_layouts_[layout] = new machine_mma884_layout(mod_, &*builder_, tgt_, a_axes_, axes_, layout);
|
||||||
}
|
}
|
||||||
|
|
||||||
void generator::visit_layout_scanline(analysis::layout_scanline_t* layout) {
|
void generator::visit_layout_scanline(analysis::scanline_layout* layout) {
|
||||||
machine_layouts_[layout] = new machine_layout_scanline_t(mod_, &*builder_, tgt_, llvm_type(layout->ty->get_scalar_ty(), *ctx_), a_axes_, axes_, layout);
|
machine_layouts_[layout] = new machine_scanline_layout(mod_, &*builder_, tgt_, a_axes_, axes_, layout);
|
||||||
}
|
}
|
||||||
|
|
||||||
void generator::visit_layout_shared(analysis::layout_shared_t* layout) {
|
void generator::visit_layout_shared(analysis::shared_layout* layout) {
|
||||||
|
|
||||||
machine_layouts_[layout] = new machine_layout_shared_t(mod_, &*builder_, tgt_, alloc_, sh_mem_ptr_, layout, vmap_, tmap_);
|
machine_layouts_[layout] = new machine_shared_layout(mod_, &*builder_, tgt_, alloc_, sh_mem_ptr_, layout, vmap_, tmap_);
|
||||||
}
|
}
|
||||||
|
|
||||||
void generator::visit_basic_block(ir::basic_block * block) {
|
void generator::visit_basic_block(ir::basic_block * block) {
|
||||||
@@ -1230,9 +1230,9 @@ void generator::set_value(ir::value *x, const indices_t& idx, Value* v) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
void generator::finalize_shared_layout(analysis::layout_shared_t *shared) {
|
void generator::finalize_shared_layout(analysis::shared_layout *shared) {
|
||||||
if(shared->double_buffer) {
|
if(shared->get_double_buffer()) {
|
||||||
auto info = *shared->double_buffer;
|
auto info = *shared->get_double_buffer();
|
||||||
ir::phi_node *phi = info.phi;
|
ir::phi_node *phi = info.phi;
|
||||||
PHINode *ptr = (PHINode*)((shared_tile*)tmap_.at(phi))->get_pointer();
|
PHINode *ptr = (PHINode*)((shared_tile*)tmap_.at(phi))->get_pointer();
|
||||||
PHINode *offset = (PHINode*)((shared_tile*)tmap_.at(phi))->get_offset();
|
PHINode *offset = (PHINode*)((shared_tile*)tmap_.at(phi))->get_offset();
|
||||||
@@ -1247,8 +1247,8 @@ void generator::finalize_shared_layout(analysis::layout_shared_t *shared) {
|
|||||||
offset->addIncoming(next_offset, llvm_inc_block);
|
offset->addIncoming(next_offset, llvm_inc_block);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
unsigned num_bytes = shared->ty->get_primitive_size_in_bits() / 8;
|
unsigned num_bytes = shared->get_type()->get_primitive_size_in_bits() / 8;
|
||||||
offset->addIncoming(builder_->getInt32(shared->size / (2*num_bytes)), llvm_inc_block);
|
offset->addIncoming(builder_->getInt32(shared->get_size() / (2*num_bytes)), llvm_inc_block);
|
||||||
}
|
}
|
||||||
ptr->addIncoming(inc_shared->get_pointer(), llvm_inc_block);
|
ptr->addIncoming(inc_shared->get_pointer(), llvm_inc_block);
|
||||||
}
|
}
|
||||||
@@ -1258,7 +1258,7 @@ void generator::finalize_shared_layout(analysis::layout_shared_t *shared) {
|
|||||||
void generator::finalize_function(ir::function *fn) {
|
void generator::finalize_function(ir::function *fn) {
|
||||||
// finalize double-buffering
|
// finalize double-buffering
|
||||||
for(const auto& x: layouts_->get_all())
|
for(const auto& x: layouts_->get_all())
|
||||||
if(auto *shared = dynamic_cast<analysis::layout_shared_t*>(x.second))
|
if(auto *shared = dynamic_cast<analysis::shared_layout*>(x.second))
|
||||||
finalize_shared_layout(shared);
|
finalize_shared_layout(shared);
|
||||||
// finalize phi
|
// finalize phi
|
||||||
for(ir::basic_block *block: fn->blocks())
|
for(ir::basic_block *block: fn->blocks())
|
||||||
|
@@ -71,18 +71,18 @@ inline int32_t ceil(int32_t num, int32_t div){
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
machine_layout_shared_t::machine_layout_shared_t(Module *mod, Builder *builder, target *tgt, analysis::allocation* alloc,
|
machine_shared_layout::machine_shared_layout(Module *mod, Builder *builder, target *tgt, analysis::allocation* alloc,
|
||||||
Value *&sh_mem_ptr, analysis::layout_shared_t *layout,
|
Value *&sh_mem_ptr, analysis::shared_layout *layout,
|
||||||
std::map<ir::value *, Value *>& vmap,
|
std::map<ir::value *, Value *>& vmap,
|
||||||
std::map<ir::value *, tile *>& tmap)
|
std::map<ir::value *, tile *>& tmap)
|
||||||
: mod_(mod), builder_(builder), tgt_(tgt), alloc_(alloc), sh_mem_ptr_(sh_mem_ptr), layout_(layout), vmap_(vmap), tmap_(tmap) {
|
: mod_(mod), builder_(builder), tgt_(tgt), alloc_(alloc), sh_mem_ptr_(sh_mem_ptr), layout_(layout), vmap_(vmap), tmap_(tmap) {
|
||||||
|
|
||||||
Type* ty = llvm_type(layout_->ty, builder_->getContext());
|
Type* ty = llvm_type(layout_->get_type(), builder_->getContext());
|
||||||
PointerType *ptr_ty = ty->getPointerTo(sh_mem_ptr_->getType()->getPointerAddressSpace());
|
PointerType *ptr_ty = ty->getPointerTo(sh_mem_ptr_->getType()->getPointerAddressSpace());
|
||||||
// double-buffered
|
// double-buffered
|
||||||
if(layout_->double_buffer) {
|
if(layout_->get_double_buffer()) {
|
||||||
BasicBlock *current = builder_->GetInsertBlock();
|
BasicBlock *current = builder_->GetInsertBlock();
|
||||||
auto info = *layout_->double_buffer;
|
auto info = *layout_->get_double_buffer();
|
||||||
ir::phi_node *phi = info.phi;
|
ir::phi_node *phi = info.phi;
|
||||||
BasicBlock *parent = (BasicBlock*)vmap_.at((ir::value*)(phi->get_parent()));
|
BasicBlock *parent = (BasicBlock*)vmap_.at((ir::value*)(phi->get_parent()));
|
||||||
if(parent->empty())
|
if(parent->empty())
|
||||||
@@ -105,31 +105,31 @@ machine_layout_shared_t::machine_layout_shared_t(Module *mod, Builder *builder,
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
tile* machine_layout_shared_t::create(ir::value *v) {
|
tile* machine_shared_layout::create(ir::value *v) {
|
||||||
auto order = layout_->order;
|
Type* ty = llvm_type(layout_->get_type(), builder_->getContext());
|
||||||
auto shapes = layout_->shapes;
|
auto double_buffer = layout_->get_double_buffer();
|
||||||
Type* ty = llvm_type(layout_->ty, builder_->getContext());
|
// offset
|
||||||
// double-buffered
|
Value *offset = nullptr;
|
||||||
if(layout_->double_buffer) {
|
if(double_buffer && v == double_buffer->phi)
|
||||||
if(v == layout_->double_buffer->phi)
|
offset = offset_;
|
||||||
return new shared_tile(ty, shapes, order, ptr_, *builder_, offset_);
|
// base pointer
|
||||||
if(v == layout_->double_buffer->latch)
|
Value *ptr = ptr_;
|
||||||
return new shared_tile(ty, shapes, order, next_ptr_, *builder_);
|
if(double_buffer && v == double_buffer->latch)
|
||||||
return new shared_tile(ty, shapes, order, pre_ptr_, *builder_);
|
ptr = next_ptr_;
|
||||||
}
|
else if(double_buffer && v == double_buffer->first)
|
||||||
else {
|
ptr = pre_ptr_;
|
||||||
return new shared_tile(ty, shapes, order, ptr_, *builder_);
|
// create tile
|
||||||
}
|
return new shared_tile(ty, layout_->get_shape(), layout_->get_order(), ptr, *builder_, offset);
|
||||||
}
|
}
|
||||||
|
|
||||||
machine_layout_distributed_t::machine_layout_distributed_t(Module *mod, Builder *builder, target *tgt, Type *ty,
|
machine_distributed_layout::machine_distributed_layout(Module *mod, Builder *builder, target *tgt,
|
||||||
analysis::axes *a_axes, std::map<unsigned, distributed_axis>& axes,
|
analysis::axes *a_axes, std::map<unsigned, distributed_axis>& axes,
|
||||||
analysis::layout_t *layout)
|
analysis::data_layout *layout)
|
||||||
: mod_(mod), builder_(builder), tgt_(tgt), ty_(ty), a_axes_(a_axes), axes_(axes), layout_(layout) {
|
: mod_(mod), builder_(builder), tgt_(tgt), a_axes_(a_axes), axes_(axes), layout_(layout) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
tile *machine_layout_distributed_t::create(ir::value *v) {
|
tile *machine_distributed_layout::create(ir::value *v) {
|
||||||
Type *ty = llvm_type(v->get_type()->get_scalar_ty(), builder_->getContext());
|
Type *ty = llvm_type(v->get_type()->get_scalar_ty(), builder_->getContext());
|
||||||
const auto &shapes = v->get_type()->get_tile_shapes();
|
const auto &shapes = v->get_type()->get_tile_shapes();
|
||||||
size_t rank = shapes.size();
|
size_t rank = shapes.size();
|
||||||
@@ -151,12 +151,10 @@ tile *machine_layout_distributed_t::create(ir::value *v) {
|
|||||||
auto cmp = [&](int x, int y) {
|
auto cmp = [&](int x, int y) {
|
||||||
unsigned axx = a_axes_->get(v, x);
|
unsigned axx = a_axes_->get(v, x);
|
||||||
unsigned axy = a_axes_->get(v, y);
|
unsigned axy = a_axes_->get(v, y);
|
||||||
auto itx = std::find(layout_->axes.begin(), layout_->axes.end(), axx);
|
size_t posx = layout_->find_axis(axx);
|
||||||
auto ity = std::find(layout_->axes.begin(), layout_->axes.end(), axy);
|
size_t posy = layout_->find_axis(axy);
|
||||||
size_t posx = std::distance(layout_->axes.begin(), itx);
|
|
||||||
size_t posy = std::distance(layout_->axes.begin(), ity);
|
|
||||||
if(posx < rank && posy < rank)
|
if(posx < rank && posy < rank)
|
||||||
return layout_->order[posx] < layout_->order[posy];
|
return layout_->get_order(posx) < layout_->get_order(posy);
|
||||||
return false;
|
return false;
|
||||||
};
|
};
|
||||||
std::sort(order.begin(), order.end(), cmp);
|
std::sort(order.begin(), order.end(), cmp);
|
||||||
@@ -164,22 +162,21 @@ tile *machine_layout_distributed_t::create(ir::value *v) {
|
|||||||
return new distributed_tile(ty, shapes, order, axes, *builder_);
|
return new distributed_tile(ty, shapes, order, axes, *builder_);
|
||||||
}
|
}
|
||||||
|
|
||||||
machine_layout_hmma_884_t::machine_layout_hmma_884_t(Module *mod, Builder *builder,
|
machine_mma884_layout::machine_mma884_layout(Module *mod, Builder *builder,
|
||||||
target *tgt, Type *ty, analysis::axes *a_axes,
|
target *tgt, analysis::axes *a_axes,
|
||||||
std::map<unsigned, distributed_axis>& axes,
|
std::map<unsigned, distributed_axis>& axes,
|
||||||
analysis::layout_hmma_884_t* layout)
|
analysis::mma884_layout* layout)
|
||||||
: machine_layout_distributed_t(mod, builder, tgt, ty, a_axes, axes, layout) {
|
: machine_distributed_layout(mod, builder, tgt, a_axes, axes, layout) {
|
||||||
|
|
||||||
Value *warp_size = builder_->getInt32(32);
|
Value *warp_size = builder_->getInt32(32);
|
||||||
Value* u_thread_id_0 = tgt_->get_local_id(mod_, *builder_, 0);
|
Value* u_thread_id_0 = tgt_->get_local_id(mod_, *builder_, 0);
|
||||||
Value *u_thread_id = builder_->CreateURem(u_thread_id_0, warp_size);
|
Value *u_thread_id = builder_->CreateURem(u_thread_id_0, warp_size);
|
||||||
Value *u_warp_id = builder_->CreateUDiv(u_thread_id_0, warp_size);
|
Value *u_warp_id = builder_->CreateUDiv(u_thread_id_0, warp_size);
|
||||||
|
|
||||||
const auto& shapes = layout->shapes;
|
const auto& shape = layout->get_shape();
|
||||||
if(shapes.size() > 3)
|
if(shape.size() > 3)
|
||||||
throw std::runtime_error("unsupported");
|
throw std::runtime_error("unsupported");
|
||||||
|
bool is_batched = shape.size() >= 3;
|
||||||
bool is_batched = shapes.size() >= 3;
|
|
||||||
|
|
||||||
Value *_1 = builder_->getInt32(1);
|
Value *_1 = builder_->getInt32(1);
|
||||||
Value *_2 = builder_->getInt32(2);
|
Value *_2 = builder_->getInt32(2);
|
||||||
@@ -188,13 +185,13 @@ machine_layout_hmma_884_t::machine_layout_hmma_884_t(Module *mod, Builder *build
|
|||||||
Value *_16 = builder_->getInt32(16);
|
Value *_16 = builder_->getInt32(16);
|
||||||
|
|
||||||
// fragments per warp
|
// fragments per warp
|
||||||
unsigned fpw_0 = layout->fpw.at(0);
|
unsigned fpw_0 = layout->fpw(0);
|
||||||
unsigned fpw_1 = layout->fpw.at(1);
|
unsigned fpw_1 = layout->fpw(1);
|
||||||
unsigned fpw_2 = is_batched ? layout->fpw.at(2) : 1;
|
unsigned fpw_2 = is_batched ? layout->fpw(2) : 1;
|
||||||
// warps per tile
|
// warps per tile
|
||||||
unsigned wpt_0 = layout->wpt.at(0);
|
unsigned wpt_0 = layout->wpt(0);
|
||||||
unsigned wpt_1 = layout->wpt.at(1);
|
unsigned wpt_1 = layout->wpt(1);
|
||||||
unsigned wpt_2 = is_batched ? layout->wpt.at(2) : 1;
|
unsigned wpt_2 = is_batched ? layout->wpt(2) : 1;
|
||||||
// hmma warp tile size
|
// hmma warp tile size
|
||||||
unsigned hmma_wts_0 = fpw_0 * 8;
|
unsigned hmma_wts_0 = fpw_0 * 8;
|
||||||
unsigned hmma_wts_1 = fpw_1 * 8;
|
unsigned hmma_wts_1 = fpw_1 * 8;
|
||||||
@@ -204,9 +201,9 @@ machine_layout_hmma_884_t::machine_layout_hmma_884_t(Module *mod, Builder *build
|
|||||||
unsigned hmma_bts_1 = hmma_wts_1 * wpt_1;
|
unsigned hmma_bts_1 = hmma_wts_1 * wpt_1;
|
||||||
unsigned hmma_bts_2 = is_batched ? hmma_wts_2 * wpt_2 : 1;
|
unsigned hmma_bts_2 = is_batched ? hmma_wts_2 * wpt_2 : 1;
|
||||||
// number of repetition
|
// number of repetition
|
||||||
unsigned num_rep_0 = shapes[0] / hmma_bts_0;
|
unsigned num_rep_0 = shape[0] / hmma_bts_0;
|
||||||
unsigned num_rep_1 = shapes[1] / hmma_bts_1;
|
unsigned num_rep_1 = shape[1] / hmma_bts_1;
|
||||||
unsigned num_rep_2 = is_batched ? shapes[2] / hmma_bts_2 : 1;
|
unsigned num_rep_2 = is_batched ? shape[2] / hmma_bts_2 : 1;
|
||||||
// size of each pack (interleaving)
|
// size of each pack (interleaving)
|
||||||
pack_size_0_ = std::min<unsigned>(num_rep_0, 1);
|
pack_size_0_ = std::min<unsigned>(num_rep_0, 1);
|
||||||
pack_size_1_ = std::min<unsigned>(num_rep_1, 1);
|
pack_size_1_ = std::min<unsigned>(num_rep_1, 1);
|
||||||
@@ -275,44 +272,52 @@ machine_layout_hmma_884_t::machine_layout_hmma_884_t(Module *mod, Builder *build
|
|||||||
|
|
||||||
|
|
||||||
/* axes */
|
/* axes */
|
||||||
axes_[layout->axes[0]] = distributed_axis{1, idx_i, warp_id_0};
|
axes_[layout->get_axis(0)] = distributed_axis{1, idx_i, warp_id_0};
|
||||||
axes_[layout->axes[1]] = distributed_axis{1, idx_j, warp_id_1};
|
axes_[layout->get_axis(1)] = distributed_axis{1, idx_j, warp_id_1};
|
||||||
if(is_batched)
|
if(is_batched)
|
||||||
axes_[layout->axes[2]] = distributed_axis{1, idx_z, warp_id_2};
|
axes_[layout->get_axis(2)] = distributed_axis{1, idx_z, warp_id_2};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
machine_layout_scanline_t::machine_layout_scanline_t(Module *mod, Builder *builder,
|
machine_scanline_layout::machine_scanline_layout(Module *mod, Builder *builder,
|
||||||
target *tgt, Type *ty,
|
target *tgt,
|
||||||
analysis::axes *a_axes, std::map<unsigned, distributed_axis> &axes,
|
analysis::axes *a_axes, std::map<unsigned, distributed_axis> &axes,
|
||||||
analysis::layout_scanline_t* layout)
|
analysis::scanline_layout* layout)
|
||||||
: machine_layout_distributed_t(mod, builder, tgt, ty, a_axes, axes, layout) {
|
: machine_distributed_layout(mod, builder, tgt, a_axes, axes, layout) {
|
||||||
|
|
||||||
Value *warp_size = builder_->getInt32(32);
|
Value *warp_size = builder_->getInt32(32);
|
||||||
Value* u_thread_id_0 = tgt_->get_local_id(mod_, *builder_, 0);
|
Value* u_thread_id_0 = tgt_->get_local_id(mod_, *builder_, 0);
|
||||||
Value *u_thread_id = builder_->CreateURem(u_thread_id_0, warp_size);
|
Value *u_thread_id = builder_->CreateURem(u_thread_id_0, warp_size);
|
||||||
Value *u_warp_id = builder_->CreateUDiv(u_thread_id_0, warp_size);
|
Value *u_warp_id = builder_->CreateUDiv(u_thread_id_0, warp_size);
|
||||||
|
|
||||||
auto order = layout->order;
|
auto order = layout->get_order();
|
||||||
const auto& shapes = layout->shapes;
|
const auto& shape = layout->get_shape();
|
||||||
size_t dim = shapes.size();
|
|
||||||
std::vector<int> nts = layout->nts;
|
|
||||||
std::vector<int> mts = layout->mts;
|
|
||||||
Value* full_thread_id = builder_->CreateAdd(builder_->CreateMul(u_warp_id, builder_->getInt32(32)), u_thread_id);
|
Value* full_thread_id = builder_->CreateAdd(builder_->CreateMul(u_warp_id, builder_->getInt32(32)), u_thread_id);
|
||||||
std::vector<Value*> thread_id = delinearize(full_thread_id, order, mts, *builder_);
|
// Delinearize
|
||||||
|
size_t dim = shape.size();
|
||||||
|
std::vector<Value*> thread_id(dim);
|
||||||
|
for(unsigned k = 0; k < dim - 1; k++){
|
||||||
|
Constant *dim_k = builder_->getInt32(layout->mts(order[k]));
|
||||||
|
Value *rem = builder_->CreateURem(full_thread_id, dim_k);
|
||||||
|
full_thread_id = builder_->CreateUDiv(full_thread_id, dim_k);
|
||||||
|
thread_id[order[k]] = rem;
|
||||||
|
}
|
||||||
|
thread_id[order[dim - 1]] = full_thread_id;
|
||||||
// Create axes
|
// Create axes
|
||||||
for(unsigned k = 0; k < dim; k++) {
|
for(unsigned k = 0; k < dim; k++) {
|
||||||
|
int nts = layout->nts(k);
|
||||||
|
int mts = layout->mts(k);
|
||||||
std::string str_k = std::to_string(k);
|
std::string str_k = std::to_string(k);
|
||||||
Value *contiguous_k = builder_->getInt32(nts[k]);
|
Value *contiguous_k = builder_->getInt32(nts);
|
||||||
Value *scaled_thread_id = builder_->CreateMul(thread_id[k], contiguous_k);
|
Value *scaled_thread_id = builder_->CreateMul(thread_id[k], contiguous_k);
|
||||||
unsigned per_block = nts[k] * mts[k];
|
unsigned per_block = nts * mts;
|
||||||
unsigned per_thread = nts[k] * shapes[k] / per_block;
|
unsigned per_thread = nts * shape[k] / per_block;
|
||||||
std::vector<Value*> idx_list(per_thread);
|
std::vector<Value*> idx_list(per_thread);
|
||||||
for(unsigned n = 0 ; n < per_thread; n++){
|
for(unsigned n = 0 ; n < per_thread; n++){
|
||||||
unsigned offset = n / nts[k] * per_block + n % nts[k];
|
unsigned offset = n / nts * per_block + n % nts;
|
||||||
idx_list[n] = builder_->CreateAdd(scaled_thread_id, builder_->getInt32(offset), "idx_" + str_k + "_" + std::to_string(n));
|
idx_list[n] = builder_->CreateAdd(scaled_thread_id, builder_->getInt32(offset), "idx_" + str_k + "_" + std::to_string(n));
|
||||||
}
|
}
|
||||||
axes_[layout->axes[k]] = distributed_axis{nts[k], idx_list, thread_id[k]};
|
axes_[layout->get_axis(k)] = distributed_axis{nts, idx_list, thread_id[k]};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -12,7 +12,7 @@ namespace triton {
|
|||||||
namespace codegen{
|
namespace codegen{
|
||||||
namespace transform{
|
namespace transform{
|
||||||
|
|
||||||
coalesce::coalesce(analysis::align* align, analysis::layout *layouts)
|
coalesce::coalesce(analysis::align* align, analysis::layouts *layouts)
|
||||||
: align_(align), layout_(layouts) { }
|
: align_(align), layout_(layouts) { }
|
||||||
|
|
||||||
// Find all values that are used as pointer operands in LD/ST
|
// Find all values that are used as pointer operands in LD/ST
|
||||||
@@ -64,8 +64,9 @@ ir::value* coalesce::rematerialize(ir::value *x, ir::builder &builder,
|
|||||||
void coalesce::run(ir::module &mod) {
|
void coalesce::run(ir::module &mod) {
|
||||||
size_t num_groups = layout_->num_layouts();
|
size_t num_groups = layout_->num_layouts();
|
||||||
|
|
||||||
|
|
||||||
for(size_t id = 0; id < num_groups; id++) {
|
for(size_t id = 0; id < num_groups; id++) {
|
||||||
if(layout_->get(id)->type != analysis::HMMA_884)
|
if(!layout_->get(id)->to_mma884())
|
||||||
continue;
|
continue;
|
||||||
// extract memory stores
|
// extract memory stores
|
||||||
const auto& values = layout_->values_of(id);
|
const auto& values = layout_->values_of(id);
|
||||||
@@ -97,7 +98,6 @@ void coalesce::run(ir::module &mod) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// find values to rematerialize
|
// find values to rematerialize
|
||||||
std::vector<ir::io_inst*> remat;
|
std::vector<ir::io_inst*> remat;
|
||||||
for(size_t id = 0; id < num_groups; id++) {
|
for(size_t id = 0; id < num_groups; id++) {
|
||||||
@@ -109,7 +109,7 @@ void coalesce::run(ir::module &mod) {
|
|||||||
// extract leading axes
|
// extract leading axes
|
||||||
std::map<int, std::vector<ir::io_inst*>> axes;
|
std::map<int, std::vector<ir::io_inst*>> axes;
|
||||||
for(ir::io_inst *i: io){
|
for(ir::io_inst *i: io){
|
||||||
if(i->get_pointer_operand()->get_type()->get_tile_ranks1() == layout_->get(id)->axes.size())
|
if(i->get_pointer_operand()->get_type()->get_tile_ranks1() == layout_->get(id)->get_rank())
|
||||||
extract_ld(i, axes);
|
extract_ld(i, axes);
|
||||||
}
|
}
|
||||||
// update list of values to rematerialize
|
// update list of values to rematerialize
|
||||||
|
@@ -35,10 +35,11 @@ void membar::add_reference(ir::value *v, interval_vec_t &res){
|
|||||||
return;
|
return;
|
||||||
if(!i->get_type()->is_tile_ty())
|
if(!i->get_type()->is_tile_ty())
|
||||||
return;
|
return;
|
||||||
if(alloc_->has_offset(layouts_->get(v))){
|
analysis::shared_layout* layout = layouts_->get(v)->to_shared();
|
||||||
unsigned offset = alloc_->offset(layouts_->get(v));
|
assert(layout);
|
||||||
unsigned size = layouts_->get(v)->to_shared()->size;
|
if(alloc_->has_offset(layout)){
|
||||||
res.push_back(interval_t(offset, offset + size));
|
unsigned offset = alloc_->offset(layout);
|
||||||
|
res.push_back(interval_t(offset, offset + layout->get_size()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -119,13 +120,11 @@ void membar::run(ir::module &mod) {
|
|||||||
// without needing synchronization
|
// without needing synchronization
|
||||||
std::set<ir::value*> safe_war;
|
std::set<ir::value*> safe_war;
|
||||||
for(const auto& x: layouts_->get_all()){
|
for(const auto& x: layouts_->get_all()){
|
||||||
if(x.second->type != analysis::SHARED)
|
analysis::shared_layout* layout = x.second->to_shared();
|
||||||
|
if(!layout || !layout->get_double_buffer())
|
||||||
continue;
|
continue;
|
||||||
analysis::layout_shared_t* layout = x.second->to_shared();
|
for(ir::value *v: layout->get_values())
|
||||||
if(!layout->double_buffer)
|
if(v != layout->get_double_buffer()->phi)
|
||||||
continue;
|
|
||||||
for(ir::value *v: layout->values)
|
|
||||||
if(v != layout->double_buffer->phi)
|
|
||||||
safe_war.insert(v);
|
safe_war.insert(v);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -220,7 +220,7 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
|||||||
codegen::analysis::align align;
|
codegen::analysis::align align;
|
||||||
codegen::analysis::axes axes;
|
codegen::analysis::axes axes;
|
||||||
codegen::transform::disassociate disassociate;
|
codegen::transform::disassociate disassociate;
|
||||||
codegen::analysis::layout layouts(&axes, &align, opt.num_warps);
|
codegen::analysis::layouts layouts(&axes, &align, opt.num_warps);
|
||||||
codegen::analysis::liveness liveness(&layouts);
|
codegen::analysis::liveness liveness(&layouts);
|
||||||
codegen::analysis::allocation allocation(&liveness);
|
codegen::analysis::allocation allocation(&liveness);
|
||||||
codegen::transform::membar barriers(&liveness, &layouts, &allocation);
|
codegen::transform::membar barriers(&liveness, &layouts, &allocation);
|
||||||
@@ -239,7 +239,6 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
|||||||
align.run(module);
|
align.run(module);
|
||||||
cts.run(module);
|
cts.run(module);
|
||||||
axes.run(module);
|
axes.run(module);
|
||||||
// ir::print(module, std::cout);
|
|
||||||
layouts.run(module);
|
layouts.run(module);
|
||||||
coalesce.run(module);
|
coalesce.run(module);
|
||||||
dce.run(module);
|
dce.run(module);
|
||||||
@@ -250,15 +249,14 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
|||||||
dce.run(module);
|
dce.run(module);
|
||||||
align.run(module);
|
align.run(module);
|
||||||
axes.run(module);
|
axes.run(module);
|
||||||
// ir::print(module, std::cout);
|
|
||||||
layouts.run(module);
|
layouts.run(module);
|
||||||
liveness.run(module);
|
liveness.run(module);
|
||||||
allocation.run(module);
|
allocation.run(module);
|
||||||
if(allocation.allocated_size() > context->device()->max_shared_memory())
|
if(allocation.allocated_size() > context->device()->max_shared_memory())
|
||||||
return std::unique_ptr<driver::module>();
|
return std::unique_ptr<driver::module>();
|
||||||
barriers.run(module);
|
barriers.run(module);
|
||||||
// ir::print(module, std::cout);
|
|
||||||
isel.visit(module, *llvm);
|
isel.visit(module, *llvm);
|
||||||
|
|
||||||
// return binary
|
// return binary
|
||||||
std::unique_ptr<driver::module> res(driver::module::create(context, std::move(llvm)));
|
std::unique_ptr<driver::module> res(driver::module::create(context, std::move(llvm)));
|
||||||
// done
|
// done
|
||||||
|
@@ -79,7 +79,7 @@ for N, T, H, S, E in NTHSE:
|
|||||||
|
|
||||||
# 1D Dense convolution
|
# 1D Dense convolution
|
||||||
NCHKR = [
|
NCHKR = [
|
||||||
# (1, 1152, 12602, 512, 3)
|
(1, 1152, 12602, 512, 3)
|
||||||
]
|
]
|
||||||
for N, C, H, K, R in NCHKR:
|
for N, C, H, K, R in NCHKR:
|
||||||
torch_fn = lambda a, b: torch.nn.functional.conv1d(a, b.permute(2, 0, 1))
|
torch_fn = lambda a, b: torch.nn.functional.conv1d(a, b.permute(2, 0, 1))
|
||||||
@@ -92,10 +92,10 @@ for N, C, H, K, R in NCHKR:
|
|||||||
|
|
||||||
# 2D Dense convolution
|
# 2D Dense convolution
|
||||||
NCHWKRS = [
|
NCHWKRS = [
|
||||||
#(8, 64, 128, 128, 768, 3, 3),
|
(8, 64, 128, 128, 768, 3, 3),
|
||||||
#(8, 128, 64, 64, 256, 3, 3),
|
(8, 128, 64, 64, 256, 3, 3),
|
||||||
#(8, 256, 32, 32, 512, 3, 3),
|
(8, 256, 32, 32, 512, 3, 3),
|
||||||
#(8, 512, 32, 32, 1024, 3, 3)
|
(8, 512, 32, 32, 1024, 3, 3)
|
||||||
]
|
]
|
||||||
for N, C, H, W, K, R, S in NCHWKRS:
|
for N, C, H, W, K, R, S in NCHWKRS:
|
||||||
torch_fn = lambda a, b: torch.nn.functional.conv2d(a, b.permute(3, 0, 1, 2))
|
torch_fn = lambda a, b: torch.nn.functional.conv2d(a, b.permute(3, 0, 1, 2))
|
||||||
@@ -108,10 +108,10 @@ for N, C, H, W, K, R, S in NCHWKRS:
|
|||||||
|
|
||||||
# 3D Dense Convolution
|
# 3D Dense Convolution
|
||||||
NCDHWKTRS = [
|
NCDHWKTRS = [
|
||||||
#(8, 32, 27, 100, 100, 64, 3, 3, 3),
|
(8, 32, 27, 100, 100, 64, 3, 3, 3),
|
||||||
#(8, 64, 23, 48, 48, 256, 3, 3, 3),
|
(8, 64, 23, 48, 48, 256, 3, 3, 3),
|
||||||
#(8, 256, 19, 22, 22, 640, 3, 3, 3),
|
(8, 256, 19, 22, 22, 640, 3, 3, 3),
|
||||||
#(8, 640, 15, 36, 36, 384, 3, 3, 3)
|
(8, 640, 15, 36, 36, 384, 3, 3, 3)
|
||||||
]
|
]
|
||||||
for N, C, D, H, W, K, T, R, S in NCDHWKTRS:
|
for N, C, D, H, W, K, T, R, S in NCDHWKTRS:
|
||||||
torch_fn = lambda a, b: torch.nn.functional.conv3d(a, b.permute(4, 0, 1, 2, 3))
|
torch_fn = lambda a, b: torch.nn.functional.conv3d(a, b.permute(4, 0, 1, 2, 3))
|
||||||
@@ -168,7 +168,7 @@ for N, C, H, W, K, R, S in NCHWKRS:
|
|||||||
# Benchmark
|
# Benchmark
|
||||||
torch.set_num_threads(1)
|
torch.set_num_threads(1)
|
||||||
for a_shape, b_shape, c_shape, torch_fn, expr, arrays in configs:
|
for a_shape, b_shape, c_shape, torch_fn, expr, arrays in configs:
|
||||||
dtype = torch.cuda.FloatTensor
|
dtype = torch.cuda.HalfTensor
|
||||||
# initialize input tensors
|
# initialize input tensors
|
||||||
a = torch.rand(*a_shape).type(dtype).cuda()
|
a = torch.rand(*a_shape).type(dtype).cuda()
|
||||||
b = torch.rand(*b_shape).type(dtype).cuda()
|
b = torch.rand(*b_shape).type(dtype).cuda()
|
||||||
|
Reference in New Issue
Block a user