[jit] can now infer launch parameters from triton module
This commit is contained in:
@@ -76,14 +76,13 @@ int main() {
|
||||
// b1
|
||||
1, 8, 1
|
||||
};
|
||||
unsigned TM = params[6];
|
||||
unsigned TN = params[10];
|
||||
unsigned nthreads = params[1]*params[2]*params[15]*params[16];
|
||||
|
||||
|
||||
auto context = triton::driver::backend::contexts::get_default();
|
||||
triton::jit jit(context);
|
||||
jit.add_module(src, params);
|
||||
triton::driver::kernel kernel = jit.get_function("matmul");
|
||||
triton::jit::launch_information info = jit.get_launch_info("matmul");
|
||||
|
||||
size_t M = 128, N = 128, K = 128;
|
||||
size_t bound = 8;
|
||||
@@ -112,6 +111,9 @@ int main() {
|
||||
kernel.setArg(4, N);
|
||||
kernel.setArg(5, K);
|
||||
kernel.setArg(6, bound);
|
||||
unsigned TM = info.global_range_size[0];
|
||||
unsigned TN = info.global_range_size[1];
|
||||
unsigned nthreads = info.num_threads;
|
||||
stream.enqueue(kernel, {(M + TM - 1)/TM, (N + TN - 1)/TN, 1}, {nthreads, 1, 1});
|
||||
stream.synchronize();
|
||||
stream.read(dc, true, 0, hc);
|
||||
|
@@ -30,14 +30,16 @@ private:
|
||||
|
||||
|
||||
public:
|
||||
tune();
|
||||
std::vector<ir::metaparameter *> get_params(ir::module& mod);
|
||||
std::map<std::string, ir::metaparameter *> get_params(ir::instruction* i);
|
||||
ir::metaparameter* get_param(ir::value *value, const std::string &key) { return params_[value][key]; }
|
||||
void copy(ir::value *dst, ir::value *src) { params_[dst] = params_[src]; }
|
||||
bool check_constraints(ir::module &fn, std::map<ir::value *, std::vector<std::string>> &errors);
|
||||
void run(ir::module &mod);
|
||||
ir::metaparameter* get_num_threads();
|
||||
ir::metaparameter* get_global_range_size(unsigned axis);
|
||||
unsigned get_num_global_range();
|
||||
unsigned get_global_range_size(unsigned axis);
|
||||
unsigned get_num_threads();
|
||||
|
||||
private:
|
||||
std::vector<unsigned*> pool_;
|
||||
@@ -45,8 +47,9 @@ private:
|
||||
std::set<node_t> nodes_;
|
||||
std::map<node_t, unsigned> static_params_;
|
||||
std::map<ir::value*, std::map<std::string, ir::metaparameter*>> params_;
|
||||
ir::metaparameter *num_threads_;
|
||||
std::vector<ir::metaparameter*> global_range_sizes_;
|
||||
std::vector<ir::metaparameter*> num_threads_mp_vec_;
|
||||
std::map<unsigned, ir::metaparameter*> global_range_sizes_;
|
||||
unsigned num_global_ranges_;
|
||||
};
|
||||
|
||||
|
||||
|
@@ -20,6 +20,12 @@ class context;
|
||||
}
|
||||
|
||||
class jit {
|
||||
public:
|
||||
struct launch_information{
|
||||
std::vector<unsigned> global_range_size;
|
||||
unsigned num_threads;
|
||||
};
|
||||
|
||||
private:
|
||||
void init_llvm();
|
||||
std::string compute_data_layout(bool is64Bit = true, bool UseShortPointers = true);
|
||||
@@ -31,12 +37,14 @@ public:
|
||||
void add_module(ir::module &module, const std::vector<unsigned>& params = {});
|
||||
void add_module(const std::string &src, const std::vector<unsigned>& params = {});
|
||||
driver::kernel get_function(const std::string &name);
|
||||
launch_information get_launch_info(const std::string &name);
|
||||
|
||||
private:
|
||||
std::vector<driver::module> modules_;
|
||||
driver::context driver_context_;
|
||||
llvm::LLVMContext llvm_context_;
|
||||
ir::context triton_context_;
|
||||
std::map<std::string, launch_information> launch_info_map_;
|
||||
};
|
||||
|
||||
|
||||
|
@@ -12,6 +12,8 @@
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
|
||||
tune::tune(): num_global_ranges_(0){ }
|
||||
|
||||
void tune::add_constraint(node_t x, node_t y) {
|
||||
dependencies_[x].insert(y);
|
||||
dependencies_[y].insert(x);
|
||||
@@ -91,6 +93,11 @@ void tune::connected_components(node_t x, const std::vector<ir::metaparameter *>
|
||||
if(auto mp = dynamic_cast<ir::metaparameter*>(shape))
|
||||
params_[x.first].insert({"shape" + suffix, mp});
|
||||
}
|
||||
if(auto range = dynamic_cast<ir::get_global_range_inst*>(x.first)){
|
||||
unsigned ax = range->get_axis();
|
||||
global_range_sizes_[ax] = params_[x.first].at("shape.d0");
|
||||
num_global_ranges_ = std::max(num_global_ranges_, ax + 1);
|
||||
}
|
||||
if(static_params_.find(x) != static_params_.end()){
|
||||
mps[0]->set_value(static_params_.at(x));
|
||||
mps[1]->set_value(static_params_.at(x));
|
||||
@@ -122,6 +129,7 @@ std::map<std::string, ir::metaparameter *> tune::get_params(ir::instruction* i)
|
||||
|
||||
void tune::run(ir::module &mod) {
|
||||
ir::context &ctx = mod.get_context();
|
||||
// Create metaparameters
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
// Build constraints graph
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
@@ -143,6 +151,19 @@ void tune::run(ir::module &mod) {
|
||||
connected_components(*nodes_.begin(), {mp0, mp1, mp2}, nodes_, dependencies_);
|
||||
}
|
||||
}
|
||||
|
||||
// // Get launch info
|
||||
// for(ir::function *fn: mod.get_function_list()){
|
||||
// std::map<ir::metaparameter*, ir::instruction*> references;
|
||||
// std::vector<ir::instruction*> grids;
|
||||
// create_grids(grids, references, fn);
|
||||
// ir::instruction *first = grids.front();
|
||||
// for(unsigned i = 0; i < first->get_type()->get_tile_shapes().size(); i++){
|
||||
// std::string suffix = ".d" + std::to_string(i);
|
||||
// num_threads_mp_vec_.push_back(params_.at(first).at("p1" + suffix));
|
||||
// num_threads_mp_vec_.push_back(params_.at(first).at("p2" + suffix));
|
||||
// }
|
||||
// }
|
||||
}
|
||||
|
||||
void tune::create_grids(std::vector<ir::instruction*> &grids,
|
||||
@@ -186,6 +207,12 @@ for(ir::function *fn: mod.get_function_list()){
|
||||
std::vector<ir::instruction*> grids;
|
||||
create_grids(grids, references, fn);
|
||||
|
||||
for(unsigned i = 0; i < grids.front()->get_type()->get_tile_shapes().size(); i++){
|
||||
std::string suffix = ".d" + std::to_string(i);
|
||||
num_threads_mp_vec_.push_back(params_.at(grids.front()).at("p1" + suffix));
|
||||
num_threads_mp_vec_.push_back(params_.at(grids.front()).at("p2" + suffix));
|
||||
}
|
||||
|
||||
// number of warps
|
||||
int num_warps = 1;
|
||||
for(size_t k = 0; k < grids.front()->get_type()->get_tile_shapes().size(); k++)
|
||||
@@ -225,5 +252,21 @@ for(ir::function *fn: mod.get_function_list()){
|
||||
}
|
||||
}
|
||||
|
||||
unsigned tune::get_num_global_range() {
|
||||
return num_global_ranges_;
|
||||
}
|
||||
|
||||
unsigned tune::get_global_range_size(unsigned axis) {
|
||||
return global_range_sizes_.at(axis)->get_value();
|
||||
}
|
||||
|
||||
unsigned tune::get_num_threads() {
|
||||
unsigned result = 1;
|
||||
for(ir::metaparameter *mp: num_threads_mp_vec_)
|
||||
result *= mp->get_value();
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
@@ -85,6 +85,11 @@ std::unique_ptr<llvm::Module> jit::make_llvm_module(ir::module &module, const st
|
||||
vectorize.run(module);
|
||||
selection.run(module, *result);
|
||||
|
||||
// launch information
|
||||
auto &launch_info_map = launch_info_map_[result->getName()];
|
||||
for(unsigned i = 0; i < tune.get_num_global_range(); i++)
|
||||
launch_info_map.global_range_size.push_back(tune.get_global_range_size(i));
|
||||
launch_info_map.num_threads = tune.get_num_threads();
|
||||
return std::unique_ptr<llvm::Module>(result);
|
||||
}
|
||||
|
||||
@@ -145,5 +150,8 @@ driver::kernel jit::get_function(const std::string &name) {
|
||||
return driver::kernel(modules_.front(), name.c_str());
|
||||
}
|
||||
|
||||
jit::launch_information jit::get_launch_info(const std::string &name) {
|
||||
return launch_info_map_.at(name);
|
||||
}
|
||||
|
||||
}
|
||||
|
Reference in New Issue
Block a user