Models: Now initialize with preset if existing
This commit is contained in:
@@ -15,6 +15,7 @@ namespace isaac
|
||||
|
||||
struct database
|
||||
{
|
||||
typedef std::map<std::tuple<driver::Device::Vendor, driver::Device::Architecture> , const char *> presets_type;
|
||||
public:
|
||||
typedef std::map<std::pair<expression_type, numeric_type>, std::shared_ptr<model> > map_type;
|
||||
private:
|
||||
@@ -24,7 +25,7 @@ public:
|
||||
static map_type & get(driver::CommandQueue const & queue);
|
||||
static void set(driver::CommandQueue const & queue, expression_type operation, numeric_type dtype, std::shared_ptr<model> const & model);
|
||||
private:
|
||||
static const std::map<std::tuple<driver::Device::Vendor, driver::Device::Architecture> , const char *> presets_;
|
||||
static const presets_type presets_;
|
||||
static std::map<driver::CommandQueue, map_type> cache_;
|
||||
};
|
||||
|
||||
|
@@ -47,20 +47,13 @@ namespace detail
|
||||
}
|
||||
}
|
||||
|
||||
void database::import(std::string const & fname, driver::CommandQueue const & queue)
|
||||
void database::import(std::string const & str, driver::CommandQueue const & queue)
|
||||
{
|
||||
namespace js = rapidjson;
|
||||
map_type & result = cache_[queue];
|
||||
|
||||
//Parse the JSON document
|
||||
js::Document document;
|
||||
std::ifstream t(fname.c_str());
|
||||
if(!t) return;
|
||||
std::string str;
|
||||
t.seekg(0, std::ios::end);
|
||||
str.reserve(t.tellg());
|
||||
t.seekg(0, std::ios::beg);
|
||||
str.assign((std::istreambuf_iterator<char>(t)), std::istreambuf_iterator<char>());
|
||||
document.Parse<0>(str.c_str());
|
||||
//Deserialize
|
||||
std::vector<std::string> operations = {"axpy", "dot", "ger", "gemv_n", "gemv_t", "gemm_nn", "gemm_tn", "gemm_nt", "gemm_tt"};
|
||||
@@ -107,9 +100,26 @@ database::map_type& database::init(driver::CommandQueue const & queue)
|
||||
for(expression_type etype: etypes)
|
||||
result[std::make_pair(etype, dtype)] = std::shared_ptr<model>(new model(etype, dtype, *fallbacks[std::make_pair(etype, dtype)], queue));
|
||||
|
||||
driver::Device const & device = queue.device();
|
||||
presets_type::const_iterator it = presets_.find(std::make_tuple(device.vendor(), device.architecture()));
|
||||
if(it==presets_.end())
|
||||
import(presets_.at(std::make_tuple(device.vendor(), driver::Device::Architecture::UNKNOWN)), queue);
|
||||
else
|
||||
import(it->second, queue);
|
||||
std::string homepath = tools::getenv("HOME");
|
||||
if(homepath.size())
|
||||
import(homepath + "/.isaac/devices/device0.json", queue);
|
||||
{
|
||||
std::string json_path = homepath + "/.isaac/devices/device0.json";
|
||||
std::ifstream t(json_path);
|
||||
if(!t)
|
||||
return result;
|
||||
std::string str;
|
||||
t.seekg(0, std::ios::end);
|
||||
str.reserve(t.tellg());
|
||||
t.seekg(0, std::ios::beg);
|
||||
str.assign((std::istreambuf_iterator<char>(t)), std::istreambuf_iterator<char>());
|
||||
import(str, queue);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
@@ -115,7 +115,7 @@ def main():
|
||||
include =' src/include'.split() + ['external/boost/include', os.path.join(find_module("numpy")[1], "core", "include")]
|
||||
|
||||
#Source files
|
||||
src = 'src/lib/array.cpp src/lib/wrap/clBLAS.cpp src/lib/value_scalar.cpp src/lib/kernels/mapped_object.cpp src/lib/kernels/parse.cpp src/lib/kernels/templates/gemv.cpp src/lib/kernels/templates/gemm.cpp src/lib/kernels/templates/axpy.cpp src/lib/kernels/templates/base.cpp src/lib/kernels/templates/ger.cpp src/lib/kernels/templates/dot.cpp src/lib/kernels/stream.cpp src/lib/kernels/keywords.cpp src/lib/kernels/binder.cpp src/lib/symbolic/preset.cpp src/lib/symbolic/expression.cpp src/lib/symbolic/execute.cpp src/lib/symbolic/io.cpp src/lib/model/model.cpp src/lib/model/database.cpp src/lib/model/predictors/random_forest.cpp src/lib/exception/unknown_datatype.cpp src/lib/exception/operation_not_supported.cpp src/lib/driver/program.cpp src/lib/driver/device.cpp src/lib/driver/event.cpp src/lib/driver/program_cache.cpp src/lib/driver/ndrange.cpp src/lib/driver/kernel.cpp src/lib/driver/handle.cpp src/lib/driver/command_queue.cpp src/lib/driver/check.cpp src/lib/driver/buffer.cpp src/lib/driver/backend.cpp src/lib/driver/context.cpp src/lib/driver/platform.cpp '.split() + [os.path.join('src', 'wrap', sf) for sf in ['_isaac.cpp', 'core.cpp', 'driver.cpp', 'model.cpp', 'exceptions.cpp']]
|
||||
src = 'src/lib/array.cpp src/lib/wrap/clBLAS.cpp src/lib/value_scalar.cpp src/lib/kernels/mapped_object.cpp src/lib/kernels/parse.cpp src/lib/kernels/templates/ger.cpp src/lib/kernels/templates/gemv.cpp src/lib/kernels/templates/gemm.cpp src/lib/kernels/templates/dot.cpp src/lib/kernels/templates/base.cpp src/lib/kernels/templates/axpy.cpp src/lib/kernels/stream.cpp src/lib/kernels/keywords.cpp src/lib/kernels/binder.cpp src/lib/symbolic/io.cpp src/lib/symbolic/execute.cpp src/lib/symbolic/preset.cpp src/lib/symbolic/expression.cpp src/lib/model/model.cpp src/lib/model/database.cpp src/lib/model/predictors/random_forest.cpp src/lib/exception/unknown_datatype.cpp src/lib/exception/operation_not_supported.cpp src/lib/driver/program.cpp src/lib/driver/platform.cpp src/lib/driver/device.cpp src/lib/driver/context.cpp src/lib/driver/event.cpp src/lib/driver/program_cache.cpp src/lib/driver/ndrange.cpp src/lib/driver/kernel.cpp src/lib/driver/handle.cpp src/lib/driver/command_queue.cpp src/lib/driver/check.cpp src/lib/driver/buffer.cpp src/lib/driver/backend.cpp '.split() + [os.path.join('src', 'wrap', sf) for sf in ['_isaac.cpp', 'core.cpp', 'driver.cpp', 'model.cpp', 'exceptions.cpp']]
|
||||
boostsrc = 'external/boost/libs/'
|
||||
for s in ['numpy','python','smart_ptr','system','thread']:
|
||||
src = src + [x for x in recursive_glob('external/boost/libs/' + s + '/src/','.cpp') if 'win32' not in x and 'pthread' not in x]
|
||||
|
Reference in New Issue
Block a user