Driver: moved programs allocation logic to a static variable
This commit is contained in:
@@ -12,6 +12,7 @@
|
||||
#include "isaac/backend/templates/ger.h"
|
||||
#include "isaac/backend/templates/gemv.h"
|
||||
#include "isaac/backend/templates/gemm.h"
|
||||
#include "isaac/driver/program.h"
|
||||
#include "isaac/exception/unknown_datatype.h"
|
||||
#include "isaac/exception/operation_not_supported.h"
|
||||
#include "isaac/model/model.h"
|
||||
@@ -26,14 +27,6 @@ namespace isaac
|
||||
static double time_event(unsigned long sum, driver::Event const & e)
|
||||
{ return sum + e.elapsed_time();}
|
||||
|
||||
|
||||
std::string model::define_extension(std::string const & extensions, std::string const & ext)
|
||||
{
|
||||
if(extensions.find(ext)!=std::string::npos)
|
||||
return std::string("#pragma OPENCL EXTENSION " + ext + " : enable\n");
|
||||
return std::string("");
|
||||
}
|
||||
|
||||
void model::fill_program_name(char* program_name, expressions_tuple const & expressions, binding_policy_t binding_policy)
|
||||
{
|
||||
if (expressions.order()==expressions_tuple::INDEPENDENT)
|
||||
@@ -51,9 +44,9 @@ void model::fill_program_name(char* program_name, expressions_tuple const & expr
|
||||
delete binder;
|
||||
}
|
||||
|
||||
driver::Program& model::init(controller<expressions_tuple> const & expressions)
|
||||
driver::Program const & model::init(controller<expressions_tuple> const & expressions)
|
||||
{
|
||||
driver::Context const & context = expressions.x().context();
|
||||
driver::Context & context = (driver::Context&)expressions.x().context();
|
||||
std::string pname;
|
||||
compilation_options_type const & opt = expressions.compilation_options();
|
||||
if(opt.program_name.empty())
|
||||
@@ -65,24 +58,18 @@ driver::Program& model::init(controller<expressions_tuple> const & expressions)
|
||||
else
|
||||
pname = expressions.compilation_options().program_name;
|
||||
|
||||
driver::Program const * program = driver::ProgramsHandler::find(context, pname);
|
||||
if(program)
|
||||
return *program;
|
||||
|
||||
std::shared_ptr<driver::Program> & program = programs_[context][pname];
|
||||
if(!program)
|
||||
{
|
||||
driver::Device device = queue_.device();
|
||||
std::string extensions = device.extensions();
|
||||
std::string all_extensions = define_extension(extensions, "cl_khr_fp64");
|
||||
|
||||
std::string srcs;
|
||||
for(unsigned int i = 0 ; i < templates_.size() ; ++i){
|
||||
char buffer[16];
|
||||
sprintf(buffer,"%d",i);
|
||||
srcs += templates_[i]->generate(buffer, expressions.x(), device);
|
||||
}
|
||||
srcs += fallback_->generate("fallback", expressions.x(), device);
|
||||
program.reset(new driver::Program(context, all_extensions + srcs));
|
||||
}
|
||||
return *program;
|
||||
std::string srcs;
|
||||
for(unsigned int i = 0 ; i < templates_.size() ; ++i){
|
||||
char buffer[16];
|
||||
sprintf(buffer,"%d",i);
|
||||
srcs += templates_[i]->generate(buffer, expressions.x(), context.device());
|
||||
}
|
||||
srcs += fallback_->generate("fallback", expressions.x(), context.device());
|
||||
return driver::ProgramsHandler::add(context, pname, srcs);
|
||||
}
|
||||
|
||||
model::model(expression_type etype, numeric_type dtype, predictors::random_forest const & predictor, std::vector< std::shared_ptr<templates::base> > const & templates, driver::CommandQueue const & queue) :
|
||||
@@ -95,7 +82,7 @@ model::model(expression_type etype, numeric_type dtype, templates::base const &
|
||||
|
||||
void model::execute(controller<expressions_tuple> const & expr)
|
||||
{
|
||||
driver::Program & program = init(expr);
|
||||
driver::Program const & program = init(expr);
|
||||
std::vector<int_t> x = templates_[0]->input_sizes(expr.x());
|
||||
|
||||
//Specific tuning if requested
|
||||
@@ -280,13 +267,12 @@ model_map_t init_models(driver::CommandQueue & queue)
|
||||
|
||||
model_map_t& models(driver::CommandQueue & queue)
|
||||
{
|
||||
std::map<driver::CommandQueue, model_map_t>::iterator it = models_.find(queue);
|
||||
static std::map<driver::Device, model_map_t> models_;
|
||||
std::map<driver::Device, model_map_t>::iterator it = models_.find(queue.device());
|
||||
if(it == models_.end())
|
||||
return models_.insert(std::make_pair(queue, init_models(queue))).first->second;
|
||||
return models_.insert(std::make_pair(queue.device(), init_models(queue))).first->second;
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::map<std::pair<expression_type, numeric_type>, std::shared_ptr<templates::base> > fallbacks = init_fallback();
|
||||
std::map<driver::CommandQueue, model_map_t> models_;
|
||||
|
||||
}
|
||||
|
Reference in New Issue
Block a user