[codegen] more cleaning

This commit is contained in:
Philippe Tillet
2019-10-07 18:06:54 -04:00
parent 1783d45bef
commit 650c43ca07
16 changed files with 111 additions and 191 deletions

View File

@@ -5,6 +5,7 @@
#include <set> #include <set>
#include <vector> #include <vector>
#include <memory> #include <memory>
#include "triton/tools/graph.h"
namespace triton{ namespace triton{
@@ -19,10 +20,8 @@ namespace analysis{
class axes { class axes {
typedef std::pair<ir::value*, unsigned> node_t; typedef std::pair<ir::value*, unsigned> node_t;
typedef std::map <node_t, std::set<node_t>> graph_t;
private: private:
void add_constraint(node_t x, node_t y);
// update graph // update graph
void update_graph_store(ir::instruction *i); void update_graph_store(ir::instruction *i);
void update_graph_reduce(ir::instruction *i); void update_graph_reduce(ir::instruction *i);
@@ -32,21 +31,15 @@ private:
void update_graph_dot(ir::instruction *i); void update_graph_dot(ir::instruction *i);
void update_graph_elementwise(ir::instruction *i); void update_graph_elementwise(ir::instruction *i);
void update_graph(ir::instruction *i); void update_graph(ir::instruction *i);
// connected components
void connected_components(node_t x, std::set<node_t> &nodes, graph_t &graph, unsigned group_id);
public: public:
axes(); axes();
void run(ir::module &mod); void run(ir::module &mod);
unsigned get_id(ir::value *value, unsigned ax); unsigned get_id(ir::value *value, unsigned dim);
bool has_id(ir::value *value, unsigned ax);
private: private:
// constraints graph tools::graph<node_t> graph_;
graph_t dependencies_; std::map<node_t, size_t> axes_;
std::set<node_t> nodes_;
// parameter groups
std::map<ir::value*, std::map<unsigned, unsigned>> groups_;
}; };
} }

View File

@@ -5,6 +5,7 @@
#include <set> #include <set>
#include <vector> #include <vector>
#include <memory> #include <memory>
#include "triton/tools/graph.h"
namespace triton{ namespace triton{
@@ -27,29 +28,24 @@ private:
// graph creation // graph creation
void connect(ir::value *x, ir::value *y); void connect(ir::value *x, ir::value *y);
void make_graph(ir::instruction *i); void make_graph(ir::instruction *i);
// connected components
void connected_components(node_t x, std::set<node_t> &nodes, graph_t &graph, unsigned id);
// list the axes of the given value // list the axes of the given value
std::set<int> axes_of(ir::value *value); std::set<int> axes_of(ir::value *value);
public: public:
// constructor // constructor
layout(analysis::axes *axes); layout(analysis::axes *axes);
// run the passes // accessors
unsigned layout_of(ir::value *value) const;
const std::vector<ir::value*>& values_of(unsigned id) const;
size_t num_layouts() const;
// execution
void run(ir::module &mod); void run(ir::module &mod);
// get the layout ID of the given value
unsigned id(ir::value *value) const;
// get the values associates with the given ID
const std::vector<ir::value*>& values(unsigned id) const;
// get number of groups
size_t get_num_groups() const;
private: private:
analysis::axes* axes_; analysis::axes* axes_;
graph_t dependencies_; tools::graph<ir::value*> graph_;
std::set<node_t> nodes_; std::map<ir::value*, size_t> groups_;
std::map<ir::value*, unsigned> groups_; std::map<size_t, std::vector<ir::value*>> values_;
std::map<unsigned, std::vector<ir::value*>> values_;
}; };
} }

View File

@@ -4,6 +4,7 @@
#include <map> #include <map>
#include <set> #include <set>
#include <vector> #include <vector>
#include "triton/tools/graph.h"
namespace triton{ namespace triton{
@@ -41,7 +42,7 @@ struct double_buffer_info_t {
}; };
struct buffer_t { struct buffer_t {
unsigned id; size_t id;
size_t size; size_t size;
bool operator<(buffer_t other) const { return id < other.id; } bool operator<(buffer_t other) const { return id < other.id; }
}; };
@@ -63,7 +64,6 @@ public:
private: private:
void connected_components(node_t x, std::set<node_t> &nodes, graph_t &graph, buffer_t *buffer);
void extract_double_bufferable(ir::instruction *i); void extract_double_bufferable(ir::instruction *i);
void extract_buffers(ir::instruction *i); void extract_buffers(ir::instruction *i);
void get_parents(ir::instruction *i, std::vector<ir::value *>& res); void get_parents(ir::instruction *i, std::vector<ir::value *>& res);
@@ -98,11 +98,8 @@ private:
intervals_map_t intervals_; intervals_map_t intervals_;
std::map<ir::value*, double_buffer_info_t> double_; std::map<ir::value*, double_buffer_info_t> double_;
std::map<ir::value*, size_t> pad_; std::map<ir::value*, size_t> pad_;
std::map<ir::value*, std::vector<ir::value*>> parents_; // buffers
// graph tools::graph<node_t> graph_;
std::set<node_t> nodes_;
graph_t graph_;
std::vector<buffer_t*> buffers_;
std::map<ir::value*, buffer_t*> groups_; std::map<ir::value*, buffer_t*> groups_;
std::map<buffer_t*, std::vector<ir::value*>> values_; std::map<buffer_t*, std::vector<ir::value*>> values_;
}; };

View File

@@ -211,10 +211,10 @@ private:
public: public:
selection(analysis::liveness* liveness, analysis::allocation *alloc, analysis::tiles *tiles, selection(analysis::liveness* liveness, analysis::allocation *alloc, analysis::tiles *tiles,
analysis::align *alignment, analysis::axes *axes, analysis::align *alignment, analysis::axes *axes,
analysis::layout *layouts, transform::coalesce* reorder, target *tgt, unsigned num_warps) analysis::layout *layouts, target *tgt, unsigned num_warps)
: liveness_(liveness), alloc_(alloc), tiles_(tiles), : liveness_(liveness), alloc_(alloc), tiles_(tiles),
alignment_(alignment), a_axes_(axes), layouts_(layouts), alignment_(alignment), a_axes_(axes), layouts_(layouts),
reorder_(reorder), tgt_(tgt), num_warps_(num_warps){ } tgt_(tgt), num_warps_(num_warps){ }
void run(ir::module &src, Module &dst); void run(ir::module &src, Module &dst);
@@ -227,7 +227,6 @@ private:
analysis::axes *a_axes_; analysis::axes *a_axes_;
analysis::layout *layouts_; analysis::layout *layouts_;
analysis::align *alignment_; analysis::align *alignment_;
transform::coalesce *reorder_;
target *tgt_; target *tgt_;
std::map<unsigned, distributed_axis> axes_; std::map<unsigned, distributed_axis> axes_;
Value *sh_mem_ptr_; Value *sh_mem_ptr_;

View File

@@ -38,7 +38,7 @@ inline double bench(std::function<void()> const & op, driver::stream * stream)
double total_time = 0; double total_time = 0;
op(); op();
stream->synchronize(); stream->synchronize();
while(total_time*1e-9 < 1e-3){ while(total_time*1e-9 < 1e-2){
float norm = 1; float norm = 1;
// normalize clock if possible to reduce noise in auto-tuning // normalize clock if possible to reduce noise in auto-tuning
if(auto cu_device = dynamic_cast<const triton::driver::cu_device*>(stream->context()->device())) if(auto cu_device = dynamic_cast<const triton::driver::cu_device*>(stream->context()->device()))

View File

@@ -270,9 +270,9 @@ std::vector<unsigned> align::populate_max_contiguous_binop(ir::binary_operator*
} }
if(x->is_int_add_sub()){ if(x->is_int_add_sub()){
unsigned lvalue = 1, rvalue = 1; unsigned lvalue = 1, rvalue = 1;
if(lhs_cst_info[d].num_cst) if(lhs_cst_info[d].num_cst > 0)
lvalue = gcd(rhs_max_contiguous[d], lhs_cst_info[d].num_cst); lvalue = gcd(rhs_max_contiguous[d], lhs_cst_info[d].num_cst);
if(rhs_cst_info[d].num_cst) if(rhs_cst_info[d].num_cst > 0)
rvalue = gcd(lhs_max_contiguous[d], rhs_cst_info[d].num_cst); rvalue = gcd(lhs_max_contiguous[d], rhs_cst_info[d].num_cst);
value = std::max(lvalue, rvalue); value = std::max(lvalue, rvalue);
} }

View File

@@ -16,22 +16,6 @@ namespace analysis{
axes::axes() {} axes::axes() {}
void axes::add_constraint(node_t x, node_t y) {
size_t shape_x = 1;
size_t shape_y = 1;
if(x.first->get_type()->is_tile_ty())
shape_x = x.first->get_type()->get_tile_shapes()[x.second];
if(y.first->get_type()->is_tile_ty())
shape_y = y.first->get_type()->get_tile_shapes()[y.second];
if(shape_x == 1 && shape_y == 1)
return;
dependencies_[x].insert(y);
dependencies_[y].insert(x);
nodes_.insert(x);
nodes_.insert(y);
}
void axes::update_graph_reduce(ir::instruction *i) { void axes::update_graph_reduce(ir::instruction *i) {
auto* red = static_cast<ir::reduce_inst*>(i); auto* red = static_cast<ir::reduce_inst*>(i);
unsigned axis = red->get_axis(); unsigned axis = red->get_axis();
@@ -41,7 +25,7 @@ void axes::update_graph_reduce(ir::instruction *i) {
for(unsigned d = 0; d < in_shapes.size(); d++){ for(unsigned d = 0; d < in_shapes.size(); d++){
if(d == axis) if(d == axis)
continue; continue;
add_constraint({i, current++}, {arg, d}); graph_.add_edge({i, current++}, {arg, d});
} }
} }
@@ -59,9 +43,9 @@ void axes::update_graph_reshape(ir::instruction *i) {
bool same_shape = res_shapes[d] == op_shapes[current]; bool same_shape = res_shapes[d] == op_shapes[current];
// either add edge between axis or just add a node in the graph // either add edge between axis or just add a node in the graph
if(!is_skewed && same_shape) if(!is_skewed && same_shape)
add_constraint({i, d}, {op, current++}); graph_.add_edge({i, d}, {op, current++});
else else
add_constraint({i, d}, {i, d}); graph_.add_edge({i, d}, {i, d});
// reshaping is skewed // reshaping is skewed
if(res_shapes[d] > 1 && !same_shape) if(res_shapes[d] > 1 && !same_shape)
is_skewed = true; is_skewed = true;
@@ -74,7 +58,7 @@ void axes::update_graph_trans(ir::instruction *i) {
auto perm = trans->get_perm(); auto perm = trans->get_perm();
// add edge between axis perm[d] and axis d // add edge between axis perm[d] and axis d
for(unsigned d = 0; d < perm.size(); d++) for(unsigned d = 0; d < perm.size(); d++)
add_constraint({i, perm[d]}, {op, d}); graph_.add_edge({i, perm[d]}, {op, d});
} }
void axes::update_graph_broadcast(ir::instruction *i) { void axes::update_graph_broadcast(ir::instruction *i) {
@@ -86,7 +70,7 @@ void axes::update_graph_broadcast(ir::instruction *i) {
// add edge between non-broadcast axes // add edge between non-broadcast axes
for(unsigned d = 0; d < shapes.size(); d ++) for(unsigned d = 0; d < shapes.size(); d ++)
if(op_shapes[d] == shapes[d]) if(op_shapes[d] == shapes[d])
add_constraint({i, d}, {op, d}); graph_.add_edge({i, d}, {op, d});
} }
void axes::update_graph_dot(ir::instruction *i) { void axes::update_graph_dot(ir::instruction *i) {
@@ -97,11 +81,11 @@ void axes::update_graph_dot(ir::instruction *i) {
ir::value *D = dot->get_operand(2); ir::value *D = dot->get_operand(2);
// add edges between result and accumulator // add edges between result and accumulator
for(unsigned d = 0; d < shapes.size(); d++) for(unsigned d = 0; d < shapes.size(); d++)
add_constraint({dot, d}, {D, d}); graph_.add_edge({dot, d}, {D, d});
// add edge for batch dimension // add edge for batch dimension
for(unsigned d = 2; d < shapes.size(); d++){ for(unsigned d = 2; d < shapes.size(); d++){
add_constraint({dot, d}, {A, d}); graph_.add_edge({dot, d}, {A, d});
add_constraint({dot, d}, {B, d}); graph_.add_edge({dot, d}, {B, d});
} }
} }
@@ -116,8 +100,8 @@ void axes::update_graph_elementwise(ir::instruction *i) {
for(ir::value* opx: i->ops()) for(ir::value* opx: i->ops())
for(ir::value* opy: i->ops()){ for(ir::value* opy: i->ops()){
if(!i->get_type()->is_void_ty()) if(!i->get_type()->is_void_ty())
add_constraint({i, d}, {opx, d}); graph_.add_edge({i, d}, {opx, d});
add_constraint({opx, d}, {opy, d}); graph_.add_edge({opx, d}, {opy, d});
} }
} }
@@ -136,41 +120,19 @@ void axes::update_graph(ir::instruction *i) {
return; return;
} }
void axes::connected_components(node_t x, std::set<node_t> &nodes, graph_t &graph, unsigned group_id) {
groups_[x.first].insert({x.second, group_id});
if(nodes.find(x) != nodes.end()){
nodes.erase(x);
for(const node_t &y: graph[x])
connected_components(y, nodes, graph, group_id);
}
}
unsigned axes::get_id(ir::value *value, unsigned ax) { unsigned axes::get_id(ir::value *value, unsigned dim) {
unsigned result = groups_.at(value).at(ax); return axes_.at({value, dim});
return result;
} }
bool axes::has_id(ir::value *value, unsigned ax) {
auto it = groups_.find(value);
if(it == groups_.end())
return false;
auto iit = it->second.find(ax);
if(iit == it->second.end())
return false;
return true;
}
void axes::run(ir::module &mod) { void axes::run(ir::module &mod) {
nodes_.clear();
dependencies_.clear();
groups_.clear();
// make graph // make graph
ir::for_each_instruction(mod, [this](ir::instruction *x) { update_graph(x); }); graph_.clear();
// connected components ir::for_each_instruction(mod, [this](ir::instruction *x) {
unsigned group_id = 0; update_graph(x);
while(!nodes_.empty()) });
connected_components(*nodes_.begin(), nodes_, dependencies_, group_id++); // find connected components
graph_.connected_components(nullptr, &axes_);
} }
} }

View File

@@ -21,36 +21,24 @@ std::set<int> layout::axes_of(ir::value *value) {
// create result // create result
std::set<int> result; std::set<int> result;
for(size_t d = 0; d < rank; d++) for(size_t d = 0; d < rank; d++)
if(axes_->has_id(value, d)) result.insert(axes_->get_id(value, d));
result.insert(axes_->get_id(value, d));
return result; return result;
} }
// connected components
void layout::connected_components(node_t x, std::set<node_t> &nodes, graph_t &graph, unsigned group_id) {
groups_[x] = group_id;
values_[group_id].push_back(x);
if(nodes.find(x) != nodes.end()){
nodes.erase(x);
for(const node_t &y: graph[x])
connected_components(y, nodes, graph, group_id);
}
}
// constructor // constructor
layout::layout(analysis::axes *axes) layout::layout(analysis::axes *axes)
: axes_(axes) { } : axes_(axes) { }
// get group id // get group id
unsigned layout::id(ir::value *value) const unsigned layout::layout_of(ir::value *value) const
{ return groups_.at(value); } { return groups_.at(value); }
// get values // get values
const std::vector<ir::value*>& layout::values(unsigned id) const const std::vector<ir::value*>& layout::values_of(unsigned id) const
{ return values_.at(id); } { return values_.at(id); }
// get number of groups // get number of groups
size_t layout::get_num_groups() const size_t layout::num_layouts() const
{ return values_.size(); } { return values_.size(); }
// connect two values // connect two values
@@ -67,12 +55,8 @@ void layout::connect(ir::value *x, ir::value *y) {
std::set_intersection(x_axes.begin(), x_axes.end(), std::set_intersection(x_axes.begin(), x_axes.end(),
y_axes.begin(), y_axes.end(), y_axes.begin(), y_axes.end(),
std::inserter(common, common.begin())); std::inserter(common, common.begin()));
if(!common.empty()){ if(!common.empty())
nodes_.insert(x); graph_.add_edge(x, y);
nodes_.insert(y);
dependencies_[x].insert(y);
dependencies_[y].insert(x);
}
} }
// make graph // make graph
@@ -84,19 +68,16 @@ void layout::make_graph(ir::instruction *i) {
} }
} }
// run
void layout::run(ir::module &mod) { void layout::run(ir::module &mod) {
nodes_.clear();
dependencies_.clear();
groups_.clear();
values_.clear();
// make graph // make graph
ir::for_each_instruction(mod, [this](ir::instruction* i) { make_graph(i); }); graph_.clear();
ir::for_each_instruction(mod, [this](ir::instruction* i) {
make_graph(i);
});
// connected components // connected components
unsigned group_id = 0; values_.clear();
while(!nodes_.empty()){ groups_.clear();
connected_components(*nodes_.begin(), nodes_, dependencies_, group_id++); graph_.connected_components(&values_, &groups_);
}
} }
} }

View File

@@ -53,42 +53,20 @@ void liveness::extract_double_bufferable(ir::instruction *i) {
void liveness::make_graph(ir::instruction *i) { void liveness::make_graph(ir::instruction *i) {
if(has_double(i)){ if(has_double(i)){
ir::value *latch = double_[i].latch; ir::value *latch = double_[i].latch;
nodes_.insert(i); graph_.add_edge(i, latch);
nodes_.insert(latch);
graph_[i].insert(latch);
graph_[latch].insert(i);
} }
if(i->get_id() == ir::INST_PHI){ if(storage_info.at(i->get_id()).first == SHARED){
ir::phi_node* phi = (ir::phi_node*)i; graph_.add_edge(i, i);
for(ir::value* op: phi->ops()){ for(ir::value* op: i->ops()){
auto* iop = dynamic_cast<ir::instruction*>(op); auto* iop = dynamic_cast<ir::instruction*>(op);
if(!iop || storage_info.at(iop->get_id()).first != SHARED) if(!iop || storage_info.at(iop->get_id()).first != SHARED)
continue; continue;
nodes_.insert(phi); graph_.add_edge(i, op);
nodes_.insert(op);
graph_[phi].insert(op);
graph_[op].insert(phi);
} }
} }
if(i->get_id() == ir::INST_TRANS){
nodes_.insert(i);
nodes_.insert(i->get_operand(0));
graph_[i].insert(i->get_operand(0));
graph_[i->get_operand(0)].insert(i);
}
} }
// connected components // connected components
void liveness::connected_components(node_t x, std::set<node_t> &nodes, graph_t &graph, buffer_t* buffer) {
groups_[x] = buffer;
values_[buffer].push_back(x);
if(nodes.find(x) != nodes.end()){
nodes.erase(x);
for(const node_t &y: graph[x])
connected_components(y, nodes, graph, buffer);
}
}
bool is_trans(ir::value *v) { bool is_trans(ir::value *v) {
if(dynamic_cast<ir::trans_inst *>(v)) { if(dynamic_cast<ir::trans_inst *>(v)) {
return true; return true;
@@ -121,12 +99,14 @@ bool liveness::do_pad(ir::value *x) {
pad_[b] = std::max<int>(pad_[b], (24 - b_shapes[b_row ? 1 : 0]) % 32); pad_[b] = std::max<int>(pad_[b], (24 - b_shapes[b_row ? 1 : 0]) % 32);
return a_previous != pad_[a] || b_previous != pad_[b]; return a_previous != pad_[a] || b_previous != pad_[b];
} }
// padding for trans
if(auto* trans = dynamic_cast<ir::trans_inst*>(x)) { if(auto* trans = dynamic_cast<ir::trans_inst*>(x)) {
ir::value *op = trans->get_operand(0); ir::value *op = trans->get_operand(0);
size_t previous = pad_[op]; size_t previous = pad_[op];
pad_[op] = std::max(pad_[op], pad_[x]); pad_[op] = std::max(pad_[op], pad_[x]);
return previous != pad_[op]; return previous != pad_[op];
} }
// padding for copy to shared
if(auto* cts = dynamic_cast<ir::copy_to_shared_inst*>(x)) { if(auto* cts = dynamic_cast<ir::copy_to_shared_inst*>(x)) {
auto cts_order = tiles_->order(cts); auto cts_order = tiles_->order(cts);
ir::value *arg = cts->get_operand(0); ir::value *arg = cts->get_operand(0);
@@ -187,7 +167,7 @@ void liveness::run(ir::module &mod) {
indices.clear(); indices.clear();
pad_.clear(); pad_.clear();
intervals_.clear(); intervals_.clear();
parents_.clear(); graph_.clear();
// Create set of pair of values that can be double-buffered // Create set of pair of values that can be double-buffered
ir::for_each_instruction(mod, [this](ir::instruction* i) { ir::for_each_instruction(mod, [this](ir::instruction* i) {
@@ -209,12 +189,16 @@ void liveness::run(ir::module &mod) {
}); });
// connected components // connected components
unsigned group_id = 0; tools::graph<node_t>::cmap_t cmap;
while(!nodes_.empty()){ tools::graph<node_t>::nmap_t nmap;
buffer_t* buffer = new buffer_t{group_id++}; graph_.connected_components(&cmap, &nmap);
connected_components(*nodes_.begin(), nodes_, graph_, buffer); for(auto x: cmap) {
for(ir::value *v: values_.at(buffer)) buffer_t* buffer = new buffer_t{x.first};
values_[buffer] = x.second;
for(ir::value *v: x.second){
buffer->size = std::max<int>(buffer->size, num_bytes(v)); buffer->size = std::max<int>(buffer->size, num_bytes(v));
groups_[v] = buffer;
}
} }
// Assigns index to each instruction // Assigns index to each instruction
@@ -245,6 +229,8 @@ void liveness::run(ir::module &mod) {
intervals_[x.first] = segment{start, end}; intervals_[x.first] = segment{start, end};
} }
} }
} }

View File

@@ -74,7 +74,7 @@ bool is_hmma_b_row(ir::value* v) {
layout_t tiles::hmma(ir::value *value) { layout_t tiles::hmma(ir::value *value) {
return hmma_.at(layout_->id(value)); return hmma_.at(layout_->layout_of(value));
} }
int tiles::mts(ir::value *value, unsigned ax) { int tiles::mts(ir::value *value, unsigned ax) {
@@ -94,7 +94,7 @@ int tiles::wpt(ir::value *value, unsigned ax) {
} }
std::vector<int> tiles::order(ir::value *v) { std::vector<int> tiles::order(ir::value *v) {
auto ret = order_[layout_->id(v)]; auto ret = order_[layout_->layout_of(v)];
return ret; return ret;
} }
@@ -201,7 +201,9 @@ bool tiles::is_trans(ir::value *v) {
void tiles::run(ir::module &) { void tiles::run(ir::module &) {
hmma_.clear(); hmma_.clear();
largest_.clear(); largest_.clear();
size_t num_groups = layout_->get_num_groups(); order_.clear();
size_t num_groups = layout_->num_layouts();
// helpers // helpers
auto rank = [](ir::value* v) { auto rank = [](ir::value* v) {
ir::type *ty = v->get_type(); ir::type *ty = v->get_type();
@@ -213,7 +215,7 @@ void tiles::run(ir::module &) {
}; };
// find out which groups require hmma layout // find out which groups require hmma layout
for(size_t i = 0; i < num_groups; i++) { for(size_t i = 0; i < num_groups; i++) {
const auto& values = layout_->values(i); const auto& values = layout_->values_of(i);
bool hmma_c = std::any_of(values.begin(), values.end(), &is_hmma_c); bool hmma_c = std::any_of(values.begin(), values.end(), &is_hmma_c);
if(hmma_c) hmma_[i] = HMMA_C; if(hmma_c) hmma_[i] = HMMA_C;
else hmma_[i] = SCANLINE; else hmma_[i] = SCANLINE;
@@ -221,7 +223,7 @@ void tiles::run(ir::module &) {
} }
// find out which value is the largest in each group // find out which value is the largest in each group
for(size_t i = 0; i < num_groups; i++) { for(size_t i = 0; i < num_groups; i++) {
const auto& values = layout_->values(i); const auto& values = layout_->values_of(i);
auto cmp = [&rank](ir::value* x, ir::value *y) { return rank(x) < rank(y); }; auto cmp = [&rank](ir::value* x, ir::value *y) { return rank(x) < rank(y); };
largest_[i] = *std::max_element(values.begin(), values.end(), cmp); largest_[i] = *std::max_element(values.begin(), values.end(), cmp);
} }
@@ -230,7 +232,7 @@ void tiles::run(ir::module &) {
// find out the layout ordering of a group // find out the layout ordering of a group
for(size_t i = 0; i < num_groups; i++){ for(size_t i = 0; i < num_groups; i++){
std::set<ir::io_inst*> io; std::set<ir::io_inst*> io;
for(ir::value* v: layout_->values(i)) for(ir::value* v: layout_->values_of(i))
extract_io_use(v, io); extract_io_use(v, io);
auto cmp = [&rank](ir::io_inst* x, ir::io_inst *y) { auto cmp = [&rank](ir::io_inst* x, ir::io_inst *y) {
return rank(x->get_pointer_operand()) < rank(y->get_pointer_operand()); return rank(x->get_pointer_operand()) < rank(y->get_pointer_operand());
@@ -249,27 +251,27 @@ void tiles::run(ir::module &) {
// matrix multiplication optimizations // matrix multiplication optimizations
for(size_t i = 0; i < num_groups; i++){ for(size_t i = 0; i < num_groups; i++){
std::vector<ir::dot_inst*> dots; std::vector<ir::dot_inst*> dots;
for(ir::value* v: layout_->values(i)) for(ir::value* v: layout_->values_of(i))
if(auto *x = dynamic_cast<ir::dot_inst*>(v)) if(auto *x = dynamic_cast<ir::dot_inst*>(v))
dots.push_back(x); dots.push_back(x);
for(ir::dot_inst* dot: dots){ for(ir::dot_inst* dot: dots){
ir::value* a = dot->get_operand(0); ir::value* a = dot->get_operand(0);
ir::value* b = dot->get_operand(1); ir::value* b = dot->get_operand(1);
if(hmma_.at(layout_->id(dot)) == HMMA_C){ if(hmma_.at(layout_->layout_of(dot)) == HMMA_C){
auto a_val = layout_->values(layout_->id(a)); auto a_val = layout_->values_of(layout_->layout_of(a));
auto b_val = layout_->values(layout_->id(b)); auto b_val = layout_->values_of(layout_->layout_of(b));
for(ir::value *v: a_val) for(ir::value *v: a_val)
if(auto *cts = dynamic_cast<ir::copy_to_shared_inst*>(v)) if(auto *cts = dynamic_cast<ir::copy_to_shared_inst*>(v))
order_[layout_->id(a)] = order_[layout_->id(cts->get_operand(0))]; order_[layout_->layout_of(a)] = order_[layout_->layout_of(cts->get_operand(0))];
for(ir::value *v: b_val) for(ir::value *v: b_val)
if(auto *cts = dynamic_cast<ir::copy_to_shared_inst*>(v)) if(auto *cts = dynamic_cast<ir::copy_to_shared_inst*>(v))
order_[layout_->id(b)] = order_[layout_->id(cts->get_operand(0))]; order_[layout_->layout_of(b)] = order_[layout_->layout_of(cts->get_operand(0))];
} }
else{ else{
std::vector<int> col = {0, 1}; std::vector<int> col = {0, 1};
std::vector<int> row = {1, 0}; std::vector<int> row = {1, 0};
order_[layout_->id(a)] = is_trans(a) ? row : col; order_[layout_->layout_of(a)] = is_trans(a) ? row : col;
order_[layout_->id(b)] = is_trans(b) ? col : row; order_[layout_->layout_of(b)] = is_trans(b) ? col : row;
} }
} }
} }

View File

@@ -67,10 +67,10 @@ ir::value* coalesce::rematerialize(ir::value *x, ir::builder &builder,
void coalesce::run(ir::module &mod) { void coalesce::run(ir::module &mod) {
// find values to rematerialize // find values to rematerialize
size_t num_groups = layout_->get_num_groups(); size_t num_groups = layout_->num_layouts();
std::vector<ir::io_inst*> remat; std::vector<ir::io_inst*> remat;
for(size_t id = 0; id < num_groups; id++) { for(size_t id = 0; id < num_groups; id++) {
const auto& values = layout_->values(id); const auto& values = layout_->values_of(id);
// extract pointers used in ld/st operations // extract pointers used in ld/st operations
std::set<ir::io_inst*> io; std::set<ir::io_inst*> io;
for(ir::value *v: values) for(ir::value *v: values)

View File

@@ -269,7 +269,10 @@ void reassociate::run(ir::module &mod) {
it++; it++;
builder.set_insert_point(*it); builder.set_insert_point(*it);
} }
ir::value *neg_off = builder.create_neg(off); ir::value *_0 = builder.get_int32(0);
if(off->get_type()->is_tile_ty())
_0 = builder.create_splat(_0, off->get_type()->get_tile_shapes());
ir::value *neg_off = builder.create_sub(_0, off);
ir::value *pz_dyn = builder.create_gep(pz, {neg_off}); ir::value *pz_dyn = builder.create_gep(pz, {neg_off});
phi_dyn->add_incoming(pz_dyn, phi->get_incoming_block(idx_z)); phi_dyn->add_incoming(pz_dyn, phi->get_incoming_block(idx_z));
infos[phi_sta].dyn_ptr = phi_dyn; infos[phi_sta].dyn_ptr = phi_dyn;

View File

@@ -200,11 +200,9 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
llvm::LLVMContext ctx; llvm::LLVMContext ctx;
std::unique_ptr<llvm::Module> llvm(new llvm::Module(module.get_name(), ctx)); std::unique_ptr<llvm::Module> llvm(new llvm::Module(module.get_name(), ctx));
// create passes // create passes
codegen::transform::cts cts;
codegen::analysis::align align; codegen::analysis::align align;
codegen::analysis::axes axes; codegen::analysis::axes axes;
codegen::analysis::layout layouts(&axes); codegen::analysis::layout layouts(&axes);
codegen::transform::coalesce coalesce(&align, &layouts);
codegen::analysis::tiles tiles(opt.num_warps, &align, &axes, &layouts); codegen::analysis::tiles tiles(opt.num_warps, &align, &axes, &layouts);
codegen::analysis::liveness liveness(&tiles); codegen::analysis::liveness liveness(&tiles);
codegen::analysis::allocation allocation(&liveness, &tiles); codegen::analysis::allocation allocation(&liveness, &tiles);
@@ -212,11 +210,12 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
codegen::transform::dce dce; codegen::transform::dce dce;
codegen::transform::peephole peephole; codegen::transform::peephole peephole;
codegen::transform::reassociate reassociate(&align); codegen::transform::reassociate reassociate(&align);
codegen::selection selection(&liveness, &allocation, &tiles, &align, &axes, &layouts, &coalesce, target.get(), opt.num_warps); codegen::transform::coalesce coalesce(&align, &layouts);
codegen::transform::cts cts;
codegen::selection selection(&liveness, &allocation, &tiles, &align, &axes, &layouts, target.get(), opt.num_warps);
// run passes // run passes
peephole.run(module); peephole.run(module);
dce.run(module); dce.run(module);
// ir::print(module, std::cout);
align.run(module); align.run(module);
cts.run(module); cts.run(module);
axes.run(module); axes.run(module);
@@ -225,11 +224,15 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
dce.run(module); dce.run(module);
align.run(module); align.run(module);
dce.run(module); dce.run(module);
tiles.run(module);
reassociate.run(module); reassociate.run(module);
// ir::print(module, std::cout);
// exit(EXIT_FAILURE);
dce.run(module); dce.run(module);
cts.run(module); cts.run(module);
// ir::print(module, std::cout); align.run(module);
axes.run(module);
layouts.run(module);
tiles.run(module);
liveness.run(module); liveness.run(module);
allocation.run(module); allocation.run(module);
if(allocation.allocated_size() > context->device()->max_shared_memory()) if(allocation.allocated_size() > context->device()->max_shared_memory())
@@ -238,10 +241,9 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
dce.run(module); dce.run(module);
axes.run(module); axes.run(module);
layouts.run(module); layouts.run(module);
align.run(module);
// ir::print(module, std::cout);
tiles.run(module);
// ir::print(module, std::cout); // ir::print(module, std::cout);
align.run(module);
tiles.run(module);
selection.run(module, *llvm); selection.run(module, *llvm);
// return binary // return binary
std::unique_ptr<driver::module> res(driver::module::create(context, std::move(llvm))); std::unique_ptr<driver::module> res(driver::module::create(context, std::move(llvm)));

View File

@@ -34,7 +34,7 @@ int main() {
for(const auto& c: configs){ for(const auto& c: configs){
std::tie(ord, AT, BT, M, N, K) = c; std::tie(ord, AT, BT, M, N, K) = c;
std::cout << "// " << c << std::flush; std::cout << "// " << c << std::flush;
for(auto perf: bench_dot(stream, HALF, AT, BT, M, N, K, ord, ord)) for(auto perf: bench_dot(stream, FLOAT, AT, BT, M, N, K, ord, ord))
std::cout << ", " << perf << std::flush; std::cout << ", " << perf << std::flush;
std::cout << std::endl; std::cout << std::endl;
} }

View File

@@ -109,10 +109,10 @@ bool triton_dot(drv::stream* stream, bool AT, bool BT,
opt.num_warps = {nwarp}; opt.num_warps = {nwarp};
} }
if(mode == BENCH) { if(mode == BENCH) {
opt.defines.push_back({"TM", {"128"}}); opt.defines.push_back({"TM", {"64", "128"}});
opt.defines.push_back({"TN", {"128"}}); opt.defines.push_back({"TN", {"64", "128"}});
opt.defines.push_back({"TK", {"16"}}); opt.defines.push_back({"TK", {"8"}});
opt.num_warps = {4}; opt.num_warps = {2, 4, 8};
} }
// kernels // kernels

View File

@@ -13,7 +13,7 @@ int main() {
for(int TM: std::vector<int>{32, 64}) for(int TM: std::vector<int>{32, 64})
for(int TN: std::vector<int>{32, 64}) for(int TN: std::vector<int>{32, 64})
for(int TK: std::vector<int>{8}) for(int TK: std::vector<int>{8})
for(int nwarps: std::vector<int>{1, 2, 4, 8}) for(int nwarps: std::vector<int>{1, 4})
for(bool AT: std::array<bool, 2>{false, true}) for(bool AT: std::array<bool, 2>{false, true})
for(bool BT: std::array<bool, 2>{false, true}){ for(bool BT: std::array<bool, 2>{false, true}){
configs.push_back(config_t{FLOAT, AT, BT, 128, 128, 128, TM, TN, TK, nwarps}); configs.push_back(config_t{FLOAT, AT, BT, 128, 128, 128, TM, TN, TK, nwarps});
@@ -29,7 +29,6 @@ int main() {
std::cout << " Pass! " << std::endl; std::cout << " Pass! " << std::endl;
else{ else{
std::cout << " Fail! " << std::endl; std::cout << " Fail! " << std::endl;
exit(EXIT_FAILURE);
} }
} }
} }