[dnn]: Now implementing all existing DNN routines using common base template and auto-tuner

This commit is contained in:
Philippe Tillet
2019-07-09 17:30:58 -07:00
parent 88675fa01a
commit b7986baffa
19 changed files with 388 additions and 375 deletions

View File

@@ -33,7 +33,6 @@ class DotOp : public OpKernel {
// get inputs
const Tensor& a = context->input(0);
const Tensor& b = context->input(1);
const Tensor& locks = context->input(2);
// get shapes
const int32_t M = a.dim_size(0);
const int32_t N = b.dim_size(0);
@@ -45,40 +44,13 @@ class DotOp : public OpKernel {
// return early if possible
if (out_shape.num_elements() == 0)
return;
// initialize default compute device
triton::jit jit(ctx);
// matrix multiplication parameters
triton::driver::cu_buffer da(ctx, (CUdeviceptr)a.flat<Eigen::half>().data(), false);
triton::driver::cu_buffer db(ctx, (CUdeviceptr)b.flat<Eigen::half>().data(), false);
triton::driver::cu_buffer dc(ctx, (CUdeviceptr)c->flat<float>().data(), false);
triton::driver::cu_buffer dlocks(ctx, (CUdeviceptr)locks.flat<int32_t>().data(), false);
// benchmark a given matrix multiplication kernel
auto benchmark = [&](triton::driver::kernel* kernel,
triton::jit::launch_information info) {
// launch info
unsigned TM = info.global_range_size[0];
unsigned TN = info.global_range_size[1];
unsigned nthreads = info.num_threads;
unsigned GZ = jit.get_int("GZ");
std::array<size_t, 3> grid = {(M + TM - 1)/TM, (N + TN - 1)/TN, GZ};
triton::dnn::gemm::set_arg(kernel, &da, &db, &dc, M, N, K, &dlocks, grid[0], grid[1]);
stream->enqueue(kernel, grid, {nthreads, 1, 1});
stream->synchronize();
double ts = triton::tools::bench([&](){stream->enqueue(kernel, grid, {nthreads, 1, 1});},
[&](){ stream->synchronize(); }, ctx->device());
return 2.*M*N*K / ts * 1e-3;
};
std::string src = triton::dnn::gemm::src(false, true, "fp16", "fp16", 1, 1);
// just-in-time compile source-code
jit.autotune("matmul", src.c_str(), benchmark);
// jit.add_module("matmul", src.c_str(), {4, 2, 8, 4, 2, 32, 1, 4, 1, 1, 8, 8, 8, 1});
// jit.add_module("matmul", src.c_str(), {16, 4, 128, 16, 4, 128, 2, 2, 2, 2, 8, 32, 8, 1});
// jit.add_module("matmul", src.c_str(), {8, 8, 128, 16, 8, 128, 2, 2, 2, 2, 16, 32, 8, 1 });
// jit.add_module("matmul", src.c_str(), {16, 4, 128, 16, 4, 128, 2, 2, 2, 2, 8, 16, 8, 1});
jit.add_module("matmul", src.c_str(), {16, 2, 128, 32, 32, 2, 2, 2, 2, 8, 8, 4, 2, 1}); //NN
triton::driver::kernel* kernel = jit.get_function("matmul");
triton::jit::launch_information info = jit.get_launch_info("matmul");
std::cout << benchmark(kernel, info) << std::endl;
// template
triton::dnn::gemm dot(M, N, K, false, true, "fp16", "fp16", 4, 4);
dot.enqueue(stream, {&da, &db, &dc});
}
private:
@@ -88,6 +60,5 @@ REGISTER_KERNEL_BUILDER(Name("Dot").Device(DEVICE_GPU), DotOp);
REGISTER_OP("Dot")
.Input("a: float16")
.Input("b: float16")
.Input("locks: int32")
.Output("c: float32")
;