[RUNTIME] Auto-tuning now works as expected when the values of

autotune_key change
This commit is contained in:
Philippe Tillet
2021-01-31 14:17:27 -05:00
parent 52af8cda34
commit 3fde4b8f5b
7 changed files with 53 additions and 15 deletions

View File

@@ -82,6 +82,7 @@ public:
void operator()(void* args, size_t args_size, driver::stream *stream, const std::vector<size_t>& grid) const;
// getters
const std::vector<arg_type>& get_sig() const { return sig_; }
const std::vector<std::string>& get_arg_names() const { return arg_names_; }
std::string get_asm(asm_mode_t mode);
private:
@@ -96,6 +97,7 @@ private:
driver::device* dev_;
// signature
std::vector<arg_type> sig_;
std::vector<std::string> arg_names_;
// triton context for parsing
ir::context ctx_;
// handles
@@ -114,7 +116,7 @@ private:
static void do_loop_nest(std::vector<size_t> const & ranges,
std::function<void(std::vector<size_t> const &)> const & f);
public:
function(const std::string& src, const options_space_t& opt, driver::device *device);
function(const std::string& src, const options_space_t& opt, driver::device *device, const std::vector<std::string> &autotune_key = {});
void operator()(void* args, size_t args_size, const grid_fn_ty& grid, driver::stream *stream);
void operator()(void* args, size_t args_size, const grid_t& grid, driver::stream *stream);
// auto-tuning
@@ -129,6 +131,9 @@ private:
private:
std::vector<kernel_pair_t> kernels_;
std::map<std::vector<uint64_t>, kernel*> cache_;
std::vector<int> key_idxs_;
std::vector<int> arg_size_;
std::vector<int> arg_off_;
};
}