2015-01-12 13:20:53 -05:00
|
|
|
#include <set>
|
|
|
|
#include <fstream>
|
2015-01-28 17:08:39 -05:00
|
|
|
#include <stdexcept>
|
2015-02-04 22:06:15 -05:00
|
|
|
#include <algorithm>
|
2015-02-09 01:58:32 -05:00
|
|
|
#include <numeric>
|
2015-04-29 15:50:57 -04:00
|
|
|
#include <memory>
|
2015-01-28 17:08:39 -05:00
|
|
|
|
2015-01-12 13:20:53 -05:00
|
|
|
#include "rapidjson/document.h"
|
2015-08-04 20:56:05 -07:00
|
|
|
#include "isaac/kernels/parse.h"
|
|
|
|
#include "isaac/kernels/templates/axpy.h"
|
|
|
|
#include "isaac/kernels/templates/dot.h"
|
|
|
|
#include "isaac/kernels/templates/ger.h"
|
|
|
|
#include "isaac/kernels/templates/gemv.h"
|
|
|
|
#include "isaac/kernels/templates/gemm.h"
|
2015-08-04 11:11:38 -07:00
|
|
|
#include "isaac/driver/program_cache.h"
|
2015-04-29 15:50:57 -04:00
|
|
|
#include "isaac/exception/unknown_datatype.h"
|
|
|
|
#include "isaac/exception/operation_not_supported.h"
|
|
|
|
#include "isaac/model/model.h"
|
|
|
|
#include "isaac/tools/make_vector.hpp"
|
|
|
|
#include "isaac/tools/timer.hpp"
|
2015-08-05 11:13:49 -07:00
|
|
|
#include "isaac/tools/getenv.hpp"
|
|
|
|
|
2015-01-12 13:20:53 -05:00
|
|
|
#include "convert.hpp"
|
|
|
|
|
|
|
|
|
2015-04-29 15:50:57 -04:00
|
|
|
namespace isaac
|
2015-01-12 13:20:53 -05:00
|
|
|
{
|
|
|
|
|
2015-08-05 12:47:20 -07:00
|
|
|
static long time_event(long sum, driver::Event const & e)
|
2015-08-04 11:11:38 -07:00
|
|
|
{
|
|
|
|
return sum + e.elapsed_time();
|
|
|
|
}
|
2015-02-09 01:58:32 -05:00
|
|
|
|
2015-02-01 22:28:49 -05:00
|
|
|
void model::fill_program_name(char* program_name, expressions_tuple const & expressions, binding_policy_t binding_policy)
|
2015-01-12 13:20:53 -05:00
|
|
|
{
|
2015-02-01 22:28:49 -05:00
|
|
|
if (expressions.order()==expressions_tuple::INDEPENDENT)
|
2015-01-12 13:20:53 -05:00
|
|
|
*program_name++='i';
|
|
|
|
else
|
|
|
|
*program_name++='s';
|
|
|
|
symbolic_binder* binder = NULL;
|
|
|
|
if(binding_policy==BIND_TO_HANDLE)
|
|
|
|
binder = new bind_to_handle();
|
|
|
|
else
|
|
|
|
binder = new bind_all_unique();
|
2015-02-04 22:06:15 -05:00
|
|
|
for (const auto & elem : expressions.data())
|
|
|
|
traverse(*elem, elem->root(), array_expression_representation_functor(*binder, program_name),true);
|
2015-01-12 13:20:53 -05:00
|
|
|
*program_name='\0';
|
|
|
|
delete binder;
|
|
|
|
}
|
|
|
|
|
2015-07-30 14:35:41 -07:00
|
|
|
driver::Program const & model::init(controller<expressions_tuple> const & expressions)
|
2015-01-12 13:20:53 -05:00
|
|
|
{
|
2015-07-30 14:35:41 -07:00
|
|
|
driver::Context & context = (driver::Context&)expressions.x().context();
|
2015-02-01 22:28:49 -05:00
|
|
|
std::string pname;
|
2015-02-05 04:42:57 -05:00
|
|
|
compilation_options_type const & opt = expressions.compilation_options();
|
2015-02-01 22:28:49 -05:00
|
|
|
if(opt.program_name.empty())
|
|
|
|
{
|
|
|
|
char program_name[256];
|
2015-02-05 04:42:57 -05:00
|
|
|
fill_program_name(program_name, expressions.x(), BIND_TO_HANDLE);
|
2015-02-01 22:28:49 -05:00
|
|
|
pname = std::string(program_name);
|
|
|
|
}
|
|
|
|
else
|
2015-02-05 04:42:57 -05:00
|
|
|
pname = expressions.compilation_options().program_name;
|
2015-04-29 15:50:57 -04:00
|
|
|
|
2015-08-04 10:53:39 -07:00
|
|
|
driver::Program const * program = cache_.find(pname);
|
2015-07-30 14:35:41 -07:00
|
|
|
if(program)
|
|
|
|
return *program;
|
|
|
|
|
|
|
|
std::string srcs;
|
|
|
|
for(unsigned int i = 0 ; i < templates_.size() ; ++i){
|
2015-08-05 12:07:51 -07:00
|
|
|
srcs += templates_[i]->generate(std::to_string(i), expressions.x(), context.device());
|
2015-07-30 14:35:41 -07:00
|
|
|
}
|
|
|
|
srcs += fallback_->generate("fallback", expressions.x(), context.device());
|
2015-08-04 10:53:39 -07:00
|
|
|
return cache_.add(context, pname, srcs);
|
2015-01-12 13:20:53 -05:00
|
|
|
}
|
|
|
|
|
2015-07-28 15:26:10 -07:00
|
|
|
model::model(expression_type etype, numeric_type dtype, predictors::random_forest const & predictor, std::vector< std::shared_ptr<templates::base> > const & templates, driver::CommandQueue const & queue) :
|
2015-08-04 10:53:39 -07:00
|
|
|
templates_(templates), fallback_(fallbacks[std::make_pair(etype, dtype)]), predictor_(new predictors::random_forest(predictor)), queue_(queue), cache_(driver::backend::programs::get(queue,etype,dtype))
|
|
|
|
{
|
|
|
|
cache_.clear();
|
|
|
|
}
|
2015-01-12 13:20:53 -05:00
|
|
|
|
|
|
|
|
2015-08-04 10:53:39 -07:00
|
|
|
model::model(expression_type etype, numeric_type dtype, templates::base const & tp, driver::CommandQueue const & queue) : templates_(1,tp.clone()), fallback_(fallbacks[std::make_pair(etype, dtype)]), queue_(queue), cache_(driver::backend::programs::get(queue,etype,dtype))
|
|
|
|
{
|
|
|
|
cache_.clear();
|
|
|
|
}
|
2015-01-12 13:20:53 -05:00
|
|
|
|
2015-02-09 01:58:32 -05:00
|
|
|
void model::execute(controller<expressions_tuple> const & expr)
|
2015-01-12 13:20:53 -05:00
|
|
|
{
|
2015-07-30 14:35:41 -07:00
|
|
|
driver::Program const & program = init(expr);
|
2015-02-09 01:58:32 -05:00
|
|
|
std::vector<int_t> x = templates_[0]->input_sizes(expr.x());
|
2015-01-12 13:20:53 -05:00
|
|
|
|
2015-02-08 14:23:38 -05:00
|
|
|
//Specific tuning if requested
|
2015-02-09 01:58:32 -05:00
|
|
|
if(expr.dispatcher_options().tune && hardcoded_.find(x)==hardcoded_.end())
|
2015-01-12 13:20:53 -05:00
|
|
|
{
|
2015-02-09 01:58:32 -05:00
|
|
|
std::vector<double> timings(templates_.size());
|
2015-07-21 17:18:50 -04:00
|
|
|
for(unsigned int i = 0 ; i < templates_.size() ; ++i)
|
2015-02-01 22:28:49 -05:00
|
|
|
{
|
2015-04-29 15:50:57 -04:00
|
|
|
std::list<driver::Event> events;
|
|
|
|
try{
|
2015-08-05 12:07:51 -07:00
|
|
|
templates_[i]->enqueue(queue_, program, std::to_string(i), *fallback_, control(expr.x(), execution_options_type(0, &events)));
|
2015-04-29 15:50:57 -04:00
|
|
|
queue_.synchronize();
|
|
|
|
timings[i] = 1e-9*std::accumulate(events.begin(), events.end(), 0, &time_event);
|
|
|
|
}catch(...){
|
|
|
|
timings[i] = INFINITY;
|
|
|
|
}
|
2015-02-01 22:28:49 -05:00
|
|
|
}
|
2015-02-08 14:23:38 -05:00
|
|
|
//Fill the override
|
2015-02-09 01:58:32 -05:00
|
|
|
std::vector<int_t> x = templates_[0]->input_sizes(expr.x());
|
2015-02-08 14:23:38 -05:00
|
|
|
hardcoded_[x] = std::distance(timings.begin(),std::min_element(timings.begin(), timings.end()));
|
2015-01-12 13:20:53 -05:00
|
|
|
}
|
|
|
|
|
2015-02-08 14:23:38 -05:00
|
|
|
//Prediction
|
|
|
|
int label = 0;
|
2015-02-09 01:58:32 -05:00
|
|
|
if(expr.dispatcher_options().label>=0)
|
|
|
|
label = expr.dispatcher_options().label;
|
2015-02-08 14:23:38 -05:00
|
|
|
else if(hardcoded_.find(x)!=hardcoded_.end())
|
|
|
|
label = hardcoded_.at(x);
|
|
|
|
else if(predictor_.get())
|
2015-01-12 13:20:53 -05:00
|
|
|
{
|
2015-02-08 14:23:38 -05:00
|
|
|
std::vector<float> predictions = predictor_->predict(x);
|
2015-04-29 15:50:57 -04:00
|
|
|
label = std::distance(predictions.begin(),std::max_element(predictions.begin(), predictions.end()));
|
2015-01-12 13:20:53 -05:00
|
|
|
}
|
|
|
|
|
2015-02-08 14:23:38 -05:00
|
|
|
//Execution
|
2015-08-05 12:07:51 -07:00
|
|
|
return templates_[label]->enqueue(queue_, program, std::to_string(label), *fallback_, expr);
|
2015-01-12 13:20:53 -05:00
|
|
|
}
|
|
|
|
|
|
|
|
model::templates_container const & model::templates() const
|
2015-08-04 16:03:14 -07:00
|
|
|
{
|
|
|
|
return templates_;
|
|
|
|
}
|
2015-01-12 13:20:53 -05:00
|
|
|
|
|
|
|
///////////////////
|
|
|
|
|
|
|
|
namespace detail
|
|
|
|
{
|
|
|
|
static expression_type get_expression_type(std::string const & name)
|
|
|
|
{
|
2015-07-11 09:36:01 -04:00
|
|
|
if(name=="axpy") return AXPY_TYPE;
|
|
|
|
if(name=="dot") return DOT_TYPE;
|
|
|
|
if(name=="ger") return GER_TYPE;
|
|
|
|
if(name=="gemv_n") return GEMV_N_TYPE;
|
|
|
|
if(name=="gemv_t") return GEMV_T_TYPE;
|
|
|
|
if(name=="gemm_nn") return GEMM_NN_TYPE;
|
|
|
|
if(name=="gemm_nt") return GEMM_NT_TYPE;
|
|
|
|
if(name=="gemm_tn") return GEMM_TN_TYPE;
|
|
|
|
if(name=="gemm_tt") return GEMM_TT_TYPE;
|
2015-01-28 17:08:39 -05:00
|
|
|
throw std::invalid_argument("Invalid expression: " + name);
|
2015-01-12 13:20:53 -05:00
|
|
|
}
|
|
|
|
|
|
|
|
static numeric_type get_dtype(std::string const & name)
|
|
|
|
{
|
|
|
|
if(name=="float32") return FLOAT_TYPE;
|
|
|
|
if(name=="float64") return DOUBLE_TYPE;
|
2015-01-28 17:08:39 -05:00
|
|
|
throw std::invalid_argument("Invalid datatype: " + name);
|
2015-01-12 13:20:53 -05:00
|
|
|
}
|
|
|
|
|
2015-07-28 15:26:10 -07:00
|
|
|
static std::shared_ptr<templates::base> create(std::string const & template_name, std::vector<int> const & a)
|
2015-01-12 13:20:53 -05:00
|
|
|
{
|
2015-07-11 09:36:01 -04:00
|
|
|
templates::fetching_policy_type fetch[] = {templates::FETCH_FROM_LOCAL, templates::FETCH_FROM_GLOBAL_STRIDED, templates::FETCH_FROM_GLOBAL_CONTIGUOUS};
|
|
|
|
if(template_name=="axpy")
|
2015-07-28 15:26:10 -07:00
|
|
|
return std::shared_ptr<templates::base>(new templates::axpy(a[0], a[1], a[2], fetch[a[3]]));
|
2015-01-27 02:41:27 -05:00
|
|
|
else if(template_name=="dot")
|
2015-07-28 15:26:10 -07:00
|
|
|
return std::shared_ptr<templates::base>(new templates::dot(a[0], a[1], a[2], fetch[a[3]]));
|
2015-07-11 09:36:01 -04:00
|
|
|
else if(template_name=="ger")
|
2015-07-28 15:26:10 -07:00
|
|
|
return std::shared_ptr<templates::base>(new templates::ger(a[0], a[1], a[2], a[3], a[4], fetch[a[5]]));
|
2015-07-11 09:36:01 -04:00
|
|
|
else if(template_name.find("gemv_n")!=std::string::npos)
|
2015-07-28 15:26:10 -07:00
|
|
|
return std::shared_ptr<templates::base>(new templates::gemv_n(a[0], a[1], a[2], a[3], a[4], fetch[a[5]]));
|
2015-07-11 09:36:01 -04:00
|
|
|
else if(template_name.find("gemv_t")!=std::string::npos)
|
2015-07-28 15:26:10 -07:00
|
|
|
return std::shared_ptr<templates::base>(new templates::gemv_t(a[0], a[1], a[2], a[3], a[4], fetch[a[5]]));
|
2015-07-11 09:36:01 -04:00
|
|
|
else if(template_name.find("gemm_nn")!=std::string::npos)
|
2015-07-28 15:26:10 -07:00
|
|
|
return std::shared_ptr<templates::base>(new templates::gemm_nn(a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], fetch[a[8]], fetch[a[9]], a[10], a[11]));
|
2015-07-11 09:36:01 -04:00
|
|
|
else if(template_name.find("gemm_tn")!=std::string::npos)
|
2015-07-28 15:26:10 -07:00
|
|
|
return std::shared_ptr<templates::base>(new templates::gemm_tn(a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], fetch[a[8]], fetch[a[9]], a[10], a[11]));
|
2015-07-11 09:36:01 -04:00
|
|
|
else if(template_name.find("gemm_nt")!=std::string::npos)
|
2015-07-28 15:26:10 -07:00
|
|
|
return std::shared_ptr<templates::base>(new templates::gemm_nt(a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], fetch[a[8]], fetch[a[9]], a[10], a[11]));
|
2015-07-11 09:36:01 -04:00
|
|
|
else if(template_name.find("gemm_tt")!=std::string::npos)
|
2015-07-28 15:26:10 -07:00
|
|
|
return std::shared_ptr<templates::base>(new templates::gemm_tt(a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], fetch[a[8]], fetch[a[9]], a[10], a[11]));
|
2015-01-12 13:20:53 -05:00
|
|
|
else
|
2015-01-28 17:08:39 -05:00
|
|
|
throw std::invalid_argument("Invalid expression: " + template_name);
|
2015-01-12 13:20:53 -05:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2015-08-04 10:06:52 -07:00
|
|
|
void models::import(std::string const & fname, driver::CommandQueue const & queue)
|
2015-01-12 13:20:53 -05:00
|
|
|
{
|
|
|
|
namespace js = rapidjson;
|
2015-08-04 10:06:52 -07:00
|
|
|
map_type & result = data_[queue];
|
|
|
|
|
2015-01-12 13:20:53 -05:00
|
|
|
//Parse the JSON document
|
|
|
|
js::Document document;
|
|
|
|
std::ifstream t(fname.c_str());
|
2015-01-28 22:07:09 -05:00
|
|
|
if(!t) return;
|
2015-01-12 13:20:53 -05:00
|
|
|
std::string str;
|
|
|
|
t.seekg(0, std::ios::end);
|
|
|
|
str.reserve(t.tellg());
|
|
|
|
t.seekg(0, std::ios::beg);
|
|
|
|
str.assign((std::istreambuf_iterator<char>(t)), std::istreambuf_iterator<char>());
|
|
|
|
document.Parse<0>(str.c_str());
|
|
|
|
//Deserialize
|
2015-07-11 09:36:01 -04:00
|
|
|
std::vector<std::string> operations = {"axpy", "dot", "ger", "gemv_n", "gemv_t", "gemm_nn", "gemm_tn", "gemm_nt", "gemm_tt"};
|
2015-04-29 15:50:57 -04:00
|
|
|
std::vector<std::string> dtype = {"float32", "float64"};
|
2015-02-04 22:06:15 -05:00
|
|
|
for(auto & operation : operations)
|
2015-01-12 13:20:53 -05:00
|
|
|
{
|
2015-02-04 22:06:15 -05:00
|
|
|
const char * opcstr = operation.c_str();
|
2015-01-12 13:20:53 -05:00
|
|
|
if(document.HasMember(opcstr))
|
|
|
|
{
|
2015-02-04 22:06:15 -05:00
|
|
|
expression_type etype = detail::get_expression_type(operation);
|
|
|
|
for(auto & elem : dtype)
|
2015-01-12 13:20:53 -05:00
|
|
|
{
|
2015-02-04 22:06:15 -05:00
|
|
|
const char * dtcstr = elem.c_str();
|
2015-01-12 13:20:53 -05:00
|
|
|
if(document[opcstr].HasMember(dtcstr))
|
|
|
|
{
|
2015-02-04 22:06:15 -05:00
|
|
|
numeric_type dtype = detail::get_dtype(elem);
|
2015-01-12 13:20:53 -05:00
|
|
|
|
|
|
|
// Get profiles
|
2015-07-28 15:26:10 -07:00
|
|
|
std::vector<std::shared_ptr<templates::base> > templates;
|
2015-01-12 13:20:53 -05:00
|
|
|
js::Value const & profiles = document[opcstr][dtcstr]["profiles"];
|
|
|
|
for (js::SizeType id = 0 ; id < profiles.Size() ; ++id)
|
2015-02-04 22:06:15 -05:00
|
|
|
templates.push_back(detail::create(operation, tools::to_int_array<int>(profiles[id])));
|
2015-04-29 15:50:57 -04:00
|
|
|
|
2015-01-25 18:19:19 -05:00
|
|
|
if(templates.size()>1)
|
|
|
|
{
|
2015-01-25 01:08:18 -05:00
|
|
|
// Get predictor
|
|
|
|
predictors::random_forest predictor(document[opcstr][dtcstr]["predictor"]);
|
2015-07-28 15:26:10 -07:00
|
|
|
result[std::make_pair(etype, dtype)] = std::shared_ptr<model>(new model(etype, dtype, predictor, templates, queue));
|
2015-01-25 01:08:18 -05:00
|
|
|
}
|
2015-01-25 18:19:19 -05:00
|
|
|
else
|
2015-07-28 15:26:10 -07:00
|
|
|
result[std::make_pair(etype, dtype)] = std::shared_ptr<model>(new model(etype, dtype, *templates[0], queue));
|
2015-01-12 13:20:53 -05:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2015-08-04 10:06:52 -07:00
|
|
|
models::map_type& models::init(driver::CommandQueue const & queue)
|
|
|
|
{
|
|
|
|
map_type & result = data_[queue];
|
|
|
|
|
|
|
|
numeric_type dtypes[] = {CHAR_TYPE, UCHAR_TYPE, SHORT_TYPE, USHORT_TYPE, INT_TYPE, UINT_TYPE, LONG_TYPE, ULONG_TYPE, FLOAT_TYPE, DOUBLE_TYPE};
|
|
|
|
expression_type etypes[] = {AXPY_TYPE, DOT_TYPE, GER_TYPE, GEMV_N_TYPE, GEMV_T_TYPE, GEMM_NN_TYPE, GEMM_NT_TYPE, GEMM_TN_TYPE, GEMM_TT_TYPE};
|
|
|
|
|
|
|
|
for(numeric_type dtype: dtypes)
|
|
|
|
for(expression_type etype: etypes)
|
|
|
|
result[std::make_pair(etype, dtype)] = std::shared_ptr<model>(new model(etype, dtype, *fallbacks[std::make_pair(etype, dtype)], queue));
|
|
|
|
|
2015-08-05 11:13:49 -07:00
|
|
|
std::string homepath = tools::getenv("HOME");
|
|
|
|
if(homepath.size())
|
|
|
|
import(homepath + "/.isaac/devices/device0.json", queue);
|
2015-08-04 10:06:52 -07:00
|
|
|
|
|
|
|
return result;
|
|
|
|
}
|
|
|
|
|
|
|
|
models::map_type& models::get(driver::CommandQueue const & queue)
|
|
|
|
{
|
|
|
|
std::map<driver::CommandQueue, map_type>::iterator it = data_.find(queue);
|
|
|
|
if(it == data_.end())
|
|
|
|
return init(queue);
|
|
|
|
return it->second;
|
|
|
|
}
|
|
|
|
|
|
|
|
void models::set(driver::CommandQueue const & queue, expression_type operation, numeric_type dtype, std::shared_ptr<model> const & model)
|
|
|
|
{
|
|
|
|
data_[queue][std::make_pair(operation,dtype)] = model;
|
|
|
|
}
|
|
|
|
|
2015-08-04 10:53:39 -07:00
|
|
|
std::map<driver::CommandQueue, models::map_type> models::data_;
|
|
|
|
|
2015-08-04 10:06:52 -07:00
|
|
|
//
|
2015-04-29 15:50:57 -04:00
|
|
|
|
2015-07-28 15:26:10 -07:00
|
|
|
std::map<std::pair<expression_type, numeric_type>, std::shared_ptr<templates::base> > init_fallback()
|
2015-01-12 13:20:53 -05:00
|
|
|
{
|
2015-07-28 15:26:10 -07:00
|
|
|
typedef std::shared_ptr<templates::base> ptr_t;
|
2015-04-29 15:50:57 -04:00
|
|
|
std::map<std::pair<expression_type, numeric_type>, ptr_t > res;
|
2015-01-12 13:20:53 -05:00
|
|
|
numeric_type types[] = {CHAR_TYPE, UCHAR_TYPE, SHORT_TYPE, USHORT_TYPE, INT_TYPE, UINT_TYPE, LONG_TYPE, ULONG_TYPE, FLOAT_TYPE, DOUBLE_TYPE};
|
2015-04-29 15:50:57 -04:00
|
|
|
for(auto DTYPE : types)
|
|
|
|
{
|
2015-07-11 09:36:01 -04:00
|
|
|
res[std::make_pair(AXPY_TYPE, DTYPE)] = ptr_t (new templates::axpy(1,64,128,templates::FETCH_FROM_GLOBAL_STRIDED));
|
|
|
|
res[std::make_pair(DOT_TYPE, DTYPE)] = ptr_t(new templates::dot(1,64,128,templates::FETCH_FROM_GLOBAL_STRIDED));
|
|
|
|
res[std::make_pair(GER_TYPE, DTYPE)] = ptr_t(new templates::ger(1,8,8,8,8,templates::FETCH_FROM_GLOBAL_STRIDED));
|
|
|
|
res[std::make_pair(GEMV_N_TYPE, DTYPE)] = ptr_t(new templates::gemv_n(1, 8, 8, 4, 16, templates::FETCH_FROM_GLOBAL_STRIDED));
|
|
|
|
res[std::make_pair(GEMV_T_TYPE, DTYPE)] = ptr_t(new templates::gemv_t(1, 8, 8, 64, 8, templates::FETCH_FROM_GLOBAL_STRIDED));
|
2015-07-16 13:29:07 -04:00
|
|
|
res[std::make_pair(GEMM_NN_TYPE, DTYPE)] = ptr_t(new templates::gemm_nn(1, 8, 16, 8, 1, 8, 1, 8, templates::FETCH_FROM_LOCAL, templates::FETCH_FROM_LOCAL, 8, 8, true));
|
|
|
|
res[std::make_pair(GEMM_TN_TYPE, DTYPE)] = ptr_t(new templates::gemm_tn(1, 8, 16, 8, 1, 8, 1, 8, templates::FETCH_FROM_LOCAL, templates::FETCH_FROM_LOCAL, 8, 8, true));
|
2015-07-11 09:36:01 -04:00
|
|
|
res[std::make_pair(GEMM_NT_TYPE, DTYPE)] = ptr_t(new templates::gemm_nt(1, 8, 16, 8, 1, 8, 1, 8, templates::FETCH_FROM_LOCAL, templates::FETCH_FROM_LOCAL, 8, 8, true));
|
2015-07-16 13:29:07 -04:00
|
|
|
res[std::make_pair(GEMM_TT_TYPE, DTYPE)] = ptr_t(new templates::gemm_tt(1, 8, 16, 8, 1, 8, 1, 8, templates::FETCH_FROM_LOCAL, templates::FETCH_FROM_LOCAL, 8, 8, true));
|
2015-01-12 13:20:53 -05:00
|
|
|
}
|
|
|
|
return res;
|
|
|
|
}
|
|
|
|
|
|
|
|
|
2015-07-28 15:26:10 -07:00
|
|
|
std::map<std::pair<expression_type, numeric_type>, std::shared_ptr<templates::base> > fallbacks = init_fallback();
|
2015-08-04 10:06:52 -07:00
|
|
|
|
2015-01-12 13:20:53 -05:00
|
|
|
}
|