diff --git a/python/tutorials/02-fused-softmax.py b/python/tutorials/02-fused-softmax.py index aba4a7835..0f96152e2 100644 --- a/python/tutorials/02-fused-softmax.py +++ b/python/tutorials/02-fused-softmax.py @@ -25,7 +25,7 @@ def naive_softmax(x): # read 2MN elements ; write MN elements z = x - x_max[:, None] # read MN elements ; write MN elements - numerator = torch.exp(x) + numerator = torch.exp(z) # read MN elements ; write M elements denominator = numerator.sum(dim=1) # read 2MN elements ; write MN elements