[runtime] put jit::launch_info in another file
This commit is contained in:
@@ -14,13 +14,13 @@ int main() {
|
||||
|
||||
// initialize default compute device
|
||||
auto context = triton::driver::backend::contexts::get_default();
|
||||
auto op = triton::dnn::shift::FPROP;
|
||||
auto op = triton::dnn::shift::WGRAD;
|
||||
|
||||
// initialization
|
||||
int32_t R = 3, S = 3;
|
||||
int32_t B = 16, F = 512;
|
||||
int32_t H = 16, W = 16;
|
||||
int32_t C = 512;
|
||||
int32_t B = 32, F = 128;
|
||||
int32_t H = 28, W = 28;
|
||||
int32_t C = 128;
|
||||
|
||||
// random shifts
|
||||
std::vector<int32_t> shift_h(C);
|
||||
|
@@ -25,6 +25,7 @@
|
||||
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/driver/kernel.h"
|
||||
#include "triton/runtime/launch_info.h"
|
||||
|
||||
namespace triton{
|
||||
namespace dnn{
|
||||
@@ -45,8 +46,7 @@ private:
|
||||
// enqueue
|
||||
virtual void enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
std::vector<driver::buffer*> args,
|
||||
const std::vector<unsigned>& ranges,
|
||||
size_t nthreads) = 0;
|
||||
triton::runtime::launch_information info) = 0;
|
||||
// number of flops
|
||||
virtual size_t num_flops() const = 0;
|
||||
// comparison for maps
|
||||
|
@@ -40,7 +40,7 @@ private:
|
||||
// enqueue
|
||||
void enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
std::vector<driver::buffer*> args,
|
||||
const std::vector<unsigned> &ranges, size_t nthreads);
|
||||
triton::runtime::launch_information info);
|
||||
// number of flops
|
||||
size_t num_flops() const;
|
||||
// comparison for maps
|
||||
@@ -72,7 +72,7 @@ private:
|
||||
// enqueue
|
||||
void enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
std::vector<driver::buffer*> args,
|
||||
const std::vector<unsigned> &ranges, size_t nthreads);
|
||||
runtime::launch_information info);
|
||||
// number of flops
|
||||
size_t num_flops() const;
|
||||
// comparison for maps
|
||||
|
@@ -33,8 +33,7 @@ private:
|
||||
driver::buffer *bias);
|
||||
void enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
std::vector<driver::buffer*> args,
|
||||
const std::vector<unsigned>& ranges,
|
||||
size_t nthreads);
|
||||
triton::runtime::launch_information info);
|
||||
// number of flops
|
||||
size_t num_flops() const;
|
||||
// comparison for maps
|
||||
|
@@ -13,8 +13,7 @@ private:
|
||||
// enqueue
|
||||
void enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
std::vector<driver::buffer*> args,
|
||||
const std::vector<unsigned>& ranges,
|
||||
size_t nthreads);
|
||||
triton::runtime::launch_information info);
|
||||
// number of flops
|
||||
size_t num_flops() const;
|
||||
// comparison for maps
|
||||
|
@@ -54,8 +54,7 @@ private:
|
||||
void init_impl(driver::stream *stream, driver::cu_module *module);
|
||||
void enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
std::vector<driver::buffer*> args,
|
||||
const std::vector<unsigned>& ranges,
|
||||
size_t nthreads);
|
||||
triton::runtime::launch_information info);
|
||||
|
||||
public:
|
||||
|
||||
|
@@ -20,6 +20,7 @@
|
||||
#include "triton/codegen/alignment_info.h"
|
||||
#include "triton/codegen/target.h"
|
||||
#include "triton/codegen/vectorize.h"
|
||||
#include "triton/runtime/launch_info.h"
|
||||
#include <functional>
|
||||
|
||||
namespace llvm {
|
||||
@@ -42,12 +43,10 @@ class context;
|
||||
class metaparameter;
|
||||
}
|
||||
|
||||
namespace runtime{
|
||||
|
||||
class jit {
|
||||
public:
|
||||
struct launch_information{
|
||||
std::vector<unsigned> global_range_size;
|
||||
unsigned num_threads;
|
||||
};
|
||||
typedef std::function<double(driver::kernel*, launch_information)> benchmark_t;
|
||||
|
||||
struct tune_res_t{
|
||||
@@ -114,7 +113,6 @@ public:
|
||||
void add_module(const char* name, const char* src, const std::vector<unsigned>& params = {});
|
||||
driver::kernel* get_function(const char* name);
|
||||
launch_information get_launch_info(const char* name);
|
||||
unsigned get_int(const char* name);
|
||||
|
||||
private:
|
||||
std::map<std::string, driver::module*> modules_;
|
||||
@@ -122,11 +120,10 @@ private:
|
||||
llvm::LLVMContext llvm_context_;
|
||||
ir::context triton_context_;
|
||||
std::map<std::string, launch_information> launch_info_map_;
|
||||
std::map<std::string, unsigned> global_ints_;
|
||||
std::shared_ptr<triton::codegen::target> target_;
|
||||
};
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
|
19
include/triton/runtime/launch_info.h
Normal file
19
include/triton/runtime/launch_info.h
Normal file
@@ -0,0 +1,19 @@
|
||||
#ifndef TRITON_INCLUDE_RUNTIME_LAUNCH_INFO_H
|
||||
#define TRITON_INCLUDE_RUNTIME_LAUNCH_INFO_H
|
||||
|
||||
#include <vector>
|
||||
#include <map>
|
||||
|
||||
namespace triton{
|
||||
namespace runtime{
|
||||
|
||||
struct launch_information{
|
||||
std::vector<unsigned> global_range_size;
|
||||
unsigned num_threads;
|
||||
std::map<std::string, unsigned> globals;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
161
include/triton/tools/thread_pool.h
Normal file
161
include/triton/tools/thread_pool.h
Normal file
@@ -0,0 +1,161 @@
|
||||
#ifndef CONCURRENT_THREADPOOL_H
|
||||
#define CONCURRENT_THREADPOOL_H
|
||||
|
||||
#include <atomic>
|
||||
#include <thread>
|
||||
#include <mutex>
|
||||
#include <vector>
|
||||
#include <list>
|
||||
#include <functional>
|
||||
#include <condition_variable>
|
||||
|
||||
namespace nbsdx {
|
||||
namespace concurrent {
|
||||
|
||||
/**
|
||||
* Simple ThreadPool that creates `ThreadCount` threads upon its creation,
|
||||
* and pulls from a queue to get new jobs. The default is 10 threads.
|
||||
*
|
||||
* This class requires a number of c++11 features be present in your compiler.
|
||||
*/
|
||||
class thread_pool {
|
||||
|
||||
std::vector<std::thread> threads_;
|
||||
std::list<std::function<void(void)>> queue_;
|
||||
|
||||
std::atomic_int jobs_left_;
|
||||
std::atomic_bool bailout_;
|
||||
std::atomic_bool finished_;
|
||||
std::condition_variable job_available_var_;
|
||||
std::condition_variable wait_var_;
|
||||
std::mutex wait_mutex_;
|
||||
std::mutex queue_mutex_;
|
||||
unsigned thread_count_;
|
||||
|
||||
/**
|
||||
* Take the next job in the queue and run it.
|
||||
* Notify the main thread that a job has completed.
|
||||
*/
|
||||
void task() {
|
||||
while( !bailout_ ) {
|
||||
next_job()();
|
||||
--jobs_left_;
|
||||
wait_var_.notify_one();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the next job; pop the first item in the queue,
|
||||
* otherwise wait for a signal from the main thread.
|
||||
*/
|
||||
std::function<void(void)> next_job() {
|
||||
std::function<void(void)> res;
|
||||
std::unique_lock<std::mutex> job_lock( queue_mutex_ );
|
||||
|
||||
// Wait for a job if we don't have any.
|
||||
job_available_var_.wait( job_lock, [this]() ->bool { return queue_.size() || bailout_; } );
|
||||
|
||||
// Get job from the queue
|
||||
if( !bailout_ ) {
|
||||
res = queue_.front();
|
||||
queue_.pop_front();
|
||||
}
|
||||
else { // If we're bailing out, 'inject' a job into the queue to keep jobs_left accurate.
|
||||
res = []{};
|
||||
++jobs_left_;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
public:
|
||||
thread_pool(unsigned thread_count = 4)
|
||||
: jobs_left_( 0 )
|
||||
, bailout_( false )
|
||||
, finished_( false )
|
||||
, thread_count_(thread_count)
|
||||
{
|
||||
threads_.resize(thread_count_);
|
||||
for( unsigned i = 0; i < thread_count_; ++i )
|
||||
threads_[ i ] = std::thread( [this]{ this->task(); } );
|
||||
}
|
||||
|
||||
/**
|
||||
* JoinAll on deconstruction
|
||||
*/
|
||||
~thread_pool() {
|
||||
join_all();
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the number of threads in this pool
|
||||
*/
|
||||
inline unsigned size() const {
|
||||
return thread_count_;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the number of jobs left in the queue.
|
||||
*/
|
||||
inline unsigned jobs_remaining() {
|
||||
std::lock_guard<std::mutex> guard( queue_mutex_ );
|
||||
return queue_.size();
|
||||
}
|
||||
|
||||
/**
|
||||
* Add a new job to the pool. If there are no jobs in the queue,
|
||||
* a thread is woken up to take the job. If all threads are busy,
|
||||
* the job is added to the end of the queue.
|
||||
*/
|
||||
void add_job( std::function<void(void)> job ) {
|
||||
std::lock_guard<std::mutex> guard( queue_mutex_ );
|
||||
queue_.emplace_back( job );
|
||||
++jobs_left_;
|
||||
job_available_var_.notify_one();
|
||||
}
|
||||
|
||||
/**
|
||||
* Join with all threads. Block until all threads have completed.
|
||||
* Params: WaitForAll: If true, will wait for the queue to empty
|
||||
* before joining with threads. If false, will complete
|
||||
* current jobs, then inform the threads to exit.
|
||||
* The queue will be empty after this call, and the threads will
|
||||
* be done. After invoking `ThreadPool::JoinAll`, the pool can no
|
||||
* longer be used. If you need the pool to exist past completion
|
||||
* of jobs, look to use `ThreadPool::WaitAll`.
|
||||
*/
|
||||
void join_all( bool WaitForAll = true ) {
|
||||
if( !finished_ ) {
|
||||
if( WaitForAll ) {
|
||||
wait_all();
|
||||
}
|
||||
|
||||
// note that we're done, and wake up any thread that's
|
||||
// waiting for a new job
|
||||
bailout_ = true;
|
||||
job_available_var_.notify_all();
|
||||
|
||||
for( auto &x : threads_ )
|
||||
if( x.joinable() )
|
||||
x.join();
|
||||
finished_ = true;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Wait for the pool to empty before continuing.
|
||||
* This does not call `std::thread::join`, it only waits until
|
||||
* all jobs have finshed executing.
|
||||
*/
|
||||
void wait_all() {
|
||||
if( jobs_left_ > 0 ) {
|
||||
std::unique_lock<std::mutex> lk( wait_mutex_ );
|
||||
wait_var_.wait( lk, [this]{ return this->jobs_left_ == 0; } );
|
||||
lk.unlock();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace concurrent
|
||||
} // namespace nbsdx
|
||||
|
||||
#endif //CONCURRENT_THREADPOOL_H
|
@@ -23,29 +23,30 @@ base::base(const std::string& name)
|
||||
: name_(name) { }
|
||||
|
||||
void base::enqueue(driver::stream *stream, std::vector<driver::buffer *> args, bool autotune) {
|
||||
static std::map<base*, std::unique_ptr<triton::jit>, cmp_recompile> m_jit;
|
||||
namespace rt = triton::runtime;
|
||||
static std::map<base*, std::unique_ptr<rt::jit>, cmp_recompile> m_jit;
|
||||
driver::context* ctx = stream->context();
|
||||
triton::jit* jit;
|
||||
rt::jit* jit;
|
||||
/* the current template has not already been compiled */
|
||||
if(m_jit.find(this) == m_jit.end()) {
|
||||
jit = m_jit.emplace(this->clone(), new triton::jit(ctx)).first->second.get();
|
||||
jit = m_jit.emplace(this->clone(), new rt::jit(ctx)).first->second.get();
|
||||
std::ostringstream oss;
|
||||
triton_c_src(oss);
|
||||
std::string src = oss.str();
|
||||
auto benchmark = [&](triton::driver::kernel* kernel,
|
||||
triton::jit::launch_information info) {
|
||||
rt::launch_information info) {
|
||||
// launch info
|
||||
unsigned nthreads = info.num_threads;
|
||||
init_impl(stream, (triton::driver::cu_module*)kernel->module());
|
||||
enqueue_impl(stream, kernel, args, info.global_range_size, nthreads);
|
||||
enqueue_impl(stream, kernel, args, info);
|
||||
stream->synchronize();
|
||||
double ts = triton::tools::bench([&](){ enqueue_impl(stream, kernel, args, info.global_range_size, nthreads); },
|
||||
double ts = triton::tools::bench([&](){ enqueue_impl(stream, kernel, args, info); },
|
||||
[&](){ stream->synchronize(); }, ctx->device());
|
||||
return num_flops() / ts * 1e-3;
|
||||
};
|
||||
// auto-tune and save result
|
||||
if(autotune) {
|
||||
triton::jit::tune_res_t best = jit->autotune(name_.c_str(), src.c_str(), benchmark);
|
||||
rt::jit::tune_res_t best = jit->autotune(name_.c_str(), src.c_str(), benchmark);
|
||||
jit->add_module(name_.c_str(), src.c_str(), best.params);
|
||||
}
|
||||
else {
|
||||
@@ -60,10 +61,9 @@ void base::enqueue(driver::stream *stream, std::vector<driver::buffer *> args, b
|
||||
|
||||
/* get launch parameters */
|
||||
driver::kernel* kernel = jit->get_function(name_.c_str());
|
||||
triton::jit::launch_information info = jit->get_launch_info(name_.c_str());
|
||||
rt::launch_information info = jit->get_launch_info(name_.c_str());
|
||||
/* launch */
|
||||
enqueue_impl(stream, kernel, args,
|
||||
info.global_range_size, info.num_threads);
|
||||
enqueue_impl(stream, kernel, args, info);
|
||||
}
|
||||
|
||||
}
|
||||
|
@@ -54,8 +54,7 @@ base* batchnorm_forward::clone() const {
|
||||
|
||||
void batchnorm_forward::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
std::vector<driver::buffer*> args,
|
||||
const std::vector<unsigned>&,
|
||||
size_t nthreads)
|
||||
runtime::launch_information info)
|
||||
{
|
||||
driver::buffer *y = args[0], *m = args[1], *v = args[2];
|
||||
driver::buffer *x = args[3], *g = args[4], *b = args[5];
|
||||
@@ -69,7 +68,7 @@ void batchnorm_forward::enqueue_impl(driver::stream *stream, driver::kernel *ker
|
||||
kernel->setArg(6, DHWB_);
|
||||
kernel->setArg(7, rcpDHWB_);
|
||||
kernel->setArg(8, eps_);
|
||||
stream->enqueue(kernel, grid, {nthreads, 1, 1});
|
||||
stream->enqueue(kernel, grid, {info.num_threads, 1, 1});
|
||||
}
|
||||
|
||||
void batchnorm_forward::triton_c_src(std::ostream &os) const {
|
||||
@@ -154,7 +153,7 @@ base* batchnorm_backward::clone() const {
|
||||
|
||||
void batchnorm_backward::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
std::vector<driver::buffer *> args,
|
||||
const std::vector<unsigned> &, size_t nthreads) {
|
||||
runtime::launch_information info) {
|
||||
driver::buffer *dx = args[0], *dg = args[1], *db = args[2], *dy = args[3];
|
||||
driver::buffer *x = args[4], *g = args[5], *m = args[6], *v = args[7];
|
||||
std::array<size_t, 3> grid = {1, (size_t)C_, 1};
|
||||
@@ -169,7 +168,7 @@ void batchnorm_backward::enqueue_impl(driver::stream *stream, driver::kernel *ke
|
||||
kernel->setArg(8, (int32_t)(D_*H_*W_*B_));
|
||||
kernel->setArg(9, (float)1/(D_*H_*W_*B_));
|
||||
kernel->setArg(10, eps_);
|
||||
stream->enqueue(kernel, grid, {nthreads, 1, 1});
|
||||
stream->enqueue(kernel, grid, {info.num_threads, 1, 1});
|
||||
}
|
||||
|
||||
void batchnorm_backward::triton_c_src(std::ostream &os) const {
|
||||
|
@@ -365,10 +365,9 @@ void conv::set_arg(driver::kernel *kernel,
|
||||
|
||||
void conv::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
std::vector<driver::buffer*> args,
|
||||
const std::vector<unsigned>& ranges,
|
||||
size_t nthreads) {
|
||||
runtime::launch_information info) {
|
||||
driver::buffer *a = args[0], *b = args[1], *c = args[2], *bias = args[3];
|
||||
unsigned TM = ranges[0], TN = ranges[1];
|
||||
unsigned TM = info.global_range_size[0], TN = info.global_range_size[1];
|
||||
unsigned GZ = 1;
|
||||
set_arg(kernel, a, b, c, bias);
|
||||
std::array<size_t, 3> grid = {1};
|
||||
@@ -411,7 +410,7 @@ void conv::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
kernel->setArg(38, (pad_w_ + (1 - upsample_w_)*off_uw)/upsample_w_);
|
||||
kernel->setArg(39, (off_uh + pad_h_) % upsample_h_);
|
||||
kernel->setArg(40, (off_uw + pad_w_) % upsample_w_);
|
||||
stream->enqueue(kernel, grid, {nthreads, 1, 1});
|
||||
stream->enqueue(kernel, grid, {info.num_threads, 1, 1});
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -47,11 +47,10 @@ void gemm::init_impl(driver::stream* stream, driver::cu_module *) {
|
||||
|
||||
void gemm::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
std::vector<driver::buffer*> args,
|
||||
const std::vector<unsigned>& ranges,
|
||||
size_t nthreads) {
|
||||
runtime::launch_information info) {
|
||||
driver::buffer *a = args[0], *b = args[1], *c = args[2];
|
||||
unsigned TM = ranges[0];
|
||||
unsigned TN = ranges[1];
|
||||
unsigned TM = info.global_range_size[0];
|
||||
unsigned TN = info.global_range_size[1];
|
||||
unsigned grid_0 = (M_ + TM - 1)/TM;
|
||||
unsigned grid_1 = (N_ + TN - 1)/TN;
|
||||
unsigned grid_2 = 1;
|
||||
@@ -68,7 +67,7 @@ void gemm::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
kernel->setArg(9, locks_);
|
||||
kernel->setArg(10, grid_0);
|
||||
kernel->setArg(11, grid_1);
|
||||
stream->enqueue(kernel, grid, {nthreads, 1, 1});
|
||||
stream->enqueue(kernel, grid, {info.num_threads, 1, 1});
|
||||
}
|
||||
|
||||
std::vector<unsigned> gemm::default_params() {
|
||||
|
@@ -199,7 +199,7 @@ void shift::init_impl(driver::stream *stream, driver::cu_module *module) {
|
||||
|
||||
void shift::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
std::vector<driver::buffer *> args,
|
||||
const std::vector<unsigned> &ranges, size_t nthreads) {
|
||||
runtime::launch_information info) {
|
||||
driver::buffer *a = args[0], *b = args[1], *c = args[2];
|
||||
kernel->setArg(0, a);
|
||||
kernel->setArg(1, b);
|
||||
@@ -228,13 +228,13 @@ void shift::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
kernel->setArg(24, BW_);
|
||||
kernel->setArg(25, CH_);
|
||||
kernel->setArg(26, CW_);
|
||||
unsigned TM = ranges[0], TN = ranges[1];
|
||||
unsigned TM = info.global_range_size[0], TN = info.global_range_size[1];
|
||||
std::array<size_t, 3> grid = {(M_ + TM - 1)/TM, (N_ + TN - 1)/TN, 1};
|
||||
if(op_ == BPROP){
|
||||
size_t c_nbytes = (c_ty_ == "fp16") ? 2 : 4;
|
||||
((driver::cu_buffer*)c)->set_zero(stream, AH_*AW_*B_*C_*c_nbytes);
|
||||
}
|
||||
stream->enqueue(kernel, grid, {nthreads, 1, 1});
|
||||
stream->enqueue(kernel, grid, {info.num_threads, 1, 1});
|
||||
}
|
||||
|
||||
void shift::triton_c_src(std::ostream &os) const {
|
||||
|
@@ -29,6 +29,7 @@ extern void yy_delete_buffer(YY_BUFFER_STATE buffer);
|
||||
extern triton::lang::translation_unit *ast_root;
|
||||
|
||||
namespace triton {
|
||||
namespace runtime{
|
||||
|
||||
void loop_nest(std::vector<size_t> const & ranges,
|
||||
std::function<void(std::vector<size_t> const &)> const & f,
|
||||
@@ -80,6 +81,10 @@ std::unique_ptr<llvm::Module> jit::make_llvm_module(ir::module &module, passes_w
|
||||
info.global_range_size.clear();
|
||||
for(unsigned i = 0; i < passes.tune.get_num_global_range(); i++)
|
||||
info.global_range_size.push_back(passes.tune.get_global_range_size(i));
|
||||
// add globals
|
||||
for(auto x: module.globals())
|
||||
info.globals[x.first] = ((ir::metaparameter*)x.second)->get_value();
|
||||
// number of threads
|
||||
info.num_threads = passes.tune.get_num_threads();
|
||||
return std::unique_ptr<llvm::Module>(result);
|
||||
}
|
||||
@@ -164,7 +169,7 @@ jit::tune_res_t jit::autotune(const char *name, const char *src, benchmark_t ben
|
||||
ranges.push_back(mp->get_space());
|
||||
// iterate over parameters
|
||||
tune_res_t best;
|
||||
size_t nthreads = 4;
|
||||
size_t nthreads = 1;
|
||||
std::mutex mutex;
|
||||
loop_nest<unsigned>(ranges, [&](const std::vector<unsigned> params){
|
||||
std::map<ir::value*, std::vector<std::string>> errors;
|
||||
@@ -203,10 +208,6 @@ jit::tune_res_t jit::autotune(const char *name, const char *src, benchmark_t ben
|
||||
auto ll_module = make_llvm_module(tt_module_1, passes_1, llvm_context, info);
|
||||
std::unique_ptr<driver::module> module(driver::module::create(driver_context_, &*ll_module));
|
||||
std::unique_ptr<driver::kernel> kernel(driver::kernel::create(module.get(), name));
|
||||
// add globals
|
||||
for(auto x: tt_module_1.globals())
|
||||
global_ints_[x.first] = ((ir::metaparameter*)x.second)->get_value();
|
||||
modules_.insert({name, module.get()});
|
||||
double perf;
|
||||
perf = benchmark(kernel.get(), info);
|
||||
{
|
||||
@@ -219,7 +220,6 @@ jit::tune_res_t jit::autotune(const char *name, const char *src, benchmark_t ben
|
||||
std::cout << p << " " << std::flush;
|
||||
std::cout << perf << " [ " << best.perf << " ] " << std::endl;
|
||||
}
|
||||
modules_.erase(name);
|
||||
}, nthreads);
|
||||
return best;
|
||||
}
|
||||
@@ -248,9 +248,6 @@ void jit::add_module(ir::module &tt_module, const std::vector<unsigned> ¶ms)
|
||||
auto ll_module = make_llvm_module(tt_module, passes, llvm_context_, launch_info_map_[name]);
|
||||
// llvm module -> machine code
|
||||
modules_.insert({name, driver::module::create(driver_context_, &*ll_module)});
|
||||
// add globals
|
||||
for(auto x: tt_module.globals())
|
||||
global_ints_[x.first] = ((ir::metaparameter*)x.second)->get_value();
|
||||
}
|
||||
|
||||
void jit::add_module(const char *name, const char *src, const std::vector<unsigned> ¶ms) {
|
||||
@@ -263,12 +260,10 @@ driver::kernel *jit::get_function(const char *name) {
|
||||
return driver::kernel::create(modules_.at(name), name);
|
||||
}
|
||||
|
||||
jit::launch_information jit::get_launch_info(const char *name) {
|
||||
launch_information jit::get_launch_info(const char *name) {
|
||||
return launch_info_map_.at(name);
|
||||
}
|
||||
|
||||
unsigned jit::get_int(const char *name){
|
||||
return global_ints_.at(name);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user