Improvements w/ Auto-Tuning and standard benchmarks (#57)

[PYTHON] Bug-fixes in the auto-tuning module and improvement of the existing API for it
This commit is contained in:
Philippe Tillet
2021-02-03 13:37:21 -08:00
committed by Philippe Tillet
parent ad005d49ac
commit 6fb4800f57
12 changed files with 215 additions and 149 deletions

View File

@@ -54,19 +54,13 @@ enum asm_mode_t {
ASM_NV_SASS
};
struct options_space_t {
typedef std::pair<std::string, std::vector<std::string>> define_t;
std::vector<define_t> defines;
std::vector<int> num_warps;
};
struct options_t {
template<class T>
T D(const std::string& name) const {
return convert<T>(defines.at(name));
}
std::unordered_map<std::string, std::string> defines;
size_t num_warps;
int num_warps;
};
@@ -111,12 +105,14 @@ public:
typedef std::function<grid_t(const options_t&)> grid_fn_ty;
typedef std::pair<options_t, std::shared_ptr<kernel>> kernel_pair_t;
typedef std::map<std::vector<uint64_t>, kernel*> cache_t;
typedef std::vector<std::pair<std::map<std::string, std::string>, int>> autotune_vals_t;
private:
static void do_loop_nest(std::vector<size_t> const & ranges,
std::function<void(std::vector<size_t> const &)> const & f);
public:
function(const std::string& src, const options_space_t& opt, driver::device *device, const std::vector<std::string> &autotune_key = {});
function(const std::string& src, const options_t& opt, driver::device *device,
const autotune_vals_t& autotune_vals = {}, const std::vector<std::string> &autotune_key = {});
void operator()(void* args, size_t args_size, const grid_fn_ty& grid, driver::stream *stream);
void operator()(void* args, size_t args_size, const grid_t& grid, driver::stream *stream);
// auto-tuning
@@ -126,7 +122,7 @@ public:
const std::vector<kernel_pair_t> get_kernels() { return kernels_; }
private:
void init_kernels(const std::string& src, const options_space_t& opt, driver::device *device);
void init_kernels(const std::string& src, const options_t& opt, const autotune_vals_t& autotune_vals, driver::device *device);
private:
std::vector<kernel_pair_t> kernels_;