[dnn]: Now implementing all existing DNN routines using common base template and auto-tuner
This commit is contained in:
@@ -59,15 +59,7 @@ public:
|
||||
triton::driver::cu_buffer tv(ctx, (CUdeviceptr)v->flat<float>().data(), false);
|
||||
// create config
|
||||
triton::dnn::batchnorm_forward batchnorm(C, 1, H, W, B, "fp32");
|
||||
std::ostringstream oss;
|
||||
batchnorm.src(oss);
|
||||
std::string src = oss.str();
|
||||
triton::jit jit(ctx);
|
||||
jit.add_module("batchnorm", src.c_str(), jit.get_valid("batchnorm", src.c_str()));
|
||||
triton::driver::kernel* kernel = jit.get_function("batchnorm");
|
||||
size_t TM = jit.get_int("TM");
|
||||
triton::jit::launch_information info = jit.get_launch_info("batchnorm");
|
||||
batchnorm.enqueue(stream, kernel, &ty, &tm, &tv, &tx, &tg, &tb, TM, info.num_threads);
|
||||
batchnorm.enqueue(stream, {&ty, &tm, &tv, &tx, &tg, &tb});
|
||||
}
|
||||
|
||||
private:
|
||||
@@ -135,15 +127,7 @@ public:
|
||||
triton::driver::cu_buffer tdb(ctx, (CUdeviceptr)db->flat<float>().data(), false);
|
||||
// create config
|
||||
triton::dnn::batchnorm_backward batchnorm(C, 1, H, W, B, "fp32");
|
||||
std::ostringstream oss;
|
||||
batchnorm.src(oss);
|
||||
std::string src = oss.str();
|
||||
triton::jit jit(ctx);
|
||||
jit.add_module("batchnorm", src.c_str(), jit.get_valid("batchnorm", src.c_str()));
|
||||
triton::driver::kernel* kernel = jit.get_function("batchnorm");
|
||||
size_t TM = jit.get_int("TM");
|
||||
triton::jit::launch_information info = jit.get_launch_info("batchnorm");
|
||||
batchnorm.enqueue(stream, kernel, &tdx, &tdg, &tdb, &tdy, &tx, &tg, &tm, &tv, TM, info.num_threads);
|
||||
batchnorm.enqueue(stream, {&tdx, &tdg, &tdb, &tdy, &tx, &tg, &tm, &tv});
|
||||
}
|
||||
|
||||
private:
|
||||
|
Reference in New Issue
Block a user