Fix scaling in plots (#1259)

This commit is contained in:
Gregor Kikelj
2024-11-22 10:35:00 +01:00
committed by GitHub
parent 79a2306dba
commit 9ba3283af7

View File

@@ -160,17 +160,17 @@ fig, axs = plt.subplots(1, 3, figsize=(20, 8))
# np.convolve will compute the rolling mean for 100 episodes
axs[0].plot(np.convolve(env.return_queue, np.ones(100)))
axs[0].plot(np.convolve(env.return_queue, np.ones(100)/100))
axs[0].set_title("Episode Rewards")
axs[0].set_xlabel("Episode")
axs[0].set_ylabel("Reward")
axs[1].plot(np.convolve(env.length_queue, np.ones(100)))
axs[1].plot(np.convolve(env.length_queue, np.ones(100)/100))
axs[1].set_title("Episode Lengths")
axs[1].set_xlabel("Episode")
axs[1].set_ylabel("Length")
axs[2].plot(np.convolve(agent.training_error, np.ones(100)))
axs[2].plot(np.convolve(agent.training_error, np.ones(100)/100))
axs[2].set_title("Training Error")
axs[2].set_xlabel("Episode")
axs[2].set_ylabel("Temporal Difference")