[dnn]: Now implementing all existing DNN routines using common base template and auto-tuner

This commit is contained in:
Philippe Tillet
2019-07-09 17:30:58 -07:00
parent 88675fa01a
commit b7986baffa
19 changed files with 388 additions and 375 deletions

View File

@@ -59,15 +59,7 @@ public:
triton::driver::cu_buffer tv(ctx, (CUdeviceptr)v->flat<float>().data(), false);
// create config
triton::dnn::batchnorm_forward batchnorm(C, 1, H, W, B, "fp32");
std::ostringstream oss;
batchnorm.src(oss);
std::string src = oss.str();
triton::jit jit(ctx);
jit.add_module("batchnorm", src.c_str(), jit.get_valid("batchnorm", src.c_str()));
triton::driver::kernel* kernel = jit.get_function("batchnorm");
size_t TM = jit.get_int("TM");
triton::jit::launch_information info = jit.get_launch_info("batchnorm");
batchnorm.enqueue(stream, kernel, &ty, &tm, &tv, &tx, &tg, &tb, TM, info.num_threads);
batchnorm.enqueue(stream, {&ty, &tm, &tv, &tx, &tg, &tb});
}
private:
@@ -135,15 +127,7 @@ public:
triton::driver::cu_buffer tdb(ctx, (CUdeviceptr)db->flat<float>().data(), false);
// create config
triton::dnn::batchnorm_backward batchnorm(C, 1, H, W, B, "fp32");
std::ostringstream oss;
batchnorm.src(oss);
std::string src = oss.str();
triton::jit jit(ctx);
jit.add_module("batchnorm", src.c_str(), jit.get_valid("batchnorm", src.c_str()));
triton::driver::kernel* kernel = jit.get_function("batchnorm");
size_t TM = jit.get_int("TM");
triton::jit::launch_information info = jit.get_launch_info("batchnorm");
batchnorm.enqueue(stream, kernel, &tdx, &tdg, &tdb, &tdy, &tx, &tg, &tm, &tv, TM, info.num_threads);
batchnorm.enqueue(stream, {&tdx, &tdg, &tdb, &tdy, &tx, &tg, &tm, &tv});
}
private: