[jit/autotune] added support for multi-threaded auto-tuning

This commit is contained in:
Philippe Tillet
2019-07-14 21:54:57 -07:00
parent 3e7a3ed67a
commit 3c128fc2e2
3 changed files with 81 additions and 58 deletions

View File

@@ -14,13 +14,13 @@ int main() {
// initialize default compute device // initialize default compute device
auto context = triton::driver::backend::contexts::get_default(); auto context = triton::driver::backend::contexts::get_default();
auto op = triton::dnn::shift::BPROP; auto op = triton::dnn::shift::FPROP;
// initialization // initialization
int32_t R = 3, S = 3; int32_t R = 3, S = 3;
int32_t B = 16, F = 4096; int32_t B = 16, F = 512;
int32_t H = 16, W = 16; int32_t H = 16, W = 16;
int32_t C = 4096; int32_t C = 512;
// random shifts // random shifts
std::vector<int32_t> shift_h(C); std::vector<int32_t> shift_h(C);
@@ -66,7 +66,7 @@ int main() {
stream->write(db, true, 0, hb); stream->write(db, true, 0, hb);
stream->write(dc, true, 0, hc); stream->write(dc, true, 0, hc);
stream->synchronize(); stream->synchronize();
shift.enqueue(stream, {da, db, dc}); shift.enqueue(stream, {da, db, dc}, true);
// stream->read(dc, true, 0, hc); // stream->read(dc, true, 0, hc);
// shift.cpu_ref(rc.data(), ha.data(), hb.data()); // shift.cpu_ref(rc.data(), ha.data(), hb.data());
// for(size_t i = 0; i < hc.size(); i++) // for(size_t i = 0; i < hc.size(); i++)

View File

@@ -28,6 +28,10 @@ namespace llvm {
namespace triton { namespace triton {
namespace lang{
class translation_unit;
}
namespace codegen{ namespace codegen{
class tune; class tune;
} }
@@ -97,8 +101,9 @@ public:
private: private:
std::string compute_data_layout(bool is_64bit = true, bool use_short_pointers = true); std::string compute_data_layout(bool is_64bit = true, bool use_short_pointers = true);
std::unique_ptr<llvm::Module> make_llvm_module(triton::ir::module &module, passes_wrapper &passes); std::unique_ptr<llvm::Module> make_llvm_module(triton::ir::module &module, passes_wrapper &passes, llvm::LLVMContext &context, launch_information &info);
std::unique_ptr<ir::module> make_triton_module(const char* name, const char* src); std::unique_ptr<ir::module> make_triton_module(const char *name, triton::ir::context &context, triton::lang::translation_unit *program);
triton::lang::translation_unit *parse_program(const char *name, const char *src);
public: public:
jit(driver::context* context); jit(driver::context* context);

View File

@@ -19,6 +19,8 @@
#include "llvm/IR/LegacyPassManager.h" #include "llvm/IR/LegacyPassManager.h"
#include "llvm/Transforms/Scalar/EarlyCSE.h" #include "llvm/Transforms/Scalar/EarlyCSE.h"
#include "llvm/Analysis/LoopPass.h" #include "llvm/Analysis/LoopPass.h"
#include "triton/tools/thread_pool.h"
#include <mutex>
typedef struct yy_buffer_state * YY_BUFFER_STATE; typedef struct yy_buffer_state * YY_BUFFER_STATE;
extern int yyparse(); extern int yyparse();
@@ -28,14 +30,19 @@ extern triton::lang::translation_unit *ast_root;
namespace triton { namespace triton {
void loop_nest(std::vector<size_t> const & ranges, std::function<void(std::vector<size_t> const &)> const & f){ void loop_nest(std::vector<size_t> const & ranges,
std::function<void(std::vector<size_t> const &)> const & f,
size_t nthreads){
size_t D = ranges.size(); size_t D = ranges.size();
std::vector<size_t> values(D, 0); std::vector<size_t> values(D, 0);
// thread pools
nbsdx::concurrent::thread_pool pool(nthreads);
// Start with innermost loop // Start with innermost loop
size_t i = D - 1; size_t i = D - 1;
size_t current = 0;
while(true){ while(true){
//Execute function //Execute function
f(values); pool.add_job([values, &f](){ f(values); });
//Increment counters //Increment counters
while(values[i]++ == ranges[i] - 1){ while(values[i]++ == ranges[i] - 1){
if(i == 0) if(i == 0)
@@ -47,7 +54,7 @@ void loop_nest(std::vector<size_t> const & ranges, std::function<void(std::vecto
} }
template<class T> template<class T>
void loop_nest(std::vector<std::vector<T>> const & iterates, std::function<void(std::vector<T>)> const & f){ void loop_nest(std::vector<std::vector<T>> const & iterates, std::function<void(std::vector<T>)> const & f, size_t nthreads){
//Ranges to iterate over //Ranges to iterate over
std::vector<size_t> ranges; std::vector<size_t> ranges;
for(auto const & x: iterates) for(auto const & x: iterates)
@@ -60,17 +67,16 @@ void loop_nest(std::vector<std::vector<T>> const & iterates, std::function<void(
f(x); f(x);
}; };
//Iterate //Iterate
loop_nest(ranges, proxy); loop_nest(ranges, proxy, nthreads);
} }
std::unique_ptr<llvm::Module> jit::make_llvm_module(ir::module &module, passes_wrapper &passes) { std::unique_ptr<llvm::Module> jit::make_llvm_module(ir::module &module, passes_wrapper &passes, llvm::LLVMContext& llvm_context, launch_information& info) {
llvm::Module* result = new llvm::Module(module.get_name(), llvm_context_); llvm::Module* result = new llvm::Module(module.get_name(), llvm_context);
passes.selection.run(module, *result); passes.selection.run(module, *result);
// launch information // launch information
launch_information& info = launch_info_map_[result->getName()];
info.global_range_size.clear(); info.global_range_size.clear();
for(unsigned i = 0; i < passes.tune.get_num_global_range(); i++) for(unsigned i = 0; i < passes.tune.get_num_global_range(); i++)
info.global_range_size.push_back(passes.tune.get_global_range_size(i)); info.global_range_size.push_back(passes.tune.get_global_range_size(i));
@@ -78,14 +84,18 @@ std::unique_ptr<llvm::Module> jit::make_llvm_module(ir::module &module, passes_w
return std::unique_ptr<llvm::Module>(result); return std::unique_ptr<llvm::Module>(result);
} }
std::unique_ptr<ir::module> jit::make_triton_module(const char *name, const char *src) { triton::lang::translation_unit *jit::parse_program(const char *name, const char *src) {
// create AST from Triton-C source // create AST from Triton-C source
YY_BUFFER_STATE buffer = yy_scan_string(src); YY_BUFFER_STATE buffer = yy_scan_string(src);
yyparse(); yyparse();
yy_delete_buffer(buffer); yy_delete_buffer(buffer);
triton::lang::translation_unit *program = ast_root; triton::lang::translation_unit *program = ast_root;
return program;
}
std::unique_ptr<ir::module> jit::make_triton_module(const char * name, triton::ir::context &context, triton::lang::translation_unit *program) {
// create Triton-IR from AST // create Triton-IR from AST
ir::module* module = new ir::module(name, triton_context_); ir::module* module = new ir::module(name, context);
program->codegen(module); program->codegen(module);
return std::unique_ptr<ir::module>(module); return std::unique_ptr<ir::module>(module);
} }
@@ -98,7 +108,8 @@ jit::~jit(){ }
std::vector<unsigned> jit::get_valid(const char *name, const char *src) { std::vector<unsigned> jit::get_valid(const char *name, const char *src) {
// find metaparameters // find metaparameters
auto ptt_module = make_triton_module(name, src); triton::lang::translation_unit* program = parse_program(name, src);
auto ptt_module = make_triton_module(name, triton_context_, program);
ir::module &tt_module = *ptt_module; ir::module &tt_module = *ptt_module;
// set parameters // set parameters
passes_wrapper passes(target_.get()); passes_wrapper passes(target_.get());
@@ -111,6 +122,7 @@ std::vector<unsigned> jit::get_valid(const char *name, const char *src) {
ranges.push_back(mp->get_space()); ranges.push_back(mp->get_space());
// iterate over parameters // iterate over parameters
std::vector<unsigned> result; std::vector<unsigned> result;
size_t nthreads = 1;
loop_nest<unsigned>(ranges, [&](const std::vector<unsigned> params){ loop_nest<unsigned>(ranges, [&](const std::vector<unsigned> params){
if(!result.empty()) if(!result.empty())
return; return;
@@ -128,7 +140,7 @@ std::vector<unsigned> jit::get_valid(const char *name, const char *src) {
if(!errors.empty()) if(!errors.empty())
return; return;
result = params; result = params;
}); }, nthreads);
if(result.empty()) if(result.empty())
throw std::runtime_error("couldn't find valid parameters"); throw std::runtime_error("couldn't find valid parameters");
return result; return result;
@@ -138,72 +150,77 @@ std::vector<unsigned> jit::get_valid(const char *name, const char *src) {
jit::tune_res_t jit::autotune(const char *name, const char *src, benchmark_t benchmark) { jit::tune_res_t jit::autotune(const char *name, const char *src, benchmark_t benchmark) {
// find metaparameters // find metaparameters
auto ptt_module = make_triton_module(name, src); triton::lang::translation_unit* program = parse_program(name, src);
ir::module &tt_module = *ptt_module; auto ptt_module_0 = make_triton_module(name, triton_context_, program);
ir::module &tt_module_0 = *ptt_module_0;
// set parameters // set parameters
passes_wrapper passes(target_.get()); passes_wrapper passes_0(target_.get());
passes.target_independent(tt_module); passes_0.target_independent(tt_module_0);
passes.tune.run(tt_module); passes_0.tune.run(tt_module_0);
auto mps = passes.tune.get_params(tt_module);
// create parameter ranges // create parameter ranges
std::vector<std::vector<unsigned>> ranges; std::vector<std::vector<unsigned>> ranges;
auto mps = passes_0.tune.get_params(tt_module_0);
for(ir::metaparameter *mp: mps) for(ir::metaparameter *mp: mps)
ranges.push_back(mp->get_space()); ranges.push_back(mp->get_space());
// iterate over parameters // iterate over parameters
unsigned i;
tune_res_t best; tune_res_t best;
size_t nthreads = 4;
std::mutex mutex;
loop_nest<unsigned>(ranges, [&](const std::vector<unsigned> params){ loop_nest<unsigned>(ranges, [&](const std::vector<unsigned> params){
std::map<ir::value*, std::vector<std::string>> errors; std::map<ir::value*, std::vector<std::string>> errors;
i = 0; unsigned i = 0;
for(ir::metaparameter *mp: mps) {
mp->set_value(params[i++]); std::lock_guard<std::mutex> lock(mutex);
passes.target_independent(tt_module); for(ir::metaparameter *mp: mps)
passes.tune.init(tt_module); mp->set_value(params[i++]);
passes.tune.check_constraints(errors); passes_0.tune.init(tt_module_0);
// for(auto x: errors) passes_0.tune.check_constraints(errors);
// for(auto err: x.second) }
// std::cout << err << std::endl;
if(!errors.empty()) if(!errors.empty())
return; return;
// Deep copy of the module and tuner // Deep copy of the module and tuner
auto ptt_module = make_triton_module(name, src); triton::ir::context triton_context;
ir::module &tt_module = *ptt_module; auto ptt_module_1 = make_triton_module(name, triton_context, program);
// for(unsigned p: params) ir::module &tt_module_1 = *ptt_module_1;
// std::cout << p << " " << std::flush; // run passes
passes_wrapper passes(target_.get()); passes_wrapper passes_1(target_.get());
passes.target_independent(tt_module); passes_1.target_independent(tt_module_1);
passes.tune.run(tt_module); passes_1.tune.run(tt_module_1);
i = 0; i = 0;
for(ir::metaparameter* mp: passes.tune.get_params(tt_module)){ for(ir::metaparameter* mp: passes_1.tune.get_params(tt_module_1)){
mp->set_value(params[i++]); mp->set_value(params[i++]);
} }
passes.tune.init(tt_module); passes_1.tune.init(tt_module_1);
passes.target_dependent(tt_module); passes_1.target_dependent(tt_module_1);
driver::device* device = driver_context_->device(); driver::device* device = driver_context_->device();
if(passes.shmem_allocation.get_allocated_size() > device->max_shared_memory()) if(passes_1.shmem_allocation.get_allocated_size() > device->max_shared_memory())
return; return;
if(passes.tune.get_num_threads() > device->max_threads_per_block()) if(passes_1.tune.get_num_threads() > device->max_threads_per_block())
return; return;
// Compile // Compile
auto ll_module = make_llvm_module(tt_module, passes); launch_information info;
llvm::LLVMContext llvm_context;
auto ll_module = make_llvm_module(tt_module_1, passes_1, llvm_context, info);
std::unique_ptr<driver::module> module(driver::module::create(driver_context_, &*ll_module)); std::unique_ptr<driver::module> module(driver::module::create(driver_context_, &*ll_module));
std::unique_ptr<driver::kernel> kernel(driver::kernel::create(module.get(), name)); std::unique_ptr<driver::kernel> kernel(driver::kernel::create(module.get(), name));
launch_information info = launch_info_map_.at(name);
for(unsigned p: params)
std::cout << p << " " << std::flush;
// add globals // add globals
for(auto x: tt_module.globals()) for(auto x: tt_module_1.globals())
global_ints_[x.first] = ((ir::metaparameter*)x.second)->get_value(); global_ints_[x.first] = ((ir::metaparameter*)x.second)->get_value();
modules_.insert({name, module.get()}); modules_.insert({name, module.get()});
double perf; double perf;
perf = benchmark(kernel.get(), info); perf = benchmark(kernel.get(), info);
if(perf > best.perf){ {
best.perf = perf; std::lock_guard<std::mutex> lock(mutex);
best.params = params; if(perf > best.perf){
best.perf = perf;
best.params = params;
}
for(unsigned p: params)
std::cout << p << " " << std::flush;
std::cout << perf << " [ " << best.perf << " ] " << std::endl;
} }
std::cout << perf << " [ " << best.perf << " ] " << std::endl;
modules_.erase(name); modules_.erase(name);
}); }, nthreads);
return best; return best;
} }
@@ -227,9 +244,9 @@ void jit::add_module(ir::module &tt_module, const std::vector<unsigned> &params)
if(errors.size()) if(errors.size())
throw std::runtime_error("invalid parameters"); throw std::runtime_error("invalid parameters");
// triton module -> llvm module // triton module -> llvm module
auto ll_module = make_llvm_module(tt_module, passes);
// llvm module -> machine code
std::string name = tt_module.get_name(); std::string name = tt_module.get_name();
auto ll_module = make_llvm_module(tt_module, passes, llvm_context_, launch_info_map_[name]);
// llvm module -> machine code
modules_.insert({name, driver::module::create(driver_context_, &*ll_module)}); modules_.insert({name, driver::module::create(driver_context_, &*ll_module)});
// add globals // add globals
for(auto x: tt_module.globals()) for(auto x: tt_module.globals())
@@ -237,7 +254,8 @@ void jit::add_module(ir::module &tt_module, const std::vector<unsigned> &params)
} }
void jit::add_module(const char *name, const char *src, const std::vector<unsigned> &params) { void jit::add_module(const char *name, const char *src, const std::vector<unsigned> &params) {
auto ptt_module = make_triton_module(name, src); triton::lang::translation_unit* program = parse_program(name, src);
auto ptt_module = make_triton_module(name, triton_context_, program);
add_module(*ptt_module, params); add_module(*ptt_module, params);
} }