mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-07-31 22:04:31 +00:00
Fix reinforce tutorial (#1337)
This commit is contained in:
@@ -202,14 +202,11 @@ class REINFORCE:
|
||||
|
||||
deltas = torch.tensor(gs)
|
||||
|
||||
log_probs = torch.stack(self.probs)
|
||||
|
||||
# Calculate the mean of log probabilities for all actions in the episode
|
||||
log_prob_mean = log_probs.mean()
|
||||
log_probs = torch.stack(self.probs).squeeze()
|
||||
|
||||
# Update the loss with the mean log probability and deltas
|
||||
# Now, we compute the correct total loss by taking the sum of the element-wise products.
|
||||
loss = -torch.sum(log_prob_mean * deltas)
|
||||
loss = -torch.sum(log_probs * deltas)
|
||||
|
||||
# Update the policy network
|
||||
self.optimizer.zero_grad()
|
||||
@@ -298,8 +295,7 @@ for seed in [1, 2, 3, 5, 8]: # Fibonacci seeds
|
||||
# ~~~~~~~~~~~~~~~~~~~
|
||||
#
|
||||
|
||||
rewards_to_plot = [[reward[0] for reward in rewards] for rewards in rewards_over_seeds]
|
||||
df1 = pd.DataFrame(rewards_to_plot).melt()
|
||||
df1 = pd.DataFrame(rewards_over_seeds).melt()
|
||||
df1.rename(columns={"variable": "episodes", "value": "reward"}, inplace=True)
|
||||
sns.set(style="darkgrid", context="talk", palette="rainbow")
|
||||
sns.lineplot(x="episodes", y="reward", data=df1).set(
|
||||
|
Reference in New Issue
Block a user