more stuff

This commit is contained in:
Philippe Tillet
2019-10-30 13:44:31 -04:00
parent bf3dc63858
commit 9b0f1a0807
4 changed files with 91 additions and 106 deletions

View File

@@ -45,11 +45,11 @@ if mode == MODE.TF:
fw_beta = tf.placeholder(shape=beta.shape, dtype=beta.dtype)
fw_dy = tf.placeholder(shape=dy.shape, dtype=dy.dtype)
# execute
fw_y = triton.ops.batchnorm(fw_x, fw_gamma, fw_beta, 1e-4)
#fw_dx, fw_dgamma, fw_dbeta = tf.gradients(fw_y, [fw_x, fw_gamma, fw_beta])
fw_mean, fw_var = tf.nn.moments(fw_x, [1, 2, 3])
fw_y = triton.ops.batchnorm(fw_x, fw_mean, fw_var, fw_gamma, fw_beta, 1e-4)
fw_dx, fw_dgamma, fw_dbeta = tf.gradients(fw_y, [fw_x, fw_gamma, fw_beta], fw_dy)
# run
sess = tf.InteractiveSession()
feed_dict = {fw_x: x, fw_gamma: gamma, fw_beta: beta, fw_dy: dy}
sess.run(tf.global_variables_initializer())
#print(fw_dx, fw_dgamma, fw_dbeta)
result = sess.run([fw_y], feed_dict=feed_dict)
print(result)
result = sess.run([fw_dx, fw_dgamma, fw_dbeta], feed_dict=feed_dict)