[PYTHON] Made bench_blocksparse and bench_cross_entropy compatible

with the new performance report API
This commit is contained in:
Philippe Tillet
2021-03-08 20:19:10 -05:00
parent 5b9afaa688
commit 58a5c87c53
2 changed files with 16 additions and 9 deletions

View File

@@ -16,22 +16,26 @@ confs = [
for mode in ['forward', 'backward']
]
@triton.testing.perf_report(confs)
def bench_op(M, N, dtype, mode, provider):
# create inputs
x = torch.randn(M, N, dtype=dtype, device='cuda', requires_grad=True)
idx = 4 + torch.ones(M, dtype=torch.int64, device='cuda')
num_gb = (2 * x.numel() * x.element_size() * 1e-9)
gbps = lambda ms: num_gb / ms * 1e3
# forward pass
op = {'torch': torch.nn.CrossEntropyLoss(reduction='none'), \
'triton': triton.ops.cross_entropy}[provider]
if mode == 'forward':
ms = triton.testing.do_bench(lambda: op(x, idx))
mean_ms, min_ms, max_ms = triton.testing.do_bench(lambda: op(x, idx))
if mode == 'backward':
y = op(x, idx)
dy = torch.randn_like(y)
ms = triton.testing.do_bench(lambda: y.backward(dy, retain_graph=True), grad_to_none=x)
return num_gb / ms * 1e3
fn = lambda: y.backward(dy, retain_graph=True)
mean_ms, min_ms, max_ms = triton.testing.do_bench(fn, grad_to_none=x)
return gbps(mean_ms), gbps(min_ms), gbps(max_ms)
if __name__ == '__main__':
bench_op.run('tmp', False)