testing a simple shiftnet
This commit is contained in:
@@ -100,7 +100,7 @@ def batch_norm_grad(op, dy, mean, var):
|
||||
|
||||
|
||||
def run_batchnorm():
|
||||
C, H, W, B = 1, 4, 4, 4
|
||||
C, H, W, B = 32, 14, 14, 64
|
||||
np.random.seed(0)
|
||||
# Placeholders
|
||||
x = tf.placeholder(tf.float32, shape=[C, H, W, B])
|
||||
@@ -117,7 +117,8 @@ def run_batchnorm():
|
||||
sess = tf.InteractiveSession()
|
||||
sess.run(tf.global_variables_initializer())
|
||||
result = sess.run([y, m, v], feed_dict = {x: hx, g: hg, b: hb})
|
||||
#print(result[0], result[1], result[2])
|
||||
print(result[1])
|
||||
print(np.mean(hx, (1, 2, 3)))
|
||||
grads = tf.test.compute_gradient([x, g, b], [(C, H, W, B), (C, ), (C, )], y, (C, H, W, B),
|
||||
extra_feed_dict = {x: hx, g: hg, b: hb})
|
||||
dx_t, dx_n = grads[0]
|
||||
|
Reference in New Issue
Block a user