#include #include "rapidjson/document.h" #include "rapidjson/to_array.hpp" #include "isaac/database/database.h" #include "isaac/kernels/parse.h" #include "isaac/kernels/templates/axpy.h" #include "isaac/kernels/templates/dot.h" #include "isaac/kernels/templates/ger.h" #include "isaac/kernels/templates/gemv.h" #include "isaac/kernels/templates/gemm.h" #include "presets/broadwell.hpp" #include "getenv.hpp" namespace isaac { std::shared_ptr database::create(std::string const & template_name, std::vector const & x) { templates::fetching_policy_type fetch[] = {templates::FETCH_FROM_LOCAL, templates::FETCH_FROM_GLOBAL_STRIDED, templates::FETCH_FROM_GLOBAL_CONTIGUOUS}; if(template_name=="axpy") return std::shared_ptr(new templates::axpy(x[0], x[1], x[2], fetch[x[3]])); else if(template_name=="dot") return std::shared_ptr(new templates::dot(x[0], x[1], x[2], fetch[x[3]])); else if(template_name=="ger") return std::shared_ptr(new templates::ger(x[0], x[1], x[2], x[3], x[4], fetch[x[5]])); else if(template_name.find("gemv_n")!=std::string::npos) return std::shared_ptr(new templates::gemv_n(x[0], x[1], x[2], x[3], x[4], fetch[x[5]])); else if(template_name.find("gemv_t")!=std::string::npos) return std::shared_ptr(new templates::gemv_t(x[0], x[1], x[2], x[3], x[4], fetch[x[5]])); else if(template_name.find("gemm_nn")!=std::string::npos) return std::shared_ptr(new templates::gemm_nn(x[0], x[1], x[2], x[3], x[4], x[5], x[6], x[7], fetch[x[8]], fetch[x[9]], x[10], x[11])); else if(template_name.find("gemm_tn")!=std::string::npos) return std::shared_ptr(new templates::gemm_tn(x[0], x[1], x[2], x[3], x[4], x[5], x[6], x[7], fetch[x[8]], fetch[x[9]], x[10], x[11])); else if(template_name.find("gemm_nt")!=std::string::npos) return std::shared_ptr(new templates::gemm_nt(x[0], x[1], x[2], x[3], x[4], x[5], x[6], x[7], fetch[x[8]], fetch[x[9]], x[10], x[11])); else if(template_name.find("gemm_tt")!=std::string::npos) return std::shared_ptr(new templates::gemm_tt(x[0], x[1], x[2], x[3], x[4], x[5], x[6], x[7], fetch[x[8]], fetch[x[9]], x[10], x[11])); else throw std::invalid_argument("Invalid expression: " + template_name); } void database::import(std::string const & str, driver::CommandQueue const & queue) { map_type & result = cache_[queue]; //Parse the JSON document rapidjson::Document document; document.Parse<0>(str.c_str()); //Deserialize std::vector operations = {"axpy", "dot", "ger", "gemv_n", "gemv_t", "gemm_nn", "gemm_tn", "gemm_nt", "gemm_tt"}; std::vector dtype = {"float32", "float64"}; for(auto & operation : operations) { const char * opcstr = operation.c_str(); if(document.HasMember(opcstr)) { expression_type etype = expression_type_from_string(operation); for(auto & elem : dtype) { const char * dtcstr = elem.c_str(); if(document[opcstr].HasMember(dtcstr)) { numeric_type dtype = numeric_type_from_string(elem); // Get profiles std::vector > templates; rapidjson::Value const & profiles = document[opcstr][dtcstr]["profiles"]; for (rapidjson::SizeType id = 0 ; id < profiles.Size() ; ++id) templates.push_back(create(operation, rapidjson::to_int_array(profiles[id]))); if(templates.size()>1) { // Get predictor predictors::random_forest predictor(document[opcstr][dtcstr]["predictor"]); result[std::make_pair(etype, dtype)] = std::shared_ptr(new model(etype, dtype, predictor, templates, queue)); } else result[std::make_pair(etype, dtype)] = std::shared_ptr(new model(etype, dtype, *templates[0], queue)); } } } } } database::map_type& database::init(driver::CommandQueue const & queue) { map_type & result = cache_[queue]; numeric_type dtypes[] = {CHAR_TYPE, UCHAR_TYPE, SHORT_TYPE, USHORT_TYPE, INT_TYPE, UINT_TYPE, LONG_TYPE, ULONG_TYPE, FLOAT_TYPE, DOUBLE_TYPE}; expression_type etypes[] = {AXPY_TYPE, DOT_TYPE, GER_TYPE, GEMV_N_TYPE, GEMV_T_TYPE, GEMM_NN_TYPE, GEMM_NT_TYPE, GEMM_TN_TYPE, GEMM_TT_TYPE}; for(numeric_type dtype: dtypes) for(expression_type etype: etypes) result[std::make_pair(etype, dtype)] = std::shared_ptr(new model(etype, dtype, *fallbacks[std::make_pair(etype, dtype)], queue)); driver::Device const & device = queue.device(); presets_type::const_iterator it = presets_.find(std::make_tuple(device.vendor(), device.architecture())); if(it==presets_.end()) import(presets_.at(std::make_tuple(device.vendor(), driver::Device::Architecture::UNKNOWN)), queue); else import(it->second, queue); std::string homepath = tools::getenv("HOME"); if(homepath.size()) { std::string json_path = homepath + "/.isaac/devices/device0.json"; std::ifstream t(json_path); if(!t) return result; std::string str; t.seekg(0, std::ios::end); str.reserve(t.tellg()); t.seekg(0, std::ios::beg); str.assign((std::istreambuf_iterator(t)), std::istreambuf_iterator()); import(str, queue); } return result; } database::map_type& database::get(driver::CommandQueue const & queue) { std::map::iterator it = cache_.find(queue); if(it == cache_.end()) return init(queue); return it->second; } void database::set(driver::CommandQueue const & queue, expression_type operation, numeric_type dtype, std::shared_ptr const & model) { cache_[queue][std::make_pair(operation,dtype)] = model; } std::map database::cache_; }