[dnn/shift] bugfix in wgrad
This commit is contained in:
@@ -56,7 +56,7 @@ def blocksparse_matmul_grad(op, dy):
|
||||
return (dx, dw)
|
||||
|
||||
def run_shift():
|
||||
B, C, H, W = 1, 16, 8, 8
|
||||
B, C, H, W = 1, 32, 8, 6
|
||||
R, S, F = 3, 3, 16
|
||||
np.random.seed(2)
|
||||
a = tf.placeholder(tf.float32, shape=[C, H, W, B])
|
||||
@@ -76,8 +76,8 @@ def run_shift():
|
||||
sess = tf.InteractiveSession()
|
||||
grads = tf.test.compute_gradient([a, b], [(C, H, W, B), (C, F)], c, (F, H, W, B),
|
||||
extra_feed_dict={a: ha, b: hb})
|
||||
dx_t, dx_n = grads[0]
|
||||
dw_t, dw_n = grads[1]
|
||||
dx_t, dx_n = grads[0]
|
||||
print(np.max(np.abs(dw_t - dw_n)))
|
||||
print(np.max(np.abs(dx_t - dx_n)))
|
||||
# Run
|
||||
|
@@ -68,6 +68,8 @@ public:
|
||||
// shapes for b
|
||||
int64_t Cb = tf_b.dim_size(0);
|
||||
F = tf_b.dim_size(1);
|
||||
if(OP == triton::dnn::shift::BPROP)
|
||||
std::swap(Cb, F);
|
||||
// checks
|
||||
OP_REQUIRES(context, Ca == Cb, tensorflow::errors::InvalidArgument("operands must have the same number of channels"));
|
||||
C = Ca;
|
||||
|
@@ -49,7 +49,7 @@ shift::shift(int B, int C,
|
||||
shapes_a_.swap(shapes_b_);
|
||||
AT_ = true;
|
||||
BT_ = false;
|
||||
M_ = K_;
|
||||
M_ = F_;
|
||||
N_ = C_;
|
||||
K_ = B_*AH_*AW_;
|
||||
}
|
||||
|
Reference in New Issue
Block a user