[dnn]: Now implementing all existing DNN routines using common base template and auto-tuner
This commit is contained in:
@@ -59,15 +59,7 @@ public:
|
||||
triton::driver::cu_buffer tv(ctx, (CUdeviceptr)v->flat<float>().data(), false);
|
||||
// create config
|
||||
triton::dnn::batchnorm_forward batchnorm(C, 1, H, W, B, "fp32");
|
||||
std::ostringstream oss;
|
||||
batchnorm.src(oss);
|
||||
std::string src = oss.str();
|
||||
triton::jit jit(ctx);
|
||||
jit.add_module("batchnorm", src.c_str(), jit.get_valid("batchnorm", src.c_str()));
|
||||
triton::driver::kernel* kernel = jit.get_function("batchnorm");
|
||||
size_t TM = jit.get_int("TM");
|
||||
triton::jit::launch_information info = jit.get_launch_info("batchnorm");
|
||||
batchnorm.enqueue(stream, kernel, &ty, &tm, &tv, &tx, &tg, &tb, TM, info.num_threads);
|
||||
batchnorm.enqueue(stream, {&ty, &tm, &tv, &tx, &tg, &tb});
|
||||
}
|
||||
|
||||
private:
|
||||
@@ -135,15 +127,7 @@ public:
|
||||
triton::driver::cu_buffer tdb(ctx, (CUdeviceptr)db->flat<float>().data(), false);
|
||||
// create config
|
||||
triton::dnn::batchnorm_backward batchnorm(C, 1, H, W, B, "fp32");
|
||||
std::ostringstream oss;
|
||||
batchnorm.src(oss);
|
||||
std::string src = oss.str();
|
||||
triton::jit jit(ctx);
|
||||
jit.add_module("batchnorm", src.c_str(), jit.get_valid("batchnorm", src.c_str()));
|
||||
triton::driver::kernel* kernel = jit.get_function("batchnorm");
|
||||
size_t TM = jit.get_int("TM");
|
||||
triton::jit::launch_information info = jit.get_launch_info("batchnorm");
|
||||
batchnorm.enqueue(stream, kernel, &tdx, &tdg, &tdb, &tdy, &tx, &tg, &tm, &tv, TM, info.num_threads);
|
||||
batchnorm.enqueue(stream, {&tdx, &tdg, &tdb, &tdy, &tx, &tg, &tm, &tv});
|
||||
}
|
||||
|
||||
private:
|
||||
|
@@ -50,56 +50,28 @@ public:
|
||||
int32_t stride_d = 1, stride_h = 1, stride_w = 1;
|
||||
int32_t pad_d = 0, pad_h = 0, pad_w = 0;
|
||||
bool has_bias = false;
|
||||
|
||||
// get conv configuration
|
||||
triton::dnn::conv configuration(B, C,
|
||||
D, H, W,
|
||||
T, R, S,
|
||||
NF,
|
||||
stride_d, stride_h, stride_w,
|
||||
pad_d, pad_h, pad_w,
|
||||
1, 1, 1,
|
||||
"fp16", "fp16",
|
||||
triton::dnn::conv::FPROP, has_bias);
|
||||
|
||||
// Bind memory
|
||||
// wrap buffers
|
||||
triton::driver::cu_buffer a(ctx, (CUdeviceptr)tfa.flat<Eigen::half>().data(), false);
|
||||
triton::driver::cu_buffer b(ctx, (CUdeviceptr)tfb.flat<Eigen::half>().data(), false);
|
||||
triton::driver::buffer* bias = nullptr;
|
||||
|
||||
// template
|
||||
triton::dnn::conv conv(B, C,
|
||||
D, H, W,
|
||||
T, R, S,
|
||||
NF,
|
||||
stride_d, stride_h, stride_w,
|
||||
pad_d, pad_h, pad_w,
|
||||
1, 1, 1,
|
||||
"fp16", "fp16",
|
||||
triton::dnn::conv::FPROP, has_bias);
|
||||
// allocate output
|
||||
auto c_shapes = configuration.c_shapes();
|
||||
auto c_shapes = conv.c_shapes();
|
||||
Tensor* tfc = nullptr;
|
||||
TensorShape out_shape({c_shapes[0], c_shapes[1], c_shapes[2], c_shapes[3]});
|
||||
OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &tfc));
|
||||
triton::driver::cu_buffer c(ctx, (CUdeviceptr)tfc->flat<float>().data(), false);
|
||||
|
||||
// benchmark a given convolution kernel
|
||||
triton::jit jit(ctx);
|
||||
auto benchmark = [&](triton::driver::kernel* kernel,
|
||||
triton::jit::launch_information info) {
|
||||
configuration.init(stream, (triton::driver::cu_module*)kernel->module());
|
||||
unsigned TM = info.global_range_size[0];
|
||||
unsigned TN = info.global_range_size[1];
|
||||
unsigned nthreads = info.num_threads;
|
||||
unsigned GZ = jit.get_int("GZ");
|
||||
configuration.enqueue(stream, kernel, &a, &b, &c, bias, TM, TN, GZ, nthreads);
|
||||
stream->synchronize();
|
||||
double ts = triton::tools::bench([&](){ configuration.enqueue(stream, kernel, &a, &b, &c, bias, TM, TN, GZ, nthreads); },
|
||||
[&](){ stream->synchronize(); }, stream->context()->device());
|
||||
return configuration.get_nflops() / ts * 1e-3;
|
||||
};
|
||||
|
||||
std::ostringstream oss;
|
||||
configuration.src(oss);
|
||||
std::string src = oss.str();
|
||||
|
||||
triton::jit::tune_res_t best = jit.autotune("conv", src.c_str(), benchmark);
|
||||
jit.add_module("conv", src.c_str(), best.params);
|
||||
// jit.add_module("conv", src.c_str(), {16, 2, 32, 32, 2, 64, 2, 2, 2, 2, 8, 2, 16, 4, 1});
|
||||
triton::driver::kernel* kernel = jit.get_function("conv");
|
||||
triton::jit::launch_information info = jit.get_launch_info("conv");
|
||||
std::cout << benchmark(kernel, info) << std::endl;
|
||||
// enqueue
|
||||
conv.enqueue(stream, {&a, &b, &c, bias});
|
||||
}
|
||||
};
|
||||
|
||||
|
@@ -33,7 +33,6 @@ class DotOp : public OpKernel {
|
||||
// get inputs
|
||||
const Tensor& a = context->input(0);
|
||||
const Tensor& b = context->input(1);
|
||||
const Tensor& locks = context->input(2);
|
||||
// get shapes
|
||||
const int32_t M = a.dim_size(0);
|
||||
const int32_t N = b.dim_size(0);
|
||||
@@ -45,40 +44,13 @@ class DotOp : public OpKernel {
|
||||
// return early if possible
|
||||
if (out_shape.num_elements() == 0)
|
||||
return;
|
||||
// initialize default compute device
|
||||
triton::jit jit(ctx);
|
||||
// matrix multiplication parameters
|
||||
triton::driver::cu_buffer da(ctx, (CUdeviceptr)a.flat<Eigen::half>().data(), false);
|
||||
triton::driver::cu_buffer db(ctx, (CUdeviceptr)b.flat<Eigen::half>().data(), false);
|
||||
triton::driver::cu_buffer dc(ctx, (CUdeviceptr)c->flat<float>().data(), false);
|
||||
triton::driver::cu_buffer dlocks(ctx, (CUdeviceptr)locks.flat<int32_t>().data(), false);
|
||||
// benchmark a given matrix multiplication kernel
|
||||
auto benchmark = [&](triton::driver::kernel* kernel,
|
||||
triton::jit::launch_information info) {
|
||||
// launch info
|
||||
unsigned TM = info.global_range_size[0];
|
||||
unsigned TN = info.global_range_size[1];
|
||||
unsigned nthreads = info.num_threads;
|
||||
unsigned GZ = jit.get_int("GZ");
|
||||
std::array<size_t, 3> grid = {(M + TM - 1)/TM, (N + TN - 1)/TN, GZ};
|
||||
triton::dnn::gemm::set_arg(kernel, &da, &db, &dc, M, N, K, &dlocks, grid[0], grid[1]);
|
||||
stream->enqueue(kernel, grid, {nthreads, 1, 1});
|
||||
stream->synchronize();
|
||||
double ts = triton::tools::bench([&](){stream->enqueue(kernel, grid, {nthreads, 1, 1});},
|
||||
[&](){ stream->synchronize(); }, ctx->device());
|
||||
return 2.*M*N*K / ts * 1e-3;
|
||||
};
|
||||
std::string src = triton::dnn::gemm::src(false, true, "fp16", "fp16", 1, 1);
|
||||
// just-in-time compile source-code
|
||||
jit.autotune("matmul", src.c_str(), benchmark);
|
||||
// jit.add_module("matmul", src.c_str(), {4, 2, 8, 4, 2, 32, 1, 4, 1, 1, 8, 8, 8, 1});
|
||||
// jit.add_module("matmul", src.c_str(), {16, 4, 128, 16, 4, 128, 2, 2, 2, 2, 8, 32, 8, 1});
|
||||
// jit.add_module("matmul", src.c_str(), {8, 8, 128, 16, 8, 128, 2, 2, 2, 2, 16, 32, 8, 1 });
|
||||
// jit.add_module("matmul", src.c_str(), {16, 4, 128, 16, 4, 128, 2, 2, 2, 2, 8, 16, 8, 1});
|
||||
jit.add_module("matmul", src.c_str(), {16, 2, 128, 32, 32, 2, 2, 2, 2, 8, 8, 4, 2, 1}); //NN
|
||||
triton::driver::kernel* kernel = jit.get_function("matmul");
|
||||
triton::jit::launch_information info = jit.get_launch_info("matmul");
|
||||
std::cout << benchmark(kernel, info) << std::endl;
|
||||
// template
|
||||
triton::dnn::gemm dot(M, N, K, false, true, "fp16", "fp16", 4, 4);
|
||||
dot.enqueue(stream, {&da, &db, &dc});
|
||||
}
|
||||
|
||||
private:
|
||||
@@ -88,6 +60,5 @@ REGISTER_KERNEL_BUILDER(Name("Dot").Device(DEVICE_GPU), DotOp);
|
||||
REGISTER_OP("Dot")
|
||||
.Input("a: float16")
|
||||
.Input("b: float16")
|
||||
.Input("locks: int32")
|
||||
.Output("c: float32")
|
||||
;
|
||||
|
@@ -127,5 +127,5 @@ def run_batchnorm():
|
||||
print(np.max(np.abs(dg_t - dg_n)))
|
||||
print(np.max(np.abs(db_t - db_n)))
|
||||
|
||||
run_shift()
|
||||
#run_batchnorm()
|
||||
#run_shift()
|
||||
run_batchnorm()
|
||||
|
@@ -19,15 +19,6 @@
|
||||
using namespace tensorflow;
|
||||
using GPUDevice = Eigen::GpuDevice;
|
||||
|
||||
typedef std::tuple<int32_t, int32_t, int32_t, int32_t, int32_t,
|
||||
int32_t, int32_t, int32_t, int32_t,
|
||||
int32_t*, int32_t*,
|
||||
triton::dnn::shift::type, bool> shift_key_t;
|
||||
|
||||
static std::map<CUstream, std::unique_ptr<triton::driver::stream>> m_stream;
|
||||
static std::map<shift_key_t, std::unique_ptr<triton::jit>> m_jit;
|
||||
static std::map<shift_key_t, std::unique_ptr<triton::dnn::shift>> m_config;
|
||||
|
||||
template<triton::dnn::shift::type OP>
|
||||
class ShiftConvOp : public OpKernel {
|
||||
public:
|
||||
@@ -102,23 +93,15 @@ public:
|
||||
// shift offsets
|
||||
int32_t* shift_h_data = h_shift_h_.flat<int32_t>().data();
|
||||
int32_t* shift_w_data = h_shift_w_.flat<int32_t>().data();
|
||||
std::vector<int32_t> shift_h(shift_h_data, shift_h_data + C);
|
||||
std::vector<int32_t> shift_w(shift_w_data, shift_w_data + C);
|
||||
shift_key_t key = {B, C, 1, H, W, 1, R_, S_, F, shift_h_data, shift_w_data, OP, has_bias};
|
||||
// create configuration
|
||||
triton::dnn::shift* shift;
|
||||
if(m_config.find(key) == m_config.end())
|
||||
shift = m_config.emplace(key, new triton::dnn::shift(
|
||||
B, C, D, H, W, T, R_, S_, F,
|
||||
stride_h_, stride_w_,
|
||||
shift_h, shift_w,
|
||||
"fp32", "fp32", OP, has_bias))
|
||||
.first->second.get();
|
||||
else
|
||||
shift = m_config.at(key).get();
|
||||
triton::dnn::shift shift(B, C, D, H, W, T, R_, S_, F,
|
||||
stride_h_, stride_w_,
|
||||
shift_h_data, shift_w_data,
|
||||
"fp32", "fp32", OP, has_bias);
|
||||
|
||||
// shapes for c
|
||||
std::vector<int64> c_shapes;
|
||||
for(int32_t x: shift->c_shapes())
|
||||
for(int32_t x: shift.c_shapes())
|
||||
c_shapes.push_back(x);
|
||||
TensorShape out_shapes(c_shapes);
|
||||
Tensor* tf_c = nullptr;
|
||||
@@ -130,7 +113,7 @@ public:
|
||||
triton::driver::cu_buffer da(ctx, (CUdeviceptr)tf_a.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer db(ctx, (CUdeviceptr)tf_b.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer dc(ctx, (CUdeviceptr)tf_c->flat<float>().data(), false);
|
||||
shift->enqueue(stream, {&da, &db, &dc});
|
||||
shift.enqueue(stream, {&da, &db, &dc});
|
||||
}
|
||||
|
||||
private:
|
||||
|
Reference in New Issue
Block a user