[PYTHON][EINSUM] Added support for FP16

This commit is contained in:
Philippe Tillet
2019-10-28 11:33:18 -04:00
parent 0ec213547c
commit e9c787ef05
13 changed files with 97 additions and 83 deletions

View File

@@ -12,7 +12,7 @@ from tensorflow.python.ops import gradient_checker
one = 0
out = 0
bench = 10
bench = 0
class ProdKeyTest(tf.test.TestCase):
@@ -37,14 +37,14 @@ class ProdKeyTest(tf.test.TestCase):
# key_dim = 16
for a_shape, b_shape, c_shape, einsum in [
#[ [ 4, 8, 8 ], [ 8, 8 ], [ 4, 8, 8 ], "btc,ck->btk" ],
[ [4, 2048, 2048 ], [ 2048, 2048 ], [4, 2048, 2048 ], "btc,ck->btk" ],
#[ (batch_dim, ctx_dim, head_dim, 2, key_dim//2),(head_dim, 2, n_keys, key_dim//2), (batch_dim, ctx_dim, head_dim, 2, n_keys), "bchak,hank->bchan" ],
[ [ 4, 8, 8 ], [ 8, 8 ], [ 4, 8, 8 ], "btc,ck->btk" ],
[ [4, 1024, 1024], [ 1024, 1024 ], [4, 1024, 1024 ], "btc,ck->btk" ],
[ (batch_dim, ctx_dim, head_dim, 2, key_dim//2),(head_dim, 2, n_keys, key_dim//2), (batch_dim, ctx_dim, head_dim, 2, n_keys), "bchak,hank->bchan" ],
]:
if one:
A = np.ones(a_shape, dtype=np.float32)
B = np.ones(b_shape, dtype=np.float32)
A = np.ones(a_shape, dtype=np.float16).astype(np.float32)
B = np.ones(b_shape, dtype=np.float16).astype(np.float32)
E = np.ones(c_shape, dtype=np.float32)
else:
# QK = np.random.normal(loc=0.0, scale=1.0, size=qk_shape).astype(np.float16).astype(np.float32)
@@ -53,12 +53,14 @@ class ProdKeyTest(tf.test.TestCase):
B = np.random.uniform(-1.0, 1.0, b_shape).astype(np.float16).astype(np.float32)
E = np.random.uniform(-1.0, 1.0, c_shape).astype(np.float16).astype(np.float32)
a = tf.placeholder(tf.float32, a_shape, name="a")
b = tf.placeholder(tf.float32, b_shape, name="b")
e = tf.placeholder(tf.float32, c_shape, name="e")
feed_dict = { a:A, b:B, e:E }
a = tf.placeholder(tf.float16, a_shape, name="a")
b = tf.placeholder(tf.float16, b_shape, name="b")
e = tf.placeholder(tf.float16, c_shape, name="e")
feed_dict = { a: A.astype(np.float16),
b: B.astype(np.float16),
e: E }
cc = triton.ops.einsum(einsum, a, b, bench=bench)
c = triton.ops.einsum(einsum, a, b, bench=bench)
# error = gradient_checker.compute_gradient_error(a, a_shape, c, c_shape, delta=1e-1, extra_feed_dict={ b:B }) #
# print(error)
@@ -66,21 +68,24 @@ class ProdKeyTest(tf.test.TestCase):
# print(error)
# return
with tf.control_dependencies([cc.op]):
da, db = tf.gradients(cc, [a, b], e)
with tf.control_dependencies([c.op]):
da, db = tf.gradients(c, [a, b], e)
# c, = sess.run( [ c, ], feed_dict )
c, da, db = sess.run( [ cc, da, db ], feed_dict )
rc, rda, rdb = sess.run( [ c, da, db ], feed_dict )
if bench > 0:
nanosec = triton.bench_registry[cc]
print(A.shape, B.shape)
print(nanosec)
nanosec = triton.bench_registry[c]
ctx = triton.ctx_registry[c]
b, m, n, k = tuple((ctx.bmnk[i] for i in range(0, 4)))
ops = 2. * b * m * n * k
print('C TFLOPS:', ops / triton.bench_registry[c] * 1e-3)
print('DA TFLOPS:', ops / triton.bench_registry[da] * 1e-3)
print('DB TFLOPS:', ops / triton.bench_registry[db] * 1e-3)
else:
C = np.einsum(einsum, A, B)
id = cc.op.get_attr('id')
ctx = triton.ops._einsum.contexts[id]
ctx = triton.ctx_registry[c]
t_a = ctx.trans_a
t_b = ctx.trans_b
e_a = ctx.einsum_a
@@ -100,9 +105,9 @@ class ProdKeyTest(tf.test.TestCase):
print("testProdKey", einsum)
if not bench:
for op, dev, cpu in [
[ "C", c, C ],
[ "DA", da, DA ],
[ "DB", db, DB ],
[ "C", rc, C ],
[ "DA", rda, DA ],
[ "DB", rdb, DB ],
]:
self.compare_results(op, dev, cpu)