[PYTHON][KERNEL] Added benchmarking functionalities for kernels
This commit is contained in:
@@ -2,11 +2,11 @@ import numpy as np
|
||||
import triton
|
||||
|
||||
def run_tf():
|
||||
M, N, K = 128, 128, 128
|
||||
M, N, K = 2048, 2048, 2048
|
||||
a = tf.placeholder(tf.float32, shape=[M, K])
|
||||
b = tf.placeholder(tf.float32, shape=[N, K])
|
||||
tr_c = triton.ops.dot(a, b, transpose_a = False, transpose_b = True)
|
||||
tr_d = triton.ops.dot(tr_c, b, transpose_a = True, transpose_b = False)
|
||||
tr_c = triton.ops.dot(a, b, transpose_a = False, transpose_b = True, bench=10)
|
||||
tr_d = triton.ops.dot(tr_c, b, transpose_a = True, transpose_b = False, bench=10)
|
||||
tf_c = tf.matmul(a, b, transpose_a = False, transpose_b = True)
|
||||
tf_d = tf.matmul(tf_c, b, transpose_a = True, transpose_b = False)
|
||||
# Gradient
|
||||
@@ -20,15 +20,20 @@ def run_tf():
|
||||
sess.run(tf.global_variables_initializer())
|
||||
result = sess.run([tr_da, tf_da], feed_dict = {a: ha,
|
||||
b: hb})
|
||||
# Benchmark
|
||||
nanosec = triton.bench_registry[tr_d]
|
||||
print('NANOSEC: ', nanosec)
|
||||
print('TFLOPS:', 2. * M * N * K / nanosec * 1e-3)
|
||||
# Test
|
||||
print(result[0][0])
|
||||
print(result[1][0])
|
||||
dif = np.abs(result[0][0] - result[1][0])
|
||||
print("dif: %f" % np.max(dif))
|
||||
|
||||
|
||||
def run_torch():
|
||||
torch.manual_seed(0)
|
||||
M, N, K = 128, 128, 128
|
||||
M, N, K = 2048, 2048, 2048
|
||||
a = torch.randn(M, K).cuda()
|
||||
b = torch.randn(K, N).cuda()
|
||||
a.requires_grad_(True)
|
||||
@@ -37,9 +42,8 @@ def run_torch():
|
||||
torch_d = torch.matmul(torch.t(torch_c), b)
|
||||
torch_y = torch.mean(torch_d)
|
||||
triton_c = triton.ops.dot(a, b, False, True)
|
||||
triton_d = triton.ops.dot(triton_c, b, True, False)
|
||||
triton_d = triton.ops.dot(triton_c, b, True, False, 1)
|
||||
triton_y = torch.mean(triton_d)
|
||||
|
||||
# torch gradient
|
||||
torch_y.backward()
|
||||
torch_da = a.grad.clone()
|
||||
@@ -51,6 +55,9 @@ def run_torch():
|
||||
triton_da = a.grad.clone()
|
||||
triton_db = b.grad.clone()
|
||||
|
||||
nanosec = triton.bench_registry[triton_d]
|
||||
print(nanosec)
|
||||
print('TFLOPS:', 2. * M * N * K / nanosec * 1e-3)
|
||||
print('Diff DA:', (torch_da - triton_da).max())
|
||||
print('Diff DB:', (torch_db - triton_db).max())
|
||||
|
||||
|
@@ -12,7 +12,8 @@ from tensorflow.python.ops import gradient_checker
|
||||
|
||||
one = 0
|
||||
out = 0
|
||||
bench = 0
|
||||
bench = 10
|
||||
|
||||
class ProdKeyTest(tf.test.TestCase):
|
||||
|
||||
def testEinsum(self):
|
||||
@@ -36,9 +37,9 @@ class ProdKeyTest(tf.test.TestCase):
|
||||
# key_dim = 16
|
||||
|
||||
for a_shape, b_shape, c_shape, einsum in [
|
||||
[ [ 4, 8, 8 ], [ 8, 8 ], [ 4, 8, 8 ], "btc,ck->btk" ],
|
||||
[ [ 4, 1024, 1024 ], [ 1024, 512 ], [ 4, 1024, 512 ], "btc,ck->btk" ],
|
||||
[ (batch_dim, ctx_dim, head_dim, 2, key_dim//2),(head_dim, 2, n_keys, key_dim//2), (batch_dim, ctx_dim, head_dim, 2, n_keys), "bchak,hank->bchan" ],
|
||||
#[ [ 4, 8, 8 ], [ 8, 8 ], [ 4, 8, 8 ], "btc,ck->btk" ],
|
||||
[ [4, 2048, 2048 ], [ 2048, 2048 ], [4, 2048, 2048 ], "btc,ck->btk" ],
|
||||
#[ (batch_dim, ctx_dim, head_dim, 2, key_dim//2),(head_dim, 2, n_keys, key_dim//2), (batch_dim, ctx_dim, head_dim, 2, n_keys), "bchak,hank->bchan" ],
|
||||
]:
|
||||
|
||||
if one:
|
||||
@@ -57,7 +58,7 @@ class ProdKeyTest(tf.test.TestCase):
|
||||
e = tf.placeholder(tf.float32, c_shape, name="e")
|
||||
feed_dict = { a:A, b:B, e:E }
|
||||
|
||||
cc = triton.ops.einsum(einsum, a, b)
|
||||
cc = triton.ops.einsum(einsum, a, b, bench=bench)
|
||||
|
||||
# error = gradient_checker.compute_gradient_error(a, a_shape, c, c_shape, delta=1e-1, extra_feed_dict={ b:B }) #
|
||||
# print(error)
|
||||
@@ -71,8 +72,12 @@ class ProdKeyTest(tf.test.TestCase):
|
||||
# c, = sess.run( [ c, ], feed_dict )
|
||||
c, da, db = sess.run( [ cc, da, db ], feed_dict )
|
||||
|
||||
if bench == 0:
|
||||
if bench > 0:
|
||||
nanosec = triton.bench_registry[cc]
|
||||
print(A.shape, B.shape)
|
||||
print(nanosec)
|
||||
|
||||
else:
|
||||
C = np.einsum(einsum, A, B)
|
||||
id = cc.op.get_attr('id')
|
||||
ctx = triton.ops._einsum.contexts[id]
|
||||
|
Reference in New Issue
Block a user