[RUNTIME] Auto-tuning now works as expected when the values of

autotune_key change
This commit is contained in:
Philippe Tillet
2021-01-31 14:17:27 -05:00
parent 52af8cda34
commit 3fde4b8f5b
7 changed files with 53 additions and 15 deletions

View File

@@ -211,6 +211,8 @@ kernel::kernel(const std::string& src, const options_t& opt, driver::device *dev
init_ir(preheader() + src);
init_ker();
init_sig();
for(auto arg: ir_->get_function_list()[0]->args())
arg_names_.push_back(arg->get_name());
}
void kernel::operator()(void *args, size_t args_size, driver::stream *stream, const std::vector<size_t>& _grid) const{
@@ -328,7 +330,11 @@ kernel* function::autotune(void* args, size_t args_size, const grid_fn_ty& grid_
if(kernels_.size() == 1)
return &*kernels_.begin()->second;
// auto-tuning key
std::vector<uint64_t> key;
std::vector<uint64_t> key(key_idxs_.size());
for(size_t i = 0; i < key.size(); i++){
int idx = key_idxs_[i];
std::memcpy((void*)&key[i], (void*)((char*)args + arg_off_[idx]), arg_size_[idx]);
}
auto it = cache_.find(key);
if(it != cache_.end())
return it->second;
@@ -350,8 +356,26 @@ kernel* function::autotune(void* args, size_t args_size, const grid_fn_ty& grid_
return it->second;
}
function::function(const std::string& src, const options_space_t& opt, driver::device *device) {
function::function(const std::string& src, const options_space_t& opt,
driver::device *device, const std::vector<std::string>& autotune_key) {
init_kernels(src, opt, device);
auto arg_names = kernels_.at(0).second->get_arg_names();
for(const std::string& name: autotune_key){
auto it = std::find(arg_names.begin(), arg_names.end(), name);
if(it == arg_names.end())
throw std::runtime_error(name + " is not a valid argument name");
key_idxs_.push_back(std::distance(arg_names.begin(), it));
}
// argument size and offset
auto tys = kernels_.at(0).second->get_sig();
size_t curr = 0;
for(arg_type ty: tys){
arg_size_.push_back(size_of(ty));
arg_off_.push_back(curr);
curr += arg_size_.back();
}
}
void function::operator()(void* args, size_t args_size, const grid_fn_ty& grid_fn, driver::stream *stream) {