[dnn] Adding batchnorm

This commit is contained in:
Philippe Tillet
2019-07-08 18:44:37 -07:00
parent b0cf3143c5
commit f9db0449b7
42 changed files with 682 additions and 1763 deletions

View File

@@ -5,7 +5,7 @@ if(${TensorFlow_FOUND})
include_directories("${CUDA_HOME}/include")
link_directories(${TF_LIB})
add_definitions(-D_GLIBCXX_USE_CXX11_ABI=${TF_ABI})
add_library(tf_blocksparse SHARED dot.cpp conv2d.cpp shift.cpp)
add_library(tf_blocksparse SHARED dot.cpp conv.cpp shift.cpp batchnorm.cpp)
target_link_libraries(tf_blocksparse tensorflow_framework triton)
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/run.py
${CMAKE_CURRENT_BINARY_DIR}/run.py

View File

@@ -0,0 +1,174 @@
#include <iostream>
#include "triton/driver/buffer.h"
#include "triton/driver/backend.h"
#include "triton/driver/stream.h"
#include "triton/runtime/jit.h"
#include "triton/tools/bench.hpp"
#include "triton/dnn/batchnorm.h"
#define EIGEN_USE_GPU
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/util/cuda_kernel_helper.h"
#include "tensorflow/core/util/padding.h"
#include "tensorflow/core/util/tensor_format.h"
#include "tensorflow/core/framework/common_shape_fns.h"
using namespace tensorflow;
using shape_inference::DimensionHandle;
using shape_inference::InferenceContext;
using shape_inference::ShapeHandle;
using GPUDevice = Eigen::GpuDevice;
class BatchnormForwardOp : public OpKernel {
public:
explicit BatchnormForwardOp(OpKernelConstruction* context): OpKernel(context) {
context->GetAttr("eps", &eps_);
}
void Compute(OpKernelContext* context){
// get device/stream
GPUDevice device = context->eigen_device<GPUDevice>();
triton::driver::cu_stream sstream(device.stream(), false);
triton::driver::context* ctx = sstream.context();
triton::driver::stream* stream = &sstream;
// get inputs
const Tensor& x = context->input(0);
const Tensor& g = context->input(1);
const Tensor& b = context->input(2);
// get sizes
int C = x.dim_size(0);
int H = x.dim_size(1);
int W = x.dim_size(2);
int B = x.dim_size(3);
// allocate outputs
Tensor* y = nullptr;
Tensor* m = nullptr;
Tensor* v = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(0, x.shape(), &y));
OP_REQUIRES_OK(context, context->allocate_output(1, g.shape(), &m));
OP_REQUIRES_OK(context, context->allocate_output(2, g.shape(), &v));
// triton handles
triton::driver::cu_buffer tx(ctx, (CUdeviceptr)x.flat<float>().data(), false);
triton::driver::cu_buffer tg(ctx, (CUdeviceptr)g.flat<float>().data(), false);
triton::driver::cu_buffer tb(ctx, (CUdeviceptr)b.flat<float>().data(), false);
triton::driver::cu_buffer ty(ctx, (CUdeviceptr)y->flat<float>().data(), false);
triton::driver::cu_buffer tm(ctx, (CUdeviceptr)m->flat<float>().data(), false);
triton::driver::cu_buffer tv(ctx, (CUdeviceptr)v->flat<float>().data(), false);
// create config
triton::dnn::batchnorm_forward batchnorm(C, 1, H, W, B, "fp32");
std::ostringstream oss;
batchnorm.src(oss);
std::string src = oss.str();
triton::jit jit(ctx);
jit.add_module("batchnorm", src.c_str(), jit.get_valid("batchnorm", src.c_str()));
triton::driver::kernel* kernel = jit.get_function("batchnorm");
size_t TM = jit.get_int("TM");
triton::jit::launch_information info = jit.get_launch_info("batchnorm");
batchnorm.enqueue(stream, kernel, &ty, &tm, &tv, &tx, &tg, &tb, TM, info.num_threads);
}
private:
float eps_;
};
REGISTER_KERNEL_BUILDER(Name("BatchnormForward").Device(DEVICE_GPU), BatchnormForwardOp);
REGISTER_OP("BatchnormForward")
.Input("x: T")
.Input("g: float")
.Input("b: float")
.Output("y: T")
.Output("m: float")
.Output("v: float")
.Attr("T: {float}")
.Attr("eps: float")
.SetShapeFn([](InferenceContext* ctx) {
ctx->set_output(0, ctx->input(0));
ctx->set_output(1, ctx->input(1));
ctx->set_output(2, ctx->input(1));
return Status::OK();
})
;
class BatchnormBackwardOp : public OpKernel {
public:
explicit BatchnormBackwardOp(OpKernelConstruction* context): OpKernel(context) {
context->GetAttr("eps", &eps_);
}
void Compute(OpKernelContext* context){
// get device/stream
GPUDevice device = context->eigen_device<GPUDevice>();
triton::driver::cu_stream sstream(device.stream(), false);
triton::driver::context* ctx = sstream.context();
triton::driver::stream* stream = &sstream;
// get inputs
const Tensor& dy = context->input(0);
const Tensor& x = context->input(1);
const Tensor& g = context->input(2);
const Tensor& m = context->input(3);
const Tensor& v = context->input(4);
// get sizes
int C = x.dim_size(0);
int H = x.dim_size(1);
int W = x.dim_size(2);
int B = x.dim_size(3);
// allocate outputs
Tensor* dx = nullptr;
Tensor* dg = nullptr;
Tensor* db = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(0, x.shape(), &dx));
OP_REQUIRES_OK(context, context->allocate_output(1, g.shape(), &dg));
OP_REQUIRES_OK(context, context->allocate_output(2, g.shape(), &db));
// triton handles
triton::driver::cu_buffer tdy(ctx, (CUdeviceptr)dy.flat<float>().data(), false);
triton::driver::cu_buffer tx(ctx, (CUdeviceptr)x.flat<float>().data(), false);
triton::driver::cu_buffer tg(ctx, (CUdeviceptr)g.flat<float>().data(), false);
triton::driver::cu_buffer tm(ctx, (CUdeviceptr)m.flat<float>().data(), false);
triton::driver::cu_buffer tv(ctx, (CUdeviceptr)v.flat<float>().data(), false);
triton::driver::cu_buffer tdx(ctx, (CUdeviceptr)dx->flat<float>().data(), false);
triton::driver::cu_buffer tdg(ctx, (CUdeviceptr)dg->flat<float>().data(), false);
triton::driver::cu_buffer tdb(ctx, (CUdeviceptr)db->flat<float>().data(), false);
// create config
triton::dnn::batchnorm_backward batchnorm(C, 1, H, W, B, "fp32");
std::ostringstream oss;
batchnorm.src(oss);
std::string src = oss.str();
triton::jit jit(ctx);
jit.add_module("batchnorm", src.c_str(), jit.get_valid("batchnorm", src.c_str()));
triton::driver::kernel* kernel = jit.get_function("batchnorm");
size_t TM = jit.get_int("TM");
triton::jit::launch_information info = jit.get_launch_info("batchnorm");
batchnorm.enqueue(stream, kernel, &tdx, &tdg, &tdb, &tdy, &tx, &tg, &tm, &tv, TM, info.num_threads);
}
private:
float eps_;
};
REGISTER_KERNEL_BUILDER(Name("BatchnormBackward").Device(DEVICE_GPU), BatchnormBackwardOp);
REGISTER_OP("BatchnormBackward")
.Input("dy: TY")
.Input("x: TX")
.Input("g: float")
.Input("m: float")
.Input("v: float")
.Output("dx: TY")
.Output("dg: float")
.Output("db: float")
.Attr("TX: {float}")
.Attr("TY: {float}")
.Attr("eps: float")
.SetShapeFn([](InferenceContext* ctx) {
ctx->set_output(0, ctx->input(1));
ctx->set_output(1, ctx->input(2));
ctx->set_output(2, ctx->input(2));
return Status::OK();
})
;

View File

@@ -65,8 +65,6 @@ public:
// Bind memory
triton::driver::cu_buffer a(ctx, (CUdeviceptr)tfa.flat<Eigen::half>().data(), false);
triton::driver::cu_buffer b(ctx, (CUdeviceptr)tfb.flat<Eigen::half>().data(), false);
// triton::driver::cu_buffer cubias(ctx, (CUdeviceptr)torchbias.storage().data(), false);
// triton::driver::buffer* bias = has_bias ? &cubias : nullptr;
triton::driver::buffer* bias = nullptr;
// allocate output

View File

@@ -56,8 +56,8 @@ def blocksparse_matmul_grad(op, dy):
return (dx, dw)
def run_shift():
B, C, H, W = 1, 32, 8, 6
R, S, F = 3, 3, 16
B, C, H, W = 16, 1024, 8, 8
R, S, F = 3, 3, 1024
np.random.seed(2)
a = tf.placeholder(tf.float32, shape=[C, H, W, B])
b = tf.placeholder(tf.float32, shape=[C, F])
@@ -65,8 +65,6 @@ def run_shift():
hshift_w = np.random.randint(- (S//2), R//2 + 1, size=C, dtype=np.int32)
#hshift_h = np.ones(C, dtype=np.int32)
#hshift_w = np.ones(C, dtype=np.int32)
print(hshift_h)
print(hshift_w)
c = module.shift_conv(a, b, shift_h=tf.make_tensor_proto(hshift_h), shift_w=tf.make_tensor_proto(hshift_w))
# Reference
ha = np.random.rand(C, H, W, B)
@@ -74,16 +72,36 @@ def run_shift():
#ha = np.ones((C, H, W, B), dtype=np.int32)
#hb = np.ones((C, F), dtype=np.int32)
sess = tf.InteractiveSession()
grads = tf.test.compute_gradient([a, b], [(C, H, W, B), (C, F)], c, (F, H, W, B),
extra_feed_dict={a: ha, b: hb})
dw_t, dw_n = grads[1]
dx_t, dx_n = grads[0]
print(np.max(np.abs(dw_t - dw_n)))
print(np.max(np.abs(dx_t - dx_n)))
#grads = tf.test.compute_gradient([a, b], [(C, H, W, B), (C, F)], c, (F, H, W, B),
# extra_feed_dict = {a: ha, b: hb})
#dw_t, dw_n = grads[1]
#dx_t, dx_n = grads[0]
#print(np.max(np.abs(dw_t - dw_n)))
#print(np.max(np.abs(dx_t - dx_n)))
# Run
sess.run(tf.global_variables_initializer())
result = sess.run([c], feed_dict = {a: ha,
b: hb})[0]
#print(result)
run_shift()
def run_batchnorm():
C, H, W, B = 32, 16, 16, 16
np.random.seed(0)
# Placeholders
x = tf.placeholder(tf.float32, shape=[C, H, W, B])
g = tf.placeholder(tf.float32, shape=[C])
b = tf.placeholder(tf.float32, shape=[C])
# Feed values
hx = np.random.rand(C, H, W, B)
hg = np.random.rand(C)
hb = np.random.rand(C)
# batchnorm
y, m, v = module.batchnorm_forward(x, g, b, eps=1e-5)
# Run
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
result = sess.run([y, m, v], feed_dict = {x: hx, g: hg, b: hb})
print(hx.sum(axis=(1,2,3)))
print(result[1])
run_batchnorm()

View File

@@ -125,7 +125,7 @@ public:
triton::driver::cu_buffer dc(ctx, (CUdeviceptr)tf_c->flat<float>().data(), false);
// get JIT
triton::jit* jit;
bool autotune = false;
bool autotune = true;
if(m_jit.find(key) == m_jit.end()) {
jit = m_jit.emplace(key, new triton::jit(ctx)).first->second.get();
std::ostringstream oss;