[dnn] added base template class for mutualized auto-tuning

This commit is contained in:
Philippe Tillet
2019-07-09 16:09:34 -07:00
parent 066ae338f1
commit 88675fa01a
9 changed files with 181 additions and 106 deletions

View File

@@ -72,18 +72,18 @@ torch::Tensor shift_common(
if(m_shift_jit.find(key) == m_shift_jit.end()){
jit = m_shift_jit.emplace(key, new triton::jit(ctx)).first->second.get();
std::ostringstream oss;
configuration->src(oss);
configuration->get_src(oss);
std::string src = oss.str();
// benchmark a given shiftolution kernel
auto benchmark = [&](triton::driver::kernel* kernel,
triton::jit::launch_information info) {
configuration->init(stream, (triton::driver::cu_module*)kernel->module());
configuration->init_impl(stream, (triton::driver::cu_module*)kernel->module());
unsigned TM = info.global_range_size[0];
unsigned TN = info.global_range_size[1];
unsigned nthreads = info.num_threads;
configuration->enqueue(stream, kernel, &a, &b, &c, TM, TN, nthreads);
configuration->enqueue_impl(stream, kernel, &a, &b, &c, TM, TN, nthreads);
stream->synchronize();
double ts = triton::tools::bench([&](){ configuration->enqueue(stream, kernel, &a, &b, &c, TM, TN, nthreads); },
double ts = triton::tools::bench([&](){ configuration->enqueue_impl(stream, kernel, &a, &b, &c, TM, TN, nthreads); },
[&](){ stream->synchronize(); }, stream->context()->device());
return configuration->get_nflops() / ts * 1e-3;
};
@@ -96,7 +96,7 @@ torch::Tensor shift_common(
jit->add_module("shift", src.c_str(), jit->get_valid("shift", src.c_str()));
}
triton::driver::kernel* kernel = jit->get_function("shift");
configuration->init(stream, (triton::driver::cu_module*)kernel->module());
configuration->init_impl(stream, (triton::driver::cu_module*)kernel->module());
}
else
jit = m_shift_jit.at(key).get();
@@ -109,6 +109,6 @@ torch::Tensor shift_common(
unsigned TN = info.global_range_size[1];
unsigned nthreads = info.num_threads;
// enqueue
configuration->enqueue(stream, kernel, &a, &b, &c, TM, TN, nthreads);
configuration->enqueue_impl(stream, kernel, &a, &b, &c, TM, TN, nthreads);
return torchc;
}