mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-02 06:16:32 +00:00
Co-authored-by: will <will2346@live.co.uk> Co-authored-by: Will Dudley <14932240+WillDudley@users.noreply.github.com> Co-authored-by: Mark Towers <marktowers@Marks-MacBook-Pro.local>
127 lines
5.0 KiB
Python
127 lines
5.0 KiB
Python
"""Wrapper that tracks the cumulative rewards and episode lengths."""
|
|
import time
|
|
from collections import deque
|
|
from typing import Optional
|
|
|
|
import numpy as np
|
|
|
|
import gymnasium as gym
|
|
|
|
|
|
class RecordEpisodeStatistics(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
|
"""This wrapper will keep track of cumulative rewards and episode lengths.
|
|
|
|
At the end of an episode, the statistics of the episode will be added to ``info``
|
|
using the key ``episode``. If using a vectorized environment also the key
|
|
``_episode`` is used which indicates whether the env at the respective index has
|
|
the episode statistics.
|
|
|
|
After the completion of an episode, ``info`` will look like this::
|
|
|
|
>>> info = {
|
|
... "episode": {
|
|
... "r": "<cumulative reward>",
|
|
... "l": "<episode length>",
|
|
... "t": "<elapsed time since beginning of episode>"
|
|
... },
|
|
... }
|
|
|
|
For a vectorized environments the output will be in the form of::
|
|
|
|
>>> infos = {
|
|
... "final_observation": "<array of length num-envs>",
|
|
... "_final_observation": "<boolean array of length num-envs>",
|
|
... "final_info": "<array of length num-envs>",
|
|
... "_final_info": "<boolean array of length num-envs>",
|
|
... "episode": {
|
|
... "r": "<array of cumulative reward>",
|
|
... "l": "<array of episode length>",
|
|
... "t": "<array of elapsed time since beginning of episode>"
|
|
... },
|
|
... "_episode": "<boolean array of length num-envs>"
|
|
... }
|
|
|
|
Moreover, the most recent rewards and episode lengths are stored in buffers that can be accessed via
|
|
:attr:`wrapped_env.return_queue` and :attr:`wrapped_env.length_queue` respectively.
|
|
|
|
Attributes:
|
|
return_queue: The cumulative rewards of the last ``deque_size``-many episodes
|
|
length_queue: The lengths of the last ``deque_size``-many episodes
|
|
"""
|
|
|
|
def __init__(self, env: gym.Env, deque_size: int = 100):
|
|
"""This wrapper will keep track of cumulative rewards and episode lengths.
|
|
|
|
Args:
|
|
env (Env): The environment to apply the wrapper
|
|
deque_size: The size of the buffers :attr:`return_queue` and :attr:`length_queue`
|
|
"""
|
|
gym.utils.RecordConstructorArgs.__init__(self, deque_size=deque_size)
|
|
gym.Wrapper.__init__(self, env)
|
|
|
|
self.num_envs = getattr(env, "num_envs", 1)
|
|
self.episode_count = 0
|
|
self.episode_start_times: np.ndarray = None
|
|
self.episode_returns: Optional[np.ndarray] = None
|
|
self.episode_lengths: Optional[np.ndarray] = None
|
|
self.return_queue = deque(maxlen=deque_size)
|
|
self.length_queue = deque(maxlen=deque_size)
|
|
self.is_vector_env = getattr(env, "is_vector_env", False)
|
|
|
|
def reset(self, **kwargs):
|
|
"""Resets the environment using kwargs and resets the episode returns and lengths."""
|
|
obs, info = super().reset(**kwargs)
|
|
self.episode_start_times = np.full(
|
|
self.num_envs, time.perf_counter(), dtype=np.float32
|
|
)
|
|
self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)
|
|
self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
|
|
return obs, info
|
|
|
|
def step(self, action):
|
|
"""Steps through the environment, recording the episode statistics."""
|
|
(
|
|
observations,
|
|
rewards,
|
|
terminations,
|
|
truncations,
|
|
infos,
|
|
) = self.env.step(action)
|
|
assert isinstance(
|
|
infos, dict
|
|
), f"`info` dtype is {type(infos)} while supported dtype is `dict`. This may be due to usage of other wrappers in the wrong order."
|
|
self.episode_returns += rewards
|
|
self.episode_lengths += 1
|
|
dones = np.logical_or(terminations, truncations)
|
|
num_dones = np.sum(dones)
|
|
if num_dones:
|
|
if "episode" in infos or "_episode" in infos:
|
|
raise ValueError(
|
|
"Attempted to add episode stats when they already exist"
|
|
)
|
|
else:
|
|
infos["episode"] = {
|
|
"r": np.where(dones, self.episode_returns, 0.0),
|
|
"l": np.where(dones, self.episode_lengths, 0),
|
|
"t": np.where(
|
|
dones,
|
|
np.round(time.perf_counter() - self.episode_start_times, 6),
|
|
0.0,
|
|
),
|
|
}
|
|
if self.is_vector_env:
|
|
infos["_episode"] = np.where(dones, True, False)
|
|
self.return_queue.extend(self.episode_returns[dones])
|
|
self.length_queue.extend(self.episode_lengths[dones])
|
|
self.episode_count += num_dones
|
|
self.episode_lengths[dones] = 0
|
|
self.episode_returns[dones] = 0
|
|
self.episode_start_times[dones] = time.perf_counter()
|
|
return (
|
|
observations,
|
|
rewards,
|
|
terminations,
|
|
truncations,
|
|
infos,
|
|
)
|