[runtime/jit] made auto-tuning silent
This commit is contained in:
@@ -13,7 +13,7 @@ int main() {
|
||||
// initialize default compute device
|
||||
auto context = triton::driver::backend::contexts::get_default();
|
||||
// matrix multiplication parameters
|
||||
int32_t M = 131072, N = 128, K = 128;
|
||||
int32_t M = 32768, N = 128, K = 128;
|
||||
std::vector<float> hc(M*N);
|
||||
std::vector<float> rc(M*N);
|
||||
std::vector<float> ha(M*K);
|
||||
@@ -33,8 +33,8 @@ int main() {
|
||||
stream->write(db, true, 0, hb);
|
||||
stream->write(dc, true, 0, hc);
|
||||
stream->synchronize();
|
||||
triton::dnn::gemm gemm(M, N, K, AT, BT, "fp32", "fp32", 4, 4);
|
||||
gemm.enqueue(stream, {da, db, dc});
|
||||
triton::dnn::gemm gemm(M, N, K, AT, BT, "fp16", "fp16", 4, 4);
|
||||
gemm.enqueue(stream, {da, db, dc}, true);
|
||||
stream->read(dc, true, 0, hc);
|
||||
gemm.cpu_ref<float>(rc, ha, hb);
|
||||
for(size_t i = 0; i < M*N; i++)
|
||||
|
@@ -75,7 +75,7 @@ torch::Tensor shift_common(
|
||||
|
||||
triton::driver::cu_buffer c(ctx, (CUdeviceptr)torchc.storage().data(), false);
|
||||
// Enqueue
|
||||
shift.enqueue(&stream, {&a, &b, &c});
|
||||
shift.enqueue(&stream, {&a, &b, &c}, true);
|
||||
return torchc;
|
||||
}
|
||||
|
||||
|
@@ -105,7 +105,7 @@ private:
|
||||
triton::lang::translation_unit *parse_program(const char *name, const char *src);
|
||||
|
||||
public:
|
||||
jit(driver::context* context);
|
||||
jit(driver::context* context, unsigned nthreads = 4);
|
||||
~jit();
|
||||
std::vector<unsigned> get_valid(const char *name, const char *src);
|
||||
tune_res_t autotune(const char* name, const char* src, benchmark_t benchmark);
|
||||
@@ -121,6 +121,7 @@ private:
|
||||
ir::context triton_context_;
|
||||
std::map<std::string, launch_information> launch_info_map_;
|
||||
std::shared_ptr<triton::codegen::target> target_;
|
||||
unsigned nthreads_;
|
||||
};
|
||||
|
||||
}
|
||||
|
@@ -1,161 +1,98 @@
|
||||
#ifndef CONCURRENT_THREADPOOL_H
|
||||
#define CONCURRENT_THREADPOOL_H
|
||||
#ifndef THREAD_POOL_H
|
||||
#define THREAD_POOL_H
|
||||
|
||||
#include <atomic>
|
||||
#include <vector>
|
||||
#include <queue>
|
||||
#include <memory>
|
||||
#include <thread>
|
||||
#include <mutex>
|
||||
#include <vector>
|
||||
#include <list>
|
||||
#include <functional>
|
||||
#include <condition_variable>
|
||||
#include <future>
|
||||
#include <functional>
|
||||
#include <stdexcept>
|
||||
|
||||
namespace nbsdx {
|
||||
namespace concurrent {
|
||||
|
||||
/**
|
||||
* Simple ThreadPool that creates `ThreadCount` threads upon its creation,
|
||||
* and pulls from a queue to get new jobs. The default is 10 threads.
|
||||
*
|
||||
* This class requires a number of c++11 features be present in your compiler.
|
||||
*/
|
||||
class thread_pool {
|
||||
|
||||
std::vector<std::thread> threads_;
|
||||
std::list<std::function<void(void)>> queue_;
|
||||
|
||||
std::atomic_int jobs_left_;
|
||||
std::atomic_bool bailout_;
|
||||
std::atomic_bool finished_;
|
||||
std::condition_variable job_available_var_;
|
||||
std::condition_variable wait_var_;
|
||||
std::mutex wait_mutex_;
|
||||
std::mutex queue_mutex_;
|
||||
unsigned thread_count_;
|
||||
|
||||
/**
|
||||
* Take the next job in the queue and run it.
|
||||
* Notify the main thread that a job has completed.
|
||||
*/
|
||||
void task() {
|
||||
while( !bailout_ ) {
|
||||
next_job()();
|
||||
--jobs_left_;
|
||||
wait_var_.notify_one();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the next job; pop the first item in the queue,
|
||||
* otherwise wait for a signal from the main thread.
|
||||
*/
|
||||
std::function<void(void)> next_job() {
|
||||
std::function<void(void)> res;
|
||||
std::unique_lock<std::mutex> job_lock( queue_mutex_ );
|
||||
|
||||
// Wait for a job if we don't have any.
|
||||
job_available_var_.wait( job_lock, [this]() ->bool { return queue_.size() || bailout_; } );
|
||||
|
||||
// Get job from the queue
|
||||
if( !bailout_ ) {
|
||||
res = queue_.front();
|
||||
queue_.pop_front();
|
||||
}
|
||||
else { // If we're bailing out, 'inject' a job into the queue to keep jobs_left accurate.
|
||||
res = []{};
|
||||
++jobs_left_;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
class ThreadPool {
|
||||
public:
|
||||
thread_pool(unsigned thread_count = 4)
|
||||
: jobs_left_( 0 )
|
||||
, bailout_( false )
|
||||
, finished_( false )
|
||||
, thread_count_(thread_count)
|
||||
{
|
||||
threads_.resize(thread_count_);
|
||||
for( unsigned i = 0; i < thread_count_; ++i )
|
||||
threads_[ i ] = std::thread( [this]{ this->task(); } );
|
||||
}
|
||||
ThreadPool(size_t);
|
||||
template<class F, class... Args>
|
||||
auto enqueue(F&& f, Args&&... args)
|
||||
-> std::future<typename std::result_of<F(Args...)>::type>;
|
||||
~ThreadPool();
|
||||
private:
|
||||
// need to keep track of threads so we can join them
|
||||
std::vector< std::thread > workers;
|
||||
// the task queue
|
||||
std::queue< std::function<void()> > tasks;
|
||||
|
||||
/**
|
||||
* JoinAll on deconstruction
|
||||
*/
|
||||
~thread_pool() {
|
||||
join_all();
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the number of threads in this pool
|
||||
*/
|
||||
inline unsigned size() const {
|
||||
return thread_count_;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the number of jobs left in the queue.
|
||||
*/
|
||||
inline unsigned jobs_remaining() {
|
||||
std::lock_guard<std::mutex> guard( queue_mutex_ );
|
||||
return queue_.size();
|
||||
}
|
||||
|
||||
/**
|
||||
* Add a new job to the pool. If there are no jobs in the queue,
|
||||
* a thread is woken up to take the job. If all threads are busy,
|
||||
* the job is added to the end of the queue.
|
||||
*/
|
||||
void add_job( std::function<void(void)> job ) {
|
||||
std::lock_guard<std::mutex> guard( queue_mutex_ );
|
||||
queue_.emplace_back( job );
|
||||
++jobs_left_;
|
||||
job_available_var_.notify_one();
|
||||
}
|
||||
|
||||
/**
|
||||
* Join with all threads. Block until all threads have completed.
|
||||
* Params: WaitForAll: If true, will wait for the queue to empty
|
||||
* before joining with threads. If false, will complete
|
||||
* current jobs, then inform the threads to exit.
|
||||
* The queue will be empty after this call, and the threads will
|
||||
* be done. After invoking `ThreadPool::JoinAll`, the pool can no
|
||||
* longer be used. If you need the pool to exist past completion
|
||||
* of jobs, look to use `ThreadPool::WaitAll`.
|
||||
*/
|
||||
void join_all( bool WaitForAll = true ) {
|
||||
if( !finished_ ) {
|
||||
if( WaitForAll ) {
|
||||
wait_all();
|
||||
}
|
||||
|
||||
// note that we're done, and wake up any thread that's
|
||||
// waiting for a new job
|
||||
bailout_ = true;
|
||||
job_available_var_.notify_all();
|
||||
|
||||
for( auto &x : threads_ )
|
||||
if( x.joinable() )
|
||||
x.join();
|
||||
finished_ = true;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Wait for the pool to empty before continuing.
|
||||
* This does not call `std::thread::join`, it only waits until
|
||||
* all jobs have finshed executing.
|
||||
*/
|
||||
void wait_all() {
|
||||
if( jobs_left_ > 0 ) {
|
||||
std::unique_lock<std::mutex> lk( wait_mutex_ );
|
||||
wait_var_.wait( lk, [this]{ return this->jobs_left_ == 0; } );
|
||||
lk.unlock();
|
||||
}
|
||||
}
|
||||
// synchronization
|
||||
std::mutex queue_mutex;
|
||||
std::condition_variable condition;
|
||||
bool stop;
|
||||
};
|
||||
|
||||
} // namespace concurrent
|
||||
} // namespace nbsdx
|
||||
// the constructor just launches some amount of workers
|
||||
inline ThreadPool::ThreadPool(size_t threads)
|
||||
: stop(false)
|
||||
{
|
||||
for(size_t i = 0;i<threads;++i)
|
||||
workers.emplace_back(
|
||||
[this]
|
||||
{
|
||||
for(;;)
|
||||
{
|
||||
std::function<void()> task;
|
||||
|
||||
#endif //CONCURRENT_THREADPOOL_H
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(this->queue_mutex);
|
||||
this->condition.wait(lock,
|
||||
[this]{ return this->stop || !this->tasks.empty(); });
|
||||
if(this->stop && this->tasks.empty())
|
||||
return;
|
||||
task = std::move(this->tasks.front());
|
||||
this->tasks.pop();
|
||||
}
|
||||
|
||||
task();
|
||||
}
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
// add new work item to the pool
|
||||
template<class F, class... Args>
|
||||
auto ThreadPool::enqueue(F&& f, Args&&... args)
|
||||
-> std::future<typename std::result_of<F(Args...)>::type>
|
||||
{
|
||||
using return_type = typename std::result_of<F(Args...)>::type;
|
||||
|
||||
auto task = std::make_shared< std::packaged_task<return_type()> >(
|
||||
std::bind(std::forward<F>(f), std::forward<Args>(args)...)
|
||||
);
|
||||
|
||||
std::future<return_type> res = task->get_future();
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(queue_mutex);
|
||||
|
||||
// don't allow enqueueing after stopping the pool
|
||||
if(stop)
|
||||
throw std::runtime_error("enqueue on stopped ThreadPool");
|
||||
|
||||
tasks.emplace([task](){ (*task)(); });
|
||||
}
|
||||
condition.notify_one();
|
||||
return res;
|
||||
}
|
||||
|
||||
// the destructor joins all threads
|
||||
inline ThreadPool::~ThreadPool()
|
||||
{
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(queue_mutex);
|
||||
stop = true;
|
||||
}
|
||||
condition.notify_all();
|
||||
for(std::thread &worker: workers)
|
||||
worker.join();
|
||||
}
|
||||
|
||||
#endif
|
||||
|
@@ -223,7 +223,7 @@ void tune::run(ir::module &mod) {
|
||||
}
|
||||
else {
|
||||
ir::metaparameter *fpw = ir::metaparameter::create(ctx, ty, 2, 2);
|
||||
ir::metaparameter *wpt = ir::metaparameter::create(ctx, ty, 1, 4);
|
||||
ir::metaparameter *wpt = ir::metaparameter::create(ctx, ty, 2, 4);
|
||||
connected_components(node, {fpw, wpt}, {"fpw", "wpt"}, nodes_, dependencies_, group_id++);
|
||||
}
|
||||
}
|
||||
|
@@ -37,24 +37,18 @@ void loop_nest(std::vector<size_t> const & ranges,
|
||||
size_t D = ranges.size();
|
||||
std::vector<size_t> values(D, 0);
|
||||
// thread pools
|
||||
// nbsdx::concurrent::thread_pool pool(nthreads);
|
||||
ThreadPool pool(nthreads);
|
||||
// Start with innermost loop
|
||||
size_t i = D - 1;
|
||||
// size_t current = 0;
|
||||
while(true){
|
||||
//Execute function
|
||||
// pool.add_job([values, &f](){ f(values); });
|
||||
f(values);
|
||||
//Increment counters
|
||||
pool.enqueue([values, &f](){ f(values); });
|
||||
while(values[i]++ == ranges[i] - 1){
|
||||
if(i == 0)
|
||||
return;
|
||||
values[i--] = 0;
|
||||
}
|
||||
// if(current++ >= 1024){
|
||||
// current = 0;
|
||||
// pool.join_all();
|
||||
// }
|
||||
i = D - 1;
|
||||
}
|
||||
}
|
||||
@@ -111,8 +105,9 @@ std::unique_ptr<ir::module> jit::make_triton_module(const char * name, triton::i
|
||||
}
|
||||
|
||||
|
||||
jit::jit(driver::context *context): driver_context_(context),
|
||||
target_(context->device()->make_target()) { }
|
||||
jit::jit(driver::context *context, unsigned nthreads): driver_context_(context),
|
||||
target_(context->device()->make_target()),
|
||||
nthreads_(nthreads) { }
|
||||
|
||||
jit::~jit(){ }
|
||||
|
||||
@@ -173,7 +168,6 @@ jit::tune_res_t jit::autotune(const char *name, const char *src, benchmark_t ben
|
||||
ranges.push_back(mp->get_space());
|
||||
// iterate over parameters
|
||||
tune_res_t best;
|
||||
size_t nthreads = 4;
|
||||
std::mutex mutex;
|
||||
loop_nest<unsigned>(ranges, [&](const std::vector<unsigned> params){
|
||||
std::map<ir::value*, std::vector<std::string>> errors;
|
||||
@@ -220,11 +214,12 @@ jit::tune_res_t jit::autotune(const char *name, const char *src, benchmark_t ben
|
||||
best.perf = perf;
|
||||
best.params = params;
|
||||
}
|
||||
for(unsigned p: params)
|
||||
std::cout << p << " " << std::flush;
|
||||
std::cout << perf << " [ " << best.perf << " ] " << std::endl;
|
||||
// for(unsigned p: params)
|
||||
// std::cout << p << " " << std::flush;
|
||||
// std::cout << perf << " [ " << best.perf << " ] " << std::endl;
|
||||
}
|
||||
}, nthreads);
|
||||
}, nthreads_);
|
||||
std::cout << "Autotuning done - Best performance: " << best.perf << std::endl;
|
||||
return best;
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user