[dnn] added base template class for mutualized auto-tuning
This commit is contained in:
@@ -14,8 +14,6 @@ int main() {
|
||||
|
||||
// initialize default compute device
|
||||
auto context = triton::driver::backend::contexts::get_default();
|
||||
// initialize just-in-time compiler
|
||||
triton::jit jit(context);
|
||||
|
||||
// initialization
|
||||
int32_t R = 3, S = 3;
|
||||
@@ -31,7 +29,7 @@ int main() {
|
||||
shift_w[c] = rand() % S - S/2;
|
||||
}
|
||||
// configuration
|
||||
triton::dnn::shift shift(BS, C, 1, H, W, 1, R, S, F, shift_h, shift_w, numeric_t_str, numeric_t_str, triton::dnn::shift::BPROP);
|
||||
triton::dnn::shift shift(BS, C, 1, H, W, 1, R, S, F, 1, 1, shift_h, shift_w, numeric_t_str, numeric_t_str, triton::dnn::shift::BPROP);
|
||||
// host buffers
|
||||
std::vector<float> hc(shift.c_size());
|
||||
std::vector<float> rc(shift.c_size());
|
||||
@@ -55,35 +53,7 @@ int main() {
|
||||
stream->write(db, true, 0, hb);
|
||||
stream->write(dc, true, 0, hc);
|
||||
stream->synchronize();
|
||||
// benchmark
|
||||
auto benchmark = [&](triton::driver::kernel* kernel,
|
||||
triton::jit::launch_information info) {
|
||||
shift.init(stream, (triton::driver::cu_module*)kernel->module());
|
||||
// launch infoRR
|
||||
unsigned TM = info.global_range_size[0];
|
||||
unsigned TN = info.global_range_size[1];
|
||||
unsigned nthreads = info.num_threads;
|
||||
// set argument
|
||||
shift.enqueue(stream, kernel, da, db, dc, TM, TN, nthreads);
|
||||
stream->synchronize();
|
||||
// benchmark
|
||||
double ts = triton::tools::bench([&](){shift.enqueue(stream, kernel, da, db, dc, TM, TN, nthreads);},
|
||||
[&](){ stream->synchronize(); }, context->device());
|
||||
return shift.get_nflops() / ts * 1e-3;
|
||||
};
|
||||
|
||||
// shift
|
||||
std::vector<unsigned> params = {
|
||||
4, 2, 16, 4, 128, 2, 2, 1, 1, 8, 16, 8, 2
|
||||
};
|
||||
std::ostringstream oss;
|
||||
shift.src(oss);
|
||||
std::string src = oss.str();
|
||||
// jit.autotune("shift", src.c_str(), benchmark);
|
||||
jit.add_module("shift", src.c_str(), params);
|
||||
triton::driver::kernel* kernel = jit.get_function("shift");
|
||||
triton::jit::launch_information info = jit.get_launch_info("shift");
|
||||
std::cout << "Performance: " << benchmark(kernel, info) << " TFLOPS " << std::endl;
|
||||
shift.enqueue(stream, da, db, dc);
|
||||
// stream->read(dc, true, 0, hc);
|
||||
// shift.cpu_ref(rc.data(), ha.data(), hb.data());
|
||||
// for(size_t i = 0; i < hc.size(); i++)
|
||||
|
Reference in New Issue
Block a user