[codegen] more cleaning

This commit is contained in:
Philippe Tillet
2019-10-11 23:40:27 -04:00
parent ee3803b577
commit 7d77f34db0
7 changed files with 109 additions and 200 deletions

View File

@@ -11,8 +11,10 @@ namespace triton{
namespace ir{
class value;
class type;
class module;
class instruction;
class phi_node;
}
namespace codegen{
@@ -27,6 +29,13 @@ enum layout_type_t {
SHARED
};
struct double_buffer_info_t {
ir::value* first;
ir::value* latch;
ir::phi_node* phi;
};
struct layout_t {
layout_t(layout_type_t _type,
const std::vector<int>& _axes,
@@ -41,6 +50,9 @@ struct layout_t {
std::vector<int> order;
size_t id;
size_t size;
std::shared_ptr<double_buffer_info_t> double_buffer;
ir::type *ty;
size_t pad;
std::vector<int> mts;
std::vector<int> nts;
std::vector<int> fpw;
@@ -70,6 +82,7 @@ struct layout_shared_t: public layout_t {
const std::vector<int>& _axes,
const std::vector<unsigned>& _shapes,
const std::vector<ir::value *> &values,
ir::type *ty,
size_t _id,
analysis::align* align);
};

View File

@@ -4,6 +4,7 @@
#include <map>
#include <set>
#include <vector>
#include "triton/codegen/analysis/layout.h"
#include "triton/tools/graph.h"
namespace triton{
@@ -38,60 +39,25 @@ struct segment {
}
};
struct double_buffer_info_t {
ir::value* latch;
ir::phi_node* phi;
};
class liveness {
private:
typedef std::map<ir::value*, slot_index> indices_map_t;
typedef std::map<layout_t*, segment> intervals_map_t;
typedef std::map<ir::value*, bool> has_storage_map_t;
typedef ir::value* node_t;
typedef std::map <node_t, std::set<node_t>> graph_t;
public:
// Intervals iterators
using iterator = intervals_map_t::iterator;
using const_iterator = intervals_map_t::const_iterator;
private:
void extract_double_bufferable(ir::instruction *i);
void extract_buffers(ir::instruction *i);
void get_parents(ir::instruction *i, std::vector<ir::value *>& res);
void make_graph(ir::instruction *i);
bool do_pad(ir::value *x);
public:
liveness(layout *l): layouts_(l){ }
// padding
unsigned get_pad(ir::value *v) const { return pad_.at(v); }
// buffer size
unsigned num_bytes(ir::value *x);
// accessors
const intervals_map_t& intervals() const { return intervals_; }
segment get_interval(layout_t* v) const { return intervals_.at(v); }
// double-buffering
bool has_double(ir::value *x) const { return double_.find(x) != double_.end(); }
double_buffer_info_t get_double(ir::value *x) const { return double_.at(x); }
// run
void run(ir::module &mod);
private:
// analysis
layout *layouts_;
// stuff
has_storage_map_t has_dedicated_storage_;
indices_map_t indices;
intervals_map_t intervals_;
std::map<ir::value*, double_buffer_info_t> double_;
std::map<ir::value*, size_t> pad_;
};
}

View File

@@ -102,10 +102,10 @@ void allocation::run(ir::module &mod) {
// create offsets
for(ir::value *v: x->values){
offsets_[v] = starts[x] + colors[x] * Adj;
if(liveness_->has_double(v)){
auto info = liveness_->get_double(v);
offsets_[info.latch] = offsets_[v] + x->size / 2;
}
}
if(x->double_buffer){
auto info = *x->double_buffer;
offsets_[info.latch] = offsets_[info.first] + x->size / 2;
}
}

View File

@@ -4,6 +4,7 @@
#include "triton/codegen/analysis/axes.h"
#include "triton/codegen/analysis/align.h"
#include "triton/codegen/analysis/layout.h"
#include "triton/codegen/instructions.h"
#include "triton/ir/function.h"
#include "triton/ir/module.h"
#include "triton/ir/utils.h"
@@ -187,6 +188,20 @@ layout_hmma_884_t::layout_hmma_884_t(size_t num_warps,
throw std::runtime_error("cannot create a kernel with this amount of warps");
}
inline bool is_loop_latch(ir::phi_node *phi, ir::instruction *terminator){
if(phi->get_parent() != terminator->get_parent())
return false;
if(auto *br = dynamic_cast<ir::cond_branch_inst*>(terminator))
return br->get_true_dest() == phi->get_parent()
|| br->get_false_dest() == phi->get_parent();
else if(dynamic_cast<ir::uncond_branch_inst*>(terminator))
return false;
else
throw std::runtime_error("unreachable");
}
layout_scanline_t::layout_scanline_t(size_t num_warps,
const std::vector<int>& _axes,
const std::vector<unsigned>& _shapes,
@@ -215,17 +230,49 @@ layout_scanline_t::layout_scanline_t(size_t num_warps,
throw std::runtime_error("cannot create a kernel with this amount of warps");
}
void extract_double_bufferable(ir::value *v, std::shared_ptr<double_buffer_info_t>& res) {
auto* phi = dynamic_cast<ir::phi_node*>(v);
if(!phi || phi->get_num_incoming() != 2)
return;
ir::basic_block *block_0 = phi->get_incoming_block(0);
ir::basic_block *block_1 = phi->get_incoming_block(1);
ir::instruction *terminator_0 = block_0->get_inst_list().back();
ir::instruction *terminator_1 = block_1->get_inst_list().back();
bool is_latch_0 = is_loop_latch(phi, terminator_0);
bool is_latch_1 = is_loop_latch(phi, terminator_1);
ir::value *value_0 = phi->get_incoming_value(0);
ir::value *value_1 = phi->get_incoming_value(1);
ir::instruction *i_0 = dynamic_cast<ir::instruction*>(value_0);
ir::instruction *i_1 = dynamic_cast<ir::instruction*>(value_1);
if(!i_0 || !i_1 ||
storage_info.at(i_0->get_id()).first != codegen::SHARED ||
storage_info.at(i_1->get_id()).first != codegen::SHARED)
return;
if(is_latch_1)
res.reset(new double_buffer_info_t{value_0, value_1, phi});
if(is_latch_0)
res.reset(new double_buffer_info_t{value_1, value_0, phi});
}
layout_shared_t::layout_shared_t(const layout_t *arg,
const std::vector<int>& _axes,
const std::vector<unsigned>& _shapes,
const std::vector<ir::value *> &values,
ir::type *ty,
size_t _id,
analysis::align* align): layout_t(SHARED, _axes, _shapes, values, _id, align) {
this->ty = ty;
size = 0;
// double-buffering
for(ir::value *v: values)
extract_double_bufferable(v, double_buffer);
// order
if(arg->type == SCANLINE)
order = arg->order;
ir::value* dot_a = nullptr;
ir::value* dot_b = nullptr;
ir::value* hmma_dot_a = nullptr;
@@ -238,10 +285,35 @@ layout_shared_t::layout_shared_t(const layout_t *arg,
}
std::vector<int> col = {0, 1};
std::vector<int> row = {1, 0};
if(dot_a && !hmma_dot_a)
bool is_nonhmma_dot_a = dot_a && !hmma_dot_a;
bool is_nonhmma_dot_b = dot_b && !hmma_dot_b;
if(is_nonhmma_dot_a)
order = is_trans(dot_a) ? row : col;
if(dot_b && !hmma_dot_b)
if(is_nonhmma_dot_b)
order = is_trans(dot_b) ? col : row;
// padding
pad = 0;
if(hmma_dot_a){
bool row = is_trans(hmma_dot_a) ^ order[0] == 1;
pad = 24 - shapes[row ? 0: 1] % 32;
}
else if(hmma_dot_b){
bool row = is_trans(hmma_dot_b) ^ order[0] == 1;
pad = 24 - shapes[row ? 1 : 0] % 32;
}
else if(order != arg->order) {
pad = 16;
}
// size
auto shape = this->shapes;
shape[order[0]] += pad;
size = ty->get_primitive_size_in_bits() / 8;
for(auto s: shape)
size *= s;
if(double_buffer)
size *= 2;
}
void layout::create(size_t id, const std::vector<ir::value*>& values) {
@@ -263,7 +335,7 @@ void layout::create(size_t id, const std::vector<ir::value*>& values) {
ir::copy_to_shared_inst *cts = (ir::copy_to_shared_inst*)*it_cts;
ir::value *arg = cts->get_operand(0);
create(groups_.at(arg), values_.at(groups_.at(arg)));
layouts_[id] = new layout_shared_t(get(arg), axes, shapes, values, id, align_);
layouts_[id] = new layout_shared_t(get(arg), axes, shapes, values, largest->get_type()->get_scalar_ty(), id, align_);
}
else
layouts_[id] = new layout_scanline_t(num_warps_, axes, shapes, values, id, align_);

View File

@@ -16,150 +16,12 @@ namespace triton{
namespace codegen{
namespace analysis{
inline bool is_loop_latch(ir::phi_node *phi, ir::instruction *terminator){
if(phi->get_parent() != terminator->get_parent())
return false;
if(auto *br = dynamic_cast<ir::cond_branch_inst*>(terminator))
return br->get_true_dest() == phi->get_parent()
|| br->get_false_dest() == phi->get_parent();
else if(dynamic_cast<ir::uncond_branch_inst*>(terminator))
return false;
else
throw std::runtime_error("unreachable");
}
void liveness::extract_double_bufferable(ir::instruction *i) {
auto* phi = dynamic_cast<ir::phi_node*>(i);
if(!phi || phi->get_num_incoming() != 2)
return;
ir::basic_block *block_0 = phi->get_incoming_block(0);
ir::basic_block *block_1 = phi->get_incoming_block(1);
ir::instruction *terminator_0 = block_0->get_inst_list().back();
ir::instruction *terminator_1 = block_1->get_inst_list().back();
bool is_latch_0 = is_loop_latch(phi, terminator_0);
bool is_latch_1 = is_loop_latch(phi, terminator_1);
ir::value *value_0 = phi->get_incoming_value(0);
ir::value *value_1 = phi->get_incoming_value(1);
ir::instruction *i_0 = dynamic_cast<ir::instruction*>(value_0);
ir::instruction *i_1 = dynamic_cast<ir::instruction*>(value_1);
if(!i_0 || !i_1 || storage_info.at(i_0->get_id()).first != codegen::SHARED || storage_info.at(i_1->get_id()).first != codegen::SHARED)
return;
if(is_latch_1)
double_[value_0] = double_buffer_info_t{value_1, phi};
if(is_latch_0)
double_[value_1] = double_buffer_info_t{value_0, phi};
}
// connected components
bool is_trans(ir::value *v) {
if(dynamic_cast<ir::trans_inst *>(v)) {
return true;
}
if(auto *phi = dynamic_cast<ir::instruction *>(v)) {
bool result = true;
for(ir::value *op: phi->ops())
result = result && is_trans(op);
return result;
}
return false;
}
bool liveness::do_pad(ir::value *x) {
// alignment for matrix product
if(auto* dot = dynamic_cast<ir::dot_inst*>(x)) {
// a
ir::value *a = dot->get_operand(0);
ir::value *b = dot->get_operand(1);
size_t a_previous = pad_[a];
size_t b_previous = pad_[b];
auto a_order = layouts_->get(a)->order;
auto b_order = layouts_->get(b)->order;
bool a_row = is_trans(a) ^ (a_order[0] == 1);
bool b_row = is_trans(b) ^ (b_order[0] == 1);
auto a_shapes = a->get_type()->get_tile_shapes();
auto b_shapes = b->get_type()->get_tile_shapes();
pad_[a] = std::max<int>(pad_[a], (24 - a_shapes[a_row ? 0 : 1]) % 32);
pad_[b] = std::max<int>(pad_[b], (24 - b_shapes[b_row ? 1 : 0]) % 32);
return a_previous != pad_[a] || b_previous != pad_[b];
}
// padding for trans
if(auto* trans = dynamic_cast<ir::trans_inst*>(x)) {
ir::value *op = trans->get_operand(0);
size_t previous = pad_[op];
pad_[op] = std::max(pad_[op], pad_[x]);
return previous != pad_[op];
}
// padding for copy to shared
if(auto* cts = dynamic_cast<ir::copy_to_shared_inst*>(x)) {
auto cts_order = layouts_->get(cts)->order;
ir::value *arg = cts->get_operand(0);
auto arg_order = layouts_->get(arg)->order;
size_t previous = pad_[cts];
if(cts_order != arg_order)
pad_[cts] = std::max<int>(pad_[cts], 4);
return pad_[cts] != previous;
}
// padding for phi-nodes
if(auto* phi = dynamic_cast<ir::phi_node*>(x)) {
bool has_changed = false;
for(unsigned i = 0; i < phi->get_num_incoming(); i++){
ir::value* op = phi->get_operand(i);
size_t previous = pad_[op];
pad_[op] = std::max(pad_[op], pad_[phi]);
has_changed |= previous != pad_[op];
}
return has_changed;
}
// default -- no padding
size_t previous = pad_[x];
pad_[x] = std::max<int>(previous, 0);
return pad_[x] != previous;
}
unsigned liveness::num_bytes(ir::value *x) {
unsigned num_bytes = x->get_type()->get_primitive_size_in_bits() / 8;
unsigned pad = pad_.at(x);
if(pad > 0){
unsigned ld = x->get_type()->get_tile_shapes()[layouts_->get(x)->order[0]];
num_bytes += pad * num_bytes / ld;
}
if(has_double(x))
num_bytes *= 2;
return num_bytes;
}
// Entry point
void liveness::run(ir::module &mod) {
double_.clear();
indices.clear();
pad_.clear();
intervals_.clear();
// Create set of pair of values that can be double-buffered
ir::for_each_instruction(mod, [this](ir::instruction* i) {
this->extract_double_bufferable(i);
});
// Padding information
bool has_changed;
do{
has_changed = false;
ir::for_each_value(mod, [this, &has_changed](ir::value* v){
has_changed |= this->do_pad(v);
});
}while(has_changed);
// connected components
for(auto &x: layouts_->get_all()) {
layout_t*& layout = x.second;
if(layout->type != SHARED)
continue;
for(ir::value *v: layout->values)
layout->size = std::max<int>(layout->size, num_bytes(v));
}
// Assigns index to each instruction
for(ir::function *fn: mod.get_function_list()){

View File

@@ -710,15 +710,15 @@ void selection::create_shared_tile(ir::value *v, IRBuilder<> &builder, Value *sh
return;
auto order = layouts_->get(v)->order;
auto shapes = v->get_type()->get_tile_shapes();
unsigned pad = liveness_->get_pad(v);
unsigned pad = layouts_->get(v)->pad;
if(pad > 0)
shapes[order[0]] += pad;
Type* ty = llvm_type(v->get_type()->get_scalar_ty(), builder.getContext());
// shared copy
PointerType *ptr_ty = ty->getPointerTo(sh_mem_ptr->getType()->getPointerAddressSpace());
// double-buffered
if(liveness_->has_double(v)) {
auto info = liveness_->get_double(v);
if(layouts_->get(v)->double_buffer) {
auto info = *layouts_->get(v)->double_buffer;
ir::phi_node *phi = info.phi;
BasicBlock *parent = (BasicBlock*)vmap_[phi->get_parent()];
if(parent->empty())
@@ -1532,10 +1532,9 @@ void selection::run(ir::module &src, Module &dst) {
}
}
// finalize double-buffering
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *inst: block->get_inst_list()) {
if(liveness_->has_double(inst)) {
auto info = liveness_->get_double(inst);
for(const auto& x: layouts_->get_all()) {
if(x.second->double_buffer) {
auto info = *x.second->double_buffer;
ir::phi_node *phi = info.phi;
PHINode *ptr = (PHINode*)((shared_tile*)tmap_.at(phi))->get_pointer();
PHINode *offset = (PHINode*)((shared_tile*)tmap_.at(phi))->get_offset();
@@ -1550,8 +1549,8 @@ void selection::run(ir::module &src, Module &dst) {
offset->addIncoming(next_offset, llvm_inc_block);
}
else {
unsigned num_bytes = inst->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
offset->addIncoming(dst_builder.getInt32(layouts_->get(inst)->size / (2*num_bytes)), llvm_inc_block);
unsigned num_bytes = x.second->ty->get_primitive_size_in_bits() / 8;
offset->addIncoming(dst_builder.getInt32(x.second->size / (2*num_bytes)), llvm_inc_block);
}
ptr->addIncoming(inc_shared->get_pointer(), llvm_inc_block);
}

View File

@@ -120,18 +120,15 @@ void membar::run(ir::module &mod) {
// shared-memory copies. These can be read from and written to
// without needing synchronization
std::set<ir::value*> safe_war;
ir::for_each_instruction(mod, [&](ir::instruction* i){
if(liveness_->has_double(i)){
auto info = liveness_->get_double(i);
safe_war.insert(i);
for(const auto& x: layouts_->get_all()){
if(x.second->double_buffer){
auto info = *x.second->double_buffer;
safe_war.insert(info.first);
safe_war.insert(info.latch);
auto *trans = dynamic_cast<ir::trans_inst*>(info.latch);
if(trans)
safe_war.insert(trans->get_operand(0));
}
if(i->get_id() == ir::INST_TRANS)
safe_war.insert(i);
});
}
for(ir::function *fn: mod.get_function_list()){
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);