2015-01-12 13:20:53 -05:00
|
|
|
#include <set>
|
|
|
|
#include <fstream>
|
2015-01-28 17:08:39 -05:00
|
|
|
#include <stdexcept>
|
|
|
|
|
2015-01-12 13:20:53 -05:00
|
|
|
#include "rapidjson/document.h"
|
|
|
|
#include "atidlas/backend/parse.h"
|
|
|
|
#include "atidlas/backend/templates/vaxpy.h"
|
|
|
|
#include "atidlas/backend/templates/reduction.h"
|
|
|
|
#include "atidlas/backend/templates/maxpy.h"
|
|
|
|
#include "atidlas/backend/templates/mreduction.h"
|
|
|
|
#include "atidlas/backend/templates/mproduct.h"
|
2015-01-28 17:08:39 -05:00
|
|
|
#include "atidlas/exception/unknown_datatype.h"
|
2015-01-12 13:20:53 -05:00
|
|
|
#include "atidlas/exception/operation_not_supported.h"
|
|
|
|
#include "atidlas/model/model.h"
|
|
|
|
#include "atidlas/tools/make_vector.hpp"
|
|
|
|
#include "atidlas/tools/timer.hpp"
|
|
|
|
#include "convert.hpp"
|
|
|
|
|
|
|
|
|
|
|
|
namespace atidlas
|
|
|
|
{
|
|
|
|
|
|
|
|
|
|
|
|
std::string model::define_extension(std::string const & extensions, std::string const & ext)
|
|
|
|
{
|
|
|
|
if(extensions.find(ext)!=std::string::npos)
|
|
|
|
return std::string("#pragma OPENCL EXTENSION " + ext + " : enable\n");
|
|
|
|
return std::string("");
|
|
|
|
}
|
|
|
|
|
2015-02-01 22:28:49 -05:00
|
|
|
void model::fill_program_name(char* program_name, expressions_tuple const & expressions, binding_policy_t binding_policy)
|
2015-01-12 13:20:53 -05:00
|
|
|
{
|
2015-02-01 22:28:49 -05:00
|
|
|
if (expressions.order()==expressions_tuple::INDEPENDENT)
|
2015-01-12 13:20:53 -05:00
|
|
|
*program_name++='i';
|
|
|
|
else
|
|
|
|
*program_name++='s';
|
|
|
|
symbolic_binder* binder = NULL;
|
|
|
|
if(binding_policy==BIND_TO_HANDLE)
|
|
|
|
binder = new bind_to_handle();
|
|
|
|
else
|
|
|
|
binder = new bind_all_unique();
|
2015-02-01 22:28:49 -05:00
|
|
|
for (expressions_tuple::data_type::const_iterator it = expressions.data().begin(); it != expressions.data().end(); ++it)
|
2015-01-31 22:01:48 -05:00
|
|
|
traverse(**it, (*it)->root(), array_expression_representation_functor(*binder, program_name),true);
|
2015-01-12 13:20:53 -05:00
|
|
|
*program_name='\0';
|
|
|
|
delete binder;
|
|
|
|
}
|
|
|
|
|
2015-02-01 22:28:49 -05:00
|
|
|
std::vector<cl_ext::lazy_compiler>& model::init(expressions_tuple const & expressions, runtime_options const & opt)
|
2015-01-12 13:20:53 -05:00
|
|
|
{
|
2015-02-01 22:28:49 -05:00
|
|
|
cl::Context const & context = expressions.context();
|
|
|
|
std::string pname;
|
|
|
|
if(opt.program_name.empty())
|
|
|
|
{
|
|
|
|
char program_name[256];
|
|
|
|
fill_program_name(program_name, expressions, BIND_TO_HANDLE);
|
|
|
|
pname = std::string(program_name);
|
|
|
|
}
|
|
|
|
else
|
|
|
|
pname = opt.program_name;
|
2015-01-27 16:14:02 -05:00
|
|
|
std::vector<cl_ext::lazy_compiler> & to_init = lazy_programs_[context()][pname];
|
2015-01-12 13:20:53 -05:00
|
|
|
if(to_init.empty())
|
|
|
|
{
|
2015-02-01 22:28:49 -05:00
|
|
|
cl::Device device = queue_.getInfo<CL_QUEUE_DEVICE>();
|
2015-01-12 13:20:53 -05:00
|
|
|
std::string extensions = device.getInfo<CL_DEVICE_EXTENSIONS>();
|
|
|
|
|
2015-02-01 22:28:49 -05:00
|
|
|
to_init.push_back(cl_ext::lazy_compiler(context, pname, opt.recompile));
|
2015-01-12 13:20:53 -05:00
|
|
|
to_init.back().add(define_extension(extensions, "cl_khr_fp64"));
|
|
|
|
|
2015-02-01 22:28:49 -05:00
|
|
|
to_init.push_back(cl_ext::lazy_compiler(context, pname + "_fb", opt.recompile));
|
2015-01-12 13:20:53 -05:00
|
|
|
to_init.back().add(define_extension(extensions, "cl_khr_fp64"));
|
|
|
|
|
|
|
|
for(size_t i = 0 ; i < templates_.size() ; ++i)
|
|
|
|
{
|
2015-02-01 22:28:49 -05:00
|
|
|
std::vector<std::string> cur = templates_[i]->generate(i, expressions, device);
|
2015-01-12 13:20:53 -05:00
|
|
|
for(size_t j = 0 ; j < cur.size() ; ++j){
|
|
|
|
to_init[j].add(cur[j]);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return to_init;
|
|
|
|
}
|
|
|
|
|
2015-01-17 10:48:02 -05:00
|
|
|
model::model(predictors::random_forest const & predictor, std::vector< tools::shared_ptr<base> > const & templates, cl::CommandQueue & queue) :
|
2015-01-16 07:31:39 -05:00
|
|
|
templates_(templates), predictor_(new predictors::random_forest(predictor)), queue_(queue)
|
2015-01-12 13:20:53 -05:00
|
|
|
{}
|
|
|
|
|
2015-01-17 10:48:02 -05:00
|
|
|
model::model(std::vector< tools::shared_ptr<base> > const & templates, cl::CommandQueue & queue) : templates_(templates), queue_(queue)
|
2015-01-12 13:20:53 -05:00
|
|
|
{}
|
|
|
|
|
2015-01-17 10:48:02 -05:00
|
|
|
model::model(base const & tp, cl::CommandQueue & queue) : templates_(1,tp.clone()), queue_(queue)
|
2015-01-12 13:20:53 -05:00
|
|
|
{}
|
|
|
|
|
2015-02-01 22:28:49 -05:00
|
|
|
void model::execute(expressions_tuple const & expressions, runtime_options const & opt)
|
2015-01-12 13:20:53 -05:00
|
|
|
{
|
2015-02-01 22:28:49 -05:00
|
|
|
std::vector<cl_ext::lazy_compiler> & compilers = init(expressions, opt);
|
2015-01-12 13:20:53 -05:00
|
|
|
|
|
|
|
//Prediction
|
2015-02-01 22:28:49 -05:00
|
|
|
int label = 0;
|
|
|
|
if(opt.label>=0)
|
|
|
|
{
|
|
|
|
label = opt.label;
|
|
|
|
}
|
2015-01-12 13:20:53 -05:00
|
|
|
else
|
|
|
|
{
|
2015-02-01 22:28:49 -05:00
|
|
|
std::vector<int_t> x = templates_[0]->input_sizes(expressions);
|
|
|
|
//The user tuned the model specifically for this input size
|
|
|
|
if(hardcoded_.find(x)!=hardcoded_.end())
|
|
|
|
label = hardcoded_.at(x);
|
|
|
|
//The user bypasses the random forest
|
|
|
|
else if(predictor_.get())
|
|
|
|
{
|
|
|
|
std::vector<float> predictions = predictor_->predict(x);
|
|
|
|
label = std::distance(predictions.begin(),std::min_element(predictions.begin(), predictions.end()));
|
|
|
|
}
|
2015-01-12 13:20:53 -05:00
|
|
|
}
|
|
|
|
|
|
|
|
//Execution
|
2015-02-01 22:28:49 -05:00
|
|
|
templates_[label]->enqueue(queue_, compilers, label, expressions);
|
2015-01-12 13:20:53 -05:00
|
|
|
}
|
|
|
|
|
2015-02-01 22:28:49 -05:00
|
|
|
void model::tune(expressions_tuple const & expressions)
|
2015-01-12 13:20:53 -05:00
|
|
|
{
|
2015-02-01 22:28:49 -05:00
|
|
|
std::vector<cl_ext::lazy_compiler> & compilers = init(expressions);
|
2015-01-12 13:20:53 -05:00
|
|
|
|
|
|
|
//Collect the timings
|
|
|
|
std::vector<float> timings(templates_.size());
|
|
|
|
tools::timer timer;
|
|
|
|
for(size_t i = 0 ; i < templates_.size() ; ++i)
|
|
|
|
{
|
|
|
|
timer.start();
|
2015-02-01 22:28:49 -05:00
|
|
|
templates_[i]->enqueue(queue_, compilers, i, expressions);
|
2015-01-12 13:20:53 -05:00
|
|
|
queue_.finish();
|
|
|
|
timings[i] = timer.get();
|
|
|
|
}
|
|
|
|
|
|
|
|
//Fill the override
|
2015-02-01 22:28:49 -05:00
|
|
|
std::vector<int_t> x = templates_[0]->input_sizes(expressions);
|
2015-01-12 13:20:53 -05:00
|
|
|
hardcoded_[x] = std::distance(timings.begin(),std::min_element(timings.begin(), timings.end()));
|
|
|
|
}
|
|
|
|
|
|
|
|
model::templates_container const & model::templates() const
|
|
|
|
{ return templates_; }
|
|
|
|
|
|
|
|
///////////////////
|
|
|
|
|
|
|
|
namespace detail
|
|
|
|
{
|
|
|
|
static expression_type get_expression_type(std::string const & name)
|
|
|
|
{
|
2015-01-25 01:08:18 -05:00
|
|
|
if(name=="vaxpy") return VECTOR_AXPY_TYPE;
|
2015-01-27 02:41:27 -05:00
|
|
|
if(name=="dot") return REDUCTION_TYPE;
|
2015-01-25 18:19:19 -05:00
|
|
|
if(name=="maxpy") return MATRIX_AXPY_TYPE;
|
2015-01-27 02:41:27 -05:00
|
|
|
if(name=="gemvN") return ROW_WISE_REDUCTION_TYPE;
|
|
|
|
if(name=="gemvT") return COL_WISE_REDUCTION_TYPE;
|
|
|
|
if(name=="gemmNN") return MATRIX_PRODUCT_NN_TYPE;
|
|
|
|
if(name=="gemmNT") return MATRIX_PRODUCT_NT_TYPE;
|
|
|
|
if(name=="gemmTN") return MATRIX_PRODUCT_TN_TYPE;
|
|
|
|
if(name=="gemmTT") return MATRIX_PRODUCT_TT_TYPE;
|
2015-01-28 17:08:39 -05:00
|
|
|
throw std::invalid_argument("Invalid expression: " + name);
|
2015-01-12 13:20:53 -05:00
|
|
|
}
|
|
|
|
|
|
|
|
static numeric_type get_dtype(std::string const & name)
|
|
|
|
{
|
|
|
|
if(name=="float32") return FLOAT_TYPE;
|
|
|
|
if(name=="float64") return DOUBLE_TYPE;
|
2015-01-28 17:08:39 -05:00
|
|
|
throw std::invalid_argument("Invalid datatype: " + name);
|
2015-01-12 13:20:53 -05:00
|
|
|
}
|
|
|
|
|
2015-01-17 10:48:02 -05:00
|
|
|
static tools::shared_ptr<base> create(std::string const & template_name, std::vector<int> const & a)
|
2015-01-12 13:20:53 -05:00
|
|
|
{
|
|
|
|
fetching_policy_type fetch[] = {FETCH_FROM_LOCAL, FETCH_FROM_GLOBAL_STRIDED, FETCH_FROM_GLOBAL_CONTIGUOUS};
|
2015-01-25 01:08:18 -05:00
|
|
|
if(template_name=="vaxpy")
|
2015-01-27 02:41:27 -05:00
|
|
|
return tools::shared_ptr<base>(new vaxpy(a[0], a[1], a[2], fetch[a[3]]));
|
|
|
|
else if(template_name=="dot")
|
|
|
|
return tools::shared_ptr<base>(new reduction(a[0], a[1], a[2], fetch[a[3]]));
|
2015-01-25 18:19:19 -05:00
|
|
|
else if(template_name=="maxpy")
|
2015-01-27 02:41:27 -05:00
|
|
|
return tools::shared_ptr<base>(new maxpy(a[0], a[1], a[2], a[3], a[4], fetch[a[5]]));
|
|
|
|
else if(template_name.find("gemvN")!=std::string::npos)
|
|
|
|
return tools::shared_ptr<base>(new mreduction_rows(a[0], a[1], a[2], a[3], fetch[a[4]]));
|
|
|
|
else if(template_name.find("gemvT")!=std::string::npos)
|
|
|
|
return tools::shared_ptr<base>(new mreduction_cols(a[0], a[1], a[2], a[3], fetch[a[4]]));
|
|
|
|
else if(template_name.find("gemmNN")!=std::string::npos)
|
|
|
|
return tools::shared_ptr<base>(new mproduct_nn(a[0], a[1], a[2], a[3], a[4], a[5], a[6], fetch[a[7]], fetch[a[8]], a[9], a[10]));
|
|
|
|
else if(template_name.find("gemmTN")!=std::string::npos)
|
|
|
|
return tools::shared_ptr<base>(new mproduct_tn(a[0], a[1], a[2], a[3], a[4], a[5], a[6], fetch[a[7]], fetch[a[8]], a[9], a[10]));
|
|
|
|
else if(template_name.find("gemmNT")!=std::string::npos)
|
|
|
|
return tools::shared_ptr<base>(new mproduct_nt(a[0], a[1], a[2], a[3], a[4], a[5], a[6], fetch[a[7]], fetch[a[8]], a[9], a[10]));
|
|
|
|
else if(template_name.find("gemmTT")!=std::string::npos)
|
|
|
|
return tools::shared_ptr<base>(new mproduct_tt(a[0], a[1], a[2], a[3], a[4], a[5], a[6], fetch[a[7]], fetch[a[8]], a[9], a[10]));
|
2015-01-12 13:20:53 -05:00
|
|
|
else
|
2015-01-28 17:08:39 -05:00
|
|
|
throw std::invalid_argument("Invalid expression: " + template_name);
|
2015-01-12 13:20:53 -05:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2015-01-25 18:19:19 -05:00
|
|
|
void import(std::string const & fname, cl::CommandQueue & queue, model_map_t& result)
|
2015-01-12 13:20:53 -05:00
|
|
|
{
|
|
|
|
namespace js = rapidjson;
|
|
|
|
//Parse the JSON document
|
|
|
|
js::Document document;
|
|
|
|
std::ifstream t(fname.c_str());
|
2015-01-28 22:07:09 -05:00
|
|
|
if(!t) return;
|
2015-01-12 13:20:53 -05:00
|
|
|
std::string str;
|
|
|
|
t.seekg(0, std::ios::end);
|
|
|
|
str.reserve(t.tellg());
|
|
|
|
t.seekg(0, std::ios::beg);
|
|
|
|
str.assign((std::istreambuf_iterator<char>(t)), std::istreambuf_iterator<char>());
|
|
|
|
document.Parse<0>(str.c_str());
|
|
|
|
//Deserialize
|
2015-01-31 22:01:48 -05:00
|
|
|
std::vector<std::string> operations = tools::make_vector<std::string>() << "vaxpy" << "dot" << "maxpy" << "gemvN" << "gemvT" << "gemmNN" << "gemmTN" << "gemmTT";
|
2015-01-12 13:20:53 -05:00
|
|
|
std::vector<std::string> dtype = tools::make_vector<std::string>() << "float32" << "float64";
|
|
|
|
for(std::vector<std::string>::iterator op = operations.begin() ; op != operations.end() ; ++op)
|
|
|
|
{
|
|
|
|
const char * opcstr = op->c_str();
|
|
|
|
if(document.HasMember(opcstr))
|
|
|
|
{
|
|
|
|
expression_type etype = detail::get_expression_type(*op);
|
|
|
|
for(std::vector<std::string>::iterator dt = dtype.begin() ; dt != dtype.end() ; ++dt)
|
|
|
|
{
|
|
|
|
const char * dtcstr = dt->c_str();
|
|
|
|
if(document[opcstr].HasMember(dtcstr))
|
|
|
|
{
|
|
|
|
numeric_type dtype = detail::get_dtype(*dt);
|
|
|
|
|
|
|
|
// Get profiles
|
2015-01-17 10:48:02 -05:00
|
|
|
std::vector<tools::shared_ptr<base> > templates;
|
2015-01-12 13:20:53 -05:00
|
|
|
js::Value const & profiles = document[opcstr][dtcstr]["profiles"];
|
|
|
|
for (js::SizeType id = 0 ; id < profiles.Size() ; ++id)
|
|
|
|
templates.push_back(detail::create(*op, tools::to_int_array<int>(profiles[id])));
|
2015-01-25 18:19:19 -05:00
|
|
|
if(templates.size()>1)
|
|
|
|
{
|
2015-01-25 01:08:18 -05:00
|
|
|
// Get predictor
|
|
|
|
predictors::random_forest predictor(document[opcstr][dtcstr]["predictor"]);
|
|
|
|
result[std::make_pair(etype, dtype)] = tools::shared_ptr<model>(new model(predictor, templates, queue));
|
|
|
|
}
|
2015-01-25 18:19:19 -05:00
|
|
|
else
|
|
|
|
result[std::make_pair(etype, dtype)] = tools::shared_ptr<model>(new model(templates, queue));
|
2015-01-12 13:20:53 -05:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
model_map_t init_models(cl::CommandQueue & queue)
|
|
|
|
{
|
|
|
|
model_map_t res;
|
|
|
|
typedef tools::shared_ptr<model> ptr_t;
|
|
|
|
numeric_type types[] = {CHAR_TYPE, UCHAR_TYPE, SHORT_TYPE, USHORT_TYPE, INT_TYPE, UINT_TYPE, LONG_TYPE, ULONG_TYPE, FLOAT_TYPE, DOUBLE_TYPE};
|
|
|
|
|
2015-01-25 18:19:19 -05:00
|
|
|
for(size_t i = 0 ; i < 10 ; ++i){
|
2015-01-12 13:20:53 -05:00
|
|
|
numeric_type DTYPE = types[i];
|
2015-01-19 14:40:13 -05:00
|
|
|
res[std::make_pair(SCALAR_AXPY_TYPE, DTYPE)] = ptr_t(new model(vaxpy(1,64,128,FETCH_FROM_GLOBAL_STRIDED), queue));
|
|
|
|
res[std::make_pair(VECTOR_AXPY_TYPE, DTYPE)] = ptr_t (new model(vaxpy(1,64,128,FETCH_FROM_GLOBAL_STRIDED), queue));
|
|
|
|
res[std::make_pair(REDUCTION_TYPE, DTYPE)] = ptr_t(new model(reduction(1,64,128,FETCH_FROM_GLOBAL_STRIDED), queue));
|
2015-01-12 13:20:53 -05:00
|
|
|
res[std::make_pair(MATRIX_AXPY_TYPE, DTYPE)] = ptr_t(new model(maxpy(1,8,8,8,8,FETCH_FROM_GLOBAL_STRIDED), queue));
|
|
|
|
res[std::make_pair(ROW_WISE_REDUCTION_TYPE, DTYPE)] = ptr_t(new model(mreduction_rows(1, 8, 8, 16, FETCH_FROM_GLOBAL_STRIDED), queue));
|
|
|
|
res[std::make_pair(COL_WISE_REDUCTION_TYPE, DTYPE)] = ptr_t(new model(mreduction_cols(1, 8, 8, 16, FETCH_FROM_GLOBAL_STRIDED), queue));
|
|
|
|
res[std::make_pair(MATRIX_PRODUCT_NN_TYPE, DTYPE)] = ptr_t(new model(mproduct_nn(1, 8, 8, 8, 4, 1, 4, FETCH_FROM_LOCAL, FETCH_FROM_LOCAL, 8, 8), queue));
|
|
|
|
res[std::make_pair(MATRIX_PRODUCT_TN_TYPE, DTYPE)] = ptr_t(new model(mproduct_tn(1, 8, 8, 8, 4, 1, 4, FETCH_FROM_LOCAL, FETCH_FROM_LOCAL, 8, 8), queue));
|
|
|
|
res[std::make_pair(MATRIX_PRODUCT_NT_TYPE, DTYPE)] = ptr_t(new model(mproduct_nt(1, 8, 8, 8, 4, 1, 4, FETCH_FROM_LOCAL, FETCH_FROM_LOCAL, 8, 8), queue));
|
|
|
|
res[std::make_pair(MATRIX_PRODUCT_TT_TYPE, DTYPE)] = ptr_t(new model(mproduct_tt(1, 8, 8, 8, 4, 1, 4, FETCH_FROM_LOCAL, FETCH_FROM_LOCAL, 8, 8), queue));
|
|
|
|
}
|
2015-02-01 15:58:05 -05:00
|
|
|
if(const char * homepath = std::getenv("HOME"))
|
|
|
|
import(std::string(homepath) + "/.atidlas/devices/device0.json", queue, res);
|
2015-01-12 13:20:53 -05:00
|
|
|
return res;
|
|
|
|
}
|
|
|
|
|
|
|
|
model_map_t& get_model_map(cl::CommandQueue & queue)
|
|
|
|
{
|
2015-01-27 16:14:02 -05:00
|
|
|
std::map<cl::CommandQueue, model_map_t, cl_ext::compare>::iterator it = models.find(queue);
|
2015-01-12 13:20:53 -05:00
|
|
|
if(it == models.end())
|
|
|
|
return models.insert(std::make_pair(queue, init_models(queue))).first->second;
|
|
|
|
return it->second;
|
|
|
|
}
|
|
|
|
|
|
|
|
model& get_model(cl::CommandQueue & queue, expression_type expression, numeric_type dtype)
|
|
|
|
{
|
|
|
|
std::pair<expression_type, numeric_type> key(expression, dtype);
|
|
|
|
return *get_model_map(queue).at(key);
|
|
|
|
}
|
|
|
|
|
2015-01-27 16:14:02 -05:00
|
|
|
std::map<cl::CommandQueue, model_map_t, cl_ext::compare> models;
|
2015-01-12 13:20:53 -05:00
|
|
|
|
|
|
|
}
|