[PYTHON] bugfix in bench_cross_entropy

This commit is contained in:
Philippe Tillet
2021-02-26 02:37:05 -05:00
parent 50ff1aea86
commit ff62f7fffc
2 changed files with 9 additions and 23 deletions

View File

@@ -30,7 +30,7 @@ def bench_op(M, N, dtype, mode, provider):
if mode == 'backward':
y = op(x, idx)
dy = torch.randn_like(y)
ms = triton.testing.do_bench(lambda: y.backward(dy, retain_graph=True))
ms = triton.testing.do_bench(lambda: y.backward(dy, retain_graph=True), grad_to_none=x)
return num_gb / ms * 1e3
if __name__ == '__main__':