[dnn/shift] bugfix in wgrad
This commit is contained in:
@@ -56,7 +56,7 @@ def blocksparse_matmul_grad(op, dy):
|
||||
return (dx, dw)
|
||||
|
||||
def run_shift():
|
||||
B, C, H, W = 1, 16, 8, 8
|
||||
B, C, H, W = 1, 32, 8, 6
|
||||
R, S, F = 3, 3, 16
|
||||
np.random.seed(2)
|
||||
a = tf.placeholder(tf.float32, shape=[C, H, W, B])
|
||||
@@ -76,8 +76,8 @@ def run_shift():
|
||||
sess = tf.InteractiveSession()
|
||||
grads = tf.test.compute_gradient([a, b], [(C, H, W, B), (C, F)], c, (F, H, W, B),
|
||||
extra_feed_dict={a: ha, b: hb})
|
||||
dx_t, dx_n = grads[0]
|
||||
dw_t, dw_n = grads[1]
|
||||
dx_t, dx_n = grads[0]
|
||||
print(np.max(np.abs(dw_t - dw_n)))
|
||||
print(np.max(np.abs(dx_t - dx_n)))
|
||||
# Run
|
||||
|
Reference in New Issue
Block a user