[RUNTIME] Auto-tuning now works as expected when the values of
autotune_key change
This commit is contained in:
@@ -82,6 +82,7 @@ public:
|
||||
void operator()(void* args, size_t args_size, driver::stream *stream, const std::vector<size_t>& grid) const;
|
||||
// getters
|
||||
const std::vector<arg_type>& get_sig() const { return sig_; }
|
||||
const std::vector<std::string>& get_arg_names() const { return arg_names_; }
|
||||
std::string get_asm(asm_mode_t mode);
|
||||
|
||||
private:
|
||||
@@ -96,6 +97,7 @@ private:
|
||||
driver::device* dev_;
|
||||
// signature
|
||||
std::vector<arg_type> sig_;
|
||||
std::vector<std::string> arg_names_;
|
||||
// triton context for parsing
|
||||
ir::context ctx_;
|
||||
// handles
|
||||
@@ -114,7 +116,7 @@ private:
|
||||
static void do_loop_nest(std::vector<size_t> const & ranges,
|
||||
std::function<void(std::vector<size_t> const &)> const & f);
|
||||
public:
|
||||
function(const std::string& src, const options_space_t& opt, driver::device *device);
|
||||
function(const std::string& src, const options_space_t& opt, driver::device *device, const std::vector<std::string> &autotune_key = {});
|
||||
void operator()(void* args, size_t args_size, const grid_fn_ty& grid, driver::stream *stream);
|
||||
void operator()(void* args, size_t args_size, const grid_t& grid, driver::stream *stream);
|
||||
// auto-tuning
|
||||
@@ -129,6 +131,9 @@ private:
|
||||
private:
|
||||
std::vector<kernel_pair_t> kernels_;
|
||||
std::map<std::vector<uint64_t>, kernel*> cache_;
|
||||
std::vector<int> key_idxs_;
|
||||
std::vector<int> arg_size_;
|
||||
std::vector<int> arg_off_;
|
||||
};
|
||||
|
||||
}
|
||||
|
Reference in New Issue
Block a user