This commit is contained in:
Ariel Kwiatkowski
2023-08-02 12:43:23 +02:00
committed by GitHub
parent 023dc89112
commit d6ea183807

View File

@@ -193,34 +193,33 @@ class REINFORCE:
"""Updates the policy network's weights."""
running_g = 0
gs = []
# Discounted return (backwards) - [::-1] will return an array in reverse
for R in self.rewards[::-1]:
running_g = R + self.gamma * running_g
gs.insert(0, running_g)
deltas = torch.tensor(gs)
log_probs = torch.stack(self.probs)
# Calculate the mean of log probabilities for all actions in the episode
log_prob_mean = log_probs.mean()
# Update the loss with the mean log probability and deltas
# Now, we compute the correct total loss by taking the sum of the element-wise products.
loss = -torch.sum(log_prob_mean * deltas)
# Update the policy network
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
# Empty / zero out all episode-centric/related variables
self.probs = []
self.rewards = []
# %%
# Now lets train the policy using REINFORCE to master the task of Inverted Pendulum.
#