Ported to C++11
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
#include <set>
|
||||
#include <fstream>
|
||||
#include <stdexcept>
|
||||
#include <algorithm>
|
||||
|
||||
#include "rapidjson/document.h"
|
||||
#include "atidlas/backend/parse.h"
|
||||
@@ -39,8 +40,8 @@ void model::fill_program_name(char* program_name, expressions_tuple const & expr
|
||||
binder = new bind_to_handle();
|
||||
else
|
||||
binder = new bind_all_unique();
|
||||
for (expressions_tuple::data_type::const_iterator it = expressions.data().begin(); it != expressions.data().end(); ++it)
|
||||
traverse(**it, (*it)->root(), array_expression_representation_functor(*binder, program_name),true);
|
||||
for (const auto & elem : expressions.data())
|
||||
traverse(*elem, elem->root(), array_expression_representation_functor(*binder, program_name),true);
|
||||
*program_name='\0';
|
||||
delete binder;
|
||||
}
|
||||
@@ -80,11 +81,11 @@ std::vector<cl_ext::lazy_compiler>& model::init(expressions_tuple const & expres
|
||||
return to_init;
|
||||
}
|
||||
|
||||
model::model(predictors::random_forest const & predictor, std::vector< tools::shared_ptr<base> > const & templates, cl::CommandQueue & queue) :
|
||||
model::model(predictors::random_forest const & predictor, std::vector< std::shared_ptr<base> > const & templates, cl::CommandQueue & queue) :
|
||||
templates_(templates), predictor_(new predictors::random_forest(predictor)), queue_(queue)
|
||||
{}
|
||||
|
||||
model::model(std::vector< tools::shared_ptr<base> > const & templates, cl::CommandQueue & queue) : templates_(templates), queue_(queue)
|
||||
model::model(std::vector< std::shared_ptr<base> > const & templates, cl::CommandQueue & queue) : templates_(templates), queue_(queue)
|
||||
{}
|
||||
|
||||
model::model(base const & tp, cl::CommandQueue & queue) : templates_(1,tp.clone()), queue_(queue)
|
||||
@@ -166,27 +167,27 @@ namespace detail
|
||||
throw std::invalid_argument("Invalid datatype: " + name);
|
||||
}
|
||||
|
||||
static tools::shared_ptr<base> create(std::string const & template_name, std::vector<int> const & a)
|
||||
static std::shared_ptr<base> create(std::string const & template_name, std::vector<int> const & a)
|
||||
{
|
||||
fetching_policy_type fetch[] = {FETCH_FROM_LOCAL, FETCH_FROM_GLOBAL_STRIDED, FETCH_FROM_GLOBAL_CONTIGUOUS};
|
||||
if(template_name=="vaxpy")
|
||||
return tools::shared_ptr<base>(new vaxpy(a[0], a[1], a[2], fetch[a[3]]));
|
||||
return std::shared_ptr<base>(new vaxpy(a[0], a[1], a[2], fetch[a[3]]));
|
||||
else if(template_name=="dot")
|
||||
return tools::shared_ptr<base>(new reduction(a[0], a[1], a[2], fetch[a[3]]));
|
||||
return std::shared_ptr<base>(new reduction(a[0], a[1], a[2], fetch[a[3]]));
|
||||
else if(template_name=="maxpy")
|
||||
return tools::shared_ptr<base>(new maxpy(a[0], a[1], a[2], a[3], a[4], fetch[a[5]]));
|
||||
return std::shared_ptr<base>(new maxpy(a[0], a[1], a[2], a[3], a[4], fetch[a[5]]));
|
||||
else if(template_name.find("gemvN")!=std::string::npos)
|
||||
return tools::shared_ptr<base>(new mreduction_rows(a[0], a[1], a[2], a[3], fetch[a[4]]));
|
||||
return std::shared_ptr<base>(new mreduction_rows(a[0], a[1], a[2], a[3], fetch[a[4]]));
|
||||
else if(template_name.find("gemvT")!=std::string::npos)
|
||||
return tools::shared_ptr<base>(new mreduction_cols(a[0], a[1], a[2], a[3], fetch[a[4]]));
|
||||
return std::shared_ptr<base>(new mreduction_cols(a[0], a[1], a[2], a[3], fetch[a[4]]));
|
||||
else if(template_name.find("gemmNN")!=std::string::npos)
|
||||
return tools::shared_ptr<base>(new mproduct_nn(a[0], a[1], a[2], a[3], a[4], a[5], a[6], fetch[a[7]], fetch[a[8]], a[9], a[10]));
|
||||
return std::shared_ptr<base>(new mproduct_nn(a[0], a[1], a[2], a[3], a[4], a[5], a[6], fetch[a[7]], fetch[a[8]], a[9], a[10]));
|
||||
else if(template_name.find("gemmTN")!=std::string::npos)
|
||||
return tools::shared_ptr<base>(new mproduct_tn(a[0], a[1], a[2], a[3], a[4], a[5], a[6], fetch[a[7]], fetch[a[8]], a[9], a[10]));
|
||||
return std::shared_ptr<base>(new mproduct_tn(a[0], a[1], a[2], a[3], a[4], a[5], a[6], fetch[a[7]], fetch[a[8]], a[9], a[10]));
|
||||
else if(template_name.find("gemmNT")!=std::string::npos)
|
||||
return tools::shared_ptr<base>(new mproduct_nt(a[0], a[1], a[2], a[3], a[4], a[5], a[6], fetch[a[7]], fetch[a[8]], a[9], a[10]));
|
||||
return std::shared_ptr<base>(new mproduct_nt(a[0], a[1], a[2], a[3], a[4], a[5], a[6], fetch[a[7]], fetch[a[8]], a[9], a[10]));
|
||||
else if(template_name.find("gemmTT")!=std::string::npos)
|
||||
return tools::shared_ptr<base>(new mproduct_tt(a[0], a[1], a[2], a[3], a[4], a[5], a[6], fetch[a[7]], fetch[a[8]], a[9], a[10]));
|
||||
return std::shared_ptr<base>(new mproduct_tt(a[0], a[1], a[2], a[3], a[4], a[5], a[6], fetch[a[7]], fetch[a[8]], a[9], a[10]));
|
||||
else
|
||||
throw std::invalid_argument("Invalid expression: " + template_name);
|
||||
}
|
||||
@@ -208,32 +209,32 @@ void import(std::string const & fname, cl::CommandQueue & queue, model_map_t& re
|
||||
//Deserialize
|
||||
std::vector<std::string> operations = tools::make_vector<std::string>() << "vaxpy" << "dot" << "maxpy" << "gemvN" << "gemvT" << "gemmNN" << "gemmTN" << "gemmTT";
|
||||
std::vector<std::string> dtype = tools::make_vector<std::string>() << "float32" << "float64";
|
||||
for(std::vector<std::string>::iterator op = operations.begin() ; op != operations.end() ; ++op)
|
||||
for(auto & operation : operations)
|
||||
{
|
||||
const char * opcstr = op->c_str();
|
||||
const char * opcstr = operation.c_str();
|
||||
if(document.HasMember(opcstr))
|
||||
{
|
||||
expression_type etype = detail::get_expression_type(*op);
|
||||
for(std::vector<std::string>::iterator dt = dtype.begin() ; dt != dtype.end() ; ++dt)
|
||||
expression_type etype = detail::get_expression_type(operation);
|
||||
for(auto & elem : dtype)
|
||||
{
|
||||
const char * dtcstr = dt->c_str();
|
||||
const char * dtcstr = elem.c_str();
|
||||
if(document[opcstr].HasMember(dtcstr))
|
||||
{
|
||||
numeric_type dtype = detail::get_dtype(*dt);
|
||||
numeric_type dtype = detail::get_dtype(elem);
|
||||
|
||||
// Get profiles
|
||||
std::vector<tools::shared_ptr<base> > templates;
|
||||
std::vector<std::shared_ptr<base> > templates;
|
||||
js::Value const & profiles = document[opcstr][dtcstr]["profiles"];
|
||||
for (js::SizeType id = 0 ; id < profiles.Size() ; ++id)
|
||||
templates.push_back(detail::create(*op, tools::to_int_array<int>(profiles[id])));
|
||||
templates.push_back(detail::create(operation, tools::to_int_array<int>(profiles[id])));
|
||||
if(templates.size()>1)
|
||||
{
|
||||
// Get predictor
|
||||
predictors::random_forest predictor(document[opcstr][dtcstr]["predictor"]);
|
||||
result[std::make_pair(etype, dtype)] = tools::shared_ptr<model>(new model(predictor, templates, queue));
|
||||
result[std::make_pair(etype, dtype)] = std::shared_ptr<model>(new model(predictor, templates, queue));
|
||||
}
|
||||
else
|
||||
result[std::make_pair(etype, dtype)] = tools::shared_ptr<model>(new model(templates, queue));
|
||||
result[std::make_pair(etype, dtype)] = std::shared_ptr<model>(new model(templates, queue));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -243,11 +244,11 @@ void import(std::string const & fname, cl::CommandQueue & queue, model_map_t& re
|
||||
model_map_t init_models(cl::CommandQueue & queue)
|
||||
{
|
||||
model_map_t res;
|
||||
typedef tools::shared_ptr<model> ptr_t;
|
||||
typedef std::shared_ptr<model> ptr_t;
|
||||
numeric_type types[] = {CHAR_TYPE, UCHAR_TYPE, SHORT_TYPE, USHORT_TYPE, INT_TYPE, UINT_TYPE, LONG_TYPE, ULONG_TYPE, FLOAT_TYPE, DOUBLE_TYPE};
|
||||
|
||||
for(size_t i = 0 ; i < 10 ; ++i){
|
||||
numeric_type DTYPE = types[i];
|
||||
for(auto DTYPE : types){
|
||||
|
||||
res[std::make_pair(SCALAR_AXPY_TYPE, DTYPE)] = ptr_t(new model(vaxpy(1,64,128,FETCH_FROM_GLOBAL_STRIDED), queue));
|
||||
res[std::make_pair(VECTOR_AXPY_TYPE, DTYPE)] = ptr_t (new model(vaxpy(1,64,128,FETCH_FROM_GLOBAL_STRIDED), queue));
|
||||
res[std::make_pair(REDUCTION_TYPE, DTYPE)] = ptr_t(new model(reduction(1,64,128,FETCH_FROM_GLOBAL_STRIDED), queue));
|
||||
|
Reference in New Issue
Block a user