[codegen] more cleaning
This commit is contained in:
@@ -5,6 +5,7 @@
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "triton/tools/graph.h"
|
||||
|
||||
namespace triton{
|
||||
|
||||
@@ -19,10 +20,8 @@ namespace analysis{
|
||||
|
||||
class axes {
|
||||
typedef std::pair<ir::value*, unsigned> node_t;
|
||||
typedef std::map <node_t, std::set<node_t>> graph_t;
|
||||
|
||||
private:
|
||||
void add_constraint(node_t x, node_t y);
|
||||
// update graph
|
||||
void update_graph_store(ir::instruction *i);
|
||||
void update_graph_reduce(ir::instruction *i);
|
||||
@@ -32,21 +31,15 @@ private:
|
||||
void update_graph_dot(ir::instruction *i);
|
||||
void update_graph_elementwise(ir::instruction *i);
|
||||
void update_graph(ir::instruction *i);
|
||||
// connected components
|
||||
void connected_components(node_t x, std::set<node_t> &nodes, graph_t &graph, unsigned group_id);
|
||||
|
||||
public:
|
||||
axes();
|
||||
void run(ir::module &mod);
|
||||
unsigned get_id(ir::value *value, unsigned ax);
|
||||
bool has_id(ir::value *value, unsigned ax);
|
||||
unsigned get_id(ir::value *value, unsigned dim);
|
||||
|
||||
private:
|
||||
// constraints graph
|
||||
graph_t dependencies_;
|
||||
std::set<node_t> nodes_;
|
||||
// parameter groups
|
||||
std::map<ir::value*, std::map<unsigned, unsigned>> groups_;
|
||||
tools::graph<node_t> graph_;
|
||||
std::map<node_t, size_t> axes_;
|
||||
};
|
||||
|
||||
}
|
||||
|
@@ -5,6 +5,7 @@
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "triton/tools/graph.h"
|
||||
|
||||
namespace triton{
|
||||
|
||||
@@ -27,29 +28,24 @@ private:
|
||||
// graph creation
|
||||
void connect(ir::value *x, ir::value *y);
|
||||
void make_graph(ir::instruction *i);
|
||||
// connected components
|
||||
void connected_components(node_t x, std::set<node_t> &nodes, graph_t &graph, unsigned id);
|
||||
// list the axes of the given value
|
||||
std::set<int> axes_of(ir::value *value);
|
||||
|
||||
public:
|
||||
// constructor
|
||||
layout(analysis::axes *axes);
|
||||
// run the passes
|
||||
// accessors
|
||||
unsigned layout_of(ir::value *value) const;
|
||||
const std::vector<ir::value*>& values_of(unsigned id) const;
|
||||
size_t num_layouts() const;
|
||||
// execution
|
||||
void run(ir::module &mod);
|
||||
// get the layout ID of the given value
|
||||
unsigned id(ir::value *value) const;
|
||||
// get the values associates with the given ID
|
||||
const std::vector<ir::value*>& values(unsigned id) const;
|
||||
// get number of groups
|
||||
size_t get_num_groups() const;
|
||||
|
||||
private:
|
||||
analysis::axes* axes_;
|
||||
graph_t dependencies_;
|
||||
std::set<node_t> nodes_;
|
||||
std::map<ir::value*, unsigned> groups_;
|
||||
std::map<unsigned, std::vector<ir::value*>> values_;
|
||||
tools::graph<ir::value*> graph_;
|
||||
std::map<ir::value*, size_t> groups_;
|
||||
std::map<size_t, std::vector<ir::value*>> values_;
|
||||
};
|
||||
|
||||
}
|
||||
|
@@ -4,6 +4,7 @@
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include "triton/tools/graph.h"
|
||||
|
||||
namespace triton{
|
||||
|
||||
@@ -41,7 +42,7 @@ struct double_buffer_info_t {
|
||||
};
|
||||
|
||||
struct buffer_t {
|
||||
unsigned id;
|
||||
size_t id;
|
||||
size_t size;
|
||||
bool operator<(buffer_t other) const { return id < other.id; }
|
||||
};
|
||||
@@ -63,7 +64,6 @@ public:
|
||||
|
||||
|
||||
private:
|
||||
void connected_components(node_t x, std::set<node_t> &nodes, graph_t &graph, buffer_t *buffer);
|
||||
void extract_double_bufferable(ir::instruction *i);
|
||||
void extract_buffers(ir::instruction *i);
|
||||
void get_parents(ir::instruction *i, std::vector<ir::value *>& res);
|
||||
@@ -98,11 +98,8 @@ private:
|
||||
intervals_map_t intervals_;
|
||||
std::map<ir::value*, double_buffer_info_t> double_;
|
||||
std::map<ir::value*, size_t> pad_;
|
||||
std::map<ir::value*, std::vector<ir::value*>> parents_;
|
||||
// graph
|
||||
std::set<node_t> nodes_;
|
||||
graph_t graph_;
|
||||
std::vector<buffer_t*> buffers_;
|
||||
// buffers
|
||||
tools::graph<node_t> graph_;
|
||||
std::map<ir::value*, buffer_t*> groups_;
|
||||
std::map<buffer_t*, std::vector<ir::value*>> values_;
|
||||
};
|
||||
|
@@ -211,10 +211,10 @@ private:
|
||||
public:
|
||||
selection(analysis::liveness* liveness, analysis::allocation *alloc, analysis::tiles *tiles,
|
||||
analysis::align *alignment, analysis::axes *axes,
|
||||
analysis::layout *layouts, transform::coalesce* reorder, target *tgt, unsigned num_warps)
|
||||
analysis::layout *layouts, target *tgt, unsigned num_warps)
|
||||
: liveness_(liveness), alloc_(alloc), tiles_(tiles),
|
||||
alignment_(alignment), a_axes_(axes), layouts_(layouts),
|
||||
reorder_(reorder), tgt_(tgt), num_warps_(num_warps){ }
|
||||
tgt_(tgt), num_warps_(num_warps){ }
|
||||
|
||||
void run(ir::module &src, Module &dst);
|
||||
|
||||
@@ -227,7 +227,6 @@ private:
|
||||
analysis::axes *a_axes_;
|
||||
analysis::layout *layouts_;
|
||||
analysis::align *alignment_;
|
||||
transform::coalesce *reorder_;
|
||||
target *tgt_;
|
||||
std::map<unsigned, distributed_axis> axes_;
|
||||
Value *sh_mem_ptr_;
|
||||
|
@@ -38,7 +38,7 @@ inline double bench(std::function<void()> const & op, driver::stream * stream)
|
||||
double total_time = 0;
|
||||
op();
|
||||
stream->synchronize();
|
||||
while(total_time*1e-9 < 1e-3){
|
||||
while(total_time*1e-9 < 1e-2){
|
||||
float norm = 1;
|
||||
// normalize clock if possible to reduce noise in auto-tuning
|
||||
if(auto cu_device = dynamic_cast<const triton::driver::cu_device*>(stream->context()->device()))
|
||||
|
@@ -270,9 +270,9 @@ std::vector<unsigned> align::populate_max_contiguous_binop(ir::binary_operator*
|
||||
}
|
||||
if(x->is_int_add_sub()){
|
||||
unsigned lvalue = 1, rvalue = 1;
|
||||
if(lhs_cst_info[d].num_cst)
|
||||
if(lhs_cst_info[d].num_cst > 0)
|
||||
lvalue = gcd(rhs_max_contiguous[d], lhs_cst_info[d].num_cst);
|
||||
if(rhs_cst_info[d].num_cst)
|
||||
if(rhs_cst_info[d].num_cst > 0)
|
||||
rvalue = gcd(lhs_max_contiguous[d], rhs_cst_info[d].num_cst);
|
||||
value = std::max(lvalue, rvalue);
|
||||
}
|
||||
|
@@ -16,22 +16,6 @@ namespace analysis{
|
||||
|
||||
axes::axes() {}
|
||||
|
||||
void axes::add_constraint(node_t x, node_t y) {
|
||||
size_t shape_x = 1;
|
||||
size_t shape_y = 1;
|
||||
if(x.first->get_type()->is_tile_ty())
|
||||
shape_x = x.first->get_type()->get_tile_shapes()[x.second];
|
||||
if(y.first->get_type()->is_tile_ty())
|
||||
shape_y = y.first->get_type()->get_tile_shapes()[y.second];
|
||||
if(shape_x == 1 && shape_y == 1)
|
||||
return;
|
||||
dependencies_[x].insert(y);
|
||||
dependencies_[y].insert(x);
|
||||
nodes_.insert(x);
|
||||
nodes_.insert(y);
|
||||
}
|
||||
|
||||
|
||||
void axes::update_graph_reduce(ir::instruction *i) {
|
||||
auto* red = static_cast<ir::reduce_inst*>(i);
|
||||
unsigned axis = red->get_axis();
|
||||
@@ -41,7 +25,7 @@ void axes::update_graph_reduce(ir::instruction *i) {
|
||||
for(unsigned d = 0; d < in_shapes.size(); d++){
|
||||
if(d == axis)
|
||||
continue;
|
||||
add_constraint({i, current++}, {arg, d});
|
||||
graph_.add_edge({i, current++}, {arg, d});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -59,9 +43,9 @@ void axes::update_graph_reshape(ir::instruction *i) {
|
||||
bool same_shape = res_shapes[d] == op_shapes[current];
|
||||
// either add edge between axis or just add a node in the graph
|
||||
if(!is_skewed && same_shape)
|
||||
add_constraint({i, d}, {op, current++});
|
||||
graph_.add_edge({i, d}, {op, current++});
|
||||
else
|
||||
add_constraint({i, d}, {i, d});
|
||||
graph_.add_edge({i, d}, {i, d});
|
||||
// reshaping is skewed
|
||||
if(res_shapes[d] > 1 && !same_shape)
|
||||
is_skewed = true;
|
||||
@@ -74,7 +58,7 @@ void axes::update_graph_trans(ir::instruction *i) {
|
||||
auto perm = trans->get_perm();
|
||||
// add edge between axis perm[d] and axis d
|
||||
for(unsigned d = 0; d < perm.size(); d++)
|
||||
add_constraint({i, perm[d]}, {op, d});
|
||||
graph_.add_edge({i, perm[d]}, {op, d});
|
||||
}
|
||||
|
||||
void axes::update_graph_broadcast(ir::instruction *i) {
|
||||
@@ -86,7 +70,7 @@ void axes::update_graph_broadcast(ir::instruction *i) {
|
||||
// add edge between non-broadcast axes
|
||||
for(unsigned d = 0; d < shapes.size(); d ++)
|
||||
if(op_shapes[d] == shapes[d])
|
||||
add_constraint({i, d}, {op, d});
|
||||
graph_.add_edge({i, d}, {op, d});
|
||||
}
|
||||
|
||||
void axes::update_graph_dot(ir::instruction *i) {
|
||||
@@ -97,11 +81,11 @@ void axes::update_graph_dot(ir::instruction *i) {
|
||||
ir::value *D = dot->get_operand(2);
|
||||
// add edges between result and accumulator
|
||||
for(unsigned d = 0; d < shapes.size(); d++)
|
||||
add_constraint({dot, d}, {D, d});
|
||||
graph_.add_edge({dot, d}, {D, d});
|
||||
// add edge for batch dimension
|
||||
for(unsigned d = 2; d < shapes.size(); d++){
|
||||
add_constraint({dot, d}, {A, d});
|
||||
add_constraint({dot, d}, {B, d});
|
||||
graph_.add_edge({dot, d}, {A, d});
|
||||
graph_.add_edge({dot, d}, {B, d});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -116,8 +100,8 @@ void axes::update_graph_elementwise(ir::instruction *i) {
|
||||
for(ir::value* opx: i->ops())
|
||||
for(ir::value* opy: i->ops()){
|
||||
if(!i->get_type()->is_void_ty())
|
||||
add_constraint({i, d}, {opx, d});
|
||||
add_constraint({opx, d}, {opy, d});
|
||||
graph_.add_edge({i, d}, {opx, d});
|
||||
graph_.add_edge({opx, d}, {opy, d});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -136,41 +120,19 @@ void axes::update_graph(ir::instruction *i) {
|
||||
return;
|
||||
}
|
||||
|
||||
void axes::connected_components(node_t x, std::set<node_t> &nodes, graph_t &graph, unsigned group_id) {
|
||||
groups_[x.first].insert({x.second, group_id});
|
||||
if(nodes.find(x) != nodes.end()){
|
||||
nodes.erase(x);
|
||||
for(const node_t &y: graph[x])
|
||||
connected_components(y, nodes, graph, group_id);
|
||||
}
|
||||
}
|
||||
|
||||
unsigned axes::get_id(ir::value *value, unsigned ax) {
|
||||
unsigned result = groups_.at(value).at(ax);
|
||||
return result;
|
||||
unsigned axes::get_id(ir::value *value, unsigned dim) {
|
||||
return axes_.at({value, dim});
|
||||
}
|
||||
|
||||
bool axes::has_id(ir::value *value, unsigned ax) {
|
||||
auto it = groups_.find(value);
|
||||
if(it == groups_.end())
|
||||
return false;
|
||||
auto iit = it->second.find(ax);
|
||||
if(iit == it->second.end())
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
void axes::run(ir::module &mod) {
|
||||
nodes_.clear();
|
||||
dependencies_.clear();
|
||||
groups_.clear();
|
||||
// make graph
|
||||
ir::for_each_instruction(mod, [this](ir::instruction *x) { update_graph(x); });
|
||||
// connected components
|
||||
unsigned group_id = 0;
|
||||
while(!nodes_.empty())
|
||||
connected_components(*nodes_.begin(), nodes_, dependencies_, group_id++);
|
||||
graph_.clear();
|
||||
ir::for_each_instruction(mod, [this](ir::instruction *x) {
|
||||
update_graph(x);
|
||||
});
|
||||
// find connected components
|
||||
graph_.connected_components(nullptr, &axes_);
|
||||
}
|
||||
|
||||
}
|
||||
|
@@ -21,36 +21,24 @@ std::set<int> layout::axes_of(ir::value *value) {
|
||||
// create result
|
||||
std::set<int> result;
|
||||
for(size_t d = 0; d < rank; d++)
|
||||
if(axes_->has_id(value, d))
|
||||
result.insert(axes_->get_id(value, d));
|
||||
result.insert(axes_->get_id(value, d));
|
||||
return result;
|
||||
}
|
||||
|
||||
// connected components
|
||||
void layout::connected_components(node_t x, std::set<node_t> &nodes, graph_t &graph, unsigned group_id) {
|
||||
groups_[x] = group_id;
|
||||
values_[group_id].push_back(x);
|
||||
if(nodes.find(x) != nodes.end()){
|
||||
nodes.erase(x);
|
||||
for(const node_t &y: graph[x])
|
||||
connected_components(y, nodes, graph, group_id);
|
||||
}
|
||||
}
|
||||
|
||||
// constructor
|
||||
layout::layout(analysis::axes *axes)
|
||||
: axes_(axes) { }
|
||||
|
||||
// get group id
|
||||
unsigned layout::id(ir::value *value) const
|
||||
unsigned layout::layout_of(ir::value *value) const
|
||||
{ return groups_.at(value); }
|
||||
|
||||
// get values
|
||||
const std::vector<ir::value*>& layout::values(unsigned id) const
|
||||
const std::vector<ir::value*>& layout::values_of(unsigned id) const
|
||||
{ return values_.at(id); }
|
||||
|
||||
// get number of groups
|
||||
size_t layout::get_num_groups() const
|
||||
size_t layout::num_layouts() const
|
||||
{ return values_.size(); }
|
||||
|
||||
// connect two values
|
||||
@@ -67,12 +55,8 @@ void layout::connect(ir::value *x, ir::value *y) {
|
||||
std::set_intersection(x_axes.begin(), x_axes.end(),
|
||||
y_axes.begin(), y_axes.end(),
|
||||
std::inserter(common, common.begin()));
|
||||
if(!common.empty()){
|
||||
nodes_.insert(x);
|
||||
nodes_.insert(y);
|
||||
dependencies_[x].insert(y);
|
||||
dependencies_[y].insert(x);
|
||||
}
|
||||
if(!common.empty())
|
||||
graph_.add_edge(x, y);
|
||||
}
|
||||
|
||||
// make graph
|
||||
@@ -84,19 +68,16 @@ void layout::make_graph(ir::instruction *i) {
|
||||
}
|
||||
}
|
||||
|
||||
// run
|
||||
void layout::run(ir::module &mod) {
|
||||
nodes_.clear();
|
||||
dependencies_.clear();
|
||||
groups_.clear();
|
||||
values_.clear();
|
||||
// make graph
|
||||
ir::for_each_instruction(mod, [this](ir::instruction* i) { make_graph(i); });
|
||||
graph_.clear();
|
||||
ir::for_each_instruction(mod, [this](ir::instruction* i) {
|
||||
make_graph(i);
|
||||
});
|
||||
// connected components
|
||||
unsigned group_id = 0;
|
||||
while(!nodes_.empty()){
|
||||
connected_components(*nodes_.begin(), nodes_, dependencies_, group_id++);
|
||||
}
|
||||
values_.clear();
|
||||
groups_.clear();
|
||||
graph_.connected_components(&values_, &groups_);
|
||||
}
|
||||
|
||||
}
|
||||
|
@@ -53,42 +53,20 @@ void liveness::extract_double_bufferable(ir::instruction *i) {
|
||||
void liveness::make_graph(ir::instruction *i) {
|
||||
if(has_double(i)){
|
||||
ir::value *latch = double_[i].latch;
|
||||
nodes_.insert(i);
|
||||
nodes_.insert(latch);
|
||||
graph_[i].insert(latch);
|
||||
graph_[latch].insert(i);
|
||||
graph_.add_edge(i, latch);
|
||||
}
|
||||
if(i->get_id() == ir::INST_PHI){
|
||||
ir::phi_node* phi = (ir::phi_node*)i;
|
||||
for(ir::value* op: phi->ops()){
|
||||
if(storage_info.at(i->get_id()).first == SHARED){
|
||||
graph_.add_edge(i, i);
|
||||
for(ir::value* op: i->ops()){
|
||||
auto* iop = dynamic_cast<ir::instruction*>(op);
|
||||
if(!iop || storage_info.at(iop->get_id()).first != SHARED)
|
||||
continue;
|
||||
nodes_.insert(phi);
|
||||
nodes_.insert(op);
|
||||
graph_[phi].insert(op);
|
||||
graph_[op].insert(phi);
|
||||
graph_.add_edge(i, op);
|
||||
}
|
||||
}
|
||||
if(i->get_id() == ir::INST_TRANS){
|
||||
nodes_.insert(i);
|
||||
nodes_.insert(i->get_operand(0));
|
||||
graph_[i].insert(i->get_operand(0));
|
||||
graph_[i->get_operand(0)].insert(i);
|
||||
}
|
||||
}
|
||||
|
||||
// connected components
|
||||
void liveness::connected_components(node_t x, std::set<node_t> &nodes, graph_t &graph, buffer_t* buffer) {
|
||||
groups_[x] = buffer;
|
||||
values_[buffer].push_back(x);
|
||||
if(nodes.find(x) != nodes.end()){
|
||||
nodes.erase(x);
|
||||
for(const node_t &y: graph[x])
|
||||
connected_components(y, nodes, graph, buffer);
|
||||
}
|
||||
}
|
||||
|
||||
bool is_trans(ir::value *v) {
|
||||
if(dynamic_cast<ir::trans_inst *>(v)) {
|
||||
return true;
|
||||
@@ -121,12 +99,14 @@ bool liveness::do_pad(ir::value *x) {
|
||||
pad_[b] = std::max<int>(pad_[b], (24 - b_shapes[b_row ? 1 : 0]) % 32);
|
||||
return a_previous != pad_[a] || b_previous != pad_[b];
|
||||
}
|
||||
// padding for trans
|
||||
if(auto* trans = dynamic_cast<ir::trans_inst*>(x)) {
|
||||
ir::value *op = trans->get_operand(0);
|
||||
size_t previous = pad_[op];
|
||||
pad_[op] = std::max(pad_[op], pad_[x]);
|
||||
return previous != pad_[op];
|
||||
}
|
||||
// padding for copy to shared
|
||||
if(auto* cts = dynamic_cast<ir::copy_to_shared_inst*>(x)) {
|
||||
auto cts_order = tiles_->order(cts);
|
||||
ir::value *arg = cts->get_operand(0);
|
||||
@@ -187,7 +167,7 @@ void liveness::run(ir::module &mod) {
|
||||
indices.clear();
|
||||
pad_.clear();
|
||||
intervals_.clear();
|
||||
parents_.clear();
|
||||
graph_.clear();
|
||||
|
||||
// Create set of pair of values that can be double-buffered
|
||||
ir::for_each_instruction(mod, [this](ir::instruction* i) {
|
||||
@@ -209,12 +189,16 @@ void liveness::run(ir::module &mod) {
|
||||
});
|
||||
|
||||
// connected components
|
||||
unsigned group_id = 0;
|
||||
while(!nodes_.empty()){
|
||||
buffer_t* buffer = new buffer_t{group_id++};
|
||||
connected_components(*nodes_.begin(), nodes_, graph_, buffer);
|
||||
for(ir::value *v: values_.at(buffer))
|
||||
tools::graph<node_t>::cmap_t cmap;
|
||||
tools::graph<node_t>::nmap_t nmap;
|
||||
graph_.connected_components(&cmap, &nmap);
|
||||
for(auto x: cmap) {
|
||||
buffer_t* buffer = new buffer_t{x.first};
|
||||
values_[buffer] = x.second;
|
||||
for(ir::value *v: x.second){
|
||||
buffer->size = std::max<int>(buffer->size, num_bytes(v));
|
||||
groups_[v] = buffer;
|
||||
}
|
||||
}
|
||||
|
||||
// Assigns index to each instruction
|
||||
@@ -245,6 +229,8 @@ void liveness::run(ir::module &mod) {
|
||||
intervals_[x.first] = segment{start, end};
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
@@ -74,7 +74,7 @@ bool is_hmma_b_row(ir::value* v) {
|
||||
|
||||
|
||||
layout_t tiles::hmma(ir::value *value) {
|
||||
return hmma_.at(layout_->id(value));
|
||||
return hmma_.at(layout_->layout_of(value));
|
||||
}
|
||||
|
||||
int tiles::mts(ir::value *value, unsigned ax) {
|
||||
@@ -94,7 +94,7 @@ int tiles::wpt(ir::value *value, unsigned ax) {
|
||||
}
|
||||
|
||||
std::vector<int> tiles::order(ir::value *v) {
|
||||
auto ret = order_[layout_->id(v)];
|
||||
auto ret = order_[layout_->layout_of(v)];
|
||||
return ret;
|
||||
}
|
||||
|
||||
@@ -201,7 +201,9 @@ bool tiles::is_trans(ir::value *v) {
|
||||
void tiles::run(ir::module &) {
|
||||
hmma_.clear();
|
||||
largest_.clear();
|
||||
size_t num_groups = layout_->get_num_groups();
|
||||
order_.clear();
|
||||
|
||||
size_t num_groups = layout_->num_layouts();
|
||||
// helpers
|
||||
auto rank = [](ir::value* v) {
|
||||
ir::type *ty = v->get_type();
|
||||
@@ -213,7 +215,7 @@ void tiles::run(ir::module &) {
|
||||
};
|
||||
// find out which groups require hmma layout
|
||||
for(size_t i = 0; i < num_groups; i++) {
|
||||
const auto& values = layout_->values(i);
|
||||
const auto& values = layout_->values_of(i);
|
||||
bool hmma_c = std::any_of(values.begin(), values.end(), &is_hmma_c);
|
||||
if(hmma_c) hmma_[i] = HMMA_C;
|
||||
else hmma_[i] = SCANLINE;
|
||||
@@ -221,7 +223,7 @@ void tiles::run(ir::module &) {
|
||||
}
|
||||
// find out which value is the largest in each group
|
||||
for(size_t i = 0; i < num_groups; i++) {
|
||||
const auto& values = layout_->values(i);
|
||||
const auto& values = layout_->values_of(i);
|
||||
auto cmp = [&rank](ir::value* x, ir::value *y) { return rank(x) < rank(y); };
|
||||
largest_[i] = *std::max_element(values.begin(), values.end(), cmp);
|
||||
}
|
||||
@@ -230,7 +232,7 @@ void tiles::run(ir::module &) {
|
||||
// find out the layout ordering of a group
|
||||
for(size_t i = 0; i < num_groups; i++){
|
||||
std::set<ir::io_inst*> io;
|
||||
for(ir::value* v: layout_->values(i))
|
||||
for(ir::value* v: layout_->values_of(i))
|
||||
extract_io_use(v, io);
|
||||
auto cmp = [&rank](ir::io_inst* x, ir::io_inst *y) {
|
||||
return rank(x->get_pointer_operand()) < rank(y->get_pointer_operand());
|
||||
@@ -249,27 +251,27 @@ void tiles::run(ir::module &) {
|
||||
// matrix multiplication optimizations
|
||||
for(size_t i = 0; i < num_groups; i++){
|
||||
std::vector<ir::dot_inst*> dots;
|
||||
for(ir::value* v: layout_->values(i))
|
||||
for(ir::value* v: layout_->values_of(i))
|
||||
if(auto *x = dynamic_cast<ir::dot_inst*>(v))
|
||||
dots.push_back(x);
|
||||
for(ir::dot_inst* dot: dots){
|
||||
ir::value* a = dot->get_operand(0);
|
||||
ir::value* b = dot->get_operand(1);
|
||||
if(hmma_.at(layout_->id(dot)) == HMMA_C){
|
||||
auto a_val = layout_->values(layout_->id(a));
|
||||
auto b_val = layout_->values(layout_->id(b));
|
||||
if(hmma_.at(layout_->layout_of(dot)) == HMMA_C){
|
||||
auto a_val = layout_->values_of(layout_->layout_of(a));
|
||||
auto b_val = layout_->values_of(layout_->layout_of(b));
|
||||
for(ir::value *v: a_val)
|
||||
if(auto *cts = dynamic_cast<ir::copy_to_shared_inst*>(v))
|
||||
order_[layout_->id(a)] = order_[layout_->id(cts->get_operand(0))];
|
||||
order_[layout_->layout_of(a)] = order_[layout_->layout_of(cts->get_operand(0))];
|
||||
for(ir::value *v: b_val)
|
||||
if(auto *cts = dynamic_cast<ir::copy_to_shared_inst*>(v))
|
||||
order_[layout_->id(b)] = order_[layout_->id(cts->get_operand(0))];
|
||||
order_[layout_->layout_of(b)] = order_[layout_->layout_of(cts->get_operand(0))];
|
||||
}
|
||||
else{
|
||||
std::vector<int> col = {0, 1};
|
||||
std::vector<int> row = {1, 0};
|
||||
order_[layout_->id(a)] = is_trans(a) ? row : col;
|
||||
order_[layout_->id(b)] = is_trans(b) ? col : row;
|
||||
order_[layout_->layout_of(a)] = is_trans(a) ? row : col;
|
||||
order_[layout_->layout_of(b)] = is_trans(b) ? col : row;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -67,10 +67,10 @@ ir::value* coalesce::rematerialize(ir::value *x, ir::builder &builder,
|
||||
|
||||
void coalesce::run(ir::module &mod) {
|
||||
// find values to rematerialize
|
||||
size_t num_groups = layout_->get_num_groups();
|
||||
size_t num_groups = layout_->num_layouts();
|
||||
std::vector<ir::io_inst*> remat;
|
||||
for(size_t id = 0; id < num_groups; id++) {
|
||||
const auto& values = layout_->values(id);
|
||||
const auto& values = layout_->values_of(id);
|
||||
// extract pointers used in ld/st operations
|
||||
std::set<ir::io_inst*> io;
|
||||
for(ir::value *v: values)
|
||||
|
@@ -269,7 +269,10 @@ void reassociate::run(ir::module &mod) {
|
||||
it++;
|
||||
builder.set_insert_point(*it);
|
||||
}
|
||||
ir::value *neg_off = builder.create_neg(off);
|
||||
ir::value *_0 = builder.get_int32(0);
|
||||
if(off->get_type()->is_tile_ty())
|
||||
_0 = builder.create_splat(_0, off->get_type()->get_tile_shapes());
|
||||
ir::value *neg_off = builder.create_sub(_0, off);
|
||||
ir::value *pz_dyn = builder.create_gep(pz, {neg_off});
|
||||
phi_dyn->add_incoming(pz_dyn, phi->get_incoming_block(idx_z));
|
||||
infos[phi_sta].dyn_ptr = phi_dyn;
|
||||
|
@@ -200,11 +200,9 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
||||
llvm::LLVMContext ctx;
|
||||
std::unique_ptr<llvm::Module> llvm(new llvm::Module(module.get_name(), ctx));
|
||||
// create passes
|
||||
codegen::transform::cts cts;
|
||||
codegen::analysis::align align;
|
||||
codegen::analysis::axes axes;
|
||||
codegen::analysis::layout layouts(&axes);
|
||||
codegen::transform::coalesce coalesce(&align, &layouts);
|
||||
codegen::analysis::tiles tiles(opt.num_warps, &align, &axes, &layouts);
|
||||
codegen::analysis::liveness liveness(&tiles);
|
||||
codegen::analysis::allocation allocation(&liveness, &tiles);
|
||||
@@ -212,11 +210,12 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
||||
codegen::transform::dce dce;
|
||||
codegen::transform::peephole peephole;
|
||||
codegen::transform::reassociate reassociate(&align);
|
||||
codegen::selection selection(&liveness, &allocation, &tiles, &align, &axes, &layouts, &coalesce, target.get(), opt.num_warps);
|
||||
codegen::transform::coalesce coalesce(&align, &layouts);
|
||||
codegen::transform::cts cts;
|
||||
codegen::selection selection(&liveness, &allocation, &tiles, &align, &axes, &layouts, target.get(), opt.num_warps);
|
||||
// run passes
|
||||
peephole.run(module);
|
||||
dce.run(module);
|
||||
// ir::print(module, std::cout);
|
||||
align.run(module);
|
||||
cts.run(module);
|
||||
axes.run(module);
|
||||
@@ -225,11 +224,15 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
||||
dce.run(module);
|
||||
align.run(module);
|
||||
dce.run(module);
|
||||
tiles.run(module);
|
||||
reassociate.run(module);
|
||||
// ir::print(module, std::cout);
|
||||
// exit(EXIT_FAILURE);
|
||||
dce.run(module);
|
||||
cts.run(module);
|
||||
// ir::print(module, std::cout);
|
||||
align.run(module);
|
||||
axes.run(module);
|
||||
layouts.run(module);
|
||||
tiles.run(module);
|
||||
liveness.run(module);
|
||||
allocation.run(module);
|
||||
if(allocation.allocated_size() > context->device()->max_shared_memory())
|
||||
@@ -238,10 +241,9 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
||||
dce.run(module);
|
||||
axes.run(module);
|
||||
layouts.run(module);
|
||||
align.run(module);
|
||||
// ir::print(module, std::cout);
|
||||
tiles.run(module);
|
||||
// ir::print(module, std::cout);
|
||||
align.run(module);
|
||||
tiles.run(module);
|
||||
selection.run(module, *llvm);
|
||||
// return binary
|
||||
std::unique_ptr<driver::module> res(driver::module::create(context, std::move(llvm)));
|
||||
|
@@ -34,7 +34,7 @@ int main() {
|
||||
for(const auto& c: configs){
|
||||
std::tie(ord, AT, BT, M, N, K) = c;
|
||||
std::cout << "// " << c << std::flush;
|
||||
for(auto perf: bench_dot(stream, HALF, AT, BT, M, N, K, ord, ord))
|
||||
for(auto perf: bench_dot(stream, FLOAT, AT, BT, M, N, K, ord, ord))
|
||||
std::cout << ", " << perf << std::flush;
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
@@ -109,10 +109,10 @@ bool triton_dot(drv::stream* stream, bool AT, bool BT,
|
||||
opt.num_warps = {nwarp};
|
||||
}
|
||||
if(mode == BENCH) {
|
||||
opt.defines.push_back({"TM", {"128"}});
|
||||
opt.defines.push_back({"TN", {"128"}});
|
||||
opt.defines.push_back({"TK", {"16"}});
|
||||
opt.num_warps = {4};
|
||||
opt.defines.push_back({"TM", {"64", "128"}});
|
||||
opt.defines.push_back({"TN", {"64", "128"}});
|
||||
opt.defines.push_back({"TK", {"8"}});
|
||||
opt.num_warps = {2, 4, 8};
|
||||
}
|
||||
|
||||
// kernels
|
||||
|
@@ -13,7 +13,7 @@ int main() {
|
||||
for(int TM: std::vector<int>{32, 64})
|
||||
for(int TN: std::vector<int>{32, 64})
|
||||
for(int TK: std::vector<int>{8})
|
||||
for(int nwarps: std::vector<int>{1, 2, 4, 8})
|
||||
for(int nwarps: std::vector<int>{1, 4})
|
||||
for(bool AT: std::array<bool, 2>{false, true})
|
||||
for(bool BT: std::array<bool, 2>{false, true}){
|
||||
configs.push_back(config_t{FLOAT, AT, BT, 128, 128, 128, TM, TN, TK, nwarps});
|
||||
@@ -29,7 +29,6 @@ int main() {
|
||||
std::cout << " Pass! " << std::endl;
|
||||
else{
|
||||
std::cout << " Fail! " << std::endl;
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user