[dnn] added base template class for mutualized auto-tuning
This commit is contained in:
@@ -72,18 +72,18 @@ torch::Tensor shift_common(
|
||||
if(m_shift_jit.find(key) == m_shift_jit.end()){
|
||||
jit = m_shift_jit.emplace(key, new triton::jit(ctx)).first->second.get();
|
||||
std::ostringstream oss;
|
||||
configuration->src(oss);
|
||||
configuration->get_src(oss);
|
||||
std::string src = oss.str();
|
||||
// benchmark a given shiftolution kernel
|
||||
auto benchmark = [&](triton::driver::kernel* kernel,
|
||||
triton::jit::launch_information info) {
|
||||
configuration->init(stream, (triton::driver::cu_module*)kernel->module());
|
||||
configuration->init_impl(stream, (triton::driver::cu_module*)kernel->module());
|
||||
unsigned TM = info.global_range_size[0];
|
||||
unsigned TN = info.global_range_size[1];
|
||||
unsigned nthreads = info.num_threads;
|
||||
configuration->enqueue(stream, kernel, &a, &b, &c, TM, TN, nthreads);
|
||||
configuration->enqueue_impl(stream, kernel, &a, &b, &c, TM, TN, nthreads);
|
||||
stream->synchronize();
|
||||
double ts = triton::tools::bench([&](){ configuration->enqueue(stream, kernel, &a, &b, &c, TM, TN, nthreads); },
|
||||
double ts = triton::tools::bench([&](){ configuration->enqueue_impl(stream, kernel, &a, &b, &c, TM, TN, nthreads); },
|
||||
[&](){ stream->synchronize(); }, stream->context()->device());
|
||||
return configuration->get_nflops() / ts * 1e-3;
|
||||
};
|
||||
@@ -96,7 +96,7 @@ torch::Tensor shift_common(
|
||||
jit->add_module("shift", src.c_str(), jit->get_valid("shift", src.c_str()));
|
||||
}
|
||||
triton::driver::kernel* kernel = jit->get_function("shift");
|
||||
configuration->init(stream, (triton::driver::cu_module*)kernel->module());
|
||||
configuration->init_impl(stream, (triton::driver::cu_module*)kernel->module());
|
||||
}
|
||||
else
|
||||
jit = m_shift_jit.at(key).get();
|
||||
@@ -109,6 +109,6 @@ torch::Tensor shift_common(
|
||||
unsigned TN = info.global_range_size[1];
|
||||
unsigned nthreads = info.num_threads;
|
||||
// enqueue
|
||||
configuration->enqueue(stream, kernel, &a, &b, &c, TM, TN, nthreads);
|
||||
configuration->enqueue_impl(stream, kernel, &a, &b, &c, TM, TN, nthreads);
|
||||
return torchc;
|
||||
}
|
||||
|
@@ -133,7 +133,6 @@ public:
|
||||
triton::driver::cu_buffer tdx(ctx, (CUdeviceptr)dx->flat<float>().data(), false);
|
||||
triton::driver::cu_buffer tdg(ctx, (CUdeviceptr)dg->flat<float>().data(), false);
|
||||
triton::driver::cu_buffer tdb(ctx, (CUdeviceptr)db->flat<float>().data(), false);
|
||||
|
||||
// create config
|
||||
triton::dnn::batchnorm_backward batchnorm(C, 1, H, W, B, "fp32");
|
||||
std::ostringstream oss;
|
||||
|
@@ -128,3 +128,4 @@ def run_batchnorm():
|
||||
print(np.max(np.abs(db_t - db_n)))
|
||||
|
||||
run_shift()
|
||||
#run_batchnorm()
|
||||
|
@@ -99,7 +99,7 @@ public:
|
||||
FillShapes(context, C, H, W, B, F, tf_a, tf_b);
|
||||
int64_t D = 1, T = 1;
|
||||
bool has_bias = false;
|
||||
// shift configuration
|
||||
// shift offsets
|
||||
int32_t* shift_h_data = h_shift_h_.flat<int32_t>().data();
|
||||
int32_t* shift_w_data = h_shift_w_.flat<int32_t>().data();
|
||||
std::vector<int32_t> shift_h(shift_h_data, shift_h_data + C);
|
||||
@@ -116,7 +116,6 @@ public:
|
||||
.first->second.get();
|
||||
else
|
||||
shift = m_config.at(key).get();
|
||||
|
||||
// shapes for c
|
||||
std::vector<int64> c_shapes;
|
||||
for(int32_t x: shift->c_shapes())
|
||||
@@ -131,49 +130,7 @@ public:
|
||||
triton::driver::cu_buffer da(ctx, (CUdeviceptr)tf_a.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer db(ctx, (CUdeviceptr)tf_b.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer dc(ctx, (CUdeviceptr)tf_c->flat<float>().data(), false);
|
||||
// get JIT
|
||||
triton::jit* jit;
|
||||
bool autotune = false;
|
||||
if(m_jit.find(key) == m_jit.end()) {
|
||||
jit = m_jit.emplace(key, new triton::jit(ctx)).first->second.get();
|
||||
std::ostringstream oss;
|
||||
shift->src(oss);
|
||||
std::string src = oss.str();
|
||||
auto benchmark = [&](triton::driver::kernel* kernel,
|
||||
triton::jit::launch_information info) {
|
||||
// launch info
|
||||
unsigned TM = info.global_range_size[0];
|
||||
unsigned TN = info.global_range_size[1];
|
||||
unsigned nthreads = info.num_threads;
|
||||
shift->init(stream, (triton::driver::cu_module*)kernel->module());
|
||||
shift->enqueue(stream, kernel, &da, &db, &dc, TM, TN, nthreads);
|
||||
stream->synchronize();
|
||||
double ts = triton::tools::bench([&](){ shift->enqueue(stream, kernel, &da, &db, &dc, TM, TN, nthreads); },
|
||||
[&](){ stream->synchronize(); }, ctx->device());
|
||||
return shift->get_nflops() / ts * 1e-3;
|
||||
};
|
||||
// auto-tune and save result
|
||||
if(autotune) {
|
||||
triton::jit::tune_res_t best = jit->autotune("shift", src.c_str(), benchmark);
|
||||
jit->add_module("shift", src.c_str(), best.params);
|
||||
}
|
||||
else {
|
||||
jit->add_module("shift", src.c_str(), jit->get_valid("shift", src.c_str()));
|
||||
}
|
||||
triton::driver::kernel* kernel = jit->get_function("shift");
|
||||
shift->init(stream, (triton::driver::cu_module*)kernel->module());
|
||||
}
|
||||
else
|
||||
jit = m_jit.at(key).get();
|
||||
// Run
|
||||
triton::driver::kernel* kernel = jit->get_function("shift");
|
||||
triton::jit::launch_information info = jit->get_launch_info("shift");
|
||||
// launch info
|
||||
unsigned TM = info.global_range_size[0];
|
||||
unsigned TN = info.global_range_size[1];
|
||||
unsigned nthreads = info.num_threads;
|
||||
// enqueue
|
||||
shift->enqueue(stream, kernel, &da, &db, &dc, TM, TN, nthreads);
|
||||
shift->enqueue(stream, {&da, &db, &dc});
|
||||
}
|
||||
|
||||
private:
|
||||
|
Reference in New Issue
Block a user