[DOCS] Fix fused softmax example script naive softmax implementation (#178)

This commit is contained in:
Xiangru Lian
2021-08-02 09:37:31 -07:00
committed by GitHub
parent e8031fe61f
commit 9967e9d4b4

View File

@@ -25,7 +25,7 @@ def naive_softmax(x):
# read 2MN elements ; write MN elements
z = x - x_max[:, None]
# read MN elements ; write MN elements
numerator = torch.exp(x)
numerator = torch.exp(z)
# read MN elements ; write M elements
denominator = numerator.sum(dim=1)
# read 2MN elements ; write MN elements