more work on padding
This commit is contained in:
@@ -49,7 +49,7 @@ struct buffer_t {
|
|||||||
class liveness {
|
class liveness {
|
||||||
private:
|
private:
|
||||||
typedef std::map<ir::value*, slot_index> indices_map_t;
|
typedef std::map<ir::value*, slot_index> indices_map_t;
|
||||||
typedef std::map<buffer_t, segment> intervals_map_t;
|
typedef std::map<buffer_t*, segment> intervals_map_t;
|
||||||
typedef std::map<ir::value*, bool> has_storage_map_t;
|
typedef std::map<ir::value*, bool> has_storage_map_t;
|
||||||
typedef ir::value* node_t;
|
typedef ir::value* node_t;
|
||||||
typedef std::map <node_t, std::set<node_t>> graph_t;
|
typedef std::map <node_t, std::set<node_t>> graph_t;
|
||||||
@@ -63,24 +63,26 @@ public:
|
|||||||
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void connected_components(node_t x, std::set<node_t> &nodes, graph_t &graph, unsigned group_id);
|
void connected_components(node_t x, std::set<node_t> &nodes, graph_t &graph, buffer_t *buffer);
|
||||||
void extract_double_bufferable(ir::instruction *i);
|
void extract_double_bufferable(ir::instruction *i);
|
||||||
void extract_buffers(ir::instruction *i);
|
void extract_buffers(ir::instruction *i);
|
||||||
void get_parents(ir::instruction *i, std::vector<ir::value *>& res);
|
void get_parents(ir::instruction *i, std::vector<ir::value *>& res);
|
||||||
void make_graph(ir::instruction *i);
|
void make_graph(ir::instruction *i);
|
||||||
|
bool do_pad(ir::value *x);
|
||||||
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
liveness(tiles *t): tiles_(t){ }
|
liveness(tiles *t): tiles_(t){ }
|
||||||
|
// padding
|
||||||
|
unsigned get_pad(ir::value *v) const { return pad_.at(v); }
|
||||||
// buffer size
|
// buffer size
|
||||||
unsigned is_ld_padded(ir::value *x);
|
|
||||||
unsigned num_bytes(ir::value *x);
|
unsigned num_bytes(ir::value *x);
|
||||||
// accessors
|
// accessors
|
||||||
const intervals_map_t& intervals() const { return intervals_; }
|
const intervals_map_t& intervals() const { return intervals_; }
|
||||||
segment get_interval(buffer_t v) const { return intervals_.at(v); }
|
segment get_interval(buffer_t* v) const { return intervals_.at(v); }
|
||||||
// buffers
|
// buffers
|
||||||
buffer_t get_buffer(ir::value *v) const { return groups_.at(v); }
|
buffer_t* get_buffer(ir::value *v) const { return groups_.at(v); }
|
||||||
std::vector<ir::value*> get_values(buffer_t x) const { return values_.at(x); }
|
std::vector<ir::value*> get_values(buffer_t* x) const { return values_.at(x); }
|
||||||
// double-buffering
|
// double-buffering
|
||||||
bool has_double(ir::value *x) const { return double_.find(x) != double_.end(); }
|
bool has_double(ir::value *x) const { return double_.find(x) != double_.end(); }
|
||||||
double_buffer_info_t get_double(ir::value *x) const { return double_.at(x); }
|
double_buffer_info_t get_double(ir::value *x) const { return double_.at(x); }
|
||||||
@@ -95,12 +97,14 @@ private:
|
|||||||
indices_map_t indices;
|
indices_map_t indices;
|
||||||
intervals_map_t intervals_;
|
intervals_map_t intervals_;
|
||||||
std::map<ir::value*, double_buffer_info_t> double_;
|
std::map<ir::value*, double_buffer_info_t> double_;
|
||||||
|
std::map<ir::value*, size_t> pad_;
|
||||||
std::map<ir::value*, std::vector<ir::value*>> parents_;
|
std::map<ir::value*, std::vector<ir::value*>> parents_;
|
||||||
// graph
|
// graph
|
||||||
std::set<node_t> nodes_;
|
std::set<node_t> nodes_;
|
||||||
graph_t graph_;
|
graph_t graph_;
|
||||||
std::map<ir::value*, buffer_t> groups_;
|
std::vector<buffer_t*> buffers_;
|
||||||
std::map<buffer_t, std::vector<ir::value*>> values_;
|
std::map<ir::value*, buffer_t*> groups_;
|
||||||
|
std::map<buffer_t*, std::vector<ir::value*>> values_;
|
||||||
};
|
};
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@@ -89,7 +89,7 @@ private:
|
|||||||
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
shared_tile(Type* ty, const shapes_t &shapes, const std::vector<int> &order, Value* ptr, Builder &builder, Value* offset = nullptr);
|
shared_tile(Type* ty, const shapes_t &shapes, const std::vector<int> &order, Value* ptr, Builder &builder, Value* offset = nullptr, const std::vector<int>& perm = {});
|
||||||
void set_vector_size(unsigned vector_size);
|
void set_vector_size(unsigned vector_size);
|
||||||
void set_return_mode(bool return_vector);
|
void set_return_mode(bool return_vector);
|
||||||
void set_value(indices_t, Value *);
|
void set_value(indices_t, Value *);
|
||||||
@@ -97,8 +97,9 @@ public:
|
|||||||
Value* get_value(indices_t idx);
|
Value* get_value(indices_t idx);
|
||||||
Value* get_pointer() { return ptr_; }
|
Value* get_pointer() { return ptr_; }
|
||||||
Value* get_offset() { return offset_; }
|
Value* get_offset() { return offset_; }
|
||||||
|
const std::vector<int>& get_perm() { return perm_; }
|
||||||
const std::vector<int>& get_order() { return order_; }
|
const std::vector<int>& get_order() { return order_; }
|
||||||
static Value* shared_offset(Builder& builder, const shapes_t& shapes, const std::vector<int>& order, indices_t idx);
|
static Value* shared_offset(Builder& builder, const shapes_t& shapes, const std::vector<int>& perm, const std::vector<int>& order, indices_t idx);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Value *ptr_;
|
Value *ptr_;
|
||||||
@@ -108,6 +109,7 @@ private:
|
|||||||
std::map<indices_t, Value*> ptr_cache_;
|
std::map<indices_t, Value*> ptr_cache_;
|
||||||
unsigned vector_size_;
|
unsigned vector_size_;
|
||||||
std::vector<int> order_;
|
std::vector<int> order_;
|
||||||
|
std::vector<int> perm_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Distribtued tile
|
// Distribtued tile
|
||||||
|
@@ -135,7 +135,7 @@ public:
|
|||||||
value *create_atomic_exch(value *ptr, value *val, const std::string &name = "");
|
value *create_atomic_exch(value *ptr, value *val, const std::string &name = "");
|
||||||
value *create_atomic_add(value *ptr, value *val, const std::string &name = "");
|
value *create_atomic_add(value *ptr, value *val, const std::string &name = "");
|
||||||
value *create_dot(value *A, value *B, value *C, const std::string &name = "");
|
value *create_dot(value *A, value *B, value *C, const std::string &name = "");
|
||||||
value *create_trans(value *A, const std::vector<constant_int *> &perm = {}, const std::string &name = "");
|
value *create_trans(value *A, const std::vector<int> &perm = {}, const std::string &name = "");
|
||||||
value *create_sqrt(value *A, const std::string &name = "");
|
value *create_sqrt(value *A, const std::string &name = "");
|
||||||
value *create_reduce(value *A, unsigned axis, const std::string &name = "");
|
value *create_reduce(value *A, unsigned axis, const std::string &name = "");
|
||||||
value *create_select(value *pred, value *if_value, value *else_value, const std::string &name = "");
|
value *create_select(value *pred, value *if_value, value *else_value, const std::string &name = "");
|
||||||
|
@@ -591,7 +591,7 @@ public:
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
dot_inst(value *A, value *B, value *C, TransT AT, TransT BT, const std::string &name, instruction *next);
|
dot_inst(value *A, value *B, value *C, TransT AT, TransT BT, const std::string &name, instruction *next);
|
||||||
std::string repr_impl() const { return std::string("dot.") + ((AT_==NoTrans)?"n":"t") + ((BT_==NoTrans)?"n":"t"); }
|
std::string repr_impl() const { return "dot"; }
|
||||||
|
|
||||||
public:
|
public:
|
||||||
static instruction *create(value *A, value *B, value *C, bool AT, bool BT, const std::string &name = "", instruction *next = nullptr);
|
static instruction *create(value *A, value *B, value *C, bool AT, bool BT, const std::string &name = "", instruction *next = nullptr);
|
||||||
@@ -599,13 +599,7 @@ public:
|
|||||||
static instruction* create_nt(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr);
|
static instruction* create_nt(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr);
|
||||||
static instruction* create_tn(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr);
|
static instruction* create_tn(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr);
|
||||||
static instruction* create_tt(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr);
|
static instruction* create_tt(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr);
|
||||||
bool is_a_trans() { return AT_ == Trans; }
|
|
||||||
bool is_b_trans() { return BT_ == Trans; }
|
|
||||||
_TRITON_DEFINE_CLONE(dot_inst)
|
_TRITON_DEFINE_CLONE(dot_inst)
|
||||||
|
|
||||||
private:
|
|
||||||
TransT AT_;
|
|
||||||
TransT BT_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
//class outer_inst: public builtin_inst {
|
//class outer_inst: public builtin_inst {
|
||||||
@@ -617,20 +611,20 @@ private:
|
|||||||
|
|
||||||
class trans_inst: public builtin_inst {
|
class trans_inst: public builtin_inst {
|
||||||
public:
|
public:
|
||||||
ir::type* get_res_ty(ir::type* in, std::vector<constant_int *> perm);
|
ir::type* get_res_ty(ir::type* in, std::vector<int> perm);
|
||||||
std::vector<constant_int*> init_perm(ir::type* ty, const std::vector<constant_int*>& perm);
|
std::vector<int> init_perm(ir::type* ty, const std::vector<int>& perm);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
trans_inst(value *arg, const std::vector<constant_int*>& perm, const std::string& name, instruction* next);
|
trans_inst(value *arg, const std::vector<int>& perm, const std::string& name, instruction* next);
|
||||||
std::string repr_impl() const { return "trans"; }
|
std::string repr_impl() const { return "trans"; }
|
||||||
|
|
||||||
public:
|
public:
|
||||||
static instruction* create(value *arg, const std::vector<constant_int*>& perm = {}, const std::string &name = "", instruction *next = nullptr);
|
static instruction* create(value *arg, const std::vector<int> &perm = {}, const std::string &name = "", instruction *next = nullptr);
|
||||||
const std::vector<constant_int*> get_perm() const;
|
const std::vector<int> get_perm() const;
|
||||||
_TRITON_DEFINE_CLONE(trans_inst)
|
_TRITON_DEFINE_CLONE(trans_inst)
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::vector<constant_int*> perm_;
|
std::vector<int> perm_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class sqrt_inst: public builtin_inst {
|
class sqrt_inst: public builtin_inst {
|
||||||
|
@@ -487,9 +487,6 @@ void align::populate(ir::value *v) {
|
|||||||
populate_is_constant(v);
|
populate_is_constant(v);
|
||||||
populate_starting_multiple(v);
|
populate_starting_multiple(v);
|
||||||
populate_max_contiguous(v);
|
populate_max_contiguous(v);
|
||||||
// std::cout << v->get_name() << std::endl;
|
|
||||||
// if(max_contiguous_[v].size() == 2)
|
|
||||||
// std::cout << max_contiguous_[v][0] << " " << max_contiguous_[v][1] << std::endl;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void align::run(ir::module &mod) {
|
void align::run(ir::module &mod) {
|
||||||
|
@@ -21,22 +21,22 @@ void allocation::run(ir::module &mod) {
|
|||||||
using std::min;
|
using std::min;
|
||||||
typedef std::multimap<unsigned, segment> triples_map_type;
|
typedef std::multimap<unsigned, segment> triples_map_type;
|
||||||
|
|
||||||
std::vector<buffer_t> I;
|
std::vector<buffer_t*> I;
|
||||||
for(auto x: liveness_->intervals())
|
for(auto x: liveness_->intervals())
|
||||||
I.push_back(x.first);
|
I.push_back(x.first);
|
||||||
std::vector<buffer_t> J = I;
|
std::vector<buffer_t*> J = I;
|
||||||
|
|
||||||
triples_map_type H;
|
triples_map_type H;
|
||||||
H.insert({0, segment{0, INT_MAX}});
|
H.insert({0, segment{0, INT_MAX}});
|
||||||
|
|
||||||
std::vector<buffer_t> V;
|
std::vector<buffer_t*> V;
|
||||||
std::map<buffer_t, unsigned> starts;
|
std::map<buffer_t*, unsigned> starts;
|
||||||
while(!J.empty()){
|
while(!J.empty()){
|
||||||
auto h_it = H.begin();
|
auto h_it = H.begin();
|
||||||
unsigned w = h_it->first;
|
unsigned w = h_it->first;
|
||||||
segment xh = h_it->second;
|
segment xh = h_it->second;
|
||||||
H.erase(h_it);
|
H.erase(h_it);
|
||||||
auto j_it = std::find_if(J.begin(), J.end(), [&](buffer_t JJ){
|
auto j_it = std::find_if(J.begin(), J.end(), [&](buffer_t* JJ){
|
||||||
segment xj = liveness_->get_interval(JJ);
|
segment xj = liveness_->get_interval(JJ);
|
||||||
bool res = xj.intersect(xh);
|
bool res = xj.intersect(xh);
|
||||||
for(auto val: H)
|
for(auto val: H)
|
||||||
@@ -44,7 +44,7 @@ void allocation::run(ir::module &mod) {
|
|||||||
return res;
|
return res;
|
||||||
});
|
});
|
||||||
if(j_it != J.end()){
|
if(j_it != J.end()){
|
||||||
unsigned size = j_it->size;
|
unsigned size = (*j_it)->size;
|
||||||
segment xj = liveness_->get_interval(*j_it);
|
segment xj = liveness_->get_interval(*j_it);
|
||||||
starts[*j_it] = w;
|
starts[*j_it] = w;
|
||||||
H.insert({w + size, segment{max(xh.start, xj.start), min(xh.end, xj.end)}});
|
H.insert({w + size, segment{max(xh.start, xj.start), min(xh.end, xj.end)}});
|
||||||
@@ -58,14 +58,14 @@ void allocation::run(ir::module &mod) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Build interference graph
|
// Build interference graph
|
||||||
std::map<buffer_t, std::set<buffer_t>> interferences;
|
std::map<buffer_t*, std::set<buffer_t*>> interferences;
|
||||||
for(buffer_t x: V)
|
for(buffer_t* x: V)
|
||||||
for(buffer_t y: V){
|
for(buffer_t* y: V){
|
||||||
if(x.id == y.id)
|
if(x->id == y->id)
|
||||||
continue;
|
continue;
|
||||||
unsigned X0 = starts[x], Y0 = starts[y];
|
unsigned X0 = starts[x], Y0 = starts[y];
|
||||||
unsigned NX = x.size;
|
unsigned NX = x->size;
|
||||||
unsigned NY = y.size;
|
unsigned NY = y->size;
|
||||||
segment XS = {X0, X0 + NX};
|
segment XS = {X0, X0 + NX};
|
||||||
segment YS = {Y0, Y0 + NY};
|
segment YS = {Y0, Y0 + NY};
|
||||||
if(liveness_->get_interval(x).intersect(liveness_->get_interval(y))
|
if(liveness_->get_interval(x).intersect(liveness_->get_interval(y))
|
||||||
@@ -74,17 +74,17 @@ void allocation::run(ir::module &mod) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Initialize colors
|
// Initialize colors
|
||||||
std::map<buffer_t, int> colors;
|
std::map<buffer_t*, int> colors;
|
||||||
for(buffer_t X: V)
|
for(buffer_t* X: V)
|
||||||
colors[X] = (X.id==V[0].id)?0:-1;
|
colors[X] = (X->id==V[0]->id)?0:-1;
|
||||||
|
|
||||||
|
|
||||||
// First-fit graph coloring
|
// First-fit graph coloring
|
||||||
std::vector<bool> available(V.size());
|
std::vector<bool> available(V.size());
|
||||||
for(buffer_t x: V){
|
for(buffer_t* x: V){
|
||||||
// Non-neighboring colors are available
|
// Non-neighboring colors are available
|
||||||
std::fill(available.begin(), available.end(), true);
|
std::fill(available.begin(), available.end(), true);
|
||||||
for(buffer_t Y: interferences[x]){
|
for(buffer_t* Y: interferences[x]){
|
||||||
int color = colors[Y];
|
int color = colors[Y];
|
||||||
if(color >= 0)
|
if(color >= 0)
|
||||||
available[color] = false;
|
available[color] = false;
|
||||||
@@ -95,25 +95,24 @@ void allocation::run(ir::module &mod) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Finalize allocation
|
// Finalize allocation
|
||||||
for(buffer_t x: V){
|
for(buffer_t* x: V){
|
||||||
unsigned Adj = 0;
|
unsigned Adj = 0;
|
||||||
for(buffer_t y: interferences[x])
|
for(buffer_t* y: interferences[x])
|
||||||
Adj = std::max<unsigned>(Adj, starts[y] + y.size);
|
Adj = std::max<unsigned>(Adj, starts[y] + y->size);
|
||||||
// create offsets
|
// create offsets
|
||||||
for(ir::value *v: liveness_->get_values(x)){
|
for(ir::value *v: liveness_->get_values(x)){
|
||||||
offsets_[v] = starts[x] + colors[x] * Adj;
|
offsets_[v] = starts[x] + colors[x] * Adj;
|
||||||
if(liveness_->has_double(v)){
|
if(liveness_->has_double(v)){
|
||||||
auto info = liveness_->get_double(v);
|
auto info = liveness_->get_double(v);
|
||||||
offsets_[info.latch] = offsets_[v] + x.size / 2;
|
offsets_[info.latch] = offsets_[v] + x->size / 2;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Save maximum size of induced memory space
|
// Save maximum size of induced memory space
|
||||||
allocated_size_ = 0;
|
allocated_size_ = 0;
|
||||||
for(auto &x: offsets_){
|
for(buffer_t* x: V)
|
||||||
allocated_size_ = std::max<size_t>(allocated_size_, x.second + liveness_->get_buffer(x.first).size);
|
allocated_size_ = std::max<size_t>(allocated_size_, starts[x] + x->size);
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@@ -74,7 +74,7 @@ void axes::update_graph_trans(ir::instruction *i) {
|
|||||||
auto perm = trans->get_perm();
|
auto perm = trans->get_perm();
|
||||||
// add edge between axis perm[d] and axis d
|
// add edge between axis perm[d] and axis d
|
||||||
for(unsigned d = 0; d < perm.size(); d++)
|
for(unsigned d = 0; d < perm.size(); d++)
|
||||||
add_constraint({i, perm[d]->get_value()}, {op, d});
|
add_constraint({i, perm[d]}, {op, d});
|
||||||
}
|
}
|
||||||
|
|
||||||
void axes::update_graph_broadcast(ir::instruction *i) {
|
void axes::update_graph_broadcast(ir::instruction *i) {
|
||||||
|
@@ -58,6 +58,18 @@ void liveness::make_graph(ir::instruction *i) {
|
|||||||
graph_[i].insert(latch);
|
graph_[i].insert(latch);
|
||||||
graph_[latch].insert(i);
|
graph_[latch].insert(i);
|
||||||
}
|
}
|
||||||
|
if(i->get_id() == ir::INST_PHI){
|
||||||
|
ir::phi_node* phi = (ir::phi_node*)i;
|
||||||
|
for(ir::value* op: phi->ops()){
|
||||||
|
auto* iop = dynamic_cast<ir::instruction*>(op);
|
||||||
|
if(!iop || storage_info.at(iop->get_id()).first != SHARED)
|
||||||
|
continue;
|
||||||
|
nodes_.insert(phi);
|
||||||
|
nodes_.insert(op);
|
||||||
|
graph_[phi].insert(op);
|
||||||
|
graph_[op].insert(phi);
|
||||||
|
}
|
||||||
|
}
|
||||||
if(i->get_id() == ir::INST_TRANS){
|
if(i->get_id() == ir::INST_TRANS){
|
||||||
nodes_.insert(i);
|
nodes_.insert(i);
|
||||||
nodes_.insert(i->get_operand(0));
|
nodes_.insert(i->get_operand(0));
|
||||||
@@ -67,39 +79,63 @@ void liveness::make_graph(ir::instruction *i) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// connected components
|
// connected components
|
||||||
void liveness::connected_components(node_t x, std::set<node_t> &nodes, graph_t &graph, unsigned group_id) {
|
void liveness::connected_components(node_t x, std::set<node_t> &nodes, graph_t &graph, buffer_t* buffer) {
|
||||||
buffer_t buffer{group_id, num_bytes(x)};
|
|
||||||
groups_[x] = buffer;
|
groups_[x] = buffer;
|
||||||
values_[buffer].push_back(x);
|
values_[buffer].push_back(x);
|
||||||
if(nodes.find(x) != nodes.end()){
|
if(nodes.find(x) != nodes.end()){
|
||||||
nodes.erase(x);
|
nodes.erase(x);
|
||||||
for(const node_t &y: graph[x])
|
for(const node_t &y: graph[x])
|
||||||
connected_components(y, nodes, graph, group_id);
|
connected_components(y, nodes, graph, buffer);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsigned liveness::is_ld_padded(ir::value *x) {
|
bool liveness::do_pad(ir::value *x) {
|
||||||
if(auto *trans = dynamic_cast<ir::trans_inst*>(x)){
|
// alignment for matrix product
|
||||||
if(trans->get_perm()[0]->get_value() != 0)
|
if(auto* dot = dynamic_cast<ir::dot_inst*>(x)) {
|
||||||
return 4;
|
|
||||||
}
|
|
||||||
auto order = tiles_->order(x);
|
auto order = tiles_->order(x);
|
||||||
bool is_col_major = order[0] == 0;
|
// a
|
||||||
|
ir::value *a = dot->get_operand(0);\
|
||||||
|
size_t previous_a = pad_[a];
|
||||||
|
bool a_trans = dynamic_cast<ir::trans_inst*>(a);
|
||||||
|
bool a_row = order[0] == 1;
|
||||||
if(tiles_->hmma(x) == HMMA_A_ROW)
|
if(tiles_->hmma(x) == HMMA_A_ROW)
|
||||||
return is_col_major ? 16 : 16;
|
pad_[a] = 16;
|
||||||
if(tiles_->hmma(x) == HMMA_A_COL)
|
else if(tiles_->hmma(x) == HMMA_A_COL)
|
||||||
return is_col_major ? 8 : 8;
|
pad_[a] = 8;
|
||||||
|
else if(a_trans ^ a_row)
|
||||||
|
pad_[a] = 4;
|
||||||
|
else
|
||||||
|
pad_[a] = 0;
|
||||||
|
// b
|
||||||
|
ir::value *b = dot->get_operand(1);
|
||||||
|
size_t previous_b = pad_[b];
|
||||||
|
bool b_trans = dynamic_cast<ir::trans_inst*>(a);
|
||||||
|
bool b_col = order[0] == 0;
|
||||||
if(tiles_->hmma(x) == HMMA_B_COL)
|
if(tiles_->hmma(x) == HMMA_B_COL)
|
||||||
return is_col_major ? 16 : 16;
|
pad_[b] = 16;
|
||||||
if(tiles_->hmma(x) == HMMA_B_ROW)
|
if(tiles_->hmma(x) == HMMA_B_ROW)
|
||||||
return is_col_major ? 8 : 8;
|
pad_[b] = 8;
|
||||||
if(auto* phi = dynamic_cast<ir::phi_node*>(x)) {
|
if(b_trans ^ b_col)
|
||||||
unsigned result = 0;
|
pad_[b] = 4;
|
||||||
for(unsigned i = 0; i < phi->get_num_incoming(); i++)
|
else
|
||||||
result = std::max(result, is_ld_padded(phi->get_incoming_value(i)));
|
pad_[b] = 0;
|
||||||
return result;
|
return previous_a != pad_[a] || previous_b != pad_[b];
|
||||||
}
|
}
|
||||||
return 0;
|
// padding for phi-nodes
|
||||||
|
if(auto* phi = dynamic_cast<ir::phi_node*>(x)) {
|
||||||
|
bool has_changed = false;
|
||||||
|
for(unsigned i = 0; i < phi->get_num_incoming(); i++){
|
||||||
|
ir::value* op = phi->get_operand(i);
|
||||||
|
size_t previous = pad_[op];
|
||||||
|
pad_[op] = std::max(pad_[op], pad_[phi]);
|
||||||
|
has_changed |= previous != pad_[op];
|
||||||
|
}
|
||||||
|
return has_changed;
|
||||||
|
}
|
||||||
|
// default -- no pading
|
||||||
|
size_t previous = pad_[x];
|
||||||
|
pad_[x] = std::max<int>(previous, 0);
|
||||||
|
return pad_[x] != previous;
|
||||||
}
|
}
|
||||||
|
|
||||||
unsigned liveness::num_bytes(ir::value *x) {
|
unsigned liveness::num_bytes(ir::value *x) {
|
||||||
@@ -120,7 +156,8 @@ unsigned liveness::num_bytes(ir::value *x) {
|
|||||||
return num_elements * num_bytes * depth;
|
return num_elements * num_bytes * depth;
|
||||||
}
|
}
|
||||||
unsigned num_bytes = x->get_type()->get_primitive_size_in_bits() / 8;
|
unsigned num_bytes = x->get_type()->get_primitive_size_in_bits() / 8;
|
||||||
unsigned pad = is_ld_padded(x);
|
unsigned pad = pad_.at(x);
|
||||||
|
std::cout << x->get_name() << " " << pad << std::endl;
|
||||||
if(pad > 0){
|
if(pad > 0){
|
||||||
unsigned ld = x->get_type()->get_tile_shapes()[tiles_->order(x)[0]];
|
unsigned ld = x->get_type()->get_tile_shapes()[tiles_->order(x)[0]];
|
||||||
num_bytes += pad * num_bytes / ld;
|
num_bytes += pad * num_bytes / ld;
|
||||||
@@ -134,6 +171,7 @@ unsigned liveness::num_bytes(ir::value *x) {
|
|||||||
void liveness::run(ir::module &mod) {
|
void liveness::run(ir::module &mod) {
|
||||||
double_.clear();
|
double_.clear();
|
||||||
indices.clear();
|
indices.clear();
|
||||||
|
pad_.clear();
|
||||||
intervals_.clear();
|
intervals_.clear();
|
||||||
parents_.clear();
|
parents_.clear();
|
||||||
|
|
||||||
@@ -142,6 +180,15 @@ void liveness::run(ir::module &mod) {
|
|||||||
this->extract_double_bufferable(i);
|
this->extract_double_bufferable(i);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// Padding information
|
||||||
|
bool has_changed;
|
||||||
|
do{
|
||||||
|
has_changed = false;
|
||||||
|
ir::for_each_value(mod, [this, &has_changed](ir::value* v){
|
||||||
|
has_changed |= this->do_pad(v);
|
||||||
|
});
|
||||||
|
}while(has_changed);
|
||||||
|
|
||||||
// Create buffer dependency graph
|
// Create buffer dependency graph
|
||||||
ir::for_each_instruction(mod, [this](ir::instruction* i) {
|
ir::for_each_instruction(mod, [this](ir::instruction* i) {
|
||||||
this->make_graph(i);
|
this->make_graph(i);
|
||||||
@@ -150,7 +197,10 @@ void liveness::run(ir::module &mod) {
|
|||||||
// connected components
|
// connected components
|
||||||
unsigned group_id = 0;
|
unsigned group_id = 0;
|
||||||
while(!nodes_.empty()){
|
while(!nodes_.empty()){
|
||||||
connected_components(*nodes_.begin(), nodes_, graph_, group_id++);
|
buffer_t* buffer = new buffer_t{group_id++};
|
||||||
|
connected_components(*nodes_.begin(), nodes_, graph_, buffer);
|
||||||
|
for(ir::value *v: values_.at(buffer))
|
||||||
|
buffer->size = std::max<int>(buffer->size, num_bytes(v));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Assigns index to each instruction
|
// Assigns index to each instruction
|
||||||
|
@@ -40,7 +40,7 @@ bool is_hmma_a_col(ir::value* v) {
|
|||||||
for(ir::user *u: v->get_users())
|
for(ir::user *u: v->get_users())
|
||||||
if(is_hmma_c(u)){
|
if(is_hmma_c(u)){
|
||||||
ir::dot_inst* dot = (ir::dot_inst*)u;
|
ir::dot_inst* dot = (ir::dot_inst*)u;
|
||||||
if((v == dot->get_operand(0)) && !dot->is_a_trans())
|
if((v == dot->get_operand(0)))
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -49,7 +49,7 @@ bool is_hmma_a_row(ir::value* v) {
|
|||||||
for(ir::user *u: v->get_users())
|
for(ir::user *u: v->get_users())
|
||||||
if(is_hmma_c(u)){
|
if(is_hmma_c(u)){
|
||||||
ir::dot_inst* dot = (ir::dot_inst*)u;
|
ir::dot_inst* dot = (ir::dot_inst*)u;
|
||||||
if((v == dot->get_operand(0)) && dot->is_a_trans())
|
if((v == dot->get_operand(0)))
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -58,7 +58,7 @@ bool is_hmma_b_col(ir::value* v) {
|
|||||||
for(ir::user *u: v->get_users())
|
for(ir::user *u: v->get_users())
|
||||||
if(is_hmma_c(u)){
|
if(is_hmma_c(u)){
|
||||||
ir::dot_inst* dot = (ir::dot_inst*)u;
|
ir::dot_inst* dot = (ir::dot_inst*)u;
|
||||||
if((v == dot->get_operand(1)) && !dot->is_b_trans())
|
if((v == dot->get_operand(1)))
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -67,7 +67,7 @@ bool is_hmma_b_row(ir::value* v) {
|
|||||||
for(ir::user *u: v->get_users())
|
for(ir::user *u: v->get_users())
|
||||||
if(is_hmma_c(u)){
|
if(is_hmma_c(u)){
|
||||||
ir::dot_inst* dot = (ir::dot_inst*)u;
|
ir::dot_inst* dot = (ir::dot_inst*)u;
|
||||||
if((v == dot->get_operand(1)) && dot->is_b_trans())
|
if((v == dot->get_operand(1)))
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -170,6 +170,7 @@ void tiles::init_scanline_tile(ir::value *i) {
|
|||||||
unsigned effective_num_threads = 1;
|
unsigned effective_num_threads = 1;
|
||||||
for(size_t d = 0; d < shapes.size(); d++)
|
for(size_t d = 0; d < shapes.size(); d++)
|
||||||
effective_num_threads *= mts_[axes_->get_id(i, d)];
|
effective_num_threads *= mts_[axes_->get_id(i, d)];
|
||||||
|
// std::cout << num_threads << " " << effective_num_threads << std::endl;
|
||||||
if(num_threads != effective_num_threads)
|
if(num_threads != effective_num_threads)
|
||||||
throw std::runtime_error("cannot create a kernel with this amount of warps");
|
throw std::runtime_error("cannot create a kernel with this amount of warps");
|
||||||
}
|
}
|
||||||
@@ -219,7 +220,7 @@ void tiles::run(ir::module &) {
|
|||||||
largest_[i] = *std::max_element(values.begin(), values.end(), cmp);
|
largest_[i] = *std::max_element(values.begin(), values.end(), cmp);
|
||||||
}
|
}
|
||||||
|
|
||||||
// find out the order of a group
|
// find out the layout ordering of a group
|
||||||
for(size_t i = 0; i < num_groups; i++){
|
for(size_t i = 0; i < num_groups; i++){
|
||||||
std::set<ir::io_inst*> io;
|
std::set<ir::io_inst*> io;
|
||||||
for(ir::value* v: layout_->values(i))
|
for(ir::value* v: layout_->values(i))
|
||||||
@@ -239,11 +240,6 @@ void tiles::run(ir::module &) {
|
|||||||
order_[i] = order;
|
order_[i] = order;
|
||||||
}
|
}
|
||||||
for(size_t i = 0; i < num_groups; i++){
|
for(size_t i = 0; i < num_groups; i++){
|
||||||
bool is_hmma_op = hmma_[i] == HMMA_A_COL || hmma_[i] == HMMA_A_ROW ||
|
|
||||||
hmma_[i] == HMMA_B_COL || hmma_[i] == HMMA_B_ROW;
|
|
||||||
if(!is_hmma_op)
|
|
||||||
continue;
|
|
||||||
// extract copies to shared memory
|
|
||||||
std::vector<ir::copy_to_shared_inst*> cts;
|
std::vector<ir::copy_to_shared_inst*> cts;
|
||||||
for(ir::value* v: layout_->values(i))
|
for(ir::value* v: layout_->values(i))
|
||||||
if(auto *x = dynamic_cast<ir::copy_to_shared_inst*>(v))
|
if(auto *x = dynamic_cast<ir::copy_to_shared_inst*>(v))
|
||||||
|
@@ -146,26 +146,30 @@ void shared_tile::extract_constant(const indices_t &arg_idx, indices_t &non_cst_
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
Value* shared_tile::shared_offset(llvm::IRBuilder<> &builder, const shapes_t& shapes, const std::vector<int>& order, indices_t idx) {
|
Value* shared_tile::shared_offset(llvm::IRBuilder<> &builder, const shapes_t& shapes, const std::vector<int>& perm, const std::vector<int>& order, indices_t idx) {
|
||||||
|
// strides
|
||||||
|
std::vector<Value*> strides(order.size());
|
||||||
|
strides[order[0]] = builder.getInt32(1);
|
||||||
|
for(size_t i = 1; i < idx.size(); i++)
|
||||||
|
strides[order[i]] = builder.CreateMul(strides[order[i-1]], builder.getInt32(shapes[order[i-1]]));
|
||||||
|
// result
|
||||||
Value *result = builder.getInt32(0);
|
Value *result = builder.getInt32(0);
|
||||||
result = builder.CreateAdd(result, idx[order[0]]);
|
for(size_t i = 0; i < strides.size(); i++)
|
||||||
Value *ld = builder.getInt32(shapes[order[0]]);
|
result = builder.CreateAdd(result, builder.CreateMul(idx[perm[i]], strides[i]));
|
||||||
for(size_t i = 1; i < idx.size(); i++) {
|
|
||||||
result = builder.CreateAdd(result, builder.CreateMul(idx[order[i]], ld));
|
|
||||||
if(i < idx.size() - 1){
|
|
||||||
ld = builder.CreateMul(ld, builder.getInt32(shapes[order[i]]));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
shared_tile::shared_tile(Type *ty, const shapes_t &shapes, const std::vector<int>& order, Value *ptr, llvm::IRBuilder<> &builder, Value *offset):
|
shared_tile::shared_tile(Type *ty, const shapes_t &shapes, const std::vector<int>& order, Value *ptr, llvm::IRBuilder<> &builder, Value *offset, const std::vector<int>& perm):
|
||||||
tile(ty, shapes), order_(order), ptr_(ptr), builder_(builder), offset_(offset), vector_size_(1){
|
tile(ty, shapes), order_(order), ptr_(ptr), builder_(builder), offset_(offset), vector_size_(1), perm_(perm){
|
||||||
return_vector_ = false;
|
return_vector_ = false;
|
||||||
|
if(perm_.empty()){
|
||||||
|
perm_.resize(shapes.size());
|
||||||
|
std::iota(perm_.begin(), perm_.end(), 0);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void shared_tile::set_value(indices_t idx, Value *value) {
|
void shared_tile::set_value(indices_t idx, Value *value) {
|
||||||
Value *ptr = builder_.CreateGEP(ptr_, shared_offset(builder_, shapes_, order_, idx));
|
Value *ptr = builder_.CreateGEP(ptr_, shared_offset(builder_, shapes_, perm_, order_, idx));
|
||||||
unsigned addr_space = ptr->getType()->getPointerAddressSpace();
|
unsigned addr_space = ptr->getType()->getPointerAddressSpace();
|
||||||
ptr = builder_.CreateBitCast(ptr, value->getType()->getPointerTo(addr_space));
|
ptr = builder_.CreateBitCast(ptr, value->getType()->getPointerTo(addr_space));
|
||||||
builder_.CreateStore(value, ptr);
|
builder_.CreateStore(value, ptr);
|
||||||
@@ -196,7 +200,7 @@ Value* shared_tile::get_value(indices_t idx) {
|
|||||||
// if(isa<Instruction>(non_cst_idx.front())){
|
// if(isa<Instruction>(non_cst_idx.front())){
|
||||||
// builder_.SetInsertPoint((Instruction*)non_cst_idx.front());
|
// builder_.SetInsertPoint((Instruction*)non_cst_idx.front());
|
||||||
// }
|
// }
|
||||||
base_ptr = builder_.CreateGEP(ptr_, shared_offset(builder_, shapes_, order_, non_cst_idx));
|
base_ptr = builder_.CreateGEP(ptr_, shared_offset(builder_, shapes_, perm_, order_, non_cst_idx));
|
||||||
if(vector_size_ > 1){
|
if(vector_size_ > 1){
|
||||||
Type *vec_ty = VectorType::get(ty, vector_size);
|
Type *vec_ty = VectorType::get(ty, vector_size);
|
||||||
Type *vec_ptr_ty = PointerType::get(vec_ty, base_ptr->getType()->getPointerAddressSpace());
|
Type *vec_ptr_ty = PointerType::get(vec_ty, base_ptr->getType()->getPointerAddressSpace());
|
||||||
@@ -204,7 +208,7 @@ Value* shared_tile::get_value(indices_t idx) {
|
|||||||
}
|
}
|
||||||
// builder_.SetInsertPoint(store);
|
// builder_.SetInsertPoint(store);
|
||||||
}
|
}
|
||||||
Value *offset = shared_offset(builder_, shapes_, order_, cst_idx);
|
Value *offset = shared_offset(builder_, shapes_, perm_, order_, cst_idx);
|
||||||
Value *div = offset;
|
Value *div = offset;
|
||||||
if(vector_size_ > 1)
|
if(vector_size_ > 1)
|
||||||
div = builder_.CreateUDiv(offset, builder_.getInt32(vector_size_));
|
div = builder_.CreateUDiv(offset, builder_.getInt32(vector_size_));
|
||||||
@@ -725,7 +729,7 @@ void selection::create_shared_tile(ir::value *v, IRBuilder<> &builder, Value *sh
|
|||||||
return;
|
return;
|
||||||
auto order = tiles_->order(v);
|
auto order = tiles_->order(v);
|
||||||
auto shapes = v->get_type()->get_tile_shapes();
|
auto shapes = v->get_type()->get_tile_shapes();
|
||||||
unsigned pad = liveness_->is_ld_padded(v);
|
unsigned pad = liveness_->get_pad(v);
|
||||||
if(pad > 0)
|
if(pad > 0)
|
||||||
shapes[order[0]] += pad;
|
shapes[order[0]] += pad;
|
||||||
Type* ty = llvm_type(v->get_type()->get_scalar_ty(), builder.getContext());
|
Type* ty = llvm_type(v->get_type()->get_scalar_ty(), builder.getContext());
|
||||||
@@ -923,7 +927,7 @@ void selection::lower_reduce(ir::reduce_inst *x, LLVMContext &ctx, Function *fn,
|
|||||||
write_idx.insert(write_idx.begin() + axis, lane);
|
write_idx.insert(write_idx.begin() + axis, lane);
|
||||||
|
|
||||||
// shared memory write pointer
|
// shared memory write pointer
|
||||||
Value *write_offset = shared_tile::shared_offset(builder, op_tile->get_shapes(), op_tile->get_order(), write_idx);
|
Value *write_offset = shared_tile::shared_offset(builder, op_tile->get_shapes(), {0, 1}, op_tile->get_order(), write_idx);
|
||||||
Value *write_ptr = builder.CreateGEP(base_ptr, write_offset);
|
Value *write_ptr = builder.CreateGEP(base_ptr, write_offset);
|
||||||
|
|
||||||
// initialize shared memory
|
// initialize shared memory
|
||||||
@@ -936,7 +940,7 @@ void selection::lower_reduce(ir::reduce_inst *x, LLVMContext &ctx, Function *fn,
|
|||||||
indices_t current(write_idx.size(), builder.getInt32(0));
|
indices_t current(write_idx.size(), builder.getInt32(0));
|
||||||
current[axis] = builder.getInt32(i);
|
current[axis] = builder.getInt32(i);
|
||||||
// shared memory offset
|
// shared memory offset
|
||||||
Value *read_offset = shared_tile::shared_offset(builder, op_tile->get_shapes(), op_tile->get_order(), current);
|
Value *read_offset = shared_tile::shared_offset(builder, op_tile->get_shapes(), {0, 1}, op_tile->get_order(), current);
|
||||||
Value *is_active = builder.CreateICmpULT(lane, builder.getInt32(i));
|
Value *is_active = builder.CreateICmpULT(lane, builder.getInt32(i));
|
||||||
read_offset = builder.CreateSelect(is_active, read_offset, builder.getInt32(0));
|
read_offset = builder.CreateSelect(is_active, read_offset, builder.getInt32(0));
|
||||||
// shared memory read pointer
|
// shared memory read pointer
|
||||||
@@ -952,7 +956,7 @@ void selection::lower_reduce(ir::reduce_inst *x, LLVMContext &ctx, Function *fn,
|
|||||||
// result is on the first lane of shared memory
|
// result is on the first lane of shared memory
|
||||||
indices_t final = write_idx;
|
indices_t final = write_idx;
|
||||||
final[axis] = builder.getInt32(0);
|
final[axis] = builder.getInt32(0);
|
||||||
Value *read_offset = shared_tile::shared_offset(builder, op_tile->get_shapes(), op_tile->get_order(), final);
|
Value *read_offset = shared_tile::shared_offset(builder, op_tile->get_shapes(), {0, 1}, op_tile->get_order(), final);
|
||||||
Value *read_ptr = builder.CreateGEP(base_ptr, read_offset);
|
Value *read_ptr = builder.CreateGEP(base_ptr, read_offset);
|
||||||
tgt_->add_barrier(module, builder);
|
tgt_->add_barrier(module, builder);
|
||||||
result = builder.CreateLoad(read_ptr);
|
result = builder.CreateLoad(read_ptr);
|
||||||
@@ -1041,11 +1045,7 @@ void selection::lower_copy_to_shared(ir::copy_to_shared_inst *x, LLVMContext &ct
|
|||||||
|
|
||||||
void selection::lower_trans(ir::trans_inst *x, LLVMContext &ctx, Function *fn, IRBuilder<> &builder) {
|
void selection::lower_trans(ir::trans_inst *x, LLVMContext &ctx, Function *fn, IRBuilder<> &builder) {
|
||||||
shared_tile* in = (shared_tile*)tmap_.at(x->get_operand(0));
|
shared_tile* in = (shared_tile*)tmap_.at(x->get_operand(0));
|
||||||
auto in_order = in->get_order();
|
shared_tile* out = new shared_tile(in->get_ty(), in->get_shapes(), in->get_order(), in->get_pointer(), builder, in->get_offset(), x->get_perm());
|
||||||
std::vector<int> order;
|
|
||||||
for(auto p: x->get_perm())
|
|
||||||
order.push_back(in_order[p->get_value()]);
|
|
||||||
shared_tile* out = new shared_tile(in->get_ty(), in->get_shapes(), order, in->get_pointer(), builder, in->get_offset());
|
|
||||||
tmap_[x] = out;
|
tmap_[x] = out;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1082,8 +1082,8 @@ void selection::lower_hmma_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *fn
|
|||||||
auto ord_a = tiles_->order(dot->get_operand(0));
|
auto ord_a = tiles_->order(dot->get_operand(0));
|
||||||
auto ord_b = tiles_->order(dot->get_operand(1));
|
auto ord_b = tiles_->order(dot->get_operand(1));
|
||||||
|
|
||||||
bool is_a_row = dot->is_a_trans() ^ ord_a[ord_a.size() - 2] == 1;
|
bool is_a_row = ord_a[ord_a.size() - 2] == 1;
|
||||||
bool is_b_row = dot->is_b_trans() ^ ord_b[ord_b.size() - 2] == 1;
|
bool is_b_row = ord_b[ord_b.size() - 2] == 1;
|
||||||
|
|
||||||
if(is_a_row){
|
if(is_a_row){
|
||||||
offset_a_i = builder.CreateAdd(offset_a_i, builder.CreateURem(u_thread_id, builder.getInt32(4)));
|
offset_a_i = builder.CreateAdd(offset_a_i, builder.CreateURem(u_thread_id, builder.getInt32(4)));
|
||||||
@@ -1125,10 +1125,6 @@ void selection::lower_hmma_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *fn
|
|||||||
Value *current_offset_b_i = builder.CreateAdd(offset_b_j, builder.getInt32(pack_j*stride_rep_j*pack_size_1_));
|
Value *current_offset_b_i = builder.CreateAdd(offset_b_j, builder.getInt32(pack_j*stride_rep_j*pack_size_1_));
|
||||||
indices_t idx_a = {current_offset_a_i, builder.CreateAdd(offset_a_k, _K)};
|
indices_t idx_a = {current_offset_a_i, builder.CreateAdd(offset_a_k, _K)};
|
||||||
indices_t idx_b = {current_offset_b_i, builder.CreateAdd(offset_b_k, _K)};
|
indices_t idx_b = {current_offset_b_i, builder.CreateAdd(offset_b_k, _K)};
|
||||||
if(dot->is_a_trans())
|
|
||||||
std::swap(idx_a[0], idx_a[1]);
|
|
||||||
if(!dot->is_b_trans())
|
|
||||||
std::swap(idx_b[0], idx_b[1]);
|
|
||||||
idx_a.insert(idx_a.end(), x.first.begin(), x.first.end());
|
idx_a.insert(idx_a.end(), x.first.begin(), x.first.end());
|
||||||
idx_b.insert(idx_b.end(), x.first.begin(), x.first.end());
|
idx_b.insert(idx_b.end(), x.first.begin(), x.first.end());
|
||||||
Value *ha = TA->get_value(idx_a);
|
Value *ha = TA->get_value(idx_a);
|
||||||
@@ -1188,10 +1184,6 @@ void selection::lower_scanline_dot(ir::dot_inst *dot, LLVMContext &ctx, Function
|
|||||||
// input indices
|
// input indices
|
||||||
indices_t a_idx = {idx[0], builder.getInt32(K)};
|
indices_t a_idx = {idx[0], builder.getInt32(K)};
|
||||||
indices_t b_idx = {builder.getInt32(K), idx[1]};
|
indices_t b_idx = {builder.getInt32(K), idx[1]};
|
||||||
if(dot->is_a_trans())
|
|
||||||
std::swap(a_idx[0], a_idx[1]);
|
|
||||||
if(dot->is_b_trans())
|
|
||||||
std::swap(b_idx[0], b_idx[1]);
|
|
||||||
// add batching dimension
|
// add batching dimension
|
||||||
for(size_t i = 2; i < idx.size(); i++){
|
for(size_t i = 2; i < idx.size(); i++){
|
||||||
a_idx.insert(a_idx.end(), idx[i]);
|
a_idx.insert(a_idx.end(), idx[i]);
|
||||||
@@ -1217,9 +1209,7 @@ void selection::lower_outer_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *f
|
|||||||
Value *res = TD->get_value(idx);
|
Value *res = TD->get_value(idx);
|
||||||
indices_t a_idx = {idx[0], builder.getInt32(0)};
|
indices_t a_idx = {idx[0], builder.getInt32(0)};
|
||||||
indices_t b_idx = {builder.getInt32(0), idx[1]};
|
indices_t b_idx = {builder.getInt32(0), idx[1]};
|
||||||
if(dot->is_a_trans())
|
|
||||||
std::swap(a_idx[0], a_idx[1]);
|
std::swap(a_idx[0], a_idx[1]);
|
||||||
if(dot->is_b_trans())
|
|
||||||
std::swap(b_idx[0], b_idx[1]);
|
std::swap(b_idx[0], b_idx[1]);
|
||||||
Value *a = TA->get_value(a_idx);
|
Value *a = TA->get_value(a_idx);
|
||||||
Value *b = TB->get_value(b_idx);
|
Value *b = TB->get_value(b_idx);
|
||||||
@@ -1243,7 +1233,7 @@ void selection::lower_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *fn, IRB
|
|||||||
Type *c_ty = llvm_type(D->get_type()->get_scalar_ty(), ctx);
|
Type *c_ty = llvm_type(D->get_type()->get_scalar_ty(), ctx);
|
||||||
Function *f_mul_add = Intrinsic::getDeclaration(module, Intrinsic::fmuladd, {c_ty});
|
Function *f_mul_add = Intrinsic::getDeclaration(module, Intrinsic::fmuladd, {c_ty});
|
||||||
auto A_shapes = A->get_type()->get_tile_shapes();
|
auto A_shapes = A->get_type()->get_tile_shapes();
|
||||||
size_t red_axis = dot->is_a_trans() ? 0 : 1;
|
size_t red_axis = 1;
|
||||||
unsigned NK = A_shapes[red_axis];
|
unsigned NK = A_shapes[red_axis];
|
||||||
|
|
||||||
if(NK != 1) {
|
if(NK != 1) {
|
||||||
@@ -1552,8 +1542,8 @@ void selection::run(ir::module &src, Module &dst) {
|
|||||||
offset->addIncoming(next_offset, llvm_inc_block);
|
offset->addIncoming(next_offset, llvm_inc_block);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
unsigned num_bytes = phi->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
|
unsigned num_bytes = inst->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
|
||||||
offset->addIncoming(dst_builder.getInt32(liveness_->num_bytes(phi)/(num_bytes)), llvm_inc_block);
|
offset->addIncoming(dst_builder.getInt32(liveness_->get_buffer(inst)->size / (2*num_bytes)), llvm_inc_block);
|
||||||
}
|
}
|
||||||
ptr->addIncoming(inc_shared->get_pointer(), llvm_inc_block);
|
ptr->addIncoming(inc_shared->get_pointer(), llvm_inc_block);
|
||||||
}
|
}
|
||||||
|
@@ -38,8 +38,8 @@ void membar::add_reference(ir::value *v, interval_vec_t &res){
|
|||||||
return;
|
return;
|
||||||
if(alloc_->has_offset(v)){
|
if(alloc_->has_offset(v)){
|
||||||
unsigned offset = alloc_->offset(v);
|
unsigned offset = alloc_->offset(v);
|
||||||
unsigned num_bytes = liveness_->num_bytes(v);
|
unsigned size = liveness_->get_buffer(v)->size;
|
||||||
res.push_back(interval_t(offset, offset + num_bytes));
|
res.push_back(interval_t(offset, offset + size));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -8,37 +8,8 @@ namespace codegen{
|
|||||||
namespace transform{
|
namespace transform{
|
||||||
|
|
||||||
|
|
||||||
inline bool is_trans(ir::value *v){
|
|
||||||
auto *x = dynamic_cast<ir::trans_inst*>(v);
|
|
||||||
if(!x)
|
|
||||||
return false;
|
|
||||||
std::vector<ir::constant_int*> perm = x->get_perm();
|
|
||||||
std::vector<ir::constant_int*> ref;
|
|
||||||
ir::type *int32_ty = ir::type::get_int32_ty(v->get_type()->get_context());
|
|
||||||
for(size_t i = 0; i < perm.size(); i++)
|
|
||||||
ref.push_back(ir::constant_int::get(int32_ty, i));
|
|
||||||
std::swap(ref[0], ref[1]);
|
|
||||||
// true is perm == ref
|
|
||||||
return std::equal(perm.begin(), perm.end(), ref.begin());
|
|
||||||
}
|
|
||||||
|
|
||||||
inline bool is_hmma(ir::value *v){
|
|
||||||
bool result = false;
|
|
||||||
if(auto *x = dynamic_cast<ir::dot_inst*>(v)){
|
|
||||||
ir::value *a = x->get_operand(0);
|
|
||||||
ir::type *a_ty = a->get_type();
|
|
||||||
ir::value *b = x->get_operand(1);
|
|
||||||
ir::type *b_ty = b->get_type();
|
|
||||||
// inputs have to be FP16
|
|
||||||
result = a_ty->get_scalar_ty()->is_half_ty() && b_ty->get_scalar_ty()->is_half_ty();
|
|
||||||
// reduction has to be multiple of 4
|
|
||||||
// result = result && ((a_ty->get_tile_shapes()[1]->get_value() % 4) == 0);
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
ir::value* rewrite_trans_phi_impl(ir::value *value, ir::builder &builder,
|
ir::value* rewrite_trans_phi_impl(ir::value *value, ir::builder &builder,
|
||||||
const std::vector<ir::constant_int*>& perm) {
|
const std::vector<int>& perm) {
|
||||||
if(auto phi = dynamic_cast<ir::phi_node*>(value)) {
|
if(auto phi = dynamic_cast<ir::phi_node*>(value)) {
|
||||||
// transpose operands
|
// transpose operands
|
||||||
std::vector<ir::value*> incs;
|
std::vector<ir::value*> incs;
|
||||||
@@ -106,9 +77,7 @@ bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){
|
|||||||
ir::value *a = dot->get_operand(0);
|
ir::value *a = dot->get_operand(0);
|
||||||
ir::value *b = dot->get_operand(1);
|
ir::value *b = dot->get_operand(1);
|
||||||
builder.set_insert_point(add);
|
builder.set_insert_point(add);
|
||||||
ir::value * new_dot = builder.insert(ir::dot_inst::create(a, b, other,
|
ir::value * new_dot = builder.insert(ir::dot_inst::create_nn(a, b, other, dot->get_name()));
|
||||||
dot->is_a_trans(), dot->is_b_trans(),
|
|
||||||
dot->get_name()));
|
|
||||||
add->replace_all_uses_with(new_dot);
|
add->replace_all_uses_with(new_dot);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@@ -241,7 +241,7 @@ std::string cu_module::compile_llvm_module(std::unique_ptr<llvm::Module> module,
|
|||||||
cu_module::cu_module(driver::context * context, std::unique_ptr<llvm::Module> ll_module): cu_module(context, compile_llvm_module(std::move(ll_module), context->device())) { }
|
cu_module::cu_module(driver::context * context, std::unique_ptr<llvm::Module> ll_module): cu_module(context, compile_llvm_module(std::move(ll_module), context->device())) { }
|
||||||
|
|
||||||
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
|
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
|
||||||
// std::cout << source << std::endl;
|
std::cout << source << std::endl;
|
||||||
cu_context::context_switcher ctx(*context);
|
cu_context::context_switcher ctx(*context);
|
||||||
// JIT compile source-code
|
// JIT compile source-code
|
||||||
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};
|
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};
|
||||||
|
@@ -322,7 +322,7 @@ value *builder::create_dot(value *A, value *B, value *C, const std::string &name
|
|||||||
return insert(dot_inst::create_nn(A, B, C, name));
|
return insert(dot_inst::create_nn(A, B, C, name));
|
||||||
}
|
}
|
||||||
|
|
||||||
value *builder::create_trans(value *A, const std::vector<ir::constant_int*>& perm, const std::string &name) {
|
value *builder::create_trans(value *A, const std::vector<int>& perm, const std::string &name) {
|
||||||
return insert(trans_inst::create(A, perm, name));
|
return insert(trans_inst::create(A, perm, name));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -536,7 +536,7 @@ instruction* downcast_inst::create(value *arg, const std::string &name, instruct
|
|||||||
|
|
||||||
dot_inst::dot_inst(value *A, value *B, value *C, TransT AT, TransT BT,
|
dot_inst::dot_inst(value *A, value *B, value *C, TransT AT, TransT BT,
|
||||||
const std::string &name, instruction *next)
|
const std::string &name, instruction *next)
|
||||||
: builtin_inst(C->get_type(), INST_DOT, 3, name, next), AT_(AT), BT_(BT) {
|
: builtin_inst(C->get_type(), INST_DOT, 3, name, next) {
|
||||||
set_operand(0, A);
|
set_operand(0, A);
|
||||||
set_operand(1, B);
|
set_operand(1, B);
|
||||||
set_operand(2, C);
|
set_operand(2, C);
|
||||||
@@ -574,31 +574,30 @@ instruction *dot_inst::create_tt(value *A, value *B, value *C,
|
|||||||
// trans instructions
|
// trans instructions
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
ir::type* trans_inst::get_res_ty(ir::type* ty, std::vector<constant_int*> perm) {
|
ir::type* trans_inst::get_res_ty(ir::type* ty, std::vector<int> perm) {
|
||||||
// get argument shapes
|
// get argument shapes
|
||||||
ir::tile_type::tile_shapes_t arg_shapes = ty->get_tile_shapes();
|
ir::tile_type::tile_shapes_t arg_shapes = ty->get_tile_shapes();
|
||||||
// permutate argument shapes
|
// permutate argument shapes
|
||||||
perm = init_perm(ty, perm);
|
perm = init_perm(ty, perm);
|
||||||
ir::tile_type::tile_shapes_t res_shapes = arg_shapes;
|
ir::tile_type::tile_shapes_t res_shapes = arg_shapes;
|
||||||
for(size_t i = 0; i < perm.size(); i++)
|
for(size_t i = 0; i < perm.size(); i++)
|
||||||
res_shapes[i] = arg_shapes[perm[i]->get_value()];
|
res_shapes[i] = arg_shapes[perm[i]];
|
||||||
// construct type
|
// construct type
|
||||||
return tile_type::get(ty->get_scalar_ty(), res_shapes);
|
return tile_type::get(ty->get_scalar_ty(), res_shapes);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<constant_int*> trans_inst::init_perm(ir::type* ty, const std::vector<constant_int*>& perm) {
|
std::vector<int> trans_inst::init_perm(ir::type* ty, const std::vector<int>& perm) {
|
||||||
if(!perm.empty())
|
if(!perm.empty())
|
||||||
return perm;
|
return perm;
|
||||||
auto size = ty->get_tile_shapes().size();
|
auto size = ty->get_tile_shapes().size();
|
||||||
ir::type* int32_ty = type::get_int32_ty(ty->get_context());
|
std::vector<int> result;
|
||||||
std::vector<constant_int*> result;
|
result.push_back(size - 1);
|
||||||
result.push_back(ir::constant_int::get(int32_ty, size - 1));
|
|
||||||
for(size_t i = 0; i < size - 1; i++)
|
for(size_t i = 0; i < size - 1; i++)
|
||||||
result.push_back(ir::constant_int::get(int32_ty, i));
|
result.push_back(i);
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
trans_inst::trans_inst(value *arg, const std::vector<constant_int*>& perm, const std::string &name, instruction *next)
|
trans_inst::trans_inst(value *arg, const std::vector<int> &perm, const std::string &name, instruction *next)
|
||||||
: builtin_inst(get_res_ty(arg->get_type(), perm), INST_TRANS, 1, name, next) {
|
: builtin_inst(get_res_ty(arg->get_type(), perm), INST_TRANS, 1, name, next) {
|
||||||
// sanity check
|
// sanity check
|
||||||
perm_ = init_perm(arg->get_type(), perm);
|
perm_ = init_perm(arg->get_type(), perm);
|
||||||
@@ -607,11 +606,11 @@ trans_inst::trans_inst(value *arg, const std::vector<constant_int*>& perm, const
|
|||||||
set_operand(0, arg);
|
set_operand(0, arg);
|
||||||
}
|
}
|
||||||
|
|
||||||
instruction* trans_inst::create(value *arg, const std::vector<constant_int *> &perm, const std::string &name, instruction *next) {
|
instruction* trans_inst::create(value *arg, const std::vector<int> &perm, const std::string &name, instruction *next) {
|
||||||
return new trans_inst(arg, perm, name, next);
|
return new trans_inst(arg, perm, name, next);
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::vector<constant_int*> trans_inst::get_perm() const {
|
const std::vector<int> trans_inst::get_perm() const {
|
||||||
return perm_;
|
return perm_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -229,6 +229,7 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
|||||||
reassociate.run(module);
|
reassociate.run(module);
|
||||||
dce.run(module);
|
dce.run(module);
|
||||||
cts.run(module);
|
cts.run(module);
|
||||||
|
// ir::print(module, std::cout);
|
||||||
liveness.run(module);
|
liveness.run(module);
|
||||||
allocation.run(module);
|
allocation.run(module);
|
||||||
if(allocation.allocated_size() > context->device()->max_shared_memory())
|
if(allocation.allocated_size() > context->device()->max_shared_memory())
|
||||||
|
@@ -27,9 +27,9 @@ inline rt::function::grid_fn_ty grid2d(size_t M, size_t N) {
|
|||||||
|
|
||||||
|
|
||||||
std::vector<double> do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int32_t K){
|
std::vector<double> do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int32_t K){
|
||||||
typedef half_float::half NumericT;
|
typedef float NumericT;
|
||||||
std::string ty = "half";
|
std::string ty = "float";
|
||||||
cublasDataType_t cuty = CUDA_R_16F;
|
cublasDataType_t cuty = CUDA_R_32F;
|
||||||
size_t dt_nbytes = sizeof(NumericT);
|
size_t dt_nbytes = sizeof(NumericT);
|
||||||
drv::context* context = stream->context();
|
drv::context* context = stream->context();
|
||||||
// leading dimensions
|
// leading dimensions
|
||||||
@@ -45,9 +45,9 @@ std::vector<double> do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, i
|
|||||||
opt.defines.push_back({"TYPE", {ty}});
|
opt.defines.push_back({"TYPE", {ty}});
|
||||||
opt.defines.push_back({"AT", {AT?"1":"0"}});
|
opt.defines.push_back({"AT", {AT?"1":"0"}});
|
||||||
opt.defines.push_back({"BT", {BT?"1":"0"}});
|
opt.defines.push_back({"BT", {BT?"1":"0"}});
|
||||||
opt.defines.push_back({"TM", {"128"}});
|
opt.defines.push_back({"TM", {"64"}});
|
||||||
opt.defines.push_back({"TN", {"128"}});
|
opt.defines.push_back({"TN", {"64"}});
|
||||||
opt.defines.push_back({"TK", {"16"}});
|
opt.defines.push_back({"TK", {"8"}});
|
||||||
opt.num_warps = {4};
|
opt.num_warps = {4};
|
||||||
// create function
|
// create function
|
||||||
rt::function function(src::dot, opt);
|
rt::function function(src::dot, opt);
|
||||||
@@ -79,10 +79,9 @@ int main() {
|
|||||||
// shapes to benchmark
|
// shapes to benchmark
|
||||||
typedef std::tuple<bool, bool, int, int, int> config_t;
|
typedef std::tuple<bool, bool, int, int, int> config_t;
|
||||||
std::vector<config_t> configs;
|
std::vector<config_t> configs;
|
||||||
for(auto x: std::vector<std::array<bool, 2>>{{false, true},
|
for(auto x: std::vector<std::array<bool, 2>>{{false, false}}){
|
||||||
{true, false}, {true, true}}){
|
|
||||||
std::vector<config_t> tmp = {
|
std::vector<config_t> tmp = {
|
||||||
config_t{x[0], x[1], 4096, 4096, 4096}
|
config_t{x[0], x[1], 2048, 2048, 2048}
|
||||||
// config_t{x[0], x[1], 16, 2048, 2048},
|
// config_t{x[0], x[1], 16, 2048, 2048},
|
||||||
// config_t{x[0], x[1], 32, 2048, 2048},
|
// config_t{x[0], x[1], 32, 2048, 2048},
|
||||||
// config_t{x[0], x[1], 64, 2048, 2048},
|
// config_t{x[0], x[1], 64, 2048, 2048},
|
||||||
|
@@ -54,12 +54,12 @@ void dot(TYPE * A, TYPE * B, TYPE * C,
|
|||||||
TYPE a[SHAPE_A] = *pa;
|
TYPE a[SHAPE_A] = *pa;
|
||||||
TYPE b[SHAPE_B] = *pb;
|
TYPE b[SHAPE_B] = *pb;
|
||||||
// reduction loop
|
// reduction loop
|
||||||
for(int k = K; k > 0; k-= TK){
|
for(int k = K; k > TK; k-= TK){
|
||||||
c += USEA @ USEB;
|
c += USEA @ USEB;
|
||||||
pa = pa + TK * STRIDE_AK;
|
pa = pa + TK * STRIDE_AK;
|
||||||
pb = pb + TK * STRIDE_BK;
|
pb = pb + TK * STRIDE_BK;
|
||||||
a = ((bool[SHAPE_A])(k > TK)) ? *pa : 0;
|
a = *pa;
|
||||||
b = ((bool[SHAPE_B])(k > TK)) ? *pb : 0;
|
b = *pb;
|
||||||
}
|
}
|
||||||
// epilogue
|
// epilogue
|
||||||
int rxc[TM] = ridx * TM + 0 ... TM;
|
int rxc[TM] = ridx * TM + 0 ... TM;
|
||||||
|
@@ -139,13 +139,13 @@ int main() {
|
|||||||
// shapes to benchmark
|
// shapes to benchmark
|
||||||
typedef std::tuple<dtype_t, bool, bool, int, int, int, int, int, int, int> config_t;
|
typedef std::tuple<dtype_t, bool, bool, int, int, int, int, int, int, int> config_t;
|
||||||
std::vector<config_t> configs;
|
std::vector<config_t> configs;
|
||||||
for(bool AT: std::array<bool, 2>{false})
|
|
||||||
for(bool BT: std::array<bool, 2>{false})
|
|
||||||
for(int TM: std::vector<int>{32, 64})
|
for(int TM: std::vector<int>{32, 64})
|
||||||
for(int TN: std::vector<int>{32, 64})
|
for(int TN: std::vector<int>{32, 64})
|
||||||
for(int TK: std::vector<int>{16, 32})
|
for(int TK: std::vector<int>{8})
|
||||||
for(int nwarps: std::vector<int>{1, 2, 4, 8}){
|
for(int nwarps: std::vector<int>{1, 2, 4, 8})
|
||||||
configs.push_back(config_t{HALF, AT, BT, 128, 128, 128, TM, TN, TK, nwarps});
|
for(bool AT: std::array<bool, 2>{false, true})
|
||||||
|
for(bool BT: std::array<bool, 2>{false, true}){
|
||||||
|
configs.push_back(config_t{FLOAT, AT, BT, 128, 128, 128, TM, TN, TK, nwarps});
|
||||||
}
|
}
|
||||||
// does the work
|
// does the work
|
||||||
dtype_t dtype;
|
dtype_t dtype;
|
||||||
|
Reference in New Issue
Block a user