[jit] can now infer launch parameters from triton module

This commit is contained in:
Philippe Tillet
2019-03-09 14:44:13 -05:00
parent b721202812
commit 9a3537662d
5 changed files with 71 additions and 7 deletions

View File

@@ -12,6 +12,8 @@
namespace triton{
namespace codegen{
tune::tune(): num_global_ranges_(0){ }
void tune::add_constraint(node_t x, node_t y) {
dependencies_[x].insert(y);
dependencies_[y].insert(x);
@@ -91,6 +93,11 @@ void tune::connected_components(node_t x, const std::vector<ir::metaparameter *>
if(auto mp = dynamic_cast<ir::metaparameter*>(shape))
params_[x.first].insert({"shape" + suffix, mp});
}
if(auto range = dynamic_cast<ir::get_global_range_inst*>(x.first)){
unsigned ax = range->get_axis();
global_range_sizes_[ax] = params_[x.first].at("shape.d0");
num_global_ranges_ = std::max(num_global_ranges_, ax + 1);
}
if(static_params_.find(x) != static_params_.end()){
mps[0]->set_value(static_params_.at(x));
mps[1]->set_value(static_params_.at(x));
@@ -122,6 +129,7 @@ std::map<std::string, ir::metaparameter *> tune::get_params(ir::instruction* i)
void tune::run(ir::module &mod) {
ir::context &ctx = mod.get_context();
// Create metaparameters
for(ir::function *fn: mod.get_function_list()){
// Build constraints graph
for(ir::basic_block *block: fn->blocks())
@@ -143,6 +151,19 @@ void tune::run(ir::module &mod) {
connected_components(*nodes_.begin(), {mp0, mp1, mp2}, nodes_, dependencies_);
}
}
// // Get launch info
// for(ir::function *fn: mod.get_function_list()){
// std::map<ir::metaparameter*, ir::instruction*> references;
// std::vector<ir::instruction*> grids;
// create_grids(grids, references, fn);
// ir::instruction *first = grids.front();
// for(unsigned i = 0; i < first->get_type()->get_tile_shapes().size(); i++){
// std::string suffix = ".d" + std::to_string(i);
// num_threads_mp_vec_.push_back(params_.at(first).at("p1" + suffix));
// num_threads_mp_vec_.push_back(params_.at(first).at("p2" + suffix));
// }
// }
}
void tune::create_grids(std::vector<ir::instruction*> &grids,
@@ -186,6 +207,12 @@ for(ir::function *fn: mod.get_function_list()){
std::vector<ir::instruction*> grids;
create_grids(grids, references, fn);
for(unsigned i = 0; i < grids.front()->get_type()->get_tile_shapes().size(); i++){
std::string suffix = ".d" + std::to_string(i);
num_threads_mp_vec_.push_back(params_.at(grids.front()).at("p1" + suffix));
num_threads_mp_vec_.push_back(params_.at(grids.front()).at("p2" + suffix));
}
// number of warps
int num_warps = 1;
for(size_t k = 0; k < grids.front()->get_type()->get_tile_shapes().size(); k++)
@@ -225,5 +252,21 @@ for(ir::function *fn: mod.get_function_list()){
}
}
unsigned tune::get_num_global_range() {
return num_global_ranges_;
}
unsigned tune::get_global_range_size(unsigned axis) {
return global_range_sizes_.at(axis)->get_value();
}
unsigned tune::get_num_threads() {
unsigned result = 1;
for(ir::metaparameter *mp: num_threads_mp_vec_)
result *= mp->get_value();
return result;
}
}
}