diff --git a/lib/runtime/function.cc b/lib/runtime/function.cc index 406dd89a4..fb2daea43 100644 --- a/lib/runtime/function.cc +++ b/lib/runtime/function.cc @@ -347,21 +347,24 @@ function::cache_key_t function::get_key(driver::stream *stream, const std::vecto // returns program with best compilation options for given parameter function::caller* function::autotune(driver::stream* stream, const grid_fn_ty& grid_fn, const std::vector& args) { -// // copy buffer argument so that auto-tuning doesn't corrupt data -// std::list> copies; -// std::vector args = args; -// for(arg& x: args) -// if(x.type() == BUFFER_T){ -// driver::buffer* old = x.buffer(); -// driver::context* ctx = old->context(); -// size_t size = old->size(); -// copies.push_back(std::make_shared(ctx, size)); -// x = arg(copies.back().get()); -// } // fast path -- no autotuning necessary if(callers_.size() == 1) return &*callers_.begin()->second; // slow path -- autotuning necessary + // copy buffer argument so that auto-tuning doesn't corrupt data + std::list> copies; + std::vector _args = args; + for(size_t i = 0; i < args.size(); i++) + if(_args[i].type() == BUFFER_T){ + driver::buffer* old = _args[i].buffer(); + size_t size = old->size(); + // only copy scalars + // TODO: change that + if(size != 4 && size != 2) + continue; + copies.push_back(std::make_shared(old->context(), size)); + _args[i] = arg(copies.back().get()); + } double best_ts = INFINITY; caller* ret = nullptr; for(auto &x : callers_){ @@ -373,6 +376,7 @@ function::caller* function::autotune(driver::stream* stream, const grid_fn_ty& g ret = (ts < best_ts) ? current : ret; best_ts = std::min(ts, best_ts); } + stream->synchronize(); return ret; } diff --git a/python/src/bindings.cc b/python/src/bindings.cc index e36760316..ec4c1c6a1 100644 --- a/python/src/bindings.cc +++ b/python/src/bindings.cc @@ -194,7 +194,7 @@ void gen_make_handles(std::ostream &os, const std::vector& args) if(!arg->get_type()->is_pointer_ty()) continue; const std::string& name = arg->get_name(); - os << " drv::cu_buffer cu_" + name + "(ctx, " + name + "->tensor_data().size(), (CUdeviceptr)" + name + "->tensor_data().data(), false);\n "; + os << " drv::cu_buffer cu_" + name + "(ctx, " + name + "->nbytes(), (CUdeviceptr)" + name + "->tensor_data().data(), false);\n "; } } @@ -524,7 +524,7 @@ void gen_torch_make_handles(std::ostream &os, os << " " << to_c_ty(arg) << " " << name << " = " << th_name << ";" << std::endl; else{ os << " CHECK_INPUT(" << th_name << ");" << std::endl; - os << " drv::cu_buffer " + name + "(ctx, " + th_name + ".storage().size(), " + os << " drv::cu_buffer " + name + "(ctx, " + th_name + ".nbytes(), " " (CUdeviceptr)((char*)" + th_name + ".storage().data() + " + th_name + ".storage_offset() * " + th_name + ".itemsize()), false);" << std::endl; } } @@ -561,16 +561,7 @@ void gen_torch_make_launch_function(std::ostream &os, os << "args.push_back(rt::arg(ty" << i << ", val" << i << "));\n "; } os << " std::function run = [&](){\n "; - os << " (*id_fn_map.at({id, dev_id}))({"; - for(unsigned i = 0; i < args.size() ; i++){ - std::string name = "arg_" + std::to_string(i); - if(args[i] == rt::BUFFER_T) - name = "&" + name; - if(i > 0) - os << ", "; - os << name; - } - os << "}, *id_grid_map.at({id, dev_id}), &stream);\n"; + os << " (*id_fn_map.at({id, dev_id}))(args , *id_grid_map.at({id, dev_id}), &stream);\n"; os << " };\n"; os << " run();\n"; os << " if(bench > 0)\n ";