diff --git a/docs/introduction/train_agent.md b/docs/introduction/train_agent.md index 4c0f2542e..461a0d9b1 100644 --- a/docs/introduction/train_agent.md +++ b/docs/introduction/train_agent.md @@ -155,38 +155,49 @@ You can use `matplotlib` to visualize the training reward and length. ```python from matplotlib import pyplot as plt -# visualize the episode rewards, episode length and training error in one figure -fig, axs = plt.subplots(1, 3, figsize=(20, 8)) -# np.convolve will compute the rolling mean for 100 episodes +def get_moving_avgs(arr, window, convolution_mode): + return np.convolve( + np.array(arr).flatten(), + np.ones(window), + mode=convolution_mode + ) / window -axs[0].plot(np.convolve(env.return_queue, np.ones(100)/100)) -axs[0].set_title("Episode Rewards") -axs[0].set_xlabel("Episode") -axs[0].set_ylabel("Reward") +# Smooth over a 500 episode window +rolling_length = 500 +fig, axs = plt.subplots(ncols=3, figsize=(12, 5)) -axs[1].plot(np.convolve(env.length_queue, np.ones(100)/100)) -axs[1].set_title("Episode Lengths") -axs[1].set_xlabel("Episode") -axs[1].set_ylabel("Length") +axs[0].set_title("Episode rewards") +reward_moving_average = get_moving_avgs( + env.return_queue, + rolling_length, + "valid" +) +axs[0].plot(range(len(reward_moving_average)), reward_moving_average) + +axs[1].set_title("Episode lengths") +length_moving_average = get_moving_avgs( + env.length_queue, + rolling_length, + "valid" +) +axs[1].plot(range(len(length_moving_average)), length_moving_average) -axs[2].plot(np.convolve(agent.training_error, np.ones(100)/100)) axs[2].set_title("Training Error") -axs[2].set_xlabel("Episode") -axs[2].set_ylabel("Temporal Difference") - +training_error_moving_average = get_moving_avgs( + agent.training_error, + rolling_length, + "same" +) +axs[2].plot(range(len(training_error_moving_average)), training_error_moving_average) plt.tight_layout() plt.show() + + ``` ![](../_static/img/tutorials/blackjack_training_plots.png "Training Plot") -## Visualising the policy - -![](../_static/img/tutorials/blackjack_with_usable_ace.png "With a usable ace") - -![](../_static/img/tutorials/blackjack_without_usable_ace.png "Without a usable ace") - Hopefully this tutorial helped you get a grip of how to interact with Gymnasium environments and sets you on a journey to solve many more RL challenges. It is recommended that you solve this environment by yourself (project based learning is really effective!). You can apply your favorite discrete RL algorithm or give Monte Carlo ES a try (covered in `Sutton & Barto `_, section 5.3) - this way you can compare your results directly to the book.