[jit] can now infer launch parameters from triton module
This commit is contained in:
@@ -12,6 +12,8 @@
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
|
||||
tune::tune(): num_global_ranges_(0){ }
|
||||
|
||||
void tune::add_constraint(node_t x, node_t y) {
|
||||
dependencies_[x].insert(y);
|
||||
dependencies_[y].insert(x);
|
||||
@@ -91,6 +93,11 @@ void tune::connected_components(node_t x, const std::vector<ir::metaparameter *>
|
||||
if(auto mp = dynamic_cast<ir::metaparameter*>(shape))
|
||||
params_[x.first].insert({"shape" + suffix, mp});
|
||||
}
|
||||
if(auto range = dynamic_cast<ir::get_global_range_inst*>(x.first)){
|
||||
unsigned ax = range->get_axis();
|
||||
global_range_sizes_[ax] = params_[x.first].at("shape.d0");
|
||||
num_global_ranges_ = std::max(num_global_ranges_, ax + 1);
|
||||
}
|
||||
if(static_params_.find(x) != static_params_.end()){
|
||||
mps[0]->set_value(static_params_.at(x));
|
||||
mps[1]->set_value(static_params_.at(x));
|
||||
@@ -122,6 +129,7 @@ std::map<std::string, ir::metaparameter *> tune::get_params(ir::instruction* i)
|
||||
|
||||
void tune::run(ir::module &mod) {
|
||||
ir::context &ctx = mod.get_context();
|
||||
// Create metaparameters
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
// Build constraints graph
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
@@ -143,6 +151,19 @@ void tune::run(ir::module &mod) {
|
||||
connected_components(*nodes_.begin(), {mp0, mp1, mp2}, nodes_, dependencies_);
|
||||
}
|
||||
}
|
||||
|
||||
// // Get launch info
|
||||
// for(ir::function *fn: mod.get_function_list()){
|
||||
// std::map<ir::metaparameter*, ir::instruction*> references;
|
||||
// std::vector<ir::instruction*> grids;
|
||||
// create_grids(grids, references, fn);
|
||||
// ir::instruction *first = grids.front();
|
||||
// for(unsigned i = 0; i < first->get_type()->get_tile_shapes().size(); i++){
|
||||
// std::string suffix = ".d" + std::to_string(i);
|
||||
// num_threads_mp_vec_.push_back(params_.at(first).at("p1" + suffix));
|
||||
// num_threads_mp_vec_.push_back(params_.at(first).at("p2" + suffix));
|
||||
// }
|
||||
// }
|
||||
}
|
||||
|
||||
void tune::create_grids(std::vector<ir::instruction*> &grids,
|
||||
@@ -186,6 +207,12 @@ for(ir::function *fn: mod.get_function_list()){
|
||||
std::vector<ir::instruction*> grids;
|
||||
create_grids(grids, references, fn);
|
||||
|
||||
for(unsigned i = 0; i < grids.front()->get_type()->get_tile_shapes().size(); i++){
|
||||
std::string suffix = ".d" + std::to_string(i);
|
||||
num_threads_mp_vec_.push_back(params_.at(grids.front()).at("p1" + suffix));
|
||||
num_threads_mp_vec_.push_back(params_.at(grids.front()).at("p2" + suffix));
|
||||
}
|
||||
|
||||
// number of warps
|
||||
int num_warps = 1;
|
||||
for(size_t k = 0; k < grids.front()->get_type()->get_tile_shapes().size(); k++)
|
||||
@@ -225,5 +252,21 @@ for(ir::function *fn: mod.get_function_list()){
|
||||
}
|
||||
}
|
||||
|
||||
unsigned tune::get_num_global_range() {
|
||||
return num_global_ranges_;
|
||||
}
|
||||
|
||||
unsigned tune::get_global_range_size(unsigned axis) {
|
||||
return global_range_sizes_.at(axis)->get_value();
|
||||
}
|
||||
|
||||
unsigned tune::get_num_threads() {
|
||||
unsigned result = 1;
|
||||
for(ir::metaparameter *mp: num_threads_mp_vec_)
|
||||
result *= mp->get_value();
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user