Models: Now initialize with preset if existing

This commit is contained in:
Philippe Tillet
2015-08-06 16:39:48 -07:00
parent e4ff883688
commit a730e11148
3 changed files with 22 additions and 11 deletions

View File

@@ -15,6 +15,7 @@ namespace isaac
struct database struct database
{ {
typedef std::map<std::tuple<driver::Device::Vendor, driver::Device::Architecture> , const char *> presets_type;
public: public:
typedef std::map<std::pair<expression_type, numeric_type>, std::shared_ptr<model> > map_type; typedef std::map<std::pair<expression_type, numeric_type>, std::shared_ptr<model> > map_type;
private: private:
@@ -24,7 +25,7 @@ public:
static map_type & get(driver::CommandQueue const & queue); static map_type & get(driver::CommandQueue const & queue);
static void set(driver::CommandQueue const & queue, expression_type operation, numeric_type dtype, std::shared_ptr<model> const & model); static void set(driver::CommandQueue const & queue, expression_type operation, numeric_type dtype, std::shared_ptr<model> const & model);
private: private:
static const std::map<std::tuple<driver::Device::Vendor, driver::Device::Architecture> , const char *> presets_; static const presets_type presets_;
static std::map<driver::CommandQueue, map_type> cache_; static std::map<driver::CommandQueue, map_type> cache_;
}; };

View File

@@ -47,20 +47,13 @@ namespace detail
} }
} }
void database::import(std::string const & fname, driver::CommandQueue const & queue) void database::import(std::string const & str, driver::CommandQueue const & queue)
{ {
namespace js = rapidjson; namespace js = rapidjson;
map_type & result = cache_[queue]; map_type & result = cache_[queue];
//Parse the JSON document //Parse the JSON document
js::Document document; js::Document document;
std::ifstream t(fname.c_str());
if(!t) return;
std::string str;
t.seekg(0, std::ios::end);
str.reserve(t.tellg());
t.seekg(0, std::ios::beg);
str.assign((std::istreambuf_iterator<char>(t)), std::istreambuf_iterator<char>());
document.Parse<0>(str.c_str()); document.Parse<0>(str.c_str());
//Deserialize //Deserialize
std::vector<std::string> operations = {"axpy", "dot", "ger", "gemv_n", "gemv_t", "gemm_nn", "gemm_tn", "gemm_nt", "gemm_tt"}; std::vector<std::string> operations = {"axpy", "dot", "ger", "gemv_n", "gemv_t", "gemm_nn", "gemm_tn", "gemm_nt", "gemm_tt"};
@@ -107,9 +100,26 @@ database::map_type& database::init(driver::CommandQueue const & queue)
for(expression_type etype: etypes) for(expression_type etype: etypes)
result[std::make_pair(etype, dtype)] = std::shared_ptr<model>(new model(etype, dtype, *fallbacks[std::make_pair(etype, dtype)], queue)); result[std::make_pair(etype, dtype)] = std::shared_ptr<model>(new model(etype, dtype, *fallbacks[std::make_pair(etype, dtype)], queue));
driver::Device const & device = queue.device();
presets_type::const_iterator it = presets_.find(std::make_tuple(device.vendor(), device.architecture()));
if(it==presets_.end())
import(presets_.at(std::make_tuple(device.vendor(), driver::Device::Architecture::UNKNOWN)), queue);
else
import(it->second, queue);
std::string homepath = tools::getenv("HOME"); std::string homepath = tools::getenv("HOME");
if(homepath.size()) if(homepath.size())
import(homepath + "/.isaac/devices/device0.json", queue); {
std::string json_path = homepath + "/.isaac/devices/device0.json";
std::ifstream t(json_path);
if(!t)
return result;
std::string str;
t.seekg(0, std::ios::end);
str.reserve(t.tellg());
t.seekg(0, std::ios::beg);
str.assign((std::istreambuf_iterator<char>(t)), std::istreambuf_iterator<char>());
import(str, queue);
}
return result; return result;
} }

View File

@@ -115,7 +115,7 @@ def main():
include =' src/include'.split() + ['external/boost/include', os.path.join(find_module("numpy")[1], "core", "include")] include =' src/include'.split() + ['external/boost/include', os.path.join(find_module("numpy")[1], "core", "include")]
#Source files #Source files
src = 'src/lib/array.cpp src/lib/wrap/clBLAS.cpp src/lib/value_scalar.cpp src/lib/kernels/mapped_object.cpp src/lib/kernels/parse.cpp src/lib/kernels/templates/gemv.cpp src/lib/kernels/templates/gemm.cpp src/lib/kernels/templates/axpy.cpp src/lib/kernels/templates/base.cpp src/lib/kernels/templates/ger.cpp src/lib/kernels/templates/dot.cpp src/lib/kernels/stream.cpp src/lib/kernels/keywords.cpp src/lib/kernels/binder.cpp src/lib/symbolic/preset.cpp src/lib/symbolic/expression.cpp src/lib/symbolic/execute.cpp src/lib/symbolic/io.cpp src/lib/model/model.cpp src/lib/model/database.cpp src/lib/model/predictors/random_forest.cpp src/lib/exception/unknown_datatype.cpp src/lib/exception/operation_not_supported.cpp src/lib/driver/program.cpp src/lib/driver/device.cpp src/lib/driver/event.cpp src/lib/driver/program_cache.cpp src/lib/driver/ndrange.cpp src/lib/driver/kernel.cpp src/lib/driver/handle.cpp src/lib/driver/command_queue.cpp src/lib/driver/check.cpp src/lib/driver/buffer.cpp src/lib/driver/backend.cpp src/lib/driver/context.cpp src/lib/driver/platform.cpp '.split() + [os.path.join('src', 'wrap', sf) for sf in ['_isaac.cpp', 'core.cpp', 'driver.cpp', 'model.cpp', 'exceptions.cpp']] src = 'src/lib/array.cpp src/lib/wrap/clBLAS.cpp src/lib/value_scalar.cpp src/lib/kernels/mapped_object.cpp src/lib/kernels/parse.cpp src/lib/kernels/templates/ger.cpp src/lib/kernels/templates/gemv.cpp src/lib/kernels/templates/gemm.cpp src/lib/kernels/templates/dot.cpp src/lib/kernels/templates/base.cpp src/lib/kernels/templates/axpy.cpp src/lib/kernels/stream.cpp src/lib/kernels/keywords.cpp src/lib/kernels/binder.cpp src/lib/symbolic/io.cpp src/lib/symbolic/execute.cpp src/lib/symbolic/preset.cpp src/lib/symbolic/expression.cpp src/lib/model/model.cpp src/lib/model/database.cpp src/lib/model/predictors/random_forest.cpp src/lib/exception/unknown_datatype.cpp src/lib/exception/operation_not_supported.cpp src/lib/driver/program.cpp src/lib/driver/platform.cpp src/lib/driver/device.cpp src/lib/driver/context.cpp src/lib/driver/event.cpp src/lib/driver/program_cache.cpp src/lib/driver/ndrange.cpp src/lib/driver/kernel.cpp src/lib/driver/handle.cpp src/lib/driver/command_queue.cpp src/lib/driver/check.cpp src/lib/driver/buffer.cpp src/lib/driver/backend.cpp '.split() + [os.path.join('src', 'wrap', sf) for sf in ['_isaac.cpp', 'core.cpp', 'driver.cpp', 'model.cpp', 'exceptions.cpp']]
boostsrc = 'external/boost/libs/' boostsrc = 'external/boost/libs/'
for s in ['numpy','python','smart_ptr','system','thread']: for s in ['numpy','python','smart_ptr','system','thread']:
src = src + [x for x in recursive_glob('external/boost/libs/' + s + '/src/','.cpp') if 'win32' not in x and 'pthread' not in x] src = src + [x for x in recursive_glob('external/boost/libs/' + s + '/src/','.cpp') if 'win32' not in x and 'pthread' not in x]