[PYTHON][KERNEL] Enforcing shapes to be known at compile-time for
TensorFlow Graph Execution
This commit is contained in:
@@ -5,30 +5,28 @@ def run_tf():
|
||||
M, N, K = 2048, 2048, 2048
|
||||
a = tf.placeholder(tf.float32, shape=[M, K])
|
||||
b = tf.placeholder(tf.float32, shape=[N, K])
|
||||
tr_c = triton.ops.dot(a, b, transpose_a = False, transpose_b = True, bench=1)
|
||||
tr_d = triton.ops.dot(tr_c, b, transpose_a = True, transpose_b = False, bench=1)
|
||||
tf_c = tf.matmul(a, b, transpose_a = False, transpose_b = True)
|
||||
tf_d = tf.matmul(tf_c, b, transpose_a = True, transpose_b = False)
|
||||
triton_c = triton.ops.dot(a, b, False, True, 1)
|
||||
triton_d = triton.ops.dot(triton_c, b, True, False, 1)
|
||||
triton_y = tf.math.reduce_mean(triton_d)
|
||||
fw_c = tf.matmul(a, b, False, True)
|
||||
fw_d = tf.matmul(fw_c, b, True, False)
|
||||
fw_y = tf.math.reduce_mean(fw_d)
|
||||
# Gradient
|
||||
tr_da = tf.gradients(tr_d, [a])
|
||||
tf_da = tf.gradients(tf_d, [a])
|
||||
triton_da, triton_db = tf.gradients(triton_y, [a, b])
|
||||
fw_da, fw_db = tf.gradients(fw_y, [a, b])
|
||||
# Reference
|
||||
ha = np.random.rand(M, K).astype(np.float32)
|
||||
hb = np.random.rand(K, N).astype(np.float32)
|
||||
# Run
|
||||
feed_dict = {a: np.random.rand(M, K).astype(np.float32),
|
||||
b: np.random.rand(K, N).astype(np.float32)}
|
||||
sess = tf.InteractiveSession()
|
||||
sess.run(tf.global_variables_initializer())
|
||||
result = sess.run([tr_da, tf_da], feed_dict = {a: ha,
|
||||
b: hb})
|
||||
result = sess.run([triton_da, fw_da, triton_db, fw_db, fw_y, triton_y], feed_dict = feed_dict)
|
||||
triton_da, fw_da = result[0][0], result[1][0]
|
||||
triton_db, fw_db = result[2][0], result[3][0]
|
||||
# Benchmark
|
||||
nanosec = triton.bench_registry[tr_d]
|
||||
print('NANOSEC: ', nanosec)
|
||||
#print('TFLOPS:', 2. * M * N * K / nanosec * 1e-3)
|
||||
# Test
|
||||
print(result[0][0])
|
||||
print(result[1][0])
|
||||
dif = np.abs(result[0][0] - result[1][0])
|
||||
print("dif: %f" % np.max(dif))
|
||||
nanosec = triton.bench_registry[triton_d]
|
||||
print('TFLOPS:', 2. * M * N * K / nanosec * 1e-3)
|
||||
print('Diff DA:', (triton_da - fw_da).max())
|
||||
print('Diff DB:', (triton_db - fw_db).max())
|
||||
|
||||
|
||||
def run_torch():
|
||||
@@ -41,7 +39,7 @@ def run_torch():
|
||||
torch_c = torch.matmul(a, torch.t(b))
|
||||
torch_d = torch.matmul(torch.t(torch_c), b)
|
||||
torch_y = torch.mean(torch_d)
|
||||
triton_c = triton.ops.dot(a, b, False, True)
|
||||
triton_c = triton.ops.dot(a, b, False, True, 1)
|
||||
triton_d = triton.ops.dot(triton_c, b, True, False, 1)
|
||||
triton_y = torch.mean(triton_d)
|
||||
# torch gradient
|
||||
@@ -56,7 +54,6 @@ def run_torch():
|
||||
triton_db = b.grad.clone()
|
||||
|
||||
nanosec = triton.bench_registry[triton_d]
|
||||
print(nanosec)
|
||||
print('TFLOPS:', 2. * M * N * K / nanosec * 1e-3)
|
||||
print('Diff DA:', (torch_da - triton_da).max())
|
||||
print('Diff DB:', (torch_db - triton_db).max())
|
||||
|
Reference in New Issue
Block a user