diff --git a/examples/python/tensorflow/run.py b/examples/python/tensorflow/run.py index 0d8019f27..c85ef2f7e 100644 --- a/examples/python/tensorflow/run.py +++ b/examples/python/tensorflow/run.py @@ -56,38 +56,30 @@ def blocksparse_matmul_grad(op, dy): return (dx, dw) def run_shift(): - B, C, H, W = 1, 16, 4, 4 + B, C, H, W = 1, 16, 8, 8 R, S, F = 3, 3, 16 np.random.seed(2) a = tf.placeholder(tf.float32, shape=[C, H, W, B]) b = tf.placeholder(tf.float32, shape=[C, F]) - #hshift_h = np.random.randint(- (R//2), R//2 + 1, size=C, dtype=np.int32) - #hshift_w = np.random.randint(- (S//2), R//2 + 1, size=C, dtype=np.int32) - hshift_h = -1*np.ones(C, dtype=np.int32) - hshift_w = -1*np.ones(C, dtype=np.int32) + hshift_h = np.random.randint(- (R//2), R//2 + 1, size=C, dtype=np.int32) + hshift_w = np.random.randint(- (S//2), R//2 + 1, size=C, dtype=np.int32) + #hshift_h = np.ones(C, dtype=np.int32) + #hshift_w = np.ones(C, dtype=np.int32) print(hshift_h) print(hshift_w) c = module.shift_conv(a, b, shift_h=tf.make_tensor_proto(hshift_h), shift_w=tf.make_tensor_proto(hshift_w)) - c = tf.math.reduce_sum(c) # Reference - ha = np.ones((C, H, W, B), dtype=np.float32) - hb = np.ones((C, F), dtype=np.float32) + ha = np.random.rand(C, H, W, B) + hb = np.random.rand(C, F) #ha = np.ones((C, H, W, B), dtype=np.int32) #hb = np.ones((C, F), dtype=np.int32) sess = tf.InteractiveSession() - grads = tf.test.compute_gradient([a, b], [(C, H, W, B), (C, F)], c, (1,), + grads = tf.test.compute_gradient([a, b], [(C, H, W, B), (C, F)], c, (F, H, W, B), extra_feed_dict={a: ha, b: hb}) dx_t, dx_n = grads[0] dw_t, dw_n = grads[1] - #print(dw_t - dw_n) - #np.savetxt('diff.dat', dw_t - dw_n, fmt='%2.4f') - #np.savetxt('theoretical.dat', dw_t, fmt='%2.4f') - #np.savetxt('numerical.dat', dw_n, fmt='%2.4f') print(np.max(np.abs(dw_t - dw_n))) print(np.max(np.abs(dx_t - dx_n))) - np.savetxt('diff.dat', dx_t - dx_n, fmt='%2.4f') - np.savetxt('theoretical.dat', dx_t, fmt='%2.4f') - np.savetxt('numerical.dat', dx_n, fmt='%2.4f') # Run sess.run(tf.global_variables_initializer()) result = sess.run([c], feed_dict = {a: ha, diff --git a/examples/python/tensorflow/shift.cpp b/examples/python/tensorflow/shift.cpp index a049f869d..c000c9db2 100644 --- a/examples/python/tensorflow/shift.cpp +++ b/examples/python/tensorflow/shift.cpp @@ -71,6 +71,8 @@ public: // checks OP_REQUIRES(context, Ca == Cb, tensorflow::errors::InvalidArgument("operands must have the same number of channels")); C = Ca; + if(OP == triton::dnn::shift::BPROP) + std::swap(C, F); } } diff --git a/lib/dnn/shift.cpp b/lib/dnn/shift.cpp index 92245ba16..122a51578 100644 --- a/lib/dnn/shift.cpp +++ b/lib/dnn/shift.cpp @@ -125,14 +125,16 @@ void shift::init(driver::stream *stream, driver::cu_module *module) { void shift::enqueue(driver::stream *stream, driver::kernel *kernel, driver::buffer *a, driver::buffer *b, driver::buffer *c, size_t TM, size_t TN, size_t nthreads) { + int32_t lda = AT_ ? K_ : M_; + int32_t ldb = BT_ ? N_ : K_; kernel->setArg(0, a); kernel->setArg(1, b); kernel->setArg(2, c); kernel->setArg(3, M_); kernel->setArg(4, N_); kernel->setArg(5, K_); - kernel->setArg(6, M_); - kernel->setArg(7, N_); + kernel->setArg(6, lda); + kernel->setArg(7, ldb); kernel->setArg(8, B_); kernel->setArg(9, AH_); kernel->setArg(10, AW_);