removing C++11 interface

This commit is contained in:
Philippe Tillet
2015-02-08 23:19:38 -05:00
parent 85fb438806
commit a6d7671831
21 changed files with 423 additions and 956 deletions

View File

@@ -82,11 +82,11 @@ std::vector<cl_ext::lazy_compiler>& model::init(controller<expressions_tuple> co
return to_init;
}
model::model(predictors::random_forest const & predictor, std::vector< std::shared_ptr<base> > const & templates, cl::CommandQueue & queue) :
model::model(predictors::random_forest const & predictor, std::vector< tools::shared_ptr<base> > const & templates, cl::CommandQueue & queue) :
templates_(templates), predictor_(new predictors::random_forest(predictor)), queue_(queue)
{}
model::model(std::vector< std::shared_ptr<base> > const & templates, cl::CommandQueue & queue) : templates_(templates), queue_(queue)
model::model(std::vector< tools::shared_ptr<base> > const & templates, cl::CommandQueue & queue) : templates_(templates), queue_(queue)
{}
model::model(base const & tp, cl::CommandQueue & queue) : templates_(1,tp.clone()), queue_(queue)
@@ -158,27 +158,27 @@ namespace detail
throw std::invalid_argument("Invalid datatype: " + name);
}
static std::shared_ptr<base> create(std::string const & template_name, std::vector<int> const & a)
static tools::shared_ptr<base> create(std::string const & template_name, std::vector<int> const & a)
{
fetching_policy_type fetch[] = {FETCH_FROM_LOCAL, FETCH_FROM_GLOBAL_STRIDED, FETCH_FROM_GLOBAL_CONTIGUOUS};
if(template_name=="vaxpy")
return std::shared_ptr<base>(new vaxpy(a[0], a[1], a[2], fetch[a[3]]));
return tools::shared_ptr<base>(new vaxpy(a[0], a[1], a[2], fetch[a[3]]));
else if(template_name=="dot")
return std::shared_ptr<base>(new reduction(a[0], a[1], a[2], fetch[a[3]]));
return tools::shared_ptr<base>(new reduction(a[0], a[1], a[2], fetch[a[3]]));
else if(template_name=="maxpy")
return std::shared_ptr<base>(new maxpy(a[0], a[1], a[2], a[3], a[4], fetch[a[5]]));
return tools::shared_ptr<base>(new maxpy(a[0], a[1], a[2], a[3], a[4], fetch[a[5]]));
else if(template_name.find("gemvN")!=std::string::npos)
return std::shared_ptr<base>(new mreduction_rows(a[0], a[1], a[2], a[3], fetch[a[4]]));
return tools::shared_ptr<base>(new mreduction_rows(a[0], a[1], a[2], a[3], fetch[a[4]]));
else if(template_name.find("gemvT")!=std::string::npos)
return std::shared_ptr<base>(new mreduction_cols(a[0], a[1], a[2], a[3], fetch[a[4]]));
return tools::shared_ptr<base>(new mreduction_cols(a[0], a[1], a[2], a[3], fetch[a[4]]));
else if(template_name.find("gemmNN")!=std::string::npos)
return std::shared_ptr<base>(new mproduct_nn(a[0], a[1], a[2], a[3], a[4], a[5], a[6], fetch[a[7]], fetch[a[8]], a[9], a[10]));
return tools::shared_ptr<base>(new mproduct_nn(a[0], a[1], a[2], a[3], a[4], a[5], a[6], fetch[a[7]], fetch[a[8]], a[9], a[10]));
else if(template_name.find("gemmTN")!=std::string::npos)
return std::shared_ptr<base>(new mproduct_tn(a[0], a[1], a[2], a[3], a[4], a[5], a[6], fetch[a[7]], fetch[a[8]], a[9], a[10]));
return tools::shared_ptr<base>(new mproduct_tn(a[0], a[1], a[2], a[3], a[4], a[5], a[6], fetch[a[7]], fetch[a[8]], a[9], a[10]));
else if(template_name.find("gemmNT")!=std::string::npos)
return std::shared_ptr<base>(new mproduct_nt(a[0], a[1], a[2], a[3], a[4], a[5], a[6], fetch[a[7]], fetch[a[8]], a[9], a[10]));
return tools::shared_ptr<base>(new mproduct_nt(a[0], a[1], a[2], a[3], a[4], a[5], a[6], fetch[a[7]], fetch[a[8]], a[9], a[10]));
else if(template_name.find("gemmTT")!=std::string::npos)
return std::shared_ptr<base>(new mproduct_tt(a[0], a[1], a[2], a[3], a[4], a[5], a[6], fetch[a[7]], fetch[a[8]], a[9], a[10]));
return tools::shared_ptr<base>(new mproduct_tt(a[0], a[1], a[2], a[3], a[4], a[5], a[6], fetch[a[7]], fetch[a[8]], a[9], a[10]));
else
throw std::invalid_argument("Invalid expression: " + template_name);
}
@@ -214,7 +214,7 @@ void import(std::string const & fname, cl::CommandQueue & queue, model_map_t& re
numeric_type dtype = detail::get_dtype(elem);
// Get profiles
std::vector<std::shared_ptr<base> > templates;
std::vector<tools::shared_ptr<base> > templates;
js::Value const & profiles = document[opcstr][dtcstr]["profiles"];
for (js::SizeType id = 0 ; id < profiles.Size() ; ++id)
templates.push_back(detail::create(operation, tools::to_int_array<int>(profiles[id])));
@@ -222,10 +222,10 @@ void import(std::string const & fname, cl::CommandQueue & queue, model_map_t& re
{
// Get predictor
predictors::random_forest predictor(document[opcstr][dtcstr]["predictor"]);
result[std::make_pair(etype, dtype)] = std::shared_ptr<model>(new model(predictor, templates, queue));
result[std::make_pair(etype, dtype)] = tools::shared_ptr<model>(new model(predictor, templates, queue));
}
else
result[std::make_pair(etype, dtype)] = std::shared_ptr<model>(new model(templates, queue));
result[std::make_pair(etype, dtype)] = tools::shared_ptr<model>(new model(templates, queue));
}
}
}
@@ -235,7 +235,7 @@ void import(std::string const & fname, cl::CommandQueue & queue, model_map_t& re
model_map_t init_models(cl::CommandQueue & queue)
{
model_map_t res;
typedef std::shared_ptr<model> ptr_t;
typedef tools::shared_ptr<model> ptr_t;
numeric_type types[] = {CHAR_TYPE, UCHAR_TYPE, SHORT_TYPE, USHORT_TYPE, INT_TYPE, UINT_TYPE, LONG_TYPE, ULONG_TYPE, FLOAT_TYPE, DOUBLE_TYPE};
for(auto DTYPE : types){