[codegen] more cleaning

This commit is contained in:
Philippe Tillet
2019-10-09 15:05:44 -04:00
parent 10ab94d1c5
commit 9bc6df4fd1
10 changed files with 226 additions and 252 deletions

View File

@@ -19,6 +19,20 @@ namespace codegen{
namespace analysis{ namespace analysis{
class axes; class axes;
class align;
enum layout_type_t {
HMMA_884,
SCANLINE
};
struct layout_t {
layout_type_t type;
ir::value *i;
std::vector<int> axes;
std::vector<unsigned> shapes;
std::vector<int> order;
};
class layout { class layout {
typedef ir::value* node_t; typedef ir::value* node_t;
@@ -31,19 +45,24 @@ private:
public: public:
// constructor // constructor
layout(analysis::axes *axes); layout(analysis::axes *axes, analysis::align *align);
// accessors // accessors
unsigned layout_of(ir::value *value) const; unsigned layout_of(ir::value *value) const;
const std::vector<ir::value*>& values_of(unsigned id) const; const std::vector<ir::value*>& values_of(unsigned id) const;
size_t num_layouts() const; size_t num_layouts() const;
layout_t get(ir::value *v) const;
const std::map<size_t, layout_t>& get_all() const;
// execution // execution
void run(ir::module &mod); void run(ir::module &mod);
private: private:
analysis::axes* axes_; analysis::axes* axes_;
analysis::align* align_;
tools::graph<ir::value*> graph_; tools::graph<ir::value*> graph_;
std::map<ir::value*, size_t> groups_; std::map<ir::value*, size_t> groups_;
std::map<size_t, std::vector<ir::value*>> values_; std::map<size_t, std::vector<ir::value*>> values_;
std::map<size_t, layout_t> layouts_;
}; };
} }

View File

@@ -22,6 +22,7 @@ namespace analysis{
typedef unsigned slot_index; typedef unsigned slot_index;
class tiles; class tiles;
class layout;
struct segment { struct segment {
slot_index start; slot_index start;
@@ -72,7 +73,7 @@ private:
public: public:
liveness(tiles *t): tiles_(t){ } liveness(tiles *t, layout *l): tiles_(t), layouts_(l){ }
// padding // padding
unsigned get_pad(ir::value *v) const { return pad_.at(v); } unsigned get_pad(ir::value *v) const { return pad_.at(v); }
// buffer size // buffer size
@@ -92,6 +93,7 @@ public:
private: private:
// analysis // analysis
tiles *tiles_; tiles *tiles_;
layout *layouts_;
// stuff // stuff
has_storage_map_t has_dedicated_storage_; has_storage_map_t has_dedicated_storage_;
indices_map_t indices; indices_map_t indices;

View File

@@ -5,6 +5,7 @@
#include <set> #include <set>
#include <vector> #include <vector>
#include <memory> #include <memory>
#include "triton/codegen/analysis/layout.h"
namespace triton{ namespace triton{
@@ -25,28 +26,22 @@ class axes;
class layout; class layout;
class align; class align;
enum layout_t {
SCANLINE,
HMMA_C
};
class tiles { class tiles {
typedef std::map<ir::value*, std::map<int, int>> param_map_t; typedef std::map<ir::value*, std::map<int, int>> param_map_t;
private: private:
void init_hmma_tile(ir::value *i); void init_hmma_tile(const layout_t& layout);
void init_scanline_tile(ir::value *i); void init_scanline_tile(const layout_t& layout);
bool is_trans(ir::value *i); bool is_trans(ir::value *i);
public: public:
tiles(size_t num_warps, analysis::align* align, analysis::axes* axes, analysis::layout* layout); tiles(size_t num_warps, analysis::align* align, analysis::axes* axes, analysis::layout* layout);
void run(ir::module &mod); void run(ir::module &mod);
layout_t hmma(ir::value *value);
int mts(ir::value *value, unsigned ax); int mts(ir::value *value, unsigned ax);
int nts(ir::value *value, unsigned ax); int nts(ir::value *value, unsigned ax);
int fpw(ir::value *value, unsigned ax); int fpw(ir::value *value, unsigned ax);
int wpt(ir::value *value, unsigned ax); int wpt(ir::value *value, unsigned ax);
std::vector<int> order(ir::value *v);
const std::map<int, ir::value*>& largest();
private: private:
// dependencies // dependencies
@@ -56,9 +51,6 @@ private:
// number of warps // number of warps
size_t num_warps_; size_t num_warps_;
// tile properties // tile properties
std::map<int, ir::value*> largest_;
std::map<int, std::vector<int>> order_;
std::map<int, layout_t> hmma_;
std::map<int, int> fpw_; std::map<int, int> fpw_;
std::map<int, int> wpt_; std::map<int, int> wpt_;
std::map<int, int> mts_; std::map<int, int> mts_;

View File

@@ -5,6 +5,7 @@
#include "triton/ir/module.h" #include "triton/ir/module.h"
#include "triton/ir/function.h" #include "triton/ir/function.h"
#include "triton/ir/type.h" #include "triton/ir/type.h"
#include "triton/codegen/analysis/layout.h"
#include "triton/codegen/transform/cts.h" #include "triton/codegen/transform/cts.h"
@@ -171,9 +172,9 @@ private:
void create_shared_tile(ir::value *v, Builder &builder, Value *sh_mem_ptr); void create_shared_tile(ir::value *v, Builder &builder, Value *sh_mem_ptr);
void create_distributed_tile(ir::value *v, Builder &builder); void create_distributed_tile(ir::value *v, Builder &builder);
void create_tile(ir::value *v, Builder &builder, std::set<ir::value *> &seen, Value *sh_mem_ptr); void create_tile(ir::value *v, Builder &builder, std::set<ir::value *> &seen, Value *sh_mem_ptr);
void init_strided_scan_axes(ir::value *i, Builder &builder, Value *u_thread_id, Value *u_warp_id); void init_strided_scan_axes(const analysis::layout_t& layout, Builder &builder, Value *u_thread_id, Value *u_warp_id);
void init_hmma_axes(ir::value *i, Builder &builder, Value *u_thread_id, Value *u_warp_id); void init_hmma_axes(const analysis::layout_t& layout, Builder &builder, Value *u_thread_id, Value *u_warp_id);
void init_axes(ir::value *i, Builder &builder, Value *u_thread_id, Value *u_warp_id); void init_axes(const analysis::layout_t& layout, Builder &builder, Value *u_thread_id, Value *u_warp_id);
void init_layouts(ir::function *fn, Builder &builder, Value *sh_mem_ptr); void init_layouts(ir::function *fn, Builder &builder, Value *sh_mem_ptr);
// lower scalar instruction // lower scalar instruction

View File

@@ -1,6 +1,8 @@
#include <algorithm> #include <algorithm>
#include <iostream> #include <iostream>
#include <numeric>
#include "triton/codegen/analysis/axes.h" #include "triton/codegen/analysis/axes.h"
#include "triton/codegen/analysis/align.h"
#include "triton/codegen/analysis/layout.h" #include "triton/codegen/analysis/layout.h"
#include "triton/ir/function.h" #include "triton/ir/function.h"
#include "triton/ir/module.h" #include "triton/ir/module.h"
@@ -12,8 +14,8 @@ namespace analysis{
// constructor // constructor
layout::layout(analysis::axes *axes) layout::layout(analysis::axes *axes, analysis::align *align)
: axes_(axes) { } : axes_(axes), align_(align) { }
// get group id // get group id
unsigned layout::layout_of(ir::value *value) const unsigned layout::layout_of(ir::value *value) const
@@ -56,6 +58,51 @@ void layout::make_graph(ir::instruction *i) {
} }
} }
// hmma
bool is_hmma_c(ir::value *v){
bool result = false;
if(auto *x = dynamic_cast<ir::dot_inst*>(v)){
ir::value *a = x->get_operand(0);
ir::type *a_ty = a->get_type();
ir::value *b = x->get_operand(1);
ir::type *b_ty = b->get_type();
result = a_ty->get_scalar_ty()->is_half_ty() &&
b_ty->get_scalar_ty()->is_half_ty();
}
return result;
}
layout_t layout::get(ir::value *v) const {
return layouts_.at(groups_.at(v));
}
const std::map<size_t, layout_t>& layout::get_all() const {
return layouts_;
}
void extract_io_use(ir::value *v, std::set<ir::io_inst*>& result) {
for(ir::user* u: v->get_users()){
auto i = dynamic_cast<ir::io_inst*>(u);
if(i && i->get_pointer_operand() == v)
result.insert(i);
}
}
inline bool is_trans(ir::value *v) {
if(dynamic_cast<ir::trans_inst *>(v)) {
return true;
}
if(auto *phi = dynamic_cast<ir::instruction *>(v)) {
bool result = true;
for(ir::value *op: phi->ops())
result = result && is_trans(op);
return result;
}
return false;
}
void layout::run(ir::module &mod) { void layout::run(ir::module &mod) {
// make graph // make graph
graph_.clear(); graph_.clear();
@@ -64,6 +111,82 @@ void layout::run(ir::module &mod) {
}); });
// connected components // connected components
graph_.connected_components(&values_, &groups_); graph_.connected_components(&values_, &groups_);
// create layouts
for(const auto& x: values_) {
bool hmma_c = std::any_of(x.second.begin(), x.second.end(), &is_hmma_c);
layouts_[x.first].type = hmma_c ? HMMA_884 : SCANLINE;
}
/* ---- TO CLEAN ---- */
size_t num_groups = num_layouts();
// helpers
auto rank = [this](ir::value* v) {
int ret = 0;
for(int s: v->get_type()->get_tile_shapes())
ret += s > 1;
return ret;
};
// find out which value is the largest in each group
for(const auto& x: values_) {
auto cmp = [&rank](ir::value* x, ir::value *y) { return rank(x) < rank(y); };
ir::value *largest = *std::max_element(x.second.begin(), x.second.end(), cmp);
layouts_[x.first].axes = axes_->get(largest);
layouts_[x.first].i = largest;
layouts_[x.first].shapes = largest->get_type()->get_tile_shapes();
}
// find out the layout ordering of a group
for(size_t i = 0; i < num_groups; i++){
std::set<ir::io_inst*> io;
for(ir::value* v: values_of(i))
extract_io_use(v, io);
auto cmp = [&rank](ir::io_inst* x, ir::io_inst *y) {
return rank(x->get_pointer_operand()) < rank(y->get_pointer_operand());
};
auto it = std::max_element(io.begin(), io.end(), cmp);
std::vector<int> order(layouts_[i].axes.size());
std::iota(order.begin(), order.end(), 0);
if(it != io.end()) {
auto max_contiguous = align_->contiguous((*it)->get_pointer_operand());
std::sort(order.begin(), order.end(), [&](unsigned a, unsigned b) {
return max_contiguous[a] > max_contiguous[b]; }
);
}
layouts_[i].order = order;
}
// matrix multiplication optimizations
for(size_t i = 0; i < num_groups; i++){
std::vector<ir::dot_inst*> dots;
for(ir::value* v: values_of(i))
if(auto *x = dynamic_cast<ir::dot_inst*>(v))
dots.push_back(x);
for(ir::dot_inst* dot: dots){
ir::value* a = dot->get_operand(0);
ir::value* b = dot->get_operand(1);
if(get(dot).type == HMMA_884){
auto a_val = values_of(layout_of(a));
auto b_val = values_of(layout_of(b));
for(ir::value *v: a_val)
if(auto *cts = dynamic_cast<ir::copy_to_shared_inst*>(v))
layouts_[layout_of(a)].order = layouts_[layout_of(cts->get_operand(0))].order;
for(ir::value *v: b_val)
if(auto *cts = dynamic_cast<ir::copy_to_shared_inst*>(v))
layouts_[layout_of(b)].order = layouts_[layout_of(cts->get_operand(0))].order;
}
else{
std::vector<int> col = {0, 1};
std::vector<int> row = {1, 0};
layouts_[layout_of(a)].order = is_trans(a) ? row : col;
layouts_[layout_of(b)].order = is_trans(b) ? col : row;
}
}
}
} }
} }

View File

@@ -4,6 +4,7 @@
#include "triton/codegen/instructions.h" #include "triton/codegen/instructions.h"
#include "triton/codegen/analysis/liveness.h" #include "triton/codegen/analysis/liveness.h"
#include "triton/codegen/analysis/tiles.h" #include "triton/codegen/analysis/tiles.h"
#include "triton/codegen/analysis/layout.h"
#include "triton/codegen/transform/cts.h" #include "triton/codegen/transform/cts.h"
#include "triton/ir/basic_block.h" #include "triton/ir/basic_block.h"
#include "triton/ir/function.h" #include "triton/ir/function.h"
@@ -89,8 +90,8 @@ bool liveness::do_pad(ir::value *x) {
ir::value *b = dot->get_operand(1); ir::value *b = dot->get_operand(1);
size_t a_previous = pad_[a]; size_t a_previous = pad_[a];
size_t b_previous = pad_[b]; size_t b_previous = pad_[b];
auto a_order = tiles_->order(a); auto a_order = layouts_->get(a).order;
auto b_order = tiles_->order(b); auto b_order = layouts_->get(b).order;
bool a_row = is_trans(a) ^ (a_order[0] == 1); bool a_row = is_trans(a) ^ (a_order[0] == 1);
bool b_row = is_trans(b) ^ (b_order[0] == 1); bool b_row = is_trans(b) ^ (b_order[0] == 1);
auto a_shapes = a->get_type()->get_tile_shapes(); auto a_shapes = a->get_type()->get_tile_shapes();
@@ -108,9 +109,9 @@ bool liveness::do_pad(ir::value *x) {
} }
// padding for copy to shared // padding for copy to shared
if(auto* cts = dynamic_cast<ir::copy_to_shared_inst*>(x)) { if(auto* cts = dynamic_cast<ir::copy_to_shared_inst*>(x)) {
auto cts_order = tiles_->order(cts); auto cts_order = layouts_->get(cts).order;
ir::value *arg = cts->get_operand(0); ir::value *arg = cts->get_operand(0);
auto arg_order = tiles_->order(arg); auto arg_order = layouts_->get(arg).order;
size_t previous = pad_[cts]; size_t previous = pad_[cts];
if(cts_order != arg_order) if(cts_order != arg_order)
pad_[cts] = std::max<int>(pad_[cts], 4); pad_[cts] = std::max<int>(pad_[cts], 4);
@@ -144,7 +145,7 @@ unsigned liveness::num_bytes(ir::value *x) {
for(auto x: shapes) for(auto x: shapes)
num_elements *= x; num_elements *= x;
size_t depth; size_t depth;
if(tiles_->hmma(x)) if(layouts_->get(x).type == HMMA_884)
depth = tiles_->wpt(op, axis); depth = tiles_->wpt(op, axis);
else else
depth = tiles_->mts(op, axis); depth = tiles_->mts(op, axis);
@@ -153,7 +154,7 @@ unsigned liveness::num_bytes(ir::value *x) {
unsigned num_bytes = x->get_type()->get_primitive_size_in_bits() / 8; unsigned num_bytes = x->get_type()->get_primitive_size_in_bits() / 8;
unsigned pad = pad_.at(x); unsigned pad = pad_.at(x);
if(pad > 0){ if(pad > 0){
unsigned ld = x->get_type()->get_tile_shapes()[tiles_->order(x)[0]]; unsigned ld = x->get_type()->get_tile_shapes()[layouts_->get(x).order[0]];
num_bytes += pad * num_bytes / ld; num_bytes += pad * num_bytes / ld;
} }
if(has_double(x)) if(has_double(x))

View File

@@ -23,59 +23,7 @@ tiles::tiles(size_t num_warps, analysis::align *align, analysis::axes *axes, ana
num_warps_(num_warps), align_(align), axes_(axes), layout_(layout) num_warps_(num_warps), align_(align), axes_(axes), layout_(layout)
{ } { }
bool is_hmma_c(ir::value *v){
bool result = false;
if(auto *x = dynamic_cast<ir::dot_inst*>(v)){
ir::value *a = x->get_operand(0);
ir::type *a_ty = a->get_type();
ir::value *b = x->get_operand(1);
ir::type *b_ty = b->get_type();
result = a_ty->get_scalar_ty()->is_half_ty() &&
b_ty->get_scalar_ty()->is_half_ty();
}
return result;
}
bool is_hmma_a_col(ir::value* v) {
for(ir::user *u: v->get_users())
if(is_hmma_c(u)){
ir::dot_inst* dot = (ir::dot_inst*)u;
if((v == dot->get_operand(0)))
return true;
}
}
bool is_hmma_a_row(ir::value* v) {
for(ir::user *u: v->get_users())
if(is_hmma_c(u)){
ir::dot_inst* dot = (ir::dot_inst*)u;
if((v == dot->get_operand(0)))
return true;
}
}
bool is_hmma_b_col(ir::value* v) {
for(ir::user *u: v->get_users())
if(is_hmma_c(u)){
ir::dot_inst* dot = (ir::dot_inst*)u;
if((v == dot->get_operand(1)))
return true;
}
}
bool is_hmma_b_row(ir::value* v) {
for(ir::user *u: v->get_users())
if(is_hmma_c(u)){
ir::dot_inst* dot = (ir::dot_inst*)u;
if((v == dot->get_operand(1)))
return true;
}
}
layout_t tiles::hmma(ir::value *value) {
return hmma_.at(layout_->layout_of(value));
}
int tiles::mts(ir::value *value, unsigned ax) { int tiles::mts(ir::value *value, unsigned ax) {
return mts_.at(axes_->get(value, ax)); return mts_.at(axes_->get(value, ax));
@@ -93,24 +41,15 @@ int tiles::wpt(ir::value *value, unsigned ax) {
return wpt_.at(axes_->get(value, ax)); return wpt_.at(axes_->get(value, ax));
} }
std::vector<int> tiles::order(ir::value *v) {
auto ret = order_[layout_->layout_of(v)];
return ret;
}
const std::map<int, ir::value*>& tiles::largest() {
return largest_;
}
unsigned clamp(unsigned x, unsigned lo, unsigned hi) { unsigned clamp(unsigned x, unsigned lo, unsigned hi) {
return std::min(std::max(x, lo), hi); return std::min(std::max(x, lo), hi);
} }
void tiles::init_hmma_tile(ir::value *i) { void tiles::init_hmma_tile(const layout_t& layout) {
auto ord = order(i); auto ord = layout.order;
auto shapes = i->get_type()->get_tile_shapes(); auto shapes = layout.i->get_type()->get_tile_shapes();
unsigned shape_0 = shapes[ord[0]]; unsigned shape_0 = shapes[ord[0]];
unsigned shape_1 = shapes[ord[1]]; unsigned shape_1 = shapes[ord[1]];
/* fragments per warp */ /* fragments per warp */
@@ -127,7 +66,7 @@ void tiles::init_hmma_tile(ir::value *i) {
}while(fpw_nm1 != fpw); }while(fpw_nm1 != fpw);
// store parameters // store parameters
for(unsigned d = 0; d < shapes.size(); d++) for(unsigned d = 0; d < shapes.size(); d++)
fpw_[axes_->get(i, d)] = fpw[d]; fpw_[layout.axes[d]] = fpw[d];
/* warps per tile */ /* warps per tile */
// try to make things as square as possible to maximize data re-use // try to make things as square as possible to maximize data re-use
std::vector<unsigned> wpt = {1, 1, 1}; std::vector<unsigned> wpt = {1, 1, 1};
@@ -141,149 +80,48 @@ void tiles::init_hmma_tile(ir::value *i) {
}while(wpt_nm1 != wpt); }while(wpt_nm1 != wpt);
// store parameters // store parameters
for(unsigned d = 0; d < shapes.size(); d++) for(unsigned d = 0; d < shapes.size(); d++)
wpt_[axes_->get(i, d)] = wpt[d]; wpt_[layout.axes[d]] = wpt[d];
/* sanity check */ /* sanity check */
unsigned effective_num_warps = 1; unsigned effective_num_warps = 1;
for(size_t d = 0; d < shapes.size(); d++) for(size_t d = 0; d < shapes.size(); d++)
effective_num_warps *= wpt_[axes_->get(i, d)]; effective_num_warps *= wpt_[layout.axes[d]];
if(num_warps_ != effective_num_warps) if(num_warps_ != effective_num_warps)
throw std::runtime_error("cannot create a kernel with this amount of warps"); throw std::runtime_error("cannot create a kernel with this amount of warps");
} }
void tiles::init_scanline_tile(ir::value *i) { void tiles::init_scanline_tile(const layout_t& layout) {
auto ord = order(i); auto ord = layout.order;
auto shapes = i->get_type()->get_tile_shapes(); auto shapes = layout.shapes;
unsigned size = i->get_type()->get_tile_num_elements(); unsigned size = std::accumulate(shapes.begin(), shapes.end(), 1, std::multiplies<int>());
unsigned ld = ord[0]; unsigned ld = ord[0];
unsigned num_threads = num_warps_*32; unsigned num_threads = num_warps_*32;
unsigned current = num_threads; unsigned current = num_threads;
nts_[axes_->get(i, ld)] = clamp(size / num_threads, 1, 4); nts_[layout.axes[ld]] = clamp(size / num_threads, 1, 4);
mts_[axes_->get(i, ld)] = clamp(current, 1, shapes[ld] / nts_[axes_->get(i, ld)]); mts_[layout.axes[ld]] = clamp(current, 1, shapes[ld] / nts_[layout.axes[ld]]);
current = current / mts_[axes_->get(i, ld)]; current = current / mts_[layout.axes[ld]];
for(size_t d = 1; d < shapes.size(); d++){ for(size_t d = 1; d < shapes.size(); d++){
ld = ord[d]; ld = ord[d];
nts_[axes_->get(i, ld)] = 1; nts_[layout.axes[ld]] = 1;
mts_[axes_->get(i, ld)] = clamp(current, 1, shapes[ld]); mts_[layout.axes[ld]] = clamp(current, 1, shapes[ld]);
current = current / mts_[axes_->get(i, ld)]; current = current / mts_[layout.axes[ld]];
} }
/* sanity check */ /* sanity check */
unsigned effective_num_threads = 1; unsigned effective_num_threads = 1;
for(size_t d = 0; d < shapes.size(); d++) for(size_t d = 0; d < shapes.size(); d++)
effective_num_threads *= mts_[axes_->get(i, d)]; effective_num_threads *= mts_[layout.axes[d]];
// std::cout << num_threads << " " << effective_num_threads << std::endl; // std::cout << num_threads << " " << effective_num_threads << std::endl;
if(num_threads != effective_num_threads) if(num_threads != effective_num_threads)
throw std::runtime_error("cannot create a kernel with this amount of warps"); throw std::runtime_error("cannot create a kernel with this amount of warps");
} }
void extract_io_use(ir::value *v, std::set<ir::io_inst*>& result) {
for(ir::user* u: v->get_users()){
auto i = dynamic_cast<ir::io_inst*>(u);
if(i && i->get_pointer_operand() == v)
result.insert(i);
}
}
bool tiles::is_trans(ir::value *v) {
if(dynamic_cast<ir::trans_inst *>(v)) {
return true;
}
if(auto *phi = dynamic_cast<ir::instruction *>(v)) {
bool result = true;
for(ir::value *op: phi->ops())
result = result && is_trans(op);
return result;
}
return false;
}
void tiles::run(ir::module &) { void tiles::run(ir::module &) {
hmma_.clear();
largest_.clear();
order_.clear();
size_t num_groups = layout_->num_layouts();
// helpers
auto rank = [](ir::value* v) {
int ret = 0;
for(int s: v->get_type()->get_tile_shapes())
ret += s > 1;
return ret;
};
// find out which groups require hmma layout
for(size_t i = 0; i < num_groups; i++) {
const auto& values = layout_->values_of(i);
bool hmma_c = std::any_of(values.begin(), values.end(), &is_hmma_c);
if(hmma_c) hmma_[i] = HMMA_C;
else hmma_[i] = SCANLINE;
}
// find out which value is the largest in each group
for(size_t i = 0; i < num_groups; i++) {
const auto& values = layout_->values_of(i);
auto cmp = [&rank](ir::value* x, ir::value *y) { return rank(x) < rank(y); };
largest_[i] = *std::max_element(values.begin(), values.end(), cmp);
}
// find out the layout ordering of a group
for(size_t i = 0; i < num_groups; i++){
std::set<ir::io_inst*> io;
for(ir::value* v: layout_->values_of(i))
extract_io_use(v, io);
auto cmp = [&rank](ir::io_inst* x, ir::io_inst *y) {
return rank(x->get_pointer_operand()) < rank(y->get_pointer_operand());
};
auto it = std::max_element(io.begin(), io.end(), cmp);
std::vector<int> order(rank(largest_[i]));
std::iota(order.begin(), order.end(), 0);
if(it != io.end()) {
auto max_contiguous = align_->contiguous((*it)->get_pointer_operand());
std::sort(order.begin(), order.end(), [&](unsigned a, unsigned b) {
return max_contiguous[a] > max_contiguous[b]; }
);
}
order_[i] = order;
}
// matrix multiplication optimizations
for(size_t i = 0; i < num_groups; i++){
std::vector<ir::dot_inst*> dots;
for(ir::value* v: layout_->values_of(i))
if(auto *x = dynamic_cast<ir::dot_inst*>(v))
dots.push_back(x);
for(ir::dot_inst* dot: dots){
ir::value* a = dot->get_operand(0);
ir::value* b = dot->get_operand(1);
if(hmma_.at(layout_->layout_of(dot)) == HMMA_C){
auto a_val = layout_->values_of(layout_->layout_of(a));
auto b_val = layout_->values_of(layout_->layout_of(b));
for(ir::value *v: a_val)
if(auto *cts = dynamic_cast<ir::copy_to_shared_inst*>(v))
order_[layout_->layout_of(a)] = order_[layout_->layout_of(cts->get_operand(0))];
for(ir::value *v: b_val)
if(auto *cts = dynamic_cast<ir::copy_to_shared_inst*>(v))
order_[layout_->layout_of(b)] = order_[layout_->layout_of(cts->get_operand(0))];
}
else{
std::vector<int> col = {0, 1};
std::vector<int> row = {1, 0};
order_[layout_->layout_of(a)] = is_trans(a) ? row : col;
order_[layout_->layout_of(b)] = is_trans(b) ? col : row;
}
}
}
// tiling parameters // tiling parameters
for(auto x: largest_){ for(auto x: layout_->get_all()){
ir::value *i = x.second;
if(!i->get_type()->is_tile_ty())
continue;
/* HMMA parameters*/ /* HMMA parameters*/
if(hmma_[x.first] == HMMA_C) if(x.second.type == HMMA_884)
init_hmma_tile(i); init_hmma_tile(x.second);
else else
init_scanline_tile(i); init_scanline_tile(x.second);
} }
} }

View File

@@ -577,37 +577,36 @@ inline int32_t ceil(int32_t num, int32_t div){
return (num + div - 1)/div; return (num + div - 1)/div;
} }
void selection::init_strided_scan_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) { void selection::init_strided_scan_axes(const analysis::layout_t& layout, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
auto order = tiles_->order(v); auto order = layout.order;
const auto& shapes = v->get_type()->get_tile_shapes(); const auto& shapes = layout.shapes;
size_t dim = shapes.size(); size_t dim = shapes.size();
std::vector<unsigned> contiguous(dim); std::vector<unsigned> nts(dim);
std::vector<unsigned> block_size(dim); std::vector<unsigned> mts(dim);
for(unsigned i = 0; i < shapes.size(); i++){ for(unsigned i = 0; i < shapes.size(); i++){
contiguous[i] = tiles_->nts(v, i); nts[i] = tiles_->nts(layout.i, i);
block_size[i] = tiles_->mts(v, i); mts[i] = tiles_->mts(layout.i, i);
} }
Value* full_thread_id = builder.CreateAdd(builder.CreateMul(u_warp_id, builder.getInt32(32)), u_thread_id); Value* full_thread_id = builder.CreateAdd(builder.CreateMul(u_warp_id, builder.getInt32(32)), u_thread_id);
std::vector<Value*> thread_id = delinearize(full_thread_id, order, block_size, builder); std::vector<Value*> thread_id = delinearize(full_thread_id, order, mts, builder);
// Create axes // Create axes
for(unsigned k = 0; k < dim; k++) { for(unsigned k = 0; k < dim; k++) {
std::string str_k = std::to_string(k); std::string str_k = std::to_string(k);
Value *contiguous_k = builder.getInt32(contiguous[k]); Value *contiguous_k = builder.getInt32(nts[k]);
Value *scaled_thread_id = builder.CreateMul(thread_id[k], contiguous_k); Value *scaled_thread_id = builder.CreateMul(thread_id[k], contiguous_k);
unsigned per_block = contiguous[k] * block_size[k]; unsigned per_block = nts[k] * mts[k];
unsigned per_thread = contiguous[k] * shapes[k] / per_block; unsigned per_thread = nts[k] * shapes[k] / per_block;
std::vector<Value*> idx_list(per_thread); std::vector<Value*> idx_list(per_thread);
for(unsigned n = 0 ; n < per_thread; n++){ for(unsigned n = 0 ; n < per_thread; n++){
unsigned offset = n / contiguous[k] * per_block + n % contiguous[k]; unsigned offset = n / nts[k] * per_block + n % nts[k];
idx_list[n] = builder.CreateAdd(scaled_thread_id, builder.getInt32(offset), "idx_" + str_k + "_" + std::to_string(n)); idx_list[n] = builder.CreateAdd(scaled_thread_id, builder.getInt32(offset), "idx_" + str_k + "_" + std::to_string(n));
} }
axes_[a_axes_->get(v, k)] = distributed_axis{contiguous[k], idx_list, thread_id[k]}; axes_[layout.axes[k]] = distributed_axis{nts[k], idx_list, thread_id[k]};
} }
} }
void selection::init_hmma_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) { void selection::init_hmma_axes(const analysis::layout_t& layout, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
// auto order = reorder_->get_order(v); const auto& shapes = layout.shapes;
const auto& shapes = v->get_type()->get_tile_shapes();
if(shapes.size() > 3) if(shapes.size() > 3)
throw std::runtime_error("unsupported"); throw std::runtime_error("unsupported");
bool is_batched = shapes.size() >= 3; bool is_batched = shapes.size() >= 3;
@@ -619,13 +618,13 @@ void selection::init_hmma_axes(ir::value *v, IRBuilder<> &builder, Value *u_thre
Value *_16 = builder.getInt32(16); Value *_16 = builder.getInt32(16);
// fragments per warp // fragments per warp
unsigned fpw_0 = tiles_->fpw(v, 0); unsigned fpw_0 = tiles_->fpw(layout.i, 0);
unsigned fpw_1 = tiles_->fpw(v, 1); unsigned fpw_1 = tiles_->fpw(layout.i, 1);
unsigned fpw_2 = is_batched ? tiles_->fpw(v, 2) : 1; unsigned fpw_2 = is_batched ? tiles_->fpw(layout.i, 2) : 1;
// warps per tile // warps per tile
unsigned wpt_0 = tiles_->wpt(v, 0); unsigned wpt_0 = tiles_->wpt(layout.i, 0);
unsigned wpt_1 = tiles_->wpt(v, 1); unsigned wpt_1 = tiles_->wpt(layout.i, 1);
unsigned wpt_2 = is_batched ? tiles_->wpt(v, 2) : 1; unsigned wpt_2 = is_batched ? tiles_->wpt(layout.i, 2) : 1;
// hmma warp tile size // hmma warp tile size
unsigned hmma_wts_0 = fpw_0 * 8; unsigned hmma_wts_0 = fpw_0 * 8;
unsigned hmma_wts_1 = fpw_1 * 8; unsigned hmma_wts_1 = fpw_1 * 8;
@@ -706,18 +705,18 @@ void selection::init_hmma_axes(ir::value *v, IRBuilder<> &builder, Value *u_thre
/* axes */ /* axes */
axes_[a_axes_->get(v, 0)] = distributed_axis{1, idx_i, warp_id_0}; axes_[layout.axes[0]] = distributed_axis{1, idx_i, warp_id_0};
axes_[a_axes_->get(v, 1)] = distributed_axis{1, idx_j, warp_id_1}; axes_[layout.axes[1]] = distributed_axis{1, idx_j, warp_id_1};
if(is_batched) if(is_batched)
axes_[a_axes_->get(v, 2)] = distributed_axis{1, idx_z, warp_id_2}; axes_[layout.axes[2]] = distributed_axis{1, idx_z, warp_id_2};
} }
void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) { void selection::init_axes(const analysis::layout_t& layout, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
if(tiles_->hmma(v) == analysis::HMMA_C) if(layout.type == analysis::HMMA_884)
init_hmma_axes(v, builder, u_thread_id, u_warp_id); init_hmma_axes(layout, builder, u_thread_id, u_warp_id);
else else
init_strided_scan_axes(v, builder, u_thread_id, u_warp_id); init_strided_scan_axes(layout, builder, u_thread_id, u_warp_id);
} }
/* ------------------- /* -------------------
@@ -727,7 +726,7 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id
void selection::create_shared_tile(ir::value *v, IRBuilder<> &builder, Value *sh_mem_ptr) { void selection::create_shared_tile(ir::value *v, IRBuilder<> &builder, Value *sh_mem_ptr) {
if(tmap_.find(v) != tmap_.end()) if(tmap_.find(v) != tmap_.end())
return; return;
auto order = tiles_->order(v); auto order = layouts_->get(v).order;
auto shapes = v->get_type()->get_tile_shapes(); auto shapes = v->get_type()->get_tile_shapes();
unsigned pad = liveness_->get_pad(v); unsigned pad = liveness_->get_pad(v);
if(pad > 0) if(pad > 0)
@@ -777,7 +776,7 @@ void selection::create_distributed_tile(ir::value *v, IRBuilder<> &builder) {
axes[d].values = {builder.getInt32(0)}; axes[d].values = {builder.getInt32(0)};
} }
} }
distributed_tile *T = new distributed_tile(ty, shapes, tiles_->order(v), axes, builder, false); distributed_tile *T = new distributed_tile(ty, shapes, layouts_->get(v).order, axes, builder, false);
bool is_inserted = tmap_.insert({v, T}).second; bool is_inserted = tmap_.insert({v, T}).second;
// constant range // constant range
if(is_inserted && dynamic_cast<ir::make_range*>(v)){ if(is_inserted && dynamic_cast<ir::make_range*>(v)){
@@ -820,7 +819,7 @@ void selection::init_layouts(ir::function *fn, IRBuilder<> &builder, Value *sh_m
Value *u_thread_warp_id = builder.CreateURem(u_thread_id, warp_size); Value *u_thread_warp_id = builder.CreateURem(u_thread_id, warp_size);
Value *u_warp_id = builder.CreateUDiv(u_thread_id, warp_size); Value *u_warp_id = builder.CreateUDiv(u_thread_id, warp_size);
// create grid // create grid
for(auto x: tiles_->largest()) for(auto x: layouts_->get_all())
init_axes(x.second, builder, u_thread_warp_id, u_warp_id); init_axes(x.second, builder, u_thread_warp_id, u_warp_id);
// create tile // create tile
std::set<ir::value*> seen; std::set<ir::value*> seen;
@@ -868,7 +867,7 @@ void selection::lower_masked_store(ir::masked_store_inst *x, LLVMContext &ctx, F
void selection::lower_store(ir::store_inst *x, LLVMContext &ctx, Function *fn, IRBuilder<> &builder) { void selection::lower_store(ir::store_inst *x, LLVMContext &ctx, Function *fn, IRBuilder<> &builder) {
distributed_tile* ptrs = (distributed_tile*)tmap_.at(x->get_pointer_operand()); distributed_tile* ptrs = (distributed_tile*)tmap_.at(x->get_pointer_operand());
tile *scalars = tmap_.at(x->get_value_operand()); tile *scalars = tmap_.at(x->get_value_operand());
// size_t ld = tiles_->order(x->get_pointer_operand())[0]; // size_t ld = layouts_->order(x->get_pointer_operand())[0];
// unsigned vector_size = 2; // unsigned vector_size = 2;
// // vectorize pointers // // vectorize pointers
// std::map<unsigned, Value*> ptr_packets; // std::map<unsigned, Value*> ptr_packets;
@@ -1015,9 +1014,9 @@ void selection::lower_broadcast(ir::broadcast_inst *x, LLVMContext &ctx, Functio
void selection::lower_copy_to_shared(ir::copy_to_shared_inst *x, LLVMContext &ctx, Function *fn, IRBuilder<> &builder) { void selection::lower_copy_to_shared(ir::copy_to_shared_inst *x, LLVMContext &ctx, Function *fn, IRBuilder<> &builder) {
unsigned vector_size = 1; unsigned vector_size = 1;
auto x_order = tiles_->order(x); auto x_order = layouts_->get(x).order;
ir::value *arg = x->get_operand(0); ir::value *arg = x->get_operand(0);
auto arg_order = tiles_->order(arg); auto arg_order = layouts_->get(arg).order;
// tiles // tiles
shared_tile* result = (shared_tile*)tmap_.at(x); shared_tile* result = (shared_tile*)tmap_.at(x);
distributed_tile* in = (distributed_tile*)tmap_.at(arg); distributed_tile* in = (distributed_tile*)tmap_.at(arg);
@@ -1092,8 +1091,8 @@ void selection::lower_hmma_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *fn
Value* u_thread_id = tgt_->get_local_id(builder.GetInsertBlock()->getModule(), builder, 0); Value* u_thread_id = tgt_->get_local_id(builder.GetInsertBlock()->getModule(), builder, 0);
auto ord_a = tiles_->order(dot->get_operand(0)); auto ord_a = layouts_->get(dot->get_operand(0)).order;
auto ord_b = tiles_->order(dot->get_operand(1)); auto ord_b = layouts_->get(dot->get_operand(1)).order;
bool is_a_trans = is_trans(dot->get_operand(0)); bool is_a_trans = is_trans(dot->get_operand(0));
bool is_b_trans = is_trans(dot->get_operand(1)); bool is_b_trans = is_trans(dot->get_operand(1));
@@ -1255,7 +1254,7 @@ void selection::lower_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *fn, IRB
if(NK != 1) { if(NK != 1) {
shared_tile *TA = (shared_tile*)tmap_.at(A); shared_tile *TA = (shared_tile*)tmap_.at(A);
shared_tile *TB = (shared_tile*)tmap_.at(B); shared_tile *TB = (shared_tile*)tmap_.at(B);
if(tiles_->hmma(dot) == analysis::HMMA_C) if(layouts_->get(dot).type == analysis::HMMA_884)
lower_hmma_dot(dot, ctx, fn, builder, TC, TA, TB, TD, NK); lower_hmma_dot(dot, ctx, fn, builder, TC, TA, TB, TD, NK);
else else
lower_scanline_dot(dot, ctx, fn, builder, TC, TA, TB, TD, NK, c_ty, f_mul_add); lower_scanline_dot(dot, ctx, fn, builder, TC, TA, TB, TD, NK, c_ty, f_mul_add);
@@ -1271,7 +1270,7 @@ void selection::lower_masked_load(ir::masked_load_inst *x, LLVMContext &ctx, Fun
// find vector size // find vector size
distributed_tile* result = (distributed_tile*)tmap_.at(x); distributed_tile* result = (distributed_tile*)tmap_.at(x);
ir::value *ptr = x->get_pointer_operand(); ir::value *ptr = x->get_pointer_operand();
size_t ld = tiles_->order(ptr)[0]; size_t ld = layouts_->get(ptr).order[0];
unsigned alignment = alignment_->get(ptr, ld); unsigned alignment = alignment_->get(ptr, ld);
unsigned vector_size = std::min<unsigned>(result->axis(ld).contiguous, alignment); unsigned vector_size = std::min<unsigned>(result->axis(ld).contiguous, alignment);
distributed_tile *pointers = (distributed_tile*)tmap_.at(ptr); distributed_tile *pointers = (distributed_tile*)tmap_.at(ptr);
@@ -1343,7 +1342,7 @@ void selection::lower_load(ir::load_inst *x, LLVMContext &ctx, Function *fn, IRB
distributed_tile* result = (distributed_tile*)tmap_.at(x); distributed_tile* result = (distributed_tile*)tmap_.at(x);
// find vector size // find vector size
ir::value *ptr = x->get_pointer_operand(); ir::value *ptr = x->get_pointer_operand();
size_t ld = tiles_->order(ptr)[0]; size_t ld = layouts_->get(ptr).order[0];
unsigned alignment = alignment_->get(ptr, ld); unsigned alignment = alignment_->get(ptr, ld);
unsigned vector_size = std::min<unsigned>(result->axis(ld).contiguous, alignment); unsigned vector_size = std::min<unsigned>(result->axis(ld).contiguous, alignment);
distributed_tile *pointers = (distributed_tile*)tmap_.at(ptr); distributed_tile *pointers = (distributed_tile*)tmap_.at(ptr);

View File

@@ -83,8 +83,7 @@ void coalesce::run(ir::module &mod) {
if(axes.empty()) if(axes.empty())
continue; continue;
for(auto it = ++axes.rbegin(); it != axes.rend(); it++) for(auto it = ++axes.rbegin(); it != axes.rend(); it++)
remat.insert(remat.begin(), remat.insert(remat.begin(), it->second.begin(), it->second.end());
it->second.begin(), it->second.end());
} }
// rematerialize values // rematerialize values
for(ir::io_inst *r: remat) { for(ir::io_inst *r: remat) {

View File

@@ -202,9 +202,9 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
// create passes // create passes
codegen::analysis::align align; codegen::analysis::align align;
codegen::analysis::axes axes; codegen::analysis::axes axes;
codegen::analysis::layout layouts(&axes); codegen::analysis::layout layouts(&axes, &align);
codegen::analysis::tiles tiles(opt.num_warps, &align, &axes, &layouts); codegen::analysis::tiles tiles(opt.num_warps, &align, &axes, &layouts);
codegen::analysis::liveness liveness(&tiles); codegen::analysis::liveness liveness(&tiles, &layouts);
codegen::analysis::allocation allocation(&liveness, &tiles); codegen::analysis::allocation allocation(&liveness, &tiles);
codegen::transform::membar barriers(&liveness, &allocation); codegen::transform::membar barriers(&liveness, &allocation);
codegen::transform::dce dce; codegen::transform::dce dce;