[dnn]: Now implementing all existing DNN routines using common base template and auto-tuner

This commit is contained in:
Philippe Tillet
2019-07-09 17:30:58 -07:00
parent 88675fa01a
commit b7986baffa
19 changed files with 388 additions and 375 deletions

View File

@@ -50,56 +50,28 @@ public:
int32_t stride_d = 1, stride_h = 1, stride_w = 1;
int32_t pad_d = 0, pad_h = 0, pad_w = 0;
bool has_bias = false;
// get conv configuration
triton::dnn::conv configuration(B, C,
D, H, W,
T, R, S,
NF,
stride_d, stride_h, stride_w,
pad_d, pad_h, pad_w,
1, 1, 1,
"fp16", "fp16",
triton::dnn::conv::FPROP, has_bias);
// Bind memory
// wrap buffers
triton::driver::cu_buffer a(ctx, (CUdeviceptr)tfa.flat<Eigen::half>().data(), false);
triton::driver::cu_buffer b(ctx, (CUdeviceptr)tfb.flat<Eigen::half>().data(), false);
triton::driver::buffer* bias = nullptr;
// template
triton::dnn::conv conv(B, C,
D, H, W,
T, R, S,
NF,
stride_d, stride_h, stride_w,
pad_d, pad_h, pad_w,
1, 1, 1,
"fp16", "fp16",
triton::dnn::conv::FPROP, has_bias);
// allocate output
auto c_shapes = configuration.c_shapes();
auto c_shapes = conv.c_shapes();
Tensor* tfc = nullptr;
TensorShape out_shape({c_shapes[0], c_shapes[1], c_shapes[2], c_shapes[3]});
OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &tfc));
triton::driver::cu_buffer c(ctx, (CUdeviceptr)tfc->flat<float>().data(), false);
// benchmark a given convolution kernel
triton::jit jit(ctx);
auto benchmark = [&](triton::driver::kernel* kernel,
triton::jit::launch_information info) {
configuration.init(stream, (triton::driver::cu_module*)kernel->module());
unsigned TM = info.global_range_size[0];
unsigned TN = info.global_range_size[1];
unsigned nthreads = info.num_threads;
unsigned GZ = jit.get_int("GZ");
configuration.enqueue(stream, kernel, &a, &b, &c, bias, TM, TN, GZ, nthreads);
stream->synchronize();
double ts = triton::tools::bench([&](){ configuration.enqueue(stream, kernel, &a, &b, &c, bias, TM, TN, GZ, nthreads); },
[&](){ stream->synchronize(); }, stream->context()->device());
return configuration.get_nflops() / ts * 1e-3;
};
std::ostringstream oss;
configuration.src(oss);
std::string src = oss.str();
triton::jit::tune_res_t best = jit.autotune("conv", src.c_str(), benchmark);
jit.add_module("conv", src.c_str(), best.params);
// jit.add_module("conv", src.c_str(), {16, 2, 32, 32, 2, 64, 2, 2, 2, 2, 8, 2, 16, 4, 1});
triton::driver::kernel* kernel = jit.get_function("conv");
triton::jit::launch_information info = jit.get_launch_info("conv");
std::cout << benchmark(kernel, info) << std::endl;
// enqueue
conv.enqueue(stream, {&a, &b, &c, bias});
}
};