This commit is contained in:
Ariel Kwiatkowski
2023-08-02 12:43:23 +02:00
committed by GitHub
parent 023dc89112
commit d6ea183807

View File

@@ -193,34 +193,33 @@ class REINFORCE:
"""Updates the policy network's weights.""" """Updates the policy network's weights."""
running_g = 0 running_g = 0
gs = [] gs = []
# Discounted return (backwards) - [::-1] will return an array in reverse # Discounted return (backwards) - [::-1] will return an array in reverse
for R in self.rewards[::-1]: for R in self.rewards[::-1]:
running_g = R + self.gamma * running_g running_g = R + self.gamma * running_g
gs.insert(0, running_g) gs.insert(0, running_g)
deltas = torch.tensor(gs) deltas = torch.tensor(gs)
log_probs = torch.stack(self.probs) log_probs = torch.stack(self.probs)
# Calculate the mean of log probabilities for all actions in the episode # Calculate the mean of log probabilities for all actions in the episode
log_prob_mean = log_probs.mean() log_prob_mean = log_probs.mean()
# Update the loss with the mean log probability and deltas # Update the loss with the mean log probability and deltas
# Now, we compute the correct total loss by taking the sum of the element-wise products. # Now, we compute the correct total loss by taking the sum of the element-wise products.
loss = -torch.sum(log_prob_mean * deltas) loss = -torch.sum(log_prob_mean * deltas)
# Update the policy network # Update the policy network
self.optimizer.zero_grad() self.optimizer.zero_grad()
loss.backward() loss.backward()
self.optimizer.step() self.optimizer.step()
# Empty / zero out all episode-centric/related variables # Empty / zero out all episode-centric/related variables
self.probs = [] self.probs = []
self.rewards = [] self.rewards = []
# %% # %%
# Now lets train the policy using REINFORCE to master the task of Inverted Pendulum. # Now lets train the policy using REINFORCE to master the task of Inverted Pendulum.
# #