Fixup for 'Training An Agent' page (#1281)

Co-authored-by: chr0nikler <jchahal@diffzero.com>
This commit is contained in:
Joraaver S. Chahal
2025-01-06 04:12:00 -08:00
committed by GitHub
parent 87cc458437
commit fc74bb8fc0

View File

@@ -155,38 +155,49 @@ You can use `matplotlib` to visualize the training reward and length.
```python
from matplotlib import pyplot as plt
# visualize the episode rewards, episode length and training error in one figure
fig, axs = plt.subplots(1, 3, figsize=(20, 8))
# np.convolve will compute the rolling mean for 100 episodes
def get_moving_avgs(arr, window, convolution_mode):
return np.convolve(
np.array(arr).flatten(),
np.ones(window),
mode=convolution_mode
) / window
axs[0].plot(np.convolve(env.return_queue, np.ones(100)/100))
axs[0].set_title("Episode Rewards")
axs[0].set_xlabel("Episode")
axs[0].set_ylabel("Reward")
# Smooth over a 500 episode window
rolling_length = 500
fig, axs = plt.subplots(ncols=3, figsize=(12, 5))
axs[1].plot(np.convolve(env.length_queue, np.ones(100)/100))
axs[1].set_title("Episode Lengths")
axs[1].set_xlabel("Episode")
axs[1].set_ylabel("Length")
axs[0].set_title("Episode rewards")
reward_moving_average = get_moving_avgs(
env.return_queue,
rolling_length,
"valid"
)
axs[0].plot(range(len(reward_moving_average)), reward_moving_average)
axs[1].set_title("Episode lengths")
length_moving_average = get_moving_avgs(
env.length_queue,
rolling_length,
"valid"
)
axs[1].plot(range(len(length_moving_average)), length_moving_average)
axs[2].plot(np.convolve(agent.training_error, np.ones(100)/100))
axs[2].set_title("Training Error")
axs[2].set_xlabel("Episode")
axs[2].set_ylabel("Temporal Difference")
training_error_moving_average = get_moving_avgs(
agent.training_error,
rolling_length,
"same"
)
axs[2].plot(range(len(training_error_moving_average)), training_error_moving_average)
plt.tight_layout()
plt.show()
```
![](../_static/img/tutorials/blackjack_training_plots.png "Training Plot")
## Visualising the policy
![](../_static/img/tutorials/blackjack_with_usable_ace.png "With a usable ace")
![](../_static/img/tutorials/blackjack_without_usable_ace.png "Without a usable ace")
Hopefully this tutorial helped you get a grip of how to interact with Gymnasium environments and sets you on a journey to solve many more RL challenges.
It is recommended that you solve this environment by yourself (project based learning is really effective!). You can apply your favorite discrete RL algorithm or give Monte Carlo ES a try (covered in `Sutton & Barto <http://incompleteideas.net/book/the-book-2nd.html>`_, section 5.3) - this way you can compare your results directly to the book.