[dnn]: Now implementing all existing DNN routines using common base template and auto-tuner

This commit is contained in:
Philippe Tillet
2019-07-09 17:30:58 -07:00
parent 88675fa01a
commit b7986baffa
19 changed files with 388 additions and 375 deletions

View File

@@ -59,15 +59,7 @@ public:
triton::driver::cu_buffer tv(ctx, (CUdeviceptr)v->flat<float>().data(), false);
// create config
triton::dnn::batchnorm_forward batchnorm(C, 1, H, W, B, "fp32");
std::ostringstream oss;
batchnorm.src(oss);
std::string src = oss.str();
triton::jit jit(ctx);
jit.add_module("batchnorm", src.c_str(), jit.get_valid("batchnorm", src.c_str()));
triton::driver::kernel* kernel = jit.get_function("batchnorm");
size_t TM = jit.get_int("TM");
triton::jit::launch_information info = jit.get_launch_info("batchnorm");
batchnorm.enqueue(stream, kernel, &ty, &tm, &tv, &tx, &tg, &tb, TM, info.num_threads);
batchnorm.enqueue(stream, {&ty, &tm, &tv, &tx, &tg, &tb});
}
private:
@@ -135,15 +127,7 @@ public:
triton::driver::cu_buffer tdb(ctx, (CUdeviceptr)db->flat<float>().data(), false);
// create config
triton::dnn::batchnorm_backward batchnorm(C, 1, H, W, B, "fp32");
std::ostringstream oss;
batchnorm.src(oss);
std::string src = oss.str();
triton::jit jit(ctx);
jit.add_module("batchnorm", src.c_str(), jit.get_valid("batchnorm", src.c_str()));
triton::driver::kernel* kernel = jit.get_function("batchnorm");
size_t TM = jit.get_int("TM");
triton::jit::launch_information info = jit.get_launch_info("batchnorm");
batchnorm.enqueue(stream, kernel, &tdx, &tdg, &tdb, &tdy, &tx, &tg, &tm, &tv, TM, info.num_threads);
batchnorm.enqueue(stream, {&tdx, &tdg, &tdb, &tdy, &tx, &tg, &tm, &tv});
}
private:

View File

@@ -50,56 +50,28 @@ public:
int32_t stride_d = 1, stride_h = 1, stride_w = 1;
int32_t pad_d = 0, pad_h = 0, pad_w = 0;
bool has_bias = false;
// get conv configuration
triton::dnn::conv configuration(B, C,
D, H, W,
T, R, S,
NF,
stride_d, stride_h, stride_w,
pad_d, pad_h, pad_w,
1, 1, 1,
"fp16", "fp16",
triton::dnn::conv::FPROP, has_bias);
// Bind memory
// wrap buffers
triton::driver::cu_buffer a(ctx, (CUdeviceptr)tfa.flat<Eigen::half>().data(), false);
triton::driver::cu_buffer b(ctx, (CUdeviceptr)tfb.flat<Eigen::half>().data(), false);
triton::driver::buffer* bias = nullptr;
// template
triton::dnn::conv conv(B, C,
D, H, W,
T, R, S,
NF,
stride_d, stride_h, stride_w,
pad_d, pad_h, pad_w,
1, 1, 1,
"fp16", "fp16",
triton::dnn::conv::FPROP, has_bias);
// allocate output
auto c_shapes = configuration.c_shapes();
auto c_shapes = conv.c_shapes();
Tensor* tfc = nullptr;
TensorShape out_shape({c_shapes[0], c_shapes[1], c_shapes[2], c_shapes[3]});
OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &tfc));
triton::driver::cu_buffer c(ctx, (CUdeviceptr)tfc->flat<float>().data(), false);
// benchmark a given convolution kernel
triton::jit jit(ctx);
auto benchmark = [&](triton::driver::kernel* kernel,
triton::jit::launch_information info) {
configuration.init(stream, (triton::driver::cu_module*)kernel->module());
unsigned TM = info.global_range_size[0];
unsigned TN = info.global_range_size[1];
unsigned nthreads = info.num_threads;
unsigned GZ = jit.get_int("GZ");
configuration.enqueue(stream, kernel, &a, &b, &c, bias, TM, TN, GZ, nthreads);
stream->synchronize();
double ts = triton::tools::bench([&](){ configuration.enqueue(stream, kernel, &a, &b, &c, bias, TM, TN, GZ, nthreads); },
[&](){ stream->synchronize(); }, stream->context()->device());
return configuration.get_nflops() / ts * 1e-3;
};
std::ostringstream oss;
configuration.src(oss);
std::string src = oss.str();
triton::jit::tune_res_t best = jit.autotune("conv", src.c_str(), benchmark);
jit.add_module("conv", src.c_str(), best.params);
// jit.add_module("conv", src.c_str(), {16, 2, 32, 32, 2, 64, 2, 2, 2, 2, 8, 2, 16, 4, 1});
triton::driver::kernel* kernel = jit.get_function("conv");
triton::jit::launch_information info = jit.get_launch_info("conv");
std::cout << benchmark(kernel, info) << std::endl;
// enqueue
conv.enqueue(stream, {&a, &b, &c, bias});
}
};

View File

@@ -33,7 +33,6 @@ class DotOp : public OpKernel {
// get inputs
const Tensor& a = context->input(0);
const Tensor& b = context->input(1);
const Tensor& locks = context->input(2);
// get shapes
const int32_t M = a.dim_size(0);
const int32_t N = b.dim_size(0);
@@ -45,40 +44,13 @@ class DotOp : public OpKernel {
// return early if possible
if (out_shape.num_elements() == 0)
return;
// initialize default compute device
triton::jit jit(ctx);
// matrix multiplication parameters
triton::driver::cu_buffer da(ctx, (CUdeviceptr)a.flat<Eigen::half>().data(), false);
triton::driver::cu_buffer db(ctx, (CUdeviceptr)b.flat<Eigen::half>().data(), false);
triton::driver::cu_buffer dc(ctx, (CUdeviceptr)c->flat<float>().data(), false);
triton::driver::cu_buffer dlocks(ctx, (CUdeviceptr)locks.flat<int32_t>().data(), false);
// benchmark a given matrix multiplication kernel
auto benchmark = [&](triton::driver::kernel* kernel,
triton::jit::launch_information info) {
// launch info
unsigned TM = info.global_range_size[0];
unsigned TN = info.global_range_size[1];
unsigned nthreads = info.num_threads;
unsigned GZ = jit.get_int("GZ");
std::array<size_t, 3> grid = {(M + TM - 1)/TM, (N + TN - 1)/TN, GZ};
triton::dnn::gemm::set_arg(kernel, &da, &db, &dc, M, N, K, &dlocks, grid[0], grid[1]);
stream->enqueue(kernel, grid, {nthreads, 1, 1});
stream->synchronize();
double ts = triton::tools::bench([&](){stream->enqueue(kernel, grid, {nthreads, 1, 1});},
[&](){ stream->synchronize(); }, ctx->device());
return 2.*M*N*K / ts * 1e-3;
};
std::string src = triton::dnn::gemm::src(false, true, "fp16", "fp16", 1, 1);
// just-in-time compile source-code
jit.autotune("matmul", src.c_str(), benchmark);
// jit.add_module("matmul", src.c_str(), {4, 2, 8, 4, 2, 32, 1, 4, 1, 1, 8, 8, 8, 1});
// jit.add_module("matmul", src.c_str(), {16, 4, 128, 16, 4, 128, 2, 2, 2, 2, 8, 32, 8, 1});
// jit.add_module("matmul", src.c_str(), {8, 8, 128, 16, 8, 128, 2, 2, 2, 2, 16, 32, 8, 1 });
// jit.add_module("matmul", src.c_str(), {16, 4, 128, 16, 4, 128, 2, 2, 2, 2, 8, 16, 8, 1});
jit.add_module("matmul", src.c_str(), {16, 2, 128, 32, 32, 2, 2, 2, 2, 8, 8, 4, 2, 1}); //NN
triton::driver::kernel* kernel = jit.get_function("matmul");
triton::jit::launch_information info = jit.get_launch_info("matmul");
std::cout << benchmark(kernel, info) << std::endl;
// template
triton::dnn::gemm dot(M, N, K, false, true, "fp16", "fp16", 4, 4);
dot.enqueue(stream, {&da, &db, &dc});
}
private:
@@ -88,6 +60,5 @@ REGISTER_KERNEL_BUILDER(Name("Dot").Device(DEVICE_GPU), DotOp);
REGISTER_OP("Dot")
.Input("a: float16")
.Input("b: float16")
.Input("locks: int32")
.Output("c: float32")
;

View File

@@ -127,5 +127,5 @@ def run_batchnorm():
print(np.max(np.abs(dg_t - dg_n)))
print(np.max(np.abs(db_t - db_n)))
run_shift()
#run_batchnorm()
#run_shift()
run_batchnorm()

View File

@@ -19,15 +19,6 @@
using namespace tensorflow;
using GPUDevice = Eigen::GpuDevice;
typedef std::tuple<int32_t, int32_t, int32_t, int32_t, int32_t,
int32_t, int32_t, int32_t, int32_t,
int32_t*, int32_t*,
triton::dnn::shift::type, bool> shift_key_t;
static std::map<CUstream, std::unique_ptr<triton::driver::stream>> m_stream;
static std::map<shift_key_t, std::unique_ptr<triton::jit>> m_jit;
static std::map<shift_key_t, std::unique_ptr<triton::dnn::shift>> m_config;
template<triton::dnn::shift::type OP>
class ShiftConvOp : public OpKernel {
public:
@@ -102,23 +93,15 @@ public:
// shift offsets
int32_t* shift_h_data = h_shift_h_.flat<int32_t>().data();
int32_t* shift_w_data = h_shift_w_.flat<int32_t>().data();
std::vector<int32_t> shift_h(shift_h_data, shift_h_data + C);
std::vector<int32_t> shift_w(shift_w_data, shift_w_data + C);
shift_key_t key = {B, C, 1, H, W, 1, R_, S_, F, shift_h_data, shift_w_data, OP, has_bias};
// create configuration
triton::dnn::shift* shift;
if(m_config.find(key) == m_config.end())
shift = m_config.emplace(key, new triton::dnn::shift(
B, C, D, H, W, T, R_, S_, F,
stride_h_, stride_w_,
shift_h, shift_w,
"fp32", "fp32", OP, has_bias))
.first->second.get();
else
shift = m_config.at(key).get();
triton::dnn::shift shift(B, C, D, H, W, T, R_, S_, F,
stride_h_, stride_w_,
shift_h_data, shift_w_data,
"fp32", "fp32", OP, has_bias);
// shapes for c
std::vector<int64> c_shapes;
for(int32_t x: shift->c_shapes())
for(int32_t x: shift.c_shapes())
c_shapes.push_back(x);
TensorShape out_shapes(c_shapes);
Tensor* tf_c = nullptr;
@@ -130,7 +113,7 @@ public:
triton::driver::cu_buffer da(ctx, (CUdeviceptr)tf_a.flat<float>().data(), false);
triton::driver::cu_buffer db(ctx, (CUdeviceptr)tf_b.flat<float>().data(), false);
triton::driver::cu_buffer dc(ctx, (CUdeviceptr)tf_c->flat<float>().data(), false);
shift->enqueue(stream, {&da, &db, &dc});
shift.enqueue(stream, {&da, &db, &dc});
}
private: