Files
triton/python/examples/dot.py
Philippe Tillet 448f4433d9 [PYTHON][KERNEL] Enforcing shapes to be known at compile-time for
TensorFlow Graph Execution
2019-10-29 00:48:53 -04:00

72 lines
2.1 KiB
Python

import numpy as np
import triton
def run_tf():
M, N, K = 2048, 2048, 2048
a = tf.placeholder(tf.float32, shape=[M, K])
b = tf.placeholder(tf.float32, shape=[N, K])
triton_c = triton.ops.dot(a, b, False, True, 1)
triton_d = triton.ops.dot(triton_c, b, True, False, 1)
triton_y = tf.math.reduce_mean(triton_d)
fw_c = tf.matmul(a, b, False, True)
fw_d = tf.matmul(fw_c, b, True, False)
fw_y = tf.math.reduce_mean(fw_d)
# Gradient
triton_da, triton_db = tf.gradients(triton_y, [a, b])
fw_da, fw_db = tf.gradients(fw_y, [a, b])
# Reference
feed_dict = {a: np.random.rand(M, K).astype(np.float32),
b: np.random.rand(K, N).astype(np.float32)}
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
result = sess.run([triton_da, fw_da, triton_db, fw_db, fw_y, triton_y], feed_dict = feed_dict)
triton_da, fw_da = result[0][0], result[1][0]
triton_db, fw_db = result[2][0], result[3][0]
# Benchmark
nanosec = triton.bench_registry[triton_d]
print('TFLOPS:', 2. * M * N * K / nanosec * 1e-3)
print('Diff DA:', (triton_da - fw_da).max())
print('Diff DB:', (triton_db - fw_db).max())
def run_torch():
torch.manual_seed(0)
M, N, K = 2048, 2048, 2048
a = torch.randn(M, K).cuda()
b = torch.randn(K, N).cuda()
a.requires_grad_(True)
b.requires_grad_(True)
torch_c = torch.matmul(a, torch.t(b))
torch_d = torch.matmul(torch.t(torch_c), b)
torch_y = torch.mean(torch_d)
triton_c = triton.ops.dot(a, b, False, True, 1)
triton_d = triton.ops.dot(triton_c, b, True, False, 1)
triton_y = torch.mean(triton_d)
# torch gradient
torch_y.backward()
torch_da = a.grad.clone()
torch_db = b.grad.clone()
# triton gradient
a.grad.zero_()
b.grad.zero_()
triton_y.backward()
triton_da = a.grad.clone()
triton_db = b.grad.clone()
nanosec = triton.bench_registry[triton_d]
print('TFLOPS:', 2. * M * N * K / nanosec * 1e-3)
print('Diff DA:', (torch_da - triton_da).max())
print('Diff DB:', (torch_db - triton_db).max())
try:
import tensorflow as tf
run_tf()
except ModuleNotFoundError:
pass
try:
import torch
run_torch()
except ModuleNotFoundError:
pass