2023-02-12 07:49:37 -05:00
|
|
|
"""Wrapper that tracks the cumulative rewards and episode lengths."""
|
2023-05-23 15:46:04 +01:00
|
|
|
from __future__ import annotations
|
|
|
|
|
2023-02-12 07:49:37 -05:00
|
|
|
import time
|
|
|
|
from collections import deque
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
2023-05-23 15:46:04 +01:00
|
|
|
from gymnasium.core import ActType, ObsType
|
2023-11-07 13:27:25 +00:00
|
|
|
from gymnasium.vector.vector_env import ArrayType, VectorEnv, VectorWrapper
|
2023-02-12 07:49:37 -05:00
|
|
|
|
|
|
|
|
2023-11-07 13:27:25 +00:00
|
|
|
__all__ = ["RecordEpisodeStatistics"]
|
2023-06-21 17:04:11 +01:00
|
|
|
|
|
|
|
|
2023-11-07 13:27:25 +00:00
|
|
|
class RecordEpisodeStatistics(VectorWrapper):
|
2023-02-12 07:49:37 -05:00
|
|
|
"""This wrapper will keep track of cumulative rewards and episode lengths.
|
|
|
|
|
2023-11-07 13:27:25 +00:00
|
|
|
At the end of any episode within the vectorized env, the statistics of the episode
|
|
|
|
will be added to ``info`` using the key ``episode``, and the ``_episode`` key
|
|
|
|
is used to indicate the environment index which has a terminated or truncated episode.
|
2023-02-12 07:49:37 -05:00
|
|
|
|
|
|
|
>>> infos = { # doctest: +SKIP
|
|
|
|
... ...
|
|
|
|
... "episode": {
|
2023-05-23 15:46:04 +01:00
|
|
|
... "r": "<array of cumulative reward for each done sub-environment>",
|
|
|
|
... "l": "<array of episode length for each done sub-environment>",
|
|
|
|
... "t": "<array of elapsed time since beginning of episode for each done sub-environment>"
|
2023-02-12 07:49:37 -05:00
|
|
|
... },
|
|
|
|
... "_episode": "<boolean array of length num-envs>"
|
|
|
|
... }
|
|
|
|
|
|
|
|
Moreover, the most recent rewards and episode lengths are stored in buffers that can be accessed via
|
|
|
|
:attr:`wrapped_env.return_queue` and :attr:`wrapped_env.length_queue` respectively.
|
|
|
|
|
|
|
|
Attributes:
|
|
|
|
return_queue: The cumulative rewards of the last ``deque_size``-many episodes
|
|
|
|
length_queue: The lengths of the last ``deque_size``-many episodes
|
2023-11-07 13:27:25 +00:00
|
|
|
|
|
|
|
Example:
|
|
|
|
>>> from pprint import pprint
|
|
|
|
>>> import gymnasium as gym
|
|
|
|
>>> envs = gym.make_vec("CartPole-v1", num_envs=3)
|
|
|
|
>>> envs = RecordEpisodeStatistics(envs)
|
|
|
|
>>> obs, info = envs.reset(123)
|
|
|
|
>>> _ = envs.action_space.seed(123)
|
|
|
|
>>> end = False
|
|
|
|
>>> while not end:
|
|
|
|
... obs, rew, term, trunc, info = envs.step(envs.action_space.sample())
|
|
|
|
... end = term.any() or trunc.any()
|
|
|
|
...
|
|
|
|
>>> envs.close()
|
|
|
|
>>> pprint(info) # doctest: +SKIP
|
|
|
|
{'_episode': array([ True, False, False]),
|
|
|
|
'_final_info': array([ True, False, False]),
|
|
|
|
'_final_observation': array([ True, False, False]),
|
|
|
|
'episode': {'l': array([11, 0, 0], dtype=int32),
|
|
|
|
'r': array([11., 0., 0.], dtype=float32),
|
|
|
|
't': array([0.007812, 0. , 0. ], dtype=float32)},
|
|
|
|
'final_info': array([{}, None, None], dtype=object),
|
|
|
|
'final_observation': array([array([ 0.11448676, 0.9416149 , -0.20946532, -1.7619033 ], dtype=float32),
|
|
|
|
None, None], dtype=object)}
|
2023-02-12 07:49:37 -05:00
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, env: VectorEnv, deque_size: int = 100):
|
|
|
|
"""This wrapper will keep track of cumulative rewards and episode lengths.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
env (Env): The environment to apply the wrapper
|
|
|
|
deque_size: The size of the buffers :attr:`return_queue` and :attr:`length_queue`
|
|
|
|
"""
|
|
|
|
super().__init__(env)
|
2023-05-23 15:46:04 +01:00
|
|
|
|
2023-02-12 07:49:37 -05:00
|
|
|
self.episode_count = 0
|
2023-05-23 15:46:04 +01:00
|
|
|
|
|
|
|
self.episode_start_times: np.ndarray = np.zeros(())
|
|
|
|
self.episode_returns: np.ndarray = np.zeros(())
|
|
|
|
self.episode_lengths: np.ndarray = np.zeros(())
|
|
|
|
|
2023-02-12 07:49:37 -05:00
|
|
|
self.return_queue = deque(maxlen=deque_size)
|
|
|
|
self.length_queue = deque(maxlen=deque_size)
|
|
|
|
|
|
|
|
def reset(
|
|
|
|
self,
|
2023-05-23 15:46:04 +01:00
|
|
|
seed: int | list[int] | None = None,
|
|
|
|
options: dict | None = None,
|
2023-02-12 07:49:37 -05:00
|
|
|
):
|
|
|
|
"""Resets the environment using kwargs and resets the episode returns and lengths."""
|
|
|
|
obs, info = super().reset(seed=seed, options=options)
|
2023-05-23 15:46:04 +01:00
|
|
|
|
2023-02-12 07:49:37 -05:00
|
|
|
self.episode_start_times = np.full(
|
|
|
|
self.num_envs, time.perf_counter(), dtype=np.float32
|
|
|
|
)
|
|
|
|
self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)
|
|
|
|
self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
|
2023-05-23 15:46:04 +01:00
|
|
|
|
2023-02-12 07:49:37 -05:00
|
|
|
return obs, info
|
|
|
|
|
2023-05-23 15:46:04 +01:00
|
|
|
def step(
|
|
|
|
self, actions: ActType
|
|
|
|
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]:
|
2023-02-12 07:49:37 -05:00
|
|
|
"""Steps through the environment, recording the episode statistics."""
|
|
|
|
(
|
|
|
|
observations,
|
|
|
|
rewards,
|
|
|
|
terminations,
|
|
|
|
truncations,
|
|
|
|
infos,
|
2023-05-23 15:46:04 +01:00
|
|
|
) = self.env.step(actions)
|
|
|
|
|
2023-02-12 07:49:37 -05:00
|
|
|
assert isinstance(
|
|
|
|
infos, dict
|
|
|
|
), f"`info` dtype is {type(infos)} while supported dtype is `dict`. This may be due to usage of other wrappers in the wrong order."
|
2023-05-23 15:46:04 +01:00
|
|
|
|
2023-02-12 07:49:37 -05:00
|
|
|
self.episode_returns += rewards
|
|
|
|
self.episode_lengths += 1
|
2023-05-23 15:46:04 +01:00
|
|
|
|
2023-02-12 07:49:37 -05:00
|
|
|
dones = np.logical_or(terminations, truncations)
|
|
|
|
num_dones = np.sum(dones)
|
2023-05-23 15:46:04 +01:00
|
|
|
|
2023-02-12 07:49:37 -05:00
|
|
|
if num_dones:
|
|
|
|
if "episode" in infos or "_episode" in infos:
|
|
|
|
raise ValueError(
|
|
|
|
"Attempted to add episode stats when they already exist"
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
infos["episode"] = {
|
|
|
|
"r": np.where(dones, self.episode_returns, 0.0),
|
|
|
|
"l": np.where(dones, self.episode_lengths, 0),
|
|
|
|
"t": np.where(
|
|
|
|
dones,
|
|
|
|
np.round(time.perf_counter() - self.episode_start_times, 6),
|
|
|
|
0.0,
|
|
|
|
),
|
|
|
|
}
|
2023-05-23 15:46:04 +01:00
|
|
|
infos["_episode"] = dones
|
|
|
|
|
2023-02-12 07:49:37 -05:00
|
|
|
self.episode_count += num_dones
|
2023-05-23 15:46:04 +01:00
|
|
|
|
|
|
|
for i in np.where(dones):
|
|
|
|
self.return_queue.extend(self.episode_returns[i])
|
|
|
|
self.length_queue.extend(self.episode_lengths[i])
|
|
|
|
|
2023-02-12 07:49:37 -05:00
|
|
|
self.episode_lengths[dones] = 0
|
|
|
|
self.episode_returns[dones] = 0
|
|
|
|
self.episode_start_times[dones] = time.perf_counter()
|
2023-05-23 15:46:04 +01:00
|
|
|
|
2023-02-12 07:49:37 -05:00
|
|
|
return (
|
|
|
|
observations,
|
|
|
|
rewards,
|
|
|
|
terminations,
|
|
|
|
truncations,
|
|
|
|
infos,
|
|
|
|
)
|