Models: cleaning of the global caching mechanism

This commit is contained in:
Philippe Tillet
2015-08-04 10:06:52 -07:00
parent d88ff6b39b
commit df2d5e7d00
6 changed files with 53 additions and 34 deletions

View File

@@ -179,9 +179,11 @@ namespace detail
}
}
void import(std::string const & fname, driver::CommandQueue & queue, model_map_t& result)
void models::import(std::string const & fname, driver::CommandQueue const & queue)
{
namespace js = rapidjson;
map_type & result = data_[queue];
//Parse the JSON document
js::Document document;
std::ifstream t(fname.c_str());
@@ -228,6 +230,37 @@ void import(std::string const & fname, driver::CommandQueue & queue, model_map_t
}
}
models::map_type& models::init(driver::CommandQueue const & queue)
{
map_type & result = data_[queue];
numeric_type dtypes[] = {CHAR_TYPE, UCHAR_TYPE, SHORT_TYPE, USHORT_TYPE, INT_TYPE, UINT_TYPE, LONG_TYPE, ULONG_TYPE, FLOAT_TYPE, DOUBLE_TYPE};
expression_type etypes[] = {AXPY_TYPE, DOT_TYPE, GER_TYPE, GEMV_N_TYPE, GEMV_T_TYPE, GEMM_NN_TYPE, GEMM_NT_TYPE, GEMM_TN_TYPE, GEMM_TT_TYPE};
for(numeric_type dtype: dtypes)
for(expression_type etype: etypes)
result[std::make_pair(etype, dtype)] = std::shared_ptr<model>(new model(etype, dtype, *fallbacks[std::make_pair(etype, dtype)], queue));
if(const char * homepath = std::getenv("HOME"))
import(std::string(homepath) + "/.isaac/devices/device0.json", queue);
return result;
}
models::map_type& models::get(driver::CommandQueue const & queue)
{
std::map<driver::CommandQueue, map_type>::iterator it = data_.find(queue);
if(it == data_.end())
return init(queue);
return it->second;
}
void models::set(driver::CommandQueue const & queue, expression_type operation, numeric_type dtype, std::shared_ptr<model> const & model)
{
data_[queue][std::make_pair(operation,dtype)] = model;
}
//
std::map<std::pair<expression_type, numeric_type>, std::shared_ptr<templates::base> > init_fallback()
{
@@ -249,30 +282,8 @@ std::map<std::pair<expression_type, numeric_type>, std::shared_ptr<templates::ba
return res;
}
//TODO: Clean everything by overloading operator[]
model_map_t init_models(driver::CommandQueue & queue)
{
model_map_t res;
numeric_type dtypes[] = {CHAR_TYPE, UCHAR_TYPE, SHORT_TYPE, USHORT_TYPE, INT_TYPE, UINT_TYPE, LONG_TYPE, ULONG_TYPE, FLOAT_TYPE, DOUBLE_TYPE};
expression_type etypes[] = {AXPY_TYPE, DOT_TYPE, GER_TYPE, GEMV_N_TYPE, GEMV_T_TYPE, GEMM_NN_TYPE, GEMM_NT_TYPE, GEMM_TN_TYPE, GEMM_TT_TYPE};
for(numeric_type dtype: dtypes)
for(expression_type etype: etypes)
res[std::make_pair(etype, dtype)] = std::shared_ptr<model>(new model(etype, dtype, *fallbacks[std::make_pair(etype, dtype)], queue));
if(const char * homepath = std::getenv("HOME"))
import(std::string(homepath) + "/.isaac/devices/device0.json", queue, res);
return res;
}
model_map_t& models(driver::CommandQueue & queue)
{
static std::map<driver::Device, model_map_t> models_;
std::map<driver::Device, model_map_t>::iterator it = models_.find(queue.device());
if(it == models_.end())
return models_.insert(std::make_pair(queue.device(), init_models(queue))).first->second;
return it->second;
}
std::map<driver::CommandQueue, models::map_type> models::data_;
std::map<std::pair<expression_type, numeric_type>, std::shared_ptr<templates::base> > fallbacks = init_fallback();
}