[GENERAL] Cleaned polymorphic structure of layouts analysis pass

This commit is contained in:
Philippe Tillet
2020-01-20 15:15:32 -05:00
parent 382ca2c745
commit 78b98fb7cf
17 changed files with 500 additions and 480 deletions

View File

@@ -12,57 +12,15 @@ namespace triton{
namespace codegen{
namespace analysis{
/* -------------------------------- *
* Helper Functions *
* -------------------------------- */
// constructor
layout::layout(analysis::axes *axes, analysis::align *align, size_t num_warps)
: axes_(axes), align_(align), num_warps_(num_warps) { }
// get group id
unsigned layout::layout_of(ir::value *value) const
{ return groups_.at(value); }
// get values
const std::vector<ir::value*>& layout::values_of(unsigned id) const
{ return values_.at(id); }
// get number of groups
size_t layout::num_layouts() const
{ return values_.size(); }
// connect two values
void layout::connect(ir::value *x, ir::value *y) {
if(x == y)
return;
if(!x->get_type()->is_tile_ty())
return;
if(!y->get_type()->is_tile_ty())
return;
std::vector<int> x_axes = axes_->get(x);
std::vector<int> y_axes = axes_->get(y);
std::set<int> sx_axes(x_axes.begin(), x_axes.end());
std::set<int> sy_axes(y_axes.begin(), y_axes.end());
std::set<int> common;
std::set_intersection(sx_axes.begin(), sx_axes.end(),
sy_axes.begin(), sy_axes.end(),
std::inserter(common, common.begin()));
graph_.add_edge(x, x);
graph_.add_edge(y, y);
if(!common.empty())
graph_.add_edge(x, y);
inline unsigned clamp(unsigned x, unsigned lo, unsigned hi) {
return std::min(std::max(x, lo), hi);
}
// make graph
void layout::make_graph(ir::instruction *i) {
for(ir::value* opx: i->ops())
for(ir::value* opy: i->ops()){
connect(i, opx);
connect(opx, opy);
}
}
// hmma
bool is_hmma_c(ir::value *v){
inline bool is_hmma_c(ir::value *v){
bool result = false;
if(auto *x = dynamic_cast<ir::dot_inst*>(v)){
ir::value *a = x->get_operand(0);
@@ -75,23 +33,7 @@ bool is_hmma_c(ir::value *v){
return result;
}
layout_t* layout::get(size_t id) {
return layouts_.at(id);
}
layout_t* layout::get(ir::value *v) {
return layouts_.at(groups_.at(v));
}
std::map<size_t, layout_t*>& layout::get_all() {
return layouts_;
}
size_t layout::tmp(ir::instruction* i) {
return tmp_.at(i);
}
void extract_io_use(ir::value *v, std::set<ir::value*>& result) {
inline void extract_io_use(ir::value *v, std::set<ir::value*>& result) {
for(ir::user* u: v->get_users()){
auto i = dynamic_cast<ir::io_inst*>(u);
if(i && i->get_pointer_operand() == v)
@@ -99,7 +41,7 @@ void extract_io_use(ir::value *v, std::set<ir::value*>& result) {
}
}
void extract_dot_use(ir::value *v, ir::value*& result, size_t n) {
inline void extract_dot_use(ir::value *v, ir::value*& result, size_t n) {
for(ir::user* u: v->get_users()){
auto i = dynamic_cast<ir::dot_inst*>(u);
if(i && i->get_operand(n) == v)
@@ -107,7 +49,7 @@ void extract_dot_use(ir::value *v, ir::value*& result, size_t n) {
}
}
void extract_hmma_dot_use(ir::value *v, ir::value*& result, size_t n) {
inline void extract_hmma_dot_use(ir::value *v, ir::value*& result, size_t n) {
for(ir::user* u: v->get_users()){
auto i = dynamic_cast<ir::dot_inst*>(u);
if(i && is_hmma_c(i) && i->get_operand(n) == v)
@@ -116,7 +58,6 @@ void extract_hmma_dot_use(ir::value *v, ir::value*& result, size_t n) {
}
inline bool is_trans(ir::value *v) {
if(dynamic_cast<ir::trans_inst *>(v)) {
return true;
@@ -131,104 +72,103 @@ inline bool is_trans(ir::value *v) {
}
void layout_visitor::visit_layout(layout_t *layout) {
/* -------------------------------- *
* Layout Visitor *
* -------------------------------- */
void layout_visitor::visit_layout(data_layout *layout) {
layout->accept(this);
}
layout_t::layout_t(layout_type_t _type,
const std::vector<int> &_axes,
const std::vector<unsigned> &_shapes,
const std::vector<ir::value *> &_values, ir::type *_ty,
analysis::align* align): type(_type), axes(_axes), shapes(_shapes), values(_values), ty(_ty) {
/* -------------------------------- *
* Base Data Layout *
* -------------------------------- */
data_layout::data_layout(id_t id,
const std::vector<int> &axes,
const std::vector<unsigned> &shape,
const std::vector<ir::value *> &values,
analysis::align* align): id_(id), axes_(axes), shape_(shape), values_(values) {
// io pointer
std::set<ir::value*> ptr;
for(ir::value* v: values)
for(ir::value* v: values_)
extract_io_use(v, ptr);
order.resize(axes.size());
std::iota(order.begin(), order.end(), 0);
order_.resize(axes_.size());
std::iota(order_.begin(), order_.end(), 0);
auto largest = std::max_element(ptr.begin(), ptr.end(), [&](ir::value *x, ir::value *y){
return x->get_type()->get_tile_rank() < y->get_type()->get_tile_rank();
});
if(*largest){
auto max_contiguous = align->contiguous(*largest);
std::sort(order.begin(), order.end(), [&](unsigned a, unsigned b) {
std::sort(order_.begin(), order_.end(), [&](unsigned a, unsigned b) {
return max_contiguous[a] > max_contiguous[b];
});
}
}
// downcast
layout_hmma_884_t* layout_t::to_hmma884() {
assert(type == HMMA_884);
return static_cast<layout_hmma_884_t*>(this);
size_t data_layout::find_axis(int to_find) const {
auto it = std::find(axes_.begin(), axes_.end(), to_find);
return std::distance(axes_.begin(), it);
}
layout_scanline_t* layout_t::to_scanline() {
assert(type == SCANLINE);
return static_cast<layout_scanline_t*>(this);
}
layout_shared_t* layout_t::to_shared() {
assert(type == SHARED);
return static_cast<layout_shared_t*>(this);
}
/* -------------------------------- *
* MMA Layout *
* -------------------------------- */
inline unsigned clamp(unsigned x, unsigned lo, unsigned hi) {
return std::min(std::max(x, lo), hi);
}
layout_hmma_884_t::layout_hmma_884_t(size_t num_warps,
const std::vector<int>& _axes,
const std::vector<unsigned>& _shapes,
const std::vector<ir::value *> &values, ir::type *_ty,
analysis::align* align): layout_t(HMMA_884, _axes, _shapes, values, _ty, align) {
unsigned shape_0 = shapes[0];
unsigned shape_1 = shapes[1];
mma884_layout::mma884_layout(size_t num_warps,
const std::vector<int>& axes,
const std::vector<unsigned>& shape,
const std::vector<ir::value *> &values,
analysis::align* align): data_layout(HMMA_884, axes, shape, values, align) {
/* fragments per warp */
// try to make things as square as possible to maximize data re-use
fpw = {1, 1, 1};
fpw_ = {1, 1, 1};
std::vector<int> fpw_nm1;
unsigned num_fragments = std::min<unsigned>((shape_0/8)*(shape_1/8), 4);
unsigned num_fragments = std::min<unsigned>((shape_[0]/8)*(shape_[1]/8), 4);
do {
fpw_nm1 = fpw;
if(fpw[0]*fpw[1] < num_fragments)
fpw[0] = clamp(fpw[0]*2, 1, shape_0 / 8);
if(fpw[0]*fpw[1] < num_fragments)
fpw[1] = clamp(fpw[1]*2, 1, shape_1 / 8);
}while(fpw_nm1 != fpw);
fpw_nm1 = fpw_;
if(fpw_[0]*fpw_[1] < num_fragments)
fpw_[0] = clamp(fpw_[0]*2, 1, shape_[0] / 8);
if(fpw_[0]*fpw_[1] < num_fragments)
fpw_[1] = clamp(fpw_[1]*2, 1, shape_[1] / 8);
}while(fpw_nm1 != fpw_);
/* warps per tile */
// try to make things as square as possible to maximize data re-use
wpt = {1, 1, 1};
wpt_ = {1, 1, 1};
std::vector<int> wpt_nm1;
do{
wpt_nm1 = wpt;
if(wpt[0] * wpt[1] * wpt[2] < num_warps)
wpt[0] = clamp(wpt[0]*2, 1, shape_0 / (fpw[0]*8));
if(wpt[0] * wpt[1] * wpt[2] < num_warps)
wpt[1] = clamp(wpt[1]*2, 1, shape_1 / (fpw[1]*8));
}while(wpt_nm1 != wpt);
wpt_nm1 = wpt_;
if(wpt_[0] * wpt_[1] * wpt_[2] < num_warps)
wpt_[0] = clamp(wpt_[0]*2, 1, shape_[0] / (fpw_[0]*8));
if(wpt_[0] * wpt_[1] * wpt_[2] < num_warps)
wpt_[1] = clamp(wpt_[1]*2, 1, shape_[1] / (fpw_[1]*8));
}while(wpt_nm1 != wpt_);
/* sanity check */
unsigned effective_num_warps = 1;
for(size_t d = 0; d < shapes.size(); d++)
effective_num_warps *= wpt[d];
for(size_t d = 0; d < shape.size(); d++)
effective_num_warps *= wpt_[d];
if(num_warps != effective_num_warps)
throw std::runtime_error("cannot create a kernel with this amount of warps");
}
/* -------------------------------- *
* Scanline Layout *
* -------------------------------- */
layout_scanline_t::layout_scanline_t(size_t num_warps,
const std::vector<int>& _axes,
const std::vector<unsigned>& _shapes,
const std::vector<ir::value *> &values, ir::type *_ty,
analysis::align* align): layout_t(SCANLINE, _axes, _shapes, values, _ty, align){
unsigned size = std::accumulate(shapes.begin(), shapes.end(), 1, std::multiplies<int>());
scanline_layout::scanline_layout(size_t num_warps,
const std::vector<int>& axes,
const std::vector<unsigned>& shape,
const std::vector<ir::value *> &values,
analysis::align* align): data_layout(SCANLINE, axes, shape, values, align){
unsigned size = std::accumulate(shape_.begin(), shape_.end(), 1, std::multiplies<int>());
unsigned num_threads = num_warps * 32;
nts.resize(shapes.size());
mts.resize(shapes.size());
nts_.resize(shape_.size());
mts_.resize(shape_.size());
bool is_dot = std::any_of(values.begin(), values.end(),
[&](ir::value* v) { return dynamic_cast<ir::dot_inst*>(v); });
@@ -238,34 +178,39 @@ layout_scanline_t::layout_scanline_t(size_t num_warps,
if(auto *st = dynamic_cast<ir::store_inst*>(usr))
ptr = st->get_pointer_operand();
unsigned i = order[0];
unsigned i = order_[0];
int contiguous = 4;
if(ptr)
contiguous = std::min<int>(align->contiguous(ptr)[i], 4);
nts[i] = clamp(size / num_threads, 1, std::min<int>(contiguous, shapes[i]));
mts[i] = clamp(num_threads, 1, shapes[i] / nts[i]);
size /= shapes[i];
num_threads /= mts[i];
nts_[i] = clamp(size / num_threads, 1, std::min<int>(contiguous, shape_[i]));
mts_[i] = clamp(num_threads, 1, shape_[i] / nts_[i]);
size /= shape_[i];
num_threads /= mts_[i];
if(is_dot)
nts[order[1]] = clamp(size / num_threads, 1, std::min<int>(4, shapes[order[1]]));
for(size_t d = 1; d < shapes.size(); d++){
i = order[d];
nts_[order_[1]] = clamp(size / num_threads, 1, std::min<int>(4, shape_[order_[1]]));
for(size_t d = 1; d < shape_.size(); d++){
i = order_[d];
if(d > 1 || !is_dot)
nts[i] = 1;
mts[i] = clamp(num_threads, 1, shapes[i] / nts[i]);
num_threads = num_threads / mts[i];
nts_[i] = 1;
mts_[i] = clamp(num_threads, 1, shape_[i] / nts_[i]);
num_threads = num_threads / mts_[i];
}
/* sanity check */
unsigned effective_num_threads = 1;
for(size_t d = 0; d < shapes.size(); d++)
effective_num_threads *= mts[d];
for(size_t d = 0; d < shape_.size(); d++)
effective_num_threads *= mts_[d];
if(num_warps * 32 != effective_num_threads)
throw std::runtime_error("cannot create a kernel with this amount of warps");
}
inline bool is_loop_latch(ir::phi_node *phi, ir::instruction *terminator){
/* -------------------------------- *
* Shared Layout *
* -------------------------------- */
bool shared_layout::is_loop_latch(ir::phi_node *phi, ir::instruction *terminator){
if(phi->get_parent() != terminator->get_parent())
return false;
if(auto *br = dynamic_cast<ir::cond_branch_inst*>(terminator))
@@ -278,7 +223,7 @@ inline bool is_loop_latch(ir::phi_node *phi, ir::instruction *terminator){
}
void extract_double_bufferable(ir::value *v, std::shared_ptr<double_buffer_info_t>& res) {
void shared_layout::extract_double_bufferable(ir::value *v, std::shared_ptr<double_buffer_info_t>& res) {
auto* phi = dynamic_cast<ir::phi_node*>(v);
if(!phi || phi->get_num_incoming() != 2)
return;
@@ -303,22 +248,22 @@ void extract_double_bufferable(ir::value *v, std::shared_ptr<double_buffer_info_
}
layout_shared_t::layout_shared_t(const layout_t *arg,
const std::vector<int>& _axes,
const std::vector<unsigned>& _shapes,
shared_layout::shared_layout(const data_layout *arg,
const std::vector<int>& axes,
const std::vector<unsigned>& shape,
const std::vector<ir::value *> &values,
ir::type *ty,
analysis::align* align): layout_t(SHARED, _axes, _shapes, values, ty, align) {
analysis::align* align): data_layout(SHARED, axes, shape, values, align), ty_(ty) {
size = 0;
size_ = 0;
// double-buffering
for(ir::value *v: values)
extract_double_bufferable(v, double_buffer);
extract_double_bufferable(v, double_buffer_);
// order
std::vector<int> arg_order = arg ? arg->order : std::vector<int>{0};
order = arg_order;
std::vector<int> arg_order = arg ? arg->get_order() : std::vector<int>{0};
order_ = arg_order;
ir::value* dot_a = nullptr;
ir::value* dot_b = nullptr;
@@ -330,48 +275,84 @@ layout_shared_t::layout_shared_t(const layout_t *arg,
extract_hmma_dot_use(v, hmma_dot_a, 0);
extract_hmma_dot_use(v, hmma_dot_b, 1);
}
// non-mma ordering
std::vector<int> col = {0, 1};
std::vector<int> row = {1, 0};
for(size_t s = 2; s < shapes.size(); s++){
for(size_t s = 2; s < get_rank(); s++){
col.push_back(s);
row.push_back(s);
}
bool is_nonhmma_dot_a = dot_a && !hmma_dot_a;
bool is_nonhmma_dot_b = dot_b && !hmma_dot_b;
if(is_nonhmma_dot_a)
order = is_trans(dot_a) ? row : col;
order_ = is_trans(dot_a) ? row : col;
else if(is_nonhmma_dot_b)
order = is_trans(dot_b) ? col : row;
// else
// order = row;
order_ = is_trans(dot_b) ? col : row;
// padding
size_t pad = 0;
if(hmma_dot_a){
bool row = is_trans(hmma_dot_a) ^ order[0] != 0;
pad = 24 - shapes[row ? 0 : 1] % 32;
bool row = is_trans(hmma_dot_a) ^ order_[0] != 0;
pad = 24 - shape_[row ? 0 : 1] % 32;
}
else if(hmma_dot_b){
bool row = is_trans(hmma_dot_b) ^ order[0] != 0;
pad = 24 - shapes[row ? 1 : 0] % 32;
bool row = is_trans(hmma_dot_b) ^ order_[0] != 0;
pad = 24 - shape_[row ? 1 : 0] % 32;
}
else if(order != arg_order) {
else if(order_ != arg_order) {
pad = 4;
}
shapes[order[0]] += pad;
shape_[order_[0]] += pad;
// size
size = ty->get_primitive_size_in_bits() / 8;
for(auto s: shapes)
size *= s;
if(double_buffer)
size *= 2;
size_ = ty_->get_primitive_size_in_bits() / 8;
for(auto s: shape_)
size_ *= s;
if(double_buffer_)
size_ *= 2;
}
// layout factory method
void layout::create(size_t id, const std::vector<ir::value*>& values) {
/* -------------------------------- *
* ---- Layouts Inference Pass ---- *
* -------------------------------- */
layouts::layouts(analysis::axes *axes, analysis::align *align, size_t num_warps)
: axes_(axes), align_(align), num_warps_(num_warps) { }
void layouts::connect(ir::value *x, ir::value *y) {
if(x == y)
return;
if(!x->get_type()->is_tile_ty())
return;
if(!y->get_type()->is_tile_ty())
return;
std::vector<int> x_axes = axes_->get(x);
std::vector<int> y_axes = axes_->get(y);
std::set<int> sx_axes(x_axes.begin(), x_axes.end());
std::set<int> sy_axes(y_axes.begin(), y_axes.end());
std::set<int> common;
std::set_intersection(sx_axes.begin(), sx_axes.end(),
sy_axes.begin(), sy_axes.end(),
std::inserter(common, common.begin()));
graph_.add_edge(x, x);
graph_.add_edge(y, y);
if(!common.empty())
graph_.add_edge(x, y);
}
void layouts::make_graph(ir::instruction *i) {
for(ir::value* opx: i->ops())
for(ir::value* opy: i->ops()){
connect(i, opx);
connect(opx, opy);
}
}
void layouts::create(size_t id, const std::vector<ir::value*>& values) {
auto it_hmma_c = std::find_if(values.begin(), values.end(), &is_hmma_c);
auto cmp = [](ir::value* x, ir::value *y) {
return x->get_type()->get_tile_ranks1() <
@@ -387,18 +368,18 @@ void layout::create(size_t id, const std::vector<ir::value*>& values) {
});
// type
if(it_hmma_c != values.end())
layouts_[id] = new layout_hmma_884_t(num_warps_, axes, shapes, values, largest->get_type()->get_scalar_ty(), align_);
layouts_[id] = new mma884_layout(num_warps_, axes, shapes, values, align_);
else if(it_cts != values.end()){
ir::copy_to_shared_inst *cts = (ir::copy_to_shared_inst*)*it_cts;
ir::value *arg = cts->get_operand(0);
create(groups_.at(arg), values_.at(groups_.at(arg)));
layouts_[id] = new layout_shared_t(get(arg), axes, shapes, values, largest->get_type()->get_scalar_ty(), align_);
layouts_[id] = new shared_layout(get(arg), axes, shapes, values, largest->get_type()->get_scalar_ty(), align_);
}
else
layouts_[id] = new layout_scanline_t(num_warps_, axes, shapes, values, largest->get_type()->get_scalar_ty(), align_);
layouts_[id] = new scanline_layout(num_warps_, axes, shapes, values, align_);
}
void layout::run(ir::module &mod) {
void layouts::run(ir::module &mod) {
// make graph
graph_.clear();
ir::for_each_instruction(mod, [this](ir::instruction* i) {
@@ -422,35 +403,35 @@ void layout::run(ir::module &mod) {
// shape
auto shapes = arg->get_type()->get_tile_shapes();
unsigned shape_ax = shapes[axis];
layout_scanline_t *layout = get(arg)->to_scanline();
unsigned per_thread = layout->nts[axis];
scanline_layout *layout = get(arg)->to_scanline();
unsigned per_thread = layout->nts(axis);
unsigned depth = shape_ax / per_thread;
shapes[axis] = depth;
// create layout
layouts_[id] = new layout_shared_t(layout, axes_->get(arg), shapes, {red}, red->get_type()->get_scalar_ty(), align_);
layouts_[id] = new shared_layout(layout, axes_->get(arg), shapes, {red}, red->get_type()->get_scalar_ty(), align_);
tmp_[red] = id;
}
if(auto *recoalasce = dynamic_cast<ir::recoalesce_inst*>(i)){
ir::value *val = recoalasce->get_operand(0);
layout_t* in_layout = get(val);
layout_t* out_layout = get(i);
if(in_layout->type != HMMA_884)
mma884_layout* in_layout = get(val)->to_mma884();
scanline_layout* out_layout = get(i)->to_scanline();
if(!in_layout || !out_layout)
return;
id++;
ir::type::tile_shapes_t in_shape = val->get_type()->get_tile_shapes();
ir::type::tile_shapes_t shape(in_shape.size());
size_t ld = out_layout->order[0];
size_t ld = out_layout->get_order(0);
shape[ld] = in_shape[ld];
for(size_t k = 0; k < in_shape.size(); k++)
if(k != ld)
shape[k] = 4*in_layout->to_hmma884()->fpw[k]*in_layout->to_hmma884()->wpt[k];
shape[k] = 4*in_layout->to_mma884()->fpw(k)*in_layout->to_mma884()->wpt(k);
// create layout
layouts_[id] = new layout_shared_t(out_layout, axes_->get(val), shape, {recoalasce}, val->get_type()->get_scalar_ty(), align_);
layouts_[id] = new shared_layout(out_layout, axes_->get(val), shape, {recoalasce}, val->get_type()->get_scalar_ty(), align_);
tmp_[recoalasce] = id;
}
if(auto *atom = dynamic_cast<ir::atomic_cas_inst*>(i)){
id++;
layouts_[id] = new layout_shared_t(nullptr, {}, {1}, {atom}, atom->get_type()->get_scalar_ty(), align_);
layouts_[id] = new shared_layout(nullptr, {}, {1}, {atom}, atom->get_type()->get_scalar_ty(), align_);
tmp_[atom] = id;
}
});