[CORE] Auto-tuning now copies scalar buffers. Still needs to copy all buffers that are both read from and written to.
This commit is contained in:
committed by
Philippe Tillet
parent
78cd54b0c8
commit
5995cbff8e
@@ -347,21 +347,24 @@ function::cache_key_t function::get_key(driver::stream *stream, const std::vecto
|
||||
// returns program with best compilation options for given parameter
|
||||
function::caller* function::autotune(driver::stream* stream, const grid_fn_ty& grid_fn,
|
||||
const std::vector<arg>& args) {
|
||||
// // copy buffer argument so that auto-tuning doesn't corrupt data
|
||||
// std::list<std::shared_ptr<driver::cu_buffer>> copies;
|
||||
// std::vector<arg> args = args;
|
||||
// for(arg& x: args)
|
||||
// if(x.type() == BUFFER_T){
|
||||
// driver::buffer* old = x.buffer();
|
||||
// driver::context* ctx = old->context();
|
||||
// size_t size = old->size();
|
||||
// copies.push_back(std::make_shared<driver::cu_buffer>(ctx, size));
|
||||
// x = arg(copies.back().get());
|
||||
// }
|
||||
// fast path -- no autotuning necessary
|
||||
if(callers_.size() == 1)
|
||||
return &*callers_.begin()->second;
|
||||
// slow path -- autotuning necessary
|
||||
// copy buffer argument so that auto-tuning doesn't corrupt data
|
||||
std::list<std::shared_ptr<driver::cu_buffer>> copies;
|
||||
std::vector<arg> _args = args;
|
||||
for(size_t i = 0; i < args.size(); i++)
|
||||
if(_args[i].type() == BUFFER_T){
|
||||
driver::buffer* old = _args[i].buffer();
|
||||
size_t size = old->size();
|
||||
// only copy scalars
|
||||
// TODO: change that
|
||||
if(size != 4 && size != 2)
|
||||
continue;
|
||||
copies.push_back(std::make_shared<driver::cu_buffer>(old->context(), size));
|
||||
_args[i] = arg(copies.back().get());
|
||||
}
|
||||
double best_ts = INFINITY;
|
||||
caller* ret = nullptr;
|
||||
for(auto &x : callers_){
|
||||
@@ -373,6 +376,7 @@ function::caller* function::autotune(driver::stream* stream, const grid_fn_ty& g
|
||||
ret = (ts < best_ts) ? current : ret;
|
||||
best_ts = std::min(ts, best_ts);
|
||||
}
|
||||
stream->synchronize();
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user