From 434f65737f4ba2d7692500a3a46edcee3ea532fe Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Mon, 15 Jul 2019 12:35:53 -0700 Subject: [PATCH] [runtime] put jit::launch_info in another file --- examples/cpp/shift.cpp | 8 +- include/triton/dnn/base.h | 4 +- include/triton/dnn/batchnorm.h | 4 +- include/triton/dnn/conv.h | 3 +- include/triton/dnn/gemm.h | 3 +- include/triton/dnn/shift.h | 3 +- include/triton/runtime/jit.h | 11 +- include/triton/runtime/launch_info.h | 19 ++++ include/triton/tools/thread_pool.h | 161 +++++++++++++++++++++++++++ lib/dnn/base.cpp | 20 ++-- lib/dnn/batchnorm.cpp | 9 +- lib/dnn/conv.cpp | 7 +- lib/dnn/gemm.cpp | 9 +- lib/dnn/shift.cpp | 6 +- lib/runtime/jit.cpp | 21 ++-- 15 files changed, 227 insertions(+), 61 deletions(-) create mode 100644 include/triton/runtime/launch_info.h create mode 100644 include/triton/tools/thread_pool.h diff --git a/examples/cpp/shift.cpp b/examples/cpp/shift.cpp index 6941dfa0d..33ded064e 100644 --- a/examples/cpp/shift.cpp +++ b/examples/cpp/shift.cpp @@ -14,13 +14,13 @@ int main() { // initialize default compute device auto context = triton::driver::backend::contexts::get_default(); - auto op = triton::dnn::shift::FPROP; + auto op = triton::dnn::shift::WGRAD; // initialization int32_t R = 3, S = 3; - int32_t B = 16, F = 512; - int32_t H = 16, W = 16; - int32_t C = 512; + int32_t B = 32, F = 128; + int32_t H = 28, W = 28; + int32_t C = 128; // random shifts std::vector shift_h(C); diff --git a/include/triton/dnn/base.h b/include/triton/dnn/base.h index 7aeab2a14..e8ba1c47e 100644 --- a/include/triton/dnn/base.h +++ b/include/triton/dnn/base.h @@ -25,6 +25,7 @@ #include "triton/driver/stream.h" #include "triton/driver/kernel.h" +#include "triton/runtime/launch_info.h" namespace triton{ namespace dnn{ @@ -45,8 +46,7 @@ private: // enqueue virtual void enqueue_impl(driver::stream *stream, driver::kernel *kernel, std::vector args, - const std::vector& ranges, - size_t nthreads) = 0; + triton::runtime::launch_information info) = 0; // number of flops virtual size_t num_flops() const = 0; // comparison for maps diff --git a/include/triton/dnn/batchnorm.h b/include/triton/dnn/batchnorm.h index df2a2df30..496e19ae4 100644 --- a/include/triton/dnn/batchnorm.h +++ b/include/triton/dnn/batchnorm.h @@ -40,7 +40,7 @@ private: // enqueue void enqueue_impl(driver::stream *stream, driver::kernel *kernel, std::vector args, - const std::vector &ranges, size_t nthreads); + triton::runtime::launch_information info); // number of flops size_t num_flops() const; // comparison for maps @@ -72,7 +72,7 @@ private: // enqueue void enqueue_impl(driver::stream *stream, driver::kernel *kernel, std::vector args, - const std::vector &ranges, size_t nthreads); + runtime::launch_information info); // number of flops size_t num_flops() const; // comparison for maps diff --git a/include/triton/dnn/conv.h b/include/triton/dnn/conv.h index 67d621050..1b6f2d778 100644 --- a/include/triton/dnn/conv.h +++ b/include/triton/dnn/conv.h @@ -33,8 +33,7 @@ private: driver::buffer *bias); void enqueue_impl(driver::stream *stream, driver::kernel *kernel, std::vector args, - const std::vector& ranges, - size_t nthreads); + triton::runtime::launch_information info); // number of flops size_t num_flops() const; // comparison for maps diff --git a/include/triton/dnn/gemm.h b/include/triton/dnn/gemm.h index 26ed7d68a..8348edf3e 100644 --- a/include/triton/dnn/gemm.h +++ b/include/triton/dnn/gemm.h @@ -13,8 +13,7 @@ private: // enqueue void enqueue_impl(driver::stream *stream, driver::kernel *kernel, std::vector args, - const std::vector& ranges, - size_t nthreads); + triton::runtime::launch_information info); // number of flops size_t num_flops() const; // comparison for maps diff --git a/include/triton/dnn/shift.h b/include/triton/dnn/shift.h index ec4ffc753..8f33aee66 100644 --- a/include/triton/dnn/shift.h +++ b/include/triton/dnn/shift.h @@ -54,8 +54,7 @@ private: void init_impl(driver::stream *stream, driver::cu_module *module); void enqueue_impl(driver::stream *stream, driver::kernel *kernel, std::vector args, - const std::vector& ranges, - size_t nthreads); + triton::runtime::launch_information info); public: diff --git a/include/triton/runtime/jit.h b/include/triton/runtime/jit.h index 9b0f75f96..c594eccd8 100644 --- a/include/triton/runtime/jit.h +++ b/include/triton/runtime/jit.h @@ -20,6 +20,7 @@ #include "triton/codegen/alignment_info.h" #include "triton/codegen/target.h" #include "triton/codegen/vectorize.h" +#include "triton/runtime/launch_info.h" #include namespace llvm { @@ -42,12 +43,10 @@ class context; class metaparameter; } +namespace runtime{ + class jit { public: - struct launch_information{ - std::vector global_range_size; - unsigned num_threads; - }; typedef std::function benchmark_t; struct tune_res_t{ @@ -114,7 +113,6 @@ public: void add_module(const char* name, const char* src, const std::vector& params = {}); driver::kernel* get_function(const char* name); launch_information get_launch_info(const char* name); - unsigned get_int(const char* name); private: std::map modules_; @@ -122,11 +120,10 @@ private: llvm::LLVMContext llvm_context_; ir::context triton_context_; std::map launch_info_map_; - std::map global_ints_; std::shared_ptr target_; }; - +} } #endif diff --git a/include/triton/runtime/launch_info.h b/include/triton/runtime/launch_info.h new file mode 100644 index 000000000..a6a0ddb5b --- /dev/null +++ b/include/triton/runtime/launch_info.h @@ -0,0 +1,19 @@ +#ifndef TRITON_INCLUDE_RUNTIME_LAUNCH_INFO_H +#define TRITON_INCLUDE_RUNTIME_LAUNCH_INFO_H + +#include +#include + +namespace triton{ +namespace runtime{ + +struct launch_information{ + std::vector global_range_size; + unsigned num_threads; + std::map globals; +}; + +} +} + +#endif diff --git a/include/triton/tools/thread_pool.h b/include/triton/tools/thread_pool.h new file mode 100644 index 000000000..5d01511e1 --- /dev/null +++ b/include/triton/tools/thread_pool.h @@ -0,0 +1,161 @@ +#ifndef CONCURRENT_THREADPOOL_H +#define CONCURRENT_THREADPOOL_H + +#include +#include +#include +#include +#include +#include +#include + +namespace nbsdx { +namespace concurrent { + +/** + * Simple ThreadPool that creates `ThreadCount` threads upon its creation, + * and pulls from a queue to get new jobs. The default is 10 threads. + * + * This class requires a number of c++11 features be present in your compiler. + */ +class thread_pool { + + std::vector threads_; + std::list> queue_; + + std::atomic_int jobs_left_; + std::atomic_bool bailout_; + std::atomic_bool finished_; + std::condition_variable job_available_var_; + std::condition_variable wait_var_; + std::mutex wait_mutex_; + std::mutex queue_mutex_; + unsigned thread_count_; + + /** + * Take the next job in the queue and run it. + * Notify the main thread that a job has completed. + */ + void task() { + while( !bailout_ ) { + next_job()(); + --jobs_left_; + wait_var_.notify_one(); + } + } + + /** + * Get the next job; pop the first item in the queue, + * otherwise wait for a signal from the main thread. + */ + std::function next_job() { + std::function res; + std::unique_lock job_lock( queue_mutex_ ); + + // Wait for a job if we don't have any. + job_available_var_.wait( job_lock, [this]() ->bool { return queue_.size() || bailout_; } ); + + // Get job from the queue + if( !bailout_ ) { + res = queue_.front(); + queue_.pop_front(); + } + else { // If we're bailing out, 'inject' a job into the queue to keep jobs_left accurate. + res = []{}; + ++jobs_left_; + } + return res; + } + +public: + thread_pool(unsigned thread_count = 4) + : jobs_left_( 0 ) + , bailout_( false ) + , finished_( false ) + , thread_count_(thread_count) + { + threads_.resize(thread_count_); + for( unsigned i = 0; i < thread_count_; ++i ) + threads_[ i ] = std::thread( [this]{ this->task(); } ); + } + + /** + * JoinAll on deconstruction + */ + ~thread_pool() { + join_all(); + } + + /** + * Get the number of threads in this pool + */ + inline unsigned size() const { + return thread_count_; + } + + /** + * Get the number of jobs left in the queue. + */ + inline unsigned jobs_remaining() { + std::lock_guard guard( queue_mutex_ ); + return queue_.size(); + } + + /** + * Add a new job to the pool. If there are no jobs in the queue, + * a thread is woken up to take the job. If all threads are busy, + * the job is added to the end of the queue. + */ + void add_job( std::function job ) { + std::lock_guard guard( queue_mutex_ ); + queue_.emplace_back( job ); + ++jobs_left_; + job_available_var_.notify_one(); + } + + /** + * Join with all threads. Block until all threads have completed. + * Params: WaitForAll: If true, will wait for the queue to empty + * before joining with threads. If false, will complete + * current jobs, then inform the threads to exit. + * The queue will be empty after this call, and the threads will + * be done. After invoking `ThreadPool::JoinAll`, the pool can no + * longer be used. If you need the pool to exist past completion + * of jobs, look to use `ThreadPool::WaitAll`. + */ + void join_all( bool WaitForAll = true ) { + if( !finished_ ) { + if( WaitForAll ) { + wait_all(); + } + + // note that we're done, and wake up any thread that's + // waiting for a new job + bailout_ = true; + job_available_var_.notify_all(); + + for( auto &x : threads_ ) + if( x.joinable() ) + x.join(); + finished_ = true; + } + } + + /** + * Wait for the pool to empty before continuing. + * This does not call `std::thread::join`, it only waits until + * all jobs have finshed executing. + */ + void wait_all() { + if( jobs_left_ > 0 ) { + std::unique_lock lk( wait_mutex_ ); + wait_var_.wait( lk, [this]{ return this->jobs_left_ == 0; } ); + lk.unlock(); + } + } +}; + +} // namespace concurrent +} // namespace nbsdx + +#endif //CONCURRENT_THREADPOOL_H diff --git a/lib/dnn/base.cpp b/lib/dnn/base.cpp index b3bf6c05a..f5e2af0b2 100644 --- a/lib/dnn/base.cpp +++ b/lib/dnn/base.cpp @@ -23,29 +23,30 @@ base::base(const std::string& name) : name_(name) { } void base::enqueue(driver::stream *stream, std::vector args, bool autotune) { - static std::map, cmp_recompile> m_jit; + namespace rt = triton::runtime; + static std::map, cmp_recompile> m_jit; driver::context* ctx = stream->context(); - triton::jit* jit; + rt::jit* jit; /* the current template has not already been compiled */ if(m_jit.find(this) == m_jit.end()) { - jit = m_jit.emplace(this->clone(), new triton::jit(ctx)).first->second.get(); + jit = m_jit.emplace(this->clone(), new rt::jit(ctx)).first->second.get(); std::ostringstream oss; triton_c_src(oss); std::string src = oss.str(); auto benchmark = [&](triton::driver::kernel* kernel, - triton::jit::launch_information info) { + rt::launch_information info) { // launch info unsigned nthreads = info.num_threads; init_impl(stream, (triton::driver::cu_module*)kernel->module()); - enqueue_impl(stream, kernel, args, info.global_range_size, nthreads); + enqueue_impl(stream, kernel, args, info); stream->synchronize(); - double ts = triton::tools::bench([&](){ enqueue_impl(stream, kernel, args, info.global_range_size, nthreads); }, + double ts = triton::tools::bench([&](){ enqueue_impl(stream, kernel, args, info); }, [&](){ stream->synchronize(); }, ctx->device()); return num_flops() / ts * 1e-3; }; // auto-tune and save result if(autotune) { - triton::jit::tune_res_t best = jit->autotune(name_.c_str(), src.c_str(), benchmark); + rt::jit::tune_res_t best = jit->autotune(name_.c_str(), src.c_str(), benchmark); jit->add_module(name_.c_str(), src.c_str(), best.params); } else { @@ -60,10 +61,9 @@ void base::enqueue(driver::stream *stream, std::vector args, b /* get launch parameters */ driver::kernel* kernel = jit->get_function(name_.c_str()); - triton::jit::launch_information info = jit->get_launch_info(name_.c_str()); + rt::launch_information info = jit->get_launch_info(name_.c_str()); /* launch */ - enqueue_impl(stream, kernel, args, - info.global_range_size, info.num_threads); + enqueue_impl(stream, kernel, args, info); } } diff --git a/lib/dnn/batchnorm.cpp b/lib/dnn/batchnorm.cpp index 54bb9c16e..34275a931 100644 --- a/lib/dnn/batchnorm.cpp +++ b/lib/dnn/batchnorm.cpp @@ -54,8 +54,7 @@ base* batchnorm_forward::clone() const { void batchnorm_forward::enqueue_impl(driver::stream *stream, driver::kernel *kernel, std::vector args, - const std::vector&, - size_t nthreads) + runtime::launch_information info) { driver::buffer *y = args[0], *m = args[1], *v = args[2]; driver::buffer *x = args[3], *g = args[4], *b = args[5]; @@ -69,7 +68,7 @@ void batchnorm_forward::enqueue_impl(driver::stream *stream, driver::kernel *ker kernel->setArg(6, DHWB_); kernel->setArg(7, rcpDHWB_); kernel->setArg(8, eps_); - stream->enqueue(kernel, grid, {nthreads, 1, 1}); + stream->enqueue(kernel, grid, {info.num_threads, 1, 1}); } void batchnorm_forward::triton_c_src(std::ostream &os) const { @@ -154,7 +153,7 @@ base* batchnorm_backward::clone() const { void batchnorm_backward::enqueue_impl(driver::stream *stream, driver::kernel *kernel, std::vector args, - const std::vector &, size_t nthreads) { + runtime::launch_information info) { driver::buffer *dx = args[0], *dg = args[1], *db = args[2], *dy = args[3]; driver::buffer *x = args[4], *g = args[5], *m = args[6], *v = args[7]; std::array grid = {1, (size_t)C_, 1}; @@ -169,7 +168,7 @@ void batchnorm_backward::enqueue_impl(driver::stream *stream, driver::kernel *ke kernel->setArg(8, (int32_t)(D_*H_*W_*B_)); kernel->setArg(9, (float)1/(D_*H_*W_*B_)); kernel->setArg(10, eps_); - stream->enqueue(kernel, grid, {nthreads, 1, 1}); + stream->enqueue(kernel, grid, {info.num_threads, 1, 1}); } void batchnorm_backward::triton_c_src(std::ostream &os) const { diff --git a/lib/dnn/conv.cpp b/lib/dnn/conv.cpp index 011cd7a53..c20701a4b 100644 --- a/lib/dnn/conv.cpp +++ b/lib/dnn/conv.cpp @@ -365,10 +365,9 @@ void conv::set_arg(driver::kernel *kernel, void conv::enqueue_impl(driver::stream *stream, driver::kernel *kernel, std::vector args, - const std::vector& ranges, - size_t nthreads) { + runtime::launch_information info) { driver::buffer *a = args[0], *b = args[1], *c = args[2], *bias = args[3]; - unsigned TM = ranges[0], TN = ranges[1]; + unsigned TM = info.global_range_size[0], TN = info.global_range_size[1]; unsigned GZ = 1; set_arg(kernel, a, b, c, bias); std::array grid = {1}; @@ -411,7 +410,7 @@ void conv::enqueue_impl(driver::stream *stream, driver::kernel *kernel, kernel->setArg(38, (pad_w_ + (1 - upsample_w_)*off_uw)/upsample_w_); kernel->setArg(39, (off_uh + pad_h_) % upsample_h_); kernel->setArg(40, (off_uw + pad_w_) % upsample_w_); - stream->enqueue(kernel, grid, {nthreads, 1, 1}); + stream->enqueue(kernel, grid, {info.num_threads, 1, 1}); } } diff --git a/lib/dnn/gemm.cpp b/lib/dnn/gemm.cpp index 6ea1a8c21..139062db8 100644 --- a/lib/dnn/gemm.cpp +++ b/lib/dnn/gemm.cpp @@ -47,11 +47,10 @@ void gemm::init_impl(driver::stream* stream, driver::cu_module *) { void gemm::enqueue_impl(driver::stream *stream, driver::kernel *kernel, std::vector args, - const std::vector& ranges, - size_t nthreads) { + runtime::launch_information info) { driver::buffer *a = args[0], *b = args[1], *c = args[2]; - unsigned TM = ranges[0]; - unsigned TN = ranges[1]; + unsigned TM = info.global_range_size[0]; + unsigned TN = info.global_range_size[1]; unsigned grid_0 = (M_ + TM - 1)/TM; unsigned grid_1 = (N_ + TN - 1)/TN; unsigned grid_2 = 1; @@ -68,7 +67,7 @@ void gemm::enqueue_impl(driver::stream *stream, driver::kernel *kernel, kernel->setArg(9, locks_); kernel->setArg(10, grid_0); kernel->setArg(11, grid_1); - stream->enqueue(kernel, grid, {nthreads, 1, 1}); + stream->enqueue(kernel, grid, {info.num_threads, 1, 1}); } std::vector gemm::default_params() { diff --git a/lib/dnn/shift.cpp b/lib/dnn/shift.cpp index 872189c89..cc6dccc4d 100644 --- a/lib/dnn/shift.cpp +++ b/lib/dnn/shift.cpp @@ -199,7 +199,7 @@ void shift::init_impl(driver::stream *stream, driver::cu_module *module) { void shift::enqueue_impl(driver::stream *stream, driver::kernel *kernel, std::vector args, - const std::vector &ranges, size_t nthreads) { + runtime::launch_information info) { driver::buffer *a = args[0], *b = args[1], *c = args[2]; kernel->setArg(0, a); kernel->setArg(1, b); @@ -228,13 +228,13 @@ void shift::enqueue_impl(driver::stream *stream, driver::kernel *kernel, kernel->setArg(24, BW_); kernel->setArg(25, CH_); kernel->setArg(26, CW_); - unsigned TM = ranges[0], TN = ranges[1]; + unsigned TM = info.global_range_size[0], TN = info.global_range_size[1]; std::array grid = {(M_ + TM - 1)/TM, (N_ + TN - 1)/TN, 1}; if(op_ == BPROP){ size_t c_nbytes = (c_ty_ == "fp16") ? 2 : 4; ((driver::cu_buffer*)c)->set_zero(stream, AH_*AW_*B_*C_*c_nbytes); } - stream->enqueue(kernel, grid, {nthreads, 1, 1}); + stream->enqueue(kernel, grid, {info.num_threads, 1, 1}); } void shift::triton_c_src(std::ostream &os) const { diff --git a/lib/runtime/jit.cpp b/lib/runtime/jit.cpp index 51f3ed916..b55680a21 100644 --- a/lib/runtime/jit.cpp +++ b/lib/runtime/jit.cpp @@ -29,6 +29,7 @@ extern void yy_delete_buffer(YY_BUFFER_STATE buffer); extern triton::lang::translation_unit *ast_root; namespace triton { +namespace runtime{ void loop_nest(std::vector const & ranges, std::function const &)> const & f, @@ -80,6 +81,10 @@ std::unique_ptr jit::make_llvm_module(ir::module &module, passes_w info.global_range_size.clear(); for(unsigned i = 0; i < passes.tune.get_num_global_range(); i++) info.global_range_size.push_back(passes.tune.get_global_range_size(i)); + // add globals + for(auto x: module.globals()) + info.globals[x.first] = ((ir::metaparameter*)x.second)->get_value(); + // number of threads info.num_threads = passes.tune.get_num_threads(); return std::unique_ptr(result); } @@ -164,7 +169,7 @@ jit::tune_res_t jit::autotune(const char *name, const char *src, benchmark_t ben ranges.push_back(mp->get_space()); // iterate over parameters tune_res_t best; - size_t nthreads = 4; + size_t nthreads = 1; std::mutex mutex; loop_nest(ranges, [&](const std::vector params){ std::map> errors; @@ -203,10 +208,6 @@ jit::tune_res_t jit::autotune(const char *name, const char *src, benchmark_t ben auto ll_module = make_llvm_module(tt_module_1, passes_1, llvm_context, info); std::unique_ptr module(driver::module::create(driver_context_, &*ll_module)); std::unique_ptr kernel(driver::kernel::create(module.get(), name)); - // add globals - for(auto x: tt_module_1.globals()) - global_ints_[x.first] = ((ir::metaparameter*)x.second)->get_value(); - modules_.insert({name, module.get()}); double perf; perf = benchmark(kernel.get(), info); { @@ -219,7 +220,6 @@ jit::tune_res_t jit::autotune(const char *name, const char *src, benchmark_t ben std::cout << p << " " << std::flush; std::cout << perf << " [ " << best.perf << " ] " << std::endl; } - modules_.erase(name); }, nthreads); return best; } @@ -248,9 +248,6 @@ void jit::add_module(ir::module &tt_module, const std::vector ¶ms) auto ll_module = make_llvm_module(tt_module, passes, llvm_context_, launch_info_map_[name]); // llvm module -> machine code modules_.insert({name, driver::module::create(driver_context_, &*ll_module)}); - // add globals - for(auto x: tt_module.globals()) - global_ints_[x.first] = ((ir::metaparameter*)x.second)->get_value(); } void jit::add_module(const char *name, const char *src, const std::vector ¶ms) { @@ -263,12 +260,10 @@ driver::kernel *jit::get_function(const char *name) { return driver::kernel::create(modules_.at(name), name); } -jit::launch_information jit::get_launch_info(const char *name) { +launch_information jit::get_launch_info(const char *name) { return launch_info_map_.at(name); } -unsigned jit::get_int(const char *name){ - return global_ints_.at(name); -} } +}