[PYTHON] bugfix in bench_cross_entropy
This commit is contained in:
@@ -30,7 +30,7 @@ def bench_op(M, N, dtype, mode, provider):
|
||||
if mode == 'backward':
|
||||
y = op(x, idx)
|
||||
dy = torch.randn_like(y)
|
||||
ms = triton.testing.do_bench(lambda: y.backward(dy, retain_graph=True))
|
||||
ms = triton.testing.do_bench(lambda: y.backward(dy, retain_graph=True), grad_to_none=x)
|
||||
return num_gb / ms * 1e3
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
Reference in New Issue
Block a user