[codegen][grid] some cleaning
This commit is contained in:
@@ -4,6 +4,7 @@
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
namespace triton{
|
||||
|
||||
@@ -27,6 +28,8 @@ namespace analysis{
|
||||
class grids {
|
||||
typedef std::pair<ir::value*, unsigned> node_t;
|
||||
typedef std::map <node_t, std::set<node_t>> graph_t;
|
||||
typedef std::shared_ptr<int> param_ptr_t;
|
||||
typedef std::map<ir::value*, std::map<int, param_ptr_t>> param_map_t;
|
||||
|
||||
public:
|
||||
enum fragment_t{
|
||||
@@ -39,7 +42,7 @@ private:
|
||||
void init_c_phi(ir::instruction *i);
|
||||
void init_c_graph(ir::instruction *v);
|
||||
fragment_t get_fragmentation_type(node_t x, graph_t &graph);
|
||||
void connected_components(node_t x, const std::vector<ir::metaparameter *> mps, const std::vector<std::string> prefixes, std::set<node_t> &nodes, graph_t &graph, unsigned group_id);
|
||||
void connected_components(node_t x, const std::vector<param_ptr_t>& params, const std::vector<param_map_t*>& maps, std::set<node_t> &nodes, graph_t &graph, unsigned group_id);
|
||||
void create_grids(std::vector<ir::value*> &grids,
|
||||
std::map<unsigned, triton::ir::value *> &references,
|
||||
ir::function *fn);
|
||||
@@ -47,27 +50,36 @@ private:
|
||||
|
||||
public:
|
||||
grids(size_t num_warps, transform::coalesce* reorder);
|
||||
ir::metaparameter* get_param(ir::value *value, const std::string &key) { return params_[value][key]; }
|
||||
unsigned get_param_group(ir::value *value, unsigned ax);
|
||||
fragment_t get_fragment(ir::value *value, unsigned ax);
|
||||
void copy(ir::value *dst, ir::value *src);
|
||||
void run(ir::module &mod);
|
||||
unsigned get_num_threads();
|
||||
const std::vector<ir::value*> get_grids() const { return grids_; }
|
||||
int get_mts(ir::value *value, unsigned ax);
|
||||
int get_nts(ir::value *value, unsigned ax);
|
||||
int get_fpw(ir::value *value, unsigned ax);
|
||||
int get_wpt(ir::value *value, unsigned ax);
|
||||
|
||||
private:
|
||||
std::vector<unsigned*> pool_;
|
||||
|
||||
transform::coalesce* reorder_;
|
||||
// number of warps
|
||||
size_t num_warps_;
|
||||
// grids
|
||||
std::vector<ir::value*> grids_;
|
||||
// grid parameters
|
||||
param_map_t fpw_;
|
||||
param_map_t wpt_;
|
||||
param_map_t mts_;
|
||||
param_map_t nts_;
|
||||
// constraints graph
|
||||
graph_t dependencies_;
|
||||
std::set<node_t> nodes_;
|
||||
// fragments
|
||||
std::map<node_t, fragment_t> fragments_;
|
||||
std::map<node_t, unsigned> static_params_;
|
||||
std::map<ir::value*, std::map<std::string, ir::metaparameter*>> params_;
|
||||
std::map<unsigned, ir::metaparameter*> global_range_sizes_;
|
||||
std::vector<ir::value*> grids_;
|
||||
// parameter groups
|
||||
std::map<ir::value*, std::map<unsigned, unsigned>> groups_;
|
||||
size_t num_warps_;
|
||||
transform::coalesce* reorder_;
|
||||
|
||||
};
|
||||
|
||||
|
||||
|
@@ -84,7 +84,6 @@ void grids::init_c_graph(ir::instruction *v) {
|
||||
bool is_skewed = false;
|
||||
for(unsigned i = 0; i < shapes.size(); i ++){
|
||||
if(shapes[i] == 1){
|
||||
static_params_.insert({{v, i}, 1});
|
||||
add_constraint({v, i}, {v, i});
|
||||
}
|
||||
else if(!is_skewed &&
|
||||
@@ -125,8 +124,6 @@ void grids::init_c_graph(ir::instruction *v) {
|
||||
for(unsigned i = 0; i < shapes.size(); i++)
|
||||
add_constraint({v, i}, {D, i});
|
||||
for(unsigned i = 2; i < shapes.size(); i++){
|
||||
if(shapes[i] == 1)
|
||||
static_params_.insert({{v, i}, 1});
|
||||
add_constraint({v, i}, {A, i});
|
||||
add_constraint({v, i}, {B, i});
|
||||
}
|
||||
@@ -159,21 +156,15 @@ grids::fragment_t grids::get_fragmentation_type(node_t x, graph_t &graph){
|
||||
return STRIDED_SCAN;
|
||||
}
|
||||
|
||||
void grids::connected_components(node_t x, const std::vector<ir::metaparameter *> mps, const std::vector<std::string> prefixes, std::set<node_t> &nodes, graph_t &graph, unsigned group_id) {
|
||||
void grids::connected_components(node_t x, const std::vector<param_ptr_t>& ptr_vec, const std::vector<param_map_t*>& maps, std::set<node_t> &nodes, graph_t &graph, unsigned group_id)
|
||||
{
|
||||
groups_[x.first].insert({x.second, group_id});
|
||||
if(nodes.find(x) != nodes.end()){
|
||||
nodes.erase(x);
|
||||
std::string suffix = ".d" + std::to_string(x.second);
|
||||
for(unsigned i = 0; i < mps.size(); i++)
|
||||
params_[x.first].insert({prefixes[i] + suffix, mps[i]});
|
||||
ir::type *ty = x.first->get_type();
|
||||
if(static_params_.find(x) != static_params_.end()){
|
||||
for(ir::metaparameter *mp: mps)
|
||||
mp->set_value(static_params_.at(x));
|
||||
}
|
||||
for(const node_t &y: graph[x]){
|
||||
connected_components(y, mps, prefixes, nodes, graph, group_id);
|
||||
}
|
||||
for(unsigned i = 0; i < ptr_vec.size(); i++)
|
||||
(*maps[i])[x.first][x.second] = ptr_vec[i];
|
||||
for(const node_t &y: graph[x])
|
||||
connected_components(y, ptr_vec, maps, nodes, graph, group_id);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -189,7 +180,10 @@ grids::fragment_t grids::get_fragment(ir::value *value, unsigned ax) {
|
||||
|
||||
//TODO: This shouldn't exist!
|
||||
void grids::copy(ir::value *dst, ir::value *src) {
|
||||
params_[dst] = params_[src];
|
||||
mts_[dst] = mts_[src];
|
||||
nts_[dst] = nts_[src];
|
||||
fpw_[dst] = fpw_[src];
|
||||
wpt_[dst] = wpt_[src];
|
||||
groups_[dst] = groups_[src];
|
||||
fragments_[{dst, 0}] = fragments_[{src, 0}];
|
||||
}
|
||||
@@ -217,17 +211,16 @@ void grids::run(ir::module &mod) {
|
||||
for(auto x: nodes_)
|
||||
fragments_[x] = get_fragmentation_type(x, dependencies_);
|
||||
while(!nodes_.empty()) {
|
||||
ir::type *ty = mod.get_builder().get_int32_ty();
|
||||
node_t node = *nodes_.begin();
|
||||
if(fragments_[node] == STRIDED_SCAN) {
|
||||
ir::metaparameter *nts = ir::metaparameter::create(ctx, ty, 1, 1);
|
||||
ir::metaparameter *mts = ir::metaparameter::create(ctx, ty, 1, 1);
|
||||
connected_components(node, {nts, mts}, {"nts", "mts"}, nodes_, dependencies_, group_id++);
|
||||
param_ptr_t nts(new int(-1));
|
||||
param_ptr_t mts(new int(-1));
|
||||
connected_components(node, {nts, mts}, {&nts_, &mts_}, nodes_, dependencies_, group_id++);
|
||||
}
|
||||
else {
|
||||
ir::metaparameter *fpw = ir::metaparameter::create(ctx, ty, 1, 1);
|
||||
ir::metaparameter *wpt = ir::metaparameter::create(ctx, ty, 1, 1);
|
||||
connected_components(node, {fpw, wpt}, {"fpw", "wpt"}, nodes_, dependencies_, group_id++);
|
||||
param_ptr_t fpw(new int(-1));
|
||||
param_ptr_t wpt(new int(-1));
|
||||
connected_components(node, {fpw, wpt}, {&fpw_, &wpt_}, nodes_, dependencies_, group_id++);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -267,7 +260,7 @@ void grids::run(ir::module &mod) {
|
||||
}while(fpw_nm1 != fpw);
|
||||
// store parameters
|
||||
for(unsigned d = 0; d < shapes.size(); d++)
|
||||
params_.at(i).at("fpw.d" + std::to_string(d))->set_value(fpw[d]);
|
||||
*fpw_[i][d] = fpw[d];
|
||||
|
||||
/* warps per tile */
|
||||
// try to make things as square as possible to maximize data re-use
|
||||
@@ -282,14 +275,12 @@ void grids::run(ir::module &mod) {
|
||||
}while(wpt_nm1 != wpt);
|
||||
// store parameters
|
||||
for(unsigned d = 0; d < shapes.size(); d++)
|
||||
params_.at(i).at("wpt.d" + std::to_string(d))->set_value(wpt[d]);
|
||||
*wpt_[i][d] = wpt[d];
|
||||
|
||||
/* sanity check */
|
||||
unsigned effective_num_warps = 1;
|
||||
for(size_t d = 0; d < shapes.size(); d++){
|
||||
std::string str_d = std::to_string(d);
|
||||
effective_num_warps *= params_.at(i).at("wpt.d" + str_d)->get_value();
|
||||
}
|
||||
for(size_t d = 0; d < shapes.size(); d++)
|
||||
effective_num_warps *= *wpt_[i][d];
|
||||
|
||||
if(num_warps_ != effective_num_warps)
|
||||
throw std::runtime_error("cannot create a kernel with this amount of warps");
|
||||
@@ -299,28 +290,20 @@ void grids::run(ir::module &mod) {
|
||||
/* Scan-line */
|
||||
else{
|
||||
unsigned ld = order[0];
|
||||
std::string s_ld = std::to_string(ld);
|
||||
unsigned current = num_threads;
|
||||
std::string nts = "nts.d" + s_ld;
|
||||
std::string mts = "mts.d" + s_ld;
|
||||
params_.at(i).at(nts)->set_value(clamp(size / num_threads, 1, 4));
|
||||
params_.at(i).at(mts)->set_value(clamp(current, 1, shapes[ld] / params_.at(i).at(nts)->get_value()));
|
||||
current = current / params_.at(i).at(mts)->get_value();
|
||||
*nts_[i][ld] = clamp(size / num_threads, 1, 4);
|
||||
*mts_[i][ld] = clamp(current, 1, shapes[ld] / *nts_[i][ld]);
|
||||
current = current / *mts_[i][ld];
|
||||
for(size_t d = 1; d < shapes.size(); d++){
|
||||
ld = order[d];
|
||||
s_ld = std::to_string(ld);
|
||||
nts = "nts.d" + s_ld;
|
||||
mts = "mts.d" + s_ld;
|
||||
params_.at(i).at(nts)->set_value(1);
|
||||
params_.at(i).at(mts)->set_value(clamp(current, 1, shapes[ld]));
|
||||
current = current / params_.at(i).at(mts)->get_value();
|
||||
*nts_[i][ld] = 1;
|
||||
*mts_[i][ld] = clamp(current, 1, shapes[ld]);
|
||||
current = current / *mts_[i][ld];
|
||||
}
|
||||
/* sanity check */
|
||||
unsigned effective_num_threads = 1;
|
||||
for(size_t d = 0; d < shapes.size(); d++){
|
||||
std::string str_d = std::to_string(d);
|
||||
effective_num_threads *= params_.at(i).at("mts.d" + str_d)->get_value();
|
||||
}
|
||||
for(size_t d = 0; d < shapes.size(); d++)
|
||||
effective_num_threads *= *mts_[i][d];
|
||||
if(num_threads != effective_num_threads)
|
||||
throw std::runtime_error("cannot create a kernel with this amount of warps");
|
||||
}
|
||||
@@ -378,6 +361,21 @@ unsigned grids::get_num_threads() {
|
||||
return num_warps_*32;
|
||||
}
|
||||
|
||||
int grids::get_mts(ir::value *value, unsigned ax) {
|
||||
return *mts_.at(value).at(ax);
|
||||
}
|
||||
|
||||
int grids::get_nts(ir::value *value, unsigned ax) {
|
||||
return *nts_.at(value).at(ax);
|
||||
}
|
||||
|
||||
int grids::get_fpw(ir::value *value, unsigned ax) {
|
||||
return *fpw_.at(value).at(ax);
|
||||
}
|
||||
|
||||
int grids::get_wpt(ir::value *value, unsigned ax) {
|
||||
return *wpt_.at(value).at(ax);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
@@ -57,9 +57,9 @@ unsigned memalloc::get_num_bytes(ir::value *x) {
|
||||
num_elements *= x;
|
||||
size_t depth;
|
||||
if(params_->get_fragment(x, 0) == grids::HMMA_FRAGMENT_C)
|
||||
depth = params_->get_param(op, "wpt.d" + std::to_string(axis))->get_value();
|
||||
depth = params_->get_wpt(op, axis);
|
||||
else
|
||||
depth = params_->get_param(op, "mts.d" + std::to_string(axis))->get_value();
|
||||
depth = params_->get_mts(op, axis);
|
||||
return num_elements * num_bytes * depth;
|
||||
}
|
||||
unsigned num_bytes = x->get_type()->get_primitive_size_in_bits() / 8;
|
||||
|
@@ -581,9 +581,8 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id
|
||||
std::vector<unsigned> warp_size(dim);
|
||||
std::vector<unsigned> n_warps(dim);
|
||||
for(unsigned i = 0; i < shapes.size(); i++){
|
||||
std::string str_i = std::to_string(i);
|
||||
contiguous[i] = params_->get_param(v, "nts.d" + str_i)->get_value();
|
||||
block_size[i] = params_->get_param(v, "mts.d" + str_i)->get_value();
|
||||
contiguous[i] = params_->get_nts(v, i);
|
||||
block_size[i] = params_->get_mts(v, i);
|
||||
}
|
||||
to_warps(block_size, order, n_warps, warp_size);
|
||||
std::vector<Value*> thread_id_in_warp = delinearize(u_thread_id, order, warp_size, builder);
|
||||
@@ -617,13 +616,13 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id
|
||||
Value *_16 = builder.getInt32(16);
|
||||
|
||||
// fragments per warp
|
||||
unsigned fpw_0 = params_->get_param(v, "fpw.d0")->get_value();
|
||||
unsigned fpw_1 = params_->get_param(v, "fpw.d1")->get_value();
|
||||
unsigned fpw_2 = is_batched ? params_->get_param(v, "fpw.d2")->get_value() : 1;
|
||||
unsigned fpw_0 = params_->get_fpw(v, 0);
|
||||
unsigned fpw_1 = params_->get_fpw(v, 1);
|
||||
unsigned fpw_2 = is_batched ? params_->get_fpw(v, 2) : 1;
|
||||
// warps per tile
|
||||
unsigned wpt_0 = params_->get_param(v, "wpt.d0")->get_value();
|
||||
unsigned wpt_1 = params_->get_param(v, "wpt.d1")->get_value();
|
||||
unsigned wpt_2 = is_batched ? params_->get_param(v, "wpt.d2")->get_value() : 1;
|
||||
unsigned wpt_0 = params_->get_wpt(v, 0);
|
||||
unsigned wpt_1 = params_->get_wpt(v, 1);
|
||||
unsigned wpt_2 = is_batched ? params_->get_wpt(v, 2) : 1;
|
||||
// hmma warp tile size
|
||||
unsigned hmma_wts_0 = fpw_0 * 8;
|
||||
unsigned hmma_wts_1 = fpw_1 * 8;
|
||||
@@ -909,7 +908,7 @@ void selection::lower_reduce(ir::reduce_inst *x, LLVMContext &ctx, Function *fn,
|
||||
tgt_->add_barrier(module, builder);
|
||||
builder.CreateStore(result, write_ptr);
|
||||
// build result
|
||||
unsigned depth = params_->get_param(op, "wpt.d" + std::to_string(axis))->get_value();
|
||||
unsigned depth = params_->get_wpt(op, axis);
|
||||
for(unsigned i = depth/2; i > 0; i >>= 1){
|
||||
// current indices
|
||||
indices_t current(write_idx.size(), builder.getInt32(0));
|
||||
@@ -1076,12 +1075,12 @@ void selection::lower_hmma_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *fn
|
||||
"{$10, $11}, "
|
||||
"{$0, $1, $2, $3, $4, $5, $6, $7};", "=f,=f,=f,=f,=f,=f,=f,=f,r,r,r,r,0,1,2,3,4,5,6,7", false);
|
||||
|
||||
unsigned fpw_0 = params_->get_param(dot, "fpw.d0")->get_value();
|
||||
unsigned fpw_1 = params_->get_param(dot, "fpw.d1")->get_value();
|
||||
unsigned fpw_0 = params_->get_fpw(dot, 0);
|
||||
unsigned fpw_1 = params_->get_fpw(dot, 1);
|
||||
unsigned wts_0 = fpw_0 * 8;
|
||||
unsigned wts_1 = fpw_1 * 8;
|
||||
unsigned wpt_0 = params_->get_param(dot, "wpt.d0")->get_value();
|
||||
unsigned wpt_1 = params_->get_param(dot, "wpt.d1")->get_value();
|
||||
unsigned wpt_0 = params_->get_wpt(dot, 0);
|
||||
unsigned wpt_1 = params_->get_wpt(dot, 1);
|
||||
unsigned stride_rep_i = wpt_0 * wts_0;
|
||||
unsigned stride_rep_j = wpt_1 * wts_1;
|
||||
unsigned num_rep_i = shapes[0] / stride_rep_i;
|
||||
|
@@ -27,7 +27,7 @@ void vectorize::run(ir::module &mod) {
|
||||
}
|
||||
if(dynamic_cast<ir::copy_to_shared_inst*>(i)){
|
||||
ir::value *x = i->get_operand(0);
|
||||
if(params_->get_param(x, "nts.d0")->get_value() == 1)
|
||||
if(params_->get_nts(x, 0) == 1)
|
||||
continue;
|
||||
builder.set_insert_point(i);
|
||||
ir::instruction *rx = (ir::instruction*)builder.create_vectorize(x);
|
||||
|
Reference in New Issue
Block a user