more stuff
This commit is contained in:
@@ -45,11 +45,11 @@ if mode == MODE.TF:
|
||||
fw_beta = tf.placeholder(shape=beta.shape, dtype=beta.dtype)
|
||||
fw_dy = tf.placeholder(shape=dy.shape, dtype=dy.dtype)
|
||||
# execute
|
||||
fw_y = triton.ops.batchnorm(fw_x, fw_gamma, fw_beta, 1e-4)
|
||||
#fw_dx, fw_dgamma, fw_dbeta = tf.gradients(fw_y, [fw_x, fw_gamma, fw_beta])
|
||||
fw_mean, fw_var = tf.nn.moments(fw_x, [1, 2, 3])
|
||||
fw_y = triton.ops.batchnorm(fw_x, fw_mean, fw_var, fw_gamma, fw_beta, 1e-4)
|
||||
fw_dx, fw_dgamma, fw_dbeta = tf.gradients(fw_y, [fw_x, fw_gamma, fw_beta], fw_dy)
|
||||
# run
|
||||
sess = tf.InteractiveSession()
|
||||
feed_dict = {fw_x: x, fw_gamma: gamma, fw_beta: beta, fw_dy: dy}
|
||||
sess.run(tf.global_variables_initializer())
|
||||
#print(fw_dx, fw_dgamma, fw_dbeta)
|
||||
result = sess.run([fw_y], feed_dict=feed_dict)
|
||||
print(result)
|
||||
result = sess.run([fw_dx, fw_dgamma, fw_dbeta], feed_dict=feed_dict)
|
Reference in New Issue
Block a user