[dnn/batchnorm]: added some more code in Triton-C batchnorm implementations

This commit is contained in:
Philippe Tillet
2019-07-08 20:18:20 -07:00
parent fa3270dcf2
commit f74dcb7e30
13 changed files with 103 additions and 11 deletions

View File

@@ -84,6 +84,21 @@ def run_shift():
b: hb})[0]
#print(result)
def batch_norm(x, g, b, epsilon=1e-6):
shape = x.shape
C = int(shape[1])
assert g.get_shape().num_elements() == C
assert b.get_shape().num_elements() == C
return module.batchnorm_forward(x, g, b, eps=epsilon)
@ops.RegisterGradient("BatchnormForward")
def batch_norm_grad(op, dy, mean, var):
eps = op.get_attr("eps")
return module.batchnorm_backward(dy, op.inputs[0], op.inputs[1],
op.outputs[1], op.outputs[2], eps=eps)
def run_batchnorm():
C, H, W, B = 32, 16, 16, 16
np.random.seed(0)
@@ -101,7 +116,7 @@ def run_batchnorm():
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
result = sess.run([y, m, v], feed_dict = {x: hx, g: hg, b: hb})
print(hx.sum(axis=(1,2,3)))
print(result[1])
run_batchnorm()