diff --git a/python/examples/batchnorm.py b/python/examples/batchnorm.py new file mode 100644 index 000000000..673c8ec2c --- /dev/null +++ b/python/examples/batchnorm.py @@ -0,0 +1,55 @@ +import triton +import numpy as np +from enum import Enum + +class MODE(Enum): + TF = 1 + TORCH = 2 + +try: + import tensorflow as tf + mode = MODE.TF +except ModuleNotFoundError: + pass + +try: + import torch + mode = MODE.TORCH +except ModuleNotFoundError: + pass + + +C, H, W, B = 32, 1, 1, 128 + +x = np.random.uniform(-1, 1, (C, H, W, B)).astype(np.float32) +gamma = np.random.uniform(-1, 1, C).astype(np.float32) +beta = np.random.uniform(-1, 1, C).astype(np.float32) +dy = np.random.uniform(-1, 1, (C, H, W, B)).astype(np.float32) + +if mode == MODE.TORCH: + fw_x = torch.from_numpy(x).cuda() + fw_gamma = torch.from_numpy(gamma).cuda() + fw_beta = torch.from_numpy(beta).cuda() + fw_dy = torch.from_numpy(dy).cuda() + # register gradients + fw_x.requires_grad_(True) + fw_gamma.requires_grad_(True) + fw_beta.requires_grad_(True) + # execute + fw_y = triton.ops.batchnorm(fw_x, fw_gamma, fw_beta, 1e-4) + fw_y.backward(fw_dy) + +if mode == MODE.TF: + fw_x = tf.placeholder(shape=x.shape, dtype=x.dtype) + fw_gamma = tf.placeholder(shape=gamma.shape, dtype=gamma.dtype) + fw_beta = tf.placeholder(shape=beta.shape, dtype=beta.dtype) + fw_dy = tf.placeholder(shape=dy.shape, dtype=dy.dtype) + # execute + fw_y = triton.ops.batchnorm(fw_x, fw_gamma, fw_beta, 1e-4) + #fw_dx, fw_dgamma, fw_dbeta = tf.gradients(fw_y, [fw_x, fw_gamma, fw_beta]) + sess = tf.InteractiveSession() + feed_dict = {fw_x: x, fw_gamma: gamma, fw_beta: beta, fw_dy: dy} + sess.run(tf.global_variables_initializer()) + #print(fw_dx, fw_dgamma, fw_dbeta) + result = sess.run([fw_y], feed_dict=feed_dict) + print(result) diff --git a/python/triton/utils.py b/python/triton/utils.py index e55afd602..17534abb7 100644 --- a/python/triton/utils.py +++ b/python/triton/utils.py @@ -18,7 +18,7 @@ def empty(shape, dtype): return tf_empty_proxy(shape, dtype) #return fw.tf_extra_ops.alloc_empty(args, T = dtype) elif fw.has_torch(): - return fw.torch.empty(*shapes).cuda() + return fw.torch.empty(*shape).cuda() class lazy_shape: