Models: cleaning of the global caching mechanism
This commit is contained in:
@@ -179,9 +179,11 @@ namespace detail
|
||||
}
|
||||
}
|
||||
|
||||
void import(std::string const & fname, driver::CommandQueue & queue, model_map_t& result)
|
||||
void models::import(std::string const & fname, driver::CommandQueue const & queue)
|
||||
{
|
||||
namespace js = rapidjson;
|
||||
map_type & result = data_[queue];
|
||||
|
||||
//Parse the JSON document
|
||||
js::Document document;
|
||||
std::ifstream t(fname.c_str());
|
||||
@@ -228,6 +230,37 @@ void import(std::string const & fname, driver::CommandQueue & queue, model_map_t
|
||||
}
|
||||
}
|
||||
|
||||
models::map_type& models::init(driver::CommandQueue const & queue)
|
||||
{
|
||||
map_type & result = data_[queue];
|
||||
|
||||
numeric_type dtypes[] = {CHAR_TYPE, UCHAR_TYPE, SHORT_TYPE, USHORT_TYPE, INT_TYPE, UINT_TYPE, LONG_TYPE, ULONG_TYPE, FLOAT_TYPE, DOUBLE_TYPE};
|
||||
expression_type etypes[] = {AXPY_TYPE, DOT_TYPE, GER_TYPE, GEMV_N_TYPE, GEMV_T_TYPE, GEMM_NN_TYPE, GEMM_NT_TYPE, GEMM_TN_TYPE, GEMM_TT_TYPE};
|
||||
|
||||
for(numeric_type dtype: dtypes)
|
||||
for(expression_type etype: etypes)
|
||||
result[std::make_pair(etype, dtype)] = std::shared_ptr<model>(new model(etype, dtype, *fallbacks[std::make_pair(etype, dtype)], queue));
|
||||
|
||||
if(const char * homepath = std::getenv("HOME"))
|
||||
import(std::string(homepath) + "/.isaac/devices/device0.json", queue);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
models::map_type& models::get(driver::CommandQueue const & queue)
|
||||
{
|
||||
std::map<driver::CommandQueue, map_type>::iterator it = data_.find(queue);
|
||||
if(it == data_.end())
|
||||
return init(queue);
|
||||
return it->second;
|
||||
}
|
||||
|
||||
void models::set(driver::CommandQueue const & queue, expression_type operation, numeric_type dtype, std::shared_ptr<model> const & model)
|
||||
{
|
||||
data_[queue][std::make_pair(operation,dtype)] = model;
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
std::map<std::pair<expression_type, numeric_type>, std::shared_ptr<templates::base> > init_fallback()
|
||||
{
|
||||
@@ -249,30 +282,8 @@ std::map<std::pair<expression_type, numeric_type>, std::shared_ptr<templates::ba
|
||||
return res;
|
||||
}
|
||||
|
||||
//TODO: Clean everything by overloading operator[]
|
||||
model_map_t init_models(driver::CommandQueue & queue)
|
||||
{
|
||||
model_map_t res;
|
||||
numeric_type dtypes[] = {CHAR_TYPE, UCHAR_TYPE, SHORT_TYPE, USHORT_TYPE, INT_TYPE, UINT_TYPE, LONG_TYPE, ULONG_TYPE, FLOAT_TYPE, DOUBLE_TYPE};
|
||||
expression_type etypes[] = {AXPY_TYPE, DOT_TYPE, GER_TYPE, GEMV_N_TYPE, GEMV_T_TYPE, GEMM_NN_TYPE, GEMM_NT_TYPE, GEMM_TN_TYPE, GEMM_TT_TYPE};
|
||||
|
||||
for(numeric_type dtype: dtypes)
|
||||
for(expression_type etype: etypes)
|
||||
res[std::make_pair(etype, dtype)] = std::shared_ptr<model>(new model(etype, dtype, *fallbacks[std::make_pair(etype, dtype)], queue));
|
||||
|
||||
if(const char * homepath = std::getenv("HOME"))
|
||||
import(std::string(homepath) + "/.isaac/devices/device0.json", queue, res);
|
||||
return res;
|
||||
}
|
||||
|
||||
model_map_t& models(driver::CommandQueue & queue)
|
||||
{
|
||||
static std::map<driver::Device, model_map_t> models_;
|
||||
std::map<driver::Device, model_map_t>::iterator it = models_.find(queue.device());
|
||||
if(it == models_.end())
|
||||
return models_.insert(std::make_pair(queue.device(), init_models(queue))).first->second;
|
||||
return it->second;
|
||||
}
|
||||
std::map<driver::CommandQueue, models::map_type> models::data_;
|
||||
|
||||
std::map<std::pair<expression_type, numeric_type>, std::shared_ptr<templates::base> > fallbacks = init_fallback();
|
||||
|
||||
}
|
||||
|
Reference in New Issue
Block a user