diff --git a/bench/blas.cpp b/bench/blas.cpp index 431a7239d..dee5aeb53 100644 --- a/bench/blas.cpp +++ b/bench/blas.cpp @@ -96,7 +96,7 @@ void bench(ad::numeric_type dtype) ad::array x(N, dtype), y(N, dtype); /* ATIDLAS */ y = x + y; queue.flush(); queue.finish(); - BENCHMARK_ATIDLAS(y = ad::controller(x + y, ad::execution_options_type(0, &events)), 3*N*dtsize/t) + BENCHMARK_ATIDLAS(y = ad::control(x + y, ad::execution_options_type(0, &events), ad::dispatcher_options_type(true)), 3*N*dtsize/t) /* clAmdBlas */ #ifdef BENCH_CLAMDBLAS BENCHMARK_CLAMDBLAS(clAmdBlasSaxpy(N, 1, x.data()(), 0, 1, y.data()(), 0, 1, 1, &queue(), 0, NULL, &event()), 3*N*dtsize/t) diff --git a/include/atidlas/model/model.h b/include/atidlas/model/model.h index e3fea979f..1c8d0c0d4 100644 --- a/include/atidlas/model/model.h +++ b/include/atidlas/model/model.h @@ -29,8 +29,6 @@ namespace atidlas model(base const &, cl::CommandQueue &); void execute(controller const &); - void tune(controller const &); - templates_container const & templates() const; private: templates_container templates_; diff --git a/include/atidlas/symbolic/expression.h b/include/atidlas/symbolic/expression.h index 5eae54e62..020edff76 100644 --- a/include/atidlas/symbolic/expression.h +++ b/include/atidlas/symbolic/expression.h @@ -260,7 +260,8 @@ struct execution_options_type struct dispatcher_options_type { - dispatcher_options_type(int _label = -1) : label(_label){} + dispatcher_options_type(bool _tune = false, int _label = -1) : tune(_tune), label(_label){} + bool tune; int label; }; diff --git a/lib/model/model.cpp b/lib/model/model.cpp index 1d139c3f4..d01380818 100644 --- a/lib/model/model.cpp +++ b/lib/model/model.cpp @@ -95,51 +95,41 @@ model::model(base const & tp, cl::CommandQueue & queue) : templates_(1,tp.clone( void model::execute(controller const & expressions) { std::vector & compilers = init(expressions); + std::vector x = templates_[0]->input_sizes(expressions.x()); + + //Specific tuning if requested + if(expressions.dispatcher_options().tune && hardcoded_.find(x)==hardcoded_.end()) + { + std::vector timings(templates_.size()); + tools::timer timer; + for(size_t i = 0 ; i < templates_.size() ; ++i) + { + timer.start(); + templates_[i]->enqueue(queue_, compilers, i, expressions); + queue_.finish(); + timings[i] = timer.get(); + } + //Fill the override + std::vector x = templates_[0]->input_sizes(expressions.x()); + hardcoded_[x] = std::distance(timings.begin(),std::min_element(timings.begin(), timings.end())); + } //Prediction int label = 0; if(expressions.dispatcher_options().label>=0) + label = expressions.dispatcher_options().label; + else if(hardcoded_.find(x)!=hardcoded_.end()) + label = hardcoded_.at(x); + else if(predictor_.get()) { - label = expressions.dispatcher_options().label; - } - else - { - std::vector x = templates_[0]->input_sizes(expressions.x()); - //The user tuned the model specifically for this input size - if(hardcoded_.find(x)!=hardcoded_.end()) - label = hardcoded_.at(x); - //The user bypasses the random forest - else if(predictor_.get()) - { - std::vector predictions = predictor_->predict(x); - label = std::distance(predictions.begin(),std::min_element(predictions.begin(), predictions.end())); - } + std::vector predictions = predictor_->predict(x); + label = std::distance(predictions.begin(),std::min_element(predictions.begin(), predictions.end())); } //Execution return templates_[label]->enqueue(queue_, compilers, label, expressions); } -void model::tune(controller const & expressions) -{ - std::vector & compilers = init(expressions); - - //Collect the timings - std::vector timings(templates_.size()); - tools::timer timer; - for(size_t i = 0 ; i < templates_.size() ; ++i) - { - timer.start(); - templates_[i]->enqueue(queue_, compilers, i, expressions); - queue_.finish(); - timings[i] = timer.get(); - } - - //Fill the override - std::vector x = templates_[0]->input_sizes(expressions.x()); - hardcoded_[x] = std::distance(timings.begin(),std::min_element(timings.begin(), timings.end())); -} - model::templates_container const & model::templates() const { return templates_; }