diff --git a/examples/cpp/shift.cpp b/examples/cpp/shift.cpp index 41c123fef..6941dfa0d 100644 --- a/examples/cpp/shift.cpp +++ b/examples/cpp/shift.cpp @@ -14,13 +14,13 @@ int main() { // initialize default compute device auto context = triton::driver::backend::contexts::get_default(); - auto op = triton::dnn::shift::BPROP; + auto op = triton::dnn::shift::FPROP; // initialization int32_t R = 3, S = 3; - int32_t B = 16, F = 4096; + int32_t B = 16, F = 512; int32_t H = 16, W = 16; - int32_t C = 4096; + int32_t C = 512; // random shifts std::vector shift_h(C); @@ -66,7 +66,7 @@ int main() { stream->write(db, true, 0, hb); stream->write(dc, true, 0, hc); stream->synchronize(); - shift.enqueue(stream, {da, db, dc}); + shift.enqueue(stream, {da, db, dc}, true); // stream->read(dc, true, 0, hc); // shift.cpu_ref(rc.data(), ha.data(), hb.data()); // for(size_t i = 0; i < hc.size(); i++) diff --git a/include/triton/runtime/jit.h b/include/triton/runtime/jit.h index ca5395893..9b0f75f96 100644 --- a/include/triton/runtime/jit.h +++ b/include/triton/runtime/jit.h @@ -28,6 +28,10 @@ namespace llvm { namespace triton { +namespace lang{ +class translation_unit; +} + namespace codegen{ class tune; } @@ -97,8 +101,9 @@ public: private: std::string compute_data_layout(bool is_64bit = true, bool use_short_pointers = true); - std::unique_ptr make_llvm_module(triton::ir::module &module, passes_wrapper &passes); - std::unique_ptr make_triton_module(const char* name, const char* src); + std::unique_ptr make_llvm_module(triton::ir::module &module, passes_wrapper &passes, llvm::LLVMContext &context, launch_information &info); + std::unique_ptr make_triton_module(const char *name, triton::ir::context &context, triton::lang::translation_unit *program); + triton::lang::translation_unit *parse_program(const char *name, const char *src); public: jit(driver::context* context); diff --git a/lib/runtime/jit.cpp b/lib/runtime/jit.cpp index 536ad44b0..51f3ed916 100644 --- a/lib/runtime/jit.cpp +++ b/lib/runtime/jit.cpp @@ -19,6 +19,8 @@ #include "llvm/IR/LegacyPassManager.h" #include "llvm/Transforms/Scalar/EarlyCSE.h" #include "llvm/Analysis/LoopPass.h" +#include "triton/tools/thread_pool.h" +#include typedef struct yy_buffer_state * YY_BUFFER_STATE; extern int yyparse(); @@ -28,14 +30,19 @@ extern triton::lang::translation_unit *ast_root; namespace triton { -void loop_nest(std::vector const & ranges, std::function const &)> const & f){ +void loop_nest(std::vector const & ranges, + std::function const &)> const & f, + size_t nthreads){ size_t D = ranges.size(); std::vector values(D, 0); + // thread pools + nbsdx::concurrent::thread_pool pool(nthreads); // Start with innermost loop size_t i = D - 1; + size_t current = 0; while(true){ //Execute function - f(values); + pool.add_job([values, &f](){ f(values); }); //Increment counters while(values[i]++ == ranges[i] - 1){ if(i == 0) @@ -47,7 +54,7 @@ void loop_nest(std::vector const & ranges, std::function -void loop_nest(std::vector> const & iterates, std::function)> const & f){ +void loop_nest(std::vector> const & iterates, std::function)> const & f, size_t nthreads){ //Ranges to iterate over std::vector ranges; for(auto const & x: iterates) @@ -60,17 +67,16 @@ void loop_nest(std::vector> const & iterates, std::function jit::make_llvm_module(ir::module &module, passes_wrapper &passes) { - llvm::Module* result = new llvm::Module(module.get_name(), llvm_context_); +std::unique_ptr jit::make_llvm_module(ir::module &module, passes_wrapper &passes, llvm::LLVMContext& llvm_context, launch_information& info) { + llvm::Module* result = new llvm::Module(module.get_name(), llvm_context); passes.selection.run(module, *result); // launch information - launch_information& info = launch_info_map_[result->getName()]; info.global_range_size.clear(); for(unsigned i = 0; i < passes.tune.get_num_global_range(); i++) info.global_range_size.push_back(passes.tune.get_global_range_size(i)); @@ -78,14 +84,18 @@ std::unique_ptr jit::make_llvm_module(ir::module &module, passes_w return std::unique_ptr(result); } -std::unique_ptr jit::make_triton_module(const char *name, const char *src) { +triton::lang::translation_unit *jit::parse_program(const char *name, const char *src) { // create AST from Triton-C source YY_BUFFER_STATE buffer = yy_scan_string(src); yyparse(); yy_delete_buffer(buffer); triton::lang::translation_unit *program = ast_root; + return program; +} + +std::unique_ptr jit::make_triton_module(const char * name, triton::ir::context &context, triton::lang::translation_unit *program) { // create Triton-IR from AST - ir::module* module = new ir::module(name, triton_context_); + ir::module* module = new ir::module(name, context); program->codegen(module); return std::unique_ptr(module); } @@ -98,7 +108,8 @@ jit::~jit(){ } std::vector jit::get_valid(const char *name, const char *src) { // find metaparameters - auto ptt_module = make_triton_module(name, src); + triton::lang::translation_unit* program = parse_program(name, src); + auto ptt_module = make_triton_module(name, triton_context_, program); ir::module &tt_module = *ptt_module; // set parameters passes_wrapper passes(target_.get()); @@ -111,6 +122,7 @@ std::vector jit::get_valid(const char *name, const char *src) { ranges.push_back(mp->get_space()); // iterate over parameters std::vector result; + size_t nthreads = 1; loop_nest(ranges, [&](const std::vector params){ if(!result.empty()) return; @@ -128,7 +140,7 @@ std::vector jit::get_valid(const char *name, const char *src) { if(!errors.empty()) return; result = params; - }); + }, nthreads); if(result.empty()) throw std::runtime_error("couldn't find valid parameters"); return result; @@ -138,72 +150,77 @@ std::vector jit::get_valid(const char *name, const char *src) { jit::tune_res_t jit::autotune(const char *name, const char *src, benchmark_t benchmark) { // find metaparameters - auto ptt_module = make_triton_module(name, src); - ir::module &tt_module = *ptt_module; + triton::lang::translation_unit* program = parse_program(name, src); + auto ptt_module_0 = make_triton_module(name, triton_context_, program); + ir::module &tt_module_0 = *ptt_module_0; // set parameters - passes_wrapper passes(target_.get()); - passes.target_independent(tt_module); - passes.tune.run(tt_module); - auto mps = passes.tune.get_params(tt_module); + passes_wrapper passes_0(target_.get()); + passes_0.target_independent(tt_module_0); + passes_0.tune.run(tt_module_0); // create parameter ranges std::vector> ranges; + auto mps = passes_0.tune.get_params(tt_module_0); for(ir::metaparameter *mp: mps) ranges.push_back(mp->get_space()); // iterate over parameters - unsigned i; tune_res_t best; + size_t nthreads = 4; + std::mutex mutex; loop_nest(ranges, [&](const std::vector params){ std::map> errors; - i = 0; - for(ir::metaparameter *mp: mps) - mp->set_value(params[i++]); - passes.target_independent(tt_module); - passes.tune.init(tt_module); - passes.tune.check_constraints(errors); -// for(auto x: errors) -// for(auto err: x.second) -// std::cout << err << std::endl; + unsigned i = 0; + { + std::lock_guard lock(mutex); + for(ir::metaparameter *mp: mps) + mp->set_value(params[i++]); + passes_0.tune.init(tt_module_0); + passes_0.tune.check_constraints(errors); + } if(!errors.empty()) return; // Deep copy of the module and tuner - auto ptt_module = make_triton_module(name, src); - ir::module &tt_module = *ptt_module; -// for(unsigned p: params) -// std::cout << p << " " << std::flush; - passes_wrapper passes(target_.get()); - passes.target_independent(tt_module); - passes.tune.run(tt_module); + triton::ir::context triton_context; + auto ptt_module_1 = make_triton_module(name, triton_context, program); + ir::module &tt_module_1 = *ptt_module_1; + // run passes + passes_wrapper passes_1(target_.get()); + passes_1.target_independent(tt_module_1); + passes_1.tune.run(tt_module_1); i = 0; - for(ir::metaparameter* mp: passes.tune.get_params(tt_module)){ + for(ir::metaparameter* mp: passes_1.tune.get_params(tt_module_1)){ mp->set_value(params[i++]); } - passes.tune.init(tt_module); - passes.target_dependent(tt_module); + passes_1.tune.init(tt_module_1); + passes_1.target_dependent(tt_module_1); driver::device* device = driver_context_->device(); - if(passes.shmem_allocation.get_allocated_size() > device->max_shared_memory()) + if(passes_1.shmem_allocation.get_allocated_size() > device->max_shared_memory()) return; - if(passes.tune.get_num_threads() > device->max_threads_per_block()) + if(passes_1.tune.get_num_threads() > device->max_threads_per_block()) return; // Compile - auto ll_module = make_llvm_module(tt_module, passes); + launch_information info; + llvm::LLVMContext llvm_context; + auto ll_module = make_llvm_module(tt_module_1, passes_1, llvm_context, info); std::unique_ptr module(driver::module::create(driver_context_, &*ll_module)); std::unique_ptr kernel(driver::kernel::create(module.get(), name)); - launch_information info = launch_info_map_.at(name); - for(unsigned p: params) - std::cout << p << " " << std::flush; // add globals - for(auto x: tt_module.globals()) + for(auto x: tt_module_1.globals()) global_ints_[x.first] = ((ir::metaparameter*)x.second)->get_value(); modules_.insert({name, module.get()}); double perf; perf = benchmark(kernel.get(), info); - if(perf > best.perf){ - best.perf = perf; - best.params = params; + { + std::lock_guard lock(mutex); + if(perf > best.perf){ + best.perf = perf; + best.params = params; + } + for(unsigned p: params) + std::cout << p << " " << std::flush; + std::cout << perf << " [ " << best.perf << " ] " << std::endl; } - std::cout << perf << " [ " << best.perf << " ] " << std::endl; modules_.erase(name); - }); + }, nthreads); return best; } @@ -227,9 +244,9 @@ void jit::add_module(ir::module &tt_module, const std::vector ¶ms) if(errors.size()) throw std::runtime_error("invalid parameters"); // triton module -> llvm module - auto ll_module = make_llvm_module(tt_module, passes); - // llvm module -> machine code std::string name = tt_module.get_name(); + auto ll_module = make_llvm_module(tt_module, passes, llvm_context_, launch_info_map_[name]); + // llvm module -> machine code modules_.insert({name, driver::module::create(driver_context_, &*ll_module)}); // add globals for(auto x: tt_module.globals()) @@ -237,7 +254,8 @@ void jit::add_module(ir::module &tt_module, const std::vector ¶ms) } void jit::add_module(const char *name, const char *src, const std::vector ¶ms) { - auto ptt_module = make_triton_module(name, src); + triton::lang::translation_unit* program = parse_program(name, src); + auto ptt_module = make_triton_module(name, triton_context_, program); add_module(*ptt_module, params); }