[jit/autotune] added support for multi-threaded auto-tuning
This commit is contained in:
@@ -14,13 +14,13 @@ int main() {
|
|||||||
|
|
||||||
// initialize default compute device
|
// initialize default compute device
|
||||||
auto context = triton::driver::backend::contexts::get_default();
|
auto context = triton::driver::backend::contexts::get_default();
|
||||||
auto op = triton::dnn::shift::BPROP;
|
auto op = triton::dnn::shift::FPROP;
|
||||||
|
|
||||||
// initialization
|
// initialization
|
||||||
int32_t R = 3, S = 3;
|
int32_t R = 3, S = 3;
|
||||||
int32_t B = 16, F = 4096;
|
int32_t B = 16, F = 512;
|
||||||
int32_t H = 16, W = 16;
|
int32_t H = 16, W = 16;
|
||||||
int32_t C = 4096;
|
int32_t C = 512;
|
||||||
|
|
||||||
// random shifts
|
// random shifts
|
||||||
std::vector<int32_t> shift_h(C);
|
std::vector<int32_t> shift_h(C);
|
||||||
@@ -66,7 +66,7 @@ int main() {
|
|||||||
stream->write(db, true, 0, hb);
|
stream->write(db, true, 0, hb);
|
||||||
stream->write(dc, true, 0, hc);
|
stream->write(dc, true, 0, hc);
|
||||||
stream->synchronize();
|
stream->synchronize();
|
||||||
shift.enqueue(stream, {da, db, dc});
|
shift.enqueue(stream, {da, db, dc}, true);
|
||||||
// stream->read(dc, true, 0, hc);
|
// stream->read(dc, true, 0, hc);
|
||||||
// shift.cpu_ref(rc.data(), ha.data(), hb.data());
|
// shift.cpu_ref(rc.data(), ha.data(), hb.data());
|
||||||
// for(size_t i = 0; i < hc.size(); i++)
|
// for(size_t i = 0; i < hc.size(); i++)
|
||||||
|
@@ -28,6 +28,10 @@ namespace llvm {
|
|||||||
|
|
||||||
namespace triton {
|
namespace triton {
|
||||||
|
|
||||||
|
namespace lang{
|
||||||
|
class translation_unit;
|
||||||
|
}
|
||||||
|
|
||||||
namespace codegen{
|
namespace codegen{
|
||||||
class tune;
|
class tune;
|
||||||
}
|
}
|
||||||
@@ -97,8 +101,9 @@ public:
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
std::string compute_data_layout(bool is_64bit = true, bool use_short_pointers = true);
|
std::string compute_data_layout(bool is_64bit = true, bool use_short_pointers = true);
|
||||||
std::unique_ptr<llvm::Module> make_llvm_module(triton::ir::module &module, passes_wrapper &passes);
|
std::unique_ptr<llvm::Module> make_llvm_module(triton::ir::module &module, passes_wrapper &passes, llvm::LLVMContext &context, launch_information &info);
|
||||||
std::unique_ptr<ir::module> make_triton_module(const char* name, const char* src);
|
std::unique_ptr<ir::module> make_triton_module(const char *name, triton::ir::context &context, triton::lang::translation_unit *program);
|
||||||
|
triton::lang::translation_unit *parse_program(const char *name, const char *src);
|
||||||
|
|
||||||
public:
|
public:
|
||||||
jit(driver::context* context);
|
jit(driver::context* context);
|
||||||
|
@@ -19,6 +19,8 @@
|
|||||||
#include "llvm/IR/LegacyPassManager.h"
|
#include "llvm/IR/LegacyPassManager.h"
|
||||||
#include "llvm/Transforms/Scalar/EarlyCSE.h"
|
#include "llvm/Transforms/Scalar/EarlyCSE.h"
|
||||||
#include "llvm/Analysis/LoopPass.h"
|
#include "llvm/Analysis/LoopPass.h"
|
||||||
|
#include "triton/tools/thread_pool.h"
|
||||||
|
#include <mutex>
|
||||||
|
|
||||||
typedef struct yy_buffer_state * YY_BUFFER_STATE;
|
typedef struct yy_buffer_state * YY_BUFFER_STATE;
|
||||||
extern int yyparse();
|
extern int yyparse();
|
||||||
@@ -28,14 +30,19 @@ extern triton::lang::translation_unit *ast_root;
|
|||||||
|
|
||||||
namespace triton {
|
namespace triton {
|
||||||
|
|
||||||
void loop_nest(std::vector<size_t> const & ranges, std::function<void(std::vector<size_t> const &)> const & f){
|
void loop_nest(std::vector<size_t> const & ranges,
|
||||||
|
std::function<void(std::vector<size_t> const &)> const & f,
|
||||||
|
size_t nthreads){
|
||||||
size_t D = ranges.size();
|
size_t D = ranges.size();
|
||||||
std::vector<size_t> values(D, 0);
|
std::vector<size_t> values(D, 0);
|
||||||
|
// thread pools
|
||||||
|
nbsdx::concurrent::thread_pool pool(nthreads);
|
||||||
// Start with innermost loop
|
// Start with innermost loop
|
||||||
size_t i = D - 1;
|
size_t i = D - 1;
|
||||||
|
size_t current = 0;
|
||||||
while(true){
|
while(true){
|
||||||
//Execute function
|
//Execute function
|
||||||
f(values);
|
pool.add_job([values, &f](){ f(values); });
|
||||||
//Increment counters
|
//Increment counters
|
||||||
while(values[i]++ == ranges[i] - 1){
|
while(values[i]++ == ranges[i] - 1){
|
||||||
if(i == 0)
|
if(i == 0)
|
||||||
@@ -47,7 +54,7 @@ void loop_nest(std::vector<size_t> const & ranges, std::function<void(std::vecto
|
|||||||
}
|
}
|
||||||
|
|
||||||
template<class T>
|
template<class T>
|
||||||
void loop_nest(std::vector<std::vector<T>> const & iterates, std::function<void(std::vector<T>)> const & f){
|
void loop_nest(std::vector<std::vector<T>> const & iterates, std::function<void(std::vector<T>)> const & f, size_t nthreads){
|
||||||
//Ranges to iterate over
|
//Ranges to iterate over
|
||||||
std::vector<size_t> ranges;
|
std::vector<size_t> ranges;
|
||||||
for(auto const & x: iterates)
|
for(auto const & x: iterates)
|
||||||
@@ -60,17 +67,16 @@ void loop_nest(std::vector<std::vector<T>> const & iterates, std::function<void(
|
|||||||
f(x);
|
f(x);
|
||||||
};
|
};
|
||||||
//Iterate
|
//Iterate
|
||||||
loop_nest(ranges, proxy);
|
loop_nest(ranges, proxy, nthreads);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
std::unique_ptr<llvm::Module> jit::make_llvm_module(ir::module &module, passes_wrapper &passes) {
|
std::unique_ptr<llvm::Module> jit::make_llvm_module(ir::module &module, passes_wrapper &passes, llvm::LLVMContext& llvm_context, launch_information& info) {
|
||||||
llvm::Module* result = new llvm::Module(module.get_name(), llvm_context_);
|
llvm::Module* result = new llvm::Module(module.get_name(), llvm_context);
|
||||||
passes.selection.run(module, *result);
|
passes.selection.run(module, *result);
|
||||||
// launch information
|
// launch information
|
||||||
launch_information& info = launch_info_map_[result->getName()];
|
|
||||||
info.global_range_size.clear();
|
info.global_range_size.clear();
|
||||||
for(unsigned i = 0; i < passes.tune.get_num_global_range(); i++)
|
for(unsigned i = 0; i < passes.tune.get_num_global_range(); i++)
|
||||||
info.global_range_size.push_back(passes.tune.get_global_range_size(i));
|
info.global_range_size.push_back(passes.tune.get_global_range_size(i));
|
||||||
@@ -78,14 +84,18 @@ std::unique_ptr<llvm::Module> jit::make_llvm_module(ir::module &module, passes_w
|
|||||||
return std::unique_ptr<llvm::Module>(result);
|
return std::unique_ptr<llvm::Module>(result);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<ir::module> jit::make_triton_module(const char *name, const char *src) {
|
triton::lang::translation_unit *jit::parse_program(const char *name, const char *src) {
|
||||||
// create AST from Triton-C source
|
// create AST from Triton-C source
|
||||||
YY_BUFFER_STATE buffer = yy_scan_string(src);
|
YY_BUFFER_STATE buffer = yy_scan_string(src);
|
||||||
yyparse();
|
yyparse();
|
||||||
yy_delete_buffer(buffer);
|
yy_delete_buffer(buffer);
|
||||||
triton::lang::translation_unit *program = ast_root;
|
triton::lang::translation_unit *program = ast_root;
|
||||||
|
return program;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<ir::module> jit::make_triton_module(const char * name, triton::ir::context &context, triton::lang::translation_unit *program) {
|
||||||
// create Triton-IR from AST
|
// create Triton-IR from AST
|
||||||
ir::module* module = new ir::module(name, triton_context_);
|
ir::module* module = new ir::module(name, context);
|
||||||
program->codegen(module);
|
program->codegen(module);
|
||||||
return std::unique_ptr<ir::module>(module);
|
return std::unique_ptr<ir::module>(module);
|
||||||
}
|
}
|
||||||
@@ -98,7 +108,8 @@ jit::~jit(){ }
|
|||||||
|
|
||||||
std::vector<unsigned> jit::get_valid(const char *name, const char *src) {
|
std::vector<unsigned> jit::get_valid(const char *name, const char *src) {
|
||||||
// find metaparameters
|
// find metaparameters
|
||||||
auto ptt_module = make_triton_module(name, src);
|
triton::lang::translation_unit* program = parse_program(name, src);
|
||||||
|
auto ptt_module = make_triton_module(name, triton_context_, program);
|
||||||
ir::module &tt_module = *ptt_module;
|
ir::module &tt_module = *ptt_module;
|
||||||
// set parameters
|
// set parameters
|
||||||
passes_wrapper passes(target_.get());
|
passes_wrapper passes(target_.get());
|
||||||
@@ -111,6 +122,7 @@ std::vector<unsigned> jit::get_valid(const char *name, const char *src) {
|
|||||||
ranges.push_back(mp->get_space());
|
ranges.push_back(mp->get_space());
|
||||||
// iterate over parameters
|
// iterate over parameters
|
||||||
std::vector<unsigned> result;
|
std::vector<unsigned> result;
|
||||||
|
size_t nthreads = 1;
|
||||||
loop_nest<unsigned>(ranges, [&](const std::vector<unsigned> params){
|
loop_nest<unsigned>(ranges, [&](const std::vector<unsigned> params){
|
||||||
if(!result.empty())
|
if(!result.empty())
|
||||||
return;
|
return;
|
||||||
@@ -128,7 +140,7 @@ std::vector<unsigned> jit::get_valid(const char *name, const char *src) {
|
|||||||
if(!errors.empty())
|
if(!errors.empty())
|
||||||
return;
|
return;
|
||||||
result = params;
|
result = params;
|
||||||
});
|
}, nthreads);
|
||||||
if(result.empty())
|
if(result.empty())
|
||||||
throw std::runtime_error("couldn't find valid parameters");
|
throw std::runtime_error("couldn't find valid parameters");
|
||||||
return result;
|
return result;
|
||||||
@@ -138,72 +150,77 @@ std::vector<unsigned> jit::get_valid(const char *name, const char *src) {
|
|||||||
|
|
||||||
jit::tune_res_t jit::autotune(const char *name, const char *src, benchmark_t benchmark) {
|
jit::tune_res_t jit::autotune(const char *name, const char *src, benchmark_t benchmark) {
|
||||||
// find metaparameters
|
// find metaparameters
|
||||||
auto ptt_module = make_triton_module(name, src);
|
triton::lang::translation_unit* program = parse_program(name, src);
|
||||||
ir::module &tt_module = *ptt_module;
|
auto ptt_module_0 = make_triton_module(name, triton_context_, program);
|
||||||
|
ir::module &tt_module_0 = *ptt_module_0;
|
||||||
// set parameters
|
// set parameters
|
||||||
passes_wrapper passes(target_.get());
|
passes_wrapper passes_0(target_.get());
|
||||||
passes.target_independent(tt_module);
|
passes_0.target_independent(tt_module_0);
|
||||||
passes.tune.run(tt_module);
|
passes_0.tune.run(tt_module_0);
|
||||||
auto mps = passes.tune.get_params(tt_module);
|
|
||||||
// create parameter ranges
|
// create parameter ranges
|
||||||
std::vector<std::vector<unsigned>> ranges;
|
std::vector<std::vector<unsigned>> ranges;
|
||||||
|
auto mps = passes_0.tune.get_params(tt_module_0);
|
||||||
for(ir::metaparameter *mp: mps)
|
for(ir::metaparameter *mp: mps)
|
||||||
ranges.push_back(mp->get_space());
|
ranges.push_back(mp->get_space());
|
||||||
// iterate over parameters
|
// iterate over parameters
|
||||||
unsigned i;
|
|
||||||
tune_res_t best;
|
tune_res_t best;
|
||||||
|
size_t nthreads = 4;
|
||||||
|
std::mutex mutex;
|
||||||
loop_nest<unsigned>(ranges, [&](const std::vector<unsigned> params){
|
loop_nest<unsigned>(ranges, [&](const std::vector<unsigned> params){
|
||||||
std::map<ir::value*, std::vector<std::string>> errors;
|
std::map<ir::value*, std::vector<std::string>> errors;
|
||||||
i = 0;
|
unsigned i = 0;
|
||||||
|
{
|
||||||
|
std::lock_guard<std::mutex> lock(mutex);
|
||||||
for(ir::metaparameter *mp: mps)
|
for(ir::metaparameter *mp: mps)
|
||||||
mp->set_value(params[i++]);
|
mp->set_value(params[i++]);
|
||||||
passes.target_independent(tt_module);
|
passes_0.tune.init(tt_module_0);
|
||||||
passes.tune.init(tt_module);
|
passes_0.tune.check_constraints(errors);
|
||||||
passes.tune.check_constraints(errors);
|
}
|
||||||
// for(auto x: errors)
|
|
||||||
// for(auto err: x.second)
|
|
||||||
// std::cout << err << std::endl;
|
|
||||||
if(!errors.empty())
|
if(!errors.empty())
|
||||||
return;
|
return;
|
||||||
// Deep copy of the module and tuner
|
// Deep copy of the module and tuner
|
||||||
auto ptt_module = make_triton_module(name, src);
|
triton::ir::context triton_context;
|
||||||
ir::module &tt_module = *ptt_module;
|
auto ptt_module_1 = make_triton_module(name, triton_context, program);
|
||||||
// for(unsigned p: params)
|
ir::module &tt_module_1 = *ptt_module_1;
|
||||||
// std::cout << p << " " << std::flush;
|
// run passes
|
||||||
passes_wrapper passes(target_.get());
|
passes_wrapper passes_1(target_.get());
|
||||||
passes.target_independent(tt_module);
|
passes_1.target_independent(tt_module_1);
|
||||||
passes.tune.run(tt_module);
|
passes_1.tune.run(tt_module_1);
|
||||||
i = 0;
|
i = 0;
|
||||||
for(ir::metaparameter* mp: passes.tune.get_params(tt_module)){
|
for(ir::metaparameter* mp: passes_1.tune.get_params(tt_module_1)){
|
||||||
mp->set_value(params[i++]);
|
mp->set_value(params[i++]);
|
||||||
}
|
}
|
||||||
passes.tune.init(tt_module);
|
passes_1.tune.init(tt_module_1);
|
||||||
passes.target_dependent(tt_module);
|
passes_1.target_dependent(tt_module_1);
|
||||||
driver::device* device = driver_context_->device();
|
driver::device* device = driver_context_->device();
|
||||||
if(passes.shmem_allocation.get_allocated_size() > device->max_shared_memory())
|
if(passes_1.shmem_allocation.get_allocated_size() > device->max_shared_memory())
|
||||||
return;
|
return;
|
||||||
if(passes.tune.get_num_threads() > device->max_threads_per_block())
|
if(passes_1.tune.get_num_threads() > device->max_threads_per_block())
|
||||||
return;
|
return;
|
||||||
// Compile
|
// Compile
|
||||||
auto ll_module = make_llvm_module(tt_module, passes);
|
launch_information info;
|
||||||
|
llvm::LLVMContext llvm_context;
|
||||||
|
auto ll_module = make_llvm_module(tt_module_1, passes_1, llvm_context, info);
|
||||||
std::unique_ptr<driver::module> module(driver::module::create(driver_context_, &*ll_module));
|
std::unique_ptr<driver::module> module(driver::module::create(driver_context_, &*ll_module));
|
||||||
std::unique_ptr<driver::kernel> kernel(driver::kernel::create(module.get(), name));
|
std::unique_ptr<driver::kernel> kernel(driver::kernel::create(module.get(), name));
|
||||||
launch_information info = launch_info_map_.at(name);
|
|
||||||
for(unsigned p: params)
|
|
||||||
std::cout << p << " " << std::flush;
|
|
||||||
// add globals
|
// add globals
|
||||||
for(auto x: tt_module.globals())
|
for(auto x: tt_module_1.globals())
|
||||||
global_ints_[x.first] = ((ir::metaparameter*)x.second)->get_value();
|
global_ints_[x.first] = ((ir::metaparameter*)x.second)->get_value();
|
||||||
modules_.insert({name, module.get()});
|
modules_.insert({name, module.get()});
|
||||||
double perf;
|
double perf;
|
||||||
perf = benchmark(kernel.get(), info);
|
perf = benchmark(kernel.get(), info);
|
||||||
|
{
|
||||||
|
std::lock_guard<std::mutex> lock(mutex);
|
||||||
if(perf > best.perf){
|
if(perf > best.perf){
|
||||||
best.perf = perf;
|
best.perf = perf;
|
||||||
best.params = params;
|
best.params = params;
|
||||||
}
|
}
|
||||||
|
for(unsigned p: params)
|
||||||
|
std::cout << p << " " << std::flush;
|
||||||
std::cout << perf << " [ " << best.perf << " ] " << std::endl;
|
std::cout << perf << " [ " << best.perf << " ] " << std::endl;
|
||||||
|
}
|
||||||
modules_.erase(name);
|
modules_.erase(name);
|
||||||
});
|
}, nthreads);
|
||||||
return best;
|
return best;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -227,9 +244,9 @@ void jit::add_module(ir::module &tt_module, const std::vector<unsigned> ¶ms)
|
|||||||
if(errors.size())
|
if(errors.size())
|
||||||
throw std::runtime_error("invalid parameters");
|
throw std::runtime_error("invalid parameters");
|
||||||
// triton module -> llvm module
|
// triton module -> llvm module
|
||||||
auto ll_module = make_llvm_module(tt_module, passes);
|
|
||||||
// llvm module -> machine code
|
|
||||||
std::string name = tt_module.get_name();
|
std::string name = tt_module.get_name();
|
||||||
|
auto ll_module = make_llvm_module(tt_module, passes, llvm_context_, launch_info_map_[name]);
|
||||||
|
// llvm module -> machine code
|
||||||
modules_.insert({name, driver::module::create(driver_context_, &*ll_module)});
|
modules_.insert({name, driver::module::create(driver_context_, &*ll_module)});
|
||||||
// add globals
|
// add globals
|
||||||
for(auto x: tt_module.globals())
|
for(auto x: tt_module.globals())
|
||||||
@@ -237,7 +254,8 @@ void jit::add_module(ir::module &tt_module, const std::vector<unsigned> ¶ms)
|
|||||||
}
|
}
|
||||||
|
|
||||||
void jit::add_module(const char *name, const char *src, const std::vector<unsigned> ¶ms) {
|
void jit::add_module(const char *name, const char *src, const std::vector<unsigned> ¶ms) {
|
||||||
auto ptt_module = make_triton_module(name, src);
|
triton::lang::translation_unit* program = parse_program(name, src);
|
||||||
|
auto ptt_module = make_triton_module(name, triton_context_, program);
|
||||||
add_module(*ptt_module, params);
|
add_module(*ptt_module, params);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user