[PYTHON][EINSUM] Added support for FP16
This commit is contained in:
@@ -5,8 +5,8 @@ def run_tf():
|
||||
M, N, K = 2048, 2048, 2048
|
||||
a = tf.placeholder(tf.float32, shape=[M, K])
|
||||
b = tf.placeholder(tf.float32, shape=[N, K])
|
||||
tr_c = triton.ops.dot(a, b, transpose_a = False, transpose_b = True, bench=10)
|
||||
tr_d = triton.ops.dot(tr_c, b, transpose_a = True, transpose_b = False, bench=10)
|
||||
tr_c = triton.ops.dot(a, b, transpose_a = False, transpose_b = True, bench=1)
|
||||
tr_d = triton.ops.dot(tr_c, b, transpose_a = True, transpose_b = False, bench=1)
|
||||
tf_c = tf.matmul(a, b, transpose_a = False, transpose_b = True)
|
||||
tf_d = tf.matmul(tf_c, b, transpose_a = True, transpose_b = False)
|
||||
# Gradient
|
||||
@@ -23,7 +23,7 @@ def run_tf():
|
||||
# Benchmark
|
||||
nanosec = triton.bench_registry[tr_d]
|
||||
print('NANOSEC: ', nanosec)
|
||||
print('TFLOPS:', 2. * M * N * K / nanosec * 1e-3)
|
||||
#print('TFLOPS:', 2. * M * N * K / nanosec * 1e-3)
|
||||
# Test
|
||||
print(result[0][0])
|
||||
print(result[1][0])
|
||||
|
Reference in New Issue
Block a user