Run: ``` isort ./python autopep8 -i --ignore E501,E701,E731 $(find ./python/ -name '*.py') ``` with an `.isort.cfg` and then clean up a few warts. This PR should be a no-op; the idea is that this is all boring whitespace changes, and any config file changes will be in a different change to make it easier to review.
42 lines
1.3 KiB
Python
42 lines
1.3 KiB
Python
import torch
|
|
|
|
import triton
|
|
|
|
confs = [
|
|
triton.testing.Benchmark(
|
|
x_names=['N'],
|
|
x_vals=[128, 256, 512, 1024, 2048, 3072, 4096, 6144, 8192],
|
|
line_arg='provider',
|
|
line_vals=['triton', 'torch'],
|
|
line_names=['Triton', 'Torch'],
|
|
ylabel='GBPS',
|
|
plot_name=f'{mode}-2048',
|
|
args={'M': 2048, 'dtype': torch.float16, 'mode': mode}
|
|
)
|
|
for mode in ['forward', 'backward']
|
|
]
|
|
|
|
|
|
@triton.testing.perf_report(confs)
|
|
def bench_op(M, N, dtype, mode, provider):
|
|
# create inputs
|
|
x = torch.randn(M, N, dtype=dtype, device='cuda', requires_grad=True)
|
|
idx = 4 + torch.ones(M, dtype=torch.int64, device='cuda')
|
|
num_gb = (2 * x.numel() * x.element_size() * 1e-9)
|
|
gbps = lambda ms: num_gb / ms * 1e3
|
|
# forward pass
|
|
op = {'torch': torch.nn.CrossEntropyLoss(reduction='none'),
|
|
'triton': triton.ops.cross_entropy}[provider]
|
|
if mode == 'forward':
|
|
mean_ms, min_ms, max_ms = triton.testing.do_bench(lambda: op(x, idx))
|
|
if mode == 'backward':
|
|
y = op(x, idx)
|
|
dy = torch.randn_like(y)
|
|
fn = lambda: y.backward(dy, retain_graph=True)
|
|
mean_ms, min_ms, max_ms = triton.testing.do_bench(fn, grad_to_none=[x])
|
|
return gbps(mean_ms), gbps(min_ms), gbps(max_ms)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
bench_op.run(print_data=True)
|