Code quality: removed tools::shared_ptr<>

This commit is contained in:
Philippe Tillet
2015-07-28 15:26:10 -07:00
parent 0434ac551c
commit 9c15debf8b
10 changed files with 70 additions and 231 deletions

View File

@@ -85,7 +85,7 @@ driver::Program& model::init(controller<expressions_tuple> const & expressions)
return *program;
}
model::model(expression_type etype, numeric_type dtype, predictors::random_forest const & predictor, std::vector< tools::shared_ptr<templates::base> > const & templates, driver::CommandQueue const & queue) :
model::model(expression_type etype, numeric_type dtype, predictors::random_forest const & predictor, std::vector< std::shared_ptr<templates::base> > const & templates, driver::CommandQueue const & queue) :
templates_(templates), fallback_(fallbacks[std::make_pair(etype, dtype)]), predictor_(new predictors::random_forest(predictor)), queue_(queue)
{}
@@ -166,27 +166,27 @@ namespace detail
throw std::invalid_argument("Invalid datatype: " + name);
}
static tools::shared_ptr<templates::base> create(std::string const & template_name, std::vector<int> const & a)
static std::shared_ptr<templates::base> create(std::string const & template_name, std::vector<int> const & a)
{
templates::fetching_policy_type fetch[] = {templates::FETCH_FROM_LOCAL, templates::FETCH_FROM_GLOBAL_STRIDED, templates::FETCH_FROM_GLOBAL_CONTIGUOUS};
if(template_name=="axpy")
return tools::shared_ptr<templates::base>(new templates::axpy(a[0], a[1], a[2], fetch[a[3]]));
return std::shared_ptr<templates::base>(new templates::axpy(a[0], a[1], a[2], fetch[a[3]]));
else if(template_name=="dot")
return tools::shared_ptr<templates::base>(new templates::dot(a[0], a[1], a[2], fetch[a[3]]));
return std::shared_ptr<templates::base>(new templates::dot(a[0], a[1], a[2], fetch[a[3]]));
else if(template_name=="ger")
return tools::shared_ptr<templates::base>(new templates::ger(a[0], a[1], a[2], a[3], a[4], fetch[a[5]]));
return std::shared_ptr<templates::base>(new templates::ger(a[0], a[1], a[2], a[3], a[4], fetch[a[5]]));
else if(template_name.find("gemv_n")!=std::string::npos)
return tools::shared_ptr<templates::base>(new templates::gemv_n(a[0], a[1], a[2], a[3], a[4], fetch[a[5]]));
return std::shared_ptr<templates::base>(new templates::gemv_n(a[0], a[1], a[2], a[3], a[4], fetch[a[5]]));
else if(template_name.find("gemv_t")!=std::string::npos)
return tools::shared_ptr<templates::base>(new templates::gemv_t(a[0], a[1], a[2], a[3], a[4], fetch[a[5]]));
return std::shared_ptr<templates::base>(new templates::gemv_t(a[0], a[1], a[2], a[3], a[4], fetch[a[5]]));
else if(template_name.find("gemm_nn")!=std::string::npos)
return tools::shared_ptr<templates::base>(new templates::gemm_nn(a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], fetch[a[8]], fetch[a[9]], a[10], a[11]));
return std::shared_ptr<templates::base>(new templates::gemm_nn(a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], fetch[a[8]], fetch[a[9]], a[10], a[11]));
else if(template_name.find("gemm_tn")!=std::string::npos)
return tools::shared_ptr<templates::base>(new templates::gemm_tn(a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], fetch[a[8]], fetch[a[9]], a[10], a[11]));
return std::shared_ptr<templates::base>(new templates::gemm_tn(a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], fetch[a[8]], fetch[a[9]], a[10], a[11]));
else if(template_name.find("gemm_nt")!=std::string::npos)
return tools::shared_ptr<templates::base>(new templates::gemm_nt(a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], fetch[a[8]], fetch[a[9]], a[10], a[11]));
return std::shared_ptr<templates::base>(new templates::gemm_nt(a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], fetch[a[8]], fetch[a[9]], a[10], a[11]));
else if(template_name.find("gemm_tt")!=std::string::npos)
return tools::shared_ptr<templates::base>(new templates::gemm_tt(a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], fetch[a[8]], fetch[a[9]], a[10], a[11]));
return std::shared_ptr<templates::base>(new templates::gemm_tt(a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], fetch[a[8]], fetch[a[9]], a[10], a[11]));
else
throw std::invalid_argument("Invalid expression: " + template_name);
}
@@ -222,7 +222,7 @@ void import(std::string const & fname, driver::CommandQueue & queue, model_map_t
numeric_type dtype = detail::get_dtype(elem);
// Get profiles
std::vector<tools::shared_ptr<templates::base> > templates;
std::vector<std::shared_ptr<templates::base> > templates;
js::Value const & profiles = document[opcstr][dtcstr]["profiles"];
for (js::SizeType id = 0 ; id < profiles.Size() ; ++id)
templates.push_back(detail::create(operation, tools::to_int_array<int>(profiles[id])));
@@ -231,10 +231,10 @@ void import(std::string const & fname, driver::CommandQueue & queue, model_map_t
{
// Get predictor
predictors::random_forest predictor(document[opcstr][dtcstr]["predictor"]);
result[std::make_pair(etype, dtype)] = tools::shared_ptr<model>(new model(etype, dtype, predictor, templates, queue));
result[std::make_pair(etype, dtype)] = std::shared_ptr<model>(new model(etype, dtype, predictor, templates, queue));
}
else
result[std::make_pair(etype, dtype)] = tools::shared_ptr<model>(new model(etype, dtype, *templates[0], queue));
result[std::make_pair(etype, dtype)] = std::shared_ptr<model>(new model(etype, dtype, *templates[0], queue));
}
}
}
@@ -242,9 +242,9 @@ void import(std::string const & fname, driver::CommandQueue & queue, model_map_t
}
std::map<std::pair<expression_type, numeric_type>, tools::shared_ptr<templates::base> > init_fallback()
std::map<std::pair<expression_type, numeric_type>, std::shared_ptr<templates::base> > init_fallback()
{
typedef tools::shared_ptr<templates::base> ptr_t;
typedef std::shared_ptr<templates::base> ptr_t;
std::map<std::pair<expression_type, numeric_type>, ptr_t > res;
numeric_type types[] = {CHAR_TYPE, UCHAR_TYPE, SHORT_TYPE, USHORT_TYPE, INT_TYPE, UINT_TYPE, LONG_TYPE, ULONG_TYPE, FLOAT_TYPE, DOUBLE_TYPE};
for(auto DTYPE : types)
@@ -271,7 +271,7 @@ model_map_t init_models(driver::CommandQueue & queue)
for(numeric_type dtype: dtypes)
for(expression_type etype: etypes)
res[std::make_pair(etype, dtype)] = tools::shared_ptr<model>(new model(etype, dtype, *fallbacks[std::make_pair(etype, dtype)], queue));
res[std::make_pair(etype, dtype)] = std::shared_ptr<model>(new model(etype, dtype, *fallbacks[std::make_pair(etype, dtype)], queue));
if(const char * homepath = std::getenv("HOME"))
import(std::string(homepath) + "/.isaac/devices/device0.json", queue, res);
@@ -286,7 +286,7 @@ model_map_t& models(driver::CommandQueue & queue)
return it->second;
}
std::map<std::pair<expression_type, numeric_type>, tools::shared_ptr<templates::base> > fallbacks = init_fallback();
std::map<std::pair<expression_type, numeric_type>, std::shared_ptr<templates::base> > fallbacks = init_fallback();
std::map<driver::CommandQueue, model_map_t> models_;
}