more work on padding

This commit is contained in:
Philippe Tillet
2019-09-27 22:15:30 -04:00
parent 575dd06be3
commit ed1b2bc563
19 changed files with 191 additions and 191 deletions

View File

@@ -49,7 +49,7 @@ struct buffer_t {
class liveness { class liveness {
private: private:
typedef std::map<ir::value*, slot_index> indices_map_t; typedef std::map<ir::value*, slot_index> indices_map_t;
typedef std::map<buffer_t, segment> intervals_map_t; typedef std::map<buffer_t*, segment> intervals_map_t;
typedef std::map<ir::value*, bool> has_storage_map_t; typedef std::map<ir::value*, bool> has_storage_map_t;
typedef ir::value* node_t; typedef ir::value* node_t;
typedef std::map <node_t, std::set<node_t>> graph_t; typedef std::map <node_t, std::set<node_t>> graph_t;
@@ -63,24 +63,26 @@ public:
private: private:
void connected_components(node_t x, std::set<node_t> &nodes, graph_t &graph, unsigned group_id); void connected_components(node_t x, std::set<node_t> &nodes, graph_t &graph, buffer_t *buffer);
void extract_double_bufferable(ir::instruction *i); void extract_double_bufferable(ir::instruction *i);
void extract_buffers(ir::instruction *i); void extract_buffers(ir::instruction *i);
void get_parents(ir::instruction *i, std::vector<ir::value *>& res); void get_parents(ir::instruction *i, std::vector<ir::value *>& res);
void make_graph(ir::instruction *i); void make_graph(ir::instruction *i);
bool do_pad(ir::value *x);
public: public:
liveness(tiles *t): tiles_(t){ } liveness(tiles *t): tiles_(t){ }
// padding
unsigned get_pad(ir::value *v) const { return pad_.at(v); }
// buffer size // buffer size
unsigned is_ld_padded(ir::value *x);
unsigned num_bytes(ir::value *x); unsigned num_bytes(ir::value *x);
// accessors // accessors
const intervals_map_t& intervals() const { return intervals_; } const intervals_map_t& intervals() const { return intervals_; }
segment get_interval(buffer_t v) const { return intervals_.at(v); } segment get_interval(buffer_t* v) const { return intervals_.at(v); }
// buffers // buffers
buffer_t get_buffer(ir::value *v) const { return groups_.at(v); } buffer_t* get_buffer(ir::value *v) const { return groups_.at(v); }
std::vector<ir::value*> get_values(buffer_t x) const { return values_.at(x); } std::vector<ir::value*> get_values(buffer_t* x) const { return values_.at(x); }
// double-buffering // double-buffering
bool has_double(ir::value *x) const { return double_.find(x) != double_.end(); } bool has_double(ir::value *x) const { return double_.find(x) != double_.end(); }
double_buffer_info_t get_double(ir::value *x) const { return double_.at(x); } double_buffer_info_t get_double(ir::value *x) const { return double_.at(x); }
@@ -95,12 +97,14 @@ private:
indices_map_t indices; indices_map_t indices;
intervals_map_t intervals_; intervals_map_t intervals_;
std::map<ir::value*, double_buffer_info_t> double_; std::map<ir::value*, double_buffer_info_t> double_;
std::map<ir::value*, size_t> pad_;
std::map<ir::value*, std::vector<ir::value*>> parents_; std::map<ir::value*, std::vector<ir::value*>> parents_;
// graph // graph
std::set<node_t> nodes_; std::set<node_t> nodes_;
graph_t graph_; graph_t graph_;
std::map<ir::value*, buffer_t> groups_; std::vector<buffer_t*> buffers_;
std::map<buffer_t, std::vector<ir::value*>> values_; std::map<ir::value*, buffer_t*> groups_;
std::map<buffer_t*, std::vector<ir::value*>> values_;
}; };
} }

View File

@@ -89,7 +89,7 @@ private:
public: public:
shared_tile(Type* ty, const shapes_t &shapes, const std::vector<int> &order, Value* ptr, Builder &builder, Value* offset = nullptr); shared_tile(Type* ty, const shapes_t &shapes, const std::vector<int> &order, Value* ptr, Builder &builder, Value* offset = nullptr, const std::vector<int>& perm = {});
void set_vector_size(unsigned vector_size); void set_vector_size(unsigned vector_size);
void set_return_mode(bool return_vector); void set_return_mode(bool return_vector);
void set_value(indices_t, Value *); void set_value(indices_t, Value *);
@@ -97,8 +97,9 @@ public:
Value* get_value(indices_t idx); Value* get_value(indices_t idx);
Value* get_pointer() { return ptr_; } Value* get_pointer() { return ptr_; }
Value* get_offset() { return offset_; } Value* get_offset() { return offset_; }
const std::vector<int>& get_perm() { return perm_; }
const std::vector<int>& get_order() { return order_; } const std::vector<int>& get_order() { return order_; }
static Value* shared_offset(Builder& builder, const shapes_t& shapes, const std::vector<int>& order, indices_t idx); static Value* shared_offset(Builder& builder, const shapes_t& shapes, const std::vector<int>& perm, const std::vector<int>& order, indices_t idx);
private: private:
Value *ptr_; Value *ptr_;
@@ -108,6 +109,7 @@ private:
std::map<indices_t, Value*> ptr_cache_; std::map<indices_t, Value*> ptr_cache_;
unsigned vector_size_; unsigned vector_size_;
std::vector<int> order_; std::vector<int> order_;
std::vector<int> perm_;
}; };
// Distribtued tile // Distribtued tile

View File

@@ -135,7 +135,7 @@ public:
value *create_atomic_exch(value *ptr, value *val, const std::string &name = ""); value *create_atomic_exch(value *ptr, value *val, const std::string &name = "");
value *create_atomic_add(value *ptr, value *val, const std::string &name = ""); value *create_atomic_add(value *ptr, value *val, const std::string &name = "");
value *create_dot(value *A, value *B, value *C, const std::string &name = ""); value *create_dot(value *A, value *B, value *C, const std::string &name = "");
value *create_trans(value *A, const std::vector<constant_int *> &perm = {}, const std::string &name = ""); value *create_trans(value *A, const std::vector<int> &perm = {}, const std::string &name = "");
value *create_sqrt(value *A, const std::string &name = ""); value *create_sqrt(value *A, const std::string &name = "");
value *create_reduce(value *A, unsigned axis, const std::string &name = ""); value *create_reduce(value *A, unsigned axis, const std::string &name = "");
value *create_select(value *pred, value *if_value, value *else_value, const std::string &name = ""); value *create_select(value *pred, value *if_value, value *else_value, const std::string &name = "");

View File

@@ -591,7 +591,7 @@ public:
private: private:
dot_inst(value *A, value *B, value *C, TransT AT, TransT BT, const std::string &name, instruction *next); dot_inst(value *A, value *B, value *C, TransT AT, TransT BT, const std::string &name, instruction *next);
std::string repr_impl() const { return std::string("dot.") + ((AT_==NoTrans)?"n":"t") + ((BT_==NoTrans)?"n":"t"); } std::string repr_impl() const { return "dot"; }
public: public:
static instruction *create(value *A, value *B, value *C, bool AT, bool BT, const std::string &name = "", instruction *next = nullptr); static instruction *create(value *A, value *B, value *C, bool AT, bool BT, const std::string &name = "", instruction *next = nullptr);
@@ -599,13 +599,7 @@ public:
static instruction* create_nt(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr); static instruction* create_nt(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr);
static instruction* create_tn(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr); static instruction* create_tn(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr);
static instruction* create_tt(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr); static instruction* create_tt(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr);
bool is_a_trans() { return AT_ == Trans; }
bool is_b_trans() { return BT_ == Trans; }
_TRITON_DEFINE_CLONE(dot_inst) _TRITON_DEFINE_CLONE(dot_inst)
private:
TransT AT_;
TransT BT_;
}; };
//class outer_inst: public builtin_inst { //class outer_inst: public builtin_inst {
@@ -617,20 +611,20 @@ private:
class trans_inst: public builtin_inst { class trans_inst: public builtin_inst {
public: public:
ir::type* get_res_ty(ir::type* in, std::vector<constant_int *> perm); ir::type* get_res_ty(ir::type* in, std::vector<int> perm);
std::vector<constant_int*> init_perm(ir::type* ty, const std::vector<constant_int*>& perm); std::vector<int> init_perm(ir::type* ty, const std::vector<int>& perm);
private: private:
trans_inst(value *arg, const std::vector<constant_int*>& perm, const std::string& name, instruction* next); trans_inst(value *arg, const std::vector<int>& perm, const std::string& name, instruction* next);
std::string repr_impl() const { return "trans"; } std::string repr_impl() const { return "trans"; }
public: public:
static instruction* create(value *arg, const std::vector<constant_int*>& perm = {}, const std::string &name = "", instruction *next = nullptr); static instruction* create(value *arg, const std::vector<int> &perm = {}, const std::string &name = "", instruction *next = nullptr);
const std::vector<constant_int*> get_perm() const; const std::vector<int> get_perm() const;
_TRITON_DEFINE_CLONE(trans_inst) _TRITON_DEFINE_CLONE(trans_inst)
private: private:
std::vector<constant_int*> perm_; std::vector<int> perm_;
}; };
class sqrt_inst: public builtin_inst { class sqrt_inst: public builtin_inst {

View File

@@ -487,9 +487,6 @@ void align::populate(ir::value *v) {
populate_is_constant(v); populate_is_constant(v);
populate_starting_multiple(v); populate_starting_multiple(v);
populate_max_contiguous(v); populate_max_contiguous(v);
// std::cout << v->get_name() << std::endl;
// if(max_contiguous_[v].size() == 2)
// std::cout << max_contiguous_[v][0] << " " << max_contiguous_[v][1] << std::endl;
} }
void align::run(ir::module &mod) { void align::run(ir::module &mod) {

View File

@@ -21,22 +21,22 @@ void allocation::run(ir::module &mod) {
using std::min; using std::min;
typedef std::multimap<unsigned, segment> triples_map_type; typedef std::multimap<unsigned, segment> triples_map_type;
std::vector<buffer_t> I; std::vector<buffer_t*> I;
for(auto x: liveness_->intervals()) for(auto x: liveness_->intervals())
I.push_back(x.first); I.push_back(x.first);
std::vector<buffer_t> J = I; std::vector<buffer_t*> J = I;
triples_map_type H; triples_map_type H;
H.insert({0, segment{0, INT_MAX}}); H.insert({0, segment{0, INT_MAX}});
std::vector<buffer_t> V; std::vector<buffer_t*> V;
std::map<buffer_t, unsigned> starts; std::map<buffer_t*, unsigned> starts;
while(!J.empty()){ while(!J.empty()){
auto h_it = H.begin(); auto h_it = H.begin();
unsigned w = h_it->first; unsigned w = h_it->first;
segment xh = h_it->second; segment xh = h_it->second;
H.erase(h_it); H.erase(h_it);
auto j_it = std::find_if(J.begin(), J.end(), [&](buffer_t JJ){ auto j_it = std::find_if(J.begin(), J.end(), [&](buffer_t* JJ){
segment xj = liveness_->get_interval(JJ); segment xj = liveness_->get_interval(JJ);
bool res = xj.intersect(xh); bool res = xj.intersect(xh);
for(auto val: H) for(auto val: H)
@@ -44,7 +44,7 @@ void allocation::run(ir::module &mod) {
return res; return res;
}); });
if(j_it != J.end()){ if(j_it != J.end()){
unsigned size = j_it->size; unsigned size = (*j_it)->size;
segment xj = liveness_->get_interval(*j_it); segment xj = liveness_->get_interval(*j_it);
starts[*j_it] = w; starts[*j_it] = w;
H.insert({w + size, segment{max(xh.start, xj.start), min(xh.end, xj.end)}}); H.insert({w + size, segment{max(xh.start, xj.start), min(xh.end, xj.end)}});
@@ -58,14 +58,14 @@ void allocation::run(ir::module &mod) {
} }
// Build interference graph // Build interference graph
std::map<buffer_t, std::set<buffer_t>> interferences; std::map<buffer_t*, std::set<buffer_t*>> interferences;
for(buffer_t x: V) for(buffer_t* x: V)
for(buffer_t y: V){ for(buffer_t* y: V){
if(x.id == y.id) if(x->id == y->id)
continue; continue;
unsigned X0 = starts[x], Y0 = starts[y]; unsigned X0 = starts[x], Y0 = starts[y];
unsigned NX = x.size; unsigned NX = x->size;
unsigned NY = y.size; unsigned NY = y->size;
segment XS = {X0, X0 + NX}; segment XS = {X0, X0 + NX};
segment YS = {Y0, Y0 + NY}; segment YS = {Y0, Y0 + NY};
if(liveness_->get_interval(x).intersect(liveness_->get_interval(y)) if(liveness_->get_interval(x).intersect(liveness_->get_interval(y))
@@ -74,17 +74,17 @@ void allocation::run(ir::module &mod) {
} }
// Initialize colors // Initialize colors
std::map<buffer_t, int> colors; std::map<buffer_t*, int> colors;
for(buffer_t X: V) for(buffer_t* X: V)
colors[X] = (X.id==V[0].id)?0:-1; colors[X] = (X->id==V[0]->id)?0:-1;
// First-fit graph coloring // First-fit graph coloring
std::vector<bool> available(V.size()); std::vector<bool> available(V.size());
for(buffer_t x: V){ for(buffer_t* x: V){
// Non-neighboring colors are available // Non-neighboring colors are available
std::fill(available.begin(), available.end(), true); std::fill(available.begin(), available.end(), true);
for(buffer_t Y: interferences[x]){ for(buffer_t* Y: interferences[x]){
int color = colors[Y]; int color = colors[Y];
if(color >= 0) if(color >= 0)
available[color] = false; available[color] = false;
@@ -95,25 +95,24 @@ void allocation::run(ir::module &mod) {
} }
// Finalize allocation // Finalize allocation
for(buffer_t x: V){ for(buffer_t* x: V){
unsigned Adj = 0; unsigned Adj = 0;
for(buffer_t y: interferences[x]) for(buffer_t* y: interferences[x])
Adj = std::max<unsigned>(Adj, starts[y] + y.size); Adj = std::max<unsigned>(Adj, starts[y] + y->size);
// create offsets // create offsets
for(ir::value *v: liveness_->get_values(x)){ for(ir::value *v: liveness_->get_values(x)){
offsets_[v] = starts[x] + colors[x] * Adj; offsets_[v] = starts[x] + colors[x] * Adj;
if(liveness_->has_double(v)){ if(liveness_->has_double(v)){
auto info = liveness_->get_double(v); auto info = liveness_->get_double(v);
offsets_[info.latch] = offsets_[v] + x.size / 2; offsets_[info.latch] = offsets_[v] + x->size / 2;
} }
} }
} }
// Save maximum size of induced memory space // Save maximum size of induced memory space
allocated_size_ = 0; allocated_size_ = 0;
for(auto &x: offsets_){ for(buffer_t* x: V)
allocated_size_ = std::max<size_t>(allocated_size_, x.second + liveness_->get_buffer(x.first).size); allocated_size_ = std::max<size_t>(allocated_size_, starts[x] + x->size);
}
} }
} }

View File

@@ -74,7 +74,7 @@ void axes::update_graph_trans(ir::instruction *i) {
auto perm = trans->get_perm(); auto perm = trans->get_perm();
// add edge between axis perm[d] and axis d // add edge between axis perm[d] and axis d
for(unsigned d = 0; d < perm.size(); d++) for(unsigned d = 0; d < perm.size(); d++)
add_constraint({i, perm[d]->get_value()}, {op, d}); add_constraint({i, perm[d]}, {op, d});
} }
void axes::update_graph_broadcast(ir::instruction *i) { void axes::update_graph_broadcast(ir::instruction *i) {

View File

@@ -58,6 +58,18 @@ void liveness::make_graph(ir::instruction *i) {
graph_[i].insert(latch); graph_[i].insert(latch);
graph_[latch].insert(i); graph_[latch].insert(i);
} }
if(i->get_id() == ir::INST_PHI){
ir::phi_node* phi = (ir::phi_node*)i;
for(ir::value* op: phi->ops()){
auto* iop = dynamic_cast<ir::instruction*>(op);
if(!iop || storage_info.at(iop->get_id()).first != SHARED)
continue;
nodes_.insert(phi);
nodes_.insert(op);
graph_[phi].insert(op);
graph_[op].insert(phi);
}
}
if(i->get_id() == ir::INST_TRANS){ if(i->get_id() == ir::INST_TRANS){
nodes_.insert(i); nodes_.insert(i);
nodes_.insert(i->get_operand(0)); nodes_.insert(i->get_operand(0));
@@ -67,39 +79,63 @@ void liveness::make_graph(ir::instruction *i) {
} }
// connected components // connected components
void liveness::connected_components(node_t x, std::set<node_t> &nodes, graph_t &graph, unsigned group_id) { void liveness::connected_components(node_t x, std::set<node_t> &nodes, graph_t &graph, buffer_t* buffer) {
buffer_t buffer{group_id, num_bytes(x)};
groups_[x] = buffer; groups_[x] = buffer;
values_[buffer].push_back(x); values_[buffer].push_back(x);
if(nodes.find(x) != nodes.end()){ if(nodes.find(x) != nodes.end()){
nodes.erase(x); nodes.erase(x);
for(const node_t &y: graph[x]) for(const node_t &y: graph[x])
connected_components(y, nodes, graph, group_id); connected_components(y, nodes, graph, buffer);
} }
} }
unsigned liveness::is_ld_padded(ir::value *x) { bool liveness::do_pad(ir::value *x) {
if(auto *trans = dynamic_cast<ir::trans_inst*>(x)){ // alignment for matrix product
if(trans->get_perm()[0]->get_value() != 0) if(auto* dot = dynamic_cast<ir::dot_inst*>(x)) {
return 4;
}
auto order = tiles_->order(x); auto order = tiles_->order(x);
bool is_col_major = order[0] == 0; // a
ir::value *a = dot->get_operand(0);\
size_t previous_a = pad_[a];
bool a_trans = dynamic_cast<ir::trans_inst*>(a);
bool a_row = order[0] == 1;
if(tiles_->hmma(x) == HMMA_A_ROW) if(tiles_->hmma(x) == HMMA_A_ROW)
return is_col_major ? 16 : 16; pad_[a] = 16;
if(tiles_->hmma(x) == HMMA_A_COL) else if(tiles_->hmma(x) == HMMA_A_COL)
return is_col_major ? 8 : 8; pad_[a] = 8;
else if(a_trans ^ a_row)
pad_[a] = 4;
else
pad_[a] = 0;
// b
ir::value *b = dot->get_operand(1);
size_t previous_b = pad_[b];
bool b_trans = dynamic_cast<ir::trans_inst*>(a);
bool b_col = order[0] == 0;
if(tiles_->hmma(x) == HMMA_B_COL) if(tiles_->hmma(x) == HMMA_B_COL)
return is_col_major ? 16 : 16; pad_[b] = 16;
if(tiles_->hmma(x) == HMMA_B_ROW) if(tiles_->hmma(x) == HMMA_B_ROW)
return is_col_major ? 8 : 8; pad_[b] = 8;
if(auto* phi = dynamic_cast<ir::phi_node*>(x)) { if(b_trans ^ b_col)
unsigned result = 0; pad_[b] = 4;
for(unsigned i = 0; i < phi->get_num_incoming(); i++) else
result = std::max(result, is_ld_padded(phi->get_incoming_value(i))); pad_[b] = 0;
return result; return previous_a != pad_[a] || previous_b != pad_[b];
} }
return 0; // padding for phi-nodes
if(auto* phi = dynamic_cast<ir::phi_node*>(x)) {
bool has_changed = false;
for(unsigned i = 0; i < phi->get_num_incoming(); i++){
ir::value* op = phi->get_operand(i);
size_t previous = pad_[op];
pad_[op] = std::max(pad_[op], pad_[phi]);
has_changed |= previous != pad_[op];
}
return has_changed;
}
// default -- no pading
size_t previous = pad_[x];
pad_[x] = std::max<int>(previous, 0);
return pad_[x] != previous;
} }
unsigned liveness::num_bytes(ir::value *x) { unsigned liveness::num_bytes(ir::value *x) {
@@ -120,7 +156,8 @@ unsigned liveness::num_bytes(ir::value *x) {
return num_elements * num_bytes * depth; return num_elements * num_bytes * depth;
} }
unsigned num_bytes = x->get_type()->get_primitive_size_in_bits() / 8; unsigned num_bytes = x->get_type()->get_primitive_size_in_bits() / 8;
unsigned pad = is_ld_padded(x); unsigned pad = pad_.at(x);
std::cout << x->get_name() << " " << pad << std::endl;
if(pad > 0){ if(pad > 0){
unsigned ld = x->get_type()->get_tile_shapes()[tiles_->order(x)[0]]; unsigned ld = x->get_type()->get_tile_shapes()[tiles_->order(x)[0]];
num_bytes += pad * num_bytes / ld; num_bytes += pad * num_bytes / ld;
@@ -134,6 +171,7 @@ unsigned liveness::num_bytes(ir::value *x) {
void liveness::run(ir::module &mod) { void liveness::run(ir::module &mod) {
double_.clear(); double_.clear();
indices.clear(); indices.clear();
pad_.clear();
intervals_.clear(); intervals_.clear();
parents_.clear(); parents_.clear();
@@ -142,6 +180,15 @@ void liveness::run(ir::module &mod) {
this->extract_double_bufferable(i); this->extract_double_bufferable(i);
}); });
// Padding information
bool has_changed;
do{
has_changed = false;
ir::for_each_value(mod, [this, &has_changed](ir::value* v){
has_changed |= this->do_pad(v);
});
}while(has_changed);
// Create buffer dependency graph // Create buffer dependency graph
ir::for_each_instruction(mod, [this](ir::instruction* i) { ir::for_each_instruction(mod, [this](ir::instruction* i) {
this->make_graph(i); this->make_graph(i);
@@ -150,7 +197,10 @@ void liveness::run(ir::module &mod) {
// connected components // connected components
unsigned group_id = 0; unsigned group_id = 0;
while(!nodes_.empty()){ while(!nodes_.empty()){
connected_components(*nodes_.begin(), nodes_, graph_, group_id++); buffer_t* buffer = new buffer_t{group_id++};
connected_components(*nodes_.begin(), nodes_, graph_, buffer);
for(ir::value *v: values_.at(buffer))
buffer->size = std::max<int>(buffer->size, num_bytes(v));
} }
// Assigns index to each instruction // Assigns index to each instruction

View File

@@ -40,7 +40,7 @@ bool is_hmma_a_col(ir::value* v) {
for(ir::user *u: v->get_users()) for(ir::user *u: v->get_users())
if(is_hmma_c(u)){ if(is_hmma_c(u)){
ir::dot_inst* dot = (ir::dot_inst*)u; ir::dot_inst* dot = (ir::dot_inst*)u;
if((v == dot->get_operand(0)) && !dot->is_a_trans()) if((v == dot->get_operand(0)))
return true; return true;
} }
} }
@@ -49,7 +49,7 @@ bool is_hmma_a_row(ir::value* v) {
for(ir::user *u: v->get_users()) for(ir::user *u: v->get_users())
if(is_hmma_c(u)){ if(is_hmma_c(u)){
ir::dot_inst* dot = (ir::dot_inst*)u; ir::dot_inst* dot = (ir::dot_inst*)u;
if((v == dot->get_operand(0)) && dot->is_a_trans()) if((v == dot->get_operand(0)))
return true; return true;
} }
} }
@@ -58,7 +58,7 @@ bool is_hmma_b_col(ir::value* v) {
for(ir::user *u: v->get_users()) for(ir::user *u: v->get_users())
if(is_hmma_c(u)){ if(is_hmma_c(u)){
ir::dot_inst* dot = (ir::dot_inst*)u; ir::dot_inst* dot = (ir::dot_inst*)u;
if((v == dot->get_operand(1)) && !dot->is_b_trans()) if((v == dot->get_operand(1)))
return true; return true;
} }
} }
@@ -67,7 +67,7 @@ bool is_hmma_b_row(ir::value* v) {
for(ir::user *u: v->get_users()) for(ir::user *u: v->get_users())
if(is_hmma_c(u)){ if(is_hmma_c(u)){
ir::dot_inst* dot = (ir::dot_inst*)u; ir::dot_inst* dot = (ir::dot_inst*)u;
if((v == dot->get_operand(1)) && dot->is_b_trans()) if((v == dot->get_operand(1)))
return true; return true;
} }
} }
@@ -170,6 +170,7 @@ void tiles::init_scanline_tile(ir::value *i) {
unsigned effective_num_threads = 1; unsigned effective_num_threads = 1;
for(size_t d = 0; d < shapes.size(); d++) for(size_t d = 0; d < shapes.size(); d++)
effective_num_threads *= mts_[axes_->get_id(i, d)]; effective_num_threads *= mts_[axes_->get_id(i, d)];
// std::cout << num_threads << " " << effective_num_threads << std::endl;
if(num_threads != effective_num_threads) if(num_threads != effective_num_threads)
throw std::runtime_error("cannot create a kernel with this amount of warps"); throw std::runtime_error("cannot create a kernel with this amount of warps");
} }
@@ -219,7 +220,7 @@ void tiles::run(ir::module &) {
largest_[i] = *std::max_element(values.begin(), values.end(), cmp); largest_[i] = *std::max_element(values.begin(), values.end(), cmp);
} }
// find out the order of a group // find out the layout ordering of a group
for(size_t i = 0; i < num_groups; i++){ for(size_t i = 0; i < num_groups; i++){
std::set<ir::io_inst*> io; std::set<ir::io_inst*> io;
for(ir::value* v: layout_->values(i)) for(ir::value* v: layout_->values(i))
@@ -239,11 +240,6 @@ void tiles::run(ir::module &) {
order_[i] = order; order_[i] = order;
} }
for(size_t i = 0; i < num_groups; i++){ for(size_t i = 0; i < num_groups; i++){
bool is_hmma_op = hmma_[i] == HMMA_A_COL || hmma_[i] == HMMA_A_ROW ||
hmma_[i] == HMMA_B_COL || hmma_[i] == HMMA_B_ROW;
if(!is_hmma_op)
continue;
// extract copies to shared memory
std::vector<ir::copy_to_shared_inst*> cts; std::vector<ir::copy_to_shared_inst*> cts;
for(ir::value* v: layout_->values(i)) for(ir::value* v: layout_->values(i))
if(auto *x = dynamic_cast<ir::copy_to_shared_inst*>(v)) if(auto *x = dynamic_cast<ir::copy_to_shared_inst*>(v))

View File

@@ -146,26 +146,30 @@ void shared_tile::extract_constant(const indices_t &arg_idx, indices_t &non_cst_
} }
Value* shared_tile::shared_offset(llvm::IRBuilder<> &builder, const shapes_t& shapes, const std::vector<int>& order, indices_t idx) { Value* shared_tile::shared_offset(llvm::IRBuilder<> &builder, const shapes_t& shapes, const std::vector<int>& perm, const std::vector<int>& order, indices_t idx) {
// strides
std::vector<Value*> strides(order.size());
strides[order[0]] = builder.getInt32(1);
for(size_t i = 1; i < idx.size(); i++)
strides[order[i]] = builder.CreateMul(strides[order[i-1]], builder.getInt32(shapes[order[i-1]]));
// result
Value *result = builder.getInt32(0); Value *result = builder.getInt32(0);
result = builder.CreateAdd(result, idx[order[0]]); for(size_t i = 0; i < strides.size(); i++)
Value *ld = builder.getInt32(shapes[order[0]]); result = builder.CreateAdd(result, builder.CreateMul(idx[perm[i]], strides[i]));
for(size_t i = 1; i < idx.size(); i++) {
result = builder.CreateAdd(result, builder.CreateMul(idx[order[i]], ld));
if(i < idx.size() - 1){
ld = builder.CreateMul(ld, builder.getInt32(shapes[order[i]]));
}
}
return result; return result;
} }
shared_tile::shared_tile(Type *ty, const shapes_t &shapes, const std::vector<int>& order, Value *ptr, llvm::IRBuilder<> &builder, Value *offset): shared_tile::shared_tile(Type *ty, const shapes_t &shapes, const std::vector<int>& order, Value *ptr, llvm::IRBuilder<> &builder, Value *offset, const std::vector<int>& perm):
tile(ty, shapes), order_(order), ptr_(ptr), builder_(builder), offset_(offset), vector_size_(1){ tile(ty, shapes), order_(order), ptr_(ptr), builder_(builder), offset_(offset), vector_size_(1), perm_(perm){
return_vector_ = false; return_vector_ = false;
if(perm_.empty()){
perm_.resize(shapes.size());
std::iota(perm_.begin(), perm_.end(), 0);
}
} }
void shared_tile::set_value(indices_t idx, Value *value) { void shared_tile::set_value(indices_t idx, Value *value) {
Value *ptr = builder_.CreateGEP(ptr_, shared_offset(builder_, shapes_, order_, idx)); Value *ptr = builder_.CreateGEP(ptr_, shared_offset(builder_, shapes_, perm_, order_, idx));
unsigned addr_space = ptr->getType()->getPointerAddressSpace(); unsigned addr_space = ptr->getType()->getPointerAddressSpace();
ptr = builder_.CreateBitCast(ptr, value->getType()->getPointerTo(addr_space)); ptr = builder_.CreateBitCast(ptr, value->getType()->getPointerTo(addr_space));
builder_.CreateStore(value, ptr); builder_.CreateStore(value, ptr);
@@ -196,7 +200,7 @@ Value* shared_tile::get_value(indices_t idx) {
// if(isa<Instruction>(non_cst_idx.front())){ // if(isa<Instruction>(non_cst_idx.front())){
// builder_.SetInsertPoint((Instruction*)non_cst_idx.front()); // builder_.SetInsertPoint((Instruction*)non_cst_idx.front());
// } // }
base_ptr = builder_.CreateGEP(ptr_, shared_offset(builder_, shapes_, order_, non_cst_idx)); base_ptr = builder_.CreateGEP(ptr_, shared_offset(builder_, shapes_, perm_, order_, non_cst_idx));
if(vector_size_ > 1){ if(vector_size_ > 1){
Type *vec_ty = VectorType::get(ty, vector_size); Type *vec_ty = VectorType::get(ty, vector_size);
Type *vec_ptr_ty = PointerType::get(vec_ty, base_ptr->getType()->getPointerAddressSpace()); Type *vec_ptr_ty = PointerType::get(vec_ty, base_ptr->getType()->getPointerAddressSpace());
@@ -204,7 +208,7 @@ Value* shared_tile::get_value(indices_t idx) {
} }
// builder_.SetInsertPoint(store); // builder_.SetInsertPoint(store);
} }
Value *offset = shared_offset(builder_, shapes_, order_, cst_idx); Value *offset = shared_offset(builder_, shapes_, perm_, order_, cst_idx);
Value *div = offset; Value *div = offset;
if(vector_size_ > 1) if(vector_size_ > 1)
div = builder_.CreateUDiv(offset, builder_.getInt32(vector_size_)); div = builder_.CreateUDiv(offset, builder_.getInt32(vector_size_));
@@ -725,7 +729,7 @@ void selection::create_shared_tile(ir::value *v, IRBuilder<> &builder, Value *sh
return; return;
auto order = tiles_->order(v); auto order = tiles_->order(v);
auto shapes = v->get_type()->get_tile_shapes(); auto shapes = v->get_type()->get_tile_shapes();
unsigned pad = liveness_->is_ld_padded(v); unsigned pad = liveness_->get_pad(v);
if(pad > 0) if(pad > 0)
shapes[order[0]] += pad; shapes[order[0]] += pad;
Type* ty = llvm_type(v->get_type()->get_scalar_ty(), builder.getContext()); Type* ty = llvm_type(v->get_type()->get_scalar_ty(), builder.getContext());
@@ -923,7 +927,7 @@ void selection::lower_reduce(ir::reduce_inst *x, LLVMContext &ctx, Function *fn,
write_idx.insert(write_idx.begin() + axis, lane); write_idx.insert(write_idx.begin() + axis, lane);
// shared memory write pointer // shared memory write pointer
Value *write_offset = shared_tile::shared_offset(builder, op_tile->get_shapes(), op_tile->get_order(), write_idx); Value *write_offset = shared_tile::shared_offset(builder, op_tile->get_shapes(), {0, 1}, op_tile->get_order(), write_idx);
Value *write_ptr = builder.CreateGEP(base_ptr, write_offset); Value *write_ptr = builder.CreateGEP(base_ptr, write_offset);
// initialize shared memory // initialize shared memory
@@ -936,7 +940,7 @@ void selection::lower_reduce(ir::reduce_inst *x, LLVMContext &ctx, Function *fn,
indices_t current(write_idx.size(), builder.getInt32(0)); indices_t current(write_idx.size(), builder.getInt32(0));
current[axis] = builder.getInt32(i); current[axis] = builder.getInt32(i);
// shared memory offset // shared memory offset
Value *read_offset = shared_tile::shared_offset(builder, op_tile->get_shapes(), op_tile->get_order(), current); Value *read_offset = shared_tile::shared_offset(builder, op_tile->get_shapes(), {0, 1}, op_tile->get_order(), current);
Value *is_active = builder.CreateICmpULT(lane, builder.getInt32(i)); Value *is_active = builder.CreateICmpULT(lane, builder.getInt32(i));
read_offset = builder.CreateSelect(is_active, read_offset, builder.getInt32(0)); read_offset = builder.CreateSelect(is_active, read_offset, builder.getInt32(0));
// shared memory read pointer // shared memory read pointer
@@ -952,7 +956,7 @@ void selection::lower_reduce(ir::reduce_inst *x, LLVMContext &ctx, Function *fn,
// result is on the first lane of shared memory // result is on the first lane of shared memory
indices_t final = write_idx; indices_t final = write_idx;
final[axis] = builder.getInt32(0); final[axis] = builder.getInt32(0);
Value *read_offset = shared_tile::shared_offset(builder, op_tile->get_shapes(), op_tile->get_order(), final); Value *read_offset = shared_tile::shared_offset(builder, op_tile->get_shapes(), {0, 1}, op_tile->get_order(), final);
Value *read_ptr = builder.CreateGEP(base_ptr, read_offset); Value *read_ptr = builder.CreateGEP(base_ptr, read_offset);
tgt_->add_barrier(module, builder); tgt_->add_barrier(module, builder);
result = builder.CreateLoad(read_ptr); result = builder.CreateLoad(read_ptr);
@@ -1041,11 +1045,7 @@ void selection::lower_copy_to_shared(ir::copy_to_shared_inst *x, LLVMContext &ct
void selection::lower_trans(ir::trans_inst *x, LLVMContext &ctx, Function *fn, IRBuilder<> &builder) { void selection::lower_trans(ir::trans_inst *x, LLVMContext &ctx, Function *fn, IRBuilder<> &builder) {
shared_tile* in = (shared_tile*)tmap_.at(x->get_operand(0)); shared_tile* in = (shared_tile*)tmap_.at(x->get_operand(0));
auto in_order = in->get_order(); shared_tile* out = new shared_tile(in->get_ty(), in->get_shapes(), in->get_order(), in->get_pointer(), builder, in->get_offset(), x->get_perm());
std::vector<int> order;
for(auto p: x->get_perm())
order.push_back(in_order[p->get_value()]);
shared_tile* out = new shared_tile(in->get_ty(), in->get_shapes(), order, in->get_pointer(), builder, in->get_offset());
tmap_[x] = out; tmap_[x] = out;
} }
@@ -1082,8 +1082,8 @@ void selection::lower_hmma_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *fn
auto ord_a = tiles_->order(dot->get_operand(0)); auto ord_a = tiles_->order(dot->get_operand(0));
auto ord_b = tiles_->order(dot->get_operand(1)); auto ord_b = tiles_->order(dot->get_operand(1));
bool is_a_row = dot->is_a_trans() ^ ord_a[ord_a.size() - 2] == 1; bool is_a_row = ord_a[ord_a.size() - 2] == 1;
bool is_b_row = dot->is_b_trans() ^ ord_b[ord_b.size() - 2] == 1; bool is_b_row = ord_b[ord_b.size() - 2] == 1;
if(is_a_row){ if(is_a_row){
offset_a_i = builder.CreateAdd(offset_a_i, builder.CreateURem(u_thread_id, builder.getInt32(4))); offset_a_i = builder.CreateAdd(offset_a_i, builder.CreateURem(u_thread_id, builder.getInt32(4)));
@@ -1125,10 +1125,6 @@ void selection::lower_hmma_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *fn
Value *current_offset_b_i = builder.CreateAdd(offset_b_j, builder.getInt32(pack_j*stride_rep_j*pack_size_1_)); Value *current_offset_b_i = builder.CreateAdd(offset_b_j, builder.getInt32(pack_j*stride_rep_j*pack_size_1_));
indices_t idx_a = {current_offset_a_i, builder.CreateAdd(offset_a_k, _K)}; indices_t idx_a = {current_offset_a_i, builder.CreateAdd(offset_a_k, _K)};
indices_t idx_b = {current_offset_b_i, builder.CreateAdd(offset_b_k, _K)}; indices_t idx_b = {current_offset_b_i, builder.CreateAdd(offset_b_k, _K)};
if(dot->is_a_trans())
std::swap(idx_a[0], idx_a[1]);
if(!dot->is_b_trans())
std::swap(idx_b[0], idx_b[1]);
idx_a.insert(idx_a.end(), x.first.begin(), x.first.end()); idx_a.insert(idx_a.end(), x.first.begin(), x.first.end());
idx_b.insert(idx_b.end(), x.first.begin(), x.first.end()); idx_b.insert(idx_b.end(), x.first.begin(), x.first.end());
Value *ha = TA->get_value(idx_a); Value *ha = TA->get_value(idx_a);
@@ -1188,10 +1184,6 @@ void selection::lower_scanline_dot(ir::dot_inst *dot, LLVMContext &ctx, Function
// input indices // input indices
indices_t a_idx = {idx[0], builder.getInt32(K)}; indices_t a_idx = {idx[0], builder.getInt32(K)};
indices_t b_idx = {builder.getInt32(K), idx[1]}; indices_t b_idx = {builder.getInt32(K), idx[1]};
if(dot->is_a_trans())
std::swap(a_idx[0], a_idx[1]);
if(dot->is_b_trans())
std::swap(b_idx[0], b_idx[1]);
// add batching dimension // add batching dimension
for(size_t i = 2; i < idx.size(); i++){ for(size_t i = 2; i < idx.size(); i++){
a_idx.insert(a_idx.end(), idx[i]); a_idx.insert(a_idx.end(), idx[i]);
@@ -1217,9 +1209,7 @@ void selection::lower_outer_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *f
Value *res = TD->get_value(idx); Value *res = TD->get_value(idx);
indices_t a_idx = {idx[0], builder.getInt32(0)}; indices_t a_idx = {idx[0], builder.getInt32(0)};
indices_t b_idx = {builder.getInt32(0), idx[1]}; indices_t b_idx = {builder.getInt32(0), idx[1]};
if(dot->is_a_trans())
std::swap(a_idx[0], a_idx[1]); std::swap(a_idx[0], a_idx[1]);
if(dot->is_b_trans())
std::swap(b_idx[0], b_idx[1]); std::swap(b_idx[0], b_idx[1]);
Value *a = TA->get_value(a_idx); Value *a = TA->get_value(a_idx);
Value *b = TB->get_value(b_idx); Value *b = TB->get_value(b_idx);
@@ -1243,7 +1233,7 @@ void selection::lower_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *fn, IRB
Type *c_ty = llvm_type(D->get_type()->get_scalar_ty(), ctx); Type *c_ty = llvm_type(D->get_type()->get_scalar_ty(), ctx);
Function *f_mul_add = Intrinsic::getDeclaration(module, Intrinsic::fmuladd, {c_ty}); Function *f_mul_add = Intrinsic::getDeclaration(module, Intrinsic::fmuladd, {c_ty});
auto A_shapes = A->get_type()->get_tile_shapes(); auto A_shapes = A->get_type()->get_tile_shapes();
size_t red_axis = dot->is_a_trans() ? 0 : 1; size_t red_axis = 1;
unsigned NK = A_shapes[red_axis]; unsigned NK = A_shapes[red_axis];
if(NK != 1) { if(NK != 1) {
@@ -1552,8 +1542,8 @@ void selection::run(ir::module &src, Module &dst) {
offset->addIncoming(next_offset, llvm_inc_block); offset->addIncoming(next_offset, llvm_inc_block);
} }
else { else {
unsigned num_bytes = phi->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8; unsigned num_bytes = inst->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
offset->addIncoming(dst_builder.getInt32(liveness_->num_bytes(phi)/(num_bytes)), llvm_inc_block); offset->addIncoming(dst_builder.getInt32(liveness_->get_buffer(inst)->size / (2*num_bytes)), llvm_inc_block);
} }
ptr->addIncoming(inc_shared->get_pointer(), llvm_inc_block); ptr->addIncoming(inc_shared->get_pointer(), llvm_inc_block);
} }

View File

@@ -38,8 +38,8 @@ void membar::add_reference(ir::value *v, interval_vec_t &res){
return; return;
if(alloc_->has_offset(v)){ if(alloc_->has_offset(v)){
unsigned offset = alloc_->offset(v); unsigned offset = alloc_->offset(v);
unsigned num_bytes = liveness_->num_bytes(v); unsigned size = liveness_->get_buffer(v)->size;
res.push_back(interval_t(offset, offset + num_bytes)); res.push_back(interval_t(offset, offset + size));
} }
} }

View File

@@ -8,37 +8,8 @@ namespace codegen{
namespace transform{ namespace transform{
inline bool is_trans(ir::value *v){
auto *x = dynamic_cast<ir::trans_inst*>(v);
if(!x)
return false;
std::vector<ir::constant_int*> perm = x->get_perm();
std::vector<ir::constant_int*> ref;
ir::type *int32_ty = ir::type::get_int32_ty(v->get_type()->get_context());
for(size_t i = 0; i < perm.size(); i++)
ref.push_back(ir::constant_int::get(int32_ty, i));
std::swap(ref[0], ref[1]);
// true is perm == ref
return std::equal(perm.begin(), perm.end(), ref.begin());
}
inline bool is_hmma(ir::value *v){
bool result = false;
if(auto *x = dynamic_cast<ir::dot_inst*>(v)){
ir::value *a = x->get_operand(0);
ir::type *a_ty = a->get_type();
ir::value *b = x->get_operand(1);
ir::type *b_ty = b->get_type();
// inputs have to be FP16
result = a_ty->get_scalar_ty()->is_half_ty() && b_ty->get_scalar_ty()->is_half_ty();
// reduction has to be multiple of 4
// result = result && ((a_ty->get_tile_shapes()[1]->get_value() % 4) == 0);
}
return result;
}
ir::value* rewrite_trans_phi_impl(ir::value *value, ir::builder &builder, ir::value* rewrite_trans_phi_impl(ir::value *value, ir::builder &builder,
const std::vector<ir::constant_int*>& perm) { const std::vector<int>& perm) {
if(auto phi = dynamic_cast<ir::phi_node*>(value)) { if(auto phi = dynamic_cast<ir::phi_node*>(value)) {
// transpose operands // transpose operands
std::vector<ir::value*> incs; std::vector<ir::value*> incs;
@@ -106,9 +77,7 @@ bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){
ir::value *a = dot->get_operand(0); ir::value *a = dot->get_operand(0);
ir::value *b = dot->get_operand(1); ir::value *b = dot->get_operand(1);
builder.set_insert_point(add); builder.set_insert_point(add);
ir::value * new_dot = builder.insert(ir::dot_inst::create(a, b, other, ir::value * new_dot = builder.insert(ir::dot_inst::create_nn(a, b, other, dot->get_name()));
dot->is_a_trans(), dot->is_b_trans(),
dot->get_name()));
add->replace_all_uses_with(new_dot); add->replace_all_uses_with(new_dot);
return true; return true;
} }

View File

@@ -241,7 +241,7 @@ std::string cu_module::compile_llvm_module(std::unique_ptr<llvm::Module> module,
cu_module::cu_module(driver::context * context, std::unique_ptr<llvm::Module> ll_module): cu_module(context, compile_llvm_module(std::move(ll_module), context->device())) { } cu_module::cu_module(driver::context * context, std::unique_ptr<llvm::Module> ll_module): cu_module(context, compile_llvm_module(std::move(ll_module), context->device())) { }
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){ cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
// std::cout << source << std::endl; std::cout << source << std::endl;
cu_context::context_switcher ctx(*context); cu_context::context_switcher ctx(*context);
// JIT compile source-code // JIT compile source-code
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER}; CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};

View File

@@ -322,7 +322,7 @@ value *builder::create_dot(value *A, value *B, value *C, const std::string &name
return insert(dot_inst::create_nn(A, B, C, name)); return insert(dot_inst::create_nn(A, B, C, name));
} }
value *builder::create_trans(value *A, const std::vector<ir::constant_int*>& perm, const std::string &name) { value *builder::create_trans(value *A, const std::vector<int>& perm, const std::string &name) {
return insert(trans_inst::create(A, perm, name)); return insert(trans_inst::create(A, perm, name));
} }

View File

@@ -536,7 +536,7 @@ instruction* downcast_inst::create(value *arg, const std::string &name, instruct
dot_inst::dot_inst(value *A, value *B, value *C, TransT AT, TransT BT, dot_inst::dot_inst(value *A, value *B, value *C, TransT AT, TransT BT,
const std::string &name, instruction *next) const std::string &name, instruction *next)
: builtin_inst(C->get_type(), INST_DOT, 3, name, next), AT_(AT), BT_(BT) { : builtin_inst(C->get_type(), INST_DOT, 3, name, next) {
set_operand(0, A); set_operand(0, A);
set_operand(1, B); set_operand(1, B);
set_operand(2, C); set_operand(2, C);
@@ -574,31 +574,30 @@ instruction *dot_inst::create_tt(value *A, value *B, value *C,
// trans instructions // trans instructions
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
ir::type* trans_inst::get_res_ty(ir::type* ty, std::vector<constant_int*> perm) { ir::type* trans_inst::get_res_ty(ir::type* ty, std::vector<int> perm) {
// get argument shapes // get argument shapes
ir::tile_type::tile_shapes_t arg_shapes = ty->get_tile_shapes(); ir::tile_type::tile_shapes_t arg_shapes = ty->get_tile_shapes();
// permutate argument shapes // permutate argument shapes
perm = init_perm(ty, perm); perm = init_perm(ty, perm);
ir::tile_type::tile_shapes_t res_shapes = arg_shapes; ir::tile_type::tile_shapes_t res_shapes = arg_shapes;
for(size_t i = 0; i < perm.size(); i++) for(size_t i = 0; i < perm.size(); i++)
res_shapes[i] = arg_shapes[perm[i]->get_value()]; res_shapes[i] = arg_shapes[perm[i]];
// construct type // construct type
return tile_type::get(ty->get_scalar_ty(), res_shapes); return tile_type::get(ty->get_scalar_ty(), res_shapes);
} }
std::vector<constant_int*> trans_inst::init_perm(ir::type* ty, const std::vector<constant_int*>& perm) { std::vector<int> trans_inst::init_perm(ir::type* ty, const std::vector<int>& perm) {
if(!perm.empty()) if(!perm.empty())
return perm; return perm;
auto size = ty->get_tile_shapes().size(); auto size = ty->get_tile_shapes().size();
ir::type* int32_ty = type::get_int32_ty(ty->get_context()); std::vector<int> result;
std::vector<constant_int*> result; result.push_back(size - 1);
result.push_back(ir::constant_int::get(int32_ty, size - 1));
for(size_t i = 0; i < size - 1; i++) for(size_t i = 0; i < size - 1; i++)
result.push_back(ir::constant_int::get(int32_ty, i)); result.push_back(i);
return result; return result;
} }
trans_inst::trans_inst(value *arg, const std::vector<constant_int*>& perm, const std::string &name, instruction *next) trans_inst::trans_inst(value *arg, const std::vector<int> &perm, const std::string &name, instruction *next)
: builtin_inst(get_res_ty(arg->get_type(), perm), INST_TRANS, 1, name, next) { : builtin_inst(get_res_ty(arg->get_type(), perm), INST_TRANS, 1, name, next) {
// sanity check // sanity check
perm_ = init_perm(arg->get_type(), perm); perm_ = init_perm(arg->get_type(), perm);
@@ -607,11 +606,11 @@ trans_inst::trans_inst(value *arg, const std::vector<constant_int*>& perm, const
set_operand(0, arg); set_operand(0, arg);
} }
instruction* trans_inst::create(value *arg, const std::vector<constant_int *> &perm, const std::string &name, instruction *next) { instruction* trans_inst::create(value *arg, const std::vector<int> &perm, const std::string &name, instruction *next) {
return new trans_inst(arg, perm, name, next); return new trans_inst(arg, perm, name, next);
} }
const std::vector<constant_int*> trans_inst::get_perm() const { const std::vector<int> trans_inst::get_perm() const {
return perm_; return perm_;
} }

View File

@@ -229,6 +229,7 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
reassociate.run(module); reassociate.run(module);
dce.run(module); dce.run(module);
cts.run(module); cts.run(module);
// ir::print(module, std::cout);
liveness.run(module); liveness.run(module);
allocation.run(module); allocation.run(module);
if(allocation.allocated_size() > context->device()->max_shared_memory()) if(allocation.allocated_size() > context->device()->max_shared_memory())

View File

@@ -27,9 +27,9 @@ inline rt::function::grid_fn_ty grid2d(size_t M, size_t N) {
std::vector<double> do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int32_t K){ std::vector<double> do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int32_t K){
typedef half_float::half NumericT; typedef float NumericT;
std::string ty = "half"; std::string ty = "float";
cublasDataType_t cuty = CUDA_R_16F; cublasDataType_t cuty = CUDA_R_32F;
size_t dt_nbytes = sizeof(NumericT); size_t dt_nbytes = sizeof(NumericT);
drv::context* context = stream->context(); drv::context* context = stream->context();
// leading dimensions // leading dimensions
@@ -45,9 +45,9 @@ std::vector<double> do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, i
opt.defines.push_back({"TYPE", {ty}}); opt.defines.push_back({"TYPE", {ty}});
opt.defines.push_back({"AT", {AT?"1":"0"}}); opt.defines.push_back({"AT", {AT?"1":"0"}});
opt.defines.push_back({"BT", {BT?"1":"0"}}); opt.defines.push_back({"BT", {BT?"1":"0"}});
opt.defines.push_back({"TM", {"128"}}); opt.defines.push_back({"TM", {"64"}});
opt.defines.push_back({"TN", {"128"}}); opt.defines.push_back({"TN", {"64"}});
opt.defines.push_back({"TK", {"16"}}); opt.defines.push_back({"TK", {"8"}});
opt.num_warps = {4}; opt.num_warps = {4};
// create function // create function
rt::function function(src::dot, opt); rt::function function(src::dot, opt);
@@ -79,10 +79,9 @@ int main() {
// shapes to benchmark // shapes to benchmark
typedef std::tuple<bool, bool, int, int, int> config_t; typedef std::tuple<bool, bool, int, int, int> config_t;
std::vector<config_t> configs; std::vector<config_t> configs;
for(auto x: std::vector<std::array<bool, 2>>{{false, true}, for(auto x: std::vector<std::array<bool, 2>>{{false, false}}){
{true, false}, {true, true}}){
std::vector<config_t> tmp = { std::vector<config_t> tmp = {
config_t{x[0], x[1], 4096, 4096, 4096} config_t{x[0], x[1], 2048, 2048, 2048}
// config_t{x[0], x[1], 16, 2048, 2048}, // config_t{x[0], x[1], 16, 2048, 2048},
// config_t{x[0], x[1], 32, 2048, 2048}, // config_t{x[0], x[1], 32, 2048, 2048},
// config_t{x[0], x[1], 64, 2048, 2048}, // config_t{x[0], x[1], 64, 2048, 2048},

View File

@@ -54,12 +54,12 @@ void dot(TYPE * A, TYPE * B, TYPE * C,
TYPE a[SHAPE_A] = *pa; TYPE a[SHAPE_A] = *pa;
TYPE b[SHAPE_B] = *pb; TYPE b[SHAPE_B] = *pb;
// reduction loop // reduction loop
for(int k = K; k > 0; k-= TK){ for(int k = K; k > TK; k-= TK){
c += USEA @ USEB; c += USEA @ USEB;
pa = pa + TK * STRIDE_AK; pa = pa + TK * STRIDE_AK;
pb = pb + TK * STRIDE_BK; pb = pb + TK * STRIDE_BK;
a = ((bool[SHAPE_A])(k > TK)) ? *pa : 0; a = *pa;
b = ((bool[SHAPE_B])(k > TK)) ? *pb : 0; b = *pb;
} }
// epilogue // epilogue
int rxc[TM] = ridx * TM + 0 ... TM; int rxc[TM] = ridx * TM + 0 ... TM;

View File

@@ -139,13 +139,13 @@ int main() {
// shapes to benchmark // shapes to benchmark
typedef std::tuple<dtype_t, bool, bool, int, int, int, int, int, int, int> config_t; typedef std::tuple<dtype_t, bool, bool, int, int, int, int, int, int, int> config_t;
std::vector<config_t> configs; std::vector<config_t> configs;
for(bool AT: std::array<bool, 2>{false})
for(bool BT: std::array<bool, 2>{false})
for(int TM: std::vector<int>{32, 64}) for(int TM: std::vector<int>{32, 64})
for(int TN: std::vector<int>{32, 64}) for(int TN: std::vector<int>{32, 64})
for(int TK: std::vector<int>{16, 32}) for(int TK: std::vector<int>{8})
for(int nwarps: std::vector<int>{1, 2, 4, 8}){ for(int nwarps: std::vector<int>{1, 2, 4, 8})
configs.push_back(config_t{HALF, AT, BT, 128, 128, 128, TM, TN, TK, nwarps}); for(bool AT: std::array<bool, 2>{false, true})
for(bool BT: std::array<bool, 2>{false, true}){
configs.push_back(config_t{FLOAT, AT, BT, 128, 128, 128, TM, TN, TK, nwarps});
} }
// does the work // does the work
dtype_t dtype; dtype_t dtype;