[dnn]: Now implementing all existing DNN routines using common base template and auto-tuner
This commit is contained in:
@@ -50,56 +50,28 @@ public:
|
||||
int32_t stride_d = 1, stride_h = 1, stride_w = 1;
|
||||
int32_t pad_d = 0, pad_h = 0, pad_w = 0;
|
||||
bool has_bias = false;
|
||||
|
||||
// get conv configuration
|
||||
triton::dnn::conv configuration(B, C,
|
||||
D, H, W,
|
||||
T, R, S,
|
||||
NF,
|
||||
stride_d, stride_h, stride_w,
|
||||
pad_d, pad_h, pad_w,
|
||||
1, 1, 1,
|
||||
"fp16", "fp16",
|
||||
triton::dnn::conv::FPROP, has_bias);
|
||||
|
||||
// Bind memory
|
||||
// wrap buffers
|
||||
triton::driver::cu_buffer a(ctx, (CUdeviceptr)tfa.flat<Eigen::half>().data(), false);
|
||||
triton::driver::cu_buffer b(ctx, (CUdeviceptr)tfb.flat<Eigen::half>().data(), false);
|
||||
triton::driver::buffer* bias = nullptr;
|
||||
|
||||
// template
|
||||
triton::dnn::conv conv(B, C,
|
||||
D, H, W,
|
||||
T, R, S,
|
||||
NF,
|
||||
stride_d, stride_h, stride_w,
|
||||
pad_d, pad_h, pad_w,
|
||||
1, 1, 1,
|
||||
"fp16", "fp16",
|
||||
triton::dnn::conv::FPROP, has_bias);
|
||||
// allocate output
|
||||
auto c_shapes = configuration.c_shapes();
|
||||
auto c_shapes = conv.c_shapes();
|
||||
Tensor* tfc = nullptr;
|
||||
TensorShape out_shape({c_shapes[0], c_shapes[1], c_shapes[2], c_shapes[3]});
|
||||
OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &tfc));
|
||||
triton::driver::cu_buffer c(ctx, (CUdeviceptr)tfc->flat<float>().data(), false);
|
||||
|
||||
// benchmark a given convolution kernel
|
||||
triton::jit jit(ctx);
|
||||
auto benchmark = [&](triton::driver::kernel* kernel,
|
||||
triton::jit::launch_information info) {
|
||||
configuration.init(stream, (triton::driver::cu_module*)kernel->module());
|
||||
unsigned TM = info.global_range_size[0];
|
||||
unsigned TN = info.global_range_size[1];
|
||||
unsigned nthreads = info.num_threads;
|
||||
unsigned GZ = jit.get_int("GZ");
|
||||
configuration.enqueue(stream, kernel, &a, &b, &c, bias, TM, TN, GZ, nthreads);
|
||||
stream->synchronize();
|
||||
double ts = triton::tools::bench([&](){ configuration.enqueue(stream, kernel, &a, &b, &c, bias, TM, TN, GZ, nthreads); },
|
||||
[&](){ stream->synchronize(); }, stream->context()->device());
|
||||
return configuration.get_nflops() / ts * 1e-3;
|
||||
};
|
||||
|
||||
std::ostringstream oss;
|
||||
configuration.src(oss);
|
||||
std::string src = oss.str();
|
||||
|
||||
triton::jit::tune_res_t best = jit.autotune("conv", src.c_str(), benchmark);
|
||||
jit.add_module("conv", src.c_str(), best.params);
|
||||
// jit.add_module("conv", src.c_str(), {16, 2, 32, 32, 2, 64, 2, 2, 2, 2, 8, 2, 16, 4, 1});
|
||||
triton::driver::kernel* kernel = jit.get_function("conv");
|
||||
triton::jit::launch_information info = jit.get_launch_info("conv");
|
||||
std::cout << benchmark(kernel, info) << std::endl;
|
||||
// enqueue
|
||||
conv.enqueue(stream, {&a, &b, &c, bias});
|
||||
}
|
||||
};
|
||||
|
||||
|
Reference in New Issue
Block a user