[DOCS] Fix fused softmax example script naive softmax implementation (#178)
This commit is contained in:
@@ -25,7 +25,7 @@ def naive_softmax(x):
|
|||||||
# read 2MN elements ; write MN elements
|
# read 2MN elements ; write MN elements
|
||||||
z = x - x_max[:, None]
|
z = x - x_max[:, None]
|
||||||
# read MN elements ; write MN elements
|
# read MN elements ; write MN elements
|
||||||
numerator = torch.exp(x)
|
numerator = torch.exp(z)
|
||||||
# read MN elements ; write M elements
|
# read MN elements ; write M elements
|
||||||
denominator = numerator.sum(dim=1)
|
denominator = numerator.sum(dim=1)
|
||||||
# read 2MN elements ; write MN elements
|
# read 2MN elements ; write MN elements
|
||||||
|
Reference in New Issue
Block a user