mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-01 06:07:08 +00:00
Update train_agent.md (#1237)
Co-authored-by: Mark Towers <mark.m.towers@gmail.com>
This commit is contained in:
@@ -115,8 +115,11 @@ start_epsilon = 1.0
|
||||
epsilon_decay = start_epsilon / (n_episodes / 2) # reduce the exploration over time
|
||||
final_epsilon = 0.1
|
||||
|
||||
env = gym.make("Blackjack-v1", sab=False)
|
||||
env = gym.wrappers.RecordEpisodeStatistics(env, deque_size=n_episodes)
|
||||
|
||||
agent = BlackjackAgent(
|
||||
env,
|
||||
env=env,
|
||||
learning_rate=learning_rate,
|
||||
initial_epsilon=start_epsilon,
|
||||
epsilon_decay=epsilon_decay,
|
||||
@@ -129,9 +132,6 @@ Info: The current hyperparameters are set to quickly train a decent agent. If yo
|
||||
```python
|
||||
from tqdm import tqdm
|
||||
|
||||
env = gym.make("Blackjack-v1", sab=False)
|
||||
env = gym.wrappers.RecordEpisodeStatistics(env, deque_size=n_episodes)
|
||||
|
||||
for episode in tqdm(range(n_episodes)):
|
||||
obs, info = env.reset()
|
||||
done = False
|
||||
@@ -151,6 +151,34 @@ for episode in tqdm(range(n_episodes)):
|
||||
agent.decay_epsilon()
|
||||
```
|
||||
|
||||
You can use `matplotlib` to visualize the training reward and length.
|
||||
|
||||
```python
|
||||
from matplotlib import pyplot as plt
|
||||
# visualize the episode rewards, episode length and training error in one figure
|
||||
fig, axs = plt.subplots(1, 3, figsize=(20, 8))
|
||||
|
||||
# np.convolve will compute the rolling mean for 100 episodes
|
||||
|
||||
axs[0].plot(np.convolve(env.return_queue, np.ones(100)))
|
||||
axs[0].set_title("Episode Rewards")
|
||||
axs[0].set_xlabel("Episode")
|
||||
axs[0].set_ylabel("Reward")
|
||||
|
||||
axs[1].plot(np.convolve(env.length_queue, np.ones(100)))
|
||||
axs[1].set_title("Episode Lengths")
|
||||
axs[1].set_xlabel("Episode")
|
||||
axs[1].set_ylabel("Length")
|
||||
|
||||
axs[2].plot(np.convolve(agent.training_error, np.ones(100)))
|
||||
axs[2].set_title("Training Error")
|
||||
axs[2].set_xlabel("Episode")
|
||||
axs[2].set_ylabel("Temporal Difference")
|
||||
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
```
|
||||
|
||||

|
||||
|
||||
## Visualising the policy
|
||||
|
Reference in New Issue
Block a user