Driver: moved programs allocation logic to a static variable
This commit is contained in:
@@ -25,7 +25,7 @@ public:
|
|||||||
axpy(axpy::parameters_type const & parameters, binding_policy_t binding_policy = BIND_ALL_UNIQUE);
|
axpy(axpy::parameters_type const & parameters, binding_policy_t binding_policy = BIND_ALL_UNIQUE);
|
||||||
axpy(unsigned int _simd_width, unsigned int _group_size, unsigned int _num_groups, fetching_policy_type _fetching_policy, binding_policy_t binding_policy = BIND_ALL_UNIQUE);
|
axpy(unsigned int _simd_width, unsigned int _group_size, unsigned int _num_groups, fetching_policy_type _fetching_policy, binding_policy_t binding_policy = BIND_ALL_UNIQUE);
|
||||||
std::vector<int_t> input_sizes(expressions_tuple const & expressions) const;
|
std::vector<int_t> input_sizes(expressions_tuple const & expressions) const;
|
||||||
void enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback, controller<expressions_tuple> const &);
|
void enqueue(driver::CommandQueue & queue, driver::Program const & program, const char * suffix, base & fallback, controller<expressions_tuple> const &);
|
||||||
};
|
};
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@@ -182,7 +182,7 @@ public:
|
|||||||
virtual ~base();
|
virtual ~base();
|
||||||
std::string generate(const char * suffix, expressions_tuple const & expressions, driver::Device const & device);
|
std::string generate(const char * suffix, expressions_tuple const & expressions, driver::Device const & device);
|
||||||
virtual int is_invalid(expressions_tuple const & expressions, driver::Device const & device) const = 0;
|
virtual int is_invalid(expressions_tuple const & expressions, driver::Device const & device) const = 0;
|
||||||
virtual void enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback, controller<expressions_tuple> const & expressions) = 0;
|
virtual void enqueue(driver::CommandQueue & queue, driver::Program const & program, const char * suffix, base & fallback, controller<expressions_tuple> const & expressions) = 0;
|
||||||
virtual std::shared_ptr<base> clone() const = 0;
|
virtual std::shared_ptr<base> clone() const = 0;
|
||||||
private:
|
private:
|
||||||
binding_policy_t binding_policy_;
|
binding_policy_t binding_policy_;
|
||||||
|
@@ -30,7 +30,7 @@ public:
|
|||||||
dot(dot::parameters_type const & parameters, binding_policy_t binding_policy = BIND_ALL_UNIQUE);
|
dot(dot::parameters_type const & parameters, binding_policy_t binding_policy = BIND_ALL_UNIQUE);
|
||||||
dot(unsigned int simd, unsigned int ls, unsigned int ng, fetching_policy_type fetch, binding_policy_t bind = BIND_ALL_UNIQUE);
|
dot(unsigned int simd, unsigned int ls, unsigned int ng, fetching_policy_type fetch, binding_policy_t bind = BIND_ALL_UNIQUE);
|
||||||
std::vector<int_t> input_sizes(expressions_tuple const & expressions) const;
|
std::vector<int_t> input_sizes(expressions_tuple const & expressions) const;
|
||||||
void enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback, controller<expressions_tuple> const &);
|
void enqueue(driver::CommandQueue & queue, driver::Program const & program, const char * suffix, base & fallback, controller<expressions_tuple> const &);
|
||||||
private:
|
private:
|
||||||
std::vector< driver::Buffer > tmp_;
|
std::vector< driver::Buffer > tmp_;
|
||||||
std::vector< driver::Buffer > tmpidx_;
|
std::vector< driver::Buffer > tmpidx_;
|
||||||
|
@@ -47,7 +47,7 @@ private:
|
|||||||
int is_invalid_impl(driver::Device const &, expressions_tuple const &) const;
|
int is_invalid_impl(driver::Device const &, expressions_tuple const &) const;
|
||||||
std::string generate_impl(const char * suffix, expressions_tuple const & expressions, driver::Device const & device, std::vector<mapping_type> const &) const;
|
std::string generate_impl(const char * suffix, expressions_tuple const & expressions, driver::Device const & device, std::vector<mapping_type> const &) const;
|
||||||
void enqueue_block(driver::CommandQueue & queue, int_t M, int_t N, int_t K, array const & A, array const & B, array const & C,
|
void enqueue_block(driver::CommandQueue & queue, int_t M, int_t N, int_t K, array const & A, array const & B, array const & C,
|
||||||
value_scalar const &alpha, value_scalar const &beta, driver::Program & program, const char * suffix, execution_options_type const & options);
|
value_scalar const &alpha, value_scalar const &beta, driver::Program const & program, const char * suffix, execution_options_type const & options);
|
||||||
array create_slice(array & M, int_t s0_0, int_t s0_1, int_t s1_0, int_t s1_1, bool swap);
|
array create_slice(array & M, int_t s0_0, int_t s0_1, int_t s1_0, int_t s1_1, bool swap);
|
||||||
std::vector<int_t> infos(expressions_tuple const & expressions, isaac::symbolic::preset::gemm::args &arguments) const;
|
std::vector<int_t> infos(expressions_tuple const & expressions, isaac::symbolic::preset::gemm::args &arguments) const;
|
||||||
public:
|
public:
|
||||||
@@ -55,7 +55,7 @@ public:
|
|||||||
std::vector<int_t> input_sizes(expressions_tuple const & expressions) const;
|
std::vector<int_t> input_sizes(expressions_tuple const & expressions) const;
|
||||||
void cleanup(values_holder beta, controller<expressions_tuple> const & ctr, model & fallback,
|
void cleanup(values_holder beta, controller<expressions_tuple> const & ctr, model & fallback,
|
||||||
lhs_rhs_element* eA, lhs_rhs_element* eB, lhs_rhs_element* eC, lhs_rhs_element* ebeta, array const & A, array const & B, array const & C);
|
lhs_rhs_element* eA, lhs_rhs_element* eB, lhs_rhs_element* eC, lhs_rhs_element* ebeta, array const & A, array const & B, array const & C);
|
||||||
void enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback, controller<expressions_tuple> const &ctr);
|
void enqueue(driver::CommandQueue & queue, driver::Program const & program, const char * suffix, base & fallback, controller<expressions_tuple> const &ctr);
|
||||||
private:
|
private:
|
||||||
const char A_trans_;
|
const char A_trans_;
|
||||||
const char B_trans_;
|
const char B_trans_;
|
||||||
|
@@ -36,7 +36,7 @@ private:
|
|||||||
std::string generate_impl(const char * suffix, expressions_tuple const &, driver::Device const & device, std::vector<mapping_type> const &) const;
|
std::string generate_impl(const char * suffix, expressions_tuple const &, driver::Device const & device, std::vector<mapping_type> const &) const;
|
||||||
public:
|
public:
|
||||||
virtual std::vector<int_t> input_sizes(expressions_tuple const & expressions) const;
|
virtual std::vector<int_t> input_sizes(expressions_tuple const & expressions) const;
|
||||||
void enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback, controller<expressions_tuple> const &);
|
void enqueue(driver::CommandQueue & queue, driver::Program const & program, const char * suffix, base & fallback, controller<expressions_tuple> const &);
|
||||||
private:
|
private:
|
||||||
dot_type dot_type_;
|
dot_type dot_type_;
|
||||||
};
|
};
|
||||||
|
@@ -28,7 +28,7 @@ public:
|
|||||||
ger(parameters_type const & parameters, binding_policy_t binding_policy = BIND_ALL_UNIQUE);
|
ger(parameters_type const & parameters, binding_policy_t binding_policy = BIND_ALL_UNIQUE);
|
||||||
ger(unsigned int simd, unsigned int ls1, unsigned int ls2, unsigned int ng1, unsigned int ng2, fetching_policy_type fetch, binding_policy_t bind = BIND_ALL_UNIQUE);
|
ger(unsigned int simd, unsigned int ls1, unsigned int ls2, unsigned int ng1, unsigned int ng2, fetching_policy_type fetch, binding_policy_t bind = BIND_ALL_UNIQUE);
|
||||||
std::vector<int_t> input_sizes(expressions_tuple const & expressions) const;
|
std::vector<int_t> input_sizes(expressions_tuple const & expressions) const;
|
||||||
void enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback, controller<expressions_tuple> const &);
|
void enqueue(driver::CommandQueue & queue, driver::Program const & program, const char * suffix, base & fallback, controller<expressions_tuple> const &);
|
||||||
};
|
};
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@@ -38,6 +38,10 @@ public:
|
|||||||
Device(int ordinal);
|
Device(int ordinal);
|
||||||
#endif
|
#endif
|
||||||
Device(cl_device_id const & device, bool take_ownership = true);
|
Device(cl_device_id const & device, bool take_ownership = true);
|
||||||
|
|
||||||
|
bool operator==(Device const &) const;
|
||||||
|
bool operator<(Device const &) const;
|
||||||
|
|
||||||
backend_type backend() const;
|
backend_type backend() const;
|
||||||
size_t clock_rate() const;
|
size_t clock_rate() const;
|
||||||
unsigned int address_bits() const;
|
unsigned int address_bits() const;
|
||||||
|
@@ -3,7 +3,6 @@
|
|||||||
|
|
||||||
#include "isaac/defines.h"
|
#include "isaac/defines.h"
|
||||||
#include "isaac/driver/common.h"
|
#include "isaac/driver/common.h"
|
||||||
#include "isaac/driver/context.h"
|
|
||||||
#include "isaac/driver/handle.h"
|
#include "isaac/driver/handle.h"
|
||||||
|
|
||||||
namespace isaac
|
namespace isaac
|
||||||
@@ -13,6 +12,7 @@ namespace driver
|
|||||||
{
|
{
|
||||||
|
|
||||||
class Context;
|
class Context;
|
||||||
|
class Device;
|
||||||
|
|
||||||
class ISAACAPI Program
|
class ISAACAPI Program
|
||||||
{
|
{
|
||||||
@@ -22,11 +22,20 @@ public:
|
|||||||
Context const & context() const;
|
Context const & context() const;
|
||||||
private:
|
private:
|
||||||
backend_type backend_;
|
backend_type backend_;
|
||||||
Context context_;
|
Context const & context_;
|
||||||
std::string source_;
|
std::string source_;
|
||||||
HANDLE_TYPE(cl_program, CUmodule) h_;
|
HANDLE_TYPE(cl_program, CUmodule) h_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class ISAACAPI ProgramsHandler
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
static Program const & add(Context const & scontext, std::string const & name, std::string const & src);
|
||||||
|
static Program const * find(Context const & context, std::string const & name);
|
||||||
|
private:
|
||||||
|
static std::map<driver::Context, std::map<std::string, Program> > programs_;
|
||||||
|
};
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@@ -20,7 +20,7 @@ namespace isaac
|
|||||||
private:
|
private:
|
||||||
std::string define_extension(std::string const & extensions, std::string const & ext);
|
std::string define_extension(std::string const & extensions, std::string const & ext);
|
||||||
inline void fill_program_name(char* program_name, expressions_tuple const & expressions, binding_policy_t binding_policy);
|
inline void fill_program_name(char* program_name, expressions_tuple const & expressions, binding_policy_t binding_policy);
|
||||||
driver::Program& init(controller<expressions_tuple> const &);
|
driver::Program const & init(controller<expressions_tuple> const &);
|
||||||
|
|
||||||
public:
|
public:
|
||||||
model(expression_type, numeric_type, predictors::random_forest const &, std::vector< std::shared_ptr<templates::base> > const &, driver::CommandQueue const &);
|
model(expression_type, numeric_type, predictors::random_forest const &, std::vector< std::shared_ptr<templates::base> > const &, driver::CommandQueue const &);
|
||||||
@@ -29,15 +29,11 @@ namespace isaac
|
|||||||
void execute(controller<expressions_tuple> const &);
|
void execute(controller<expressions_tuple> const &);
|
||||||
templates_container const & templates() const;
|
templates_container const & templates() const;
|
||||||
|
|
||||||
void test() const
|
|
||||||
{ std::cout << queue_.device().backend() << std::endl;}
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
templates_container templates_;
|
templates_container templates_;
|
||||||
template_pointer fallback_;
|
template_pointer fallback_;
|
||||||
std::shared_ptr<predictors::random_forest> predictor_;
|
std::shared_ptr<predictors::random_forest> predictor_;
|
||||||
std::map<std::vector<int_t>, int> hardcoded_;
|
std::map<std::vector<int_t>, int> hardcoded_;
|
||||||
std::map<driver::Context, std::map<std::string, std::shared_ptr<driver::Program> > > programs_;
|
|
||||||
driver::CommandQueue queue_;
|
driver::CommandQueue queue_;
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -47,7 +43,7 @@ namespace isaac
|
|||||||
model_map_t& models(driver::CommandQueue & queue);
|
model_map_t& models(driver::CommandQueue & queue);
|
||||||
|
|
||||||
extern std::map<std::pair<expression_type, numeric_type>, std::shared_ptr<templates::base> > fallbacks;
|
extern std::map<std::pair<expression_type, numeric_type>, std::shared_ptr<templates::base> > fallbacks;
|
||||||
extern std::map<driver::CommandQueue, model_map_t> models_;
|
extern std::map<driver::Device, model_map_t> models_;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -110,7 +110,7 @@ std::vector<int_t> axpy::input_sizes(expressions_tuple const & expressions) cons
|
|||||||
return tools::make_vector<int_t>() << std::max(shape[0], shape[1]);
|
return tools::make_vector<int_t>() << std::max(shape[0], shape[1]);
|
||||||
}
|
}
|
||||||
|
|
||||||
void axpy::enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback, controller<expressions_tuple> const & controller)
|
void axpy::enqueue(driver::CommandQueue & queue, driver::Program const & program, const char * suffix, base & fallback, controller<expressions_tuple> const & controller)
|
||||||
{
|
{
|
||||||
expressions_tuple const & expressions = controller.x();
|
expressions_tuple const & expressions = controller.x();
|
||||||
//Size
|
//Size
|
||||||
|
@@ -279,7 +279,7 @@ std::vector<int_t> dot::input_sizes(expressions_tuple const & expressions) const
|
|||||||
return tools::make_vector<int_t>() << N;
|
return tools::make_vector<int_t>() << N;
|
||||||
}
|
}
|
||||||
|
|
||||||
void dot::enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback, controller<expressions_tuple> const & controller)
|
void dot::enqueue(driver::CommandQueue & queue, driver::Program const & program, const char * suffix, base & fallback, controller<expressions_tuple> const & controller)
|
||||||
{
|
{
|
||||||
expressions_tuple const & expressions = controller.x();
|
expressions_tuple const & expressions = controller.x();
|
||||||
|
|
||||||
|
@@ -574,7 +574,7 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
|
|||||||
void gemm::enqueue_block(driver::CommandQueue & /*queue*/, int_t M, int_t N, int_t K,
|
void gemm::enqueue_block(driver::CommandQueue & /*queue*/, int_t M, int_t N, int_t K,
|
||||||
array const & A, array const & B, array const & C,
|
array const & A, array const & B, array const & C,
|
||||||
value_scalar const & alpha, value_scalar const & beta,
|
value_scalar const & alpha, value_scalar const & beta,
|
||||||
driver::Program & program, const char * suffix, execution_options_type const & options)
|
driver::Program const & program, const char * suffix, execution_options_type const & options)
|
||||||
{
|
{
|
||||||
using tools::align;
|
using tools::align;
|
||||||
|
|
||||||
@@ -685,7 +685,7 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
|
|||||||
return infos(expressions, dummy);
|
return infos(expressions, dummy);
|
||||||
}
|
}
|
||||||
|
|
||||||
void gemm::enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback_base, controller<expressions_tuple> const & ctr)
|
void gemm::enqueue(driver::CommandQueue & queue, driver::Program const & program, const char * suffix, base & fallback_base, controller<expressions_tuple> const & ctr)
|
||||||
{
|
{
|
||||||
using namespace tools;
|
using namespace tools;
|
||||||
// std::cout << p_.simd_width << " " << p_.mL << " " << p_.kL << " " << p_.mS << " " << p_.depth << " " << p_.local_size_0 << std::endl;
|
// std::cout << p_.simd_width << " " << p_.mL << " " << p_.kL << " " << p_.mS << " " << p_.depth << " " << p_.local_size_0 << std::endl;
|
||||||
|
@@ -336,7 +336,7 @@ std::vector<int_t> gemv::input_sizes(expressions_tuple const & expressions) cons
|
|||||||
return tools::make_vector<int_t>() << MN.first << MN.second;
|
return tools::make_vector<int_t>() << MN.first << MN.second;
|
||||||
}
|
}
|
||||||
|
|
||||||
void gemv::enqueue(driver::CommandQueue & queue, driver::Program & program, const char * suffix, base & fallback, controller<expressions_tuple> const & controller)
|
void gemv::enqueue(driver::CommandQueue & queue, driver::Program const & program, const char * suffix, base & fallback, controller<expressions_tuple> const & controller)
|
||||||
{
|
{
|
||||||
expressions_tuple const & expressions = controller.x();
|
expressions_tuple const & expressions = controller.x();
|
||||||
driver::Context const & context = expressions.context();
|
driver::Context const & context = expressions.context();
|
||||||
|
@@ -114,7 +114,7 @@ std::vector<int_t> ger::input_sizes(expressions_tuple const & expressions) const
|
|||||||
return tools::make_vector<int_t>() << size.first << size.second;
|
return tools::make_vector<int_t>() << size.first << size.second;
|
||||||
}
|
}
|
||||||
|
|
||||||
void ger::enqueue(driver::CommandQueue & /*queue*/, driver::Program & program, const char * suffix, base &, controller<expressions_tuple> const & controller)
|
void ger::enqueue(driver::CommandQueue & /*queue*/, driver::Program const & program, const char * suffix, base &, controller<expressions_tuple> const & controller)
|
||||||
{
|
{
|
||||||
expressions_tuple const & expressions = controller.x();
|
expressions_tuple const & expressions = controller.x();
|
||||||
char name[32] = {"axpy"};
|
char name[32] = {"axpy"};
|
||||||
|
@@ -48,10 +48,14 @@ Context::Context(Device const & device) : backend_(device.backend_), device_(dev
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool Context::operator==(Context const & other) const
|
bool Context::operator==(Context const & other) const
|
||||||
{ return h_==other.h_; }
|
{
|
||||||
|
return h_==other.h_;
|
||||||
|
}
|
||||||
|
|
||||||
bool Context::operator<(Context const & other) const
|
bool Context::operator<(Context const & other) const
|
||||||
{ return h_<other.h_; }
|
{
|
||||||
|
return h_<other.h_;
|
||||||
|
}
|
||||||
|
|
||||||
Device const & Context::device() const
|
Device const & Context::device() const
|
||||||
{ return device_; }
|
{ return device_; }
|
||||||
|
@@ -27,6 +27,18 @@ Device::Device(int ordinal): backend_(CUDA), h_(backend_, true)
|
|||||||
Device::Device(cl_device_id const & device, bool take_ownership) : backend_(OPENCL), h_(backend_, take_ownership)
|
Device::Device(cl_device_id const & device, bool take_ownership) : backend_(OPENCL), h_(backend_, take_ownership)
|
||||||
{ h_.cl() = device; }
|
{ h_.cl() = device; }
|
||||||
|
|
||||||
|
|
||||||
|
bool Device::operator==(Device const & other) const
|
||||||
|
{
|
||||||
|
return h_==other.h_;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Device::operator<(Device const & other) const
|
||||||
|
{
|
||||||
|
return h_<other.h_;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
backend_type Device::backend() const
|
backend_type Device::backend() const
|
||||||
{ return backend_; }
|
{ return backend_; }
|
||||||
|
|
||||||
|
@@ -159,6 +159,32 @@ Program::Program(Context const & context, std::string const & source) : backend_
|
|||||||
Context const & Program::context() const
|
Context const & Program::context() const
|
||||||
{ return context_; }
|
{ return context_; }
|
||||||
|
|
||||||
|
Program const & ProgramsHandler::add(Context const & context, std::string const & name, std::string const & src)
|
||||||
|
{
|
||||||
|
std::map<std::string, Program> & pgms = programs_[context];
|
||||||
|
std::map<std::string, Program>::iterator it = pgms.find(name);
|
||||||
|
if(it==pgms.end())
|
||||||
|
{
|
||||||
|
std::string extensions;
|
||||||
|
std::string ext = "cl_khr_fp64";
|
||||||
|
if(context.device().extensions().find(ext)!=std::string::npos)
|
||||||
|
extensions = "#pragma OPENCL EXTENSION " + ext + " : enable\n";
|
||||||
|
return pgms.insert(std::make_pair(name, driver::Program(context, extensions + src))).first->second;
|
||||||
|
}
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
const Program * ProgramsHandler::find(Context const & context, const std::string &name)
|
||||||
|
{
|
||||||
|
std::map<std::string, Program> & pgms = programs_[context];
|
||||||
|
std::map<std::string, Program>::const_iterator it = pgms.find(name);
|
||||||
|
if(it==pgms.end())
|
||||||
|
return NULL;
|
||||||
|
return &it->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::map<driver::Context, std::map<std::string, Program>> ProgramsHandler::programs_;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@@ -12,6 +12,7 @@
|
|||||||
#include "isaac/backend/templates/ger.h"
|
#include "isaac/backend/templates/ger.h"
|
||||||
#include "isaac/backend/templates/gemv.h"
|
#include "isaac/backend/templates/gemv.h"
|
||||||
#include "isaac/backend/templates/gemm.h"
|
#include "isaac/backend/templates/gemm.h"
|
||||||
|
#include "isaac/driver/program.h"
|
||||||
#include "isaac/exception/unknown_datatype.h"
|
#include "isaac/exception/unknown_datatype.h"
|
||||||
#include "isaac/exception/operation_not_supported.h"
|
#include "isaac/exception/operation_not_supported.h"
|
||||||
#include "isaac/model/model.h"
|
#include "isaac/model/model.h"
|
||||||
@@ -26,14 +27,6 @@ namespace isaac
|
|||||||
static double time_event(unsigned long sum, driver::Event const & e)
|
static double time_event(unsigned long sum, driver::Event const & e)
|
||||||
{ return sum + e.elapsed_time();}
|
{ return sum + e.elapsed_time();}
|
||||||
|
|
||||||
|
|
||||||
std::string model::define_extension(std::string const & extensions, std::string const & ext)
|
|
||||||
{
|
|
||||||
if(extensions.find(ext)!=std::string::npos)
|
|
||||||
return std::string("#pragma OPENCL EXTENSION " + ext + " : enable\n");
|
|
||||||
return std::string("");
|
|
||||||
}
|
|
||||||
|
|
||||||
void model::fill_program_name(char* program_name, expressions_tuple const & expressions, binding_policy_t binding_policy)
|
void model::fill_program_name(char* program_name, expressions_tuple const & expressions, binding_policy_t binding_policy)
|
||||||
{
|
{
|
||||||
if (expressions.order()==expressions_tuple::INDEPENDENT)
|
if (expressions.order()==expressions_tuple::INDEPENDENT)
|
||||||
@@ -51,9 +44,9 @@ void model::fill_program_name(char* program_name, expressions_tuple const & expr
|
|||||||
delete binder;
|
delete binder;
|
||||||
}
|
}
|
||||||
|
|
||||||
driver::Program& model::init(controller<expressions_tuple> const & expressions)
|
driver::Program const & model::init(controller<expressions_tuple> const & expressions)
|
||||||
{
|
{
|
||||||
driver::Context const & context = expressions.x().context();
|
driver::Context & context = (driver::Context&)expressions.x().context();
|
||||||
std::string pname;
|
std::string pname;
|
||||||
compilation_options_type const & opt = expressions.compilation_options();
|
compilation_options_type const & opt = expressions.compilation_options();
|
||||||
if(opt.program_name.empty())
|
if(opt.program_name.empty())
|
||||||
@@ -65,24 +58,18 @@ driver::Program& model::init(controller<expressions_tuple> const & expressions)
|
|||||||
else
|
else
|
||||||
pname = expressions.compilation_options().program_name;
|
pname = expressions.compilation_options().program_name;
|
||||||
|
|
||||||
|
driver::Program const * program = driver::ProgramsHandler::find(context, pname);
|
||||||
|
if(program)
|
||||||
|
return *program;
|
||||||
|
|
||||||
std::shared_ptr<driver::Program> & program = programs_[context][pname];
|
std::string srcs;
|
||||||
if(!program)
|
for(unsigned int i = 0 ; i < templates_.size() ; ++i){
|
||||||
{
|
char buffer[16];
|
||||||
driver::Device device = queue_.device();
|
sprintf(buffer,"%d",i);
|
||||||
std::string extensions = device.extensions();
|
srcs += templates_[i]->generate(buffer, expressions.x(), context.device());
|
||||||
std::string all_extensions = define_extension(extensions, "cl_khr_fp64");
|
}
|
||||||
|
srcs += fallback_->generate("fallback", expressions.x(), context.device());
|
||||||
std::string srcs;
|
return driver::ProgramsHandler::add(context, pname, srcs);
|
||||||
for(unsigned int i = 0 ; i < templates_.size() ; ++i){
|
|
||||||
char buffer[16];
|
|
||||||
sprintf(buffer,"%d",i);
|
|
||||||
srcs += templates_[i]->generate(buffer, expressions.x(), device);
|
|
||||||
}
|
|
||||||
srcs += fallback_->generate("fallback", expressions.x(), device);
|
|
||||||
program.reset(new driver::Program(context, all_extensions + srcs));
|
|
||||||
}
|
|
||||||
return *program;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
model::model(expression_type etype, numeric_type dtype, predictors::random_forest const & predictor, std::vector< std::shared_ptr<templates::base> > const & templates, driver::CommandQueue const & queue) :
|
model::model(expression_type etype, numeric_type dtype, predictors::random_forest const & predictor, std::vector< std::shared_ptr<templates::base> > const & templates, driver::CommandQueue const & queue) :
|
||||||
@@ -95,7 +82,7 @@ model::model(expression_type etype, numeric_type dtype, templates::base const &
|
|||||||
|
|
||||||
void model::execute(controller<expressions_tuple> const & expr)
|
void model::execute(controller<expressions_tuple> const & expr)
|
||||||
{
|
{
|
||||||
driver::Program & program = init(expr);
|
driver::Program const & program = init(expr);
|
||||||
std::vector<int_t> x = templates_[0]->input_sizes(expr.x());
|
std::vector<int_t> x = templates_[0]->input_sizes(expr.x());
|
||||||
|
|
||||||
//Specific tuning if requested
|
//Specific tuning if requested
|
||||||
@@ -280,13 +267,12 @@ model_map_t init_models(driver::CommandQueue & queue)
|
|||||||
|
|
||||||
model_map_t& models(driver::CommandQueue & queue)
|
model_map_t& models(driver::CommandQueue & queue)
|
||||||
{
|
{
|
||||||
std::map<driver::CommandQueue, model_map_t>::iterator it = models_.find(queue);
|
static std::map<driver::Device, model_map_t> models_;
|
||||||
|
std::map<driver::Device, model_map_t>::iterator it = models_.find(queue.device());
|
||||||
if(it == models_.end())
|
if(it == models_.end())
|
||||||
return models_.insert(std::make_pair(queue, init_models(queue))).first->second;
|
return models_.insert(std::make_pair(queue.device(), init_models(queue))).first->second;
|
||||||
return it->second;
|
return it->second;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::map<std::pair<expression_type, numeric_type>, std::shared_ptr<templates::base> > fallbacks = init_fallback();
|
std::map<std::pair<expression_type, numeric_type>, std::shared_ptr<templates::base> > fallbacks = init_fallback();
|
||||||
std::map<driver::CommandQueue, model_map_t> models_;
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user