[RUNTIME] Auto-tuning now works as expected when the values of
autotune_key change
This commit is contained in:
@@ -211,6 +211,8 @@ kernel::kernel(const std::string& src, const options_t& opt, driver::device *dev
|
||||
init_ir(preheader() + src);
|
||||
init_ker();
|
||||
init_sig();
|
||||
for(auto arg: ir_->get_function_list()[0]->args())
|
||||
arg_names_.push_back(arg->get_name());
|
||||
}
|
||||
|
||||
void kernel::operator()(void *args, size_t args_size, driver::stream *stream, const std::vector<size_t>& _grid) const{
|
||||
@@ -328,7 +330,11 @@ kernel* function::autotune(void* args, size_t args_size, const grid_fn_ty& grid_
|
||||
if(kernels_.size() == 1)
|
||||
return &*kernels_.begin()->second;
|
||||
// auto-tuning key
|
||||
std::vector<uint64_t> key;
|
||||
std::vector<uint64_t> key(key_idxs_.size());
|
||||
for(size_t i = 0; i < key.size(); i++){
|
||||
int idx = key_idxs_[i];
|
||||
std::memcpy((void*)&key[i], (void*)((char*)args + arg_off_[idx]), arg_size_[idx]);
|
||||
}
|
||||
auto it = cache_.find(key);
|
||||
if(it != cache_.end())
|
||||
return it->second;
|
||||
@@ -350,8 +356,26 @@ kernel* function::autotune(void* args, size_t args_size, const grid_fn_ty& grid_
|
||||
return it->second;
|
||||
}
|
||||
|
||||
function::function(const std::string& src, const options_space_t& opt, driver::device *device) {
|
||||
function::function(const std::string& src, const options_space_t& opt,
|
||||
driver::device *device, const std::vector<std::string>& autotune_key) {
|
||||
init_kernels(src, opt, device);
|
||||
auto arg_names = kernels_.at(0).second->get_arg_names();
|
||||
for(const std::string& name: autotune_key){
|
||||
auto it = std::find(arg_names.begin(), arg_names.end(), name);
|
||||
if(it == arg_names.end())
|
||||
throw std::runtime_error(name + " is not a valid argument name");
|
||||
key_idxs_.push_back(std::distance(arg_names.begin(), it));
|
||||
}
|
||||
// argument size and offset
|
||||
auto tys = kernels_.at(0).second->get_sig();
|
||||
size_t curr = 0;
|
||||
for(arg_type ty: tys){
|
||||
arg_size_.push_back(size_of(ty));
|
||||
arg_off_.push_back(curr);
|
||||
curr += arg_size_.back();
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
void function::operator()(void* args, size_t args_size, const grid_fn_ty& grid_fn, driver::stream *stream) {
|
||||
|
Reference in New Issue
Block a user