[jit/autotune] added support for multi-threaded auto-tuning
This commit is contained in:
@@ -14,13 +14,13 @@ int main() {
|
||||
|
||||
// initialize default compute device
|
||||
auto context = triton::driver::backend::contexts::get_default();
|
||||
auto op = triton::dnn::shift::BPROP;
|
||||
auto op = triton::dnn::shift::FPROP;
|
||||
|
||||
// initialization
|
||||
int32_t R = 3, S = 3;
|
||||
int32_t B = 16, F = 4096;
|
||||
int32_t B = 16, F = 512;
|
||||
int32_t H = 16, W = 16;
|
||||
int32_t C = 4096;
|
||||
int32_t C = 512;
|
||||
|
||||
// random shifts
|
||||
std::vector<int32_t> shift_h(C);
|
||||
@@ -66,7 +66,7 @@ int main() {
|
||||
stream->write(db, true, 0, hb);
|
||||
stream->write(dc, true, 0, hc);
|
||||
stream->synchronize();
|
||||
shift.enqueue(stream, {da, db, dc});
|
||||
shift.enqueue(stream, {da, db, dc}, true);
|
||||
// stream->read(dc, true, 0, hc);
|
||||
// shift.cpu_ref(rc.data(), ha.data(), hb.data());
|
||||
// for(size_t i = 0; i < hc.size(); i++)
|
||||
|
@@ -28,6 +28,10 @@ namespace llvm {
|
||||
|
||||
namespace triton {
|
||||
|
||||
namespace lang{
|
||||
class translation_unit;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
class tune;
|
||||
}
|
||||
@@ -97,8 +101,9 @@ public:
|
||||
|
||||
private:
|
||||
std::string compute_data_layout(bool is_64bit = true, bool use_short_pointers = true);
|
||||
std::unique_ptr<llvm::Module> make_llvm_module(triton::ir::module &module, passes_wrapper &passes);
|
||||
std::unique_ptr<ir::module> make_triton_module(const char* name, const char* src);
|
||||
std::unique_ptr<llvm::Module> make_llvm_module(triton::ir::module &module, passes_wrapper &passes, llvm::LLVMContext &context, launch_information &info);
|
||||
std::unique_ptr<ir::module> make_triton_module(const char *name, triton::ir::context &context, triton::lang::translation_unit *program);
|
||||
triton::lang::translation_unit *parse_program(const char *name, const char *src);
|
||||
|
||||
public:
|
||||
jit(driver::context* context);
|
||||
|
@@ -19,6 +19,8 @@
|
||||
#include "llvm/IR/LegacyPassManager.h"
|
||||
#include "llvm/Transforms/Scalar/EarlyCSE.h"
|
||||
#include "llvm/Analysis/LoopPass.h"
|
||||
#include "triton/tools/thread_pool.h"
|
||||
#include <mutex>
|
||||
|
||||
typedef struct yy_buffer_state * YY_BUFFER_STATE;
|
||||
extern int yyparse();
|
||||
@@ -28,14 +30,19 @@ extern triton::lang::translation_unit *ast_root;
|
||||
|
||||
namespace triton {
|
||||
|
||||
void loop_nest(std::vector<size_t> const & ranges, std::function<void(std::vector<size_t> const &)> const & f){
|
||||
void loop_nest(std::vector<size_t> const & ranges,
|
||||
std::function<void(std::vector<size_t> const &)> const & f,
|
||||
size_t nthreads){
|
||||
size_t D = ranges.size();
|
||||
std::vector<size_t> values(D, 0);
|
||||
// thread pools
|
||||
nbsdx::concurrent::thread_pool pool(nthreads);
|
||||
// Start with innermost loop
|
||||
size_t i = D - 1;
|
||||
size_t current = 0;
|
||||
while(true){
|
||||
//Execute function
|
||||
f(values);
|
||||
pool.add_job([values, &f](){ f(values); });
|
||||
//Increment counters
|
||||
while(values[i]++ == ranges[i] - 1){
|
||||
if(i == 0)
|
||||
@@ -47,7 +54,7 @@ void loop_nest(std::vector<size_t> const & ranges, std::function<void(std::vecto
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void loop_nest(std::vector<std::vector<T>> const & iterates, std::function<void(std::vector<T>)> const & f){
|
||||
void loop_nest(std::vector<std::vector<T>> const & iterates, std::function<void(std::vector<T>)> const & f, size_t nthreads){
|
||||
//Ranges to iterate over
|
||||
std::vector<size_t> ranges;
|
||||
for(auto const & x: iterates)
|
||||
@@ -60,17 +67,16 @@ void loop_nest(std::vector<std::vector<T>> const & iterates, std::function<void(
|
||||
f(x);
|
||||
};
|
||||
//Iterate
|
||||
loop_nest(ranges, proxy);
|
||||
loop_nest(ranges, proxy, nthreads);
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
std::unique_ptr<llvm::Module> jit::make_llvm_module(ir::module &module, passes_wrapper &passes) {
|
||||
llvm::Module* result = new llvm::Module(module.get_name(), llvm_context_);
|
||||
std::unique_ptr<llvm::Module> jit::make_llvm_module(ir::module &module, passes_wrapper &passes, llvm::LLVMContext& llvm_context, launch_information& info) {
|
||||
llvm::Module* result = new llvm::Module(module.get_name(), llvm_context);
|
||||
passes.selection.run(module, *result);
|
||||
// launch information
|
||||
launch_information& info = launch_info_map_[result->getName()];
|
||||
info.global_range_size.clear();
|
||||
for(unsigned i = 0; i < passes.tune.get_num_global_range(); i++)
|
||||
info.global_range_size.push_back(passes.tune.get_global_range_size(i));
|
||||
@@ -78,14 +84,18 @@ std::unique_ptr<llvm::Module> jit::make_llvm_module(ir::module &module, passes_w
|
||||
return std::unique_ptr<llvm::Module>(result);
|
||||
}
|
||||
|
||||
std::unique_ptr<ir::module> jit::make_triton_module(const char *name, const char *src) {
|
||||
triton::lang::translation_unit *jit::parse_program(const char *name, const char *src) {
|
||||
// create AST from Triton-C source
|
||||
YY_BUFFER_STATE buffer = yy_scan_string(src);
|
||||
yyparse();
|
||||
yy_delete_buffer(buffer);
|
||||
triton::lang::translation_unit *program = ast_root;
|
||||
return program;
|
||||
}
|
||||
|
||||
std::unique_ptr<ir::module> jit::make_triton_module(const char * name, triton::ir::context &context, triton::lang::translation_unit *program) {
|
||||
// create Triton-IR from AST
|
||||
ir::module* module = new ir::module(name, triton_context_);
|
||||
ir::module* module = new ir::module(name, context);
|
||||
program->codegen(module);
|
||||
return std::unique_ptr<ir::module>(module);
|
||||
}
|
||||
@@ -98,7 +108,8 @@ jit::~jit(){ }
|
||||
|
||||
std::vector<unsigned> jit::get_valid(const char *name, const char *src) {
|
||||
// find metaparameters
|
||||
auto ptt_module = make_triton_module(name, src);
|
||||
triton::lang::translation_unit* program = parse_program(name, src);
|
||||
auto ptt_module = make_triton_module(name, triton_context_, program);
|
||||
ir::module &tt_module = *ptt_module;
|
||||
// set parameters
|
||||
passes_wrapper passes(target_.get());
|
||||
@@ -111,6 +122,7 @@ std::vector<unsigned> jit::get_valid(const char *name, const char *src) {
|
||||
ranges.push_back(mp->get_space());
|
||||
// iterate over parameters
|
||||
std::vector<unsigned> result;
|
||||
size_t nthreads = 1;
|
||||
loop_nest<unsigned>(ranges, [&](const std::vector<unsigned> params){
|
||||
if(!result.empty())
|
||||
return;
|
||||
@@ -128,7 +140,7 @@ std::vector<unsigned> jit::get_valid(const char *name, const char *src) {
|
||||
if(!errors.empty())
|
||||
return;
|
||||
result = params;
|
||||
});
|
||||
}, nthreads);
|
||||
if(result.empty())
|
||||
throw std::runtime_error("couldn't find valid parameters");
|
||||
return result;
|
||||
@@ -138,72 +150,77 @@ std::vector<unsigned> jit::get_valid(const char *name, const char *src) {
|
||||
|
||||
jit::tune_res_t jit::autotune(const char *name, const char *src, benchmark_t benchmark) {
|
||||
// find metaparameters
|
||||
auto ptt_module = make_triton_module(name, src);
|
||||
ir::module &tt_module = *ptt_module;
|
||||
triton::lang::translation_unit* program = parse_program(name, src);
|
||||
auto ptt_module_0 = make_triton_module(name, triton_context_, program);
|
||||
ir::module &tt_module_0 = *ptt_module_0;
|
||||
// set parameters
|
||||
passes_wrapper passes(target_.get());
|
||||
passes.target_independent(tt_module);
|
||||
passes.tune.run(tt_module);
|
||||
auto mps = passes.tune.get_params(tt_module);
|
||||
passes_wrapper passes_0(target_.get());
|
||||
passes_0.target_independent(tt_module_0);
|
||||
passes_0.tune.run(tt_module_0);
|
||||
// create parameter ranges
|
||||
std::vector<std::vector<unsigned>> ranges;
|
||||
auto mps = passes_0.tune.get_params(tt_module_0);
|
||||
for(ir::metaparameter *mp: mps)
|
||||
ranges.push_back(mp->get_space());
|
||||
// iterate over parameters
|
||||
unsigned i;
|
||||
tune_res_t best;
|
||||
size_t nthreads = 4;
|
||||
std::mutex mutex;
|
||||
loop_nest<unsigned>(ranges, [&](const std::vector<unsigned> params){
|
||||
std::map<ir::value*, std::vector<std::string>> errors;
|
||||
i = 0;
|
||||
unsigned i = 0;
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mutex);
|
||||
for(ir::metaparameter *mp: mps)
|
||||
mp->set_value(params[i++]);
|
||||
passes.target_independent(tt_module);
|
||||
passes.tune.init(tt_module);
|
||||
passes.tune.check_constraints(errors);
|
||||
// for(auto x: errors)
|
||||
// for(auto err: x.second)
|
||||
// std::cout << err << std::endl;
|
||||
passes_0.tune.init(tt_module_0);
|
||||
passes_0.tune.check_constraints(errors);
|
||||
}
|
||||
if(!errors.empty())
|
||||
return;
|
||||
// Deep copy of the module and tuner
|
||||
auto ptt_module = make_triton_module(name, src);
|
||||
ir::module &tt_module = *ptt_module;
|
||||
// for(unsigned p: params)
|
||||
// std::cout << p << " " << std::flush;
|
||||
passes_wrapper passes(target_.get());
|
||||
passes.target_independent(tt_module);
|
||||
passes.tune.run(tt_module);
|
||||
triton::ir::context triton_context;
|
||||
auto ptt_module_1 = make_triton_module(name, triton_context, program);
|
||||
ir::module &tt_module_1 = *ptt_module_1;
|
||||
// run passes
|
||||
passes_wrapper passes_1(target_.get());
|
||||
passes_1.target_independent(tt_module_1);
|
||||
passes_1.tune.run(tt_module_1);
|
||||
i = 0;
|
||||
for(ir::metaparameter* mp: passes.tune.get_params(tt_module)){
|
||||
for(ir::metaparameter* mp: passes_1.tune.get_params(tt_module_1)){
|
||||
mp->set_value(params[i++]);
|
||||
}
|
||||
passes.tune.init(tt_module);
|
||||
passes.target_dependent(tt_module);
|
||||
passes_1.tune.init(tt_module_1);
|
||||
passes_1.target_dependent(tt_module_1);
|
||||
driver::device* device = driver_context_->device();
|
||||
if(passes.shmem_allocation.get_allocated_size() > device->max_shared_memory())
|
||||
if(passes_1.shmem_allocation.get_allocated_size() > device->max_shared_memory())
|
||||
return;
|
||||
if(passes.tune.get_num_threads() > device->max_threads_per_block())
|
||||
if(passes_1.tune.get_num_threads() > device->max_threads_per_block())
|
||||
return;
|
||||
// Compile
|
||||
auto ll_module = make_llvm_module(tt_module, passes);
|
||||
launch_information info;
|
||||
llvm::LLVMContext llvm_context;
|
||||
auto ll_module = make_llvm_module(tt_module_1, passes_1, llvm_context, info);
|
||||
std::unique_ptr<driver::module> module(driver::module::create(driver_context_, &*ll_module));
|
||||
std::unique_ptr<driver::kernel> kernel(driver::kernel::create(module.get(), name));
|
||||
launch_information info = launch_info_map_.at(name);
|
||||
for(unsigned p: params)
|
||||
std::cout << p << " " << std::flush;
|
||||
// add globals
|
||||
for(auto x: tt_module.globals())
|
||||
for(auto x: tt_module_1.globals())
|
||||
global_ints_[x.first] = ((ir::metaparameter*)x.second)->get_value();
|
||||
modules_.insert({name, module.get()});
|
||||
double perf;
|
||||
perf = benchmark(kernel.get(), info);
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mutex);
|
||||
if(perf > best.perf){
|
||||
best.perf = perf;
|
||||
best.params = params;
|
||||
}
|
||||
for(unsigned p: params)
|
||||
std::cout << p << " " << std::flush;
|
||||
std::cout << perf << " [ " << best.perf << " ] " << std::endl;
|
||||
}
|
||||
modules_.erase(name);
|
||||
});
|
||||
}, nthreads);
|
||||
return best;
|
||||
}
|
||||
|
||||
@@ -227,9 +244,9 @@ void jit::add_module(ir::module &tt_module, const std::vector<unsigned> ¶ms)
|
||||
if(errors.size())
|
||||
throw std::runtime_error("invalid parameters");
|
||||
// triton module -> llvm module
|
||||
auto ll_module = make_llvm_module(tt_module, passes);
|
||||
// llvm module -> machine code
|
||||
std::string name = tt_module.get_name();
|
||||
auto ll_module = make_llvm_module(tt_module, passes, llvm_context_, launch_info_map_[name]);
|
||||
// llvm module -> machine code
|
||||
modules_.insert({name, driver::module::create(driver_context_, &*ll_module)});
|
||||
// add globals
|
||||
for(auto x: tt_module.globals())
|
||||
@@ -237,7 +254,8 @@ void jit::add_module(ir::module &tt_module, const std::vector<unsigned> ¶ms)
|
||||
}
|
||||
|
||||
void jit::add_module(const char *name, const char *src, const std::vector<unsigned> ¶ms) {
|
||||
auto ptt_module = make_triton_module(name, src);
|
||||
triton::lang::translation_unit* program = parse_program(name, src);
|
||||
auto ptt_module = make_triton_module(name, triton_context_, program);
|
||||
add_module(*ptt_module, params);
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user