[dnn/batchnorm]: added some more code in Triton-C batchnorm implementations
This commit is contained in:
@@ -84,6 +84,21 @@ def run_shift():
|
||||
b: hb})[0]
|
||||
#print(result)
|
||||
|
||||
|
||||
def batch_norm(x, g, b, epsilon=1e-6):
|
||||
shape = x.shape
|
||||
C = int(shape[1])
|
||||
assert g.get_shape().num_elements() == C
|
||||
assert b.get_shape().num_elements() == C
|
||||
return module.batchnorm_forward(x, g, b, eps=epsilon)
|
||||
|
||||
@ops.RegisterGradient("BatchnormForward")
|
||||
def batch_norm_grad(op, dy, mean, var):
|
||||
eps = op.get_attr("eps")
|
||||
return module.batchnorm_backward(dy, op.inputs[0], op.inputs[1],
|
||||
op.outputs[1], op.outputs[2], eps=eps)
|
||||
|
||||
|
||||
def run_batchnorm():
|
||||
C, H, W, B = 32, 16, 16, 16
|
||||
np.random.seed(0)
|
||||
@@ -101,7 +116,7 @@ def run_batchnorm():
|
||||
sess = tf.InteractiveSession()
|
||||
sess.run(tf.global_variables_initializer())
|
||||
result = sess.run([y, m, v], feed_dict = {x: hx, g: hg, b: hb})
|
||||
print(hx.sum(axis=(1,2,3)))
|
||||
print(result[1])
|
||||
|
||||
|
||||
|
||||
run_batchnorm()
|
||||
|
Reference in New Issue
Block a user