[dnn/shift] bugfix in wgrad

This commit is contained in:
Philippe Tillet
2019-07-06 11:27:49 -07:00
parent 3e49dbe6ab
commit b0cf3143c5
3 changed files with 5 additions and 3 deletions

View File

@@ -56,7 +56,7 @@ def blocksparse_matmul_grad(op, dy):
return (dx, dw)
def run_shift():
B, C, H, W = 1, 16, 8, 8
B, C, H, W = 1, 32, 8, 6
R, S, F = 3, 3, 16
np.random.seed(2)
a = tf.placeholder(tf.float32, shape=[C, H, W, B])
@@ -76,8 +76,8 @@ def run_shift():
sess = tf.InteractiveSession()
grads = tf.test.compute_gradient([a, b], [(C, H, W, B), (C, F)], c, (F, H, W, B),
extra_feed_dict={a: ha, b: hb})
dx_t, dx_n = grads[0]
dw_t, dw_n = grads[1]
dx_t, dx_n = grads[0]
print(np.max(np.abs(dw_t - dw_n)))
print(np.max(np.abs(dx_t - dx_n)))
# Run

View File

@@ -68,6 +68,8 @@ public:
// shapes for b
int64_t Cb = tf_b.dim_size(0);
F = tf_b.dim_size(1);
if(OP == triton::dnn::shift::BPROP)
std::swap(Cb, F);
// checks
OP_REQUIRES(context, Ca == Cb, tensorflow::errors::InvalidArgument("operands must have the same number of channels"));
C = Ca;

View File

@@ -49,7 +49,7 @@ shift::shift(int B, int C,
shapes_a_.swap(shapes_b_);
AT_ = true;
BT_ = false;
M_ = K_;
M_ = F_;
N_ = C_;
K_ = B_*AH_*AW_;
}