mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-01 22:11:25 +00:00
Fixup for 'Training An Agent' page (#1281)
Co-authored-by: chr0nikler <jchahal@diffzero.com>
This commit is contained in:
committed by
GitHub
parent
87cc458437
commit
fc74bb8fc0
@@ -155,38 +155,49 @@ You can use `matplotlib` to visualize the training reward and length.
|
|||||||
|
|
||||||
```python
|
```python
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
# visualize the episode rewards, episode length and training error in one figure
|
|
||||||
fig, axs = plt.subplots(1, 3, figsize=(20, 8))
|
|
||||||
|
|
||||||
# np.convolve will compute the rolling mean for 100 episodes
|
def get_moving_avgs(arr, window, convolution_mode):
|
||||||
|
return np.convolve(
|
||||||
|
np.array(arr).flatten(),
|
||||||
|
np.ones(window),
|
||||||
|
mode=convolution_mode
|
||||||
|
) / window
|
||||||
|
|
||||||
axs[0].plot(np.convolve(env.return_queue, np.ones(100)/100))
|
# Smooth over a 500 episode window
|
||||||
axs[0].set_title("Episode Rewards")
|
rolling_length = 500
|
||||||
axs[0].set_xlabel("Episode")
|
fig, axs = plt.subplots(ncols=3, figsize=(12, 5))
|
||||||
axs[0].set_ylabel("Reward")
|
|
||||||
|
|
||||||
axs[1].plot(np.convolve(env.length_queue, np.ones(100)/100))
|
axs[0].set_title("Episode rewards")
|
||||||
axs[1].set_title("Episode Lengths")
|
reward_moving_average = get_moving_avgs(
|
||||||
axs[1].set_xlabel("Episode")
|
env.return_queue,
|
||||||
axs[1].set_ylabel("Length")
|
rolling_length,
|
||||||
|
"valid"
|
||||||
|
)
|
||||||
|
axs[0].plot(range(len(reward_moving_average)), reward_moving_average)
|
||||||
|
|
||||||
|
axs[1].set_title("Episode lengths")
|
||||||
|
length_moving_average = get_moving_avgs(
|
||||||
|
env.length_queue,
|
||||||
|
rolling_length,
|
||||||
|
"valid"
|
||||||
|
)
|
||||||
|
axs[1].plot(range(len(length_moving_average)), length_moving_average)
|
||||||
|
|
||||||
axs[2].plot(np.convolve(agent.training_error, np.ones(100)/100))
|
|
||||||
axs[2].set_title("Training Error")
|
axs[2].set_title("Training Error")
|
||||||
axs[2].set_xlabel("Episode")
|
training_error_moving_average = get_moving_avgs(
|
||||||
axs[2].set_ylabel("Temporal Difference")
|
agent.training_error,
|
||||||
|
rolling_length,
|
||||||
|
"same"
|
||||||
|
)
|
||||||
|
axs[2].plot(range(len(training_error_moving_average)), training_error_moving_average)
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||

|

|
||||||
|
|
||||||
## Visualising the policy
|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||
Hopefully this tutorial helped you get a grip of how to interact with Gymnasium environments and sets you on a journey to solve many more RL challenges.
|
Hopefully this tutorial helped you get a grip of how to interact with Gymnasium environments and sets you on a journey to solve many more RL challenges.
|
||||||
|
|
||||||
It is recommended that you solve this environment by yourself (project based learning is really effective!). You can apply your favorite discrete RL algorithm or give Monte Carlo ES a try (covered in `Sutton & Barto <http://incompleteideas.net/book/the-book-2nd.html>`_, section 5.3) - this way you can compare your results directly to the book.
|
It is recommended that you solve this environment by yourself (project based learning is really effective!). You can apply your favorite discrete RL algorithm or give Monte Carlo ES a try (covered in `Sutton & Barto <http://incompleteideas.net/book/the-book-2nd.html>`_, section 5.3) - this way you can compare your results directly to the book.
|
||||||
|
Reference in New Issue
Block a user