[jit/autotune] added support for multi-threaded auto-tuning

This commit is contained in:
Philippe Tillet
2019-07-14 21:54:57 -07:00
parent 3e7a3ed67a
commit 3c128fc2e2
3 changed files with 81 additions and 58 deletions

View File

@@ -14,13 +14,13 @@ int main() {
// initialize default compute device
auto context = triton::driver::backend::contexts::get_default();
auto op = triton::dnn::shift::BPROP;
auto op = triton::dnn::shift::FPROP;
// initialization
int32_t R = 3, S = 3;
int32_t B = 16, F = 4096;
int32_t B = 16, F = 512;
int32_t H = 16, W = 16;
int32_t C = 4096;
int32_t C = 512;
// random shifts
std::vector<int32_t> shift_h(C);
@@ -66,7 +66,7 @@ int main() {
stream->write(db, true, 0, hb);
stream->write(dc, true, 0, hc);
stream->synchronize();
shift.enqueue(stream, {da, db, dc});
shift.enqueue(stream, {da, db, dc}, true);
// stream->read(dc, true, 0, hc);
// shift.cpu_ref(rc.data(), ha.data(), hb.data());
// for(size_t i = 0; i < hc.size(); i++)

View File

@@ -28,6 +28,10 @@ namespace llvm {
namespace triton {
namespace lang{
class translation_unit;
}
namespace codegen{
class tune;
}
@@ -97,8 +101,9 @@ public:
private:
std::string compute_data_layout(bool is_64bit = true, bool use_short_pointers = true);
std::unique_ptr<llvm::Module> make_llvm_module(triton::ir::module &module, passes_wrapper &passes);
std::unique_ptr<ir::module> make_triton_module(const char* name, const char* src);
std::unique_ptr<llvm::Module> make_llvm_module(triton::ir::module &module, passes_wrapper &passes, llvm::LLVMContext &context, launch_information &info);
std::unique_ptr<ir::module> make_triton_module(const char *name, triton::ir::context &context, triton::lang::translation_unit *program);
triton::lang::translation_unit *parse_program(const char *name, const char *src);
public:
jit(driver::context* context);

View File

@@ -19,6 +19,8 @@
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/Transforms/Scalar/EarlyCSE.h"
#include "llvm/Analysis/LoopPass.h"
#include "triton/tools/thread_pool.h"
#include <mutex>
typedef struct yy_buffer_state * YY_BUFFER_STATE;
extern int yyparse();
@@ -28,14 +30,19 @@ extern triton::lang::translation_unit *ast_root;
namespace triton {
void loop_nest(std::vector<size_t> const & ranges, std::function<void(std::vector<size_t> const &)> const & f){
void loop_nest(std::vector<size_t> const & ranges,
std::function<void(std::vector<size_t> const &)> const & f,
size_t nthreads){
size_t D = ranges.size();
std::vector<size_t> values(D, 0);
// thread pools
nbsdx::concurrent::thread_pool pool(nthreads);
// Start with innermost loop
size_t i = D - 1;
size_t current = 0;
while(true){
//Execute function
f(values);
pool.add_job([values, &f](){ f(values); });
//Increment counters
while(values[i]++ == ranges[i] - 1){
if(i == 0)
@@ -47,7 +54,7 @@ void loop_nest(std::vector<size_t> const & ranges, std::function<void(std::vecto
}
template<class T>
void loop_nest(std::vector<std::vector<T>> const & iterates, std::function<void(std::vector<T>)> const & f){
void loop_nest(std::vector<std::vector<T>> const & iterates, std::function<void(std::vector<T>)> const & f, size_t nthreads){
//Ranges to iterate over
std::vector<size_t> ranges;
for(auto const & x: iterates)
@@ -60,17 +67,16 @@ void loop_nest(std::vector<std::vector<T>> const & iterates, std::function<void(
f(x);
};
//Iterate
loop_nest(ranges, proxy);
loop_nest(ranges, proxy, nthreads);
}
std::unique_ptr<llvm::Module> jit::make_llvm_module(ir::module &module, passes_wrapper &passes) {
llvm::Module* result = new llvm::Module(module.get_name(), llvm_context_);
std::unique_ptr<llvm::Module> jit::make_llvm_module(ir::module &module, passes_wrapper &passes, llvm::LLVMContext& llvm_context, launch_information& info) {
llvm::Module* result = new llvm::Module(module.get_name(), llvm_context);
passes.selection.run(module, *result);
// launch information
launch_information& info = launch_info_map_[result->getName()];
info.global_range_size.clear();
for(unsigned i = 0; i < passes.tune.get_num_global_range(); i++)
info.global_range_size.push_back(passes.tune.get_global_range_size(i));
@@ -78,14 +84,18 @@ std::unique_ptr<llvm::Module> jit::make_llvm_module(ir::module &module, passes_w
return std::unique_ptr<llvm::Module>(result);
}
std::unique_ptr<ir::module> jit::make_triton_module(const char *name, const char *src) {
triton::lang::translation_unit *jit::parse_program(const char *name, const char *src) {
// create AST from Triton-C source
YY_BUFFER_STATE buffer = yy_scan_string(src);
yyparse();
yy_delete_buffer(buffer);
triton::lang::translation_unit *program = ast_root;
return program;
}
std::unique_ptr<ir::module> jit::make_triton_module(const char * name, triton::ir::context &context, triton::lang::translation_unit *program) {
// create Triton-IR from AST
ir::module* module = new ir::module(name, triton_context_);
ir::module* module = new ir::module(name, context);
program->codegen(module);
return std::unique_ptr<ir::module>(module);
}
@@ -98,7 +108,8 @@ jit::~jit(){ }
std::vector<unsigned> jit::get_valid(const char *name, const char *src) {
// find metaparameters
auto ptt_module = make_triton_module(name, src);
triton::lang::translation_unit* program = parse_program(name, src);
auto ptt_module = make_triton_module(name, triton_context_, program);
ir::module &tt_module = *ptt_module;
// set parameters
passes_wrapper passes(target_.get());
@@ -111,6 +122,7 @@ std::vector<unsigned> jit::get_valid(const char *name, const char *src) {
ranges.push_back(mp->get_space());
// iterate over parameters
std::vector<unsigned> result;
size_t nthreads = 1;
loop_nest<unsigned>(ranges, [&](const std::vector<unsigned> params){
if(!result.empty())
return;
@@ -128,7 +140,7 @@ std::vector<unsigned> jit::get_valid(const char *name, const char *src) {
if(!errors.empty())
return;
result = params;
});
}, nthreads);
if(result.empty())
throw std::runtime_error("couldn't find valid parameters");
return result;
@@ -138,72 +150,77 @@ std::vector<unsigned> jit::get_valid(const char *name, const char *src) {
jit::tune_res_t jit::autotune(const char *name, const char *src, benchmark_t benchmark) {
// find metaparameters
auto ptt_module = make_triton_module(name, src);
ir::module &tt_module = *ptt_module;
triton::lang::translation_unit* program = parse_program(name, src);
auto ptt_module_0 = make_triton_module(name, triton_context_, program);
ir::module &tt_module_0 = *ptt_module_0;
// set parameters
passes_wrapper passes(target_.get());
passes.target_independent(tt_module);
passes.tune.run(tt_module);
auto mps = passes.tune.get_params(tt_module);
passes_wrapper passes_0(target_.get());
passes_0.target_independent(tt_module_0);
passes_0.tune.run(tt_module_0);
// create parameter ranges
std::vector<std::vector<unsigned>> ranges;
auto mps = passes_0.tune.get_params(tt_module_0);
for(ir::metaparameter *mp: mps)
ranges.push_back(mp->get_space());
// iterate over parameters
unsigned i;
tune_res_t best;
size_t nthreads = 4;
std::mutex mutex;
loop_nest<unsigned>(ranges, [&](const std::vector<unsigned> params){
std::map<ir::value*, std::vector<std::string>> errors;
i = 0;
for(ir::metaparameter *mp: mps)
mp->set_value(params[i++]);
passes.target_independent(tt_module);
passes.tune.init(tt_module);
passes.tune.check_constraints(errors);
// for(auto x: errors)
// for(auto err: x.second)
// std::cout << err << std::endl;
unsigned i = 0;
{
std::lock_guard<std::mutex> lock(mutex);
for(ir::metaparameter *mp: mps)
mp->set_value(params[i++]);
passes_0.tune.init(tt_module_0);
passes_0.tune.check_constraints(errors);
}
if(!errors.empty())
return;
// Deep copy of the module and tuner
auto ptt_module = make_triton_module(name, src);
ir::module &tt_module = *ptt_module;
// for(unsigned p: params)
// std::cout << p << " " << std::flush;
passes_wrapper passes(target_.get());
passes.target_independent(tt_module);
passes.tune.run(tt_module);
triton::ir::context triton_context;
auto ptt_module_1 = make_triton_module(name, triton_context, program);
ir::module &tt_module_1 = *ptt_module_1;
// run passes
passes_wrapper passes_1(target_.get());
passes_1.target_independent(tt_module_1);
passes_1.tune.run(tt_module_1);
i = 0;
for(ir::metaparameter* mp: passes.tune.get_params(tt_module)){
for(ir::metaparameter* mp: passes_1.tune.get_params(tt_module_1)){
mp->set_value(params[i++]);
}
passes.tune.init(tt_module);
passes.target_dependent(tt_module);
passes_1.tune.init(tt_module_1);
passes_1.target_dependent(tt_module_1);
driver::device* device = driver_context_->device();
if(passes.shmem_allocation.get_allocated_size() > device->max_shared_memory())
if(passes_1.shmem_allocation.get_allocated_size() > device->max_shared_memory())
return;
if(passes.tune.get_num_threads() > device->max_threads_per_block())
if(passes_1.tune.get_num_threads() > device->max_threads_per_block())
return;
// Compile
auto ll_module = make_llvm_module(tt_module, passes);
launch_information info;
llvm::LLVMContext llvm_context;
auto ll_module = make_llvm_module(tt_module_1, passes_1, llvm_context, info);
std::unique_ptr<driver::module> module(driver::module::create(driver_context_, &*ll_module));
std::unique_ptr<driver::kernel> kernel(driver::kernel::create(module.get(), name));
launch_information info = launch_info_map_.at(name);
for(unsigned p: params)
std::cout << p << " " << std::flush;
// add globals
for(auto x: tt_module.globals())
for(auto x: tt_module_1.globals())
global_ints_[x.first] = ((ir::metaparameter*)x.second)->get_value();
modules_.insert({name, module.get()});
double perf;
perf = benchmark(kernel.get(), info);
if(perf > best.perf){
best.perf = perf;
best.params = params;
{
std::lock_guard<std::mutex> lock(mutex);
if(perf > best.perf){
best.perf = perf;
best.params = params;
}
for(unsigned p: params)
std::cout << p << " " << std::flush;
std::cout << perf << " [ " << best.perf << " ] " << std::endl;
}
std::cout << perf << " [ " << best.perf << " ] " << std::endl;
modules_.erase(name);
});
}, nthreads);
return best;
}
@@ -227,9 +244,9 @@ void jit::add_module(ir::module &tt_module, const std::vector<unsigned> &params)
if(errors.size())
throw std::runtime_error("invalid parameters");
// triton module -> llvm module
auto ll_module = make_llvm_module(tt_module, passes);
// llvm module -> machine code
std::string name = tt_module.get_name();
auto ll_module = make_llvm_module(tt_module, passes, llvm_context_, launch_info_map_[name]);
// llvm module -> machine code
modules_.insert({name, driver::module::create(driver_context_, &*ll_module)});
// add globals
for(auto x: tt_module.globals())
@@ -237,7 +254,8 @@ void jit::add_module(ir::module &tt_module, const std::vector<unsigned> &params)
}
void jit::add_module(const char *name, const char *src, const std::vector<unsigned> &params) {
auto ptt_module = make_triton_module(name, src);
triton::lang::translation_unit* program = parse_program(name, src);
auto ptt_module = make_triton_module(name, triton_context_, program);
add_module(*ptt_module, params);
}