[auto-tuning] much smaller parameters space
This commit is contained in:
@@ -14,7 +14,8 @@ namespace triton{
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
|
||||
tune::tune(): num_global_ranges_(0){ }
|
||||
tune::tune() {
|
||||
}
|
||||
|
||||
bool is_hmma(ir::value *v){
|
||||
bool result = false;
|
||||
@@ -123,8 +124,8 @@ void tune::init_c_graph(ir::instruction *v) {
|
||||
for(unsigned i = 2; i < shapes.size(); i++){
|
||||
if(shapes[i] == one)
|
||||
static_params_.insert({{v, i}, 1});
|
||||
// add_constraint({v, i}, {A, i});
|
||||
// add_constraint({v, i}, {B, i});
|
||||
add_constraint({v, i}, {A, i});
|
||||
add_constraint({v, i}, {B, i});
|
||||
}
|
||||
}
|
||||
// Element-wise
|
||||
@@ -172,11 +173,6 @@ void tune::connected_components(node_t x, const std::vector<ir::metaparameter *>
|
||||
if(auto mp = dynamic_cast<ir::metaparameter*>(shape))
|
||||
params_[x.first].insert({"shape" + suffix, mp});
|
||||
}
|
||||
// if(auto range = dynamic_cast<ir::get_global_range_inst*>(x.first)){
|
||||
// unsigned ax = range->get_axis();
|
||||
// global_range_sizes_[ax] = params_[x.first].at("shape.d0");
|
||||
// num_global_ranges_ = std::max(num_global_ranges_, ax + 1);
|
||||
// }
|
||||
if(static_params_.find(x) != static_params_.end()){
|
||||
for(ir::metaparameter *mp: mps)
|
||||
mp->set_value(static_params_.at(x));
|
||||
@@ -189,21 +185,13 @@ void tune::connected_components(node_t x, const std::vector<ir::metaparameter *>
|
||||
std::vector<ir::metaparameter *> tune::get_params(ir::module &mod) {
|
||||
std::vector<ir::metaparameter*> result;
|
||||
std::set<ir::metaparameter*> seen;
|
||||
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *i : block->get_inst_list())
|
||||
for(auto &x: params_[i])
|
||||
if(seen.insert(x.second).second && !x.second->has_value()){
|
||||
result.push_back(x.second);
|
||||
}
|
||||
|
||||
for(auto x: mod.globals()){
|
||||
for(auto x: mod.globals()) {
|
||||
if(auto mp = dynamic_cast<ir::metaparameter*>(x.second))
|
||||
if(seen.insert(mp).second && !mp->has_value())
|
||||
result.push_back(mp);
|
||||
}
|
||||
|
||||
num_warps_ = ir::metaparameter::create(mod.get_context(), mod.get_builder().get_int32_ty(), 4, 4);
|
||||
result.push_back(num_warps_);
|
||||
return result;
|
||||
}
|
||||
|
||||
@@ -212,7 +200,6 @@ std::map<std::string, ir::metaparameter *> tune::get_params(ir::instruction* i)
|
||||
}
|
||||
|
||||
unsigned tune::get_param_group(ir::value *value, unsigned ax) {
|
||||
// std::cout << "group? " << value->get_name() << " " << ax << std::endl;
|
||||
unsigned result = groups_.at(value).at(ax);
|
||||
return result;
|
||||
}
|
||||
@@ -229,139 +216,164 @@ void tune::run(ir::module &mod) {
|
||||
ir::context &ctx = mod.get_context();
|
||||
// Create metaparameters
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
|
||||
// Build constraints graph
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *i : block->get_inst_list())
|
||||
if(i->has_tile_result_or_op()){
|
||||
if(i->has_tile_result_or_op())
|
||||
init_c_graph(i);
|
||||
}
|
||||
|
||||
// Build phi constraints
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *i : block->get_inst_list())
|
||||
if(i->has_tile_result_or_op())
|
||||
init_c_phi(i);
|
||||
|
||||
// Layout parameters
|
||||
unsigned group_id = 0;
|
||||
for(auto x: nodes_){
|
||||
for(auto x: nodes_)
|
||||
fragments_[x] = get_fragmentation_type(x, dependencies_);
|
||||
}
|
||||
while(!nodes_.empty()) {
|
||||
ir::type *ty = mod.get_builder().get_int32_ty();
|
||||
node_t node = *nodes_.begin();
|
||||
if(fragments_[node] == STRIDED_SCAN) {
|
||||
ir::metaparameter *nts = ir::metaparameter::create(ctx, ty, 1, 1);
|
||||
ir::metaparameter *mts = ir::metaparameter::create(ctx, ty, 1, 8);
|
||||
ir::metaparameter *mts = ir::metaparameter::create(ctx, ty, 1, 1);
|
||||
connected_components(node, {nts, mts}, {"nts", "mts"}, nodes_, dependencies_, group_id++);
|
||||
nts->set_value(1);
|
||||
}
|
||||
else {
|
||||
ir::metaparameter *fpw = ir::metaparameter::create(ctx, ty, 1, 1);
|
||||
if(node.second == 2)
|
||||
fpw->set_value(1);
|
||||
ir::metaparameter *wpt = ir::metaparameter::create(ctx, ty, 1, 1);
|
||||
connected_components(node, {fpw, wpt}, {"fpw", "wpt"}, nodes_, dependencies_, group_id++);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *i : block->get_inst_list()){
|
||||
if(!i->get_type()->is_tile_ty())
|
||||
continue;
|
||||
auto shapes = i->get_type()->get_tile_shapes();
|
||||
|
||||
if(auto *x = dynamic_cast<ir::load_inst*>(i))
|
||||
if(fragments_.at({i, 0}) == STRIDED_SCAN){
|
||||
ir::type *ptr_ty = x->get_pointer_operand()->get_type()->get_scalar_ty();
|
||||
size_t addr_space = ptr_ty->get_pointer_address_space();
|
||||
if(addr_space < 4){
|
||||
ir::type *ty = mod.get_builder().get_int32_ty();
|
||||
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 1, 1));
|
||||
*params_.at(i).at("nts.d0") = *tmp;
|
||||
}
|
||||
}
|
||||
if(dynamic_cast<ir::dot_inst*>(i) && i->get_type()->is_tile_ty()){
|
||||
ir::type *ty = mod.get_builder().get_int32_ty();
|
||||
// std::unique_ptr<ir::metaparameter> mts_2(ir::metaparameter::create(ctx, ty, 1, 4));
|
||||
// *params_.at(i->get_operand(0)).at("mts.d2") = *mts_2;
|
||||
// *params_.at(i->get_operand(1)).at("mts.d2") = *mts_2;
|
||||
if(fragments_.at({i, 0}) == STRIDED_SCAN){
|
||||
std::unique_ptr<ir::metaparameter> tmp1(ir::metaparameter::create(ctx, ty, 1, 1));
|
||||
std::unique_ptr<ir::metaparameter> tmp2(ir::metaparameter::create(ctx, ty, 1, 1));
|
||||
*params_.at(i).at("nts.d0") = *tmp1;
|
||||
*params_.at(i).at("nts.d1") = *tmp2;
|
||||
// for(size_t k = 2; k < shapes.size(); k++)
|
||||
// if(auto *x = dynamic_cast<ir::metaparameter*>(shapes[k]))
|
||||
// *params_.at(i).at("mts.d" + std::to_string(k)) = *x;
|
||||
// else
|
||||
// params_.at(i).at("mts.d" + std::to_string(k))->set_value(shapes[k]->get_value());
|
||||
}
|
||||
else{
|
||||
// for(size_t k = 2; k < shapes.size(); k++)
|
||||
// if(auto *x = dynamic_cast<ir::metaparameter*>(shapes[k]))
|
||||
// *params_.at(i).at("wpt.d" + std::to_string(k)) = *x;
|
||||
// else
|
||||
// params_.at(i).at("wpt.d" + std::to_string(k))->set_value(shapes[k]->get_value());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void tune::init(ir::module &mod) {
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
std::map<ir::metaparameter*, ir::instruction*> references;
|
||||
std::map<unsigned, ir::value*> references;
|
||||
create_grids(grids_, references, fn);
|
||||
}
|
||||
|
||||
num_threads_ = get_req_num_threads(grids_.front());
|
||||
}
|
||||
int num_threads = get_num_threads();
|
||||
int num_warps = num_warps_->get_value();
|
||||
auto clamp = [&](int x, int lo, int hi) { return std::min(std::max(x, lo), hi); };
|
||||
|
||||
unsigned tune::get_req_num_threads(ir::instruction *i){
|
||||
if(fragments_.at({i, 0}) == STRIDED_SCAN) {
|
||||
unsigned result = 1;
|
||||
for(unsigned k = 0; k < i->get_type()->get_tile_shapes().size(); k++){
|
||||
std::string suffix = ".d" + std::to_string(k);
|
||||
result *= params_.at(i).at("mts" + suffix)->get_value();
|
||||
for(ir::value *i: grids_){
|
||||
if(!i->get_type()->is_tile_ty())
|
||||
continue;
|
||||
auto shapes = i->get_type()->get_tile_shapes();
|
||||
int shape_0 = shapes[0]->get_value();
|
||||
int shape_1 = shapes[1]->get_value();
|
||||
int size = i->get_type()->get_tile_num_elements();
|
||||
/* HMMA parameters*/
|
||||
if(fragments_.at({i, 0}) == HMMA_FRAGMENT_C){
|
||||
|
||||
/* fragments per warp */
|
||||
// try to make things as square as possible to maximize data re-use
|
||||
std::vector<int> fpw = {1, 1, 1};
|
||||
std::vector<int> fpw_nm1;
|
||||
int num_fragments = std::min((shape_0/8)*(shape_1/8), 4);
|
||||
do {
|
||||
fpw_nm1 = fpw;
|
||||
if(fpw[0]*fpw[1] < num_fragments)
|
||||
fpw[0] = clamp(fpw[0]*2, 1, shape_0 / 8);
|
||||
if(fpw[0]*fpw[1] < num_fragments)
|
||||
fpw[1] = clamp(fpw[1]*2, 1, shape_1 / 8);
|
||||
}while(fpw_nm1 != fpw);
|
||||
// store parameters
|
||||
for(int d = 0; d < shapes.size(); d++)
|
||||
params_.at(i).at("fpw.d" + std::to_string(d))->set_value(fpw[d]);
|
||||
|
||||
/* warps per tile */
|
||||
// try to make things as square as possible to maximize data re-use
|
||||
std::vector<int> wpt = {1, 1, 1};
|
||||
std::vector<int> wpt_nm1;
|
||||
do{
|
||||
wpt_nm1 = wpt;
|
||||
if(wpt[0] * wpt[1] * wpt[2] < num_warps)
|
||||
wpt[0] = clamp(wpt[0]*2, 1, shape_0 / (fpw[0]*8));
|
||||
if(wpt[0] * wpt[1] * wpt[2] < num_warps)
|
||||
wpt[1] = clamp(wpt[1]*2, 1, shape_1 / (fpw[1]*8));
|
||||
}while(wpt_nm1 != wpt);
|
||||
// store parameters
|
||||
for(int d = 0; d < shapes.size(); d++)
|
||||
params_.at(i).at("wpt.d" + std::to_string(d))->set_value(wpt[d]);
|
||||
|
||||
/* sanity check */
|
||||
unsigned effective_num_warps = 1;
|
||||
for(size_t d = 0; d < shapes.size(); d++){
|
||||
std::string str_d = std::to_string(d);
|
||||
effective_num_warps *= params_.at(i).at("wpt.d" + str_d)->get_value();
|
||||
}
|
||||
assert(num_warps == effective_num_warps);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
else {
|
||||
unsigned result = 32;
|
||||
for(unsigned k = 0; k < i->get_type()->get_tile_shapes().size(); k++){
|
||||
std::string suffix = ".d" + std::to_string(k);
|
||||
result *= params_.at(i).at("wpt" + suffix)->get_value();
|
||||
|
||||
/* Scan-line */
|
||||
else{
|
||||
int shape = shapes[0]->get_value();
|
||||
int current = num_threads;
|
||||
params_.at(i).at("nts.d0")->set_value(clamp(size / num_threads, 1, 8));
|
||||
params_.at(i).at("mts.d0")->set_value(clamp(current, 1, shape / params_.at(i).at("nts.d0")->get_value()));
|
||||
current = current / params_.at(i).at("mts.d0")->get_value();
|
||||
for(size_t d = 1; d < shapes.size(); d++){
|
||||
std::string str_d = std::to_string(d);
|
||||
shape = shapes[d]->get_value();
|
||||
params_.at(i).at("nts.d" + str_d)->set_value(1);
|
||||
params_.at(i).at("mts.d" + str_d)->set_value(clamp(current, 1, shape));
|
||||
current = current / params_.at(i).at("mts.d" + str_d)->get_value();
|
||||
}
|
||||
/* sanity check */
|
||||
unsigned effective_num_threads = 1;
|
||||
for(size_t d = 0; d < shapes.size(); d++){
|
||||
std::string str_d = std::to_string(d);
|
||||
effective_num_threads *= params_.at(i).at("mts.d" + str_d)->get_value();
|
||||
}
|
||||
assert(num_threads == effective_num_threads);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
void tune::create_grids(std::vector<ir::instruction*> &grids,
|
||||
std::map<ir::metaparameter*, ir::instruction*> &references,
|
||||
ir::function *fn) {
|
||||
|
||||
void tune::create_grids(std::vector<ir::value*> &grids,
|
||||
std::map<unsigned, ir::value*> &references,
|
||||
ir::function *fn) {
|
||||
// get number of dimensions greater than 1
|
||||
auto get_tile_gt1_dim = [&](ir::value *v){
|
||||
unsigned result = 0;
|
||||
auto one = ir::tile_type::make_one(fn->get_fn_type()->get_context());
|
||||
for(ir::constant_int *shape: v->get_type()->get_tile_shapes()) {
|
||||
result += (shape != one);
|
||||
for(ir::constant_int* shape: v->get_type()->get_tile_shapes()) {
|
||||
result += (shape->get_value() > 1)?shape->get_value():0;
|
||||
}
|
||||
return result;
|
||||
};
|
||||
// bind references
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *i: block->get_inst_list()){
|
||||
if(!i->get_type()->is_tile_ty())
|
||||
continue;
|
||||
for(auto ¶m: params_.at(i)){
|
||||
if(param.second->get_value() == 1)
|
||||
std::set<ir::value*> seen;
|
||||
std::function<void(ir::value*)> bind_references = [&](ir::value *v)
|
||||
{
|
||||
// skip
|
||||
if(!v->get_type()->is_tile_ty() || !seen.insert(v).second)
|
||||
return;
|
||||
// recurse
|
||||
if(auto *user = dynamic_cast<ir::user*>(v))
|
||||
for(ir::value *op: user->ops())
|
||||
bind_references(op);
|
||||
// bind
|
||||
const auto& shapes = v->get_type()->get_tile_shapes();
|
||||
for(size_t d = 0; d < shapes.size(); d++){
|
||||
if(shapes[d]->get_value() == 1)
|
||||
continue;
|
||||
ir::instruction *&r = references[param.second];
|
||||
if(!r || get_tile_gt1_dim(i) > get_tile_gt1_dim(r))
|
||||
r = i;
|
||||
unsigned x = get_param_group(v, d);
|
||||
ir::value *&r = references[x];
|
||||
if(!r || get_tile_gt1_dim(v) > get_tile_gt1_dim(r))
|
||||
r = v;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *i: block->get_inst_list())
|
||||
bind_references(i);
|
||||
|
||||
// create grid
|
||||
for(auto &ref: references)
|
||||
if(std::find(grids.begin(), grids.end(), ref.second) == grids.end())
|
||||
@@ -370,85 +382,11 @@ void tune::create_grids(std::vector<ir::instruction*> &grids,
|
||||
|
||||
|
||||
bool tune::check_constraints(std::map<ir::value *, std::vector<std::string>> &errors) {
|
||||
using std::to_string;
|
||||
|
||||
auto get_num_warps = [&](ir::instruction *i, unsigned axis) {
|
||||
std::string strk = to_string(axis);
|
||||
if(fragments_.at({i, axis}) == STRIDED_SCAN){
|
||||
unsigned mts = params_[i]["mts.d" + strk]->get_value();
|
||||
unsigned nts = params_[i]["nts.d" + strk]->get_value();
|
||||
unsigned shape = i->get_type()->get_tile_shapes()[axis]->get_value();
|
||||
return shape / (mts * nts);
|
||||
}
|
||||
else{
|
||||
return (unsigned)params_[i]["wpt.d" + strk]->get_value();
|
||||
}
|
||||
};
|
||||
|
||||
// number of warps
|
||||
ir::instruction *first = grids_.front();
|
||||
int num_warps = 1;
|
||||
for(size_t k = 0; k < first->get_type()->get_tile_shapes().size(); k++)
|
||||
num_warps *= get_num_warps(first, k);
|
||||
|
||||
// check constraints
|
||||
for(ir::instruction *i: grids_){
|
||||
// std::cout << i->get_name() << std::endl;
|
||||
ir::type *ty = i->get_type();
|
||||
const auto &shapes = ty->get_tile_shapes();
|
||||
// for each dimension, the product of layout components
|
||||
// must device the shape
|
||||
for(size_t k = 0; k < shapes.size(); k++) {
|
||||
std::string strk = to_string(k);
|
||||
unsigned multiple;
|
||||
if(fragments_.at({i, 0}) == STRIDED_SCAN) {
|
||||
ir::metaparameter *mts = params_[i]["mts.d" + strk];
|
||||
ir::metaparameter *nts = params_[i]["nts.d" + strk];
|
||||
multiple = mts->get_value()*nts->get_value();
|
||||
}
|
||||
else {
|
||||
ir::metaparameter *fpw = params_[i]["fpw.d" + strk];
|
||||
ir::metaparameter *wpt = params_[i]["wpt.d" + strk];
|
||||
multiple = fpw->get_value()*wpt->get_value();
|
||||
if(k < 2)
|
||||
multiple *= 8;
|
||||
}
|
||||
if(shapes[k]->get_value() % multiple != 0)
|
||||
errors[i].push_back("for dim " + strk + ": shape (" + to_string(shapes[k]->get_value()) + ")"
|
||||
" is not a multiple of layout (" + to_string(multiple) + ")");
|
||||
}
|
||||
// the product of mma fragments per warp must be 4
|
||||
if(fragments_.at({i, 0}) == HMMA_FRAGMENT_C){
|
||||
unsigned prod = 1;
|
||||
for(size_t k = 0; k < shapes.size(); k++){
|
||||
prod *= params_[i]["fpw.d" + std::to_string(k)]->get_value();
|
||||
}
|
||||
if(prod > 4)
|
||||
errors[i].push_back("HMMA must have only 4 fragments per warp");
|
||||
}
|
||||
int num_threads = get_req_num_threads(i);
|
||||
if(num_threads % 32 != 0)
|
||||
errors[i].push_back("number of threads per block (" + to_string(num_threads) + ") must be multiple of warp size");
|
||||
if(num_threads != num_threads_)
|
||||
errors[i].push_back("Number of threads must be the same for all tiles (" + to_string(num_threads_) + ")");
|
||||
}
|
||||
// for(auto x: errors)
|
||||
// for(auto e: x.second)
|
||||
// std::cout << x.first->get_name() << ": " << e << std::endl;
|
||||
// exit(EXIT_SUCCESS);
|
||||
return errors.empty();
|
||||
}
|
||||
|
||||
unsigned tune::get_num_global_range() {
|
||||
return num_global_ranges_;
|
||||
}
|
||||
|
||||
unsigned tune::get_global_range_size(unsigned axis) {
|
||||
return global_range_sizes_.at(axis)->get_value();
|
||||
}
|
||||
|
||||
unsigned tune::get_num_threads() {
|
||||
return num_threads_;
|
||||
return num_warps_->get_value()*32;
|
||||
}
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user