[jit] can now infer launch parameters from triton module

This commit is contained in:
Philippe Tillet
2019-03-09 14:44:13 -05:00
parent b721202812
commit 9a3537662d
5 changed files with 71 additions and 7 deletions

View File

@@ -76,14 +76,13 @@ int main() {
// b1
1, 8, 1
};
unsigned TM = params[6];
unsigned TN = params[10];
unsigned nthreads = params[1]*params[2]*params[15]*params[16];
auto context = triton::driver::backend::contexts::get_default();
triton::jit jit(context);
jit.add_module(src, params);
triton::driver::kernel kernel = jit.get_function("matmul");
triton::jit::launch_information info = jit.get_launch_info("matmul");
size_t M = 128, N = 128, K = 128;
size_t bound = 8;
@@ -112,6 +111,9 @@ int main() {
kernel.setArg(4, N);
kernel.setArg(5, K);
kernel.setArg(6, bound);
unsigned TM = info.global_range_size[0];
unsigned TN = info.global_range_size[1];
unsigned nthreads = info.num_threads;
stream.enqueue(kernel, {(M + TM - 1)/TM, (N + TN - 1)/TN, 1}, {nthreads, 1, 1});
stream.synchronize();
stream.read(dc, true, 0, hc);

View File

@@ -30,14 +30,16 @@ private:
public:
tune();
std::vector<ir::metaparameter *> get_params(ir::module& mod);
std::map<std::string, ir::metaparameter *> get_params(ir::instruction* i);
ir::metaparameter* get_param(ir::value *value, const std::string &key) { return params_[value][key]; }
void copy(ir::value *dst, ir::value *src) { params_[dst] = params_[src]; }
bool check_constraints(ir::module &fn, std::map<ir::value *, std::vector<std::string>> &errors);
void run(ir::module &mod);
ir::metaparameter* get_num_threads();
ir::metaparameter* get_global_range_size(unsigned axis);
unsigned get_num_global_range();
unsigned get_global_range_size(unsigned axis);
unsigned get_num_threads();
private:
std::vector<unsigned*> pool_;
@@ -45,8 +47,9 @@ private:
std::set<node_t> nodes_;
std::map<node_t, unsigned> static_params_;
std::map<ir::value*, std::map<std::string, ir::metaparameter*>> params_;
ir::metaparameter *num_threads_;
std::vector<ir::metaparameter*> global_range_sizes_;
std::vector<ir::metaparameter*> num_threads_mp_vec_;
std::map<unsigned, ir::metaparameter*> global_range_sizes_;
unsigned num_global_ranges_;
};

View File

@@ -20,6 +20,12 @@ class context;
}
class jit {
public:
struct launch_information{
std::vector<unsigned> global_range_size;
unsigned num_threads;
};
private:
void init_llvm();
std::string compute_data_layout(bool is64Bit = true, bool UseShortPointers = true);
@@ -31,12 +37,14 @@ public:
void add_module(ir::module &module, const std::vector<unsigned>& params = {});
void add_module(const std::string &src, const std::vector<unsigned>& params = {});
driver::kernel get_function(const std::string &name);
launch_information get_launch_info(const std::string &name);
private:
std::vector<driver::module> modules_;
driver::context driver_context_;
llvm::LLVMContext llvm_context_;
ir::context triton_context_;
std::map<std::string, launch_information> launch_info_map_;
};

View File

@@ -12,6 +12,8 @@
namespace triton{
namespace codegen{
tune::tune(): num_global_ranges_(0){ }
void tune::add_constraint(node_t x, node_t y) {
dependencies_[x].insert(y);
dependencies_[y].insert(x);
@@ -91,6 +93,11 @@ void tune::connected_components(node_t x, const std::vector<ir::metaparameter *>
if(auto mp = dynamic_cast<ir::metaparameter*>(shape))
params_[x.first].insert({"shape" + suffix, mp});
}
if(auto range = dynamic_cast<ir::get_global_range_inst*>(x.first)){
unsigned ax = range->get_axis();
global_range_sizes_[ax] = params_[x.first].at("shape.d0");
num_global_ranges_ = std::max(num_global_ranges_, ax + 1);
}
if(static_params_.find(x) != static_params_.end()){
mps[0]->set_value(static_params_.at(x));
mps[1]->set_value(static_params_.at(x));
@@ -122,6 +129,7 @@ std::map<std::string, ir::metaparameter *> tune::get_params(ir::instruction* i)
void tune::run(ir::module &mod) {
ir::context &ctx = mod.get_context();
// Create metaparameters
for(ir::function *fn: mod.get_function_list()){
// Build constraints graph
for(ir::basic_block *block: fn->blocks())
@@ -143,6 +151,19 @@ void tune::run(ir::module &mod) {
connected_components(*nodes_.begin(), {mp0, mp1, mp2}, nodes_, dependencies_);
}
}
// // Get launch info
// for(ir::function *fn: mod.get_function_list()){
// std::map<ir::metaparameter*, ir::instruction*> references;
// std::vector<ir::instruction*> grids;
// create_grids(grids, references, fn);
// ir::instruction *first = grids.front();
// for(unsigned i = 0; i < first->get_type()->get_tile_shapes().size(); i++){
// std::string suffix = ".d" + std::to_string(i);
// num_threads_mp_vec_.push_back(params_.at(first).at("p1" + suffix));
// num_threads_mp_vec_.push_back(params_.at(first).at("p2" + suffix));
// }
// }
}
void tune::create_grids(std::vector<ir::instruction*> &grids,
@@ -186,6 +207,12 @@ for(ir::function *fn: mod.get_function_list()){
std::vector<ir::instruction*> grids;
create_grids(grids, references, fn);
for(unsigned i = 0; i < grids.front()->get_type()->get_tile_shapes().size(); i++){
std::string suffix = ".d" + std::to_string(i);
num_threads_mp_vec_.push_back(params_.at(grids.front()).at("p1" + suffix));
num_threads_mp_vec_.push_back(params_.at(grids.front()).at("p2" + suffix));
}
// number of warps
int num_warps = 1;
for(size_t k = 0; k < grids.front()->get_type()->get_tile_shapes().size(); k++)
@@ -225,5 +252,21 @@ for(ir::function *fn: mod.get_function_list()){
}
}
unsigned tune::get_num_global_range() {
return num_global_ranges_;
}
unsigned tune::get_global_range_size(unsigned axis) {
return global_range_sizes_.at(axis)->get_value();
}
unsigned tune::get_num_threads() {
unsigned result = 1;
for(ir::metaparameter *mp: num_threads_mp_vec_)
result *= mp->get_value();
return result;
}
}
}

View File

@@ -85,6 +85,11 @@ std::unique_ptr<llvm::Module> jit::make_llvm_module(ir::module &module, const st
vectorize.run(module);
selection.run(module, *result);
// launch information
auto &launch_info_map = launch_info_map_[result->getName()];
for(unsigned i = 0; i < tune.get_num_global_range(); i++)
launch_info_map.global_range_size.push_back(tune.get_global_range_size(i));
launch_info_map.num_threads = tune.get_num_threads();
return std::unique_ptr<llvm::Module>(result);
}
@@ -145,5 +150,8 @@ driver::kernel jit::get_function(const std::string &name) {
return driver::kernel(modules_.front(), name.c_str());
}
jit::launch_information jit::get_launch_info(const std::string &name) {
return launch_info_map_.at(name);
}
}