[PYTHON] [OPS] Added einsum implementation
This commit is contained in:
129
python/examples/einsum_test.py
Normal file
129
python/examples/einsum_test.py
Normal file
@@ -0,0 +1,129 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import triton
|
||||
import blocksparse as bs
|
||||
from tensorflow.python.ops import gradient_checker
|
||||
|
||||
one = 0
|
||||
out = 0
|
||||
bench = 0
|
||||
class ProdKeyTest(tf.test.TestCase):
|
||||
|
||||
def testEinsum(self):
|
||||
# multi-threading screws up benchmarking
|
||||
conf = tf.ConfigProto(
|
||||
intra_op_parallelism_threads=1,
|
||||
inter_op_parallelism_threads=1)
|
||||
|
||||
with self.test_session(config=conf) as sess, tf.device("/gpu:0"):
|
||||
|
||||
batch_dim = 4
|
||||
ctx_dim = 256
|
||||
head_dim = 8
|
||||
n_keys = 512
|
||||
key_dim = 128
|
||||
|
||||
# batch_dim = 2
|
||||
# ctx_dim = 8
|
||||
# head_dim = 2
|
||||
# n_keys = 16
|
||||
# key_dim = 16
|
||||
|
||||
for a_shape, b_shape, c_shape, einsum in [
|
||||
[ [ 4, 8, 8 ], [ 8, 8 ], [ 4, 8, 8 ], "btc,ck->btk" ],
|
||||
[ [ 4, 1024, 1024 ], [ 1024, 512 ], [ 4, 1024, 512 ], "btc,ck->btk" ],
|
||||
[ (batch_dim, ctx_dim, head_dim, 2, key_dim//2),(head_dim, 2, n_keys, key_dim//2), (batch_dim, ctx_dim, head_dim, 2, n_keys), "bchak,hank->bchan" ],
|
||||
]:
|
||||
|
||||
if one:
|
||||
A = np.ones(a_shape, dtype=np.float32)
|
||||
B = np.ones(b_shape, dtype=np.float32)
|
||||
E = np.ones(c_shape, dtype=np.float32)
|
||||
else:
|
||||
# QK = np.random.normal(loc=0.0, scale=1.0, size=qk_shape).astype(np.float16).astype(np.float32)
|
||||
# V = np.random.normal(loc=0.0, scale=1.0, size=vw_shape).astype(np.float16).astype(np.float32)
|
||||
A = np.random.uniform(-1.0, 1.0, a_shape).astype(np.float16).astype(np.float32)
|
||||
B = np.random.uniform(-1.0, 1.0, b_shape).astype(np.float16).astype(np.float32)
|
||||
E = np.random.uniform(-1.0, 1.0, c_shape).astype(np.float16).astype(np.float32)
|
||||
|
||||
a = tf.placeholder(tf.float32, a_shape, name="a")
|
||||
b = tf.placeholder(tf.float32, b_shape, name="b")
|
||||
e = tf.placeholder(tf.float32, c_shape, name="e")
|
||||
feed_dict = { a:A, b:B, e:E }
|
||||
|
||||
cc = triton.ops.einsum(einsum, a, b)
|
||||
|
||||
# error = gradient_checker.compute_gradient_error(a, a_shape, c, c_shape, delta=1e-1, extra_feed_dict={ b:B }) #
|
||||
# print(error)
|
||||
# error = gradient_checker.compute_gradient_error(b, b_shape, c, c_shape, delta=1e-1, extra_feed_dict={ a:A }) #
|
||||
# print(error)
|
||||
# return
|
||||
|
||||
with tf.control_dependencies([cc.op]):
|
||||
da, db = tf.gradients(cc, [a, b], e)
|
||||
|
||||
# c, = sess.run( [ c, ], feed_dict )
|
||||
c, da, db = sess.run( [ cc, da, db ], feed_dict )
|
||||
|
||||
if bench == 0:
|
||||
|
||||
C = np.einsum(einsum, A, B)
|
||||
id = cc.op.get_attr('id')
|
||||
ctx = triton.ops._einsum.contexts[id]
|
||||
t_a = ctx.trans_a
|
||||
t_b = ctx.trans_b
|
||||
e_a = ctx.einsum_a
|
||||
e_b = ctx.einsum_b
|
||||
e_c = ctx.einsum_c
|
||||
|
||||
if not t_a and not t_b: # NN
|
||||
DA = np.einsum(f"{e_c},{e_b}->{e_a}", E, B)
|
||||
DB = np.einsum(f"{e_a},{e_c}->{e_b}", A, E)
|
||||
elif not t_a and t_b: # NT
|
||||
DA = np.einsum(f"{e_c},{e_b}->{e_a}", E, B)
|
||||
DB = np.einsum(f"{e_c},{e_a}->{e_b}", E, A)
|
||||
elif t_a and not t_b: # TN
|
||||
DA = np.einsum(f"{e_b},{e_c}->{e_a}", B, E)
|
||||
DB = np.einsum(f"{e_a},{e_c}->{e_b}", A, E)
|
||||
|
||||
print("testProdKey", einsum)
|
||||
if not bench:
|
||||
for op, dev, cpu in [
|
||||
[ "C", c, C ],
|
||||
[ "DA", da, DA ],
|
||||
[ "DB", db, DB ],
|
||||
]:
|
||||
self.compare_results(op, dev, cpu)
|
||||
|
||||
def compare_results(self, op, dev, cpu):
|
||||
dev = dev.astype(np.float64)
|
||||
cpu = cpu.astype(np.float64)
|
||||
|
||||
# print(dev.reshape(-1)[0:4])
|
||||
# print(cpu.reshape(-1)[0:4])
|
||||
|
||||
dif = np.abs(cpu - dev)
|
||||
maxval = np.max(abs(cpu))
|
||||
avgval = np.average(abs(cpu))
|
||||
maxdif = dif.max()
|
||||
max_err = maxdif if avgval == 0 else maxdif / avgval
|
||||
l2_err = 0.0 if avgval == 0 else np.sqrt(np.square(dif).sum()) / np.sqrt(np.square(cpu).sum())
|
||||
|
||||
print("op:%3s, max:%18.12f, avg:%18.12f, dif:%18.12f, err:%18.12f, l2_err:%18.12f shape:%15s" % (op, maxval, avgval, maxdif, max_err, l2_err, str(cpu.shape)))
|
||||
|
||||
if out:
|
||||
dim = cpu.shape[-1]
|
||||
np.savetxt("%s_dif.txt" % op, dif.reshape((-1,dim)), fmt='%4.1f') #7.5 5.3
|
||||
np.savetxt("%s_cpu.txt" % op, cpu.reshape((-1,dim)), fmt='%4.1f') #7.5 5.3
|
||||
np.savetxt("%s_dev.txt" % op, dev.reshape((-1,dim)), fmt='%4.1f') #7.5 5.3
|
||||
exit()
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.test.main()
|
||||
|
@@ -1,3 +1,6 @@
|
||||
# Special thanks to Scott Gray from OpenAI for writing the einsum parsing function
|
||||
|
||||
|
||||
import triton
|
||||
|
||||
class _einsum(triton.function):
|
||||
@@ -31,14 +34,18 @@ class _einsum(triton.function):
|
||||
TYPE a[SHAPE_A] = *pa;
|
||||
TYPE b[SHAPE_B] = *pb;
|
||||
c += USE_A @ USE_B;
|
||||
pa += TK;
|
||||
pb += TK;
|
||||
pa += TK * STRIDE_AK;
|
||||
pb += TK * STRIDE_BK;
|
||||
}
|
||||
// write-back
|
||||
TYPE *pc[TM, TN, TB] = C + rm[:, newaxis, newaxis] * std_C1
|
||||
+ rn[newaxis, :, newaxis] * 1
|
||||
+ rb[newaxis, newaxis, :] * std_C0;
|
||||
*pc = c;
|
||||
bool checkm[TM] = rm < dim_M;
|
||||
bool checkn[TN] = rn < dim_N;
|
||||
bool checkc[TM, TN, TB] = checkm[:, newaxis, newaxis] &&
|
||||
checkn[newaxis, :, newaxis];
|
||||
*?(checkc)pc = c;
|
||||
}
|
||||
"""
|
||||
|
||||
@@ -141,12 +148,12 @@ class _einsum(triton.function):
|
||||
'BROADCAST_AM': 'newaxis, :, newaxis' if trans_a else ':, newaxis, newaxis',
|
||||
'SHAPE_A' : 'TK, TM, TB' if trans_a else 'TM, TK, TB',
|
||||
# handle B transposition
|
||||
'USE_B' : 'b[^1, ^0, ^2]' if not trans_b else 'b',
|
||||
'USE_B' : 'b' if not trans_b else 'b[^1, ^0, ^2]',
|
||||
'STRIDE_BK' : 'std_B1' if not trans_b else '1',
|
||||
'STRIDE_BN' : '1' if not trans_b else 'std_B1',
|
||||
'BROADCAST_BK': 'newaxis, :, newaxis' if not trans_b else ':, newaxis, newaxis',
|
||||
'BROADCAST_BN': ':, newaxis, newaxis' if not trans_b else 'newaxis, :, newaxis',
|
||||
'SHAPE_B' : 'TN, TK, TB' if not trans_b else 'TK, TN, TB'}
|
||||
'BROADCAST_BK': ':, newaxis, newaxis' if not trans_b else 'newaxis, :, newaxis',
|
||||
'BROADCAST_BN': 'newaxis, :, newaxis' if not trans_b else ':, newaxis, newaxis',
|
||||
'SHAPE_B' : 'TK, TN, TB' if not trans_b else 'TN, TK, TB'}
|
||||
return _einsum.kernel(a, b, c,
|
||||
bmnk[1], bmnk[2], bmnk[3],
|
||||
std0[0], std0[1], std0[2],
|
||||
|
Reference in New Issue
Block a user