Files
Gymnasium/gymnasium/wrappers/record_episode_statistics.py

127 lines
5.0 KiB
Python

"""Wrapper that tracks the cumulative rewards and episode lengths."""
import time
from collections import deque
from typing import Optional
import numpy as np
import gymnasium as gym
class RecordEpisodeStatistics(gym.Wrapper, gym.utils.RecordConstructorArgs):
"""This wrapper will keep track of cumulative rewards and episode lengths.
At the end of an episode, the statistics of the episode will be added to ``info``
using the key ``episode``. If using a vectorized environment also the key
``_episode`` is used which indicates whether the env at the respective index has
the episode statistics.
After the completion of an episode, ``info`` will look like this::
>>> info = {
... "episode": {
... "r": "<cumulative reward>",
... "l": "<episode length>",
... "t": "<elapsed time since beginning of episode>"
... },
... }
For a vectorized environments the output will be in the form of::
>>> infos = {
... "final_observation": "<array of length num-envs>",
... "_final_observation": "<boolean array of length num-envs>",
... "final_info": "<array of length num-envs>",
... "_final_info": "<boolean array of length num-envs>",
... "episode": {
... "r": "<array of cumulative reward>",
... "l": "<array of episode length>",
... "t": "<array of elapsed time since beginning of episode>"
... },
... "_episode": "<boolean array of length num-envs>"
... }
Moreover, the most recent rewards and episode lengths are stored in buffers that can be accessed via
:attr:`wrapped_env.return_queue` and :attr:`wrapped_env.length_queue` respectively.
Attributes:
return_queue: The cumulative rewards of the last ``deque_size``-many episodes
length_queue: The lengths of the last ``deque_size``-many episodes
"""
def __init__(self, env: gym.Env, deque_size: int = 100):
"""This wrapper will keep track of cumulative rewards and episode lengths.
Args:
env (Env): The environment to apply the wrapper
deque_size: The size of the buffers :attr:`return_queue` and :attr:`length_queue`
"""
gym.utils.RecordConstructorArgs.__init__(self, deque_size=deque_size)
gym.Wrapper.__init__(self, env)
self.num_envs = getattr(env, "num_envs", 1)
self.episode_count = 0
self.episode_start_times: np.ndarray = None
self.episode_returns: Optional[np.ndarray] = None
self.episode_lengths: Optional[np.ndarray] = None
self.return_queue = deque(maxlen=deque_size)
self.length_queue = deque(maxlen=deque_size)
self.is_vector_env = getattr(env, "is_vector_env", False)
def reset(self, **kwargs):
"""Resets the environment using kwargs and resets the episode returns and lengths."""
obs, info = super().reset(**kwargs)
self.episode_start_times = np.full(
self.num_envs, time.perf_counter(), dtype=np.float32
)
self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)
self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
return obs, info
def step(self, action):
"""Steps through the environment, recording the episode statistics."""
(
observations,
rewards,
terminations,
truncations,
infos,
) = self.env.step(action)
assert isinstance(
infos, dict
), f"`info` dtype is {type(infos)} while supported dtype is `dict`. This may be due to usage of other wrappers in the wrong order."
self.episode_returns += rewards
self.episode_lengths += 1
dones = np.logical_or(terminations, truncations)
num_dones = np.sum(dones)
if num_dones:
if "episode" in infos or "_episode" in infos:
raise ValueError(
"Attempted to add episode stats when they already exist"
)
else:
infos["episode"] = {
"r": np.where(dones, self.episode_returns, 0.0),
"l": np.where(dones, self.episode_lengths, 0),
"t": np.where(
dones,
np.round(time.perf_counter() - self.episode_start_times, 6),
0.0,
),
}
if self.is_vector_env:
infos["_episode"] = np.where(dones, True, False)
self.return_queue.extend(self.episode_returns[dones])
self.length_queue.extend(self.episode_lengths[dones])
self.episode_count += num_dones
self.episode_lengths[dones] = 0
self.episode_returns[dones] = 0
self.episode_start_times[dones] = time.perf_counter()
return (
observations,
rewards,
terminations,
truncations,
infos,
)