2019-09-04 21:55:47 -04:00
|
|
|
import numpy as np
|
2019-08-16 15:56:58 -07:00
|
|
|
import tensorflow as tf
|
2019-08-29 21:34:23 -07:00
|
|
|
import triton
|
2019-08-25 21:26:09 -07:00
|
|
|
|
2019-08-16 20:50:18 -07:00
|
|
|
def run_dot():
|
|
|
|
M, N, K = 128, 128, 128
|
2019-09-04 01:54:43 -04:00
|
|
|
a = tf.placeholder(tf.float32, shape=[M, K])
|
|
|
|
b = tf.placeholder(tf.float32, shape=[N, K])
|
2019-09-04 21:55:47 -04:00
|
|
|
_dot = triton.ops.dot.apply
|
2019-09-04 12:47:59 -04:00
|
|
|
tr_c = _dot(a, b, transpose_a = False, transpose_b = True)
|
|
|
|
tr_d = _dot(tr_c, b, transpose_a = True, transpose_b = False)
|
|
|
|
tf_c = tf.matmul(a, b, transpose_a = False, transpose_b = True)
|
|
|
|
tf_d = tf.matmul(tf_c, b, transpose_a = True, transpose_b = False)
|
|
|
|
# Gradient
|
|
|
|
tr_da = tf.gradients(tr_d, [a])
|
|
|
|
tf_da = tf.gradients(tf_d, [a])
|
2019-08-16 20:50:18 -07:00
|
|
|
# Reference
|
2019-09-04 01:54:43 -04:00
|
|
|
ha = np.random.rand(M, K).astype(np.float32)
|
|
|
|
hb = np.random.rand(K, N).astype(np.float32)
|
2019-08-16 20:50:18 -07:00
|
|
|
# Run
|
|
|
|
sess = tf.InteractiveSession()
|
|
|
|
sess.run(tf.global_variables_initializer())
|
2019-09-04 12:47:59 -04:00
|
|
|
result = sess.run([tr_da, tf_da], feed_dict = {a: ha,
|
|
|
|
b: hb})
|
2019-08-16 20:50:18 -07:00
|
|
|
# Test
|
2019-09-04 12:47:59 -04:00
|
|
|
print(result[0][0])
|
|
|
|
print(result[1][0])
|
|
|
|
dif = np.abs(result[0][0] - result[1][0])
|
2019-08-16 20:50:18 -07:00
|
|
|
print("dif: %f" % np.max(dif))
|
|
|
|
|
|
|
|
run_dot()
|