[dnn] added base template class for mutualized auto-tuning

This commit is contained in:
Philippe Tillet
2019-07-09 16:09:34 -07:00
parent 066ae338f1
commit 88675fa01a
9 changed files with 181 additions and 106 deletions

View File

@@ -99,7 +99,7 @@ public:
FillShapes(context, C, H, W, B, F, tf_a, tf_b);
int64_t D = 1, T = 1;
bool has_bias = false;
// shift configuration
// shift offsets
int32_t* shift_h_data = h_shift_h_.flat<int32_t>().data();
int32_t* shift_w_data = h_shift_w_.flat<int32_t>().data();
std::vector<int32_t> shift_h(shift_h_data, shift_h_data + C);
@@ -116,7 +116,6 @@ public:
.first->second.get();
else
shift = m_config.at(key).get();
// shapes for c
std::vector<int64> c_shapes;
for(int32_t x: shift->c_shapes())
@@ -131,49 +130,7 @@ public:
triton::driver::cu_buffer da(ctx, (CUdeviceptr)tf_a.flat<float>().data(), false);
triton::driver::cu_buffer db(ctx, (CUdeviceptr)tf_b.flat<float>().data(), false);
triton::driver::cu_buffer dc(ctx, (CUdeviceptr)tf_c->flat<float>().data(), false);
// get JIT
triton::jit* jit;
bool autotune = false;
if(m_jit.find(key) == m_jit.end()) {
jit = m_jit.emplace(key, new triton::jit(ctx)).first->second.get();
std::ostringstream oss;
shift->src(oss);
std::string src = oss.str();
auto benchmark = [&](triton::driver::kernel* kernel,
triton::jit::launch_information info) {
// launch info
unsigned TM = info.global_range_size[0];
unsigned TN = info.global_range_size[1];
unsigned nthreads = info.num_threads;
shift->init(stream, (triton::driver::cu_module*)kernel->module());
shift->enqueue(stream, kernel, &da, &db, &dc, TM, TN, nthreads);
stream->synchronize();
double ts = triton::tools::bench([&](){ shift->enqueue(stream, kernel, &da, &db, &dc, TM, TN, nthreads); },
[&](){ stream->synchronize(); }, ctx->device());
return shift->get_nflops() / ts * 1e-3;
};
// auto-tune and save result
if(autotune) {
triton::jit::tune_res_t best = jit->autotune("shift", src.c_str(), benchmark);
jit->add_module("shift", src.c_str(), best.params);
}
else {
jit->add_module("shift", src.c_str(), jit->get_valid("shift", src.c_str()));
}
triton::driver::kernel* kernel = jit->get_function("shift");
shift->init(stream, (triton::driver::cu_module*)kernel->module());
}
else
jit = m_jit.at(key).get();
// Run
triton::driver::kernel* kernel = jit->get_function("shift");
triton::jit::launch_information info = jit->get_launch_info("shift");
// launch info
unsigned TM = info.global_range_size[0];
unsigned TN = info.global_range_size[1];
unsigned nthreads = info.num_threads;
// enqueue
shift->enqueue(stream, kernel, &da, &db, &dc, TM, TN, nthreads);
shift->enqueue(stream, {&da, &db, &dc});
}
private: