mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-07-31 22:04:31 +00:00
Fix CI (#644)
This commit is contained in:
committed by
GitHub
parent
023dc89112
commit
d6ea183807
@@ -193,34 +193,33 @@ class REINFORCE:
|
|||||||
"""Updates the policy network's weights."""
|
"""Updates the policy network's weights."""
|
||||||
running_g = 0
|
running_g = 0
|
||||||
gs = []
|
gs = []
|
||||||
|
|
||||||
# Discounted return (backwards) - [::-1] will return an array in reverse
|
# Discounted return (backwards) - [::-1] will return an array in reverse
|
||||||
for R in self.rewards[::-1]:
|
for R in self.rewards[::-1]:
|
||||||
running_g = R + self.gamma * running_g
|
running_g = R + self.gamma * running_g
|
||||||
gs.insert(0, running_g)
|
gs.insert(0, running_g)
|
||||||
|
|
||||||
deltas = torch.tensor(gs)
|
deltas = torch.tensor(gs)
|
||||||
|
|
||||||
log_probs = torch.stack(self.probs)
|
log_probs = torch.stack(self.probs)
|
||||||
|
|
||||||
# Calculate the mean of log probabilities for all actions in the episode
|
# Calculate the mean of log probabilities for all actions in the episode
|
||||||
log_prob_mean = log_probs.mean()
|
log_prob_mean = log_probs.mean()
|
||||||
|
|
||||||
# Update the loss with the mean log probability and deltas
|
# Update the loss with the mean log probability and deltas
|
||||||
# Now, we compute the correct total loss by taking the sum of the element-wise products.
|
# Now, we compute the correct total loss by taking the sum of the element-wise products.
|
||||||
loss = -torch.sum(log_prob_mean * deltas)
|
loss = -torch.sum(log_prob_mean * deltas)
|
||||||
|
|
||||||
# Update the policy network
|
# Update the policy network
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
|
|
||||||
# Empty / zero out all episode-centric/related variables
|
# Empty / zero out all episode-centric/related variables
|
||||||
self.probs = []
|
self.probs = []
|
||||||
self.rewards = []
|
self.rewards = []
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
# Now lets train the policy using REINFORCE to master the task of Inverted Pendulum.
|
# Now lets train the policy using REINFORCE to master the task of Inverted Pendulum.
|
||||||
#
|
#
|
||||||
|
Reference in New Issue
Block a user