[auto-tuning] much smaller parameters space

This commit is contained in:
Philippe Tillet
2019-08-09 16:57:18 -07:00
parent fd49cdc92b
commit 1400d960a6
20 changed files with 470 additions and 367 deletions

View File

@@ -14,7 +14,8 @@ namespace triton{
namespace codegen{
namespace analysis{
tune::tune(): num_global_ranges_(0){ }
tune::tune() {
}
bool is_hmma(ir::value *v){
bool result = false;
@@ -123,8 +124,8 @@ void tune::init_c_graph(ir::instruction *v) {
for(unsigned i = 2; i < shapes.size(); i++){
if(shapes[i] == one)
static_params_.insert({{v, i}, 1});
// add_constraint({v, i}, {A, i});
// add_constraint({v, i}, {B, i});
add_constraint({v, i}, {A, i});
add_constraint({v, i}, {B, i});
}
}
// Element-wise
@@ -172,11 +173,6 @@ void tune::connected_components(node_t x, const std::vector<ir::metaparameter *>
if(auto mp = dynamic_cast<ir::metaparameter*>(shape))
params_[x.first].insert({"shape" + suffix, mp});
}
// if(auto range = dynamic_cast<ir::get_global_range_inst*>(x.first)){
// unsigned ax = range->get_axis();
// global_range_sizes_[ax] = params_[x.first].at("shape.d0");
// num_global_ranges_ = std::max(num_global_ranges_, ax + 1);
// }
if(static_params_.find(x) != static_params_.end()){
for(ir::metaparameter *mp: mps)
mp->set_value(static_params_.at(x));
@@ -189,21 +185,13 @@ void tune::connected_components(node_t x, const std::vector<ir::metaparameter *>
std::vector<ir::metaparameter *> tune::get_params(ir::module &mod) {
std::vector<ir::metaparameter*> result;
std::set<ir::metaparameter*> seen;
for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *i : block->get_inst_list())
for(auto &x: params_[i])
if(seen.insert(x.second).second && !x.second->has_value()){
result.push_back(x.second);
}
for(auto x: mod.globals()){
for(auto x: mod.globals()) {
if(auto mp = dynamic_cast<ir::metaparameter*>(x.second))
if(seen.insert(mp).second && !mp->has_value())
result.push_back(mp);
}
num_warps_ = ir::metaparameter::create(mod.get_context(), mod.get_builder().get_int32_ty(), 4, 4);
result.push_back(num_warps_);
return result;
}
@@ -212,7 +200,6 @@ std::map<std::string, ir::metaparameter *> tune::get_params(ir::instruction* i)
}
unsigned tune::get_param_group(ir::value *value, unsigned ax) {
// std::cout << "group? " << value->get_name() << " " << ax << std::endl;
unsigned result = groups_.at(value).at(ax);
return result;
}
@@ -229,139 +216,164 @@ void tune::run(ir::module &mod) {
ir::context &ctx = mod.get_context();
// Create metaparameters
for(ir::function *fn: mod.get_function_list()){
// Build constraints graph
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *i : block->get_inst_list())
if(i->has_tile_result_or_op()){
if(i->has_tile_result_or_op())
init_c_graph(i);
}
// Build phi constraints
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *i : block->get_inst_list())
if(i->has_tile_result_or_op())
init_c_phi(i);
// Layout parameters
unsigned group_id = 0;
for(auto x: nodes_){
for(auto x: nodes_)
fragments_[x] = get_fragmentation_type(x, dependencies_);
}
while(!nodes_.empty()) {
ir::type *ty = mod.get_builder().get_int32_ty();
node_t node = *nodes_.begin();
if(fragments_[node] == STRIDED_SCAN) {
ir::metaparameter *nts = ir::metaparameter::create(ctx, ty, 1, 1);
ir::metaparameter *mts = ir::metaparameter::create(ctx, ty, 1, 8);
ir::metaparameter *mts = ir::metaparameter::create(ctx, ty, 1, 1);
connected_components(node, {nts, mts}, {"nts", "mts"}, nodes_, dependencies_, group_id++);
nts->set_value(1);
}
else {
ir::metaparameter *fpw = ir::metaparameter::create(ctx, ty, 1, 1);
if(node.second == 2)
fpw->set_value(1);
ir::metaparameter *wpt = ir::metaparameter::create(ctx, ty, 1, 1);
connected_components(node, {fpw, wpt}, {"fpw", "wpt"}, nodes_, dependencies_, group_id++);
}
}
}
for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *i : block->get_inst_list()){
if(!i->get_type()->is_tile_ty())
continue;
auto shapes = i->get_type()->get_tile_shapes();
if(auto *x = dynamic_cast<ir::load_inst*>(i))
if(fragments_.at({i, 0}) == STRIDED_SCAN){
ir::type *ptr_ty = x->get_pointer_operand()->get_type()->get_scalar_ty();
size_t addr_space = ptr_ty->get_pointer_address_space();
if(addr_space < 4){
ir::type *ty = mod.get_builder().get_int32_ty();
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 1, 1));
*params_.at(i).at("nts.d0") = *tmp;
}
}
if(dynamic_cast<ir::dot_inst*>(i) && i->get_type()->is_tile_ty()){
ir::type *ty = mod.get_builder().get_int32_ty();
// std::unique_ptr<ir::metaparameter> mts_2(ir::metaparameter::create(ctx, ty, 1, 4));
// *params_.at(i->get_operand(0)).at("mts.d2") = *mts_2;
// *params_.at(i->get_operand(1)).at("mts.d2") = *mts_2;
if(fragments_.at({i, 0}) == STRIDED_SCAN){
std::unique_ptr<ir::metaparameter> tmp1(ir::metaparameter::create(ctx, ty, 1, 1));
std::unique_ptr<ir::metaparameter> tmp2(ir::metaparameter::create(ctx, ty, 1, 1));
*params_.at(i).at("nts.d0") = *tmp1;
*params_.at(i).at("nts.d1") = *tmp2;
// for(size_t k = 2; k < shapes.size(); k++)
// if(auto *x = dynamic_cast<ir::metaparameter*>(shapes[k]))
// *params_.at(i).at("mts.d" + std::to_string(k)) = *x;
// else
// params_.at(i).at("mts.d" + std::to_string(k))->set_value(shapes[k]->get_value());
}
else{
// for(size_t k = 2; k < shapes.size(); k++)
// if(auto *x = dynamic_cast<ir::metaparameter*>(shapes[k]))
// *params_.at(i).at("wpt.d" + std::to_string(k)) = *x;
// else
// params_.at(i).at("wpt.d" + std::to_string(k))->set_value(shapes[k]->get_value());
}
}
}
}
void tune::init(ir::module &mod) {
for(ir::function *fn: mod.get_function_list()){
std::map<ir::metaparameter*, ir::instruction*> references;
std::map<unsigned, ir::value*> references;
create_grids(grids_, references, fn);
}
num_threads_ = get_req_num_threads(grids_.front());
}
int num_threads = get_num_threads();
int num_warps = num_warps_->get_value();
auto clamp = [&](int x, int lo, int hi) { return std::min(std::max(x, lo), hi); };
unsigned tune::get_req_num_threads(ir::instruction *i){
if(fragments_.at({i, 0}) == STRIDED_SCAN) {
unsigned result = 1;
for(unsigned k = 0; k < i->get_type()->get_tile_shapes().size(); k++){
std::string suffix = ".d" + std::to_string(k);
result *= params_.at(i).at("mts" + suffix)->get_value();
for(ir::value *i: grids_){
if(!i->get_type()->is_tile_ty())
continue;
auto shapes = i->get_type()->get_tile_shapes();
int shape_0 = shapes[0]->get_value();
int shape_1 = shapes[1]->get_value();
int size = i->get_type()->get_tile_num_elements();
/* HMMA parameters*/
if(fragments_.at({i, 0}) == HMMA_FRAGMENT_C){
/* fragments per warp */
// try to make things as square as possible to maximize data re-use
std::vector<int> fpw = {1, 1, 1};
std::vector<int> fpw_nm1;
int num_fragments = std::min((shape_0/8)*(shape_1/8), 4);
do {
fpw_nm1 = fpw;
if(fpw[0]*fpw[1] < num_fragments)
fpw[0] = clamp(fpw[0]*2, 1, shape_0 / 8);
if(fpw[0]*fpw[1] < num_fragments)
fpw[1] = clamp(fpw[1]*2, 1, shape_1 / 8);
}while(fpw_nm1 != fpw);
// store parameters
for(int d = 0; d < shapes.size(); d++)
params_.at(i).at("fpw.d" + std::to_string(d))->set_value(fpw[d]);
/* warps per tile */
// try to make things as square as possible to maximize data re-use
std::vector<int> wpt = {1, 1, 1};
std::vector<int> wpt_nm1;
do{
wpt_nm1 = wpt;
if(wpt[0] * wpt[1] * wpt[2] < num_warps)
wpt[0] = clamp(wpt[0]*2, 1, shape_0 / (fpw[0]*8));
if(wpt[0] * wpt[1] * wpt[2] < num_warps)
wpt[1] = clamp(wpt[1]*2, 1, shape_1 / (fpw[1]*8));
}while(wpt_nm1 != wpt);
// store parameters
for(int d = 0; d < shapes.size(); d++)
params_.at(i).at("wpt.d" + std::to_string(d))->set_value(wpt[d]);
/* sanity check */
unsigned effective_num_warps = 1;
for(size_t d = 0; d < shapes.size(); d++){
std::string str_d = std::to_string(d);
effective_num_warps *= params_.at(i).at("wpt.d" + str_d)->get_value();
}
assert(num_warps == effective_num_warps);
}
return result;
}
else {
unsigned result = 32;
for(unsigned k = 0; k < i->get_type()->get_tile_shapes().size(); k++){
std::string suffix = ".d" + std::to_string(k);
result *= params_.at(i).at("wpt" + suffix)->get_value();
/* Scan-line */
else{
int shape = shapes[0]->get_value();
int current = num_threads;
params_.at(i).at("nts.d0")->set_value(clamp(size / num_threads, 1, 8));
params_.at(i).at("mts.d0")->set_value(clamp(current, 1, shape / params_.at(i).at("nts.d0")->get_value()));
current = current / params_.at(i).at("mts.d0")->get_value();
for(size_t d = 1; d < shapes.size(); d++){
std::string str_d = std::to_string(d);
shape = shapes[d]->get_value();
params_.at(i).at("nts.d" + str_d)->set_value(1);
params_.at(i).at("mts.d" + str_d)->set_value(clamp(current, 1, shape));
current = current / params_.at(i).at("mts.d" + str_d)->get_value();
}
/* sanity check */
unsigned effective_num_threads = 1;
for(size_t d = 0; d < shapes.size(); d++){
std::string str_d = std::to_string(d);
effective_num_threads *= params_.at(i).at("mts.d" + str_d)->get_value();
}
assert(num_threads == effective_num_threads);
}
return result;
}
}
void tune::create_grids(std::vector<ir::instruction*> &grids,
std::map<ir::metaparameter*, ir::instruction*> &references,
ir::function *fn) {
void tune::create_grids(std::vector<ir::value*> &grids,
std::map<unsigned, ir::value*> &references,
ir::function *fn) {
// get number of dimensions greater than 1
auto get_tile_gt1_dim = [&](ir::value *v){
unsigned result = 0;
auto one = ir::tile_type::make_one(fn->get_fn_type()->get_context());
for(ir::constant_int *shape: v->get_type()->get_tile_shapes()) {
result += (shape != one);
for(ir::constant_int* shape: v->get_type()->get_tile_shapes()) {
result += (shape->get_value() > 1)?shape->get_value():0;
}
return result;
};
// bind references
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *i: block->get_inst_list()){
if(!i->get_type()->is_tile_ty())
continue;
for(auto &param: params_.at(i)){
if(param.second->get_value() == 1)
std::set<ir::value*> seen;
std::function<void(ir::value*)> bind_references = [&](ir::value *v)
{
// skip
if(!v->get_type()->is_tile_ty() || !seen.insert(v).second)
return;
// recurse
if(auto *user = dynamic_cast<ir::user*>(v))
for(ir::value *op: user->ops())
bind_references(op);
// bind
const auto& shapes = v->get_type()->get_tile_shapes();
for(size_t d = 0; d < shapes.size(); d++){
if(shapes[d]->get_value() == 1)
continue;
ir::instruction *&r = references[param.second];
if(!r || get_tile_gt1_dim(i) > get_tile_gt1_dim(r))
r = i;
unsigned x = get_param_group(v, d);
ir::value *&r = references[x];
if(!r || get_tile_gt1_dim(v) > get_tile_gt1_dim(r))
r = v;
}
}
};
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *i: block->get_inst_list())
bind_references(i);
// create grid
for(auto &ref: references)
if(std::find(grids.begin(), grids.end(), ref.second) == grids.end())
@@ -370,85 +382,11 @@ void tune::create_grids(std::vector<ir::instruction*> &grids,
bool tune::check_constraints(std::map<ir::value *, std::vector<std::string>> &errors) {
using std::to_string;
auto get_num_warps = [&](ir::instruction *i, unsigned axis) {
std::string strk = to_string(axis);
if(fragments_.at({i, axis}) == STRIDED_SCAN){
unsigned mts = params_[i]["mts.d" + strk]->get_value();
unsigned nts = params_[i]["nts.d" + strk]->get_value();
unsigned shape = i->get_type()->get_tile_shapes()[axis]->get_value();
return shape / (mts * nts);
}
else{
return (unsigned)params_[i]["wpt.d" + strk]->get_value();
}
};
// number of warps
ir::instruction *first = grids_.front();
int num_warps = 1;
for(size_t k = 0; k < first->get_type()->get_tile_shapes().size(); k++)
num_warps *= get_num_warps(first, k);
// check constraints
for(ir::instruction *i: grids_){
// std::cout << i->get_name() << std::endl;
ir::type *ty = i->get_type();
const auto &shapes = ty->get_tile_shapes();
// for each dimension, the product of layout components
// must device the shape
for(size_t k = 0; k < shapes.size(); k++) {
std::string strk = to_string(k);
unsigned multiple;
if(fragments_.at({i, 0}) == STRIDED_SCAN) {
ir::metaparameter *mts = params_[i]["mts.d" + strk];
ir::metaparameter *nts = params_[i]["nts.d" + strk];
multiple = mts->get_value()*nts->get_value();
}
else {
ir::metaparameter *fpw = params_[i]["fpw.d" + strk];
ir::metaparameter *wpt = params_[i]["wpt.d" + strk];
multiple = fpw->get_value()*wpt->get_value();
if(k < 2)
multiple *= 8;
}
if(shapes[k]->get_value() % multiple != 0)
errors[i].push_back("for dim " + strk + ": shape (" + to_string(shapes[k]->get_value()) + ")"
" is not a multiple of layout (" + to_string(multiple) + ")");
}
// the product of mma fragments per warp must be 4
if(fragments_.at({i, 0}) == HMMA_FRAGMENT_C){
unsigned prod = 1;
for(size_t k = 0; k < shapes.size(); k++){
prod *= params_[i]["fpw.d" + std::to_string(k)]->get_value();
}
if(prod > 4)
errors[i].push_back("HMMA must have only 4 fragments per warp");
}
int num_threads = get_req_num_threads(i);
if(num_threads % 32 != 0)
errors[i].push_back("number of threads per block (" + to_string(num_threads) + ") must be multiple of warp size");
if(num_threads != num_threads_)
errors[i].push_back("Number of threads must be the same for all tiles (" + to_string(num_threads_) + ")");
}
// for(auto x: errors)
// for(auto e: x.second)
// std::cout << x.first->get_name() << ": " << e << std::endl;
// exit(EXIT_SUCCESS);
return errors.empty();
}
unsigned tune::get_num_global_range() {
return num_global_ranges_;
}
unsigned tune::get_global_range_size(unsigned axis) {
return global_range_sizes_.at(axis)->get_value();
}
unsigned tune::get_num_threads() {
return num_threads_;
return num_warps_->get_value()*32;
}