[dnn] Adding batchnorm
This commit is contained in:
@@ -5,7 +5,7 @@ if(${TensorFlow_FOUND})
|
||||
include_directories("${CUDA_HOME}/include")
|
||||
link_directories(${TF_LIB})
|
||||
add_definitions(-D_GLIBCXX_USE_CXX11_ABI=${TF_ABI})
|
||||
add_library(tf_blocksparse SHARED dot.cpp conv2d.cpp shift.cpp)
|
||||
add_library(tf_blocksparse SHARED dot.cpp conv.cpp shift.cpp batchnorm.cpp)
|
||||
target_link_libraries(tf_blocksparse tensorflow_framework triton)
|
||||
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/run.py
|
||||
${CMAKE_CURRENT_BINARY_DIR}/run.py
|
||||
|
174
examples/python/tensorflow/batchnorm.cpp
Normal file
174
examples/python/tensorflow/batchnorm.cpp
Normal file
@@ -0,0 +1,174 @@
|
||||
#include <iostream>
|
||||
|
||||
#include "triton/driver/buffer.h"
|
||||
#include "triton/driver/backend.h"
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/runtime/jit.h"
|
||||
#include "triton/tools/bench.hpp"
|
||||
#include "triton/dnn/batchnorm.h"
|
||||
|
||||
#define EIGEN_USE_GPU
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/util/cuda_kernel_helper.h"
|
||||
#include "tensorflow/core/util/padding.h"
|
||||
#include "tensorflow/core/util/tensor_format.h"
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
|
||||
using namespace tensorflow;
|
||||
using shape_inference::DimensionHandle;
|
||||
using shape_inference::InferenceContext;
|
||||
using shape_inference::ShapeHandle;
|
||||
using GPUDevice = Eigen::GpuDevice;
|
||||
|
||||
class BatchnormForwardOp : public OpKernel {
|
||||
public:
|
||||
explicit BatchnormForwardOp(OpKernelConstruction* context): OpKernel(context) {
|
||||
context->GetAttr("eps", &eps_);
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context){
|
||||
// get device/stream
|
||||
GPUDevice device = context->eigen_device<GPUDevice>();
|
||||
triton::driver::cu_stream sstream(device.stream(), false);
|
||||
triton::driver::context* ctx = sstream.context();
|
||||
triton::driver::stream* stream = &sstream;
|
||||
// get inputs
|
||||
const Tensor& x = context->input(0);
|
||||
const Tensor& g = context->input(1);
|
||||
const Tensor& b = context->input(2);
|
||||
// get sizes
|
||||
int C = x.dim_size(0);
|
||||
int H = x.dim_size(1);
|
||||
int W = x.dim_size(2);
|
||||
int B = x.dim_size(3);
|
||||
// allocate outputs
|
||||
Tensor* y = nullptr;
|
||||
Tensor* m = nullptr;
|
||||
Tensor* v = nullptr;
|
||||
OP_REQUIRES_OK(context, context->allocate_output(0, x.shape(), &y));
|
||||
OP_REQUIRES_OK(context, context->allocate_output(1, g.shape(), &m));
|
||||
OP_REQUIRES_OK(context, context->allocate_output(2, g.shape(), &v));
|
||||
// triton handles
|
||||
triton::driver::cu_buffer tx(ctx, (CUdeviceptr)x.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer tg(ctx, (CUdeviceptr)g.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer tb(ctx, (CUdeviceptr)b.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer ty(ctx, (CUdeviceptr)y->flat<float>().data(), false);
|
||||
triton::driver::cu_buffer tm(ctx, (CUdeviceptr)m->flat<float>().data(), false);
|
||||
triton::driver::cu_buffer tv(ctx, (CUdeviceptr)v->flat<float>().data(), false);
|
||||
// create config
|
||||
triton::dnn::batchnorm_forward batchnorm(C, 1, H, W, B, "fp32");
|
||||
std::ostringstream oss;
|
||||
batchnorm.src(oss);
|
||||
std::string src = oss.str();
|
||||
triton::jit jit(ctx);
|
||||
jit.add_module("batchnorm", src.c_str(), jit.get_valid("batchnorm", src.c_str()));
|
||||
triton::driver::kernel* kernel = jit.get_function("batchnorm");
|
||||
size_t TM = jit.get_int("TM");
|
||||
triton::jit::launch_information info = jit.get_launch_info("batchnorm");
|
||||
batchnorm.enqueue(stream, kernel, &ty, &tm, &tv, &tx, &tg, &tb, TM, info.num_threads);
|
||||
}
|
||||
|
||||
private:
|
||||
float eps_;
|
||||
};
|
||||
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("BatchnormForward").Device(DEVICE_GPU), BatchnormForwardOp);
|
||||
REGISTER_OP("BatchnormForward")
|
||||
.Input("x: T")
|
||||
.Input("g: float")
|
||||
.Input("b: float")
|
||||
.Output("y: T")
|
||||
.Output("m: float")
|
||||
.Output("v: float")
|
||||
.Attr("T: {float}")
|
||||
.Attr("eps: float")
|
||||
.SetShapeFn([](InferenceContext* ctx) {
|
||||
ctx->set_output(0, ctx->input(0));
|
||||
ctx->set_output(1, ctx->input(1));
|
||||
ctx->set_output(2, ctx->input(1));
|
||||
return Status::OK();
|
||||
})
|
||||
;
|
||||
|
||||
|
||||
class BatchnormBackwardOp : public OpKernel {
|
||||
public:
|
||||
explicit BatchnormBackwardOp(OpKernelConstruction* context): OpKernel(context) {
|
||||
context->GetAttr("eps", &eps_);
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context){
|
||||
// get device/stream
|
||||
GPUDevice device = context->eigen_device<GPUDevice>();
|
||||
triton::driver::cu_stream sstream(device.stream(), false);
|
||||
triton::driver::context* ctx = sstream.context();
|
||||
triton::driver::stream* stream = &sstream;
|
||||
// get inputs
|
||||
const Tensor& dy = context->input(0);
|
||||
const Tensor& x = context->input(1);
|
||||
const Tensor& g = context->input(2);
|
||||
const Tensor& m = context->input(3);
|
||||
const Tensor& v = context->input(4);
|
||||
// get sizes
|
||||
int C = x.dim_size(0);
|
||||
int H = x.dim_size(1);
|
||||
int W = x.dim_size(2);
|
||||
int B = x.dim_size(3);
|
||||
// allocate outputs
|
||||
Tensor* dx = nullptr;
|
||||
Tensor* dg = nullptr;
|
||||
Tensor* db = nullptr;
|
||||
OP_REQUIRES_OK(context, context->allocate_output(0, x.shape(), &dx));
|
||||
OP_REQUIRES_OK(context, context->allocate_output(1, g.shape(), &dg));
|
||||
OP_REQUIRES_OK(context, context->allocate_output(2, g.shape(), &db));
|
||||
// triton handles
|
||||
triton::driver::cu_buffer tdy(ctx, (CUdeviceptr)dy.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer tx(ctx, (CUdeviceptr)x.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer tg(ctx, (CUdeviceptr)g.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer tm(ctx, (CUdeviceptr)m.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer tv(ctx, (CUdeviceptr)v.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer tdx(ctx, (CUdeviceptr)dx->flat<float>().data(), false);
|
||||
triton::driver::cu_buffer tdg(ctx, (CUdeviceptr)dg->flat<float>().data(), false);
|
||||
triton::driver::cu_buffer tdb(ctx, (CUdeviceptr)db->flat<float>().data(), false);
|
||||
|
||||
// create config
|
||||
triton::dnn::batchnorm_backward batchnorm(C, 1, H, W, B, "fp32");
|
||||
std::ostringstream oss;
|
||||
batchnorm.src(oss);
|
||||
std::string src = oss.str();
|
||||
triton::jit jit(ctx);
|
||||
jit.add_module("batchnorm", src.c_str(), jit.get_valid("batchnorm", src.c_str()));
|
||||
triton::driver::kernel* kernel = jit.get_function("batchnorm");
|
||||
size_t TM = jit.get_int("TM");
|
||||
triton::jit::launch_information info = jit.get_launch_info("batchnorm");
|
||||
batchnorm.enqueue(stream, kernel, &tdx, &tdg, &tdb, &tdy, &tx, &tg, &tm, &tv, TM, info.num_threads);
|
||||
}
|
||||
|
||||
private:
|
||||
float eps_;
|
||||
};
|
||||
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("BatchnormBackward").Device(DEVICE_GPU), BatchnormBackwardOp);
|
||||
REGISTER_OP("BatchnormBackward")
|
||||
.Input("dy: TY")
|
||||
.Input("x: TX")
|
||||
.Input("g: float")
|
||||
.Input("m: float")
|
||||
.Input("v: float")
|
||||
.Output("dx: TY")
|
||||
.Output("dg: float")
|
||||
.Output("db: float")
|
||||
.Attr("TX: {float}")
|
||||
.Attr("TY: {float}")
|
||||
.Attr("eps: float")
|
||||
.SetShapeFn([](InferenceContext* ctx) {
|
||||
ctx->set_output(0, ctx->input(1));
|
||||
ctx->set_output(1, ctx->input(2));
|
||||
ctx->set_output(2, ctx->input(2));
|
||||
return Status::OK();
|
||||
})
|
||||
;
|
@@ -65,8 +65,6 @@ public:
|
||||
// Bind memory
|
||||
triton::driver::cu_buffer a(ctx, (CUdeviceptr)tfa.flat<Eigen::half>().data(), false);
|
||||
triton::driver::cu_buffer b(ctx, (CUdeviceptr)tfb.flat<Eigen::half>().data(), false);
|
||||
// triton::driver::cu_buffer cubias(ctx, (CUdeviceptr)torchbias.storage().data(), false);
|
||||
// triton::driver::buffer* bias = has_bias ? &cubias : nullptr;
|
||||
triton::driver::buffer* bias = nullptr;
|
||||
|
||||
// allocate output
|
@@ -56,8 +56,8 @@ def blocksparse_matmul_grad(op, dy):
|
||||
return (dx, dw)
|
||||
|
||||
def run_shift():
|
||||
B, C, H, W = 1, 32, 8, 6
|
||||
R, S, F = 3, 3, 16
|
||||
B, C, H, W = 16, 1024, 8, 8
|
||||
R, S, F = 3, 3, 1024
|
||||
np.random.seed(2)
|
||||
a = tf.placeholder(tf.float32, shape=[C, H, W, B])
|
||||
b = tf.placeholder(tf.float32, shape=[C, F])
|
||||
@@ -65,8 +65,6 @@ def run_shift():
|
||||
hshift_w = np.random.randint(- (S//2), R//2 + 1, size=C, dtype=np.int32)
|
||||
#hshift_h = np.ones(C, dtype=np.int32)
|
||||
#hshift_w = np.ones(C, dtype=np.int32)
|
||||
print(hshift_h)
|
||||
print(hshift_w)
|
||||
c = module.shift_conv(a, b, shift_h=tf.make_tensor_proto(hshift_h), shift_w=tf.make_tensor_proto(hshift_w))
|
||||
# Reference
|
||||
ha = np.random.rand(C, H, W, B)
|
||||
@@ -74,16 +72,36 @@ def run_shift():
|
||||
#ha = np.ones((C, H, W, B), dtype=np.int32)
|
||||
#hb = np.ones((C, F), dtype=np.int32)
|
||||
sess = tf.InteractiveSession()
|
||||
grads = tf.test.compute_gradient([a, b], [(C, H, W, B), (C, F)], c, (F, H, W, B),
|
||||
extra_feed_dict={a: ha, b: hb})
|
||||
dw_t, dw_n = grads[1]
|
||||
dx_t, dx_n = grads[0]
|
||||
print(np.max(np.abs(dw_t - dw_n)))
|
||||
print(np.max(np.abs(dx_t - dx_n)))
|
||||
#grads = tf.test.compute_gradient([a, b], [(C, H, W, B), (C, F)], c, (F, H, W, B),
|
||||
# extra_feed_dict = {a: ha, b: hb})
|
||||
#dw_t, dw_n = grads[1]
|
||||
#dx_t, dx_n = grads[0]
|
||||
#print(np.max(np.abs(dw_t - dw_n)))
|
||||
#print(np.max(np.abs(dx_t - dx_n)))
|
||||
# Run
|
||||
sess.run(tf.global_variables_initializer())
|
||||
result = sess.run([c], feed_dict = {a: ha,
|
||||
b: hb})[0]
|
||||
#print(result)
|
||||
|
||||
run_shift()
|
||||
def run_batchnorm():
|
||||
C, H, W, B = 32, 16, 16, 16
|
||||
np.random.seed(0)
|
||||
# Placeholders
|
||||
x = tf.placeholder(tf.float32, shape=[C, H, W, B])
|
||||
g = tf.placeholder(tf.float32, shape=[C])
|
||||
b = tf.placeholder(tf.float32, shape=[C])
|
||||
# Feed values
|
||||
hx = np.random.rand(C, H, W, B)
|
||||
hg = np.random.rand(C)
|
||||
hb = np.random.rand(C)
|
||||
# batchnorm
|
||||
y, m, v = module.batchnorm_forward(x, g, b, eps=1e-5)
|
||||
# Run
|
||||
sess = tf.InteractiveSession()
|
||||
sess.run(tf.global_variables_initializer())
|
||||
result = sess.run([y, m, v], feed_dict = {x: hx, g: hg, b: hb})
|
||||
print(hx.sum(axis=(1,2,3)))
|
||||
print(result[1])
|
||||
|
||||
run_batchnorm()
|
||||
|
@@ -125,7 +125,7 @@ public:
|
||||
triton::driver::cu_buffer dc(ctx, (CUdeviceptr)tf_c->flat<float>().data(), false);
|
||||
// get JIT
|
||||
triton::jit* jit;
|
||||
bool autotune = false;
|
||||
bool autotune = true;
|
||||
if(m_jit.find(key) == m_jit.end()) {
|
||||
jit = m_jit.emplace(key, new triton::jit(ctx)).first->second.get();
|
||||
std::ostringstream oss;
|
||||
|
Reference in New Issue
Block a user