[codegen][grid] some cleaning

This commit is contained in:
Philippe Tillet
2019-09-14 13:05:53 -04:00
parent 8ae779206f
commit 66e32b3074
5 changed files with 81 additions and 72 deletions

View File

@@ -4,6 +4,7 @@
#include <map>
#include <set>
#include <vector>
#include <memory>
namespace triton{
@@ -27,6 +28,8 @@ namespace analysis{
class grids {
typedef std::pair<ir::value*, unsigned> node_t;
typedef std::map <node_t, std::set<node_t>> graph_t;
typedef std::shared_ptr<int> param_ptr_t;
typedef std::map<ir::value*, std::map<int, param_ptr_t>> param_map_t;
public:
enum fragment_t{
@@ -39,7 +42,7 @@ private:
void init_c_phi(ir::instruction *i);
void init_c_graph(ir::instruction *v);
fragment_t get_fragmentation_type(node_t x, graph_t &graph);
void connected_components(node_t x, const std::vector<ir::metaparameter *> mps, const std::vector<std::string> prefixes, std::set<node_t> &nodes, graph_t &graph, unsigned group_id);
void connected_components(node_t x, const std::vector<param_ptr_t>& params, const std::vector<param_map_t*>& maps, std::set<node_t> &nodes, graph_t &graph, unsigned group_id);
void create_grids(std::vector<ir::value*> &grids,
std::map<unsigned, triton::ir::value *> &references,
ir::function *fn);
@@ -47,27 +50,36 @@ private:
public:
grids(size_t num_warps, transform::coalesce* reorder);
ir::metaparameter* get_param(ir::value *value, const std::string &key) { return params_[value][key]; }
unsigned get_param_group(ir::value *value, unsigned ax);
fragment_t get_fragment(ir::value *value, unsigned ax);
void copy(ir::value *dst, ir::value *src);
void run(ir::module &mod);
unsigned get_num_threads();
const std::vector<ir::value*> get_grids() const { return grids_; }
int get_mts(ir::value *value, unsigned ax);
int get_nts(ir::value *value, unsigned ax);
int get_fpw(ir::value *value, unsigned ax);
int get_wpt(ir::value *value, unsigned ax);
private:
std::vector<unsigned*> pool_;
transform::coalesce* reorder_;
// number of warps
size_t num_warps_;
// grids
std::vector<ir::value*> grids_;
// grid parameters
param_map_t fpw_;
param_map_t wpt_;
param_map_t mts_;
param_map_t nts_;
// constraints graph
graph_t dependencies_;
std::set<node_t> nodes_;
// fragments
std::map<node_t, fragment_t> fragments_;
std::map<node_t, unsigned> static_params_;
std::map<ir::value*, std::map<std::string, ir::metaparameter*>> params_;
std::map<unsigned, ir::metaparameter*> global_range_sizes_;
std::vector<ir::value*> grids_;
// parameter groups
std::map<ir::value*, std::map<unsigned, unsigned>> groups_;
size_t num_warps_;
transform::coalesce* reorder_;
};

View File

@@ -84,7 +84,6 @@ void grids::init_c_graph(ir::instruction *v) {
bool is_skewed = false;
for(unsigned i = 0; i < shapes.size(); i ++){
if(shapes[i] == 1){
static_params_.insert({{v, i}, 1});
add_constraint({v, i}, {v, i});
}
else if(!is_skewed &&
@@ -125,8 +124,6 @@ void grids::init_c_graph(ir::instruction *v) {
for(unsigned i = 0; i < shapes.size(); i++)
add_constraint({v, i}, {D, i});
for(unsigned i = 2; i < shapes.size(); i++){
if(shapes[i] == 1)
static_params_.insert({{v, i}, 1});
add_constraint({v, i}, {A, i});
add_constraint({v, i}, {B, i});
}
@@ -159,21 +156,15 @@ grids::fragment_t grids::get_fragmentation_type(node_t x, graph_t &graph){
return STRIDED_SCAN;
}
void grids::connected_components(node_t x, const std::vector<ir::metaparameter *> mps, const std::vector<std::string> prefixes, std::set<node_t> &nodes, graph_t &graph, unsigned group_id) {
void grids::connected_components(node_t x, const std::vector<param_ptr_t>& ptr_vec, const std::vector<param_map_t*>& maps, std::set<node_t> &nodes, graph_t &graph, unsigned group_id)
{
groups_[x.first].insert({x.second, group_id});
if(nodes.find(x) != nodes.end()){
nodes.erase(x);
std::string suffix = ".d" + std::to_string(x.second);
for(unsigned i = 0; i < mps.size(); i++)
params_[x.first].insert({prefixes[i] + suffix, mps[i]});
ir::type *ty = x.first->get_type();
if(static_params_.find(x) != static_params_.end()){
for(ir::metaparameter *mp: mps)
mp->set_value(static_params_.at(x));
}
for(const node_t &y: graph[x]){
connected_components(y, mps, prefixes, nodes, graph, group_id);
}
for(unsigned i = 0; i < ptr_vec.size(); i++)
(*maps[i])[x.first][x.second] = ptr_vec[i];
for(const node_t &y: graph[x])
connected_components(y, ptr_vec, maps, nodes, graph, group_id);
}
}
@@ -189,7 +180,10 @@ grids::fragment_t grids::get_fragment(ir::value *value, unsigned ax) {
//TODO: This shouldn't exist!
void grids::copy(ir::value *dst, ir::value *src) {
params_[dst] = params_[src];
mts_[dst] = mts_[src];
nts_[dst] = nts_[src];
fpw_[dst] = fpw_[src];
wpt_[dst] = wpt_[src];
groups_[dst] = groups_[src];
fragments_[{dst, 0}] = fragments_[{src, 0}];
}
@@ -217,17 +211,16 @@ void grids::run(ir::module &mod) {
for(auto x: nodes_)
fragments_[x] = get_fragmentation_type(x, dependencies_);
while(!nodes_.empty()) {
ir::type *ty = mod.get_builder().get_int32_ty();
node_t node = *nodes_.begin();
if(fragments_[node] == STRIDED_SCAN) {
ir::metaparameter *nts = ir::metaparameter::create(ctx, ty, 1, 1);
ir::metaparameter *mts = ir::metaparameter::create(ctx, ty, 1, 1);
connected_components(node, {nts, mts}, {"nts", "mts"}, nodes_, dependencies_, group_id++);
param_ptr_t nts(new int(-1));
param_ptr_t mts(new int(-1));
connected_components(node, {nts, mts}, {&nts_, &mts_}, nodes_, dependencies_, group_id++);
}
else {
ir::metaparameter *fpw = ir::metaparameter::create(ctx, ty, 1, 1);
ir::metaparameter *wpt = ir::metaparameter::create(ctx, ty, 1, 1);
connected_components(node, {fpw, wpt}, {"fpw", "wpt"}, nodes_, dependencies_, group_id++);
param_ptr_t fpw(new int(-1));
param_ptr_t wpt(new int(-1));
connected_components(node, {fpw, wpt}, {&fpw_, &wpt_}, nodes_, dependencies_, group_id++);
}
}
}
@@ -267,7 +260,7 @@ void grids::run(ir::module &mod) {
}while(fpw_nm1 != fpw);
// store parameters
for(unsigned d = 0; d < shapes.size(); d++)
params_.at(i).at("fpw.d" + std::to_string(d))->set_value(fpw[d]);
*fpw_[i][d] = fpw[d];
/* warps per tile */
// try to make things as square as possible to maximize data re-use
@@ -282,14 +275,12 @@ void grids::run(ir::module &mod) {
}while(wpt_nm1 != wpt);
// store parameters
for(unsigned d = 0; d < shapes.size(); d++)
params_.at(i).at("wpt.d" + std::to_string(d))->set_value(wpt[d]);
*wpt_[i][d] = wpt[d];
/* sanity check */
unsigned effective_num_warps = 1;
for(size_t d = 0; d < shapes.size(); d++){
std::string str_d = std::to_string(d);
effective_num_warps *= params_.at(i).at("wpt.d" + str_d)->get_value();
}
for(size_t d = 0; d < shapes.size(); d++)
effective_num_warps *= *wpt_[i][d];
if(num_warps_ != effective_num_warps)
throw std::runtime_error("cannot create a kernel with this amount of warps");
@@ -299,28 +290,20 @@ void grids::run(ir::module &mod) {
/* Scan-line */
else{
unsigned ld = order[0];
std::string s_ld = std::to_string(ld);
unsigned current = num_threads;
std::string nts = "nts.d" + s_ld;
std::string mts = "mts.d" + s_ld;
params_.at(i).at(nts)->set_value(clamp(size / num_threads, 1, 4));
params_.at(i).at(mts)->set_value(clamp(current, 1, shapes[ld] / params_.at(i).at(nts)->get_value()));
current = current / params_.at(i).at(mts)->get_value();
*nts_[i][ld] = clamp(size / num_threads, 1, 4);
*mts_[i][ld] = clamp(current, 1, shapes[ld] / *nts_[i][ld]);
current = current / *mts_[i][ld];
for(size_t d = 1; d < shapes.size(); d++){
ld = order[d];
s_ld = std::to_string(ld);
nts = "nts.d" + s_ld;
mts = "mts.d" + s_ld;
params_.at(i).at(nts)->set_value(1);
params_.at(i).at(mts)->set_value(clamp(current, 1, shapes[ld]));
current = current / params_.at(i).at(mts)->get_value();
*nts_[i][ld] = 1;
*mts_[i][ld] = clamp(current, 1, shapes[ld]);
current = current / *mts_[i][ld];
}
/* sanity check */
unsigned effective_num_threads = 1;
for(size_t d = 0; d < shapes.size(); d++){
std::string str_d = std::to_string(d);
effective_num_threads *= params_.at(i).at("mts.d" + str_d)->get_value();
}
for(size_t d = 0; d < shapes.size(); d++)
effective_num_threads *= *mts_[i][d];
if(num_threads != effective_num_threads)
throw std::runtime_error("cannot create a kernel with this amount of warps");
}
@@ -378,6 +361,21 @@ unsigned grids::get_num_threads() {
return num_warps_*32;
}
int grids::get_mts(ir::value *value, unsigned ax) {
return *mts_.at(value).at(ax);
}
int grids::get_nts(ir::value *value, unsigned ax) {
return *nts_.at(value).at(ax);
}
int grids::get_fpw(ir::value *value, unsigned ax) {
return *fpw_.at(value).at(ax);
}
int grids::get_wpt(ir::value *value, unsigned ax) {
return *wpt_.at(value).at(ax);
}
}
}

View File

@@ -57,9 +57,9 @@ unsigned memalloc::get_num_bytes(ir::value *x) {
num_elements *= x;
size_t depth;
if(params_->get_fragment(x, 0) == grids::HMMA_FRAGMENT_C)
depth = params_->get_param(op, "wpt.d" + std::to_string(axis))->get_value();
depth = params_->get_wpt(op, axis);
else
depth = params_->get_param(op, "mts.d" + std::to_string(axis))->get_value();
depth = params_->get_mts(op, axis);
return num_elements * num_bytes * depth;
}
unsigned num_bytes = x->get_type()->get_primitive_size_in_bits() / 8;

View File

@@ -581,9 +581,8 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id
std::vector<unsigned> warp_size(dim);
std::vector<unsigned> n_warps(dim);
for(unsigned i = 0; i < shapes.size(); i++){
std::string str_i = std::to_string(i);
contiguous[i] = params_->get_param(v, "nts.d" + str_i)->get_value();
block_size[i] = params_->get_param(v, "mts.d" + str_i)->get_value();
contiguous[i] = params_->get_nts(v, i);
block_size[i] = params_->get_mts(v, i);
}
to_warps(block_size, order, n_warps, warp_size);
std::vector<Value*> thread_id_in_warp = delinearize(u_thread_id, order, warp_size, builder);
@@ -617,13 +616,13 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id
Value *_16 = builder.getInt32(16);
// fragments per warp
unsigned fpw_0 = params_->get_param(v, "fpw.d0")->get_value();
unsigned fpw_1 = params_->get_param(v, "fpw.d1")->get_value();
unsigned fpw_2 = is_batched ? params_->get_param(v, "fpw.d2")->get_value() : 1;
unsigned fpw_0 = params_->get_fpw(v, 0);
unsigned fpw_1 = params_->get_fpw(v, 1);
unsigned fpw_2 = is_batched ? params_->get_fpw(v, 2) : 1;
// warps per tile
unsigned wpt_0 = params_->get_param(v, "wpt.d0")->get_value();
unsigned wpt_1 = params_->get_param(v, "wpt.d1")->get_value();
unsigned wpt_2 = is_batched ? params_->get_param(v, "wpt.d2")->get_value() : 1;
unsigned wpt_0 = params_->get_wpt(v, 0);
unsigned wpt_1 = params_->get_wpt(v, 1);
unsigned wpt_2 = is_batched ? params_->get_wpt(v, 2) : 1;
// hmma warp tile size
unsigned hmma_wts_0 = fpw_0 * 8;
unsigned hmma_wts_1 = fpw_1 * 8;
@@ -909,7 +908,7 @@ void selection::lower_reduce(ir::reduce_inst *x, LLVMContext &ctx, Function *fn,
tgt_->add_barrier(module, builder);
builder.CreateStore(result, write_ptr);
// build result
unsigned depth = params_->get_param(op, "wpt.d" + std::to_string(axis))->get_value();
unsigned depth = params_->get_wpt(op, axis);
for(unsigned i = depth/2; i > 0; i >>= 1){
// current indices
indices_t current(write_idx.size(), builder.getInt32(0));
@@ -1076,12 +1075,12 @@ void selection::lower_hmma_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *fn
"{$10, $11}, "
"{$0, $1, $2, $3, $4, $5, $6, $7};", "=f,=f,=f,=f,=f,=f,=f,=f,r,r,r,r,0,1,2,3,4,5,6,7", false);
unsigned fpw_0 = params_->get_param(dot, "fpw.d0")->get_value();
unsigned fpw_1 = params_->get_param(dot, "fpw.d1")->get_value();
unsigned fpw_0 = params_->get_fpw(dot, 0);
unsigned fpw_1 = params_->get_fpw(dot, 1);
unsigned wts_0 = fpw_0 * 8;
unsigned wts_1 = fpw_1 * 8;
unsigned wpt_0 = params_->get_param(dot, "wpt.d0")->get_value();
unsigned wpt_1 = params_->get_param(dot, "wpt.d1")->get_value();
unsigned wpt_0 = params_->get_wpt(dot, 0);
unsigned wpt_1 = params_->get_wpt(dot, 1);
unsigned stride_rep_i = wpt_0 * wts_0;
unsigned stride_rep_j = wpt_1 * wts_1;
unsigned num_rep_i = shapes[0] / stride_rep_i;

View File

@@ -27,7 +27,7 @@ void vectorize::run(ir::module &mod) {
}
if(dynamic_cast<ir::copy_to_shared_inst*>(i)){
ir::value *x = i->get_operand(0);
if(params_->get_param(x, "nts.d0")->get_value() == 1)
if(params_->get_nts(x, 0) == 1)
continue;
builder.set_insert_point(i);
ir::instruction *rx = (ir::instruction*)builder.create_vectorize(x);