[codegen] more progress towards unified dot implementation
This commit is contained in:
@@ -4,6 +4,7 @@
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <iostream>
|
||||
#include "triton/codegen/analysis/liveness.h"
|
||||
|
||||
namespace triton{
|
||||
|
||||
@@ -25,10 +26,8 @@ class allocation {
|
||||
public:
|
||||
allocation(liveness *live, tiles *params)
|
||||
: liveness_(live), tiles_(params){ }
|
||||
// utilities
|
||||
unsigned num_bytes(ir::value *x);
|
||||
unsigned is_ld_padded(ir::value* x);
|
||||
// accessors
|
||||
bool has_offset(ir::value *x) const { return offsets_.find(x) != offsets_.end(); }
|
||||
unsigned offset(ir::value *x) const { return offsets_.at(x); }
|
||||
unsigned allocated_size() const { return allocated_size_; }
|
||||
// run
|
||||
@@ -36,7 +35,6 @@ public:
|
||||
|
||||
private:
|
||||
std::map<ir::value*, unsigned> offsets_;
|
||||
std::map<ir::value*, unsigned> num_bytes_;
|
||||
size_t allocated_size_;
|
||||
// dependences
|
||||
liveness *liveness_;
|
||||
|
@@ -2,6 +2,8 @@
|
||||
#define TDL_INCLUDE_IR_CODEGEN_LIVENESS_H
|
||||
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
|
||||
namespace triton{
|
||||
|
||||
@@ -10,6 +12,7 @@ namespace ir{
|
||||
class phi_node;
|
||||
class function;
|
||||
class module;
|
||||
class instruction;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
@@ -17,7 +20,7 @@ namespace analysis{
|
||||
|
||||
typedef unsigned slot_index;
|
||||
|
||||
class cts;
|
||||
class tiles;
|
||||
|
||||
struct segment {
|
||||
slot_index start;
|
||||
@@ -37,21 +40,47 @@ struct double_buffer_info_t {
|
||||
ir::phi_node* phi;
|
||||
};
|
||||
|
||||
struct buffer_t {
|
||||
unsigned id;
|
||||
size_t size;
|
||||
bool operator<(buffer_t other) const { return id < other.id; }
|
||||
};
|
||||
|
||||
class liveness {
|
||||
private:
|
||||
typedef std::map<ir::value*, slot_index> indices_map_t;
|
||||
typedef std::map<ir::value*, segment> intervals_map_t;
|
||||
typedef std::map<buffer_t, segment> intervals_map_t;
|
||||
typedef std::map<ir::value*, bool> has_storage_map_t;
|
||||
typedef ir::value* node_t;
|
||||
typedef std::map <node_t, std::set<node_t>> graph_t;
|
||||
|
||||
public:
|
||||
// Intervals iterators
|
||||
using iterator = intervals_map_t::iterator;
|
||||
using const_iterator = intervals_map_t::const_iterator;
|
||||
|
||||
|
||||
|
||||
|
||||
private:
|
||||
void connected_components(node_t x, std::set<node_t> &nodes, graph_t &graph, unsigned group_id);
|
||||
void extract_double_bufferable(ir::instruction *i);
|
||||
void extract_buffers(ir::instruction *i);
|
||||
void get_parents(ir::instruction *i, std::vector<ir::value *>& res);
|
||||
void make_graph(ir::instruction *i);
|
||||
|
||||
|
||||
public:
|
||||
liveness(tiles *t): tiles_(t){ }
|
||||
// buffer size
|
||||
unsigned is_ld_padded(ir::value *x);
|
||||
unsigned num_bytes(ir::value *x);
|
||||
// accessors
|
||||
const intervals_map_t& intervals() const { return intervals_; }
|
||||
segment get_interval(ir::value* v) const { return intervals_.at(v); }
|
||||
const intervals_map_t& intervals() const { return intervals_; }
|
||||
segment get_interval(buffer_t v) const { return intervals_.at(v); }
|
||||
// buffers
|
||||
buffer_t get_buffer(ir::value *v) const { return groups_.at(v); }
|
||||
std::vector<ir::value*> get_values(buffer_t x) const { return values_.at(x); }
|
||||
// double-buffering
|
||||
bool has_double(ir::value *x) const { return double_.find(x) != double_.end(); }
|
||||
double_buffer_info_t get_double(ir::value *x) const { return double_.at(x); }
|
||||
@@ -59,10 +88,19 @@ public:
|
||||
void run(ir::module &mod);
|
||||
|
||||
private:
|
||||
// analysis
|
||||
tiles *tiles_;
|
||||
// stuff
|
||||
has_storage_map_t has_dedicated_storage_;
|
||||
indices_map_t indices_;
|
||||
indices_map_t indices;
|
||||
intervals_map_t intervals_;
|
||||
std::map<ir::value*, double_buffer_info_t> double_;
|
||||
std::map<ir::value*, std::vector<ir::value*>> parents_;
|
||||
// graph
|
||||
std::set<node_t> nodes_;
|
||||
graph_t graph_;
|
||||
std::map<ir::value*, buffer_t> groups_;
|
||||
std::map<buffer_t, std::vector<ir::value*>> values_;
|
||||
};
|
||||
|
||||
}
|
||||
|
@@ -56,7 +56,7 @@ static const std::map<ir::value_id_t, inst_storage_info_t> storage_info = {
|
||||
{ ir::INST_BROADCAST, {DISTRIBUTED, {REPLICATED}}},
|
||||
{ ir::INST_DOWNCAST, {DISTRIBUTED, {REPLICATED}}},
|
||||
// array arithmetic
|
||||
{ ir::INST_TRANS, {SHARED, {DISTRIBUTED}}}, // TODO: not necessarily
|
||||
{ ir::INST_TRANS, {SHARED, {SHARED}}},
|
||||
{ ir::INST_REDUCE, {SHARED, {DISTRIBUTED}}},
|
||||
{ ir::INST_DOT, {DISTRIBUTED, {SHARED, SHARED, DISTRIBUTED}}},
|
||||
// terminator
|
||||
|
@@ -15,79 +15,28 @@ namespace triton{
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
|
||||
unsigned allocation::is_ld_padded(ir::value *x) {
|
||||
if(auto *trans = dynamic_cast<ir::trans_inst*>(x)){
|
||||
if(trans->get_perm()[0]->get_value() != 0)
|
||||
return 4;
|
||||
}
|
||||
auto order = tiles_->order(x);
|
||||
bool is_col_major = order[0] == 0;
|
||||
if(tiles_->hmma(x) == HMMA_A_ROW)
|
||||
return is_col_major ? 16 : 8;
|
||||
if(tiles_->hmma(x) == HMMA_A_COL)
|
||||
return is_col_major ? 8 : 16;
|
||||
if(tiles_->hmma(x) == HMMA_B_COL)
|
||||
return is_col_major ? 16 : 8;
|
||||
if(tiles_->hmma(x) == HMMA_B_ROW)
|
||||
return is_col_major ? 8 : 16;
|
||||
if(auto* phi = dynamic_cast<ir::phi_node*>(x)) {
|
||||
unsigned result = 0;
|
||||
for(unsigned i = 0; i < phi->get_num_incoming(); i++)
|
||||
result = std::max(result, is_ld_padded(phi->get_incoming_value(i)));
|
||||
return result;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
unsigned allocation::num_bytes(ir::value *x) {
|
||||
if(auto *red = dynamic_cast<ir::reduce_inst*>(x)){
|
||||
unsigned num_bytes = x->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
|
||||
size_t axis = red->get_axis();
|
||||
ir::value *op = red->get_operand(0);
|
||||
auto shapes = op->get_type()->get_tile_shapes();
|
||||
shapes.erase(shapes.begin() + axis);
|
||||
size_t num_elements = 1;
|
||||
for(auto x: shapes)
|
||||
num_elements *= x;
|
||||
size_t depth;
|
||||
if(tiles_->hmma(x))
|
||||
depth = tiles_->wpt(op, axis);
|
||||
else
|
||||
depth = tiles_->mts(op, axis);
|
||||
return num_elements * num_bytes * depth;
|
||||
}
|
||||
unsigned num_bytes = x->get_type()->get_primitive_size_in_bits() / 8;
|
||||
unsigned pad = is_ld_padded(x);
|
||||
if(pad > 0){
|
||||
unsigned ld = x->get_type()->get_tile_shapes()[tiles_->order(x)[0]];
|
||||
num_bytes += pad * num_bytes / ld;
|
||||
}
|
||||
if(liveness_->has_double(x))
|
||||
num_bytes *= 2;
|
||||
return num_bytes;
|
||||
}
|
||||
|
||||
void allocation::run(ir::module &mod) {
|
||||
using std::max;
|
||||
using std::min;
|
||||
typedef std::multimap<unsigned, segment> triples_map_type;
|
||||
|
||||
std::vector<ir::value *> I;
|
||||
std::vector<buffer_t> I;
|
||||
for(auto x: liveness_->intervals())
|
||||
I.push_back(x.first);
|
||||
std::vector<ir::value *> J = I;
|
||||
std::vector<buffer_t> J = I;
|
||||
|
||||
triples_map_type H;
|
||||
H.insert({0, segment{0, INT_MAX}});
|
||||
|
||||
std::vector<ir::value *> V;
|
||||
std::map<ir::value *, unsigned> starts;
|
||||
std::vector<buffer_t> V;
|
||||
std::map<buffer_t, unsigned> starts;
|
||||
while(!J.empty()){
|
||||
auto h_it = H.begin();
|
||||
unsigned w = h_it->first;
|
||||
segment xh = h_it->second;
|
||||
H.erase(h_it);
|
||||
auto j_it = std::find_if(J.begin(), J.end(), [&](ir::value *JJ){
|
||||
auto j_it = std::find_if(J.begin(), J.end(), [&](buffer_t JJ){
|
||||
segment xj = liveness_->get_interval(JJ);
|
||||
bool res = xj.intersect(xh);
|
||||
for(auto val: H)
|
||||
@@ -95,7 +44,7 @@ void allocation::run(ir::module &mod) {
|
||||
return res;
|
||||
});
|
||||
if(j_it != J.end()){
|
||||
unsigned size = num_bytes(*j_it);
|
||||
unsigned size = j_it->size;
|
||||
segment xj = liveness_->get_interval(*j_it);
|
||||
starts[*j_it] = w;
|
||||
H.insert({w + size, segment{max(xh.start, xj.start), min(xh.end, xj.end)}});
|
||||
@@ -109,14 +58,14 @@ void allocation::run(ir::module &mod) {
|
||||
}
|
||||
|
||||
// Build interference graph
|
||||
std::map<ir::value*, std::set<ir::value *>> interferences;
|
||||
for(ir::value *x: V)
|
||||
for(ir::value *y: V){
|
||||
if(x == y)
|
||||
std::map<buffer_t, std::set<buffer_t>> interferences;
|
||||
for(buffer_t x: V)
|
||||
for(buffer_t y: V){
|
||||
if(x.id == y.id)
|
||||
continue;
|
||||
unsigned X0 = starts[x], Y0 = starts[y];
|
||||
unsigned NX = num_bytes(x);
|
||||
unsigned NY = num_bytes(y);
|
||||
unsigned NX = x.size;
|
||||
unsigned NY = y.size;
|
||||
segment XS = {X0, X0 + NX};
|
||||
segment YS = {Y0, Y0 + NY};
|
||||
if(liveness_->get_interval(x).intersect(liveness_->get_interval(y))
|
||||
@@ -125,17 +74,17 @@ void allocation::run(ir::module &mod) {
|
||||
}
|
||||
|
||||
// Initialize colors
|
||||
std::map<ir::value *, int> colors;
|
||||
for(ir::value *X: V)
|
||||
colors[X] = (X==V[0])?0:-1;
|
||||
std::map<buffer_t, int> colors;
|
||||
for(buffer_t X: V)
|
||||
colors[X] = (X.id==V[0].id)?0:-1;
|
||||
|
||||
|
||||
// First-fit graph coloring
|
||||
std::vector<bool> available(V.size());
|
||||
for(ir::value *x: V){
|
||||
for(buffer_t x: V){
|
||||
// Non-neighboring colors are available
|
||||
std::fill(available.begin(), available.end(), true);
|
||||
for(ir::value *Y: interferences[x]){
|
||||
for(buffer_t Y: interferences[x]){
|
||||
int color = colors[Y];
|
||||
if(color >= 0)
|
||||
available[color] = false;
|
||||
@@ -146,21 +95,24 @@ void allocation::run(ir::module &mod) {
|
||||
}
|
||||
|
||||
// Finalize allocation
|
||||
for(ir::value *x: V){
|
||||
for(buffer_t x: V){
|
||||
unsigned Adj = 0;
|
||||
for(ir::value *y: interferences[x])
|
||||
Adj = std::max(Adj, starts[y] + num_bytes(y));
|
||||
offsets_[x] = starts[x] + colors[x] * Adj;
|
||||
if(liveness_->has_double(x)){
|
||||
auto info = liveness_->get_double(x);
|
||||
offsets_[info.latch] = offsets_[x] + num_bytes(x) / 2;
|
||||
for(buffer_t y: interferences[x])
|
||||
Adj = std::max<unsigned>(Adj, starts[y] + y.size);
|
||||
// create offsets
|
||||
for(ir::value *v: liveness_->get_values(x)){
|
||||
offsets_[v] = starts[x] + colors[x] * Adj;
|
||||
if(liveness_->has_double(v)){
|
||||
auto info = liveness_->get_double(v);
|
||||
offsets_[info.latch] = offsets_[v] + x.size / 2;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Save maximum size of induced memory space
|
||||
allocated_size_ = 0;
|
||||
for(auto &x: offsets_){
|
||||
allocated_size_ = std::max<size_t>(allocated_size_, x.second + num_bytes(x.first));
|
||||
allocated_size_ = std::max<size_t>(allocated_size_, x.second + liveness_->get_buffer(x.first).size);
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -1,6 +1,9 @@
|
||||
#include <iostream>
|
||||
#include <climits>
|
||||
#include <unordered_set>
|
||||
#include "triton/codegen/instructions.h"
|
||||
#include "triton/codegen/analysis/liveness.h"
|
||||
#include "triton/codegen/analysis/tiles.h"
|
||||
#include "triton/codegen/transform/cts.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/function.h"
|
||||
@@ -25,7 +28,7 @@ inline bool is_loop_latch(ir::phi_node *phi, ir::instruction *terminator){
|
||||
throw std::runtime_error("unreachable");
|
||||
}
|
||||
|
||||
inline void extract_double_bufferable(ir::instruction *i, std::map<ir::value*, double_buffer_info_t>& result) {
|
||||
void liveness::extract_double_bufferable(ir::instruction *i) {
|
||||
auto* phi = dynamic_cast<ir::phi_node*>(i);
|
||||
if(!phi || phi->get_num_incoming() != 2)
|
||||
return;
|
||||
@@ -42,65 +45,142 @@ inline void extract_double_bufferable(ir::instruction *i, std::map<ir::value*, d
|
||||
if(!i_0 || !i_1 || storage_info.at(i_0->get_id()).first != SHARED || storage_info.at(i_1->get_id()).first != SHARED)
|
||||
return;
|
||||
if(is_latch_1)
|
||||
result[value_0] = double_buffer_info_t{value_1, phi};
|
||||
double_[value_0] = double_buffer_info_t{value_1, phi};
|
||||
if(is_latch_0)
|
||||
result[value_1] = double_buffer_info_t{value_0, phi};
|
||||
double_[value_1] = double_buffer_info_t{value_0, phi};
|
||||
}
|
||||
|
||||
void liveness::make_graph(ir::instruction *i) {
|
||||
if(has_double(i)){
|
||||
ir::value *latch = double_[i].latch;
|
||||
nodes_.insert(i);
|
||||
nodes_.insert(latch);
|
||||
graph_[i].insert(latch);
|
||||
graph_[latch].insert(i);
|
||||
}
|
||||
if(i->get_id() == ir::INST_TRANS){
|
||||
nodes_.insert(i);
|
||||
nodes_.insert(i->get_operand(0));
|
||||
graph_[i].insert(i->get_operand(0));
|
||||
graph_[i->get_operand(0)].insert(i);
|
||||
}
|
||||
}
|
||||
|
||||
// connected components
|
||||
void liveness::connected_components(node_t x, std::set<node_t> &nodes, graph_t &graph, unsigned group_id) {
|
||||
buffer_t buffer{group_id, num_bytes(x)};
|
||||
groups_[x] = buffer;
|
||||
values_[buffer].push_back(x);
|
||||
if(nodes.find(x) != nodes.end()){
|
||||
nodes.erase(x);
|
||||
for(const node_t &y: graph[x])
|
||||
connected_components(y, nodes, graph, group_id);
|
||||
}
|
||||
}
|
||||
|
||||
unsigned liveness::is_ld_padded(ir::value *x) {
|
||||
if(auto *trans = dynamic_cast<ir::trans_inst*>(x)){
|
||||
if(trans->get_perm()[0]->get_value() != 0)
|
||||
return 4;
|
||||
}
|
||||
auto order = tiles_->order(x);
|
||||
bool is_col_major = order[0] == 0;
|
||||
if(tiles_->hmma(x) == HMMA_A_ROW)
|
||||
return is_col_major ? 16 : 16;
|
||||
if(tiles_->hmma(x) == HMMA_A_COL)
|
||||
return is_col_major ? 8 : 8;
|
||||
if(tiles_->hmma(x) == HMMA_B_COL)
|
||||
return is_col_major ? 16 : 16;
|
||||
if(tiles_->hmma(x) == HMMA_B_ROW)
|
||||
return is_col_major ? 8 : 8;
|
||||
if(auto* phi = dynamic_cast<ir::phi_node*>(x)) {
|
||||
unsigned result = 0;
|
||||
for(unsigned i = 0; i < phi->get_num_incoming(); i++)
|
||||
result = std::max(result, is_ld_padded(phi->get_incoming_value(i)));
|
||||
return result;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
unsigned liveness::num_bytes(ir::value *x) {
|
||||
if(auto *red = dynamic_cast<ir::reduce_inst*>(x)){
|
||||
unsigned num_bytes = x->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
|
||||
size_t axis = red->get_axis();
|
||||
ir::value *op = red->get_operand(0);
|
||||
auto shapes = op->get_type()->get_tile_shapes();
|
||||
shapes.erase(shapes.begin() + axis);
|
||||
size_t num_elements = 1;
|
||||
for(auto x: shapes)
|
||||
num_elements *= x;
|
||||
size_t depth;
|
||||
if(tiles_->hmma(x))
|
||||
depth = tiles_->wpt(op, axis);
|
||||
else
|
||||
depth = tiles_->mts(op, axis);
|
||||
return num_elements * num_bytes * depth;
|
||||
}
|
||||
unsigned num_bytes = x->get_type()->get_primitive_size_in_bits() / 8;
|
||||
unsigned pad = is_ld_padded(x);
|
||||
if(pad > 0){
|
||||
unsigned ld = x->get_type()->get_tile_shapes()[tiles_->order(x)[0]];
|
||||
num_bytes += pad * num_bytes / ld;
|
||||
}
|
||||
if(has_double(x))
|
||||
num_bytes *= 2;
|
||||
return num_bytes;
|
||||
}
|
||||
|
||||
// Entry point
|
||||
void liveness::run(ir::module &mod) {
|
||||
double_.clear();
|
||||
indices_.clear();
|
||||
indices.clear();
|
||||
intervals_.clear();
|
||||
parents_.clear();
|
||||
|
||||
// set of pair of values that can be double-buffered
|
||||
// Create set of pair of values that can be double-buffered
|
||||
ir::for_each_instruction(mod, [this](ir::instruction* i) {
|
||||
extract_double_bufferable(i, this->double_);
|
||||
this->extract_double_bufferable(i);
|
||||
});
|
||||
|
||||
// Create buffer dependency graph
|
||||
ir::for_each_instruction(mod, [this](ir::instruction* i) {
|
||||
this->make_graph(i);
|
||||
});
|
||||
|
||||
// connected components
|
||||
unsigned group_id = 0;
|
||||
while(!nodes_.empty()){
|
||||
connected_components(*nodes_.begin(), nodes_, graph_, group_id++);
|
||||
}
|
||||
|
||||
// Assigns index to each instruction
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
// Assigns index to each instruction
|
||||
slot_index index = 0;
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *instr: block->get_inst_list()){
|
||||
index += 1;
|
||||
indices_.insert({instr, index});
|
||||
}
|
||||
// Liveness analysis
|
||||
// Creates live intervals
|
||||
for(auto i: indices_){
|
||||
ir::value *v = i.first;
|
||||
ir::instruction* instr = dynamic_cast<ir::instruction*>(v);
|
||||
if(!instr)
|
||||
continue;
|
||||
if(storage_info.at(instr->get_id()).first != SHARED)
|
||||
continue;
|
||||
unsigned start = i.second;
|
||||
unsigned end = start;
|
||||
for(ir::value *u: v->get_users()){
|
||||
start = std::min(start, indices_.at(u));
|
||||
end = std::max(end, indices_.at(u));
|
||||
}
|
||||
intervals_[v] = segment{start, end};
|
||||
}
|
||||
// Double-Buffering
|
||||
// Arrays are live throughout the end of the loop
|
||||
auto it = intervals_.begin();
|
||||
while(it != intervals_.end()) {
|
||||
ir::value *x = it->first;
|
||||
auto dit = double_.find(x);
|
||||
if(dit != double_.end()) {
|
||||
ir::value *y = dit->second.latch;
|
||||
unsigned start = intervals_[x].start;
|
||||
unsigned end = intervals_[y].end;
|
||||
intervals_[x] = segment{start, end};
|
||||
intervals_.erase(y);
|
||||
}
|
||||
it++;
|
||||
indices.insert({instr, index});
|
||||
}
|
||||
}
|
||||
|
||||
for(auto x: values_) {
|
||||
// users
|
||||
std::set<ir::value*> values;
|
||||
for(ir::value *v: x.second){
|
||||
values.insert(v);
|
||||
for(ir::user *u: v->get_users())
|
||||
values.insert(u);
|
||||
}
|
||||
// compute intervals
|
||||
unsigned start = INT32_MAX;
|
||||
unsigned end = 0;
|
||||
for(ir::value *u: values){
|
||||
start = std::min(start, indices.at(u));
|
||||
end = std::max(end, indices.at(u));
|
||||
}
|
||||
intervals_[x.first] = segment{start, end};
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
@@ -725,7 +725,7 @@ void selection::create_shared_tile(ir::value *v, IRBuilder<> &builder, Value *sh
|
||||
return;
|
||||
auto order = tiles_->order(v);
|
||||
auto shapes = v->get_type()->get_tile_shapes();
|
||||
unsigned pad = alloc_->is_ld_padded(v);
|
||||
unsigned pad = liveness_->is_ld_padded(v);
|
||||
if(pad > 0)
|
||||
shapes[order[0]] += pad;
|
||||
Type* ty = llvm_type(v->get_type()->get_scalar_ty(), builder.getContext());
|
||||
@@ -1040,15 +1040,13 @@ void selection::lower_copy_to_shared(ir::copy_to_shared_inst *x, LLVMContext &ct
|
||||
}
|
||||
|
||||
void selection::lower_trans(ir::trans_inst *x, LLVMContext &ctx, Function *fn, IRBuilder<> &builder) {
|
||||
shared_tile* result = (shared_tile*)tmap_.at(x);
|
||||
distributed_tile* in = (distributed_tile*)tmap_.at(x->get_operand(0));
|
||||
auto perm = x->get_perm();
|
||||
in->for_each([&](indices_t idx){
|
||||
indices_t out_idx(idx.size());
|
||||
for(size_t i = 0; i < idx.size(); i++)
|
||||
out_idx[i] = idx[perm[i]->get_value()];
|
||||
result->set_value(out_idx, in->get_value(idx));
|
||||
});
|
||||
shared_tile* in = (shared_tile*)tmap_.at(x->get_operand(0));
|
||||
auto in_order = in->get_order();
|
||||
std::vector<int> order;
|
||||
for(auto p: x->get_perm())
|
||||
order.push_back(in_order[p->get_value()]);
|
||||
shared_tile* out = new shared_tile(in->get_ty(), in->get_shapes(), order, in->get_pointer(), builder, in->get_offset());
|
||||
tmap_[x] = out;
|
||||
}
|
||||
|
||||
void selection::lower_hmma_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *fn, IRBuilder<> &builder,
|
||||
@@ -1555,7 +1553,7 @@ void selection::run(ir::module &src, Module &dst) {
|
||||
}
|
||||
else {
|
||||
unsigned num_bytes = phi->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
|
||||
offset->addIncoming(dst_builder.getInt32(alloc_->num_bytes(phi)/(num_bytes)), llvm_inc_block);
|
||||
offset->addIncoming(dst_builder.getInt32(liveness_->num_bytes(phi)/(num_bytes)), llvm_inc_block);
|
||||
}
|
||||
ptr->addIncoming(inc_shared->get_pointer(), llvm_inc_block);
|
||||
}
|
||||
|
@@ -36,9 +36,9 @@ void membar::add_reference(ir::value *v, interval_vec_t &res){
|
||||
auto *i = dynamic_cast<ir::instruction*>(v);
|
||||
if(!i)
|
||||
return;
|
||||
if(storage_info.at(i->get_id()).first == SHARED){
|
||||
if(alloc_->has_offset(v)){
|
||||
unsigned offset = alloc_->offset(v);
|
||||
unsigned num_bytes = alloc_->num_bytes(v);
|
||||
unsigned num_bytes = liveness_->num_bytes(v);
|
||||
res.push_back(interval_t(offset, offset + num_bytes));
|
||||
}
|
||||
}
|
||||
@@ -97,8 +97,10 @@ std::pair<membar::interval_vec_t,
|
||||
bool read_after_write = intersect(new_written_to, read);
|
||||
bool write_after_read = intersect(new_read_from, written);
|
||||
// double buffering: write and phi-node read won't intersect
|
||||
if(safe_war.find(i) != safe_war.end())
|
||||
if(safe_war.find(i) != safe_war.end()){
|
||||
write_after_read = false;
|
||||
read_after_write = false;
|
||||
}
|
||||
// record hazards
|
||||
if(read_after_write || write_after_read) {
|
||||
insert_loc.insert(i);
|
||||
@@ -122,7 +124,12 @@ void membar::run(ir::module &mod) {
|
||||
auto info = liveness_->get_double(i);
|
||||
safe_war.insert(i);
|
||||
safe_war.insert(info.latch);
|
||||
auto *trans = dynamic_cast<ir::trans_inst*>(info.latch);
|
||||
if(trans)
|
||||
safe_war.insert(trans->get_operand(0));
|
||||
}
|
||||
if(i->get_id() == ir::INST_TRANS)
|
||||
safe_war.insert(i);
|
||||
});
|
||||
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
@@ -152,9 +159,8 @@ void membar::run(ir::module &mod) {
|
||||
done = (n_inserted_im1 == n_inserted_i);
|
||||
n_inserted_im1 = n_inserted_i;
|
||||
}while(!done);
|
||||
for(ir::instruction* i: insert_locs){
|
||||
for(ir::instruction* i: insert_locs)
|
||||
insert_barrier(i, builder);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -84,70 +84,6 @@ bool peephole::rewrite_trans_phi(ir::instruction* value, ir::builder& builder) {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool peephole::rewrite_dot_hmma(ir::dot_inst *dot, ir::builder& builder, bool trans_a, bool trans_b,
|
||||
ir::value *A, ir::value *B, ir::value *D){
|
||||
ir::value *AA = A;
|
||||
ir::value *BB = B;
|
||||
if(trans_a){
|
||||
AA = ((ir::trans_inst*)A)->get_operand(0);
|
||||
}
|
||||
else{
|
||||
if(auto *T = dynamic_cast<ir::trans_inst*>(A)){
|
||||
std::vector<ir::constant_int*> perm(T->get_perm());
|
||||
std::swap(perm[0], perm[1]);
|
||||
AA = builder.create_trans(T->get_operand(0), perm);
|
||||
T->replace_all_uses_with(AA);
|
||||
trans_a = true;
|
||||
}
|
||||
}
|
||||
if(trans_b){
|
||||
BB = ((ir::trans_inst*)B)->get_operand(0);
|
||||
}
|
||||
else{
|
||||
if(auto *T = dynamic_cast<ir::trans_inst*>(B)){
|
||||
std::vector<ir::constant_int*> perm(T->get_perm());
|
||||
std::swap(perm[0], perm[1]);
|
||||
BB = builder.create_trans(T->get_operand(0), perm);
|
||||
T->replace_all_uses_with(BB);
|
||||
trans_b = true;
|
||||
}
|
||||
}
|
||||
if(!trans_a && !trans_b)
|
||||
return false;
|
||||
|
||||
ir::instruction *dot_atbt = builder.insert(ir::dot_inst::create(AA, BB, D, trans_a, trans_b));
|
||||
dot->replace_all_uses_with(dot_atbt);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool peephole::rewrite_dot_fp32(ir::dot_inst *dot, ir::builder& builder, bool trans_a, bool trans_b,
|
||||
ir::value *A, ir::value *B, ir::value *D){
|
||||
// dot(op(a), trans(b))
|
||||
if(trans_b){
|
||||
ir::value* BB = ((ir::trans_inst*)B)->get_operand(0);
|
||||
ir::instruction *NT = builder.insert(ir::dot_inst::create_nt(A, BB, D));
|
||||
dot->replace_all_uses_with(NT);
|
||||
return true;
|
||||
}
|
||||
// dot(op(a), b)
|
||||
if(!trans_b){
|
||||
// create permutations
|
||||
size_t size = B->get_type()->get_tile_shapes().size();
|
||||
std::vector<ir::constant_int*> perm(size);
|
||||
ir::type *int32_ty = ir::type::get_int32_ty(B->get_type()->get_context());
|
||||
for(size_t i = 0; i < size; i++)
|
||||
perm[i] = ir::constant_int::get(int32_ty, i);
|
||||
std::swap(perm[0], perm[1]);
|
||||
// replace NN -> NT (trans)
|
||||
ir::value* BB = builder.create_trans(B, perm);
|
||||
ir::instruction *NT = builder.insert(ir::dot_inst::create_nt(A, BB, D));
|
||||
dot->replace_all_uses_with(NT);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){
|
||||
// dot(a, b, 0) + c -> dot(a, b, c)
|
||||
auto add = dynamic_cast<ir::binary_operator*>(value);
|
||||
@@ -176,26 +112,6 @@ bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){
|
||||
add->replace_all_uses_with(new_dot);
|
||||
return true;
|
||||
}
|
||||
|
||||
// dot(a, b, c)
|
||||
auto dot = dynamic_cast<ir::dot_inst*>(value);
|
||||
if(!dot)
|
||||
return false;
|
||||
builder.set_insert_point(value);
|
||||
ir::value *A = dot->get_operand(0);
|
||||
ir::value *B = dot->get_operand(1);
|
||||
ir::value *D = dot->get_operand(2);
|
||||
bool trans_a = is_trans(A);
|
||||
bool trans_b = is_trans(B);
|
||||
// only consider dot-nn
|
||||
if(dot->is_a_trans() || dot->is_b_trans())
|
||||
return false;
|
||||
// hmma
|
||||
if(is_hmma(dot)){
|
||||
return rewrite_dot_hmma(dot, builder, trans_a, trans_b, A, B, D);
|
||||
}
|
||||
else
|
||||
return rewrite_dot_fp32(dot, builder, trans_a, trans_b, A, B, D);
|
||||
}
|
||||
|
||||
bool peephole::rewrite_unit_red(ir::instruction *value, ir::builder& builder){
|
||||
|
@@ -202,11 +202,11 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
||||
// create passes
|
||||
codegen::transform::cts cts;
|
||||
codegen::analysis::align align;
|
||||
codegen::analysis::liveness liveness;
|
||||
codegen::analysis::axes axes;
|
||||
codegen::analysis::layout layouts(&axes);
|
||||
codegen::transform::coalesce coalesce(&align, &layouts);
|
||||
codegen::analysis::tiles tiles(opt.num_warps, &align, &axes, &layouts);
|
||||
codegen::analysis::liveness liveness(&tiles);
|
||||
codegen::analysis::allocation allocation(&liveness, &tiles);
|
||||
codegen::transform::membar barriers(&liveness, &allocation);
|
||||
codegen::transform::dce dce;
|
||||
@@ -235,12 +235,10 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
||||
return std::unique_ptr<driver::module>();
|
||||
barriers.run(module);
|
||||
dce.run(module);
|
||||
dce.run(module);
|
||||
axes.run(module);
|
||||
layouts.run(module);
|
||||
align.run(module);
|
||||
tiles.run(module);
|
||||
// ir::print(module, std::cout);
|
||||
selection.run(module, *llvm);
|
||||
// return binary
|
||||
std::unique_ptr<driver::module> res(driver::module::create(context, std::move(llvm)));
|
||||
|
@@ -79,10 +79,8 @@ int main() {
|
||||
// shapes to benchmark
|
||||
typedef std::tuple<bool, bool, int, int, int> config_t;
|
||||
std::vector<config_t> configs;
|
||||
for(auto x: std::vector<std::array<bool, 2>>{{false, false},
|
||||
{false, true},
|
||||
{true, false},
|
||||
{true, true}}){
|
||||
for(auto x: std::vector<std::array<bool, 2>>{{false, true},
|
||||
{true, false}, {true, true}}){
|
||||
std::vector<config_t> tmp = {
|
||||
config_t{x[0], x[1], 4096, 4096, 4096}
|
||||
// config_t{x[0], x[1], 16, 2048, 2048},
|
||||
|
@@ -4,15 +4,15 @@ namespace src {
|
||||
R"(
|
||||
#if AT == 1
|
||||
#define USEA ^a
|
||||
#define STRIDE_AK 1
|
||||
#define STRIDE_AM lda
|
||||
#define STRIDE_AK lda
|
||||
#define STRIDE_AM 1
|
||||
#define BROADCAST_AK :, newaxis
|
||||
#define BROADCAST_AM newaxis, :
|
||||
#define SHAPE_A TK, TM
|
||||
#else
|
||||
#define USEA a
|
||||
#define STRIDE_AK lda
|
||||
#define STRIDE_AM 1
|
||||
#define STRIDE_AK 1
|
||||
#define STRIDE_AM lda
|
||||
#define BROADCAST_AK newaxis, :
|
||||
#define BROADCAST_AM :, newaxis
|
||||
#define SHAPE_A TM, TK
|
||||
@@ -20,15 +20,15 @@ R"(
|
||||
|
||||
#if BT == 1
|
||||
#define USEB ^b
|
||||
#define STRIDE_BK ldb
|
||||
#define STRIDE_BN 1
|
||||
#define STRIDE_BK 1
|
||||
#define STRIDE_BN ldb
|
||||
#define BROADCAST_BK newaxis, :
|
||||
#define BROADCAST_BN :, newaxis
|
||||
#define SHAPE_B TN, TK
|
||||
#else
|
||||
#define USEB b
|
||||
#define STRIDE_BK 1
|
||||
#define STRIDE_BN ldb
|
||||
#define STRIDE_BK ldb
|
||||
#define STRIDE_BN 1
|
||||
#define BROADCAST_BK :, newaxis
|
||||
#define BROADCAST_BN newaxis, :
|
||||
#define SHAPE_B TK, TN
|
||||
|
@@ -139,8 +139,8 @@ int main() {
|
||||
// shapes to benchmark
|
||||
typedef std::tuple<dtype_t, bool, bool, int, int, int, int, int, int, int> config_t;
|
||||
std::vector<config_t> configs;
|
||||
for(bool AT: std::array<bool, 2>{false, true})
|
||||
for(bool BT: std::array<bool, 2>{false, true})
|
||||
for(bool AT: std::array<bool, 2>{false})
|
||||
for(bool BT: std::array<bool, 2>{false})
|
||||
for(int TM: std::vector<int>{32, 64})
|
||||
for(int TN: std::vector<int>{32, 64})
|
||||
for(int TK: std::vector<int>{16, 32})
|
||||
|
Reference in New Issue
Block a user