Inference: careful not to repeatedly modify output when evaluating multiple different templates
This commit is contained in:
@@ -124,6 +124,7 @@ void elementwise_1d::enqueue(driver::CommandQueue &, driver::Program const & pro
|
|||||||
//Kernel
|
//Kernel
|
||||||
std::string name = "elementwise_1d";
|
std::string name = "elementwise_1d";
|
||||||
name += suffix;
|
name += suffix;
|
||||||
|
// std::cout << name << std::endl;
|
||||||
driver::Kernel kernel(program, name.c_str());
|
driver::Kernel kernel(program, name.c_str());
|
||||||
//NDRange
|
//NDRange
|
||||||
driver::NDRange global(ls0_*ng_);
|
driver::NDRange global(ls0_*ng_);
|
||||||
|
@@ -80,8 +80,9 @@ profiles::value_type::value_type(numeric_type dtype, std::shared_ptr<templates::
|
|||||||
void profiles::value_type::execute(runtime::execution_handler const & expr)
|
void profiles::value_type::execute(runtime::execution_handler const & expr)
|
||||||
{
|
{
|
||||||
static const int MAX_TEMPORARY_WORKSPACE = 1e6;
|
static const int MAX_TEMPORARY_WORKSPACE = 1e6;
|
||||||
|
expression_tree const & tree = expr.x();
|
||||||
driver::Program const & program = init(expr);
|
driver::Program const & program = init(expr);
|
||||||
std::vector<int_t> x = templates_[0]->input_sizes(expr.x());
|
std::vector<int_t> x = templates_[0]->input_sizes(tree);
|
||||||
|
|
||||||
//Cached
|
//Cached
|
||||||
auto it = labels_.find(x);
|
auto it = labels_.find(x);
|
||||||
@@ -91,6 +92,16 @@ void profiles::value_type::execute(runtime::execution_handler const & expr)
|
|||||||
}
|
}
|
||||||
|
|
||||||
//Not cached
|
//Not cached
|
||||||
|
expression_tree::node const & root = tree[tree.root()];
|
||||||
|
expression_tree::node const & left = tree[root.binary_operator.lhs];
|
||||||
|
array_base* out = left.array.base;
|
||||||
|
auto read_out = [&](expression_tree::node const & x){
|
||||||
|
return x.type == DENSE_ARRAY_TYPE && (&x != &left) && x.array.base == out;
|
||||||
|
};
|
||||||
|
bool modify_output = std::find_if(tree.data().begin(), tree.data().end(), read_out) != tree.data().end();
|
||||||
|
std::unique_ptr<array> bkp;
|
||||||
|
if(modify_output)
|
||||||
|
bkp.reset(new array(*out));
|
||||||
tools::Timer tmr;
|
tools::Timer tmr;
|
||||||
std::vector<double> times;
|
std::vector<double> times;
|
||||||
std::vector<float> perf = predictor_->predict(x);
|
std::vector<float> perf = predictor_->predict(x);
|
||||||
@@ -100,7 +111,7 @@ void profiles::value_type::execute(runtime::execution_handler const & expr)
|
|||||||
bool valid_found = false;
|
bool valid_found = false;
|
||||||
for(size_t k = 0 ; k < std::min<size_t>(5, idx.size()) || !valid_found ; k++){
|
for(size_t k = 0 ; k < std::min<size_t>(5, idx.size()) || !valid_found ; k++){
|
||||||
size_t i = idx[k];
|
size_t i = idx[k];
|
||||||
if(templates_[i]->temporary_workspace(expr.x()) > MAX_TEMPORARY_WORKSPACE){
|
if(templates_[i]->temporary_workspace(tree) > MAX_TEMPORARY_WORKSPACE){
|
||||||
times.push_back(INFINITY);
|
times.push_back(INFINITY);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@@ -109,7 +120,7 @@ void profiles::value_type::execute(runtime::execution_handler const & expr)
|
|||||||
std::vector<double> ctimes;
|
std::vector<double> ctimes;
|
||||||
while(total_time < 1e-2){
|
while(total_time < 1e-2){
|
||||||
tmr.start();
|
tmr.start();
|
||||||
templates_[i]->enqueue(queue_, program, tools::to_string(i), runtime::execution_handler(expr.x()));
|
templates_[i]->enqueue(queue_, program, tools::to_string(i), runtime::execution_handler(tree));
|
||||||
queue_.synchronize();
|
queue_.synchronize();
|
||||||
ctimes.push_back(1e-9*tmr.get().count());
|
ctimes.push_back(1e-9*tmr.get().count());
|
||||||
total_time += ctimes.back();
|
total_time += ctimes.back();
|
||||||
@@ -122,6 +133,8 @@ void profiles::value_type::execute(runtime::execution_handler const & expr)
|
|||||||
}
|
}
|
||||||
size_t i = idx[std::distance(times.begin(),std::min_element(times.begin(), times.end()))];
|
size_t i = idx[std::distance(times.begin(),std::min_element(times.begin(), times.end()))];
|
||||||
labels_.insert({x, i});
|
labels_.insert({x, i});
|
||||||
|
if(modify_output)
|
||||||
|
*out = *bkp;
|
||||||
templates_[i]->enqueue(queue_, program, tools::to_string(i), expr);
|
templates_[i]->enqueue(queue_, program, tools::to_string(i), expr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user