Files
Gymnasium/gymnasium/experimental/wrappers/common.py

308 lines
12 KiB
Python
Raw Normal View History

2022-12-10 22:04:14 +00:00
"""A collection of common wrappers.
* ``AutoresetV0`` - Auto-resets the environment
* ``PassiveEnvCheckerV0`` - Passive environment checker that does not modify any environment data
* ``OrderEnforcingV0`` - Enforces the order of function calls to environments
* ``RecordEpisodeStatisticsV0`` - Records the episode statistics
"""
from __future__ import annotations
import time
from collections import deque
from typing import Any, SupportsFloat
import numpy as np
import gymnasium as gym
2023-02-22 13:58:29 +00:00
from gymnasium.core import ActType, ObsType, RenderFrame
2022-12-10 22:04:14 +00:00
from gymnasium.error import ResetNeeded
from gymnasium.utils.passive_env_checker import (
check_action_space,
check_observation_space,
env_render_passive_checker,
env_reset_passive_checker,
env_step_passive_checker,
)
class AutoresetV0(
gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
):
2022-12-10 22:04:14 +00:00
"""A class for providing an automatic reset functionality for gymnasium environments when calling :meth:`self.step`."""
2023-02-22 13:58:29 +00:00
def __init__(self, env: gym.Env[ObsType, ActType]):
2022-12-10 22:04:14 +00:00
"""A class for providing an automatic reset functionality for gymnasium environments when calling :meth:`self.step`.
Args:
env (gym.Env): The environment to apply the wrapper
"""
gym.utils.RecordConstructorArgs.__init__(self)
gym.Wrapper.__init__(self, env)
2022-12-10 22:04:14 +00:00
self._episode_ended: bool = False
self._reset_options: dict[str, Any] | None = None
def step(
2023-02-22 13:58:29 +00:00
self, action: ActType
) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
2022-12-10 22:04:14 +00:00
"""Steps through the environment with action and resets the environment if a terminated or truncated signal is encountered in the previous step.
Args:
action: The action to take
Returns:
The autoreset environment :meth:`step`
"""
if self._episode_ended:
2023-02-22 13:58:29 +00:00
obs, info = self.env.reset(options=self._reset_options)
2022-12-10 22:04:14 +00:00
self._episode_ended = True
return obs, 0, False, False, info
else:
obs, reward, terminated, truncated, info = super().step(action)
self._episode_ended = terminated or truncated
return obs, reward, terminated, truncated, info
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
2023-02-22 13:58:29 +00:00
) -> tuple[ObsType, dict[str, Any]]:
2022-12-10 22:04:14 +00:00
"""Resets the environment, saving the options used."""
self._episode_ended = False
self._reset_options = options
return super().reset(seed=seed, options=self._reset_options)
class PassiveEnvCheckerV0(
gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
):
2022-12-10 22:04:14 +00:00
"""A passive environment checker wrapper that surrounds the step, reset and render functions to check they follow the gymnasium API."""
def __init__(self, env: gym.Env[ObsType, ActType]):
2022-12-10 22:04:14 +00:00
"""Initialises the wrapper with the environments, run the observation and action space tests."""
gym.utils.RecordConstructorArgs.__init__(self)
gym.Wrapper.__init__(self, env)
2022-12-10 22:04:14 +00:00
assert hasattr(
env, "action_space"
), "The environment must specify an action space. https://gymnasium.farama.org/content/environment_creation/"
check_action_space(env.action_space)
assert hasattr(
env, "observation_space"
), "The environment must specify an observation space. https://gymnasium.farama.org/content/environment_creation/"
check_observation_space(env.observation_space)
self._checked_reset: bool = False
self._checked_step: bool = False
self._checked_render: bool = False
def step(
2023-02-22 13:58:29 +00:00
self, action: ActType
) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
2022-12-10 22:04:14 +00:00
"""Steps through the environment that on the first call will run the `passive_env_step_check`."""
if self._checked_step is False:
self._checked_step = True
return env_step_passive_checker(self.env, action)
else:
return self.env.step(action)
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
2023-02-22 13:58:29 +00:00
) -> tuple[ObsType, dict[str, Any]]:
2022-12-10 22:04:14 +00:00
"""Resets the environment that on the first call will run the `passive_env_reset_check`."""
if self._checked_reset is False:
self._checked_reset = True
return env_reset_passive_checker(self.env, seed=seed, options=options)
else:
return self.env.reset(seed=seed, options=options)
def render(self) -> RenderFrame | list[RenderFrame] | None:
"""Renders the environment that on the first call will run the `passive_env_render_check`."""
if self._checked_render is False:
self._checked_render = True
return env_render_passive_checker(self.env)
else:
return self.env.render()
class OrderEnforcingV0(
gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
):
2022-12-10 22:04:14 +00:00
"""A wrapper that will produce an error if :meth:`step` is called before an initial :meth:`reset`.
Example:
>>> import gymnasium as gym
>>> from gymnasium.experimental.wrappers import OrderEnforcingV0
>>> env = gym.make("CartPole-v1", render_mode="human")
2022-12-10 22:04:14 +00:00
>>> env = OrderEnforcingV0(env)
>>> env.step(0)
Traceback (most recent call last):
...
gymnasium.error.ResetNeeded: Cannot call env.step() before calling env.reset()
>>> env.render()
Traceback (most recent call last):
...
gymnasium.error.ResetNeeded: Cannot call `env.render()` before calling `env.reset()`, if this is a intended action, set `disable_render_order_enforcing=True` on the OrderEnforcer wrapper.
>>> _ = env.reset()
2022-12-10 22:04:14 +00:00
>>> env.render()
>>> _ = env.step(0)
>>> env.close()
2022-12-10 22:04:14 +00:00
"""
2023-02-22 13:58:29 +00:00
def __init__(
self,
env: gym.Env[ObsType, ActType],
disable_render_order_enforcing: bool = False,
):
2022-12-10 22:04:14 +00:00
"""A wrapper that will produce an error if :meth:`step` is called before an initial :meth:`reset`.
Args:
env: The environment to wrap
disable_render_order_enforcing: If to disable render order enforcing
"""
gym.utils.RecordConstructorArgs.__init__(
self, disable_render_order_enforcing=disable_render_order_enforcing
)
gym.Wrapper.__init__(self, env)
2022-12-10 22:04:14 +00:00
self._has_reset: bool = False
self._disable_render_order_enforcing: bool = disable_render_order_enforcing
2023-02-22 13:58:29 +00:00
def step(self, action: ActType) -> tuple[ObsType, SupportsFloat, bool, bool, dict]:
"""Steps through the environment."""
2022-12-10 22:04:14 +00:00
if not self._has_reset:
raise ResetNeeded("Cannot call env.step() before calling env.reset()")
return super().step(action)
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
2023-02-22 13:58:29 +00:00
) -> tuple[ObsType, dict[str, Any]]:
2022-12-10 22:04:14 +00:00
"""Resets the environment with `kwargs`."""
self._has_reset = True
return super().reset(seed=seed, options=options)
def render(self) -> RenderFrame | list[RenderFrame] | None:
"""Renders the environment with `kwargs`."""
if not self._disable_render_order_enforcing and not self._has_reset:
raise ResetNeeded(
"Cannot call `env.render()` before calling `env.reset()`, if this is a intended action, "
"set `disable_render_order_enforcing=True` on the OrderEnforcer wrapper."
)
return super().render()
@property
def has_reset(self):
"""Returns if the environment has been reset before."""
return self._has_reset
class RecordEpisodeStatisticsV0(
gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
):
2022-12-10 22:04:14 +00:00
"""This wrapper will keep track of cumulative rewards and episode lengths.
At the end of an episode, the statistics of the episode will be added to ``info``
using the key ``episode``. If using a vectorized environment also the key
``_episode`` is used which indicates whether the env at the respective index has
the episode statistics.
After the completion of an episode, ``info`` will look like this::
>>> info = {
... "episode": {
... "r": "<cumulative reward>",
... "l": "<episode length>",
... "t": "<elapsed time since beginning of episode>"
... },
... }
For a vectorized environments the output will be in the form of::
>>> infos = {
... "final_observation": "<array of length num-envs>",
... "_final_observation": "<boolean array of length num-envs>",
... "final_info": "<array of length num-envs>",
... "_final_info": "<boolean array of length num-envs>",
2022-12-10 22:04:14 +00:00
... "episode": {
... "r": "<array of cumulative reward>",
... "l": "<array of episode length>",
... "t": "<array of elapsed time since beginning of episode>"
... },
... "_episode": "<boolean array of length num-envs>"
... }
2022-12-10 22:04:14 +00:00
Moreover, the most recent rewards and episode lengths are stored in buffers that can be accessed via
:attr:`wrapped_env.return_queue` and :attr:`wrapped_env.length_queue` respectively.
Attributes:
episode_reward_buffer: The cumulative rewards of the last ``deque_size``-many episodes
episode_length_buffer: The lengths of the last ``deque_size``-many episodes
"""
def __init__(
self,
env: gym.Env[ObsType, ActType],
2022-12-10 22:04:14 +00:00
buffer_length: int | None = 100,
stats_key: str = "episode",
):
"""This wrapper will keep track of cumulative rewards and episode lengths.
Args:
env (Env): The environment to apply the wrapper
buffer_length: The size of the buffers :attr:`return_queue` and :attr:`length_queue`
stats_key: The info key for the episode statistics
"""
gym.utils.RecordConstructorArgs.__init__(self)
gym.Wrapper.__init__(self, env)
2022-12-10 22:04:14 +00:00
self._stats_key = stats_key
self.episode_count = 0
self.episode_start_time: float = -1
self.episode_reward: float = -1
self.episode_length: int = -1
2023-02-22 13:58:29 +00:00
self.episode_time_length_buffer: deque[int] = deque(maxlen=buffer_length)
self.episode_reward_buffer: deque[float] = deque(maxlen=buffer_length)
self.episode_length_buffer: deque[int] = deque(maxlen=buffer_length)
2022-12-10 22:04:14 +00:00
def step(
2023-02-22 13:58:29 +00:00
self, action: ActType
) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
2022-12-10 22:04:14 +00:00
"""Steps through the environment, recording the episode statistics."""
obs, reward, terminated, truncated, info = super().step(action)
self.episode_reward += reward
self.episode_length += 1
if terminated or truncated:
assert self._stats_key not in info
episode_time_length = np.round(
time.perf_counter() - self.episode_start_time, 6
)
info[self._stats_key] = {
"r": self.episode_reward,
"l": self.episode_length,
"t": episode_time_length,
}
self.episode_time_length_buffer.append(episode_time_length)
self.episode_reward_buffer.append(self.episode_reward)
self.episode_length_buffer.append(self.episode_length)
self.episode_count += 1
return obs, reward, terminated, truncated, info
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
2023-02-22 13:58:29 +00:00
) -> tuple[ObsType, dict[str, Any]]:
2022-12-10 22:04:14 +00:00
"""Resets the environment using seed and options and resets the episode rewards and lengths."""
obs, info = super().reset(seed=seed, options=options)
self.episode_start_time = time.perf_counter()
self.episode_reward = 0
self.episode_length = 0
return obs, info