72 lines
2.1 KiB
Python
72 lines
2.1 KiB
Python
import numpy as np
|
|
import triton
|
|
|
|
def run_tf():
|
|
M, N, K = 2048, 2048, 2048
|
|
a = tf.placeholder(tf.float32, shape=[M, K])
|
|
b = tf.placeholder(tf.float32, shape=[N, K])
|
|
triton_c = triton.ops.dot(a, b, False, True, 1)
|
|
triton_d = triton.ops.dot(triton_c, b, True, False, 1)
|
|
triton_y = tf.math.reduce_mean(triton_d)
|
|
fw_c = tf.matmul(a, b, False, True)
|
|
fw_d = tf.matmul(fw_c, b, True, False)
|
|
fw_y = tf.math.reduce_mean(fw_d)
|
|
# Gradient
|
|
triton_da, triton_db = tf.gradients(triton_y, [a, b])
|
|
fw_da, fw_db = tf.gradients(fw_y, [a, b])
|
|
# Reference
|
|
feed_dict = {a: np.random.rand(M, K).astype(np.float32),
|
|
b: np.random.rand(K, N).astype(np.float32)}
|
|
sess = tf.InteractiveSession()
|
|
sess.run(tf.global_variables_initializer())
|
|
result = sess.run([triton_da, fw_da, triton_db, fw_db, fw_y, triton_y], feed_dict = feed_dict)
|
|
triton_da, fw_da = result[0][0], result[1][0]
|
|
triton_db, fw_db = result[2][0], result[3][0]
|
|
# Benchmark
|
|
nanosec = triton.bench_registry[triton_d]
|
|
print('TFLOPS:', 2. * M * N * K / nanosec * 1e-3)
|
|
print('Diff DA:', (triton_da - fw_da).max())
|
|
print('Diff DB:', (triton_db - fw_db).max())
|
|
|
|
|
|
def run_torch():
|
|
torch.manual_seed(0)
|
|
M, N, K = 2048, 2048, 2048
|
|
a = torch.randn(M, K).cuda()
|
|
b = torch.randn(K, N).cuda()
|
|
a.requires_grad_(True)
|
|
b.requires_grad_(True)
|
|
torch_c = torch.matmul(a, torch.t(b))
|
|
torch_d = torch.matmul(torch.t(torch_c), b)
|
|
torch_y = torch.mean(torch_d)
|
|
triton_c = triton.ops.dot(a, b, False, True, 1)
|
|
triton_d = triton.ops.dot(triton_c, b, True, False, 1)
|
|
triton_y = torch.mean(triton_d)
|
|
# torch gradient
|
|
torch_y.backward()
|
|
torch_da = a.grad.clone()
|
|
torch_db = b.grad.clone()
|
|
# triton gradient
|
|
a.grad.zero_()
|
|
b.grad.zero_()
|
|
triton_y.backward()
|
|
triton_da = a.grad.clone()
|
|
triton_db = b.grad.clone()
|
|
|
|
nanosec = triton.bench_registry[triton_d]
|
|
print('TFLOPS:', 2. * M * N * K / nanosec * 1e-3)
|
|
print('Diff DA:', (torch_da - triton_da).max())
|
|
print('Diff DB:', (torch_db - triton_db).max())
|
|
|
|
try:
|
|
import tensorflow as tf
|
|
run_tf()
|
|
except ModuleNotFoundError:
|
|
pass
|
|
|
|
try:
|
|
import torch
|
|
run_torch()
|
|
except ModuleNotFoundError:
|
|
pass
|