[runtime/jit] made auto-tuning silent

This commit is contained in:
Philippe Tillet
2019-07-16 14:41:38 -07:00
parent 7d1797cd32
commit 28959fe165
6 changed files with 105 additions and 172 deletions

View File

@@ -13,7 +13,7 @@ int main() {
// initialize default compute device
auto context = triton::driver::backend::contexts::get_default();
// matrix multiplication parameters
int32_t M = 131072, N = 128, K = 128;
int32_t M = 32768, N = 128, K = 128;
std::vector<float> hc(M*N);
std::vector<float> rc(M*N);
std::vector<float> ha(M*K);
@@ -33,8 +33,8 @@ int main() {
stream->write(db, true, 0, hb);
stream->write(dc, true, 0, hc);
stream->synchronize();
triton::dnn::gemm gemm(M, N, K, AT, BT, "fp32", "fp32", 4, 4);
gemm.enqueue(stream, {da, db, dc});
triton::dnn::gemm gemm(M, N, K, AT, BT, "fp16", "fp16", 4, 4);
gemm.enqueue(stream, {da, db, dc}, true);
stream->read(dc, true, 0, hc);
gemm.cpu_ref<float>(rc, ha, hb);
for(size_t i = 0; i < M*N; i++)

View File

@@ -75,7 +75,7 @@ torch::Tensor shift_common(
triton::driver::cu_buffer c(ctx, (CUdeviceptr)torchc.storage().data(), false);
// Enqueue
shift.enqueue(&stream, {&a, &b, &c});
shift.enqueue(&stream, {&a, &b, &c}, true);
return torchc;
}

View File

@@ -105,7 +105,7 @@ private:
triton::lang::translation_unit *parse_program(const char *name, const char *src);
public:
jit(driver::context* context);
jit(driver::context* context, unsigned nthreads = 4);
~jit();
std::vector<unsigned> get_valid(const char *name, const char *src);
tune_res_t autotune(const char* name, const char* src, benchmark_t benchmark);
@@ -121,6 +121,7 @@ private:
ir::context triton_context_;
std::map<std::string, launch_information> launch_info_map_;
std::shared_ptr<triton::codegen::target> target_;
unsigned nthreads_;
};
}

View File

@@ -1,161 +1,98 @@
#ifndef CONCURRENT_THREADPOOL_H
#define CONCURRENT_THREADPOOL_H
#ifndef THREAD_POOL_H
#define THREAD_POOL_H
#include <atomic>
#include <vector>
#include <queue>
#include <memory>
#include <thread>
#include <mutex>
#include <vector>
#include <list>
#include <functional>
#include <condition_variable>
#include <future>
#include <functional>
#include <stdexcept>
namespace nbsdx {
namespace concurrent {
/**
* Simple ThreadPool that creates `ThreadCount` threads upon its creation,
* and pulls from a queue to get new jobs. The default is 10 threads.
*
* This class requires a number of c++11 features be present in your compiler.
*/
class thread_pool {
std::vector<std::thread> threads_;
std::list<std::function<void(void)>> queue_;
std::atomic_int jobs_left_;
std::atomic_bool bailout_;
std::atomic_bool finished_;
std::condition_variable job_available_var_;
std::condition_variable wait_var_;
std::mutex wait_mutex_;
std::mutex queue_mutex_;
unsigned thread_count_;
/**
* Take the next job in the queue and run it.
* Notify the main thread that a job has completed.
*/
void task() {
while( !bailout_ ) {
next_job()();
--jobs_left_;
wait_var_.notify_one();
}
}
/**
* Get the next job; pop the first item in the queue,
* otherwise wait for a signal from the main thread.
*/
std::function<void(void)> next_job() {
std::function<void(void)> res;
std::unique_lock<std::mutex> job_lock( queue_mutex_ );
// Wait for a job if we don't have any.
job_available_var_.wait( job_lock, [this]() ->bool { return queue_.size() || bailout_; } );
// Get job from the queue
if( !bailout_ ) {
res = queue_.front();
queue_.pop_front();
}
else { // If we're bailing out, 'inject' a job into the queue to keep jobs_left accurate.
res = []{};
++jobs_left_;
}
return res;
}
class ThreadPool {
public:
thread_pool(unsigned thread_count = 4)
: jobs_left_( 0 )
, bailout_( false )
, finished_( false )
, thread_count_(thread_count)
{
threads_.resize(thread_count_);
for( unsigned i = 0; i < thread_count_; ++i )
threads_[ i ] = std::thread( [this]{ this->task(); } );
}
ThreadPool(size_t);
template<class F, class... Args>
auto enqueue(F&& f, Args&&... args)
-> std::future<typename std::result_of<F(Args...)>::type>;
~ThreadPool();
private:
// need to keep track of threads so we can join them
std::vector< std::thread > workers;
// the task queue
std::queue< std::function<void()> > tasks;
/**
* JoinAll on deconstruction
*/
~thread_pool() {
join_all();
}
/**
* Get the number of threads in this pool
*/
inline unsigned size() const {
return thread_count_;
}
/**
* Get the number of jobs left in the queue.
*/
inline unsigned jobs_remaining() {
std::lock_guard<std::mutex> guard( queue_mutex_ );
return queue_.size();
}
/**
* Add a new job to the pool. If there are no jobs in the queue,
* a thread is woken up to take the job. If all threads are busy,
* the job is added to the end of the queue.
*/
void add_job( std::function<void(void)> job ) {
std::lock_guard<std::mutex> guard( queue_mutex_ );
queue_.emplace_back( job );
++jobs_left_;
job_available_var_.notify_one();
}
/**
* Join with all threads. Block until all threads have completed.
* Params: WaitForAll: If true, will wait for the queue to empty
* before joining with threads. If false, will complete
* current jobs, then inform the threads to exit.
* The queue will be empty after this call, and the threads will
* be done. After invoking `ThreadPool::JoinAll`, the pool can no
* longer be used. If you need the pool to exist past completion
* of jobs, look to use `ThreadPool::WaitAll`.
*/
void join_all( bool WaitForAll = true ) {
if( !finished_ ) {
if( WaitForAll ) {
wait_all();
}
// note that we're done, and wake up any thread that's
// waiting for a new job
bailout_ = true;
job_available_var_.notify_all();
for( auto &x : threads_ )
if( x.joinable() )
x.join();
finished_ = true;
}
}
/**
* Wait for the pool to empty before continuing.
* This does not call `std::thread::join`, it only waits until
* all jobs have finshed executing.
*/
void wait_all() {
if( jobs_left_ > 0 ) {
std::unique_lock<std::mutex> lk( wait_mutex_ );
wait_var_.wait( lk, [this]{ return this->jobs_left_ == 0; } );
lk.unlock();
}
}
// synchronization
std::mutex queue_mutex;
std::condition_variable condition;
bool stop;
};
} // namespace concurrent
} // namespace nbsdx
// the constructor just launches some amount of workers
inline ThreadPool::ThreadPool(size_t threads)
: stop(false)
{
for(size_t i = 0;i<threads;++i)
workers.emplace_back(
[this]
{
for(;;)
{
std::function<void()> task;
#endif //CONCURRENT_THREADPOOL_H
{
std::unique_lock<std::mutex> lock(this->queue_mutex);
this->condition.wait(lock,
[this]{ return this->stop || !this->tasks.empty(); });
if(this->stop && this->tasks.empty())
return;
task = std::move(this->tasks.front());
this->tasks.pop();
}
task();
}
}
);
}
// add new work item to the pool
template<class F, class... Args>
auto ThreadPool::enqueue(F&& f, Args&&... args)
-> std::future<typename std::result_of<F(Args...)>::type>
{
using return_type = typename std::result_of<F(Args...)>::type;
auto task = std::make_shared< std::packaged_task<return_type()> >(
std::bind(std::forward<F>(f), std::forward<Args>(args)...)
);
std::future<return_type> res = task->get_future();
{
std::unique_lock<std::mutex> lock(queue_mutex);
// don't allow enqueueing after stopping the pool
if(stop)
throw std::runtime_error("enqueue on stopped ThreadPool");
tasks.emplace([task](){ (*task)(); });
}
condition.notify_one();
return res;
}
// the destructor joins all threads
inline ThreadPool::~ThreadPool()
{
{
std::unique_lock<std::mutex> lock(queue_mutex);
stop = true;
}
condition.notify_all();
for(std::thread &worker: workers)
worker.join();
}
#endif

View File

@@ -223,7 +223,7 @@ void tune::run(ir::module &mod) {
}
else {
ir::metaparameter *fpw = ir::metaparameter::create(ctx, ty, 2, 2);
ir::metaparameter *wpt = ir::metaparameter::create(ctx, ty, 1, 4);
ir::metaparameter *wpt = ir::metaparameter::create(ctx, ty, 2, 4);
connected_components(node, {fpw, wpt}, {"fpw", "wpt"}, nodes_, dependencies_, group_id++);
}
}

View File

@@ -37,24 +37,18 @@ void loop_nest(std::vector<size_t> const & ranges,
size_t D = ranges.size();
std::vector<size_t> values(D, 0);
// thread pools
// nbsdx::concurrent::thread_pool pool(nthreads);
ThreadPool pool(nthreads);
// Start with innermost loop
size_t i = D - 1;
// size_t current = 0;
while(true){
//Execute function
// pool.add_job([values, &f](){ f(values); });
f(values);
//Increment counters
pool.enqueue([values, &f](){ f(values); });
while(values[i]++ == ranges[i] - 1){
if(i == 0)
return;
values[i--] = 0;
}
// if(current++ >= 1024){
// current = 0;
// pool.join_all();
// }
i = D - 1;
}
}
@@ -111,8 +105,9 @@ std::unique_ptr<ir::module> jit::make_triton_module(const char * name, triton::i
}
jit::jit(driver::context *context): driver_context_(context),
target_(context->device()->make_target()) { }
jit::jit(driver::context *context, unsigned nthreads): driver_context_(context),
target_(context->device()->make_target()),
nthreads_(nthreads) { }
jit::~jit(){ }
@@ -173,7 +168,6 @@ jit::tune_res_t jit::autotune(const char *name, const char *src, benchmark_t ben
ranges.push_back(mp->get_space());
// iterate over parameters
tune_res_t best;
size_t nthreads = 4;
std::mutex mutex;
loop_nest<unsigned>(ranges, [&](const std::vector<unsigned> params){
std::map<ir::value*, std::vector<std::string>> errors;
@@ -220,11 +214,12 @@ jit::tune_res_t jit::autotune(const char *name, const char *src, benchmark_t ben
best.perf = perf;
best.params = params;
}
for(unsigned p: params)
std::cout << p << " " << std::flush;
std::cout << perf << " [ " << best.perf << " ] " << std::endl;
// for(unsigned p: params)
// std::cout << p << " " << std::flush;
// std::cout << perf << " [ " << best.perf << " ] " << std::endl;
}
}, nthreads);
}, nthreads_);
std::cout << "Autotuning done - Best performance: " << best.perf << std::endl;
return best;
}