Improvements w/ Auto-Tuning and standard benchmarks (#57)

[PYTHON] Bug-fixes in the auto-tuning module and improvement of the existing API for it
This commit is contained in:
Philippe Tillet
2021-02-03 13:37:21 -08:00
committed by Philippe Tillet
parent ad005d49ac
commit 6fb4800f57
12 changed files with 215 additions and 149 deletions

View File

@@ -45,7 +45,8 @@ void delete_grid(const map_key_t& key) {
void register_fn(int op_id,
int dev_id,
const std::string& src,
const rt::options_space_t& opt,
const rt::options_t& opt,
const rt::function::autotune_vals_t& autotune_vals,
const std::vector<std::string>& autotune_key) {
if(tt_devices.find(dev_id) == tt_devices.end()) {
driver::device* device;
@@ -62,7 +63,7 @@ void register_fn(int op_id,
tt_streams[dev_id].reset(stream);
}
if(id_fn_map.find(op_id) == id_fn_map.end()){
id_fn_map[op_id].reset(new rt::function(src, opt, &*tt_devices[dev_id], autotune_key));
id_fn_map[op_id].reset(new rt::function(src, opt, &*tt_devices[dev_id], autotune_vals, autotune_key));
}
for(const auto& k: id_fn_map[op_id]->get_kernels()){
const rt::options_t* opt = &k.first;
@@ -197,13 +198,9 @@ PYBIND11_MODULE(libtriton, m) {
.value("sass", rt::ASM_NV_SASS);
pybind11::class_<rt::options_t>(m, "options", pybind11::dynamic_attr())
.def_readwrite("num_warps", &rt::options_t::num_warps)
.def_readwrite("defines" , &rt::options_t::defines);
pybind11::class_<rt::options_space_t>(m, "options_space")
.def(pybind11::init<>())
.def_readwrite("num_warps", &rt::options_space_t::num_warps)
.def_readwrite("defines" , &rt::options_space_t::defines);
.def_readwrite("defines" , &rt::options_t::defines)
.def_readwrite("num_warps", &rt::options_t::num_warps);
// hooks into triton constructs since frameworks may not use pybind11
m.def("extract_kernels", &extract_kernels);