diff --git a/v1.2.0/.buildinfo b/v1.2.0/.buildinfo new file mode 100644 index 000000000..8a0801519 --- /dev/null +++ b/v1.2.0/.buildinfo @@ -0,0 +1,4 @@ +# Sphinx build info version 1 +# This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done. +config: cf2f46ba41b047a4e64ef75619dcaa75 +tags: d77d1c0d9ca2f4c8421862c7c5a0d620 diff --git a/v1.2.0/.nojekyll b/v1.2.0/.nojekyll new file mode 100644 index 000000000..e69de29bb diff --git a/v1.2.0/404.html b/v1.2.0/404.html new file mode 100644 index 000000000..0c2596cb8 --- /dev/null +++ b/v1.2.0/404.html @@ -0,0 +1,695 @@ + + +
+ + + + + + + + + + + +This folder contains the documentation for Gymnasium.
+Fork Gymnasium and edit the docstring in the environment’s Python file. Then, pip install your Gymnasium fork and run docs/_scripts/gen_mds.py
in this repo. This will automatically generate a Markdown documentation file for the environment.
Ensure the environment is in Gymnasium (or your fork). Ensure that the environment’s Python file has a properly formatted markdown docstring. Install using pip install -e .
and then run docs/_scripts/gen_mds.py
. This will automatically generate a md page for the environment. Then complete the other steps.
Add the corresponding gif into the docs/_static/videos/{ENV_TYPE}
folder, where ENV_TYPE
is the category of your new environment (e.g. mujoco). Follow snake_case naming convention. Alternatively, run docs/_scripts/gen_gifs.py
.
Edit docs/environments/{ENV_TYPE}/index.md
, and add the name of the file corresponding to your new environment to the toctree
.
Install the required packages and Gymnasium (or your fork):
+pip install gymnasium
+pip install -r docs/requirements.txt
+
To build the documentation once:
+cd docs
+make dirhtml
+
To rebuild the documentation automatically every time a change is made:
+cd docs
+sphinx-autobuild -b dirhtml --watch ../gymnasium --re-ignore "pickle$" . _build
+
You can then open http://localhost:8000 in your browser to watch a live updated version of the documentation.
+We use Sphinx-Gallery to build the tutorials inside the docs/tutorials
directory. Check docs/tutorials/demo.py
to see an example of a tutorial and Sphinx-Gallery documentation for more information.
To convert Jupyter Notebooks to the python tutorials you can use this script.
+If you want Sphinx-Gallery to execute the tutorial (which adds outputs and plots) then the file name should start with run_
. Note that this adds to the build time so make sure the script doesn’t take more than a few seconds to execute.
+"""Core API for Environment, Wrapper, ActionWrapper, RewardWrapper and ObservationWrapper."""
+
+from __future__ import annotations
+
+from copy import deepcopy
+from typing import TYPE_CHECKING, Any, Generic, SupportsFloat, TypeVar
+
+import numpy as np
+
+import gymnasium
+from gymnasium import spaces
+from gymnasium.utils import RecordConstructorArgs, seeding
+
+
+if TYPE_CHECKING:
+ from gymnasium.envs.registration import EnvSpec, WrapperSpec
+
+ObsType = TypeVar("ObsType")
+ActType = TypeVar("ActType")
+RenderFrame = TypeVar("RenderFrame")
+
+
+
+[docs]
+class Env(Generic[ObsType, ActType]):
+ r"""The main Gymnasium class for implementing Reinforcement Learning Agents environments.
+
+ The class encapsulates an environment with arbitrary behind-the-scenes dynamics through the :meth:`step` and :meth:`reset` functions.
+ An environment can be partially or fully observed by single agents. For multi-agent environments, see PettingZoo.
+
+ The main API methods that users of this class need to know are:
+
+ - :meth:`step` - Updates an environment with actions returning the next agent observation, the reward for taking that actions,
+ if the environment has terminated or truncated due to the latest action and information from the environment about the step, i.e. metrics, debug info.
+ - :meth:`reset` - Resets the environment to an initial state, required before calling step.
+ Returns the first agent observation for an episode and information, i.e. metrics, debug info.
+ - :meth:`render` - Renders the environments to help visualise what the agent see, examples modes are "human", "rgb_array", "ansi" for text.
+ - :meth:`close` - Closes the environment, important when external software is used, i.e. pygame for rendering, databases
+
+ Environments have additional attributes for users to understand the implementation
+
+ - :attr:`action_space` - The Space object corresponding to valid actions, all valid actions should be contained within the space.
+ - :attr:`observation_space` - The Space object corresponding to valid observations, all valid observations should be contained within the space.
+ - :attr:`spec` - An environment spec that contains the information used to initialize the environment from :meth:`gymnasium.make`
+ - :attr:`metadata` - The metadata of the environment, e.g. `{"render_modes": ["rgb_array", "human"], "render_fps": 30}`. For Jax or Torch, this can be indicated to users with `"jax"=True` or `"torch"=True`.
+ - :attr:`np_random` - The random number generator for the environment. This is automatically assigned during
+ ``super().reset(seed=seed)`` and when assessing :attr:`np_random`.
+
+ .. seealso:: For modifying or extending environments use the :class:`gymnasium.Wrapper` class
+
+ Note:
+ To get reproducible sampling of actions, a seed can be set with ``env.action_space.seed(123)``.
+
+ Note:
+ For strict type checking (e.g. mypy or pyright), :class:`Env` is a generic class with two parameterized types: ``ObsType`` and ``ActType``.
+ The ``ObsType`` and ``ActType`` are the expected types of the observations and actions used in :meth:`reset` and :meth:`step`.
+ The environment's :attr:`observation_space` and :attr:`action_space` should have type ``Space[ObsType]`` and ``Space[ActType]``,
+ see a space's implementation to find its parameterized type.
+ """
+
+ # Set this in SOME subclasses
+ metadata: dict[str, Any] = {"render_modes": []}
+ # define render_mode if your environment supports rendering
+ render_mode: str | None = None
+ spec: EnvSpec | None = None
+
+ # Set these in ALL subclasses
+ action_space: spaces.Space[ActType]
+ observation_space: spaces.Space[ObsType]
+
+ # Created
+ _np_random: np.random.Generator | None = None
+ # will be set to the "invalid" value -1 if the seed of the currently set rng is unknown
+ _np_random_seed: int | None = None
+
+
+[docs]
+ def step(
+ self, action: ActType
+ ) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
+ """Run one timestep of the environment's dynamics using the agent actions.
+
+ When the end of an episode is reached (``terminated or truncated``), it is necessary to call :meth:`reset` to
+ reset this environment's state for the next episode.
+
+ .. versionchanged:: 0.26
+
+ The Step API was changed removing ``done`` in favor of ``terminated`` and ``truncated`` to make it clearer
+ to users when the environment had terminated or truncated which is critical for reinforcement learning
+ bootstrapping algorithms.
+
+ Args:
+ action (ActType): an action provided by the agent to update the environment state.
+
+ Returns:
+ observation (ObsType): An element of the environment's :attr:`observation_space` as the next observation due to the agent actions.
+ An example is a numpy array containing the positions and velocities of the pole in CartPole.
+ reward (SupportsFloat): The reward as a result of taking the action.
+ terminated (bool): Whether the agent reaches the terminal state (as defined under the MDP of the task)
+ which can be positive or negative. An example is reaching the goal state or moving into the lava from
+ the Sutton and Barto Gridworld. If true, the user needs to call :meth:`reset`.
+ truncated (bool): Whether the truncation condition outside the scope of the MDP is satisfied.
+ Typically, this is a timelimit, but could also be used to indicate an agent physically going out of bounds.
+ Can be used to end the episode prematurely before a terminal state is reached.
+ If true, the user needs to call :meth:`reset`.
+ info (dict): Contains auxiliary diagnostic information (helpful for debugging, learning, and logging).
+ This might, for instance, contain: metrics that describe the agent's performance state, variables that are
+ hidden from observations, or individual reward terms that are combined to produce the total reward.
+ In OpenAI Gym <v26, it contains "TimeLimit.truncated" to distinguish truncation and termination,
+ however this is deprecated in favour of returning terminated and truncated variables.
+ done (bool): (Deprecated) A boolean value for if the episode has ended, in which case further :meth:`step` calls will
+ return undefined results. This was removed in OpenAI Gym v26 in favor of terminated and truncated attributes.
+ A done signal may be emitted for different reasons: Maybe the task underlying the environment was solved successfully,
+ a certain timelimit was exceeded, or the physics simulation has entered an invalid state.
+ """
+ raise NotImplementedError
+
+
+
+[docs]
+ def reset(
+ self,
+ *,
+ seed: int | None = None,
+ options: dict[str, Any] | None = None,
+ ) -> tuple[ObsType, dict[str, Any]]: # type: ignore
+ """Resets the environment to an initial internal state, returning an initial observation and info.
+
+ This method generates a new starting state often with some randomness to ensure that the agent explores the
+ state space and learns a generalised policy about the environment. This randomness can be controlled
+ with the ``seed`` parameter otherwise if the environment already has a random number generator and
+ :meth:`reset` is called with ``seed=None``, the RNG is not reset.
+
+ Therefore, :meth:`reset` should (in the typical use case) be called with a seed right after initialization and then never again.
+
+ For Custom environments, the first line of :meth:`reset` should be ``super().reset(seed=seed)`` which implements
+ the seeding correctly.
+
+ .. versionchanged:: v0.25
+
+ The ``return_info`` parameter was removed and now info is expected to be returned.
+
+ Args:
+ seed (optional int): The seed that is used to initialize the environment's PRNG (`np_random`) and
+ the read-only attribute `np_random_seed`.
+ If the environment does not already have a PRNG and ``seed=None`` (the default option) is passed,
+ a seed will be chosen from some source of entropy (e.g. timestamp or /dev/urandom).
+ However, if the environment already has a PRNG and ``seed=None`` is passed, the PRNG will *not* be reset
+ and the env's :attr:`np_random_seed` will *not* be altered.
+ If you pass an integer, the PRNG will be reset even if it already exists.
+ Usually, you want to pass an integer *right after the environment has been initialized and then never again*.
+ Please refer to the minimal example above to see this paradigm in action.
+ options (optional dict): Additional information to specify how the environment is reset (optional,
+ depending on the specific environment)
+
+ Returns:
+ observation (ObsType): Observation of the initial state. This will be an element of :attr:`observation_space`
+ (typically a numpy array) and is analogous to the observation returned by :meth:`step`.
+ info (dictionary): This dictionary contains auxiliary information complementing ``observation``. It should be analogous to
+ the ``info`` returned by :meth:`step`.
+ """
+ # Initialize the RNG if the seed is manually passed
+ if seed is not None:
+ self._np_random, self._np_random_seed = seeding.np_random(seed)
+
+
+
+[docs]
+ def render(self) -> RenderFrame | list[RenderFrame] | None:
+ """Compute the render frames as specified by :attr:`render_mode` during the initialization of the environment.
+
+ The environment's :attr:`metadata` render modes (`env.metadata["render_modes"]`) should contain the possible
+ ways to implement the render modes. In addition, list versions for most render modes is achieved through
+ `gymnasium.make` which automatically applies a wrapper to collect rendered frames.
+
+ Note:
+ As the :attr:`render_mode` is known during ``__init__``, the objects used to render the environment state
+ should be initialised in ``__init__``.
+
+ By convention, if the :attr:`render_mode` is:
+
+ - None (default): no render is computed.
+ - "human": The environment is continuously rendered in the current display or terminal, usually for human consumption.
+ This rendering should occur during :meth:`step` and :meth:`render` doesn't need to be called. Returns ``None``.
+ - "rgb_array": Return a single frame representing the current state of the environment.
+ A frame is a ``np.ndarray`` with shape ``(x, y, 3)`` representing RGB values for an x-by-y pixel image.
+ - "ansi": Return a strings (``str``) or ``StringIO.StringIO`` containing a terminal-style text representation
+ for each time step. The text can include newlines and ANSI escape sequences (e.g. for colors).
+ - "rgb_array_list" and "ansi_list": List based version of render modes are possible (except Human) through the
+ wrapper, :py:class:`gymnasium.wrappers.RenderCollection` that is automatically applied during ``gymnasium.make(..., render_mode="rgb_array_list")``.
+ The frames collected are popped after :meth:`render` is called or :meth:`reset`.
+
+ Note:
+ Make sure that your class's :attr:`metadata` ``"render_modes"`` key includes the list of supported modes.
+
+ .. versionchanged:: 0.25.0
+
+ The render function was changed to no longer accept parameters, rather these parameters should be specified
+ in the environment initialised, i.e., ``gymnasium.make("CartPole-v1", render_mode="human")``
+ """
+ raise NotImplementedError
+
+
+
+[docs]
+ def close(self):
+ """After the user has finished using the environment, close contains the code necessary to "clean up" the environment.
+
+ This is critical for closing rendering windows, database or HTTP connections.
+ Calling ``close`` on an already closed environment has no effect and won't raise an error.
+ """
+ pass
+
+
+ @property
+ def unwrapped(self) -> Env[ObsType, ActType]:
+ """Returns the base non-wrapped environment.
+
+ Returns:
+ Env: The base non-wrapped :class:`gymnasium.Env` instance
+ """
+ return self
+
+ @property
+ def np_random_seed(self) -> int:
+ """Returns the environment's internal :attr:`_np_random_seed` that if not set will first initialise with a random int as seed.
+
+ If :attr:`np_random_seed` was set directly instead of through :meth:`reset` or :meth:`set_np_random_through_seed`,
+ the seed will take the value -1.
+
+ Returns:
+ int: the seed of the current `np_random` or -1, if the seed of the rng is unknown
+ """
+ if self._np_random_seed is None:
+ self._np_random, self._np_random_seed = seeding.np_random()
+ return self._np_random_seed
+
+ @property
+ def np_random(self) -> np.random.Generator:
+ """Returns the environment's internal :attr:`_np_random` that if not set will initialise with a random seed.
+
+ Returns:
+ Instances of `np.random.Generator`
+ """
+ if self._np_random is None:
+ self._np_random, self._np_random_seed = seeding.np_random()
+ return self._np_random
+
+ @np_random.setter
+ def np_random(self, value: np.random.Generator):
+ """Sets the environment's internal :attr:`_np_random` with the user-provided Generator.
+
+ Since it is generally not possible to extract a seed from an instance of a random number generator,
+ this will also set the :attr:`_np_random_seed` to `-1`, which is not valid as input for the creation
+ of a numpy rng.
+ """
+ self._np_random = value
+ # Setting a numpy rng with -1 will cause a ValueError
+ self._np_random_seed = -1
+
+ def __str__(self):
+ """Returns a string of the environment with :attr:`spec` id's if :attr:`spec.
+
+ Returns:
+ A string identifying the environment
+ """
+ if self.spec is None:
+ return f"<{type(self).__name__} instance>"
+ else:
+ return f"<{type(self).__name__}<{self.spec.id}>>"
+
+ def __enter__(self):
+ """Support with-statement for the environment."""
+ return self
+
+ def __exit__(self, *args: Any):
+ """Support with-statement for the environment and closes the environment."""
+ self.close()
+ # propagate exception
+ return False
+
+ def has_wrapper_attr(self, name: str) -> bool:
+ """Checks if the attribute `name` exists in the environment."""
+ return hasattr(self, name)
+
+ def get_wrapper_attr(self, name: str) -> Any:
+ """Gets the attribute `name` from the environment."""
+ return getattr(self, name)
+
+ def set_wrapper_attr(self, name: str, value: Any, *, force: bool = True) -> bool:
+ """Sets the attribute `name` on the environment with `value`, see `Wrapper.set_wrapper_attr` for more info."""
+ if force or hasattr(self, name):
+ setattr(self, name, value)
+ return True
+ return False
+
+
+
+WrapperObsType = TypeVar("WrapperObsType")
+WrapperActType = TypeVar("WrapperActType")
+
+
+
+[docs]
+class Wrapper(
+ Env[WrapperObsType, WrapperActType],
+ Generic[WrapperObsType, WrapperActType, ObsType, ActType],
+):
+ """Wraps a :class:`gymnasium.Env` to allow a modular transformation of the :meth:`step` and :meth:`reset` methods.
+
+ This class is the base class of all wrappers to change the behavior of the underlying environment.
+ Wrappers that inherit from this class can modify the :attr:`action_space`, :attr:`observation_space`
+ and :attr:`metadata` attributes, without changing the underlying environment's attributes.
+ Moreover, the behavior of the :meth:`step` and :meth:`reset` methods can be changed by these wrappers.
+
+ Some attributes (:attr:`spec`, :attr:`render_mode`, :attr:`np_random`) will point back to the wrapper's environment
+ (i.e. to the corresponding attributes of :attr:`env`).
+
+ Note:
+ If you inherit from :class:`Wrapper`, don't forget to call ``super().__init__(env)``
+ """
+
+ def __init__(self, env: Env[ObsType, ActType]):
+ """Wraps an environment to allow a modular transformation of the :meth:`step` and :meth:`reset` methods.
+
+ Args:
+ env: The environment to wrap
+ """
+ self.env = env
+ assert isinstance(
+ env, Env
+ ), f"Expected env to be a `gymnasium.Env` but got {type(env)}"
+
+ self._action_space: spaces.Space[WrapperActType] | None = None
+ self._observation_space: spaces.Space[WrapperObsType] | None = None
+ self._metadata: dict[str, Any] | None = None
+
+ self._cached_spec: EnvSpec | None = None
+
+
+[docs]
+ def step(
+ self, action: WrapperActType
+ ) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]:
+ """Uses the :meth:`step` of the :attr:`env` that can be overwritten to change the returned data."""
+ return self.env.step(action)
+
+
+
+[docs]
+ def reset(
+ self, *, seed: int | None = None, options: dict[str, Any] | None = None
+ ) -> tuple[WrapperObsType, dict[str, Any]]:
+ """Uses the :meth:`reset` of the :attr:`env` that can be overwritten to change the returned data."""
+ return self.env.reset(seed=seed, options=options)
+
+
+
+[docs]
+ def render(self) -> RenderFrame | list[RenderFrame] | None:
+ """Uses the :meth:`render` of the :attr:`env` that can be overwritten to change the returned data."""
+ return self.env.render()
+
+
+
+
+
+ @property
+ def np_random_seed(self) -> int | None:
+ """Returns the base environment's :attr:`np_random_seed`."""
+ return self.env.np_random_seed
+
+ @property
+ def unwrapped(self) -> Env[ObsType, ActType]:
+ """Returns the base environment of the wrapper.
+
+ This will be the bare :class:`gymnasium.Env` environment, underneath all layers of wrappers.
+ """
+ return self.env.unwrapped
+
+ @property
+ def spec(self) -> EnvSpec | None:
+ """Returns the :attr:`Env` :attr:`spec` attribute with the `WrapperSpec` if the wrapper inherits from `EzPickle`."""
+ if self._cached_spec is not None:
+ return self._cached_spec
+
+ env_spec = self.env.spec
+ if env_spec is not None:
+ # See if the wrapper inherits from `RecordConstructorArgs` then add the kwargs otherwise use `None` for the wrapper kwargs. This will raise an error in `make`
+ if isinstance(self, RecordConstructorArgs):
+ kwargs = getattr(self, "_saved_kwargs")
+ if "env" in kwargs:
+ kwargs = deepcopy(kwargs)
+ kwargs.pop("env")
+ else:
+ kwargs = None
+
+ from gymnasium.envs.registration import WrapperSpec
+
+ wrapper_spec = WrapperSpec(
+ name=self.class_name(),
+ entry_point=f"{self.__module__}:{type(self).__name__}",
+ kwargs=kwargs,
+ )
+
+ # to avoid reference issues we deepcopy the prior environments spec and add the new information
+ try:
+ env_spec = deepcopy(env_spec)
+ env_spec.additional_wrappers += (wrapper_spec,)
+ except Exception as e:
+ gymnasium.logger.warn(
+ f"An exception occurred ({e}) while copying the environment spec={env_spec}"
+ )
+ return None
+
+ self._cached_spec = env_spec
+ return env_spec
+
+
+[docs]
+ @classmethod
+ def wrapper_spec(cls, **kwargs: Any) -> WrapperSpec:
+ """Generates a `WrapperSpec` for the wrappers."""
+ from gymnasium.envs.registration import WrapperSpec
+
+ return WrapperSpec(
+ name=cls.class_name(),
+ entry_point=f"{cls.__module__}:{cls.__name__}",
+ kwargs=kwargs,
+ )
+
+
+ def has_wrapper_attr(self, name: str) -> bool:
+ """Checks if the given attribute is within the wrapper or its environment."""
+ if hasattr(self, name):
+ return True
+ else:
+ return self.env.has_wrapper_attr(name)
+
+
+[docs]
+ def get_wrapper_attr(self, name: str) -> Any:
+ """Gets an attribute from the wrapper and lower environments if `name` doesn't exist in this object.
+
+ Args:
+ name: The variable name to get
+
+ Returns:
+ The variable with name in wrapper or lower environments
+ """
+ if hasattr(self, name):
+ return getattr(self, name)
+ else:
+ try:
+ return self.env.get_wrapper_attr(name)
+ except AttributeError as e:
+ raise AttributeError(
+ f"wrapper {self.class_name()} has no attribute {name!r}"
+ ) from e
+
+
+
+[docs]
+ def set_wrapper_attr(self, name: str, value: Any, *, force: bool = True) -> bool:
+ """Sets an attribute on this wrapper or lower environment if `name` is already defined.
+
+ Args:
+ name: The variable name
+ value: The new variable value
+ force: Whether to create the attribute on this wrapper if it does not exists on the
+ lower environment instead of raising an exception
+
+ Returns:
+ If the variable has been set in this or a lower wrapper.
+ """
+ if hasattr(self, name):
+ setattr(self, name, value)
+ return True
+ else:
+ already_set = self.env.set_wrapper_attr(name, value, force=False)
+ if already_set:
+ return True
+ elif force:
+ setattr(self, name, value)
+ return True
+ else:
+ return False
+
+
+ def __str__(self):
+ """Returns the wrapper name and the :attr:`env` representation string."""
+ return f"<{type(self).__name__}{self.env}>"
+
+ def __repr__(self):
+ """Returns the string representation of the wrapper."""
+ return str(self)
+
+ @classmethod
+ def class_name(cls) -> str:
+ """Returns the class name of the wrapper."""
+ return cls.__name__
+
+ @property
+ def action_space(
+ self,
+ ) -> spaces.Space[ActType] | spaces.Space[WrapperActType]:
+ """Return the :attr:`Env` :attr:`action_space` unless overwritten then the wrapper :attr:`action_space` is used."""
+ if self._action_space is None:
+ return self.env.action_space
+ return self._action_space
+
+ @action_space.setter
+ def action_space(self, space: spaces.Space[WrapperActType]):
+ self._action_space = space
+
+ @property
+ def observation_space(
+ self,
+ ) -> spaces.Space[ObsType] | spaces.Space[WrapperObsType]:
+ """Return the :attr:`Env` :attr:`observation_space` unless overwritten then the wrapper :attr:`observation_space` is used."""
+ if self._observation_space is None:
+ return self.env.observation_space
+ return self._observation_space
+
+ @observation_space.setter
+ def observation_space(self, space: spaces.Space[WrapperObsType]):
+ self._observation_space = space
+
+ @property
+ def metadata(self) -> dict[str, Any]:
+ """Returns the :attr:`Env` :attr:`metadata`."""
+ if self._metadata is None:
+ return self.env.metadata
+ return self._metadata
+
+ @metadata.setter
+ def metadata(self, value: dict[str, Any]):
+ self._metadata = value
+
+ @property
+ def render_mode(self) -> str | None:
+ """Returns the :attr:`Env` :attr:`render_mode`."""
+ return self.env.render_mode
+
+ @property
+ def np_random(self) -> np.random.Generator:
+ """Returns the :attr:`Env` :attr:`np_random` attribute."""
+ return self.env.np_random
+
+ @np_random.setter
+ def np_random(self, value: np.random.Generator):
+ self.env.np_random = value
+
+ @property
+ def _np_random(self):
+ """This code will never be run due to __getattr__ being called prior this.
+
+ It seems that @property overwrites the variable (`_np_random`) meaning that __getattr__ gets called with the missing variable.
+ """
+ raise AttributeError(
+ "Can't access `_np_random` of a wrapper, use `.unwrapped._np_random` or `.np_random`."
+ )
+
+
+
+
+[docs]
+class ObservationWrapper(Wrapper[WrapperObsType, ActType, ObsType, ActType]):
+ """Modify observations from :meth:`Env.reset` and :meth:`Env.step` using :meth:`observation` function.
+
+ If you would like to apply a function to only the observation before
+ passing it to the learning code, you can simply inherit from :class:`ObservationWrapper` and overwrite the method
+ :meth:`observation` to implement that transformation. The transformation defined in that method must be
+ reflected by the :attr:`env` observation space. Otherwise, you need to specify the new observation space of the
+ wrapper by setting :attr:`self.observation_space` in the :meth:`__init__` method of your wrapper.
+ """
+
+ def __init__(self, env: Env[ObsType, ActType]):
+ """Constructor for the observation wrapper.
+
+ Args:
+ env: Environment to be wrapped.
+ """
+ Wrapper.__init__(self, env)
+
+ def reset(
+ self, *, seed: int | None = None, options: dict[str, Any] | None = None
+ ) -> tuple[WrapperObsType, dict[str, Any]]:
+ """Modifies the :attr:`env` after calling :meth:`reset`, returning a modified observation using :meth:`self.observation`."""
+ obs, info = self.env.reset(seed=seed, options=options)
+ return self.observation(obs), info
+
+ def step(
+ self, action: ActType
+ ) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]:
+ """Modifies the :attr:`env` after calling :meth:`step` using :meth:`self.observation` on the returned observations."""
+ observation, reward, terminated, truncated, info = self.env.step(action)
+ return self.observation(observation), reward, terminated, truncated, info
+
+
+[docs]
+ def observation(self, observation: ObsType) -> WrapperObsType:
+ """Returns a modified observation.
+
+ Args:
+ observation: The :attr:`env` observation
+
+ Returns:
+ The modified observation
+ """
+ raise NotImplementedError
+
+
+
+
+
+[docs]
+class RewardWrapper(Wrapper[ObsType, ActType, ObsType, ActType]):
+ """Superclass of wrappers that can modify the returning reward from a step.
+
+ If you would like to apply a function to the reward that is returned by the base environment before
+ passing it to learning code, you can simply inherit from :class:`RewardWrapper` and overwrite the method
+ :meth:`reward` to implement that transformation.
+ """
+
+ def __init__(self, env: Env[ObsType, ActType]):
+ """Constructor for the Reward wrapper.
+
+ Args:
+ env: Environment to be wrapped.
+ """
+ Wrapper.__init__(self, env)
+
+ def step(
+ self, action: ActType
+ ) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
+ """Modifies the :attr:`env` :meth:`step` reward using :meth:`self.reward`."""
+ observation, reward, terminated, truncated, info = self.env.step(action)
+ return observation, self.reward(reward), terminated, truncated, info
+
+
+[docs]
+ def reward(self, reward: SupportsFloat) -> SupportsFloat:
+ """Returns a modified environment ``reward``.
+
+ Args:
+ reward: The :attr:`env` :meth:`step` reward
+
+ Returns:
+ The modified `reward`
+ """
+ raise NotImplementedError
+
+
+
+
+
+[docs]
+class ActionWrapper(Wrapper[ObsType, WrapperActType, ObsType, ActType]):
+ """Superclass of wrappers that can modify the action before :meth:`step`.
+
+ If you would like to apply a function to the action before passing it to the base environment,
+ you can simply inherit from :class:`ActionWrapper` and overwrite the method :meth:`action` to implement
+ that transformation. The transformation defined in that method must take values in the base environment’s
+ action space. However, its domain might differ from the original action space.
+ In that case, you need to specify the new action space of the wrapper by setting :attr:`action_space` in
+ the :meth:`__init__` method of your wrapper.
+
+ Among others, Gymnasium provides the action wrappers :class:`gymnasium.wrappers.ClipAction` and
+ :class:`gymnasium.wrappers.RescaleAction` for clipping and rescaling actions.
+ """
+
+ def __init__(self, env: Env[ObsType, ActType]):
+ """Constructor for the action wrapper.
+
+ Args:
+ env: Environment to be wrapped.
+ """
+ Wrapper.__init__(self, env)
+
+ def step(
+ self, action: WrapperActType
+ ) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
+ """Runs the :attr:`env` :meth:`env.step` using the modified ``action`` from :meth:`self.action`."""
+ return self.env.step(self.action(action))
+
+
+[docs]
+ def action(self, action: WrapperActType) -> ActType:
+ """Returns a modified action before :meth:`step` is called.
+
+ Args:
+ action: The original :meth:`step` actions
+
+ Returns:
+ The modified actions
+ """
+ raise NotImplementedError
+
+
+
+"""Functional to Environment compatibility."""
+
+from __future__ import annotations
+
+from typing import Any, Generic, TypeAlias
+
+import jax
+import jax.numpy as jnp
+import jax.random as jrng
+
+import gymnasium as gym
+from gymnasium.envs.registration import EnvSpec
+from gymnasium.experimental.functional import ActType, FuncEnv, ObsType, StateType
+from gymnasium.utils import seeding
+from gymnasium.vector import AutoresetMode
+from gymnasium.vector.utils import batch_space
+
+
+PRNGKeyType: TypeAlias = jax.Array
+
+
+
+[docs]
+class FunctionalJaxEnv(gym.Env, Generic[StateType]):
+ """A conversion layer for jax-based environments."""
+
+ state: StateType
+ rng: PRNGKeyType
+
+ def __init__(
+ self,
+ func_env: FuncEnv,
+ metadata: dict[str, Any] | None = None,
+ render_mode: str | None = None,
+ spec: EnvSpec | None = None,
+ ):
+ """Initialize the environment from a FuncEnv."""
+ if metadata is None:
+ # metadata.get("jax", False) can be used downstream to know that the environment returns jax arrays
+ metadata = {"render_mode": [], "jax": True}
+
+ self.func_env = func_env
+
+ self.observation_space = func_env.observation_space
+ self.action_space = func_env.action_space
+
+ self.metadata = metadata
+ self.render_mode = render_mode
+
+ self.spec = spec
+
+ if self.render_mode == "rgb_array":
+ self.render_state = self.func_env.render_init()
+ else:
+ self.render_state = None
+
+ np_random, _ = seeding.np_random()
+ seed = np_random.integers(0, 2**32 - 1, dtype="uint32")
+
+ self.rng = jrng.PRNGKey(seed)
+
+
+[docs]
+ def reset(self, *, seed: int | None = None, options: dict | None = None):
+ """Resets the environment using the seed."""
+ super().reset(seed=seed)
+ if seed is not None:
+ self.rng = jrng.PRNGKey(seed)
+
+ rng, self.rng = jrng.split(self.rng)
+
+ self.state = self.func_env.initial(rng=rng)
+ obs = self.func_env.observation(self.state, rng)
+ info = self.func_env.state_info(self.state)
+
+ return obs, info
+
+
+
+[docs]
+ def step(self, action: ActType):
+ """Steps through the environment using the action."""
+ rng, self.rng = jrng.split(self.rng)
+
+ next_state = self.func_env.transition(self.state, action, rng)
+ observation = self.func_env.observation(next_state, rng)
+ reward = self.func_env.reward(self.state, action, next_state, rng)
+ terminated = self.func_env.terminal(next_state, rng)
+ info = self.func_env.transition_info(self.state, action, next_state)
+ self.state = next_state
+
+ return observation, float(reward), bool(terminated), False, info
+
+
+
+[docs]
+ def render(self):
+ """Returns the render state if `render_mode` is "rgb_array"."""
+ if self.render_mode == "rgb_array":
+ self.render_state, image = self.func_env.render_image(
+ self.state, self.render_state
+ )
+ return image
+ else:
+ raise NotImplementedError
+
+
+ def close(self):
+ """Closes the environments and render state if set."""
+ if self.render_state is not None:
+ self.func_env.render_close(self.render_state)
+ self.render_state = None
+
+
+
+class FunctionalJaxVectorEnv(
+ gym.vector.VectorEnv[ObsType, ActType, Any], Generic[ObsType, ActType, StateType]
+):
+ """A vector env implementation for functional Jax envs."""
+
+ state: StateType
+ rng: PRNGKeyType
+
+ def __init__(
+ self,
+ func_env: FuncEnv[StateType, ObsType, ActType, Any, Any, Any, Any],
+ num_envs: int,
+ max_episode_steps: int = 0,
+ metadata: dict[str, Any] | None = None,
+ render_mode: str | None = None,
+ spec: EnvSpec | None = None,
+ ):
+ """Initialize the environment from a FuncEnv."""
+ super().__init__()
+ if metadata is None:
+ metadata = {"autoreset_mode": AutoresetMode.NEXT_STEP}
+ self.func_env = func_env
+ self.num_envs = num_envs
+
+ self.single_observation_space = func_env.observation_space
+ self.single_action_space = func_env.action_space
+ self.observation_space = batch_space(
+ self.single_observation_space, self.num_envs
+ )
+ self.action_space = batch_space(self.single_action_space, self.num_envs)
+
+ self.metadata = metadata
+ self.render_mode = render_mode
+ self.spec = spec
+ self.time_limit = max_episode_steps
+
+ self.steps = jnp.zeros(self.num_envs, dtype=jnp.int32)
+
+ self.prev_done = jnp.zeros(self.num_envs, dtype=jnp.bool_)
+
+ if self.render_mode == "rgb_array":
+ self.render_state = self.func_env.render_init()
+ else:
+ self.render_state = None
+
+ np_random, _ = seeding.np_random()
+ seed = np_random.integers(0, 2**32 - 1, dtype="uint32")
+
+ self.rng = jrng.PRNGKey(seed)
+
+ self.func_env.transform(jax.vmap)
+
+ def reset(self, *, seed: int | None = None, options: dict | None = None):
+ """Resets the environment."""
+ super().reset(seed=seed)
+ if seed is not None:
+ self.rng = jrng.PRNGKey(seed)
+
+ rng, self.rng = jrng.split(self.rng)
+
+ rng = jrng.split(rng, self.num_envs)
+
+ self.state = self.func_env.initial(rng=rng)
+ obs = self.func_env.observation(self.state, rng)
+ info = self.func_env.state_info(self.state)
+
+ self.steps = jnp.zeros(self.num_envs, dtype=jnp.int32)
+
+ return obs, info
+
+ def step(self, action: ActType):
+ """Steps through the environment using the action."""
+ self.steps += 1
+
+ rng, self.rng = jrng.split(self.rng)
+
+ rng = jrng.split(rng, self.num_envs)
+
+ next_state = self.func_env.transition(self.state, action, rng)
+ reward = self.func_env.reward(self.state, action, next_state, rng)
+
+ terminated = self.func_env.terminal(next_state, rng)
+ truncated = (
+ self.steps >= self.time_limit
+ if self.time_limit > 0
+ else jnp.zeros_like(terminated)
+ )
+
+ info = self.func_env.transition_info(self.state, action, next_state)
+
+ if jnp.any(self.prev_done):
+ to_reset = jnp.where(self.prev_done)[0]
+ reset_count = to_reset.shape[0]
+
+ rng, self.rng = jrng.split(self.rng)
+ rng = jrng.split(rng, reset_count)
+
+ new_initials = self.func_env.initial(rng)
+
+ next_state = self.state.at[to_reset].set(new_initials)
+ self.steps = self.steps.at[to_reset].set(0)
+ terminated = terminated.at[to_reset].set(False)
+ truncated = truncated.at[to_reset].set(False)
+
+ self.prev_done = jnp.logical_or(terminated, truncated)
+
+ rng = jrng.split(self.rng, self.num_envs)
+
+ observation = self.func_env.observation(next_state, rng)
+
+ self.state = next_state
+
+ return observation, reward, terminated, truncated, info
+
+ def render(self):
+ """Returns the render state if `render_mode` is "rgb_array"."""
+ if self.render_mode == "rgb_array":
+ self.render_state, image = self.func_env.render_image(
+ self.state, self.render_state
+ )
+ return image
+ else:
+ raise NotImplementedError
+
+ def close(self):
+ """Closes the environments and render state if set."""
+ if self.render_state is not None:
+ self.func_env.render_close(self.render_state)
+ self.render_state = None
+
+"""Functions for registering environments within gymnasium using public functions ``make``, ``register`` and ``spec``."""
+
+from __future__ import annotations
+
+import contextlib
+import copy
+import dataclasses
+import difflib
+import importlib
+import importlib.metadata as metadata
+import importlib.util
+import json
+import re
+from collections import defaultdict
+from collections.abc import Callable, Iterable, Sequence
+from dataclasses import dataclass, field
+from enum import Enum
+from types import ModuleType
+from typing import Any, Protocol
+
+import gymnasium as gym
+from gymnasium import Env, Wrapper, error, logger
+from gymnasium.logger import warn
+from gymnasium.vector import AutoresetMode
+
+
+ENV_ID_RE = re.compile(
+ r"^(?:(?P<namespace>[\w:-]+)\/)?(?:(?P<name>[\w:.-]+?))(?:-v(?P<version>\d+))?$"
+)
+
+
+__all__ = [
+ "registry",
+ "current_namespace",
+ "EnvSpec",
+ "WrapperSpec",
+ "VectorizeMode",
+ # Functions
+ "register",
+ "make",
+ "make_vec",
+ "spec",
+ "pprint_registry",
+ "register_envs",
+]
+
+
+class EnvCreator(Protocol):
+ """Function type expected for an environment."""
+
+ def __call__(self, **kwargs: Any) -> Env: ...
+
+
+class VectorEnvCreator(Protocol):
+ """Function type expected for an environment."""
+
+ def __call__(self, **kwargs: Any) -> gym.vector.VectorEnv: ...
+
+
+
+[docs]
+@dataclass
+class WrapperSpec:
+ """A specification for recording wrapper configs.
+
+ * name: The name of the wrapper.
+ * entry_point: The location of the wrapper to create from.
+ * kwargs: Additional keyword arguments passed to the wrapper. If the wrapper doesn't inherit from EzPickle then this is ``None``
+ """
+
+ name: str
+ entry_point: str
+ kwargs: dict[str, Any] | None
+
+
+
+
+[docs]
+@dataclass
+class EnvSpec:
+ """A specification for creating environments with :meth:`gymnasium.make`.
+
+ * **id**: The string used to create the environment with :meth:`gymnasium.make`
+ * **entry_point**: A string for the environment location, ``(import path):(environment name)`` or a function that creates the environment.
+ * **reward_threshold**: The reward threshold for completing the environment.
+ * **nondeterministic**: If the observation of an environment cannot be repeated with the same initial state, random number generator state and actions.
+ * **max_episode_steps**: The max number of steps that the environment can take before truncation
+ * **order_enforce**: If to enforce the order of :meth:`gymnasium.Env.reset` before :meth:`gymnasium.Env.step` and :meth:`gymnasium.Env.render` functions
+ * **disable_env_checker**: If to disable the environment checker wrapper in :meth:`gymnasium.make`, by default False (runs the environment checker)
+ * **kwargs**: Additional keyword arguments passed to the environment during initialisation
+ * **additional_wrappers**: A tuple of additional wrappers applied to the environment (WrapperSpec)
+ * **vector_entry_point**: The location of the vectorized environment to create from
+
+ Changelogs:
+ v1.0.0 - Autoreset attribute removed
+ """
+
+ id: str
+ entry_point: EnvCreator | str | None = field(default=None)
+
+ # Environment attributes
+ reward_threshold: float | None = field(default=None)
+ nondeterministic: bool = field(default=False)
+
+ # Wrappers
+ max_episode_steps: int | None = field(default=None)
+ order_enforce: bool = field(default=True)
+ disable_env_checker: bool = field(default=False)
+
+ # Environment arguments
+ kwargs: dict = field(default_factory=dict)
+
+ # post-init attributes
+ namespace: str | None = field(init=False)
+ name: str = field(init=False)
+ version: int | None = field(init=False)
+
+ # applied wrappers
+ additional_wrappers: tuple[WrapperSpec, ...] = field(default_factory=tuple)
+
+ # Vectorized environment entry point
+ vector_entry_point: VectorEnvCreator | str | None = field(default=None)
+
+ def __post_init__(self):
+ """Calls after the spec is created to extract the namespace, name and version from the environment id."""
+ self.namespace, self.name, self.version = parse_env_id(self.id)
+
+ def make(self, **kwargs: Any) -> Env:
+ """Calls ``make`` using the environment spec and any keyword arguments."""
+ return make(self, **kwargs)
+
+ def to_json(self) -> str:
+ """Converts the environment spec into a json compatible string.
+
+ Returns:
+ A jsonifyied string for the environment spec
+ """
+ env_spec_dict = dataclasses.asdict(self)
+ # As the namespace, name and version are initialised after `init` then we remove the attributes
+ env_spec_dict.pop("namespace")
+ env_spec_dict.pop("name")
+ env_spec_dict.pop("version")
+
+ # To check that the environment spec can be transformed to a json compatible type
+ self._check_can_jsonify(env_spec_dict)
+
+ return json.dumps(env_spec_dict)
+
+ @staticmethod
+ def _check_can_jsonify(env_spec: dict[str, Any]):
+ """Warns the user about serialisation failing if the spec contains a callable.
+
+ Args:
+ env_spec: An environment or wrapper specification.
+
+ Returns: The specification with lambda functions converted to strings.
+
+ """
+ spec_name = env_spec["name"] if "name" in env_spec else env_spec["id"]
+
+ for key, value in env_spec.items():
+ if callable(value):
+ raise ValueError(
+ f"Callable found in {spec_name} for {key} attribute with value={value}. Currently, Gymnasium does not support serialising callables."
+ )
+
+ @staticmethod
+ def from_json(json_env_spec: str) -> EnvSpec:
+ """Converts a JSON string into a specification stack.
+
+ Args:
+ json_env_spec: A JSON string representing the env specification.
+
+ Returns:
+ An environment spec
+ """
+ parsed_env_spec = json.loads(json_env_spec)
+
+ applied_wrapper_specs: list[WrapperSpec] = []
+ for wrapper_spec_json in parsed_env_spec.pop("additional_wrappers"):
+ try:
+ applied_wrapper_specs.append(WrapperSpec(**wrapper_spec_json))
+ except Exception as e:
+ raise ValueError(
+ f"An issue occurred when trying to make {wrapper_spec_json} a WrapperSpec"
+ ) from e
+
+ try:
+ env_spec = EnvSpec(**parsed_env_spec)
+ env_spec.additional_wrappers = tuple(applied_wrapper_specs)
+ except Exception as e:
+ raise ValueError(
+ f"An issue occurred when trying to make {parsed_env_spec} an EnvSpec"
+ ) from e
+
+ return env_spec
+
+ def pprint(
+ self,
+ disable_print: bool = False,
+ include_entry_points: bool = False,
+ print_all: bool = False,
+ ) -> str | None:
+ """Pretty prints the environment spec.
+
+ Args:
+ disable_print: If to disable print and return the output
+ include_entry_points: If to include the entry_points in the output
+ print_all: If to print all information, including variables with default values
+
+ Returns:
+ If ``disable_print is True`` a string otherwise ``None``
+ """
+ output = f"id={self.id}"
+ if print_all or include_entry_points:
+ output += f"\nentry_point={self.entry_point}"
+
+ if print_all or self.reward_threshold is not None:
+ output += f"\nreward_threshold={self.reward_threshold}"
+ if print_all or self.nondeterministic is not False:
+ output += f"\nnondeterministic={self.nondeterministic}"
+
+ if print_all or self.max_episode_steps is not None:
+ output += f"\nmax_episode_steps={self.max_episode_steps}"
+ if print_all or self.order_enforce is not True:
+ output += f"\norder_enforce={self.order_enforce}"
+ if print_all or self.disable_env_checker is not False:
+ output += f"\ndisable_env_checker={self.disable_env_checker}"
+
+ if print_all or self.additional_wrappers:
+ wrapper_output: list[str] = []
+ for wrapper_spec in self.additional_wrappers:
+ if include_entry_points:
+ wrapper_output.append(
+ f"\n\tname={wrapper_spec.name}, entry_point={wrapper_spec.entry_point}, kwargs={wrapper_spec.kwargs}"
+ )
+ else:
+ wrapper_output.append(
+ f"\n\tname={wrapper_spec.name}, kwargs={wrapper_spec.kwargs}"
+ )
+
+ if len(wrapper_output) == 0:
+ output += "\nadditional_wrappers=[]"
+ else:
+ output += f"\nadditional_wrappers=[{','.join(wrapper_output)}\n]"
+
+ if disable_print:
+ return output
+ else:
+ print(output)
+
+
+
+class VectorizeMode(Enum):
+ """All possible vectorization modes used in `make_vec`."""
+
+ ASYNC = "async"
+ SYNC = "sync"
+ VECTOR_ENTRY_POINT = "vector_entry_point"
+
+
+# Global registry of environments. Meant to be accessed through `register` and `make`
+registry: dict[str, EnvSpec] = {}
+current_namespace: str | None = None
+
+
+
+[docs]
+def parse_env_id(env_id: str) -> tuple[str | None, str, int | None]:
+ """Parse environment ID string format - ``[namespace/](env-name)[-v(version)]`` where the namespace and version are optional.
+
+ Args:
+ env_id: The environment id to parse
+
+ Returns:
+ A tuple of environment namespace, environment name and version number
+
+ Raises:
+ Error: If the environment id is not valid environment regex
+ """
+ match = ENV_ID_RE.fullmatch(env_id)
+ if not match:
+ raise error.Error(
+ f"Malformed environment ID: {env_id}. (Currently all IDs must be of the form [namespace/](env-name)-v(version). (namespace is optional))"
+ )
+ ns, name, version = match.group("namespace", "name", "version")
+ if version is not None:
+ version = int(version)
+
+ return ns, name, version
+
+
+
+
+[docs]
+def get_env_id(ns: str | None, name: str, version: int | None) -> str:
+ """Get the full env ID given a name and (optional) version and namespace. Inverse of :meth:`parse_env_id`.
+
+ Args:
+ ns: The environment namespace
+ name: The environment name
+ version: The environment version
+
+ Returns:
+ The environment id
+ """
+ full_name = name
+ if ns is not None:
+ full_name = f"{ns}/{name}"
+ if version is not None:
+ full_name = f"{full_name}-v{version}"
+
+ return full_name
+
+
+
+
+[docs]
+def find_highest_version(ns: str | None, name: str) -> int | None:
+ """Finds the highest registered version of the environment given the namespace and name in the registry.
+
+ Args:
+ ns: The environment namespace
+ name: The environment name (id)
+
+ Returns:
+ The highest version of an environment with matching namespace and name, otherwise ``None`` is returned.
+ """
+ version: list[int] = [
+ env_spec.version
+ for env_spec in registry.values()
+ if env_spec.namespace == ns
+ and env_spec.name == name
+ and env_spec.version is not None
+ ]
+ return max(version, default=None)
+
+
+
+def _check_namespace_exists(ns: str | None):
+ """Check if a namespace exists. If it doesn't, print a helpful error message."""
+ # If the namespace is none, then the namespace does exist
+ if ns is None:
+ return
+
+ # Check if the namespace exists in one of the registry's specs
+ namespaces: set[str] = {
+ env_spec.namespace
+ for env_spec in registry.values()
+ if env_spec.namespace is not None
+ }
+ if ns in namespaces:
+ return
+
+ # Otherwise, the namespace doesn't exist and raise a helpful message
+ suggestion = (
+ difflib.get_close_matches(ns, namespaces, n=1) if len(namespaces) > 0 else None
+ )
+ if suggestion:
+ suggestion_msg = f"Did you mean: `{suggestion[0]}`?"
+ else:
+ suggestion_msg = f"Have you installed the proper package for {ns}?"
+
+ raise error.NamespaceNotFound(f"Namespace {ns} not found. {suggestion_msg}")
+
+
+def _check_name_exists(ns: str | None, name: str):
+ """Check if an env exists in a namespace. If it doesn't, print a helpful error message."""
+ # First check if the namespace exists
+ _check_namespace_exists(ns)
+
+ # Then check if the name exists
+ names: set[str] = {
+ env_spec.name for env_spec in registry.values() if env_spec.namespace == ns
+ }
+ if name in names:
+ return
+
+ # Otherwise, raise a helpful error to the user
+ suggestion = difflib.get_close_matches(name, names, n=1)
+ namespace_msg = f" in namespace {ns}" if ns else ""
+ suggestion_msg = f" Did you mean: `{suggestion[0]}`?" if suggestion else ""
+
+ raise error.NameNotFound(
+ f"Environment `{name}` doesn't exist{namespace_msg}.{suggestion_msg}"
+ )
+
+
+def _check_version_exists(ns: str | None, name: str, version: int | None):
+ """Check if an env version exists in a namespace. If it doesn't, print a helpful error message.
+
+ This is a complete test whether an environment identifier is valid, and will provide the best available hints.
+
+ Args:
+ ns: The environment namespace
+ name: The environment space
+ version: The environment version
+
+ Raises:
+ DeprecatedEnv: The environment doesn't exist but a default version does
+ VersionNotFound: The ``version`` used doesn't exist
+ DeprecatedEnv: Environment version is deprecated
+ """
+ if get_env_id(ns, name, version) in registry:
+ return
+
+ _check_name_exists(ns, name)
+ if version is None:
+ return
+
+ message = f"Environment version `v{version}` for environment `{get_env_id(ns, name, None)}` doesn't exist."
+
+ env_specs = [
+ env_spec
+ for env_spec in registry.values()
+ if env_spec.namespace == ns and env_spec.name == name
+ ]
+ env_specs = sorted(env_specs, key=lambda env_spec: int(env_spec.version or -1))
+
+ default_spec = [env_spec for env_spec in env_specs if env_spec.version is None]
+
+ if default_spec:
+ message += f" It provides the default version `{default_spec[0].id}`."
+ if len(env_specs) == 1:
+ raise error.DeprecatedEnv(message)
+
+ # Process possible versioned environments
+
+ versioned_specs = [
+ env_spec for env_spec in env_specs if env_spec.version is not None
+ ]
+
+ latest_spec = max(versioned_specs, key=lambda env_spec: env_spec.version, default=None) # type: ignore
+ if latest_spec is not None and version > latest_spec.version:
+ version_list_msg = ", ".join(f"`v{env_spec.version}`" for env_spec in env_specs)
+ message += f" It provides versioned environments: [ {version_list_msg} ]."
+
+ raise error.VersionNotFound(message)
+
+ if latest_spec is not None and version < latest_spec.version:
+ raise error.DeprecatedEnv(
+ f"Environment version v{version} for `{get_env_id(ns, name, None)}` is deprecated. "
+ f"Please use `{latest_spec.id}` instead."
+ )
+
+
+def _check_spec_register(testing_spec: EnvSpec):
+ """Checks whether the spec is valid to be registered. Helper function for `register`."""
+ latest_versioned_spec = max(
+ (
+ env_spec
+ for env_spec in registry.values()
+ if env_spec.namespace == testing_spec.namespace
+ and env_spec.name == testing_spec.name
+ and env_spec.version is not None
+ ),
+ key=lambda spec_: int(spec_.version), # type: ignore
+ default=None,
+ )
+
+ unversioned_spec = next(
+ (
+ env_spec
+ for env_spec in registry.values()
+ if env_spec.namespace == testing_spec.namespace
+ and env_spec.name == testing_spec.name
+ and env_spec.version is None
+ ),
+ None,
+ )
+
+ if unversioned_spec is not None and testing_spec.version is not None:
+ raise error.RegistrationError(
+ "Can't register the versioned environment "
+ f"`{testing_spec.id}` when the unversioned environment "
+ f"`{unversioned_spec.id}` of the same name already exists."
+ )
+ elif latest_versioned_spec is not None and testing_spec.version is None:
+ raise error.RegistrationError(
+ f"Can't register the unversioned environment `{testing_spec.id}` when the versioned environment "
+ f"`{latest_versioned_spec.id}` of the same name already exists. Note: the default behavior is "
+ "that `gym.make` with the unversioned environment will return the latest versioned environment"
+ )
+
+
+def _check_metadata(testing_metadata: dict[str, Any]):
+ """Check the metadata of an environment."""
+ if not isinstance(testing_metadata, dict):
+ raise error.InvalidMetadata(
+ f"Expect the environment metadata to be dict, actual type: {type(metadata)}"
+ )
+
+ render_modes = testing_metadata.get("render_modes")
+ if render_modes is None:
+ logger.warn(
+ f"The environment creator metadata doesn't include `render_modes`, contains: {list(testing_metadata.keys())}"
+ )
+ elif not isinstance(render_modes, Iterable):
+ logger.warn(
+ f"Expects the environment metadata render_modes to be a Iterable, actual type: {type(render_modes)}"
+ )
+
+
+def _find_spec(env_id: str) -> EnvSpec:
+ # For string id's, load the environment spec from the registry then make the environment spec
+ assert isinstance(env_id, str)
+
+ # The environment name can include an unloaded module in "module:env_name" style
+ module, env_name = (None, env_id) if ":" not in env_id else env_id.split(":")
+ if module is not None:
+ try:
+ importlib.import_module(module)
+ except ModuleNotFoundError as e:
+ raise ModuleNotFoundError(
+ f"{e}. Environment registration via importing a module failed. "
+ f"Check whether '{module}' contains env registration and can be imported."
+ ) from e
+
+ # load the env spec from the registry
+ env_spec = registry.get(env_name)
+
+ # update env spec is not version provided, raise warning if out of date
+ ns, name, version = parse_env_id(env_name)
+
+ latest_version = find_highest_version(ns, name)
+ if version is not None and latest_version is not None and latest_version > version:
+ logger.deprecation(
+ f"The environment {env_name} is out of date. You should consider "
+ f"upgrading to version `v{latest_version}`."
+ )
+ if version is None and latest_version is not None:
+ version = latest_version
+ new_env_id = get_env_id(ns, name, version)
+ env_spec = registry.get(new_env_id)
+ logger.warn(
+ f"Using the latest versioned environment `{new_env_id}` "
+ f"instead of the unversioned environment `{env_name}`."
+ )
+
+ if env_spec is None:
+ _check_version_exists(ns, name, version)
+ raise error.Error(
+ f"No registered env with id: {env_name}. Did you register it, or import the package that registers it? Use `gymnasium.pprint_registry()` to see all of the registered environments."
+ )
+
+ return env_spec
+
+
+
+[docs]
+def load_env_creator(name: str) -> EnvCreator | VectorEnvCreator:
+ """Loads an environment with name of style ``"(import path):(environment name)"`` and returns the environment creation function, normally the environment class type.
+
+ Args:
+ name: The environment name
+
+ Returns:
+ The environment constructor for the given environment name.
+ """
+ mod_name, attr_name = name.split(":")
+ mod = importlib.import_module(mod_name)
+ fn = getattr(mod, attr_name)
+ return fn
+
+
+
+def register_envs(env_module: ModuleType):
+ """A No-op function such that it can appear to IDEs that a module is used."""
+ pass
+
+
+
+[docs]
+@contextlib.contextmanager
+def namespace(ns: str):
+ """Context manager for modifying the current namespace."""
+ global current_namespace
+ old_namespace = current_namespace
+ current_namespace = ns
+ yield
+ current_namespace = old_namespace
+
+
+
+
+[docs]
+def register(
+ id: str,
+ entry_point: EnvCreator | str | None = None,
+ reward_threshold: float | None = None,
+ nondeterministic: bool = False,
+ max_episode_steps: int | None = None,
+ order_enforce: bool = True,
+ disable_env_checker: bool = False,
+ additional_wrappers: tuple[WrapperSpec, ...] = (),
+ vector_entry_point: VectorEnvCreator | str | None = None,
+ kwargs: dict | None = None,
+):
+ """Registers an environment in gymnasium with an ``id`` to use with :meth:`gymnasium.make` with the ``entry_point`` being a string or callable for creating the environment.
+
+ The ``id`` parameter corresponds to the name of the environment, with the syntax as follows:
+ ``[namespace/](env_name)[-v(version)]`` where ``namespace`` and ``-v(version)`` is optional.
+
+ It takes arbitrary keyword arguments, which are passed to the :class:`EnvSpec` ``kwargs`` parameter.
+
+ Args:
+ id: The environment id
+ entry_point: The entry point for creating the environment
+ reward_threshold: The reward threshold considered for an agent to have learnt the environment
+ nondeterministic: If the environment is nondeterministic (even with knowledge of the initial seed and all actions, the same state cannot be reached)
+ max_episode_steps: The maximum number of episodes steps before truncation. Used by the :class:`gymnasium.wrappers.TimeLimit` wrapper if not ``None``.
+ order_enforce: If to enable the order enforcer wrapper to ensure users run functions in the correct order.
+ If ``True``, then the :class:`gymnasium.wrappers.OrderEnforcing` is applied to the environment.
+ disable_env_checker: If to disable the :class:`gymnasium.wrappers.PassiveEnvChecker` to the environment.
+ additional_wrappers: Additional wrappers to apply the environment.
+ vector_entry_point: The entry point for creating the vector environment
+ kwargs: arbitrary keyword arguments which are passed to the environment constructor on initialisation.
+
+ Changelogs:
+ v1.0.0 - `autoreset` and `apply_api_compatibility` parameter was removed
+ """
+ assert (
+ entry_point is not None or vector_entry_point is not None
+ ), "Either `entry_point` or `vector_entry_point` (or both) must be provided"
+ ns, name, version = parse_env_id(id)
+
+ if kwargs is None:
+ kwargs = dict()
+ if current_namespace is not None:
+ if (
+ kwargs.get("namespace") is not None
+ and kwargs.get("namespace") != current_namespace
+ ):
+ logger.warn(
+ f"Custom namespace `{kwargs.get('namespace')}` is being overridden by namespace `{current_namespace}`. "
+ f"If you are developing a plugin you shouldn't specify a namespace in `register` calls. "
+ "The namespace is specified through the entry point package metadata."
+ )
+ ns_id = current_namespace
+ else:
+ ns_id = ns
+ full_env_id = get_env_id(ns_id, name, version)
+
+ new_spec = EnvSpec(
+ id=full_env_id,
+ entry_point=entry_point,
+ reward_threshold=reward_threshold,
+ nondeterministic=nondeterministic,
+ max_episode_steps=max_episode_steps,
+ order_enforce=order_enforce,
+ disable_env_checker=disable_env_checker,
+ kwargs=kwargs,
+ additional_wrappers=additional_wrappers,
+ vector_entry_point=vector_entry_point,
+ )
+ _check_spec_register(new_spec)
+
+ if new_spec.id in registry:
+ logger.warn(f"Overriding environment {new_spec.id} already in registry.")
+ registry[new_spec.id] = new_spec
+
+
+
+
+[docs]
+def make(
+ id: str | EnvSpec,
+ max_episode_steps: int | None = None,
+ disable_env_checker: bool | None = None,
+ **kwargs: Any,
+) -> Env:
+ """Creates an environment previously registered with :meth:`gymnasium.register` or a :class:`EnvSpec`.
+
+ To find all available environments use ``gymnasium.envs.registry.keys()`` for all valid ids.
+
+ Args:
+ id: A string for the environment id or a :class:`EnvSpec`. Optionally if using a string, a module to import can be included, e.g. ``'module:Env-v0'``.
+ This is equivalent to importing the module first to register the environment followed by making the environment.
+ max_episode_steps: Maximum length of an episode, can override the registered :class:`EnvSpec` ``max_episode_steps``
+ with the value being passed to :class:`gymnasium.wrappers.TimeLimit`.
+ Using ``max_episode_steps=-1`` will not apply the wrapper to the environment.
+ disable_env_checker: If to add :class:`gymnasium.wrappers.PassiveEnvChecker`, ``None`` will default to the
+ :class:`EnvSpec` ``disable_env_checker`` value otherwise use this value will be used.
+ kwargs: Additional arguments to pass to the environment constructor.
+
+ Returns:
+ An instance of the environment with wrappers applied.
+
+ Raises:
+ Error: If the ``id`` doesn't exist in the :attr:`registry`
+
+ Changelogs:
+ v1.0.0 - `autoreset` and `apply_api_compatibility` was removed
+ """
+ if isinstance(id, EnvSpec):
+ env_spec = id
+ if not hasattr(env_spec, "additional_wrappers"):
+ logger.warn(
+ f"The env spec passed to `make` does not have a `additional_wrappers`, set it to an empty tuple. Env_spec={env_spec}"
+ )
+ env_spec.additional_wrappers = ()
+ else:
+ # For string id's, load the environment spec from the registry then make the environment spec
+ assert isinstance(id, str)
+
+ # The environment name can include an unloaded module in "module:env_name" style
+ env_spec = _find_spec(id)
+
+ assert isinstance(env_spec, EnvSpec)
+
+ # Update the env spec kwargs with the `make` kwargs
+ env_spec_kwargs = copy.deepcopy(env_spec.kwargs)
+ env_spec_kwargs.update(kwargs)
+
+ # Load the environment creator
+ if env_spec.entry_point is None:
+ raise error.Error(f"{env_spec.id} registered but entry_point is not specified")
+ elif callable(env_spec.entry_point):
+ env_creator = env_spec.entry_point
+ else:
+ # Assume it's a string
+ env_creator = load_env_creator(env_spec.entry_point)
+
+ # Determine if to use the rendering
+ render_modes: list[str] | None = None
+ if hasattr(env_creator, "metadata"):
+ _check_metadata(env_creator.metadata)
+ render_modes = env_creator.metadata.get("render_modes")
+ render_mode = env_spec_kwargs.get("render_mode")
+ apply_human_rendering = False
+ apply_render_collection = False
+
+ # If mode is not valid, try applying HumanRendering/RenderCollection wrappers
+ if (
+ render_mode is not None
+ and render_modes is not None
+ and render_mode not in render_modes
+ ):
+ displayable_modes = {"rgb_array", "rgb_array_list"}.intersection(render_modes)
+ if render_mode == "human" and len(displayable_modes) > 0:
+ logger.warn(
+ "You are trying to use 'human' rendering for an environment that doesn't natively support it. "
+ "The HumanRendering wrapper is being applied to your environment."
+ )
+ env_spec_kwargs["render_mode"] = displayable_modes.pop()
+ apply_human_rendering = True
+ elif (
+ render_mode.endswith("_list")
+ and render_mode[: -len("_list")] in render_modes
+ ):
+ env_spec_kwargs["render_mode"] = render_mode[: -len("_list")]
+ apply_render_collection = True
+ else:
+ logger.warn(
+ f"The environment is being initialised with render_mode={render_mode!r} "
+ f"that is not in the possible render_modes ({render_modes})."
+ )
+
+ try:
+ env = env_creator(**env_spec_kwargs)
+ except TypeError as e:
+ if (
+ str(e).find("got an unexpected keyword argument 'render_mode'") >= 0
+ and apply_human_rendering
+ ):
+ raise error.Error(
+ f"You passed render_mode='human' although {env_spec.id} doesn't implement human-rendering natively. "
+ "Gym tried to apply the HumanRendering wrapper but it looks like your environment is using the old "
+ "rendering API, which is not supported by the HumanRendering wrapper."
+ ) from e
+ else:
+ raise type(e)(
+ f"{e} was raised from the environment creator for {env_spec.id} with kwargs ({env_spec_kwargs})"
+ )
+
+ if not isinstance(env, gym.Env):
+ if (
+ str(env.__class__.__base__) == "<class 'gym.core.Env'>"
+ or str(env.__class__.__base__) == "<class 'gym.core.Wrapper'>"
+ ):
+ raise TypeError(
+ "Gym is incompatible with Gymnasium, please update the environment class to `gymnasium.Env`. "
+ "See https://gymnasium.farama.org/introduction/create_custom_env/ for more info."
+ )
+ else:
+ raise TypeError(
+ f"The environment must inherit from the gymnasium.Env class, actual class: {type(env)}. "
+ "See https://gymnasium.farama.org/introduction/create_custom_env/ for more info."
+ )
+
+ # Set the minimal env spec for the environment.
+ env.unwrapped.spec = EnvSpec(
+ id=env_spec.id,
+ entry_point=env_spec.entry_point,
+ reward_threshold=env_spec.reward_threshold,
+ nondeterministic=env_spec.nondeterministic,
+ max_episode_steps=None,
+ order_enforce=False,
+ disable_env_checker=True,
+ kwargs=env_spec_kwargs,
+ additional_wrappers=(),
+ vector_entry_point=env_spec.vector_entry_point,
+ )
+
+ # Check if pre-wrapped wrappers
+ assert env.spec is not None
+ num_prior_wrappers = len(env.spec.additional_wrappers)
+ if (
+ env_spec.additional_wrappers[:num_prior_wrappers]
+ != env.spec.additional_wrappers
+ ):
+ for env_spec_wrapper_spec, recreated_wrapper_spec in zip(
+ env_spec.additional_wrappers, env.spec.additional_wrappers
+ ):
+ raise ValueError(
+ f"The environment's wrapper spec {recreated_wrapper_spec} is different from the saved `EnvSpec` additional wrapper {env_spec_wrapper_spec}"
+ )
+
+ # Run the environment checker as the lowest level wrapper
+ if disable_env_checker is False or (
+ disable_env_checker is None and env_spec.disable_env_checker is False
+ ):
+ env = gym.wrappers.PassiveEnvChecker(env)
+
+ # Add the order enforcing wrapper
+ if env_spec.order_enforce:
+ env = gym.wrappers.OrderEnforcing(env)
+
+ # Add the time limit wrapper
+ if max_episode_steps != -1:
+ if max_episode_steps is not None:
+ env = gym.wrappers.TimeLimit(env, max_episode_steps)
+ elif env_spec.max_episode_steps is not None:
+ env = gym.wrappers.TimeLimit(env, env_spec.max_episode_steps)
+
+ for wrapper_spec in env_spec.additional_wrappers[num_prior_wrappers:]:
+ if wrapper_spec.kwargs is None:
+ raise ValueError(
+ f"{wrapper_spec.name} wrapper does not inherit from `gymnasium.utils.RecordConstructorArgs`, therefore, the wrapper cannot be recreated."
+ )
+
+ env = load_env_creator(wrapper_spec.entry_point)(env=env, **wrapper_spec.kwargs)
+
+ # Add human rendering wrapper
+ if apply_human_rendering:
+ env = gym.wrappers.HumanRendering(env)
+ elif apply_render_collection:
+ env = gym.wrappers.RenderCollection(env)
+
+ return env
+
+
+
+
+[docs]
+def make_vec(
+ id: str | EnvSpec,
+ num_envs: int = 1,
+ vectorization_mode: VectorizeMode | str | None = None,
+ vector_kwargs: dict[str, Any] | None = None,
+ wrappers: Sequence[Callable[[Env], Wrapper]] | None = None,
+ **kwargs,
+) -> gym.vector.VectorEnv:
+ """Create a vector environment according to the given ID.
+
+ To find all available environments use :func:`gymnasium.pprint_registry` or ``gymnasium.registry.keys()`` for all valid ids.
+ We refer to the Vector environment as the vectorizor while the environment being vectorized is the base or vectorized environment (``vectorizor(vectorized env)``).
+
+ Args:
+ id: Name of the environment. Optionally, a module to import can be included, e.g. 'module:Env-v0'
+ num_envs: Number of environments to create
+ vectorization_mode: The vectorization method used, defaults to ``None`` such that if env id' spec has a ``vector_entry_point`` (not ``None``),
+ this is first used otherwise defaults to ``sync`` to use the :class:`gymnasium.vector.SyncVectorEnv`.
+ Valid modes are ``"async"``, ``"sync"`` or ``"vector_entry_point"``. Recommended to use the :class:`VectorizeMode` enum rather than strings.
+ vector_kwargs: Additional arguments to pass to the vectorizor environment constructor, i.e., ``SyncVectorEnv(..., **vector_kwargs)``.
+ wrappers: A sequence of wrapper functions to apply to the base environment. Can only be used in ``"sync"`` or ``"async"`` mode.
+ **kwargs: Additional arguments passed to the base environment constructor.
+
+ Returns:
+ An instance of the environment.
+
+ Raises:
+ Error: If the ``id`` doesn't exist then an error is raised
+ """
+ if vector_kwargs is None:
+ vector_kwargs = {}
+ if wrappers is None:
+ wrappers = []
+
+ if isinstance(id, EnvSpec):
+ env_spec = id
+ elif isinstance(id, str):
+ env_spec = _find_spec(id)
+ else:
+ raise error.Error(f"Invalid id type: {type(id)}. Expected `str` or `EnvSpec`")
+
+ env_spec = copy.deepcopy(env_spec)
+ env_spec_kwargs = env_spec.kwargs
+ # for sync or async, these parameters should be passed in `make(..., **kwargs)` rather than in the env spec kwargs, therefore, we `reset` the kwargs
+ env_spec.kwargs = dict()
+
+ num_envs = env_spec_kwargs.pop("num_envs", num_envs)
+ vectorization_mode = env_spec_kwargs.pop("vectorization_mode", vectorization_mode)
+ vector_kwargs = env_spec_kwargs.pop("vector_kwargs", vector_kwargs)
+ wrappers = env_spec_kwargs.pop("wrappers", wrappers)
+
+ env_spec_kwargs.update(kwargs)
+
+ # Specify the vectorization mode if None or update to a `VectorizeMode`
+ if vectorization_mode is None:
+ if env_spec.vector_entry_point is not None:
+ vectorization_mode = VectorizeMode.VECTOR_ENTRY_POINT
+ else:
+ vectorization_mode = VectorizeMode.SYNC
+ else:
+ try:
+ vectorization_mode = VectorizeMode(vectorization_mode)
+ except ValueError:
+ raise ValueError(
+ f"Invalid vectorization mode: {vectorization_mode!r}, "
+ f"valid modes: {[mode.value for mode in VectorizeMode]}"
+ )
+ assert isinstance(vectorization_mode, VectorizeMode)
+
+ def create_single_env() -> Env:
+ single_env = make(env_spec, **env_spec_kwargs.copy())
+
+ if wrappers is None:
+ return single_env
+
+ for wrapper in wrappers:
+ single_env = wrapper(single_env)
+ return single_env
+
+ if vectorization_mode == VectorizeMode.SYNC:
+ if env_spec.entry_point is None:
+ raise error.Error(
+ f"Cannot create vectorized environment for {env_spec.id} because it doesn't have an entry point defined."
+ )
+
+ env = gym.vector.SyncVectorEnv(
+ env_fns=(create_single_env for _ in range(num_envs)),
+ **vector_kwargs,
+ )
+ elif vectorization_mode == VectorizeMode.ASYNC:
+ if env_spec.entry_point is None:
+ raise error.Error(
+ f"Cannot create vectorized environment for {env_spec.id} because it doesn't have an entry point defined."
+ )
+
+ env = gym.vector.AsyncVectorEnv(
+ env_fns=[create_single_env for _ in range(num_envs)],
+ **vector_kwargs,
+ )
+
+ elif vectorization_mode == VectorizeMode.VECTOR_ENTRY_POINT:
+ if len(vector_kwargs) > 0:
+ raise error.Error(
+ f"Custom vector environment can be passed arguments only through kwargs and `vector_kwargs` is not empty ({vector_kwargs})"
+ )
+ elif len(wrappers) > 0:
+ raise error.Error(
+ f"Cannot use `vector_entry_point` vectorization mode with the wrappers argument ({wrappers})."
+ )
+ elif len(env_spec.additional_wrappers) > 0:
+ raise error.Error(
+ f"Cannot use `vector_entry_point` vectorization mode with the additional_wrappers parameter in spec being not empty ({env_spec.additional_wrappers})."
+ )
+
+ entry_point = env_spec.vector_entry_point
+ if entry_point is None:
+ raise error.Error(
+ f"Cannot create vectorized environment for {id} because it doesn't have a vector entry point defined."
+ )
+ elif callable(entry_point):
+ env_creator = entry_point
+ else: # Assume it's a string
+ env_creator = load_env_creator(entry_point)
+
+ if (
+ env_spec.max_episode_steps is not None
+ and "max_episode_steps" not in env_spec_kwargs
+ ):
+ env_spec_kwargs["max_episode_steps"] = env_spec.max_episode_steps
+
+ env = env_creator(num_envs=num_envs, **env_spec_kwargs)
+ else:
+ raise error.Error(f"Unknown vectorization mode: {vectorization_mode}")
+
+ # Copies the environment creation specification and kwargs to add to the environment specification details
+ copied_id_spec = copy.deepcopy(env_spec)
+ copied_id_spec.kwargs = env_spec_kwargs.copy()
+ if num_envs != 1:
+ copied_id_spec.kwargs["num_envs"] = num_envs
+ copied_id_spec.kwargs["vectorization_mode"] = vectorization_mode.value
+ if len(vector_kwargs) > 0:
+ copied_id_spec.kwargs["vector_kwargs"] = vector_kwargs
+ if len(wrappers) > 0:
+ copied_id_spec.kwargs["wrappers"] = wrappers
+ env.unwrapped.spec = copied_id_spec
+
+ if "autoreset_mode" not in env.metadata:
+ warn(
+ f"The VectorEnv ({env}) is missing AutoresetMode metadata, metadata={env.metadata}"
+ )
+ elif not isinstance(env.metadata["autoreset_mode"], AutoresetMode):
+ warn(
+ f"The VectorEnv ({env}) metadata['autoreset_mode'] is not an instance of AutoresetMode, {type(env.metadata['autoreset_mode'])}."
+ )
+
+ return env
+
+
+
+
+[docs]
+def spec(env_id: str) -> EnvSpec:
+ """Retrieve the :class:`EnvSpec` for the environment id from the :attr:`registry`.
+
+ Args:
+ env_id: The environment id with the expected format of ``[(namespace)/]id[-v(version)]``
+
+ Returns:
+ The environment spec if it exists
+
+ Raises:
+ Error: If the environment id doesn't exist
+ """
+ env_spec = registry.get(env_id)
+ if env_spec is None:
+ ns, name, version = parse_env_id(env_id)
+ _check_version_exists(ns, name, version)
+ raise error.Error(f"No registered env with id: {env_id}")
+ else:
+ assert isinstance(
+ env_spec, EnvSpec
+ ), f"Expected the registry for {env_id} to be an `EnvSpec`, actual type is {type(env_spec)}"
+ return env_spec
+
+
+
+
+[docs]
+def pprint_registry(
+ print_registry: dict[str, EnvSpec] = registry,
+ *,
+ num_cols: int = 3,
+ exclude_namespaces: list[str] | None = None,
+ disable_print: bool = False,
+) -> str | None:
+ """Pretty prints all environments in the :attr:`registry`.
+
+ Note:
+ All arguments are keyword only
+
+ Args:
+ print_registry: Environment registry to be printed. By default, :attr:`registry`
+ num_cols: Number of columns to arrange environments in, for display.
+ exclude_namespaces: A list of namespaces to be excluded from printing. Helpful if only ALE environments are wanted.
+ disable_print: Whether to return a string of all the namespaces and environment IDs
+ or to print the string to console.
+ """
+ # Defaultdict to store environment ids according to namespace.
+ namespace_envs: dict[str, list[str]] = defaultdict(list)
+ max_justify = float("-inf")
+
+ # Find the namespace associated with each environment spec
+ for env_spec in print_registry.values():
+ ns = env_spec.namespace
+
+ if ns is None and isinstance(env_spec.entry_point, str):
+ # Use regex to obtain namespace from entrypoints.
+ env_entry_point = re.sub(r":\w+", "", env_spec.entry_point)
+ split_entry_point = env_entry_point.split(".")
+
+ if len(split_entry_point) >= 3:
+ # If namespace is of the format:
+ # - gymnasium.envs.mujoco.ant_v4:AntEnv
+ # - gymnasium.envs.mujoco:HumanoidEnv
+ ns = split_entry_point[2]
+ elif len(split_entry_point) > 1:
+ # If namespace is of the format - shimmy.atari_env
+ ns = split_entry_point[1]
+ else:
+ # If namespace cannot be found, default to env name
+ ns = env_spec.name
+
+ namespace_envs[ns].append(env_spec.id)
+ max_justify = max(max_justify, len(env_spec.name))
+
+ # Iterate through each namespace and print environment alphabetically
+ output: list[str] = []
+ for ns, env_ids in namespace_envs.items():
+ # Ignore namespaces to exclude.
+ if exclude_namespaces is not None and ns in exclude_namespaces:
+ continue
+
+ # Print the namespace
+ namespace_output = f"{'=' * 5} {ns} {'=' * 5}\n"
+
+ # Reference: https://stackoverflow.com/a/33464001
+ for count, env_id in enumerate(sorted(env_ids), 1):
+ # Print column with justification.
+ namespace_output += env_id.ljust(max_justify) + " "
+
+ # Once all rows printed, switch to new column.
+ if count % num_cols == 0:
+ namespace_output = namespace_output.rstrip(" ")
+
+ if count != len(env_ids):
+ namespace_output += "\n"
+
+ output.append(namespace_output.rstrip(" "))
+
+ if disable_print:
+ return "\n".join(output)
+ else:
+ print("\n".join(output))
+
+
+"""Base class and definitions for an alternative, functional backend for gym envs, particularly suitable for hardware accelerated and otherwise transformed environments."""
+
+from __future__ import annotations
+
+from collections.abc import Callable
+from typing import Any, Generic, TypeVar
+
+import numpy as np
+
+from gymnasium import Space
+
+
+StateType = TypeVar("StateType")
+ActType = TypeVar("ActType")
+ObsType = TypeVar("ObsType")
+RewardType = TypeVar("RewardType")
+TerminalType = TypeVar("TerminalType")
+RenderStateType = TypeVar("RenderStateType")
+Params = TypeVar("Params")
+
+
+
+[docs]
+class FuncEnv(
+ Generic[
+ StateType, ObsType, ActType, RewardType, TerminalType, RenderStateType, Params
+ ]
+):
+ """Base class (template) for functional envs.
+
+ This API is meant to be used in a stateless manner, with the environment state being passed around explicitly.
+ That being said, nothing here prevents users from using the environment statefully, it's just not recommended.
+ A functional env consists of the following functions (in this case, instance methods):
+
+ * initial: returns the initial state of the POMDP
+ * observation: returns the observation in a given state
+ * transition: returns the next state after taking an action in a given state
+ * reward: returns the reward for a given (state, action, next_state) tuple
+ * terminal: returns whether a given state is terminal
+ * state_info: optional, returns a dict of info about a given state
+ * step_info: optional, returns a dict of info about a given (state, action, next_state) tuple
+
+ The class-based structure serves the purpose of allowing environment constants to be defined in the class,
+ and then using them by name in the code itself.
+
+ For the moment, this is predominantly for internal use. This API is likely to change, but in the future
+ we intend to flesh it out and officially expose it to end users.
+ """
+
+ observation_space: Space
+ action_space: Space
+
+ def __init__(self, options: dict[str, Any] | None = None):
+ """Initialize the environment constants."""
+ self.__dict__.update(options or {})
+ self.default_params = self.get_default_params()
+
+
+[docs]
+ def initial(self, rng: Any, params: Params | None = None) -> StateType:
+ """Generates the initial state of the environment with a random number generator."""
+ raise NotImplementedError
+
+
+
+[docs]
+ def transition(
+ self, state: StateType, action: ActType, rng: Any, params: Params | None = None
+ ) -> StateType:
+ """Updates (transitions) the state with an action and random number generator."""
+ raise NotImplementedError
+
+
+
+[docs]
+ def observation(
+ self, state: StateType, rng: Any, params: Params | None = None
+ ) -> ObsType:
+ """Generates an observation for a given state of an environment."""
+ raise NotImplementedError
+
+
+
+[docs]
+ def reward(
+ self,
+ state: StateType,
+ action: ActType,
+ next_state: StateType,
+ rng: Any,
+ params: Params | None = None,
+ ) -> RewardType:
+ """Computes the reward for a given transition between `state`, `action` to `next_state`."""
+ raise NotImplementedError
+
+
+
+[docs]
+ def terminal(
+ self, state: StateType, rng: Any, params: Params | None = None
+ ) -> TerminalType:
+ """Returns if the state is a final terminal state."""
+ raise NotImplementedError
+
+
+
+[docs]
+ def state_info(self, state: StateType, params: Params | None = None) -> dict:
+ """Info dict about a single state."""
+ return {}
+
+
+
+[docs]
+ def transition_info(
+ self,
+ state: StateType,
+ action: ActType,
+ next_state: StateType,
+ params: Params | None = None,
+ ) -> dict:
+ """Info dict about a full transition."""
+ return {}
+
+
+
+[docs]
+ def transform(self, func: Callable[[Callable], Callable]):
+ """Functional transformations."""
+ self.initial = func(self.initial)
+ self.transition = func(self.transition)
+ self.observation = func(self.observation)
+ self.reward = func(self.reward)
+ self.terminal = func(self.terminal)
+ self.state_info = func(self.state_info)
+ self.step_info = func(self.transition_info)
+
+
+
+[docs]
+ def render_image(
+ self,
+ state: StateType,
+ render_state: RenderStateType,
+ params: Params | None = None,
+ ) -> tuple[RenderStateType, np.ndarray]:
+ """Show the state."""
+ raise NotImplementedError
+
+
+
+[docs]
+ def render_init(self, params: Params | None = None, **kwargs) -> RenderStateType:
+ """Initialize the render state."""
+ raise NotImplementedError
+
+
+
+[docs]
+ def render_close(self, render_state: RenderStateType, params: Params | None = None):
+ """Close the render state."""
+ raise NotImplementedError
+
+
+ def get_default_params(self, **kwargs) -> Params | None:
+ """Get the default params."""
+ return None
+
+
+"""Implementation of a space that represents closed boxes in euclidean space."""
+
+from __future__ import annotations
+
+from collections.abc import Iterable, Mapping, Sequence
+from typing import Any, SupportsFloat
+
+import numpy as np
+from numpy.typing import NDArray
+
+import gymnasium as gym
+from gymnasium.spaces.space import Space
+
+
+def array_short_repr(arr: NDArray[Any]) -> str:
+ """Create a shortened string representation of a numpy array.
+
+ If arr is a multiple of the all-ones vector, return a string representation of the multiplier.
+ Otherwise, return a string representation of the entire array.
+
+ Args:
+ arr: The array to represent
+
+ Returns:
+ A short representation of the array
+ """
+ if arr.size != 0 and np.min(arr) == np.max(arr):
+ return str(np.min(arr))
+ return str(arr)
+
+
+def is_float_integer(var: Any) -> bool:
+ """Checks if a scalar variable is an integer or float (does not include bool)."""
+ return np.issubdtype(type(var), np.integer) or np.issubdtype(type(var), np.floating)
+
+
+
+[docs]
+class Box(Space[NDArray[Any]]):
+ r"""A (possibly unbounded) box in :math:`\mathbb{R}^n`.
+
+ Specifically, a Box represents the Cartesian product of n closed intervals.
+ Each interval has the form of one of :math:`[a, b]`, :math:`(-\infty, b]`,
+ :math:`[a, \infty)`, or :math:`(-\infty, \infty)`.
+
+ There are two common use cases:
+
+ * Identical bound for each dimension::
+
+ >>> Box(low=-1.0, high=2.0, shape=(3, 4), dtype=np.float32)
+ Box(-1.0, 2.0, (3, 4), float32)
+
+ * Independent bound for each dimension::
+
+ >>> Box(low=np.array([-1.0, -2.0]), high=np.array([2.0, 4.0]), dtype=np.float32)
+ Box([-1. -2.], [2. 4.], (2,), float32)
+ """
+
+ def __init__(
+ self,
+ low: SupportsFloat | NDArray[Any],
+ high: SupportsFloat | NDArray[Any],
+ shape: Sequence[int] | None = None,
+ dtype: type[np.floating[Any]] | type[np.integer[Any]] = np.float32,
+ seed: int | np.random.Generator | None = None,
+ ):
+ r"""Constructor of :class:`Box`.
+
+ The argument ``low`` specifies the lower bound of each dimension and ``high`` specifies the upper bounds.
+ I.e., the space that is constructed will be the product of the intervals :math:`[\text{low}[i], \text{high}[i]]`.
+
+ If ``low`` (or ``high``) is a scalar, the lower bound (or upper bound, respectively) will be assumed to be
+ this value across all dimensions.
+
+ Args:
+ low (SupportsFloat | np.ndarray): Lower bounds of the intervals. If integer, must be at least ``-2**63``.
+ high (SupportsFloat | np.ndarray]): Upper bounds of the intervals. If integer, must be at most ``2**63 - 2``.
+ shape (Optional[Sequence[int]]): The shape is inferred from the shape of `low` or `high` `np.ndarray`s with
+ `low` and `high` scalars defaulting to a shape of (1,)
+ dtype: The dtype of the elements of the space. If this is an integer type, the :class:`Box` is essentially a discrete space.
+ seed: Optionally, you can use this argument to seed the RNG that is used to sample from the space.
+
+ Raises:
+ ValueError: If no shape information is provided (shape is None, low is None and high is None) then a
+ value error is raised.
+ """
+ # determine dtype
+ if dtype is None:
+ raise ValueError("Box dtype must be explicitly provided, cannot be None.")
+ self.dtype = np.dtype(dtype)
+
+ # * check that dtype is an accepted dtype
+ if not (
+ np.issubdtype(self.dtype, np.integer)
+ or np.issubdtype(self.dtype, np.floating)
+ or self.dtype == np.bool_
+ ):
+ raise ValueError(
+ f"Invalid Box dtype ({self.dtype}), must be an integer, floating, or bool dtype"
+ )
+
+ # determine shape
+ if shape is not None:
+ if not isinstance(shape, Iterable):
+ raise TypeError(
+ f"Expected Box shape to be an iterable, actual type={type(shape)}"
+ )
+ elif not all(np.issubdtype(type(dim), np.integer) for dim in shape):
+ raise TypeError(
+ f"Expected all Box shape elements to be integer, actual type={tuple(type(dim) for dim in shape)}"
+ )
+
+ # Casts the `shape` argument to tuple[int, ...] (otherwise dim can `np.int64`)
+ shape = tuple(int(dim) for dim in shape)
+ elif isinstance(low, np.ndarray) and isinstance(high, np.ndarray):
+ if low.shape != high.shape:
+ raise ValueError(
+ f"Box low.shape and high.shape don't match, low.shape={low.shape}, high.shape={high.shape}"
+ )
+ shape = low.shape
+ elif isinstance(low, np.ndarray):
+ shape = low.shape
+ elif isinstance(high, np.ndarray):
+ shape = high.shape
+ elif is_float_integer(low) and is_float_integer(high):
+ shape = (1,) # low and high are scalars
+ else:
+ raise ValueError(
+ "Box shape is not specified, therefore inferred from low and high. Expected low and high to be np.ndarray, integer, or float."
+ f"Actual types low={type(low)}, high={type(high)}"
+ )
+ self._shape: tuple[int, ...] = shape
+
+ # Cast scalar values to `np.ndarray` and capture the boundedness information
+ # disallowed cases
+ # * out of range - this must be done before casting to low and high otherwise, the value is within dtype and cannot be out of range
+ # * nan - must be done beforehand as int dtype can cast `nan` to another value
+ # * unsign int inf and -inf - special case that is disallowed
+
+ if self.dtype == np.bool_:
+ dtype_min, dtype_max = 0, 1
+ elif np.issubdtype(self.dtype, np.floating):
+ dtype_min = float(np.finfo(self.dtype).min)
+ dtype_max = float(np.finfo(self.dtype).max)
+ else:
+ dtype_min = int(np.iinfo(self.dtype).min)
+ dtype_max = int(np.iinfo(self.dtype).max)
+
+ # Cast `low` and `high` to ndarray for the dtype min and max for out of range tests
+ self.low, self.bounded_below = self._cast_low(low, dtype_min)
+ self.high, self.bounded_above = self._cast_high(high, dtype_max)
+
+ # recheck shape for case where shape and (low or high) are provided
+ if self.low.shape != shape:
+ raise ValueError(
+ f"Box low.shape doesn't match provided shape, low.shape={self.low.shape}, shape={self.shape}"
+ )
+ if self.high.shape != shape:
+ raise ValueError(
+ f"Box high.shape doesn't match provided shape, high.shape={self.high.shape}, shape={self.shape}"
+ )
+
+ # check that low <= high
+ if np.any(self.low > self.high):
+ raise ValueError(
+ f"Box all low values must be less than or equal to high (some values break this), low={self.low}, high={self.high}"
+ )
+
+ self.low_repr = array_short_repr(self.low)
+ self.high_repr = array_short_repr(self.high)
+
+ super().__init__(self.shape, self.dtype, seed)
+
+ def _cast_low(self, low, dtype_min) -> tuple[np.ndarray, np.ndarray]:
+ """Casts the input Box low value to ndarray with provided dtype.
+
+ Args:
+ low: The input box low value
+ dtype_min: The dtype's minimum value
+
+ Returns:
+ The updated low value and for what values the input is bounded (below)
+ """
+ if is_float_integer(low):
+ bounded_below = -np.inf < np.full(self.shape, low, dtype=float)
+
+ if np.isnan(low):
+ raise ValueError(f"No low value can be equal to `np.nan`, low={low}")
+ elif np.isneginf(low):
+ if self.dtype.kind == "i": # signed int
+ low = dtype_min
+ elif self.dtype.kind in {"u", "b"}: # unsigned int and bool
+ raise ValueError(
+ f"Box unsigned int dtype don't support `-np.inf`, low={low}"
+ )
+ elif low < dtype_min:
+ raise ValueError(
+ f"Box low is out of bounds of the dtype range, low={low}, min dtype={dtype_min}"
+ )
+
+ low = np.full(self.shape, low, dtype=self.dtype)
+ return low, bounded_below
+ else: # cast for low - array
+ if not isinstance(low, np.ndarray):
+ raise ValueError(
+ f"Box low must be a np.ndarray, integer, or float, actual type={type(low)}"
+ )
+ elif not (
+ np.issubdtype(low.dtype, np.floating)
+ or np.issubdtype(low.dtype, np.integer)
+ or low.dtype == np.bool_
+ ):
+ raise ValueError(
+ f"Box low must be a floating, integer, or bool dtype, actual dtype={low.dtype}"
+ )
+ elif np.any(np.isnan(low)):
+ raise ValueError(f"No low value can be equal to `np.nan`, low={low}")
+
+ bounded_below = -np.inf < low
+
+ if np.any(np.isneginf(low)):
+ if self.dtype.kind == "i": # signed int
+ low[np.isneginf(low)] = dtype_min
+ elif self.dtype.kind in {"u", "b"}: # unsigned int and bool
+ raise ValueError(
+ f"Box unsigned int dtype don't support `-np.inf`, low={low}"
+ )
+ elif low.dtype != self.dtype and np.any(low < dtype_min):
+ raise ValueError(
+ f"Box low is out of bounds of the dtype range, low={low}, min dtype={dtype_min}"
+ )
+
+ if (
+ np.issubdtype(low.dtype, np.floating)
+ and np.issubdtype(self.dtype, np.floating)
+ and np.finfo(self.dtype).precision < np.finfo(low.dtype).precision
+ ):
+ gym.logger.warn(
+ f"Box low's precision lowered by casting to {self.dtype}, current low.dtype={low.dtype}"
+ )
+ return low.astype(self.dtype), bounded_below
+
+ def _cast_high(self, high, dtype_max) -> tuple[np.ndarray, np.ndarray]:
+ """Casts the input Box high value to ndarray with provided dtype.
+
+ Args:
+ high: The input box high value
+ dtype_max: The dtype's maximum value
+
+ Returns:
+ The updated high value and for what values the input is bounded (above)
+ """
+ if is_float_integer(high):
+ bounded_above = np.full(self.shape, high, dtype=float) < np.inf
+
+ if np.isnan(high):
+ raise ValueError(f"No high value can be equal to `np.nan`, high={high}")
+ elif np.isposinf(high):
+ if self.dtype.kind == "i": # signed int
+ high = dtype_max
+ elif self.dtype.kind in {"u", "b"}: # unsigned int
+ raise ValueError(
+ f"Box unsigned int dtype don't support `np.inf`, high={high}"
+ )
+ elif high > dtype_max:
+ raise ValueError(
+ f"Box high is out of bounds of the dtype range, high={high}, max dtype={dtype_max}"
+ )
+
+ high = np.full(self.shape, high, dtype=self.dtype)
+ return high, bounded_above
+ else:
+ if not isinstance(high, np.ndarray):
+ raise ValueError(
+ f"Box high must be a np.ndarray, integer, or float, actual type={type(high)}"
+ )
+ elif not (
+ np.issubdtype(high.dtype, np.floating)
+ or np.issubdtype(high.dtype, np.integer)
+ or high.dtype == np.bool_
+ ):
+ raise ValueError(
+ f"Box high must be a floating or integer dtype, actual dtype={high.dtype}"
+ )
+ elif np.any(np.isnan(high)):
+ raise ValueError(f"No high value can be equal to `np.nan`, high={high}")
+
+ bounded_above = high < np.inf
+
+ posinf = np.isposinf(high)
+ if np.any(posinf):
+ if self.dtype.kind == "i": # signed int
+ high[posinf] = dtype_max
+ elif self.dtype.kind in {"u", "b"}: # unsigned int
+ raise ValueError(
+ f"Box unsigned int dtype don't support `np.inf`, high={high}"
+ )
+ elif high.dtype != self.dtype and np.any(dtype_max < high):
+ raise ValueError(
+ f"Box high is out of bounds of the dtype range, high={high}, max dtype={dtype_max}"
+ )
+
+ if (
+ np.issubdtype(high.dtype, np.floating)
+ and np.issubdtype(self.dtype, np.floating)
+ and np.finfo(self.dtype).precision < np.finfo(high.dtype).precision
+ ):
+ gym.logger.warn(
+ f"Box high's precision lowered by casting to {self.dtype}, current high.dtype={high.dtype}"
+ )
+ return high.astype(self.dtype), bounded_above
+
+ @property
+ def shape(self) -> tuple[int, ...]:
+ """Has stricter type than gym.Space - never None."""
+ return self._shape
+
+ @property
+ def is_np_flattenable(self):
+ """Checks whether this space can be flattened to a :class:`spaces.Box`."""
+ return True
+
+
+[docs]
+ def is_bounded(self, manner: str = "both") -> bool:
+ """Checks whether the box is bounded in some sense.
+
+ Args:
+ manner (str): One of ``"both"``, ``"below"``, ``"above"``.
+
+ Returns:
+ If the space is bounded
+
+ Raises:
+ ValueError: If `manner` is neither ``"both"`` nor ``"below"`` or ``"above"``
+ """
+ below = bool(np.all(self.bounded_below))
+ above = bool(np.all(self.bounded_above))
+ if manner == "both":
+ return below and above
+ elif manner == "below":
+ return below
+ elif manner == "above":
+ return above
+ else:
+ raise ValueError(
+ f"manner is not in {{'below', 'above', 'both'}}, actual value: {manner}"
+ )
+
+
+
+[docs]
+ def sample(self, mask: None = None, probability: None = None) -> NDArray[Any]:
+ r"""Generates a single random sample inside the Box.
+
+ In creating a sample of the box, each coordinate is sampled (independently) from a distribution
+ that is chosen according to the form of the interval:
+
+ * :math:`[a, b]` : uniform distribution
+ * :math:`[a, \infty)` : shifted exponential distribution
+ * :math:`(-\infty, b]` : shifted negative exponential distribution
+ * :math:`(-\infty, \infty)` : normal distribution
+
+ Args:
+ mask: A mask for sampling values from the Box space, currently unsupported.
+ probability: A probability mask for sampling values from the Box space, currently unsupported.
+
+ Returns:
+ A sampled value from the Box
+ """
+ if mask is not None:
+ raise gym.error.Error(
+ f"Box.sample cannot be provided a mask, actual value: {mask}"
+ )
+ elif probability is not None:
+ raise gym.error.Error(
+ f"Box.sample cannot be provided a probability mask, actual value: {probability}"
+ )
+
+ high = self.high if self.dtype.kind == "f" else self.high.astype("int64") + 1
+ sample = np.empty(self.shape)
+
+ # Masking arrays which classify the coordinates according to interval type
+ unbounded = ~self.bounded_below & ~self.bounded_above
+ upp_bounded = ~self.bounded_below & self.bounded_above
+ low_bounded = self.bounded_below & ~self.bounded_above
+ bounded = self.bounded_below & self.bounded_above
+
+ # Vectorized sampling by interval type
+ sample[unbounded] = self.np_random.normal(size=unbounded[unbounded].shape)
+
+ sample[low_bounded] = (
+ self.np_random.exponential(size=low_bounded[low_bounded].shape)
+ + self.low[low_bounded]
+ )
+
+ sample[upp_bounded] = (
+ -self.np_random.exponential(size=upp_bounded[upp_bounded].shape)
+ + high[upp_bounded]
+ )
+
+ sample[bounded] = self.np_random.uniform(
+ low=self.low[bounded], high=high[bounded], size=bounded[bounded].shape
+ )
+
+ if self.dtype.kind in ["i", "u", "b"]:
+ sample = np.floor(sample)
+
+ # clip values that would underflow/overflow
+ if np.issubdtype(self.dtype, np.signedinteger):
+ dtype_min = np.iinfo(self.dtype).min + 2
+ dtype_max = np.iinfo(self.dtype).max - 2
+ sample = sample.clip(min=dtype_min, max=dtype_max)
+ elif np.issubdtype(self.dtype, np.unsignedinteger):
+ dtype_min = np.iinfo(self.dtype).min
+ dtype_max = np.iinfo(self.dtype).max
+ sample = sample.clip(min=dtype_min, max=dtype_max)
+
+ sample = sample.astype(self.dtype)
+
+ # float64 values have lower than integer precision near int64 min/max, so clip
+ # again in case something has been cast to an out-of-bounds value
+ if self.dtype == np.int64:
+ sample = sample.clip(min=self.low, max=self.high)
+
+ return sample
+
+
+ def contains(self, x: Any) -> bool:
+ """Return boolean specifying if x is a valid member of this space."""
+ if not isinstance(x, np.ndarray):
+ gym.logger.warn("Casting input x to numpy array.")
+ try:
+ x = np.asarray(x, dtype=self.dtype)
+ except (ValueError, TypeError):
+ return False
+
+ return bool(
+ np.can_cast(x.dtype, self.dtype)
+ and x.shape == self.shape
+ and np.all(x >= self.low)
+ and np.all(x <= self.high)
+ )
+
+ def to_jsonable(self, sample_n: Sequence[NDArray[Any]]) -> list[list]:
+ """Convert a batch of samples from this space to a JSONable data type."""
+ return [sample.tolist() for sample in sample_n]
+
+ def from_jsonable(self, sample_n: Sequence[float | int]) -> list[NDArray[Any]]:
+ """Convert a JSONable data type to a batch of samples from this space."""
+ return [np.asarray(sample, dtype=self.dtype) for sample in sample_n]
+
+ def __repr__(self) -> str:
+ """A string representation of this space.
+
+ The representation will include bounds, shape and dtype.
+ If a bound is uniform, only the corresponding scalar will be given to avoid redundant and ugly strings.
+
+ Returns:
+ A representation of the space
+ """
+ return f"Box({self.low_repr}, {self.high_repr}, {self.shape}, {self.dtype})"
+
+ def __eq__(self, other: Any) -> bool:
+ """Check whether `other` is equivalent to this instance. Doesn't check dtype equivalence."""
+ return (
+ isinstance(other, Box)
+ and (self.shape == other.shape)
+ and (self.dtype == other.dtype)
+ and np.allclose(self.low, other.low)
+ and np.allclose(self.high, other.high)
+ )
+
+ def __setstate__(self, state: Iterable[tuple[str, Any]] | Mapping[str, Any]):
+ """Sets the state of the box for unpickling a box with legacy support."""
+ super().__setstate__(state)
+
+ # legacy support through re-adding "low_repr" and "high_repr" if missing from pickled state
+ if not hasattr(self, "low_repr"):
+ self.low_repr = array_short_repr(self.low)
+
+ if not hasattr(self, "high_repr"):
+ self.high_repr = array_short_repr(self.high)
+
+
+"""Implementation of a space that represents the cartesian product of other spaces as a dictionary."""
+
+from __future__ import annotations
+
+import collections.abc
+import typing
+from collections import OrderedDict
+from collections.abc import KeysView, Sequence
+from typing import Any
+
+import numpy as np
+
+from gymnasium.spaces.space import Space
+
+
+
+[docs]
+class Dict(Space[dict[str, Any]], typing.Mapping[str, Space[Any]]):
+ """A dictionary of :class:`Space` instances.
+
+ Elements of this space are (ordered) dictionaries of elements from the constituent spaces.
+
+ Example:
+ >>> from gymnasium.spaces import Dict, Box, Discrete
+ >>> observation_space = Dict({"position": Box(-1, 1, shape=(2,)), "color": Discrete(3)}, seed=42)
+ >>> observation_space.sample()
+ {'color': np.int64(0), 'position': array([-0.3991573 , 0.21649833], dtype=float32)}
+
+ With a nested dict:
+
+ >>> from gymnasium.spaces import Box, Dict, Discrete, MultiBinary, MultiDiscrete
+ >>> Dict( # doctest: +SKIP
+ ... {
+ ... "ext_controller": MultiDiscrete([5, 2, 2]),
+ ... "inner_state": Dict(
+ ... {
+ ... "charge": Discrete(100),
+ ... "system_checks": MultiBinary(10),
+ ... "job_status": Dict(
+ ... {
+ ... "task": Discrete(5),
+ ... "progress": Box(low=0, high=100, shape=()),
+ ... }
+ ... ),
+ ... }
+ ... ),
+ ... }
+ ... )
+
+ It can be convenient to use :class:`Dict` spaces if you want to make complex observations or actions more human-readable.
+ Usually, it will not be possible to use elements of this space directly in learning code. However, you can easily
+ convert :class:`Dict` observations to flat arrays by using a :class:`gymnasium.wrappers.FlattenObservation` wrapper.
+ Similar wrappers can be implemented to deal with :class:`Dict` actions.
+ """
+
+ def __init__(
+ self,
+ spaces: None | dict[str, Space] | Sequence[tuple[str, Space]] = None,
+ seed: dict | int | np.random.Generator | None = None,
+ **spaces_kwargs: Space,
+ ):
+ """Constructor of :class:`Dict` space.
+
+ This space can be instantiated in one of two ways: Either you pass a dictionary
+ of spaces to :meth:`__init__` via the ``spaces`` argument, or you pass the spaces as separate
+ keyword arguments (where you will need to avoid the keys ``spaces`` and ``seed``)
+
+ Args:
+ spaces: A dictionary of spaces. This specifies the structure of the :class:`Dict` space
+ seed: Optionally, you can use this argument to seed the RNGs of the spaces that make up the :class:`Dict` space.
+ **spaces_kwargs: If ``spaces`` is ``None``, you need to pass the constituent spaces as keyword arguments, as described above.
+ """
+ if isinstance(spaces, OrderedDict):
+ spaces = dict(spaces.items())
+ elif isinstance(spaces, collections.abc.Mapping):
+ # for legacy reasons, we need to preserve the sorted dictionary items.
+ # as this could matter for projects flatten the dictionary.
+ try:
+ spaces = dict(sorted(spaces.items()))
+ except TypeError:
+ # Incomparable types (e.g. `int` vs. `str`, or user-defined types) found.
+ # The keys remain in the insertion order.
+ spaces = dict(spaces.items())
+ elif isinstance(spaces, Sequence):
+ spaces = dict(spaces)
+ elif spaces is None:
+ spaces = dict()
+ else:
+ raise TypeError(
+ f"Unexpected Dict space input, expecting dict, OrderedDict or Sequence, actual type: {type(spaces)}"
+ )
+
+ # Add kwargs to spaces to allow both dictionary and keywords to be used
+ for key, space in spaces_kwargs.items():
+ if key not in spaces:
+ spaces[key] = space
+ else:
+ raise ValueError(
+ f"Dict space keyword '{key}' already exists in the spaces dictionary."
+ )
+
+ self.spaces: dict[str, Space[Any]] = spaces
+ for key, space in self.spaces.items():
+ assert isinstance(
+ space, Space
+ ), f"Dict space element is not an instance of Space: key='{key}', space={space}"
+
+ # None for shape and dtype, since it'll require special handling
+ super().__init__(None, None, seed) # type: ignore
+
+ @property
+ def is_np_flattenable(self):
+ """Checks whether this space can be flattened to a :class:`spaces.Box`."""
+ return all(space.is_np_flattenable for space in self.spaces.values())
+
+
+[docs]
+ def seed(self, seed: int | dict[str, Any] | None = None) -> dict[str, int]:
+ """Seed the PRNG of this space and all subspaces.
+
+ Depending on the type of seed, the subspaces will be seeded differently
+
+ * ``None`` - All the subspaces will use a random initial seed
+ * ``Int`` - The integer is used to seed the :class:`Dict` space that is used to generate seed values for each of the subspaces. Warning, this does not guarantee unique seeds for all subspaces, though is very unlikely.
+ * ``Dict`` - A dictionary of seeds for each subspace, requires a seed key for every subspace. This supports seeding of multiple composite subspaces (``Dict["space": Dict[...], ...]`` with ``{"space": {...}, ...}``).
+
+ Args:
+ seed: An optional int or dictionary of subspace keys to int to seed each PRNG. See above for more details.
+
+ Returns:
+ A dictionary for the seed values of the subspaces
+ """
+ if seed is None:
+ return {key: subspace.seed(None) for (key, subspace) in self.spaces.items()}
+ elif isinstance(seed, int):
+ super().seed(seed)
+ # Using `np.int32` will mean that the same key occurring is extremely low, even for large subspaces
+ subseeds = self.np_random.integers(
+ np.iinfo(np.int32).max, size=len(self.spaces)
+ )
+ return {
+ key: subspace.seed(int(subseed))
+ for (key, subspace), subseed in zip(self.spaces.items(), subseeds)
+ }
+ elif isinstance(seed, dict):
+ if seed.keys() != self.spaces.keys():
+ raise ValueError(
+ f"The seed keys: {seed.keys()} are not identical to space keys: {self.spaces.keys()}"
+ )
+
+ return {key: self.spaces[key].seed(seed[key]) for key in seed.keys()}
+ else:
+ raise TypeError(
+ f"Expected seed type: dict, int or None, actual type: {type(seed)}"
+ )
+
+
+
+[docs]
+ def sample(
+ self,
+ mask: dict[str, Any] | None = None,
+ probability: dict[str, Any] | None = None,
+ ) -> dict[str, Any]:
+ """Generates a single random sample from this space.
+
+ The sample is an ordered dictionary of independent samples from the constituent spaces.
+
+ Args:
+ mask: An optional mask for each of the subspaces, expects the same keys as the space
+ probability: An optional probability mask for each of the subspaces, expects the same keys as the space
+
+ Returns:
+ A dictionary with the same key and sampled values from :attr:`self.spaces`
+ """
+ if mask is not None and probability is not None:
+ raise ValueError(
+ f"Only one of `mask` or `probability` can be provided, actual values: mask={mask}, probability={probability}"
+ )
+ elif mask is not None:
+ assert isinstance(
+ mask, dict
+ ), f"Expected sample mask to be a dict, actual type: {type(mask)}"
+ assert (
+ mask.keys() == self.spaces.keys()
+ ), f"Expected sample mask keys to be same as space keys, mask keys: {mask.keys()}, space keys: {self.spaces.keys()}"
+
+ return {k: space.sample(mask=mask[k]) for k, space in self.spaces.items()}
+ elif probability is not None:
+ assert isinstance(
+ probability, dict
+ ), f"Expected sample probability mask to be a dict, actual type: {type(probability)}"
+ assert (
+ probability.keys() == self.spaces.keys()
+ ), f"Expected sample probability mask keys to be same as space keys, mask keys: {probability.keys()}, space keys: {self.spaces.keys()}"
+
+ return {
+ k: space.sample(probability=probability[k])
+ for k, space in self.spaces.items()
+ }
+ else:
+ return {k: space.sample() for k, space in self.spaces.items()}
+
+
+ def contains(self, x: Any) -> bool:
+ """Return boolean specifying if x is a valid member of this space."""
+ if isinstance(x, dict) and x.keys() == self.spaces.keys():
+ return all(x[key] in self.spaces[key] for key in self.spaces.keys())
+ return False
+
+ def __getitem__(self, key: str) -> Space[Any]:
+ """Get the space that is associated to `key`."""
+ return self.spaces[key]
+
+ def keys(self) -> KeysView:
+ """Returns the keys of the Dict."""
+ return KeysView(self.spaces)
+
+ def __setitem__(self, key: str, value: Space[Any]):
+ """Set the space that is associated to `key`."""
+ assert isinstance(
+ value, Space
+ ), f"Trying to set {key} to Dict space with value that is not a gymnasium space, actual type: {type(value)}"
+ self.spaces[key] = value
+
+ def __iter__(self):
+ """Iterator through the keys of the subspaces."""
+ yield from self.spaces
+
+ def __len__(self) -> int:
+ """Gives the number of simpler spaces that make up the `Dict` space."""
+ return len(self.spaces)
+
+ def __repr__(self) -> str:
+ """Gives a string representation of this space."""
+ return (
+ "Dict(" + ", ".join([f"{k!r}: {s}" for k, s in self.spaces.items()]) + ")"
+ )
+
+ def __eq__(self, other: Any) -> bool:
+ """Check whether `other` is equivalent to this instance."""
+ return (
+ isinstance(other, Dict)
+ # Comparison of `OrderedDict`s is order-sensitive
+ and self.spaces == other.spaces # OrderedDict.__eq__
+ )
+
+ def to_jsonable(self, sample_n: Sequence[dict[str, Any]]) -> dict[str, list[Any]]:
+ """Convert a batch of samples from this space to a JSONable data type."""
+ # serialize as dict-repr of vectors
+ return {
+ key: space.to_jsonable([sample[key] for sample in sample_n])
+ for key, space in self.spaces.items()
+ }
+
+ def from_jsonable(self, sample_n: dict[str, list[Any]]) -> list[dict[str, Any]]:
+ """Convert a JSONable data type to a batch of samples from this space."""
+ dict_of_list: dict[str, list[Any]] = {
+ key: space.from_jsonable(sample_n[key])
+ for key, space in self.spaces.items()
+ }
+
+ n_elements = len(next(iter(dict_of_list.values())))
+ result = [
+ {key: value[n] for key, value in dict_of_list.items()}
+ for n in range(n_elements)
+ ]
+ return result
+
+
+"""Implementation of a space consisting of finitely many elements."""
+
+from __future__ import annotations
+
+from collections.abc import Iterable, Mapping, Sequence
+from typing import Any
+
+import numpy as np
+
+from gymnasium.spaces.space import MaskNDArray, Space
+
+
+
+[docs]
+class Discrete(Space[np.int64]):
+ r"""A space consisting of finitely many elements.
+
+ This class represents a finite subset of integers, more specifically a set of the form :math:`\{ a, a+1, \dots, a+n-1 \}`.
+
+ Example:
+ >>> from gymnasium.spaces import Discrete
+ >>> observation_space = Discrete(2, seed=42) # {0, 1}
+ >>> observation_space.sample()
+ np.int64(0)
+ >>> observation_space = Discrete(3, start=-1, seed=42) # {-1, 0, 1}
+ >>> observation_space.sample()
+ np.int64(-1)
+ >>> observation_space.sample(mask=np.array([0,0,1], dtype=np.int8))
+ np.int64(1)
+ >>> observation_space.sample(probability=np.array([0,0,1], dtype=np.float64))
+ np.int64(1)
+ >>> observation_space.sample(probability=np.array([0,0.3,0.7], dtype=np.float64))
+ np.int64(1)
+ """
+
+ def __init__(
+ self,
+ n: int | np.integer[Any],
+ seed: int | np.random.Generator | None = None,
+ start: int | np.integer[Any] = 0,
+ ):
+ r"""Constructor of :class:`Discrete` space.
+
+ This will construct the space :math:`\{\text{start}, ..., \text{start} + n - 1\}`.
+
+ Args:
+ n (int): The number of elements of this space.
+ seed: Optionally, you can use this argument to seed the RNG that is used to sample from the ``Dict`` space.
+ start (int): The smallest element of this space.
+ """
+ assert np.issubdtype(
+ type(n), np.integer
+ ), f"Expects `n` to be an integer, actual dtype: {type(n)}"
+ assert n > 0, "n (counts) have to be positive"
+ assert np.issubdtype(
+ type(start), np.integer
+ ), f"Expects `start` to be an integer, actual type: {type(start)}"
+
+ self.n = np.int64(n)
+ self.start = np.int64(start)
+ super().__init__((), np.int64, seed)
+
+ @property
+ def is_np_flattenable(self):
+ """Checks whether this space can be flattened to a :class:`spaces.Box`."""
+ return True
+
+
+[docs]
+ def sample(
+ self, mask: MaskNDArray | None = None, probability: MaskNDArray | None = None
+ ) -> np.int64:
+ """Generates a single random sample from this space.
+
+ A sample will be chosen uniformly at random with the mask if provided, or it will be chosen according to a specified probability distribution if the probability mask is provided.
+
+ Args:
+ mask: An optional mask for if an action can be selected.
+ Expected `np.ndarray` of shape ``(n,)`` and dtype ``np.int8`` where ``1`` represents valid actions and ``0`` invalid / infeasible actions.
+ If there are no possible actions (i.e. ``np.all(mask == 0)``) then ``space.start`` will be returned.
+ probability: An optional probability mask describing the probability of each action being selected.
+ Expected `np.ndarray` of shape ``(n,)`` and dtype ``np.float64`` where each value is in the range ``[0, 1]`` and the sum of all values is 1.
+ If the values do not sum to 1, an exception will be thrown.
+
+ Returns:
+ A sampled integer from the space
+ """
+ if mask is not None and probability is not None:
+ raise ValueError(
+ f"Only one of `mask` or `probability` can be provided, actual values: mask={mask}, probability={probability}"
+ )
+ # binary mask sampling
+ elif mask is not None:
+ assert isinstance(
+ mask, np.ndarray
+ ), f"The expected type of the sample mask is np.ndarray, actual type: {type(mask)}"
+ assert (
+ mask.dtype == np.int8
+ ), f"The expected dtype of the sample mask is np.int8, actual dtype: {mask.dtype}"
+ assert mask.shape == (
+ self.n,
+ ), f"The expected shape of the sample mask is {(int(self.n),)}, actual shape: {mask.shape}"
+
+ valid_action_mask = mask == 1
+ assert np.all(
+ np.logical_or(mask == 0, valid_action_mask)
+ ), f"All values of the sample mask should be 0 or 1, actual values: {mask}"
+
+ if np.any(valid_action_mask):
+ return self.start + self.np_random.choice(
+ np.where(valid_action_mask)[0]
+ )
+ else:
+ return self.start
+ # probability mask sampling
+ elif probability is not None:
+ assert isinstance(
+ probability, np.ndarray
+ ), f"The expected type of the sample probability is np.ndarray, actual type: {type(probability)}"
+ assert (
+ probability.dtype == np.float64
+ ), f"The expected dtype of the sample probability is np.float64, actual dtype: {probability.dtype}"
+ assert probability.shape == (
+ self.n,
+ ), f"The expected shape of the sample probability is {(int(self.n),)}, actual shape: {probability.shape}"
+
+ assert np.all(
+ np.logical_and(probability >= 0, probability <= 1)
+ ), f"All values of the sample probability should be between 0 and 1, actual values: {probability}"
+ assert np.isclose(
+ np.sum(probability), 1
+ ), f"The sum of the sample probability should be equal to 1, actual sum: {np.sum(probability)}"
+
+ return self.start + self.np_random.choice(np.arange(self.n), p=probability)
+ # uniform sampling
+ else:
+ return self.start + self.np_random.integers(self.n)
+
+
+ def contains(self, x: Any) -> bool:
+ """Return boolean specifying if x is a valid member of this space."""
+ if isinstance(x, int):
+ as_int64 = np.int64(x)
+ elif isinstance(x, (np.generic, np.ndarray)) and (
+ np.issubdtype(x.dtype, np.integer) and x.shape == ()
+ ):
+ as_int64 = np.int64(x)
+ else:
+ return False
+
+ return bool(self.start <= as_int64 < self.start + self.n)
+
+ def __repr__(self) -> str:
+ """Gives a string representation of this space."""
+ if self.start != 0:
+ return f"Discrete({self.n}, start={self.start})"
+ return f"Discrete({self.n})"
+
+ def __eq__(self, other: Any) -> bool:
+ """Check whether ``other`` is equivalent to this instance."""
+ return (
+ isinstance(other, Discrete)
+ and self.n == other.n
+ and self.start == other.start
+ )
+
+ def __setstate__(self, state: Iterable[tuple[str, Any]] | Mapping[str, Any]):
+ """Used when loading a pickled space.
+
+ This method has to be implemented explicitly to allow for loading of legacy states.
+
+ Args:
+ state: The new state
+ """
+ # Don't mutate the original state
+ state = dict(state)
+
+ # Allow for loading of legacy states.
+ # See https://github.com/openai/gym/pull/2470
+ if "start" not in state:
+ state["start"] = np.int64(0)
+
+ super().__setstate__(state)
+
+ def to_jsonable(self, sample_n: Sequence[np.int64]) -> list[int]:
+ """Converts a list of samples to a list of ints."""
+ return [int(x) for x in sample_n]
+
+ def from_jsonable(self, sample_n: list[int]) -> list[np.int64]:
+ """Converts a list of json samples to a list of np.int64."""
+ return [np.int64(x) for x in sample_n]
+
+
+"""Implementation of a space that represents graph information where nodes and edges can be represented with euclidean space."""
+
+from __future__ import annotations
+
+from collections.abc import Sequence
+from typing import Any, NamedTuple
+
+import numpy as np
+from numpy.typing import NDArray
+
+import gymnasium as gym
+from gymnasium.spaces.box import Box
+from gymnasium.spaces.discrete import Discrete
+from gymnasium.spaces.multi_discrete import MultiDiscrete
+from gymnasium.spaces.space import Space
+
+
+class GraphInstance(NamedTuple):
+ """A Graph space instance.
+
+ * nodes (np.ndarray): an (n x ...) sized array representing the features for n nodes, (...) must adhere to the shape of the node space.
+ * edges (Optional[np.ndarray]): an (m x ...) sized array representing the features for m edges, (...) must adhere to the shape of the edge space.
+ * edge_links (Optional[np.ndarray]): an (m x 2) sized array of ints representing the indices of the two nodes that each edge connects.
+ """
+
+ nodes: NDArray[Any]
+ edges: NDArray[Any] | None
+ edge_links: NDArray[Any] | None
+
+
+
+[docs]
+class Graph(Space[GraphInstance]):
+ r"""A space representing graph information as a series of ``nodes`` connected with ``edges`` according to an adjacency matrix represented as a series of ``edge_links``.
+
+ Example:
+ >>> from gymnasium.spaces import Graph, Box, Discrete
+ >>> observation_space = Graph(node_space=Box(low=-100, high=100, shape=(3,)), edge_space=Discrete(3), seed=123)
+ >>> observation_space.sample(num_nodes=4, num_edges=8)
+ GraphInstance(nodes=array([[ 36.47037 , -89.235794, -55.928024],
+ [-63.125637, -64.81882 , 62.4189 ],
+ [ 84.669 , -44.68512 , 63.950912],
+ [ 77.97854 , 2.594091, -51.00708 ]], dtype=float32), edges=array([2, 0, 2, 1, 2, 0, 2, 1]), edge_links=array([[3, 0],
+ [0, 0],
+ [0, 1],
+ [0, 2],
+ [1, 0],
+ [1, 0],
+ [0, 1],
+ [0, 2]], dtype=int32))
+ """
+
+ def __init__(
+ self,
+ node_space: Box | Discrete,
+ edge_space: None | Box | Discrete,
+ seed: int | np.random.Generator | None = None,
+ ):
+ r"""Constructor of :class:`Graph`.
+
+ The argument ``node_space`` specifies the base space that each node feature will use.
+ This argument must be either a Box or Discrete instance.
+
+ The argument ``edge_space`` specifies the base space that each edge feature will use.
+ This argument must be either a None, Box or Discrete instance.
+
+ Args:
+ node_space (Union[Box, Discrete]): space of the node features.
+ edge_space (Union[None, Box, Discrete]): space of the edge features.
+ seed: Optionally, you can use this argument to seed the RNG that is used to sample from the space.
+ """
+ assert isinstance(
+ node_space, (Box, Discrete)
+ ), f"Values of the node_space should be instances of Box or Discrete, got {type(node_space)}"
+ if edge_space is not None:
+ assert isinstance(
+ edge_space, (Box, Discrete)
+ ), f"Values of the edge_space should be instances of None Box or Discrete, got {type(edge_space)}"
+
+ self.node_space = node_space
+ self.edge_space = edge_space
+
+ super().__init__(None, None, seed)
+
+ @property
+ def is_np_flattenable(self):
+ """Checks whether this space can be flattened to a :class:`spaces.Box`."""
+ return False
+
+ def _generate_sample_space(
+ self, base_space: None | Box | Discrete, num: int
+ ) -> Box | MultiDiscrete | None:
+ if num == 0 or base_space is None:
+ return None
+
+ if isinstance(base_space, Box):
+ return Box(
+ low=np.array(max(1, num) * [base_space.low]),
+ high=np.array(max(1, num) * [base_space.high]),
+ shape=(num,) + base_space.shape,
+ dtype=base_space.dtype,
+ seed=self.np_random,
+ )
+ elif isinstance(base_space, Discrete):
+ return MultiDiscrete(nvec=[base_space.n] * num, seed=self.np_random)
+ else:
+ raise TypeError(
+ f"Expects base space to be Box and Discrete, actual space: {type(base_space)}."
+ )
+
+
+[docs]
+ def seed(
+ self, seed: int | tuple[int, int] | tuple[int, int, int] | None = None
+ ) -> tuple[int, int] | tuple[int, int, int]:
+ """Seeds the PRNG of this space and node / edge subspace.
+
+ Depending on the type of seed, the subspaces will be seeded differently
+
+ * ``None`` - The root, node and edge spaces PRNG are randomly initialized
+ * ``Int`` - The integer is used to seed the :class:`Graph` space that is used to generate seed values for the node and edge subspaces.
+ * ``Tuple[int, int]`` - Seeds the :class:`Graph` and node subspace with a particular value. Only if edge subspace isn't specified
+ * ``Tuple[int, int, int]`` - Seeds the :class:`Graph`, node and edge subspaces with a particular value.
+
+ Args:
+ seed: An optional int or tuple of ints for this space and the node / edge subspaces. See above for more details.
+
+ Returns:
+ A tuple of two or three ints depending on if the edge subspace is specified.
+ """
+ if seed is None:
+ if self.edge_space is None:
+ return super().seed(None), self.node_space.seed(None)
+ else:
+ return (
+ super().seed(None),
+ self.node_space.seed(None),
+ self.edge_space.seed(None),
+ )
+ elif isinstance(seed, int):
+ if self.edge_space is None:
+ super_seed = super().seed(seed)
+ node_seed = int(self.np_random.integers(np.iinfo(np.int32).max))
+ # this is necessary such that after int or list/tuple seeding, the Graph PRNG are equivalent
+ super().seed(seed)
+ return super_seed, self.node_space.seed(node_seed)
+ else:
+ super_seed = super().seed(seed)
+ node_seed, edge_seed = self.np_random.integers(
+ np.iinfo(np.int32).max, size=(2,)
+ )
+ # this is necessary such that after int or list/tuple seeding, the Graph PRNG are equivalent
+ super().seed(seed)
+ return (
+ super_seed,
+ self.node_space.seed(int(node_seed)),
+ self.edge_space.seed(int(edge_seed)),
+ )
+ elif isinstance(seed, (list, tuple)):
+ if self.edge_space is None:
+ if len(seed) != 2:
+ raise ValueError(
+ f"Expects a tuple of two values for Graph and node space, actual length: {len(seed)}"
+ )
+
+ return super().seed(seed[0]), self.node_space.seed(seed[1])
+ else:
+ if len(seed) != 3:
+ raise ValueError(
+ f"Expects a tuple of three values for Graph, node and edge space, actual length: {len(seed)}"
+ )
+
+ return (
+ super().seed(seed[0]),
+ self.node_space.seed(seed[1]),
+ self.edge_space.seed(seed[2]),
+ )
+ else:
+ raise TypeError(
+ f"Expects `None`, int or tuple of ints, actual type: {type(seed)}"
+ )
+
+
+
+[docs]
+ def sample(
+ self,
+ mask: None | (
+ tuple[
+ NDArray[Any] | tuple[Any, ...] | None,
+ NDArray[Any] | tuple[Any, ...] | None,
+ ]
+ ) = None,
+ probability: None | (
+ tuple[
+ NDArray[Any] | tuple[Any, ...] | None,
+ NDArray[Any] | tuple[Any, ...] | None,
+ ]
+ ) = None,
+ num_nodes: int = 10,
+ num_edges: int | None = None,
+ ) -> GraphInstance:
+ """Generates a single sample graph with num_nodes between ``1`` and ``10`` sampled from the Graph.
+
+ Args:
+ mask: An optional tuple of optional node and edge mask that is only possible with Discrete spaces
+ (Box spaces don't support sample masks).
+ If no ``num_edges`` is provided then the ``edge_mask`` is multiplied by the number of edges
+ probability: An optional tuple of optional node and edge probability mask that is only possible with Discrete spaces
+ (Box spaces don't support sample probability masks).
+ If no ``num_edges`` is provided then the ``edge_mask`` is multiplied by the number of edges
+ num_nodes: The number of nodes that will be sampled, the default is `10` nodes
+ num_edges: An optional number of edges, otherwise, a random number between `0` and :math:`num_nodes^2`
+
+ Returns:
+ A :class:`GraphInstance` with attributes `.nodes`, `.edges`, and `.edge_links`.
+ """
+ assert (
+ num_nodes > 0
+ ), f"The number of nodes is expected to be greater than 0, actual value: {num_nodes}"
+
+ if mask is not None and probability is not None:
+ raise ValueError(
+ f"Only one of `mask` or `probability` can be provided, actual values: mask={mask}, probability={probability}"
+ )
+ elif mask is not None:
+ node_space_mask, edge_space_mask = mask
+ mask_type = "mask"
+ elif probability is not None:
+ node_space_mask, edge_space_mask = probability
+ mask_type = "probability"
+ else:
+ node_space_mask = edge_space_mask = mask_type = None
+
+ # we only have edges when we have at least 2 nodes
+ if num_edges is None:
+ if num_nodes > 1:
+ # maximal number of edges is `n*(n-1)` allowing self connections and two-way is allowed
+ num_edges = self.np_random.integers(num_nodes * (num_nodes - 1))
+ else:
+ num_edges = 0
+
+ if edge_space_mask is not None:
+ edge_space_mask = tuple(edge_space_mask for _ in range(num_edges))
+ else:
+ if self.edge_space is None:
+ gym.logger.warn(
+ f"The number of edges is set ({num_edges}) but the edge space is None."
+ )
+ assert (
+ num_edges >= 0
+ ), f"Expects the number of edges to be greater than 0, actual value: {num_edges}"
+ assert num_edges is not None
+
+ sampled_node_space = self._generate_sample_space(self.node_space, num_nodes)
+ assert sampled_node_space is not None
+ sampled_edge_space = self._generate_sample_space(self.edge_space, num_edges)
+
+ if mask_type is not None:
+ node_sample_kwargs = {mask_type: node_space_mask}
+ edge_sample_kwargs = {mask_type: edge_space_mask}
+ else:
+ node_sample_kwargs = edge_sample_kwargs = {}
+
+ sampled_nodes = sampled_node_space.sample(**node_sample_kwargs)
+ sampled_edges = None
+ if sampled_edge_space is not None:
+ sampled_edges = sampled_edge_space.sample(**edge_sample_kwargs)
+
+ sampled_edge_links = None
+ if sampled_edges is not None and num_edges > 0:
+ sampled_edge_links = self.np_random.integers(
+ low=0, high=num_nodes, size=(num_edges, 2), dtype=np.int32
+ )
+
+ return GraphInstance(sampled_nodes, sampled_edges, sampled_edge_links)
+
+
+ def contains(self, x: GraphInstance) -> bool:
+ """Return boolean specifying if x is a valid member of this space."""
+ if isinstance(x, GraphInstance):
+ # Checks the nodes
+ if isinstance(x.nodes, np.ndarray):
+ if all(node in self.node_space for node in x.nodes):
+ # Check the edges and edge links which are optional
+ if isinstance(x.edges, np.ndarray) and isinstance(
+ x.edge_links, np.ndarray
+ ):
+ assert x.edges is not None
+ assert x.edge_links is not None
+ if self.edge_space is not None:
+ if all(edge in self.edge_space for edge in x.edges):
+ if np.issubdtype(x.edge_links.dtype, np.integer):
+ if x.edge_links.shape == (len(x.edges), 2):
+ if np.all(
+ np.logical_and(
+ x.edge_links >= 0,
+ x.edge_links < len(x.nodes),
+ )
+ ):
+ return True
+ else:
+ return x.edges is None and x.edge_links is None
+ return False
+
+ def __repr__(self) -> str:
+ """A string representation of this space.
+
+ The representation will include ``node_space`` and ``edge_space``
+
+ Returns:
+ A representation of the space
+ """
+ return f"Graph({self.node_space}, {self.edge_space})"
+
+ def __eq__(self, other: Any) -> bool:
+ """Check whether `other` is equivalent to this instance."""
+ return (
+ isinstance(other, Graph)
+ and (self.node_space == other.node_space)
+ and (self.edge_space == other.edge_space)
+ )
+
+ def to_jsonable(
+ self, sample_n: Sequence[GraphInstance]
+ ) -> list[dict[str, list[int | float]]]:
+ """Convert a batch of samples from this space to a JSONable data type."""
+ ret_n = []
+ for sample in sample_n:
+ ret = {"nodes": sample.nodes.tolist()}
+ if sample.edges is not None and sample.edge_links is not None:
+ ret["edges"] = sample.edges.tolist()
+ ret["edge_links"] = sample.edge_links.tolist()
+ ret_n.append(ret)
+ return ret_n
+
+ def from_jsonable(
+ self, sample_n: Sequence[dict[str, list[list[int] | list[float]]]]
+ ) -> list[GraphInstance]:
+ """Convert a JSONable data type to a batch of samples from this space."""
+ ret: list[GraphInstance] = []
+ for sample in sample_n:
+ if "edges" in sample:
+ assert self.edge_space is not None
+ ret_n = GraphInstance(
+ np.asarray(sample["nodes"], dtype=self.node_space.dtype),
+ np.asarray(sample["edges"], dtype=self.edge_space.dtype),
+ np.asarray(sample["edge_links"], dtype=np.int32),
+ )
+ else:
+ ret_n = GraphInstance(
+ np.asarray(sample["nodes"], dtype=self.node_space.dtype),
+ None,
+ None,
+ )
+ ret.append(ret_n)
+ return ret
+
+
+"""Implementation of a space that consists of binary np.ndarrays of a fixed shape."""
+
+from __future__ import annotations
+
+from collections.abc import Sequence
+from typing import Any
+
+import numpy as np
+from numpy.typing import NDArray
+
+from gymnasium.spaces.space import MaskNDArray, Space
+
+
+
+[docs]
+class MultiBinary(Space[NDArray[np.int8]]):
+ """An n-shape binary space.
+
+ Elements of this space are binary arrays of a shape that is fixed during construction.
+
+ Example:
+ >>> from gymnasium.spaces import MultiBinary
+ >>> observation_space = MultiBinary(5, seed=42)
+ >>> observation_space.sample()
+ array([1, 0, 1, 0, 1], dtype=int8)
+ >>> observation_space = MultiBinary([3, 2], seed=42)
+ >>> observation_space.sample()
+ array([[1, 0],
+ [1, 0],
+ [1, 1]], dtype=int8)
+ """
+
+ def __init__(
+ self,
+ n: NDArray[np.integer[Any]] | Sequence[int] | int,
+ seed: int | np.random.Generator | None = None,
+ ):
+ """Constructor of :class:`MultiBinary` space.
+
+ Args:
+ n: This will fix the shape of elements of the space. It can either be an integer (if the space is flat)
+ or some sort of sequence (tuple, list or np.ndarray) if there are multiple axes.
+ seed: Optionally, you can use this argument to seed the RNG that is used to sample from the space.
+ """
+ if isinstance(n, int):
+ self.n = n = int(n)
+ input_n = (n,)
+ assert (np.asarray(input_n) > 0).all() # n (counts) have to be positive
+ elif isinstance(n, (Sequence, np.ndarray)):
+ self.n = input_n = tuple(int(i) for i in n)
+ assert (np.asarray(input_n) > 0).all() # n (counts) have to be positive
+ else:
+ raise ValueError(
+ f"Expected n to be an int or a sequence of ints, actual type: {type(n)}"
+ )
+
+ super().__init__(input_n, np.int8, seed)
+
+ @property
+ def shape(self) -> tuple[int, ...]:
+ """Has stricter type than gym.Space - never None."""
+ return self._shape # type: ignore
+
+ @property
+ def is_np_flattenable(self):
+ """Checks whether this space can be flattened to a :class:`spaces.Box`."""
+ return True
+
+
+[docs]
+ def sample(
+ self, mask: MaskNDArray | None = None, probability: MaskNDArray | None = None
+ ) -> NDArray[np.int8]:
+ """Generates a single random sample from this space.
+
+ A sample is drawn by independent, fair coin tosses (one toss per binary variable of the space).
+
+ Args:
+ mask: An optional ``np.ndarray`` to mask samples with expected shape of ``space.shape``.
+ For ``mask == 0`` then the samples will be ``0``, for a ``mask == 1`` then the samples will be ``1``.
+ For random samples, using a mask value of ``2``.
+ The expected mask shape is the space shape and mask dtype is ``np.int8``.
+ probability: An optional ``np.ndarray`` to mask samples with expected shape of space.shape where each element
+ represents the probability of the corresponding sample element being a 1.
+ The expected mask shape is the space shape and mask dtype is ``np.float64``.
+
+ Returns:
+ Sampled values from space
+ """
+ if mask is not None and probability is not None:
+ raise ValueError(
+ f"Only one of `mask` or `probability` can be provided, actual values: mask={mask}, probability={probability}"
+ )
+ if mask is not None:
+ assert isinstance(
+ mask, np.ndarray
+ ), f"The expected type of the mask is np.ndarray, actual type: {type(mask)}"
+ assert (
+ mask.dtype == np.int8
+ ), f"The expected dtype of the mask is np.int8, actual dtype: {mask.dtype}"
+ assert (
+ mask.shape == self.shape
+ ), f"The expected shape of the mask is {self.shape}, actual shape: {mask.shape}"
+ assert np.all(
+ (mask == 0) | (mask == 1) | (mask == 2)
+ ), f"All values of a mask should be 0, 1 or 2, actual values: {mask}"
+
+ return np.where(
+ mask == 2,
+ self.np_random.integers(low=0, high=2, size=self.n, dtype=self.dtype),
+ mask.astype(self.dtype),
+ )
+ elif probability is not None:
+ assert isinstance(
+ probability, np.ndarray
+ ), f"The expected type of the probability is np.ndarray, actual type: {type(probability)}"
+ assert (
+ probability.dtype == np.float64
+ ), f"The expected dtype of the probability is np.float64, actual dtype: {probability.dtype}"
+ assert (
+ probability.shape == self.shape
+ ), f"The expected shape of the probability is {self.shape}, actual shape: {probability}"
+ assert np.all(
+ np.logical_and(probability >= 0, probability <= 1)
+ ), f"All values of the sample probability should be between 0 and 1, actual values: {probability}"
+
+ return (self.np_random.random(size=self.shape) <= probability).astype(
+ self.dtype
+ )
+ else:
+ return self.np_random.integers(low=0, high=2, size=self.n, dtype=self.dtype)
+
+
+ def contains(self, x: Any) -> bool:
+ """Return boolean specifying if x is a valid member of this space."""
+ if isinstance(x, Sequence):
+ x = np.array(x) # Promote list to array for contains check
+
+ return bool(
+ isinstance(x, np.ndarray)
+ and self.shape == x.shape
+ and np.all(np.logical_or(x == 0, x == 1))
+ )
+
+ def to_jsonable(self, sample_n: Sequence[NDArray[np.int8]]) -> list[Sequence[int]]:
+ """Convert a batch of samples from this space to a JSONable data type."""
+ return np.array(sample_n).tolist()
+
+ def from_jsonable(self, sample_n: list[Sequence[int]]) -> list[NDArray[np.int8]]:
+ """Convert a JSONable data type to a batch of samples from this space."""
+ return [np.asarray(sample, self.dtype) for sample in sample_n]
+
+ def __repr__(self) -> str:
+ """Gives a string representation of this space."""
+ return f"MultiBinary({self.n})"
+
+ def __eq__(self, other: Any) -> bool:
+ """Check whether `other` is equivalent to this instance."""
+ return isinstance(other, MultiBinary) and self.n == other.n
+
+
+"""Implementation of a space that represents the cartesian product of `Discrete` spaces."""
+
+from __future__ import annotations
+
+from collections.abc import Iterable, Mapping, Sequence
+from typing import Any
+
+import numpy as np
+from numpy.typing import NDArray
+
+import gymnasium as gym
+from gymnasium.spaces.discrete import Discrete
+from gymnasium.spaces.space import MaskNDArray, Space
+
+
+
+[docs]
+class MultiDiscrete(Space[NDArray[np.integer]]):
+ """This represents the cartesian product of arbitrary :class:`Discrete` spaces.
+
+ It is useful to represent game controllers or keyboards where each key can be represented as a discrete action space.
+
+ Note:
+ Some environment wrappers assume a value of 0 always represents the NOOP action.
+
+ e.g. Nintendo Game Controller - Can be conceptualized as 3 discrete action spaces:
+
+ 1. Arrow Keys: Discrete 5 - NOOP[0], UP[1], RIGHT[2], DOWN[3], LEFT[4] - params: min: 0, max: 4
+ 2. Button A: Discrete 2 - NOOP[0], Pressed[1] - params: min: 0, max: 1
+ 3. Button B: Discrete 2 - NOOP[0], Pressed[1] - params: min: 0, max: 1
+
+ It can be initialized as ``MultiDiscrete([ 5, 2, 2 ])`` such that a sample might be ``array([3, 1, 0])``.
+
+ Although this feature is rarely used, :class:`MultiDiscrete` spaces may also have several axes
+ if ``nvec`` has several axes:
+
+ Example:
+ >>> from gymnasium.spaces import MultiDiscrete
+ >>> import numpy as np
+ >>> observation_space = MultiDiscrete(np.array([[1, 2], [3, 4]]), seed=42)
+ >>> observation_space.sample()
+ array([[0, 0],
+ [2, 2]])
+ """
+
+ def __init__(
+ self,
+ nvec: NDArray[np.integer[Any]] | list[int],
+ dtype: str | type[np.integer[Any]] = np.int64,
+ seed: int | np.random.Generator | None = None,
+ start: NDArray[np.integer[Any]] | list[int] | None = None,
+ ):
+ """Constructor of :class:`MultiDiscrete` space.
+
+ The argument ``nvec`` will determine the number of values each categorical variable can take. If
+ ``start`` is provided, it will define the minimal values corresponding to each categorical variable.
+
+ Args:
+ nvec: vector of counts of each categorical variable. This will usually be a list of integers. However,
+ you may also pass a more complicated numpy array if you'd like the space to have several axes.
+ dtype: This should be some kind of integer type.
+ seed: Optionally, you can use this argument to seed the RNG that is used to sample from the space.
+ start: Optionally, the starting value the element of each class will take (defaults to 0).
+ """
+ # determine dtype
+ if dtype is None:
+ raise ValueError(
+ "MultiDiscrete dtype must be explicitly provided, cannot be None."
+ )
+ self.dtype = np.dtype(dtype)
+
+ # * check that dtype is an accepted dtype
+ if not (np.issubdtype(self.dtype, np.integer)):
+ raise ValueError(
+ f"Invalid MultiDiscrete dtype ({self.dtype}), must be an integer dtype"
+ )
+
+ self.nvec = np.array(nvec, dtype=dtype, copy=True)
+ if start is not None:
+ self.start = np.array(start, dtype=dtype, copy=True)
+ else:
+ self.start = np.zeros(self.nvec.shape, dtype=dtype)
+
+ assert (
+ self.start.shape == self.nvec.shape
+ ), "start and nvec (counts) should have the same shape"
+ assert (self.nvec > 0).all(), "nvec (counts) have to be positive"
+
+ super().__init__(self.nvec.shape, self.dtype, seed)
+
+ @property
+ def shape(self) -> tuple[int, ...]:
+ """Has stricter type than :class:`gym.Space` - never None."""
+ return self._shape # type: ignore
+
+ @property
+ def is_np_flattenable(self):
+ """Checks whether this space can be flattened to a :class:`spaces.Box`."""
+ return True
+
+
+[docs]
+ def sample(
+ self,
+ mask: tuple[MaskNDArray, ...] | None = None,
+ probability: tuple[MaskNDArray, ...] | None = None,
+ ) -> NDArray[np.integer[Any]]:
+ """Generates a single random sample from this space.
+
+ Args:
+ mask: An optional mask for multi-discrete, expects tuples with a ``np.ndarray`` mask in the position of each
+ action with shape ``(n,)`` where ``n`` is the number of actions and ``dtype=np.int8``.
+ Only ``mask values == 1`` are possible to sample unless all mask values for an action are ``0`` then the default action ``self.start`` (the smallest element) is sampled.
+ probability: An optional probability mask for multi-discrete, expects tuples with a ``np.ndarray`` probability mask in the position of each
+ action with shape ``(n,)`` where ``n`` is the number of actions and ``dtype=np.float64``.
+ Only probability mask values within ``[0,1]`` are possible to sample as long as the sum of all values is ``1``.
+
+ Returns:
+ An ``np.ndarray`` of :meth:`Space.shape`
+ """
+ if mask is not None and probability is not None:
+ raise ValueError(
+ f"Only one of `mask` or `probability` can be provided, actual values: mask={mask}, probability={probability}"
+ )
+ elif mask is not None:
+ return np.array(
+ self._apply_mask(mask, self.nvec, self.start, "mask"),
+ dtype=self.dtype,
+ )
+ elif probability is not None:
+ return np.array(
+ self._apply_mask(probability, self.nvec, self.start, "probability"),
+ dtype=self.dtype,
+ )
+ else:
+ return (self.np_random.random(self.nvec.shape) * self.nvec).astype(
+ self.dtype
+ ) + self.start
+
+
+ def _apply_mask(
+ self,
+ sub_mask: MaskNDArray | tuple[MaskNDArray, ...],
+ sub_nvec: MaskNDArray | np.integer[Any],
+ sub_start: MaskNDArray | np.integer[Any],
+ mask_type: str,
+ ) -> int | list[Any]:
+ """Returns a sample using the provided mask or probability mask."""
+ if isinstance(sub_nvec, np.ndarray):
+ assert isinstance(
+ sub_mask, tuple
+ ), f"Expects the mask to be a tuple for sub_nvec ({sub_nvec}), actual type: {type(sub_mask)}"
+ assert len(sub_mask) == len(
+ sub_nvec
+ ), f"Expects the mask length to be equal to the number of actions, mask length: {len(sub_mask)}, nvec length: {len(sub_nvec)}"
+ return [
+ self._apply_mask(new_mask, new_nvec, new_start, mask_type)
+ for new_mask, new_nvec, new_start in zip(sub_mask, sub_nvec, sub_start)
+ ]
+
+ assert np.issubdtype(
+ type(sub_nvec), np.integer
+ ), f"Expects the sub_nvec to be an action, actually: {sub_nvec}, {type(sub_nvec)}"
+ assert isinstance(
+ sub_mask, np.ndarray
+ ), f"Expects the sub mask to be np.ndarray, actual type: {type(sub_mask)}"
+ assert (
+ len(sub_mask) == sub_nvec
+ ), f"Expects the mask length to be equal to the number of actions, mask length: {len(sub_mask)}, action: {sub_nvec}"
+
+ if mask_type == "mask":
+ assert (
+ sub_mask.dtype == np.int8
+ ), f"Expects the mask dtype to be np.int8, actual dtype: {sub_mask.dtype}"
+
+ valid_action_mask = sub_mask == 1
+ assert np.all(
+ np.logical_or(sub_mask == 0, valid_action_mask)
+ ), f"Expects all masks values to 0 or 1, actual values: {sub_mask}"
+
+ if np.any(valid_action_mask):
+ return self.np_random.choice(np.where(valid_action_mask)[0]) + sub_start
+ else:
+ return sub_start
+ elif mask_type == "probability":
+ assert (
+ sub_mask.dtype == np.float64
+ ), f"Expects the mask dtype to be np.float64, actual dtype: {sub_mask.dtype}"
+ valid_action_mask = np.logical_and(sub_mask > 0, sub_mask <= 1)
+ assert np.all(
+ np.logical_or(sub_mask == 0, valid_action_mask)
+ ), f"Expects all masks values to be between 0 and 1, actual values: {sub_mask}"
+ assert np.isclose(
+ np.sum(sub_mask), 1
+ ), f"Expects the sum of all mask values to be 1, actual sum: {np.sum(sub_mask)}"
+
+ normalized_sub_mask = sub_mask / np.sum(sub_mask)
+ return (
+ self.np_random.choice(
+ np.where(valid_action_mask)[0],
+ p=normalized_sub_mask[valid_action_mask],
+ )
+ + sub_start
+ )
+ raise ValueError(f"Unsupported mask type: {mask_type}")
+
+ def contains(self, x: Any) -> bool:
+ """Return boolean specifying if x is a valid member of this space."""
+ if isinstance(x, Sequence):
+ x = np.array(x) # Promote list to array for contains check
+
+ # if nvec is uint32 and space dtype is uint32, then 0 <= x < self.nvec guarantees that x
+ # is within correct bounds for space dtype (even though x does not have to be unsigned)
+ return bool(
+ isinstance(x, np.ndarray)
+ and x.shape == self.shape
+ and x.dtype != object
+ and np.all(self.start <= x)
+ and np.all(x - self.start < self.nvec)
+ )
+
+ def to_jsonable(
+ self, sample_n: Sequence[NDArray[np.integer[Any]]]
+ ) -> list[Sequence[int]]:
+ """Convert a batch of samples from this space to a JSONable data type."""
+ return [sample.tolist() for sample in sample_n]
+
+ def from_jsonable(
+ self, sample_n: list[Sequence[int]]
+ ) -> list[NDArray[np.integer[Any]]]:
+ """Convert a JSONable data type to a batch of samples from this space."""
+ return [np.array(sample, dtype=np.int64) for sample in sample_n]
+
+ def __repr__(self):
+ """Gives a string representation of this space."""
+ if np.any(self.start != 0):
+ return f"MultiDiscrete({self.nvec}, start={self.start})"
+ return f"MultiDiscrete({self.nvec})"
+
+ def __getitem__(self, index: int | tuple[int, ...]):
+ """Extract a subspace from this ``MultiDiscrete`` space."""
+ nvec = self.nvec[index]
+ start = self.start[index]
+ if nvec.ndim == 0:
+ subspace = Discrete(nvec, start=start)
+ else:
+ subspace = MultiDiscrete(nvec, self.dtype, start=start)
+
+ # you don't need to deepcopy as np random generator call replaces the state not the data
+ subspace.np_random.bit_generator.state = self.np_random.bit_generator.state
+
+ return subspace
+
+ def __len__(self):
+ """Gives the ``len`` of samples from this space."""
+ if self.nvec.ndim >= 2:
+ gym.logger.warn(
+ "Getting the length of a multi-dimensional MultiDiscrete space."
+ )
+ return len(self.nvec)
+
+ def __eq__(self, other: Any) -> bool:
+ """Check whether ``other`` is equivalent to this instance."""
+ return bool(
+ isinstance(other, MultiDiscrete)
+ and self.dtype == other.dtype
+ and self.shape == other.shape
+ and np.all(self.nvec == other.nvec)
+ and np.all(self.start == other.start)
+ )
+
+ def __setstate__(self, state: Iterable[tuple[str, Any]] | Mapping[str, Any]):
+ """Used when loading a pickled space.
+
+ This method has to be implemented explicitly to allow for loading of legacy states.
+
+ Args:
+ state: The new state
+ """
+ state = dict(state)
+
+ if "start" not in state:
+ state["start"] = np.zeros(state["_shape"], dtype=state["dtype"])
+
+ super().__setstate__(state)
+
+
+"""Implementation of a space that represents the cartesian product of other spaces."""
+
+from __future__ import annotations
+
+import typing
+from collections.abc import Iterable
+from typing import Any
+
+import numpy as np
+
+from gymnasium.spaces.space import Space
+
+
+
+[docs]
+class OneOf(Space[Any]):
+ """An exclusive tuple (more precisely: the direct sum) of :class:`Space` instances.
+
+ Elements of this space are elements of one of the constituent spaces.
+
+ Example:
+ >>> from gymnasium.spaces import OneOf, Box, Discrete
+ >>> observation_space = OneOf((Discrete(2), Box(-1, 1, shape=(2,))), seed=123)
+ >>> observation_space.sample() # the first element is the space index (Discrete in this case) and the second element is the sample from Discrete
+ (np.int64(0), np.int64(0))
+ >>> observation_space.sample() # this time the Box space was sampled as index=1
+ (np.int64(1), array([-0.00711833, -0.7257502 ], dtype=float32))
+ >>> observation_space[0]
+ Discrete(2)
+ >>> observation_space[1]
+ Box(-1.0, 1.0, (2,), float32)
+ >>> len(observation_space)
+ 2
+ """
+
+ def __init__(
+ self,
+ spaces: Iterable[Space[Any]],
+ seed: int | typing.Sequence[int] | np.random.Generator | None = None,
+ ):
+ r"""Constructor of :class:`OneOf` space.
+
+ The generated instance will represent the cartesian product :math:`\text{spaces}[0] \times ... \times \text{spaces}[-1]`.
+
+ Args:
+ spaces (Iterable[Space]): The spaces that are involved in the cartesian product.
+ seed: Optionally, you can use this argument to seed the RNGs of the ``spaces`` to ensure reproducible sampling.
+ """
+ assert isinstance(spaces, Iterable), f"{spaces} is not an iterable"
+ self.spaces = tuple(spaces)
+ assert len(self.spaces) > 0, "Empty `OneOf` spaces are not supported."
+ for space in self.spaces:
+ assert isinstance(
+ space, Space
+ ), f"{space} does not inherit from `gymnasium.Space`. Actual Type: {type(space)}"
+ super().__init__(None, None, seed)
+
+ @property
+ def is_np_flattenable(self):
+ """Checks whether this space can be flattened to a :class:`spaces.Box`."""
+ return all(space.is_np_flattenable for space in self.spaces)
+
+
+[docs]
+ def seed(self, seed: int | tuple[int, ...] | None = None) -> tuple[int, ...]:
+ """Seed the PRNG of this space and all subspaces.
+
+ Depending on the type of seed, the subspaces will be seeded differently
+
+ * ``None`` - All the subspaces will use a random initial seed
+ * ``Int`` - The integer is used to seed the :class:`Tuple` space that is used to generate seed values for each of the subspaces. Warning, this does not guarantee unique seeds for all the subspaces.
+ * ``Tuple[int, ...]`` - Values used to seed the subspaces, first value seeds the OneOf and subsequent seed the subspaces. This allows the seeding of multiple composite subspaces ``[42, 54, ...]``.
+
+ Args:
+ seed: An optional int or tuple of ints to seed the OneOf space and subspaces. See above for more details.
+
+ Returns:
+ A tuple of ints used to seed the OneOf space and subspaces
+ """
+ if seed is None:
+ super_seed = super().seed(None)
+ return (super_seed,) + tuple(space.seed(None) for space in self.spaces)
+ elif isinstance(seed, int):
+ super_seed = super().seed(seed)
+ subseeds = self.np_random.integers(
+ np.iinfo(np.int32).max, size=len(self.spaces)
+ )
+ # this is necessary such that after int or list/tuple seeding, the OneOf PRNG are equivalent
+ super().seed(seed)
+ return (super_seed,) + tuple(
+ space.seed(int(subseed))
+ for space, subseed in zip(self.spaces, subseeds)
+ )
+ elif isinstance(seed, (tuple, list)):
+ if len(seed) != len(self.spaces) + 1:
+ raise ValueError(
+ f"Expects that the subspaces of seeds equals the number of subspaces + 1. Actual length of seeds: {len(seed)}, length of subspaces: {len(self.spaces)}"
+ )
+
+ return (super().seed(seed[0]),) + tuple(
+ space.seed(subseed) for space, subseed in zip(self.spaces, seed[1:])
+ )
+ else:
+ raise TypeError(
+ f"Expected None, int, or tuple of ints, actual type: {type(seed)}"
+ )
+
+
+
+[docs]
+ def sample(
+ self,
+ mask: tuple[Any | None, ...] | None = None,
+ probability: tuple[Any | None, ...] | None = None,
+ ) -> tuple[int, Any]:
+ """Generates a single random sample inside this space.
+
+ This method draws independent samples from the subspaces.
+
+ Args:
+ mask: An optional tuple of optional masks for each of the subspace's samples,
+ expects the same number of masks as spaces
+ probability: An optional tuple of optional probability masks for each of the subspace's samples,
+ expects the same number of probability masks as spaces
+
+ Returns:
+ Tuple of the subspace's samples
+ """
+ subspace_idx = self.np_random.integers(0, len(self.spaces), dtype=np.int64)
+ subspace = self.spaces[subspace_idx]
+
+ if mask is not None and probability is not None:
+ raise ValueError(
+ f"Only one of `mask` or `probability` can be provided, actual values: mask={mask}, probability={probability}"
+ )
+ elif mask is not None:
+ assert isinstance(
+ mask, tuple
+ ), f"Expected type of `mask` is tuple, actual type: {type(mask)}"
+ assert len(mask) == len(
+ self.spaces
+ ), f"Expected length of `mask` is {len(self.spaces)}, actual length: {len(mask)}"
+
+ subspace_sample = subspace.sample(mask=mask[subspace_idx])
+
+ elif probability is not None:
+ assert isinstance(
+ probability, tuple
+ ), f"Expected type of `probability` is tuple, actual type: {type(probability)}"
+ assert len(probability) == len(
+ self.spaces
+ ), f"Expected length of `probability` is {len(self.spaces)}, actual length: {len(probability)}"
+
+ subspace_sample = subspace.sample(probability=probability[subspace_idx])
+ else:
+ subspace_sample = subspace.sample()
+
+ return subspace_idx, subspace_sample
+
+
+ def contains(self, x: tuple[int, Any]) -> bool:
+ """Return boolean specifying if x is a valid member of this space."""
+ # subspace_idx, subspace_value = x
+ return (
+ isinstance(x, tuple)
+ and len(x) == 2
+ and isinstance(x[0], (np.int64, int))
+ and 0 <= x[0] < len(self.spaces)
+ and self.spaces[x[0]].contains(x[1])
+ )
+
+ def __repr__(self) -> str:
+ """Gives a string representation of this space."""
+ return "OneOf(" + ", ".join([str(s) for s in self.spaces]) + ")"
+
+ def to_jsonable(
+ self, sample_n: typing.Sequence[tuple[int, Any]]
+ ) -> list[list[Any]]:
+ """Convert a batch of samples from this space to a JSONable data type."""
+ return [
+ [int(i), self.spaces[i].to_jsonable([subsample])[0]]
+ for (i, subsample) in sample_n
+ ]
+
+ def from_jsonable(self, sample_n: list[list[Any]]) -> list[tuple[Any, ...]]:
+ """Convert a JSONable data type to a batch of samples from this space."""
+ return [
+ (
+ np.int64(space_idx),
+ self.spaces[space_idx].from_jsonable([jsonable_sample])[0],
+ )
+ for space_idx, jsonable_sample in sample_n
+ ]
+
+ def __getitem__(self, index: int) -> Space[Any]:
+ """Get the subspace at specific `index`."""
+ return self.spaces[index]
+
+ def __len__(self) -> int:
+ """Get the number of subspaces that are involved in the cartesian product."""
+ return len(self.spaces)
+
+ def __eq__(self, other: Any) -> bool:
+ """Check whether ``other`` is equivalent to this instance."""
+ return isinstance(other, OneOf) and self.spaces == other.spaces
+
+
+"""Implementation of a space that represents finite-length sequences."""
+
+from __future__ import annotations
+
+import typing
+from typing import Any, Union
+
+import numpy as np
+from numpy.typing import NDArray
+
+import gymnasium as gym
+from gymnasium.spaces.space import Space
+
+
+
+[docs]
+class Sequence(Space[Union[tuple[Any, ...], Any]]):
+ r"""This space represent sets of finite-length sequences.
+
+ This space represents the set of tuples of the form :math:`(a_0, \dots, a_n)` where the :math:`a_i` belong
+ to some space that is specified during initialization and the integer :math:`n` is not fixed
+
+ Example:
+ >>> from gymnasium.spaces import Sequence, Box
+ >>> observation_space = Sequence(Box(0, 1), seed=0)
+ >>> observation_space.sample()
+ (array([0.6822636], dtype=float32), array([0.18933342], dtype=float32), array([0.19049619], dtype=float32))
+ >>> observation_space.sample()
+ (array([0.83506], dtype=float32), array([0.9053838], dtype=float32), array([0.5836242], dtype=float32), array([0.63214064], dtype=float32))
+
+ Example with stacked observations
+ >>> observation_space = Sequence(Box(0, 1), stack=True, seed=0)
+ >>> observation_space.sample()
+ array([[0.6822636 ],
+ [0.18933342],
+ [0.19049619]], dtype=float32)
+ """
+
+ def __init__(
+ self,
+ space: Space[Any],
+ seed: int | np.random.Generator | None = None,
+ stack: bool = False,
+ ):
+ """Constructor of the :class:`Sequence` space.
+
+ Args:
+ space: Elements in the sequences this space represent must belong to this space.
+ seed: Optionally, you can use this argument to seed the RNG that is used to sample from the space.
+ stack: If ``True`` then the resulting samples would be stacked.
+ """
+ assert isinstance(
+ space, Space
+ ), f"Expects the feature space to be instance of a gym Space, actual type: {type(space)}"
+ self.feature_space = space
+ self.stack = stack
+ if self.stack:
+ self.stacked_feature_space: Space = gym.vector.utils.batch_space(
+ self.feature_space, 1
+ )
+
+ # None for shape and dtype, since it'll require special handling
+ super().__init__(None, None, seed)
+
+
+[docs]
+ def seed(self, seed: int | tuple[int, int] | None = None) -> tuple[int, int]:
+ """Seed the PRNG of the Sequence space and the feature space.
+
+ Depending on the type of seed, the subspaces will be seeded differently
+
+ * ``None`` - All the subspaces will use a random initial seed
+ * ``Int`` - The integer is used to seed the :class:`Sequence` space that is used to generate a seed value for the feature space.
+ * ``Tuple of ints`` - A tuple for the :class:`Sequence` and feature space.
+
+ Args:
+ seed: An optional int or tuple of ints to seed the PRNG. See above for more details
+
+ Returns:
+ A tuple of the seeding values for the Sequence and feature space
+ """
+ if seed is None:
+ return super().seed(None), self.feature_space.seed(None)
+ elif isinstance(seed, int):
+ super_seed = super().seed(seed)
+ feature_seed = int(self.np_random.integers(np.iinfo(np.int32).max))
+ # this is necessary such that after int or list/tuple seeding, the Sequence PRNG are equivalent
+ super().seed(seed)
+ return super_seed, self.feature_space.seed(feature_seed)
+ elif isinstance(seed, (tuple, list)):
+ if len(seed) != 2:
+ raise ValueError(
+ f"Expects the seed to have two elements for the Sequence and feature space, actual length: {len(seed)}"
+ )
+ return super().seed(seed[0]), self.feature_space.seed(seed[1])
+ else:
+ raise TypeError(
+ f"Expected None, int, tuple of ints, actual type: {type(seed)}"
+ )
+
+
+ @property
+ def is_np_flattenable(self):
+ """Checks whether this space can be flattened to a :class:`spaces.Box`."""
+ return False
+
+
+[docs]
+ def sample(
+ self,
+ mask: None | (
+ tuple[
+ None | int | NDArray[np.integer],
+ Any,
+ ]
+ ) = None,
+ probability: None | (
+ tuple[
+ None | int | NDArray[np.integer],
+ Any,
+ ]
+ ) = None,
+ ) -> tuple[Any] | Any:
+ """Generates a single random sample from this space.
+
+ Args:
+ mask: An optional mask for (optionally) the length of the sequence and (optionally) the values in the sequence.
+ If you specify ``mask``, it is expected to be a tuple of the form ``(length_mask, sample_mask)`` where ``length_mask`` is
+
+ * ``None`` - The length will be randomly drawn from a geometric distribution
+ * ``int`` - Fixed length
+ * ``np.ndarray`` of integers - Length of the sampled sequence is randomly drawn from this array.
+
+ The second element of the tuple ``sample_mask`` specifies how the feature space will be sampled.
+ Depending on if mask or probability is used will affect what argument is used.
+ probability: See mask description above, the only difference is on the ``sample_mask`` for the feature space being probability rather than mask.
+
+ Returns:
+ A tuple of random length with random samples of elements from the :attr:`feature_space`.
+ """
+ if mask is not None and probability is not None:
+ raise ValueError(
+ f"Only one of `mask` or `probability` can be provided, actual values: mask={mask}, probability={probability}"
+ )
+ elif mask is not None:
+ sample_length = self.generate_sample_length(mask[0], "mask")
+ sampled_values = tuple(
+ self.feature_space.sample(mask=mask[1]) for _ in range(sample_length)
+ )
+ elif probability is not None:
+ sample_length = self.generate_sample_length(probability[0], "probability")
+ sampled_values = tuple(
+ self.feature_space.sample(probability=probability[1])
+ for _ in range(sample_length)
+ )
+ else:
+ sample_length = self.np_random.geometric(0.25)
+ sampled_values = tuple(
+ self.feature_space.sample() for _ in range(sample_length)
+ )
+
+ if self.stack:
+ # Concatenate values if stacked.
+ out = gym.vector.utils.create_empty_array(
+ self.feature_space, len(sampled_values)
+ )
+ return gym.vector.utils.concatenate(self.feature_space, sampled_values, out)
+
+ return sampled_values
+
+
+ def generate_sample_length(
+ self,
+ length_mask: None | np.integer | NDArray[np.integer],
+ mask_type: None | str,
+ ) -> int:
+ """Generate the sample length for a given length mask and mask type."""
+ if length_mask is not None:
+ if np.issubdtype(type(length_mask), np.integer):
+ assert (
+ 0 <= length_mask
+ ), f"Expects the length mask of `{mask_type}` to be greater than or equal to zero, actual value: {length_mask}"
+
+ return length_mask
+ elif isinstance(length_mask, np.ndarray):
+ assert (
+ len(length_mask.shape) == 1
+ ), f"Expects the shape of the length mask of `{mask_type}` to be 1-dimensional, actual shape: {length_mask.shape}"
+ assert np.all(
+ 0 <= length_mask
+ ), f"Expects all values in the length_mask of `{mask_type}` to be greater than or equal to zero, actual values: {length_mask}"
+ assert np.issubdtype(
+ length_mask.dtype, np.integer
+ ), f"Expects the length mask array of `{mask_type}` to have dtype of np.integer, actual type: {length_mask.dtype}"
+
+ return self.np_random.choice(length_mask)
+ else:
+ raise TypeError(
+ f"Expects the type of length_mask of `{mask_type}` to be an integer or a np.ndarray, actual type: {type(length_mask)}"
+ )
+ else:
+ # The choice of 0.25 is arbitrary
+ return self.np_random.geometric(0.25)
+
+ def contains(self, x: Any) -> bool:
+ """Return boolean specifying if x is a valid member of this space."""
+ # by definition, any sequence is an iterable
+ if self.stack:
+ return all(
+ item in self.feature_space
+ for item in gym.vector.utils.iterate(self.stacked_feature_space, x)
+ )
+ else:
+ return isinstance(x, tuple) and all(
+ self.feature_space.contains(item) for item in x
+ )
+
+ def __repr__(self) -> str:
+ """Gives a string representation of this space."""
+ return f"Sequence({self.feature_space}, stack={self.stack})"
+
+ def to_jsonable(
+ self, sample_n: typing.Sequence[tuple[Any, ...] | Any]
+ ) -> list[list[Any]]:
+ """Convert a batch of samples from this space to a JSONable data type."""
+ if self.stack:
+ return self.stacked_feature_space.to_jsonable(sample_n)
+ else:
+ return [self.feature_space.to_jsonable(sample) for sample in sample_n]
+
+ def from_jsonable(self, sample_n: list[list[Any]]) -> list[tuple[Any, ...] | Any]:
+ """Convert a JSONable data type to a batch of samples from this space."""
+ if self.stack:
+ return self.stacked_feature_space.from_jsonable(sample_n)
+ else:
+ return [
+ tuple(self.feature_space.from_jsonable(sample)) for sample in sample_n
+ ]
+
+ def __eq__(self, other: Any) -> bool:
+ """Check whether ``other`` is equivalent to this instance."""
+ return (
+ isinstance(other, Sequence)
+ and self.feature_space == other.feature_space
+ and self.stack == other.stack
+ )
+
+
+"""Implementation of the `Space` metaclass."""
+
+from __future__ import annotations
+
+from collections.abc import Iterable, Mapping, Sequence
+from typing import Any, Generic, TypeAlias, TypeVar
+
+import numpy as np
+import numpy.typing as npt
+
+from gymnasium.utils import seeding
+
+
+T_cov = TypeVar("T_cov", covariant=True)
+
+
+MaskNDArray: TypeAlias = npt.NDArray[np.int8]
+
+
+
+[docs]
+class Space(Generic[T_cov]):
+ """Superclass that is used to define observation and action spaces.
+
+ Spaces are crucially used in Gym to define the format of valid actions and observations.
+ They serve various purposes:
+
+ * They clearly define how to interact with environments, i.e. they specify what actions need to look like
+ and what observations will look like
+ * They allow us to work with highly structured data (e.g. in the form of elements of :class:`Dict` spaces)
+ and painlessly transform them into flat arrays that can be used in learning code
+ * They provide a method to sample random elements. This is especially useful for exploration and debugging.
+
+ Different spaces can be combined hierarchically via container spaces (:class:`Tuple` and :class:`Dict`) to build a
+ more expressive space
+
+ Warning:
+ Custom observation & action spaces can inherit from the ``Space``
+ class. However, most use-cases should be covered by the existing space
+ classes (e.g. :class:`Box`, :class:`Discrete`, etc...), and container classes (:class:`Tuple` &
+ :class:`Dict`). Note that parametrized probability distributions (through the
+ :meth:`Space.sample()` method), and batching functions (in :class:`gym.vector.VectorEnv`), are
+ only well-defined for instances of spaces provided in gym by default.
+ Moreover, some implementations of Reinforcement Learning algorithms might
+ not handle custom spaces properly. Use custom spaces with care.
+ """
+
+ def __init__(
+ self,
+ shape: Sequence[int] | None = None,
+ dtype: npt.DTypeLike | None = None,
+ seed: int | np.random.Generator | None = None,
+ ):
+ """Constructor of :class:`Space`.
+
+ Args:
+ shape (Optional[Sequence[int]]): If elements of the space are numpy arrays, this should specify their shape.
+ dtype (Optional[Type | str]): If elements of the space are numpy arrays, this should specify their dtype.
+ seed: Optionally, you can use this argument to seed the RNG that is used to sample from the space
+ """
+ self._shape = None if shape is None else tuple(shape)
+ self.dtype = None if dtype is None else np.dtype(dtype)
+ self._np_random = None
+ if seed is not None:
+ if isinstance(seed, np.random.Generator):
+ self._np_random = seed
+ else:
+ self.seed(seed)
+
+ @property
+ def np_random(self) -> np.random.Generator:
+ """Lazily seed the PRNG since this is expensive and only needed if sampling from this space.
+
+ As :meth:`seed` is not guaranteed to set the `_np_random` for particular seeds. We add a
+ check after :meth:`seed` to set a new random number generator.
+ """
+ if self._np_random is None:
+ self.seed()
+
+ # As `seed` is not guaranteed (in particular for composite spaces) to set the `_np_random` then we set it randomly.
+ if self._np_random is None:
+ self._np_random, _ = seeding.np_random()
+
+ return self._np_random
+
+ @property
+ def shape(self) -> tuple[int, ...] | None:
+ """Return the shape of the space as an immutable property."""
+ return self._shape
+
+ @property
+ def is_np_flattenable(self) -> bool:
+ """Checks whether this space can be flattened to a :class:`gymnasium.spaces.Box`."""
+ raise NotImplementedError
+
+
+[docs]
+ def sample(self, mask: Any | None = None, probability: Any | None = None) -> T_cov:
+ """Randomly sample an element of this space.
+
+ Can be uniform or non-uniform sampling based on boundedness of space.
+
+ The binary mask and the probability mask can't be used at the same time.
+
+ Args:
+ mask: A mask used for random sampling, expected ``dtype=np.int8`` and see sample implementation for expected shape.
+ probability: A probability mask used for sampling according to the given probability distribution, expected ``dtype=np.float64`` and see sample implementation for expected shape.
+
+ Returns:
+ A sampled actions from the space
+ """
+ raise NotImplementedError
+
+
+
+[docs]
+ def seed(self, seed: int | None = None) -> int | list[int] | dict[str, int]:
+ """Seed the pseudorandom number generator (PRNG) of this space and, if applicable, the PRNGs of subspaces.
+
+ Args:
+ seed: The seed value for the space. This is expanded for composite spaces to accept multiple values. For further details, please refer to the space's documentation.
+
+ Returns:
+ The seed values used for all the PRNGs, for composite spaces this can be a tuple or dictionary of values.
+ """
+ self._np_random, np_random_seed = seeding.np_random(seed)
+ return np_random_seed
+
+
+
+[docs]
+ def contains(self, x: Any) -> bool:
+ """Return boolean specifying if x is a valid member of this space, equivalent to ``sample in space``."""
+ raise NotImplementedError
+
+
+ def __contains__(self, x: Any) -> bool:
+ """Return boolean specifying if x is a valid member of this space."""
+ return self.contains(x)
+
+ def __setstate__(self, state: Iterable[tuple[str, Any]] | Mapping[str, Any]):
+ """Used when loading a pickled space.
+
+ This method was implemented explicitly to allow for loading of legacy states.
+
+ Args:
+ state: The updated state value
+ """
+ # Don't mutate the original state
+ state = dict(state)
+
+ # Allow for loading of legacy states.
+ # See:
+ # https://github.com/openai/gym/pull/2397 -- shape
+ # https://github.com/openai/gym/pull/1913 -- np_random
+ #
+ if "shape" in state:
+ state["_shape"] = state.get("shape")
+ del state["shape"]
+ if "np_random" in state:
+ state["_np_random"] = state["np_random"]
+ del state["np_random"]
+
+ # Update our state
+ self.__dict__.update(state)
+
+
+[docs]
+ def to_jsonable(self, sample_n: Sequence[T_cov]) -> list[Any]:
+ """Convert a batch of samples from this space to a JSONable data type."""
+ # By default, assume identity is JSONable
+ return list(sample_n)
+
+
+
+[docs]
+ def from_jsonable(self, sample_n: list[Any]) -> list[T_cov]:
+ """Convert a JSONable data type to a batch of samples from this space."""
+ # By default, assume identity is JSONable
+ return sample_n
+
+
+
+"""Implementation of a space that represents textual strings."""
+
+from __future__ import annotations
+
+from typing import Any
+
+import numpy as np
+from numpy.typing import NDArray
+
+from gymnasium.spaces.space import Space
+
+
+alphanumeric: frozenset[str] = frozenset(
+ "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
+)
+
+
+
+[docs]
+class Text(Space[str]):
+ r"""A space representing a string comprised of characters from a given charset.
+
+ Example:
+ >>> from gymnasium.spaces import Text
+ >>> # {"", "B5", "hello", ...}
+ >>> Text(5)
+ Text(1, 5, charset=0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz)
+ >>> # {"0", "42", "0123456789", ...}
+ >>> import string
+ >>> Text(min_length = 1,
+ ... max_length = 10,
+ ... charset = string.digits)
+ Text(1, 10, charset=0123456789)
+ """
+
+ def __init__(
+ self,
+ max_length: int,
+ *,
+ min_length: int = 1,
+ charset: frozenset[str] | str = alphanumeric,
+ seed: int | np.random.Generator | None = None,
+ ):
+ r"""Constructor of :class:`Text` space.
+
+ Both bounds for text length are inclusive.
+
+ Args:
+ min_length (int): Minimum text length (in characters). Defaults to 1 to prevent empty strings.
+ max_length (int): Maximum text length (in characters).
+ charset (Union[set], str): Character set, defaults to the lower and upper english alphabet plus latin digits.
+ seed: The seed for sampling from the space.
+ """
+ assert np.issubdtype(
+ type(min_length), np.integer
+ ), f"Expects the min_length to be an integer, actual type: {type(min_length)}"
+ assert np.issubdtype(
+ type(max_length), np.integer
+ ), f"Expects the max_length to be an integer, actual type: {type(max_length)}"
+ assert (
+ 0 <= min_length
+ ), f"Minimum text length must be non-negative, actual value: {min_length}"
+ assert (
+ min_length <= max_length
+ ), f"The min_length must be less than or equal to the max_length, min_length: {min_length}, max_length: {max_length}"
+
+ self.min_length: int = int(min_length)
+ self.max_length: int = int(max_length)
+
+ self._char_set: frozenset[str] = frozenset(charset)
+ self._char_list: tuple[str, ...] = tuple(charset)
+ self._char_index: dict[str, np.int32] = {
+ val: np.int32(i) for i, val in enumerate(tuple(charset))
+ }
+ self._char_str: str = "".join(sorted(tuple(charset)))
+
+ # As the shape is dynamic (between min_length and max_length) then None
+ super().__init__(dtype=str, seed=seed)
+
+
+[docs]
+ def sample(
+ self,
+ mask: None | (tuple[int | None, NDArray[np.int8] | None]) = None,
+ probability: None | (tuple[int | None, NDArray[np.float64] | None]) = None,
+ ) -> str:
+ """Generates a single random sample from this space with by default a random length between ``min_length`` and ``max_length`` and sampled from the ``charset``.
+
+ Args:
+ mask: An optional tuples of length and mask for the text.
+ The length is expected to be between the ``min_length`` and ``max_length``.
+ Otherwise, a random integer between ``min_length`` and ``max_length`` is selected.
+ For the mask, we expect a numpy array of length of the charset passed with ``dtype == np.int8``.
+ If the charlist mask is all zero then an empty string is returned no matter the ``min_length``
+ probability: An optional tuples of length and probability mask for the text.
+ The length is expected to be between the ``min_length`` and ``max_length``.
+ Otherwise, a random integer between ``min_length`` and ``max_length`` is selected.
+ For the probability mask, we expect a numpy array of length of the charset passed with ``dtype == np.float64``.
+ The sum of the probability mask should be 1, otherwise an exception is raised.
+
+ Returns:
+ A sampled string from the space
+ """
+ if mask is not None and probability is not None:
+ raise ValueError(
+ f"Only one of `mask` or `probability` can be provided, actual values: mask={mask}, probability={probability}"
+ )
+ elif mask is not None:
+ length, charlist_mask = self._validate_mask(mask, np.int8, "mask")
+
+ if charlist_mask is not None:
+ assert np.all(
+ np.logical_or(charlist_mask == 0, charlist_mask == 1)
+ ), f"Expects all mask values to 0 or 1, actual values: {charlist_mask}"
+
+ # normalise the mask to use as a probability
+ if np.sum(charlist_mask) > 0:
+ charlist_mask = charlist_mask / np.sum(charlist_mask)
+ elif probability is not None:
+ length, charlist_mask = self._validate_mask(
+ probability, np.float64, "probability"
+ )
+
+ if charlist_mask is not None:
+ assert np.all(
+ np.logical_and(charlist_mask >= 0, charlist_mask <= 1)
+ ), f"Expects all probability mask values to be within 0 and 1, actual values: {charlist_mask}"
+ assert np.isclose(
+ np.sum(charlist_mask), 1
+ ), f"Expects the sum of the probability mask to be 1, actual sum: {np.sum(charlist_mask)}"
+ else:
+ length = charlist_mask = None
+
+ if length is None:
+ length = self.np_random.integers(self.min_length, self.max_length + 1)
+ if charlist_mask is None: # uniform sampling
+ charlist_mask = np.ones(len(self.character_set)) / len(self.character_set)
+
+ if np.all(charlist_mask == 0):
+ if self.min_length == 0:
+ return ""
+ else:
+ # Otherwise the string will not be contained in the space
+ raise ValueError(
+ f"Trying to sample with a minimum length > 0 (actual minimum length={self.min_length}) but the character mask is all zero meaning that no character could be sampled."
+ )
+
+ string = self.np_random.choice(
+ self.character_list, size=length, p=charlist_mask
+ )
+ return "".join(string)
+
+
+ def _validate_mask(
+ self,
+ mask: tuple[int | None, NDArray[np.int8] | NDArray[np.float64] | None],
+ expected_dtype: np.dtype,
+ mask_type: str,
+ ) -> tuple[int | None, NDArray[np.int8] | NDArray[np.float64] | None]:
+ assert isinstance(
+ mask, tuple
+ ), f"Expects the `{mask_type}` type to be a tuple, actual type: {type(mask)}"
+ assert (
+ len(mask) == 2
+ ), f"Expects the `{mask_type}` length to be two, actual length: {len(mask)}"
+ length, charlist_mask = mask
+
+ if length is not None:
+ assert np.issubdtype(
+ type(length), np.integer
+ ), f"Expects the Text sample length to be an integer, actual type: {type(length)}"
+ assert (
+ self.min_length <= length <= self.max_length
+ ), f"Expects the Text sample length be between {self.min_length} and {self.max_length}, actual length: {length}"
+ if charlist_mask is not None:
+ assert isinstance(
+ charlist_mask, np.ndarray
+ ), f"Expects the Text sample `{mask_type}` to be an np.ndarray, actual type: {type(charlist_mask)}"
+ assert (
+ charlist_mask.dtype == expected_dtype
+ ), f"Expects the Text sample `{mask_type}` to be type {expected_dtype}, actual dtype: {charlist_mask.dtype}"
+ assert charlist_mask.shape == (
+ len(self.character_set),
+ ), f"expects the Text sample `{mask_type}` to be {(len(self.character_set),)}, actual shape: {charlist_mask.shape}"
+
+ return length, charlist_mask
+
+ def contains(self, x: Any) -> bool:
+ """Return boolean specifying if x is a valid member of this space."""
+ if isinstance(x, str):
+ if self.min_length <= len(x) <= self.max_length:
+ return all(c in self.character_set for c in x)
+ return False
+
+ def __repr__(self) -> str:
+ """Gives a string representation of this space."""
+ return f"Text({self.min_length}, {self.max_length}, charset={self.characters})"
+
+ def __eq__(self, other: Any) -> bool:
+ """Check whether ``other`` is equivalent to this instance."""
+ return (
+ isinstance(other, Text)
+ and self.min_length == other.min_length
+ and self.max_length == other.max_length
+ and self.character_set == other.character_set
+ )
+
+ @property
+ def character_set(self) -> frozenset[str]:
+ """Returns the character set for the space."""
+ return self._char_set
+
+ @property
+ def character_list(self) -> tuple[str, ...]:
+ """Returns a tuple of characters in the space."""
+ return self._char_list
+
+ def character_index(self, char: str) -> np.int32:
+ """Returns a unique index for each character in the space's character set."""
+ return self._char_index[char]
+
+ @property
+ def characters(self) -> str:
+ """Returns a string with all Text characters."""
+ return self._char_str
+
+ @property
+ def is_np_flattenable(self) -> bool:
+ """The flattened version is an integer array for each character, padded to the max character length."""
+ return True
+
+
+"""Implementation of a space that represents the cartesian product of other spaces."""
+
+from __future__ import annotations
+
+import typing
+from collections.abc import Iterable
+from typing import Any
+
+import numpy as np
+
+from gymnasium.spaces.space import Space
+
+
+
+[docs]
+class Tuple(Space[tuple[Any, ...]], typing.Sequence[Any]):
+ """A tuple (more precisely: the cartesian product) of :class:`Space` instances.
+
+ Elements of this space are tuples of elements of the constituent spaces.
+
+ Example:
+ >>> from gymnasium.spaces import Tuple, Box, Discrete
+ >>> observation_space = Tuple((Discrete(2), Box(-1, 1, shape=(2,))), seed=42)
+ >>> observation_space.sample()
+ (np.int64(0), array([-0.3991573 , 0.21649833], dtype=float32))
+ """
+
+ def __init__(
+ self,
+ spaces: Iterable[Space[Any]],
+ seed: int | typing.Sequence[int] | np.random.Generator | None = None,
+ ):
+ r"""Constructor of :class:`Tuple` space.
+
+ The generated instance will represent the cartesian product :math:`\text{spaces}[0] \times ... \times \text{spaces}[-1]`.
+
+ Args:
+ spaces (Iterable[Space]): The spaces that are involved in the cartesian product.
+ seed: Optionally, you can use this argument to seed the RNGs of the ``spaces`` to ensure reproducible sampling.
+ """
+ self.spaces = tuple(spaces)
+ for space in self.spaces:
+ assert isinstance(
+ space, Space
+ ), f"{space} does not inherit from `gymnasium.Space`. Actual Type: {type(space)}"
+ super().__init__(None, None, seed) # type: ignore
+
+ @property
+ def is_np_flattenable(self):
+ """Checks whether this space can be flattened to a :class:`spaces.Box`."""
+ return all(space.is_np_flattenable for space in self.spaces)
+
+
+[docs]
+ def seed(self, seed: int | typing.Sequence[int] | None = None) -> tuple[int, ...]:
+ """Seed the PRNG of this space and all subspaces.
+
+ Depending on the type of seed, the subspaces will be seeded differently
+
+ * ``None`` - All the subspaces will use a random initial seed
+ * ``Int`` - The integer is used to seed the :class:`Tuple` space that is used to generate seed values for each of the subspaces. Warning, this does not guarantee unique seeds for all the subspaces.
+ * ``List`` / ``Tuple`` - Values used to seed the subspaces. This allows the seeding of multiple composite subspaces ``[42, 54, ...]``.
+
+ Args:
+ seed: An optional list of ints or int to seed the (sub-)spaces.
+
+ Returns:
+ A tuple of the seed values for all subspaces
+ """
+ if seed is None:
+ return tuple(space.seed(None) for space in self.spaces)
+ elif isinstance(seed, int):
+ super().seed(seed)
+ subseeds = self.np_random.integers(
+ np.iinfo(np.int32).max, size=len(self.spaces)
+ )
+ return tuple(
+ subspace.seed(int(subseed))
+ for subspace, subseed in zip(self.spaces, subseeds)
+ )
+ elif isinstance(seed, (tuple, list)):
+ if len(seed) != len(self.spaces):
+ raise ValueError(
+ f"Expects that the subspaces of seeds equals the number of subspaces. Actual length of seeds: {len(seed)}, length of subspaces: {len(self.spaces)}"
+ )
+
+ return tuple(
+ space.seed(subseed) for subseed, space in zip(seed, self.spaces)
+ )
+ else:
+ raise TypeError(
+ f"Expected seed type: list, tuple, int or None, actual type: {type(seed)}"
+ )
+
+
+
+[docs]
+ def sample(
+ self,
+ mask: tuple[Any | None, ...] | None = None,
+ probability: tuple[Any | None, ...] | None = None,
+ ) -> tuple[Any, ...]:
+ """Generates a single random sample inside this space.
+
+ This method draws independent samples from the subspaces.
+
+ Args:
+ mask: An optional tuple of optional masks for each of the subspace's samples,
+ expects the same number of masks as spaces
+ probability: An optional tuple of optional probability masks for each of the subspace's samples,
+ expects the same number of probability masks as spaces
+
+ Returns:
+ Tuple of the subspace's samples
+ """
+ if mask is not None and probability is not None:
+ raise ValueError(
+ f"Only one of `mask` or `probability` can be provided, actual values: mask={mask}, probability={probability}"
+ )
+ elif mask is not None:
+ assert isinstance(
+ mask, tuple
+ ), f"Expected type of `mask` to be tuple, actual type: {type(mask)}"
+ assert len(mask) == len(
+ self.spaces
+ ), f"Expected length of `mask` to be {len(self.spaces)}, actual length: {len(mask)}"
+
+ return tuple(
+ space.sample(mask=space_mask)
+ for space, space_mask in zip(self.spaces, mask)
+ )
+
+ elif probability is not None:
+ assert isinstance(
+ probability, tuple
+ ), f"Expected type of `probability` to be tuple, actual type: {type(probability)}"
+ assert len(probability) == len(
+ self.spaces
+ ), f"Expected length of `probability` to be {len(self.spaces)}, actual length: {len(probability)}"
+
+ return tuple(
+ space.sample(probability=space_probability)
+ for space, space_probability in zip(self.spaces, probability)
+ )
+ else:
+ return tuple(space.sample() for space in self.spaces)
+
+
+ def contains(self, x: Any) -> bool:
+ """Return boolean specifying if x is a valid member of this space."""
+ if isinstance(x, (list, np.ndarray)):
+ x = tuple(x) # Promote list and ndarray to tuple for contains check
+
+ return (
+ isinstance(x, tuple)
+ and len(x) == len(self.spaces)
+ and all(space.contains(part) for (space, part) in zip(self.spaces, x))
+ )
+
+ def __repr__(self) -> str:
+ """Gives a string representation of this space."""
+ return "Tuple(" + ", ".join([str(s) for s in self.spaces]) + ")"
+
+ def to_jsonable(
+ self, sample_n: typing.Sequence[tuple[Any, ...]]
+ ) -> list[list[Any]]:
+ """Convert a batch of samples from this space to a JSONable data type."""
+ # serialize as list-repr of tuple of vectors
+ return [
+ space.to_jsonable([sample[i] for sample in sample_n])
+ for i, space in enumerate(self.spaces)
+ ]
+
+ def from_jsonable(self, sample_n: list[list[Any]]) -> list[tuple[Any, ...]]:
+ """Convert a JSONable data type to a batch of samples from this space."""
+ return [
+ sample
+ for sample in zip(
+ *[
+ space.from_jsonable(sample_n[i])
+ for i, space in enumerate(self.spaces)
+ ]
+ )
+ ]
+
+ def __getitem__(self, index: int) -> Space[Any]:
+ """Get the subspace at specific `index`."""
+ return self.spaces[index]
+
+ def __len__(self) -> int:
+ """Get the number of subspaces that are involved in the cartesian product."""
+ return len(self.spaces)
+
+ def __eq__(self, other: Any) -> bool:
+ """Check whether ``other`` is equivalent to this instance."""
+ return isinstance(other, Tuple) and self.spaces == other.spaces
+
+
+"""Implementation of utility functions that can be applied to spaces.
+
+These functions mostly take care of flattening and unflattening elements of spaces
+ to facilitate their usage in learning code.
+"""
+
+from __future__ import annotations
+
+import operator as op
+from functools import reduce, singledispatch
+from typing import Any, TypeVar, Union
+
+import numpy as np
+from numpy.typing import NDArray
+
+import gymnasium as gym
+from gymnasium.spaces import (
+ Box,
+ Dict,
+ Discrete,
+ Graph,
+ GraphInstance,
+ MultiBinary,
+ MultiDiscrete,
+ OneOf,
+ Sequence,
+ Space,
+ Text,
+ Tuple,
+)
+
+
+
+[docs]
+@singledispatch
+def flatdim(space: Space[Any]) -> int:
+ """Return the number of dimensions a flattened equivalent of this space would have.
+
+ Args:
+ space: The space to return the number of dimensions of the flattened spaces
+
+ Returns:
+ The number of dimensions for the flattened spaces
+
+ Raises:
+ NotImplementedError: if the space is not defined in :mod:`gym.spaces`.
+ ValueError: if the space cannot be flattened into a :class:`gymnasium.spaces.Box`
+
+ Example:
+ >>> from gymnasium.spaces import Dict, Discrete
+ >>> space = Dict({"position": Discrete(2), "velocity": Discrete(3)})
+ >>> flatdim(space)
+ 5
+ """
+ if space.is_np_flattenable is False:
+ raise ValueError(
+ f"{space} cannot be flattened to a numpy array, probably because it contains a `Graph` or `Sequence` subspace"
+ )
+
+ raise NotImplementedError(f"Unknown space: `{space}`")
+
+
+
+@flatdim.register(Box)
+@flatdim.register(MultiBinary)
+def _flatdim_box_multibinary(space: Box | MultiBinary) -> int:
+ return reduce(op.mul, space.shape, 1)
+
+
+@flatdim.register(Discrete)
+def _flatdim_discrete(space: Discrete) -> int:
+ return int(space.n)
+
+
+@flatdim.register(MultiDiscrete)
+def _flatdim_multidiscrete(space: MultiDiscrete) -> int:
+ return int(np.sum(space.nvec))
+
+
+@flatdim.register(Tuple)
+def _flatdim_tuple(space: Tuple) -> int:
+ if space.is_np_flattenable:
+ return sum(flatdim(s) for s in space.spaces)
+ raise ValueError(
+ f"{space} cannot be flattened to a numpy array, probably because it contains a `Graph` or `Sequence` subspace"
+ )
+
+
+@flatdim.register(Dict)
+def _flatdim_dict(space: Dict) -> int:
+ if space.is_np_flattenable:
+ return sum(flatdim(s) for s in space.spaces.values())
+ raise ValueError(
+ f"{space} cannot be flattened to a numpy array, probably because it contains a `Graph` or `Sequence` subspace"
+ )
+
+
+@flatdim.register(Graph)
+def _flatdim_graph(space: Graph):
+ raise ValueError(
+ "Cannot get flattened size as the Graph Space in Gym has a dynamic size."
+ )
+
+
+@flatdim.register(Text)
+def _flatdim_text(space: Text) -> int:
+ return space.max_length
+
+
+@flatdim.register(OneOf)
+def _flatdim_oneof(space: OneOf) -> int:
+ return 1 + max(flatdim(s) for s in space.spaces)
+
+
+T = TypeVar("T")
+FlatType = Union[NDArray[Any], dict[str, Any], tuple[Any, ...], GraphInstance]
+
+
+
+[docs]
+@singledispatch
+def flatten(space: Space[T], x: T) -> FlatType:
+ """Flatten a data point from a space.
+
+ This is useful when e.g. points from spaces must be passed to a neural
+ network, which only understands flat arrays of floats.
+
+ Args:
+ space: The space that ``x`` is flattened by
+ x: The value to flatten
+
+ Returns:
+ The flattened datapoint
+
+ - For :class:`gymnasium.spaces.Box` and :class:`gymnasium.spaces.MultiBinary`, this is a flattened array
+ - For :class:`gymnasium.spaces.Discrete` and :class:`gymnasium.spaces.MultiDiscrete`, this is a flattened one-hot array of the sample
+ - For :class:`gymnasium.spaces.Tuple` and :class:`gymnasium.spaces.Dict`, this is a concatenated array the subspaces (does not support graph subspaces)
+ - For graph spaces, returns :class:`GraphInstance` where:
+ - :attr:`GraphInstance.nodes` are n x k arrays
+ - :attr:`GraphInstance.edges` are either:
+ - m x k arrays
+ - None
+ - :attr:`GraphInstance.edge_links` are either:
+ - m x 2 arrays
+ - None
+
+ Raises:
+ NotImplementedError: If the space is not defined in :mod:`gymnasium.spaces`.
+
+ Example:
+ >>> from gymnasium.spaces import Box, Discrete, Tuple
+ >>> space = Box(0, 1, shape=(3, 5))
+ >>> flatten(space, space.sample()).shape
+ (15,)
+ >>> space = Discrete(4)
+ >>> flatten(space, 2)
+ array([0, 0, 1, 0])
+ >>> space = Tuple((Box(0, 1, shape=(2,)), Box(0, 1, shape=(3,)), Discrete(3)))
+ >>> example = ((.5, .25), (1., 0., .2), 1)
+ >>> flatten(space, example)
+ array([0.5 , 0.25, 1. , 0. , 0.2 , 0. , 1. , 0. ])
+ """
+ raise NotImplementedError(f"Unknown space: `{space}`")
+
+
+
+@flatten.register(Box)
+@flatten.register(MultiBinary)
+def _flatten_box_multibinary(space: Box | MultiBinary, x: NDArray[Any]) -> NDArray[Any]:
+ return np.asarray(x, dtype=space.dtype).flatten()
+
+
+@flatten.register(Discrete)
+def _flatten_discrete(space: Discrete, x: np.int64) -> NDArray[np.int64]:
+ onehot = np.zeros(space.n, dtype=space.dtype)
+ onehot[x - space.start] = 1
+ return onehot
+
+
+@flatten.register(MultiDiscrete)
+def _flatten_multidiscrete(
+ space: MultiDiscrete, x: NDArray[np.int64]
+) -> NDArray[np.int64]:
+ offsets = np.zeros((space.nvec.size + 1,), dtype=np.int32)
+ offsets[1:] = np.cumsum(space.nvec.flatten())
+
+ onehot = np.zeros((offsets[-1],), dtype=space.dtype)
+ onehot[offsets[:-1] + (x - space.start).flatten()] = 1
+ return onehot
+
+
+@flatten.register(Tuple)
+def _flatten_tuple(space: Tuple, x: tuple[Any, ...]) -> tuple[Any, ...] | NDArray[Any]:
+ if space.is_np_flattenable:
+ return np.concatenate(
+ [np.array(flatten(s, x_part)) for x_part, s in zip(x, space.spaces)]
+ )
+ return tuple(flatten(s, x_part) for x_part, s in zip(x, space.spaces))
+
+
+@flatten.register(Dict)
+def _flatten_dict(space: Dict, x: dict[str, Any]) -> dict[str, Any] | NDArray[Any]:
+ if space.is_np_flattenable:
+ return np.concatenate(
+ [np.array(flatten(s, x[key])) for key, s in space.spaces.items()]
+ )
+ return {key: flatten(s, x[key]) for key, s in space.spaces.items()}
+
+
+@flatten.register(Graph)
+def _flatten_graph(space: Graph, x: GraphInstance) -> GraphInstance:
+ """We're not using ``.unflatten()`` for :class:`Box` and :class:`Discrete` because a graph is not a homogeneous space, see `.flatten` docstring."""
+
+ def _graph_unflatten(
+ unflatten_space: Discrete | Box | None,
+ unflatten_x: NDArray[Any] | None,
+ ) -> NDArray[Any] | None:
+ ret = None
+ if unflatten_space is not None and unflatten_x is not None:
+ if isinstance(unflatten_space, Box):
+ ret = unflatten_x.reshape(unflatten_x.shape[0], -1)
+ else:
+ assert isinstance(unflatten_space, Discrete)
+ ret = np.zeros(
+ (unflatten_x.shape[0], unflatten_space.n - unflatten_space.start),
+ dtype=unflatten_space.dtype,
+ )
+ ret[
+ np.arange(unflatten_x.shape[0]), unflatten_x - unflatten_space.start
+ ] = 1
+ return ret
+
+ nodes = _graph_unflatten(space.node_space, x.nodes)
+ assert nodes is not None
+ edges = _graph_unflatten(space.edge_space, x.edges)
+
+ return GraphInstance(nodes, edges, x.edge_links)
+
+
+@flatten.register(Text)
+def _flatten_text(space: Text, x: str) -> NDArray[np.int32]:
+ arr = np.full(
+ shape=(space.max_length,), fill_value=len(space.character_set), dtype=np.int32
+ )
+ for i, val in enumerate(x):
+ arr[i] = space.character_index(val)
+ return arr
+
+
+@flatten.register(Sequence)
+def _flatten_sequence(
+ space: Sequence, x: tuple[Any, ...] | Any
+) -> tuple[Any, ...] | Any:
+ if space.stack:
+ samples_iters = gym.vector.utils.iterate(space.stacked_feature_space, x)
+ flattened_samples = [
+ flatten(space.feature_space, sample) for sample in samples_iters
+ ]
+ flattened_space = flatten_space(space.feature_space)
+ out = gym.vector.utils.create_empty_array(
+ flattened_space, n=len(flattened_samples)
+ )
+ return gym.vector.utils.concatenate(flattened_space, flattened_samples, out)
+ else:
+ return tuple(flatten(space.feature_space, item) for item in x)
+
+
+@flatten.register(OneOf)
+def _flatten_oneof(space: OneOf, x: tuple[int, Any]) -> NDArray[Any]:
+ idx, sample = x
+ sub_space = space.spaces[idx]
+ flat_sample = flatten(sub_space, sample)
+
+ max_flatdim = flatdim(space) - 1 # Don't include the index
+ if flat_sample.size < max_flatdim:
+ padding = np.full(
+ max_flatdim - flat_sample.size, flat_sample[0], dtype=flat_sample.dtype
+ )
+ flat_sample = np.concatenate([flat_sample, padding])
+
+ return np.concatenate([[idx], flat_sample])
+
+
+
+[docs]
+@singledispatch
+def unflatten(space: Space[T], x: FlatType) -> T:
+ """Unflatten a data point from a space.
+
+ This reverses the transformation applied by :func:`flatten`. You must ensure
+ that the ``space`` argument is the same as for the :func:`flatten` call.
+
+ Args:
+ space: The space used to unflatten ``x``
+ x: The array to unflatten
+
+ Returns:
+ A point with a structure that matches the space.
+
+ Raises:
+ NotImplementedError: if the space is not defined in :mod:`gymnasium.spaces`.
+ """
+ raise NotImplementedError(f"Unknown space: `{space}`")
+
+
+
+@unflatten.register(Box)
+@unflatten.register(MultiBinary)
+def _unflatten_box_multibinary(
+ space: Box | MultiBinary, x: NDArray[Any]
+) -> NDArray[Any]:
+ return np.asarray(x, dtype=space.dtype).reshape(space.shape)
+
+
+@unflatten.register(Discrete)
+def _unflatten_discrete(space: Discrete, x: NDArray[np.int64]) -> np.int64:
+ nonzero = np.nonzero(x)
+ if len(nonzero[0]) == 0:
+ raise ValueError(
+ f"{x} is not a valid one-hot encoded vector and can not be unflattened to space {space}. "
+ "Not all valid samples in a flattened space can be unflattened."
+ )
+ return space.start + nonzero[0][0]
+
+
+@unflatten.register(MultiDiscrete)
+def _unflatten_multidiscrete(
+ space: MultiDiscrete, x: NDArray[np.integer[Any]]
+) -> NDArray[np.integer[Any]]:
+ offsets = np.zeros((space.nvec.size + 1,), dtype=space.dtype)
+ offsets[1:] = np.cumsum(space.nvec.flatten())
+ (indices,) = np.nonzero(x)
+ if len(indices) == 0:
+ raise ValueError(
+ f"{x} is not a concatenation of one-hot encoded vectors and can not be unflattened to space {space}. "
+ "Not all valid samples in a flattened space can be unflattened."
+ )
+ return (
+ np.asarray(indices - offsets[:-1], dtype=space.dtype).reshape(space.shape)
+ + space.start
+ )
+
+
+@unflatten.register(Tuple)
+def _unflatten_tuple(
+ space: Tuple, x: NDArray[Any] | tuple[Any, ...]
+) -> tuple[Any, ...]:
+ if space.is_np_flattenable:
+ assert isinstance(
+ x, np.ndarray
+ ), f"{space} is numpy-flattenable. Thus, you should only unflatten numpy arrays for this space. Got a {type(x)}"
+ dims = np.asarray([flatdim(s) for s in space.spaces], dtype=np.int_)
+ list_flattened = np.split(x, np.cumsum(dims[:-1]))
+ return tuple(
+ unflatten(s, flattened)
+ for flattened, s in zip(list_flattened, space.spaces)
+ )
+ assert isinstance(
+ x, tuple
+ ), f"{space} is not numpy-flattenable. Thus, you should only unflatten tuples for this space. Got a {type(x)}"
+ return tuple(unflatten(s, flattened) for flattened, s in zip(x, space.spaces))
+
+
+@unflatten.register(Dict)
+def _unflatten_dict(space: Dict, x: NDArray[Any] | dict[str, Any]) -> dict[str, Any]:
+ if space.is_np_flattenable:
+ dims = np.asarray([flatdim(s) for s in space.spaces.values()], dtype=np.int_)
+ list_flattened = np.split(x, np.cumsum(dims[:-1]))
+ return {
+ key: unflatten(s, flattened)
+ for flattened, (key, s) in zip(list_flattened, space.spaces.items())
+ }
+
+ assert isinstance(
+ x, dict
+ ), f"{space} is not numpy-flattenable. Thus, you should only unflatten dictionary for this space. Got a {type(x)}"
+ return {key: unflatten(s, x[key]) for key, s in space.spaces.items()}
+
+
+@unflatten.register(Graph)
+def _unflatten_graph(space: Graph, x: GraphInstance) -> GraphInstance:
+ """We're not using `.unflatten() for :class:`Box` and :class:`Discrete` because a graph is not a homogeneous space.
+
+ The size of the outcome is actually not fixed, but determined based on the number of
+ nodes and edges in the graph.
+ """
+
+ def _graph_unflatten(unflatten_space, unflatten_x):
+ result = None
+ if unflatten_space is not None and unflatten_x is not None:
+ if isinstance(unflatten_space, Box):
+ result = unflatten_x.reshape(-1, *unflatten_space.shape)
+ elif isinstance(unflatten_space, Discrete):
+ result = np.asarray(np.nonzero(unflatten_x))[-1, :]
+ return result
+
+ nodes = _graph_unflatten(space.node_space, x.nodes)
+ edges = _graph_unflatten(space.edge_space, x.edges)
+
+ return GraphInstance(nodes, edges, x.edge_links)
+
+
+@unflatten.register(Text)
+def _unflatten_text(space: Text, x: NDArray[np.int32]) -> str:
+ return "".join(
+ [space.character_list[val] for val in x if val < len(space.character_set)]
+ )
+
+
+@unflatten.register(Sequence)
+def _unflatten_sequence(space: Sequence, x: tuple[Any, ...]) -> tuple[Any, ...] | Any:
+ if space.stack:
+ flattened_space = flatten_space(space.feature_space)
+ flatten_iters = gym.vector.utils.iterate(flattened_space, x)
+ unflattened_samples = [
+ unflatten(space.feature_space, sample) for sample in flatten_iters
+ ]
+ out = gym.vector.utils.create_empty_array(
+ space.feature_space, len(unflattened_samples)
+ )
+ return gym.vector.utils.concatenate(
+ space.feature_space, unflattened_samples, out
+ )
+ else:
+ return tuple(unflatten(space.feature_space, item) for item in x)
+
+
+@unflatten.register(OneOf)
+def _unflatten_oneof(space: OneOf, x: NDArray[Any]) -> tuple[int, Any]:
+ idx = np.int64(x[0])
+ sub_space = space.spaces[idx]
+
+ original_size = flatdim(sub_space)
+ trimmed_sample = x[1 : 1 + original_size]
+
+ return idx, unflatten(sub_space, trimmed_sample)
+
+
+
+[docs]
+@singledispatch
+def flatten_space(space: Space[Any]) -> Box | Dict | Sequence | Tuple | Graph:
+ """Flatten a space into a space that is as flat as possible.
+
+ This function will attempt to flatten ``space`` into a single :class:`gymnasium.spaces.Box` space.
+ However, this might not be possible when ``space`` is an instance of :class:`gymnasium.spaces.Graph`,
+ :class:`gymnasium.spaces.Sequence` or a compound space that contains a :class:`gymnasium.spaces.Graph`
+ or :class:`gymnasium.spaces.Sequence` space.
+ This is equivalent to :func:`flatten`, but operates on the space itself. The
+ result for non-graph spaces is always a :class:`gymnasium.spaces.Box` with flat boundaries. While
+ the result for graph spaces is always a :class:`gymnasium.spaces.Graph` with
+ :attr:`Graph.node_space` being a ``Box``
+ with flat boundaries and :attr:`Graph.edge_space` being a ``Box`` with flat boundaries or
+ ``None``. The box has exactly :func:`flatdim` dimensions. Flattening a sample
+ of the original space has the same effect as taking a sample of the flattened
+ space. However, sampling from the flattened space is not necessarily reversible.
+ For example, sampling from a flattened Discrete space is the same as sampling from
+ a Box, and the results may not be integers or one-hot encodings. This may result in
+ errors or non-uniform sampling.
+
+ Args:
+ space: The space to flatten
+
+ Returns:
+ A flattened Box
+
+ Raises:
+ NotImplementedError: if the space is not defined in :mod:`gymnasium.spaces`.
+
+ Example - Flatten spaces.Box:
+ >>> from gymnasium.spaces import Box
+ >>> box = Box(0.0, 1.0, shape=(3, 4, 5))
+ >>> box
+ Box(0.0, 1.0, (3, 4, 5), float32)
+ >>> flatten_space(box)
+ Box(0.0, 1.0, (60,), float32)
+ >>> flatten(box, box.sample()) in flatten_space(box)
+ True
+
+ Example - Flatten spaces.Discrete:
+ >>> from gymnasium.spaces import Discrete
+ >>> discrete = Discrete(5)
+ >>> flatten_space(discrete)
+ Box(0, 1, (5,), int64)
+ >>> flatten(discrete, discrete.sample()) in flatten_space(discrete)
+ True
+
+ Example - Flatten spaces.Dict:
+ >>> from gymnasium.spaces import Dict, Discrete, Box
+ >>> space = Dict({"position": Discrete(2), "velocity": Box(0, 1, shape=(2, 2))})
+ >>> flatten_space(space)
+ Box(0.0, 1.0, (6,), float64)
+ >>> flatten(space, space.sample()) in flatten_space(space)
+ True
+
+ Example - Flatten spaces.Graph:
+ >>> from gymnasium.spaces import Graph, Discrete, Box
+ >>> space = Graph(node_space=Box(low=-100, high=100, shape=(3, 4)), edge_space=Discrete(5))
+ >>> flatten_space(space)
+ Graph(Box(-100.0, 100.0, (12,), float32), Box(0, 1, (5,), int64))
+ >>> flatten(space, space.sample()) in flatten_space(space)
+ True
+ """
+ raise NotImplementedError(f"Unknown space: `{space}`")
+
+
+
+@flatten_space.register(Box)
+def _flatten_space_box(space: Box) -> Box:
+ return Box(space.low.flatten(), space.high.flatten(), dtype=space.dtype)
+
+
+@flatten_space.register(Discrete)
+@flatten_space.register(MultiBinary)
+@flatten_space.register(MultiDiscrete)
+def _flatten_space_binary(space: Discrete | MultiBinary | MultiDiscrete) -> Box:
+ return Box(low=0, high=1, shape=(flatdim(space),), dtype=space.dtype)
+
+
+@flatten_space.register(Tuple)
+def _flatten_space_tuple(space: Tuple) -> Box | Tuple:
+ if space.is_np_flattenable:
+ space_list = [flatten_space(s) for s in space.spaces]
+ return Box(
+ low=np.concatenate([s.low for s in space_list]),
+ high=np.concatenate([s.high for s in space_list]),
+ dtype=np.result_type(*[s.dtype for s in space_list]),
+ )
+ return Tuple(spaces=[flatten_space(s) for s in space.spaces])
+
+
+@flatten_space.register(Dict)
+def _flatten_space_dict(space: Dict) -> Box | Dict:
+ if space.is_np_flattenable:
+ space_list = [flatten_space(s) for s in space.spaces.values()]
+ return Box(
+ low=np.concatenate([s.low for s in space_list]),
+ high=np.concatenate([s.high for s in space_list]),
+ dtype=np.result_type(*[s.dtype for s in space_list]),
+ )
+ return Dict(
+ spaces={key: flatten_space(space) for key, space in space.spaces.items()}
+ )
+
+
+@flatten_space.register(Graph)
+def _flatten_space_graph(space: Graph) -> Graph:
+ return Graph(
+ node_space=flatten_space(space.node_space),
+ edge_space=(
+ flatten_space(space.edge_space) if space.edge_space is not None else None
+ ),
+ )
+
+
+@flatten_space.register(Text)
+def _flatten_space_text(space: Text) -> Box:
+ return Box(
+ low=0, high=len(space.character_set), shape=(space.max_length,), dtype=np.int32
+ )
+
+
+@flatten_space.register(Sequence)
+def _flatten_space_sequence(space: Sequence) -> Sequence:
+ return Sequence(flatten_space(space.feature_space), stack=space.stack)
+
+
+@flatten_space.register(OneOf)
+def _flatten_space_oneof(space: OneOf) -> Box:
+ num_subspaces = len(space.spaces)
+ max_flatdim = max(flatdim(s) for s in space.spaces) + 1
+
+ lows = np.array([np.min(flatten_space(s).low) for s in space.spaces])
+ highs = np.array([np.max(flatten_space(s).high) for s in space.spaces])
+
+ overall_low = np.min(lows)
+ overall_high = np.max(highs)
+
+ low = np.concatenate([[0], np.full(max_flatdim - 1, overall_low)])
+ high = np.concatenate([[num_subspaces - 1], np.full(max_flatdim - 1, overall_high)])
+
+ dtype = np.result_type(*[s.dtype for s in space.spaces if hasattr(s, "dtype")])
+ return Box(low=low, high=high, shape=(max_flatdim,), dtype=dtype)
+
+
+@singledispatch
+def is_space_dtype_shape_equiv(space_1: Space, space_2: Space) -> bool:
+ """Returns if two spaces share a common dtype and shape (plus any critical variables).
+
+ This function is primarily used to check for compatibility of different spaces in a vector environment.
+
+ Args:
+ space_1: A Gymnasium space
+ space_2: A Gymnasium space
+
+ Returns:
+ If the two spaces share a common dtype and shape (plus any critical variables).
+ """
+ if isinstance(space_1, Space) and isinstance(space_2, Space):
+ raise NotImplementedError(
+ "`check_dtype_shape_equivalence` doesn't support Generic Gymnasium Spaces, "
+ )
+ else:
+ raise TypeError()
+
+
+@is_space_dtype_shape_equiv.register(Box)
+@is_space_dtype_shape_equiv.register(Discrete)
+@is_space_dtype_shape_equiv.register(MultiDiscrete)
+@is_space_dtype_shape_equiv.register(MultiBinary)
+def _is_space_fundamental_dtype_shape_equiv(space_1, space_2):
+ return (
+ # this check is necessary as singledispatch only checks the first variable and there are many options
+ type(space_1) is type(space_2)
+ and space_1.shape == space_2.shape
+ and space_1.dtype == space_2.dtype
+ )
+
+
+@is_space_dtype_shape_equiv.register(Text)
+def _is_space_text_dtype_shape_equiv(space_1: Text, space_2):
+ return (
+ isinstance(space_2, Text)
+ and space_1.max_length == space_2.max_length
+ and space_1.character_set == space_2.character_set
+ )
+
+
+@is_space_dtype_shape_equiv.register(Dict)
+def _is_space_dict_dtype_shape_equiv(space_1: Dict, space_2):
+ return (
+ isinstance(space_2, Dict)
+ and space_1.keys() == space_2.keys()
+ and all(
+ is_space_dtype_shape_equiv(space_1[key], space_2[key])
+ for key in space_1.keys()
+ )
+ )
+
+
+@is_space_dtype_shape_equiv.register(Tuple)
+def _is_space_tuple_dtype_shape_equiv(space_1, space_2):
+ return isinstance(space_2, Tuple) and all(
+ is_space_dtype_shape_equiv(space_1[i], space_2[i]) for i in range(len(space_1))
+ )
+
+
+@is_space_dtype_shape_equiv.register(Graph)
+def _is_space_graph_dtype_shape_equiv(space_1: Graph, space_2):
+ return (
+ isinstance(space_2, Graph)
+ and is_space_dtype_shape_equiv(space_1.node_space, space_2.node_space)
+ and (
+ (space_1.edge_space is None and space_2.edge_space is None)
+ or (
+ space_1.edge_space is not None
+ and space_2.edge_space is not None
+ and is_space_dtype_shape_equiv(space_1.edge_space, space_2.edge_space)
+ )
+ )
+ )
+
+
+@is_space_dtype_shape_equiv.register(OneOf)
+def _is_space_oneof_dtype_shape_equiv(space_1: OneOf, space_2):
+ return (
+ isinstance(space_2, OneOf)
+ and len(space_1) == len(space_2)
+ and all(
+ is_space_dtype_shape_equiv(space_1[i], space_2[i])
+ for i in range(len(space_1))
+ )
+ )
+
+
+@is_space_dtype_shape_equiv.register(Sequence)
+def _is_space_sequence_dtype_shape_equiv(space_1: Sequence, space_2):
+ return (
+ isinstance(space_2, Sequence)
+ and space_1.stack is space_2.stack
+ and is_space_dtype_shape_equiv(space_1.feature_space, space_2.feature_space)
+ )
+
+"""A set of functions for checking an environment implementation.
+
+This file is originally from the Stable Baselines3 repository hosted on GitHub
+(https://github.com/DLR-RM/stable-baselines3/)
+Original Author: Antonin Raffin
+
+It also uses some warnings/assertions from the PettingZoo repository hosted on GitHub
+(https://github.com/PettingZoo-Team/PettingZoo)
+Original Author: J K Terry
+
+This was rewritten and split into "env_checker.py" and "passive_env_checker.py" for invasive and passive environment checking
+Original Author: Mark Towers
+
+These projects are covered by the MIT License.
+"""
+
+import inspect
+from copy import deepcopy
+
+import numpy as np
+
+import gymnasium as gym
+from gymnasium import logger, spaces
+from gymnasium.utils.passive_env_checker import (
+ check_action_space,
+ check_observation_space,
+ env_render_passive_checker,
+ env_reset_passive_checker,
+ env_step_passive_checker,
+)
+
+
+def data_equivalence(data_1, data_2, exact: bool = False) -> bool:
+ """Assert equality between data 1 and 2, i.e. observations, actions, info.
+
+ Args:
+ data_1: data structure 1
+ data_2: data structure 2
+ exact: whether to compare array exactly or not if false compares with absolute and relative tolerance of 1e-5 (for more information check [np.allclose](https://numpy.org/doc/stable/reference/generated/numpy.allclose.html)).
+
+ Returns:
+ If observation 1 and 2 are equivalent
+ """
+ if type(data_1) is not type(data_2):
+ return False
+ elif isinstance(data_1, dict):
+ return data_1.keys() == data_2.keys() and all(
+ data_equivalence(data_1[k], data_2[k], exact) for k in data_1.keys()
+ )
+ elif isinstance(data_1, (tuple, list)):
+ return len(data_1) == len(data_2) and all(
+ data_equivalence(o_1, o_2, exact) for o_1, o_2 in zip(data_1, data_2)
+ )
+ elif isinstance(data_1, np.ndarray):
+ if data_1.shape == data_2.shape and data_1.dtype == data_2.dtype:
+ if data_1.dtype == object:
+ return all(
+ data_equivalence(a, b, exact) for a, b in zip(data_1, data_2)
+ )
+ else:
+ if exact:
+ return np.all(data_1 == data_2)
+ else:
+ return np.allclose(data_1, data_2, rtol=1e-5, atol=1e-5)
+ else:
+ return False
+ else:
+ return data_1 == data_2
+
+
+def check_reset_seed_determinism(env: gym.Env):
+ """Check that the environment can be reset with a seed.
+
+ Args:
+ env: The environment to check
+
+ Raises:
+ AssertionError: The environment cannot be reset with a random seed,
+ even though `seed` or `kwargs` appear in the signature.
+ """
+ signature = inspect.signature(env.reset)
+ if "seed" in signature.parameters or (
+ "kwargs" in signature.parameters
+ and signature.parameters["kwargs"].kind is inspect.Parameter.VAR_KEYWORD
+ ):
+ try:
+ obs_1, info = env.reset(seed=123)
+ assert (
+ obs_1 in env.observation_space
+ ), "The observation returned by `env.reset(seed=123)` is not within the observation space."
+ assert (
+ env.unwrapped._np_random is not None
+ ), "Expects the random number generator to have been generated given a seed was passed to reset. Most likely the environment reset function does not call `super().reset(seed=seed)`."
+ seed_123_rng_1 = deepcopy(env.unwrapped._np_random)
+
+ obs_2, info = env.reset()
+ assert (
+ obs_2 in env.observation_space
+ ), "The observation returned by `env.reset()` is not within the observation space."
+
+ obs_3, info = env.reset(seed=123)
+ assert (
+ obs_3 in env.observation_space
+ ), "The observation returned by `env.reset(seed=123)` is not within the observation space."
+ seed_123_rng_3 = deepcopy(env.unwrapped._np_random)
+
+ obs_4, info = env.reset()
+ assert (
+ obs_4 in env.observation_space
+ ), "The observation returned by `env.reset()` is not within the observation space."
+
+ if env.spec is not None and env.spec.nondeterministic is False:
+ assert data_equivalence(
+ obs_1, obs_3
+ ), "Using `env.reset(seed=123)` is non-deterministic as the observations are not equivalent."
+ assert data_equivalence(
+ obs_2, obs_4
+ ), "Using `env.reset(seed=123)` then `env.reset()` is non-deterministic as the observations are not equivalent."
+ if not data_equivalence(obs_1, obs_3, exact=True):
+ logger.warn(
+ "Using `env.reset(seed=123)` observations are not equal although similar."
+ )
+ if not data_equivalence(obs_2, obs_4, exact=True):
+ logger.warn(
+ "Using `env.reset(seed=123)` then `env.reset()` observations are not equal although similar."
+ )
+
+ assert (
+ seed_123_rng_1.bit_generator.state == seed_123_rng_3.bit_generator.state
+ ), "Most likely the environment reset function does not call `super().reset(seed=seed)` as the random generates are not same when the same seeds are passed to `env.reset`."
+
+ obs_5, info = env.reset(seed=456)
+ assert (
+ obs_5 in env.observation_space
+ ), "The observation returned by `env.reset(seed=456)` is not within the observation space."
+ assert (
+ env.unwrapped._np_random.bit_generator.state
+ != seed_123_rng_1.bit_generator.state
+ ), "Most likely the environment reset function does not call `super().reset(seed=seed)` as the random number generators are not different when different seeds are passed to `env.reset`."
+
+ except TypeError as e:
+ raise AssertionError(
+ "The environment cannot be reset with a random seed, even though `seed` or `kwargs` appear in the signature. "
+ f"This should never happen, please report this issue. The error was: {e}"
+ ) from e
+
+ seed_param = signature.parameters.get("seed")
+ # Check the default value is None
+ if seed_param is not None and seed_param.default is not None:
+ logger.warn(
+ "The default seed argument in reset should be `None`, otherwise the environment will by default always be deterministic. "
+ f"Actual default: {seed_param.default}"
+ )
+ else:
+ raise gym.error.Error(
+ "The `reset` method does not provide a `seed` or `**kwargs` keyword argument."
+ )
+
+
+def check_reset_options(env: gym.Env):
+ """Check that the environment can be reset with options.
+
+ Args:
+ env: The environment to check
+
+ Raises:
+ AssertionError: The environment cannot be reset with options,
+ even though `options` or `kwargs` appear in the signature.
+ """
+ signature = inspect.signature(env.reset)
+ if "options" in signature.parameters or (
+ "kwargs" in signature.parameters
+ and signature.parameters["kwargs"].kind is inspect.Parameter.VAR_KEYWORD
+ ):
+ try:
+ env.reset(options={})
+ except TypeError as e:
+ raise AssertionError(
+ "The environment cannot be reset with options, even though `options` or `**kwargs` appear in the signature. "
+ f"This should never happen, please report this issue. The error was: {e}"
+ ) from e
+ else:
+ raise gym.error.Error(
+ "The `reset` method does not provide an `options` or `**kwargs` keyword argument."
+ )
+
+
+def check_step_determinism(env: gym.Env, seed=123):
+ """Check that the environment steps deterministically after reset.
+
+ Note: This check assumes that seeded `reset()` is deterministic (it must have passed `check_reset_seed`) and that `step()` returns valid values (passed `env_step_passive_checker`).
+ Note: A single step should be enough to assert that the state transition function is deterministic (at least for most environments).
+
+ Raises:
+ AssertionError: The environment cannot be step deterministically after resetting with a random seed,
+ or it truncates after 1 step.
+ """
+ if env.spec is not None and env.spec.nondeterministic is True:
+ return
+
+ env.action_space.seed(seed)
+ action = env.action_space.sample()
+
+ env.reset(seed=seed)
+ obs_0, rew_0, term_0, trunc_0, info_0 = env.step(action)
+ seeded_rng: np.random.Generator = deepcopy(env.unwrapped._np_random)
+
+ env.reset(seed=seed)
+ obs_1, rew_1, term_1, trunc_1, info_1 = env.step(action)
+
+ assert (
+ env.unwrapped._np_random.bit_generator.state # pyright: ignore [reportOptionalMemberAccess]
+ == seeded_rng.bit_generator.state
+ ), "The `.np_random` is not properly been updated after step."
+
+ assert data_equivalence(
+ obs_0, obs_1
+ ), "Deterministic step observations are not equivalent for the same seed and action"
+ if not data_equivalence(obs_0, obs_1, exact=True):
+ logger.warn(
+ "Step observations are not equal although similar given the same seed and action"
+ )
+
+ assert data_equivalence(
+ rew_0, rew_1
+ ), "Deterministic step rewards are not equivalent for the same seed and action"
+ if not data_equivalence(rew_0, rew_1, exact=True):
+ logger.warn(
+ "Step rewards are not equal although similar given the same seed and action"
+ )
+
+ assert data_equivalence(
+ term_0, term_1, exact=True
+ ), "Deterministic step termination are not equivalent for the same seed and action"
+ assert (
+ trunc_0 is False and trunc_1 is False
+ ), "Environment truncates after 1 step, something has gone very wrong."
+
+ assert data_equivalence(
+ info_0,
+ info_1,
+ ), "Deterministic step info are not equivalent for the same seed and action"
+ if not data_equivalence(info_0, info_1, exact=True):
+ logger.warn(
+ "Step info are not equal although similar given the same seed and action"
+ )
+
+
+def check_reset_return_info_deprecation(env: gym.Env):
+ """Makes sure support for deprecated `return_info` argument is dropped.
+
+ Args:
+ env: The environment to check
+ Raises:
+ UserWarning
+ """
+ signature = inspect.signature(env.reset)
+ if "return_info" in signature.parameters:
+ logger.warn(
+ "`return_info` is deprecated as an optional argument to `reset`. `reset`"
+ "should now always return `obs, info` where `obs` is an observation, and `info` is a dictionary"
+ "containing additional information."
+ )
+
+
+def check_seed_deprecation(env: gym.Env):
+ """Makes sure support for deprecated function `seed` is dropped.
+
+ Args:
+ env: The environment to check
+ Raises:
+ UserWarning
+ """
+ seed_fn = getattr(env, "seed", None)
+ if callable(seed_fn):
+ logger.warn(
+ "Official support for the `seed` function is dropped. "
+ "Standard practice is to reset gymnasium environments using `env.reset(seed=<desired seed>)`"
+ )
+
+
+def check_reset_return_type(env: gym.Env):
+ """Checks that :meth:`reset` correctly returns a tuple of the form `(obs , info)`.
+
+ Args:
+ env: The environment to check
+ Raises:
+ AssertionError depending on spec violation
+ """
+ result = env.reset()
+ assert isinstance(
+ result, tuple
+ ), f"The result returned by `env.reset()` was not a tuple of the form `(obs, info)`, where `obs` is a observation and `info` is a dictionary containing additional information. Actual type: `{type(result)}`"
+ assert (
+ len(result) == 2
+ ), f"Calling the reset method did not return a 2-tuple, actual length: {len(result)}"
+
+ obs, info = result
+ assert (
+ obs in env.observation_space
+ ), "The first element returned by `env.reset()` is not within the observation space."
+ assert isinstance(
+ info, dict
+ ), f"The second element returned by `env.reset()` was not a dictionary, actual type: {type(info)}"
+
+
+def check_space_limit(space, space_type: str):
+ """Check the space limit for only the Box space as a test that only runs as part of `check_env`."""
+ if isinstance(space, spaces.Box):
+ if np.any(np.equal(space.low, -np.inf)):
+ logger.warn(
+ f"A Box {space_type} space minimum value is -infinity. This is probably too low."
+ )
+ if np.any(np.equal(space.high, np.inf)):
+ logger.warn(
+ f"A Box {space_type} space maximum value is infinity. This is probably too high."
+ )
+
+ # Check that the Box space is normalized
+ if space_type == "action":
+ if len(space.shape) == 1: # for vector boxes
+ if (
+ np.any(
+ np.logical_and(
+ space.low != np.zeros_like(space.low),
+ np.abs(space.low) != np.abs(space.high),
+ )
+ )
+ or np.any(space.low < -1)
+ or np.any(space.high > 1)
+ ):
+ # todo - Add to gymlibrary.ml?
+ logger.warn(
+ "For Box action spaces, we recommend using a symmetric and normalized space (range=[-1, 1] or [0, 1]). "
+ "See https://stable-baselines3.readthedocs.io/en/master/guide/rl_tips.html for more information."
+ )
+ elif isinstance(space, spaces.Tuple):
+ for subspace in space.spaces:
+ check_space_limit(subspace, space_type)
+ elif isinstance(space, spaces.Dict):
+ for subspace in space.values():
+ check_space_limit(subspace, space_type)
+
+
+
+[docs]
+def check_env(
+ env: gym.Env,
+ warn: bool = None,
+ skip_render_check: bool = False,
+ skip_close_check: bool = False,
+):
+ """Check that an environment follows Gymnasium's API.
+
+ .. py:currentmodule:: gymnasium.Env
+
+ To ensure that an environment is implemented "correctly", ``check_env`` checks that the :attr:`observation_space` and :attr:`action_space` are correct.
+ Furthermore, the function will call the :meth:`reset`, :meth:`step` and :meth:`render` functions with a variety of values.
+
+ We highly recommend users call this function after an environment is constructed and within a project's continuous integration to keep an environment update with Gymnasium's API.
+
+ Args:
+ env: The Gym environment that will be checked
+ warn: Ignored, previously silenced particular warnings
+ skip_render_check: Whether to skip the checks for the render method. False by default (useful for the CI)
+ skip_close_check: Whether to skip the checks for the close method. False by default
+ """
+ if warn is not None:
+ logger.warn("`check_env(warn=...)` parameter is now ignored.")
+
+ if not isinstance(env, gym.Env):
+ if (
+ str(env.__class__.__base__) == "<class 'gym.core.Env'>"
+ or str(env.__class__.__base__) == "<class 'gym.core.Wrapper'>"
+ ):
+ raise TypeError(
+ "Gym is incompatible with Gymnasium, please update the environment class to `gymnasium.Env`. "
+ "See https://gymnasium.farama.org/introduction/create_custom_env/ for more info."
+ )
+ else:
+ raise TypeError(
+ f"The environment must inherit from the gymnasium.Env class, actual class: {type(env)}. "
+ "See https://gymnasium.farama.org/introduction/create_custom_env/ for more info."
+ )
+ if env.unwrapped is not env:
+ logger.warn(
+ f"The environment ({env}) is different from the unwrapped version ({env.unwrapped}). This could effect the environment checker as the environment most likely has a wrapper applied to it. We recommend using the raw environment for `check_env` using `env.unwrapped`."
+ )
+
+ if env.metadata.get("jax", False):
+ env = gym.wrappers.JaxToNumpy(env)
+ elif env.metadata.get("torch", False):
+ env = gym.wrappers.TorchToNumpy(env)
+
+ # ============= Check the spaces (observation and action) ================
+ if not hasattr(env, "action_space"):
+ raise AttributeError(
+ "The environment must specify an action space. See https://gymnasium.farama.org/introduction/create_custom_env/ for more info."
+ )
+ check_action_space(env.action_space)
+ check_space_limit(env.action_space, "action")
+
+ if not hasattr(env, "observation_space"):
+ raise AttributeError(
+ "The environment must specify an observation space. See https://gymnasium.farama.org/introduction/create_custom_env/ for more info."
+ )
+ check_observation_space(env.observation_space)
+ check_space_limit(env.observation_space, "observation")
+
+ # ==== Check the reset method ====
+ check_seed_deprecation(env)
+ check_reset_return_info_deprecation(env)
+ check_reset_return_type(env)
+ check_reset_seed_determinism(env)
+ check_reset_options(env)
+
+ # ============ Check the returned values ===============
+ env_reset_passive_checker(env)
+ env_step_passive_checker(env, env.action_space.sample())
+
+ # ==== Check the step method ====
+ check_step_determinism(env)
+
+ # ==== Check the render method and the declared render modes ====
+ if not skip_render_check:
+ if env.render_mode is not None:
+ env_render_passive_checker(env)
+
+ if env.spec is not None:
+ for render_mode in env.metadata["render_modes"]:
+ new_env = env.spec.make(render_mode=render_mode)
+ new_env.reset()
+ env_render_passive_checker(new_env)
+ new_env.close()
+ else:
+ logger.warn(
+ "Not able to test alternative render modes due to the environment not having a spec. Try instantiating the environment through `gymnasium.make`"
+ )
+
+ if not skip_close_check and env.spec is not None:
+ new_env = env.spec.make()
+ new_env.close()
+ try:
+ new_env.close()
+ except Exception as e:
+ logger.warn(
+ f"Calling `env.close()` on the closed environment should be allowed, but it raised an exception: {e}"
+ )
+
+
+"""Class for pickling and unpickling objects via their constructor arguments."""
+
+from typing import Any
+
+
+
+[docs]
+class EzPickle:
+ """Objects that are pickled and unpickled via their constructor arguments.
+
+ Example:
+ >>> class Animal: pass
+ >>> class Dog(Animal, EzPickle):
+ ... def __init__(self, furcolor, tailkind="bushy"):
+ ... Animal.__init__(self)
+ ... EzPickle.__init__(self, furcolor, tailkind)
+
+ When this object is unpickled, a new ``Dog`` will be constructed by passing the provided furcolor and tailkind into the constructor.
+ However, philosophers are still not sure whether it is still the same dog.
+
+ This is generally needed only for environments which wrap C/C++ code, such as MuJoCo and Atari.
+ """
+
+ def __init__(self, *args: Any, **kwargs: Any):
+ """Uses the ``args`` and ``kwargs`` from the object's constructor for pickling."""
+ self._ezpickle_args = args
+ self._ezpickle_kwargs = kwargs
+
+ def __getstate__(self):
+ """Returns the object pickle state with args and kwargs."""
+ return {
+ "_ezpickle_args": self._ezpickle_args,
+ "_ezpickle_kwargs": self._ezpickle_kwargs,
+ }
+
+ def __setstate__(self, d):
+ """Sets the object pickle state using d."""
+ out = type(self)(*d["_ezpickle_args"], **d["_ezpickle_kwargs"])
+ self.__dict__.update(out.__dict__)
+
+
+"""A collection of runtime performance bencharks, useful for debugging performance related issues."""
+
+import time
+from collections.abc import Callable
+
+import gymnasium
+
+
+
+[docs]
+def benchmark_step(env: gymnasium.Env, target_duration: int = 5, seed=None) -> float:
+ """A benchmark to measure the runtime performance of step for an environment.
+
+ example usage:
+ ```py
+ env_old = ...
+ old_throughput = benchmark_step(env_old)
+ env_new = ...
+ new_throughput = benchmark_step(env_old)
+ slowdown = old_throughput / new_throughput
+ ```
+
+ Args:
+ env: the environment to benchmarked.
+ target_duration: the duration of the benchmark in seconds (note: it will go slightly over it).
+ seed: seeds the environment and action sampled.
+
+ Returns: the average steps per second.
+ """
+ steps = 0
+ end = 0.0
+ env.reset(seed=seed)
+ env.action_space.sample()
+ start = time.time()
+
+ while True:
+ steps += 1
+ action = env.action_space.sample()
+ _, _, terminal, truncated, _ = env.step(action)
+
+ if terminal or truncated:
+ env.reset()
+
+ if time.time() - start > target_duration:
+ end = time.time()
+ break
+
+ length = end - start
+
+ steps_per_time = steps / length
+ return steps_per_time
+
+
+
+
+[docs]
+def benchmark_init(
+ env_lambda: Callable[[], gymnasium.Env], target_duration: int = 5, seed=None
+) -> float:
+ """A benchmark to measure the initialization time and first reset.
+
+ Args:
+ env_lambda: the function to initialize the environment.
+ target_duration: the duration of the benchmark in seconds (note: it will go slightly over it).
+ seed: seeds the first reset of the environment.
+ """
+ inits = 0
+ end = 0.0
+ start = time.time()
+ while True:
+ inits += 1
+ env = env_lambda()
+ env.reset(seed=seed)
+
+ if time.time() - start > target_duration:
+ end = time.time()
+ break
+ length = end - start
+
+ inits_per_time = inits / length
+ return inits_per_time
+
+
+
+
+[docs]
+def benchmark_render(env: gymnasium.Env, target_duration: int = 5) -> float:
+ """A benchmark to measure the time of render().
+
+ Note: does not work with `render_mode='human'`
+ Args:
+ env: the environment to benchmarked (Note: must be renderable).
+ target_duration: the duration of the benchmark in seconds (note: it will go slightly over it).
+
+ """
+ renders = 0
+ end = 0.0
+ start = time.time()
+ while True:
+ renders += 1
+ env.render()
+
+ if time.time() - start > target_duration:
+ end = time.time()
+ break
+ length = end - start
+
+ renders_per_time = renders / length
+ return renders_per_time
+
+
+"""Utilities of visualising an environment."""
+
+from __future__ import annotations
+
+from collections import deque
+from collections.abc import Callable
+from typing import TYPE_CHECKING
+
+import numpy as np
+
+import gymnasium as gym
+from gymnasium import Env, logger
+from gymnasium.core import ActType, ObsType
+from gymnasium.error import DependencyNotInstalled
+
+
+if TYPE_CHECKING:
+ from matplotlib.axes import Axes
+
+
+try:
+ import pygame
+ from pygame import Surface
+ from pygame.event import Event
+except ImportError as e:
+ raise gym.error.DependencyNotInstalled(
+ 'pygame is not installed, run `pip install "gymnasium[classic_control]"`'
+ ) from e
+
+try:
+ import matplotlib
+
+ matplotlib.use("TkAgg")
+ import matplotlib.pyplot as plt
+except ImportError:
+ logger.warn('matplotlib is not installed, run `pip install "gymnasium[other]"`')
+ matplotlib, plt = None, None
+
+
+class MissingKeysToAction(Exception):
+ """Raised when the environment does not have a default ``keys_to_action`` mapping."""
+
+
+
+[docs]
+class PlayableGame:
+ """Wraps an environment allowing keyboard inputs to interact with the environment."""
+
+ def __init__(
+ self,
+ env: Env,
+ keys_to_action: dict[tuple[int, ...], int] | None = None,
+ zoom: float | None = None,
+ ):
+ """Wraps an environment with a dictionary of keyboard buttons to action and if to zoom in on the environment.
+
+ Args:
+ env: The environment to play
+ keys_to_action: The dictionary of keyboard tuples and action value
+ zoom: If to zoom in on the environment render
+ """
+ if env.render_mode not in {"rgb_array", "rgb_array_list"}:
+ raise ValueError(
+ "PlayableGame wrapper works only with rgb_array and rgb_array_list render modes, "
+ f"but your environment render_mode = {env.render_mode}."
+ )
+
+ self.env = env
+ self.relevant_keys = self._get_relevant_keys(keys_to_action)
+ # self.video_size is the size of the video that is being displayed.
+ # The window size may be larger, in that case we will add black bars
+ self.video_size = self._get_video_size(zoom)
+ self.screen = pygame.display.set_mode(self.video_size, pygame.RESIZABLE)
+ self.pressed_keys = []
+ self.running = True
+
+ def _get_relevant_keys(
+ self, keys_to_action: dict[tuple[int], int] | None = None
+ ) -> set:
+ if keys_to_action is None:
+ if self.env.has_wrapper_attr("get_keys_to_action"):
+ keys_to_action = self.env.get_wrapper_attr("get_keys_to_action")()
+ else:
+ assert self.env.spec is not None
+ raise MissingKeysToAction(
+ f"{self.env.spec.id} does not have explicit key to action mapping, "
+ "please specify one manually, `play(env, keys_to_action=...)`"
+ )
+ assert isinstance(keys_to_action, dict)
+ relevant_keys = set(sum((list(k) for k in keys_to_action.keys()), []))
+ return relevant_keys
+
+ def _get_video_size(self, zoom: float | None = None) -> tuple[int, int]:
+ rendered = self.env.render()
+ if isinstance(rendered, list):
+ rendered = rendered[-1]
+ assert rendered is not None and isinstance(rendered, np.ndarray)
+ video_size = (rendered.shape[1], rendered.shape[0])
+
+ if zoom is not None:
+ video_size = (int(video_size[0] * zoom), int(video_size[1] * zoom))
+
+ return video_size
+
+
+[docs]
+ def process_event(self, event: Event):
+ """Processes a PyGame event.
+
+ In particular, this function is used to keep track of which buttons are currently pressed
+ and to exit the :func:`play` function when the PyGame window is closed.
+
+ Args:
+ event: The event to process
+ """
+ if event.type == pygame.KEYDOWN:
+ if event.key in self.relevant_keys:
+ self.pressed_keys.append(event.key)
+ elif event.key == pygame.K_ESCAPE:
+ self.running = False
+ elif event.type == pygame.KEYUP:
+ if event.key in self.relevant_keys:
+ self.pressed_keys.remove(event.key)
+ elif event.type == pygame.QUIT:
+ self.running = False
+ elif event.type == pygame.WINDOWRESIZED:
+ # Compute the maximum video size that fits into the new window
+ scale_width = event.x / self.video_size[0]
+ scale_height = event.y / self.video_size[1]
+ scale = min(scale_height, scale_width)
+ self.video_size = (scale * self.video_size[0], scale * self.video_size[1])
+
+
+
+
+def display_arr(
+ screen: Surface, arr: np.ndarray, video_size: tuple[int, int], transpose: bool
+):
+ """Displays a numpy array on screen.
+
+ Args:
+ screen: The screen to show the array on
+ arr: The array to show
+ video_size: The video size of the screen
+ transpose: If to transpose the array on the screen
+ """
+ assert isinstance(arr, np.ndarray) and arr.dtype == np.uint8
+ pyg_img = pygame.surfarray.make_surface(arr.swapaxes(0, 1) if transpose else arr)
+ pyg_img = pygame.transform.scale(pyg_img, video_size)
+ # We might have to add black bars if surface_size is larger than video_size
+ surface_size = screen.get_size()
+ width_offset = (surface_size[0] - video_size[0]) / 2
+ height_offset = (surface_size[1] - video_size[1]) / 2
+ screen.fill((0, 0, 0))
+ screen.blit(pyg_img, (width_offset, height_offset))
+
+
+
+[docs]
+def play(
+ env: Env,
+ transpose: bool | None = True,
+ fps: int | None = None,
+ zoom: float | None = None,
+ callback: Callable | None = None,
+ keys_to_action: dict[tuple[str | int, ...] | str | int, ActType] | None = None,
+ seed: int | None = None,
+ noop: ActType = 0,
+ wait_on_player: bool = False,
+):
+ """Allows the user to play the environment using a keyboard.
+
+ If playing in a turn-based environment, set wait_on_player to True.
+
+ Args:
+ env: Environment to use for playing.
+ transpose: If this is ``True``, the output of observation is transposed. Defaults to ``True``.
+ fps: Maximum number of steps of the environment executed every second. If ``None`` (the default),
+ ``env.metadata["render_fps""]`` (or 30, if the environment does not specify "render_fps") is used.
+ zoom: Zoom the observation in, ``zoom`` amount, should be positive float
+ callback: If a callback is provided, it will be executed after every step. It takes the following input:
+
+ * obs_t: observation before performing action
+ * obs_tp1: observation after performing action
+ * action: action that was executed
+ * rew: reward that was received
+ * terminated: whether the environment is terminated or not
+ * truncated: whether the environment is truncated or not
+ * info: debug info
+ keys_to_action: Mapping from keys pressed to action performed.
+ Different formats are supported: Key combinations can either be expressed as a tuple of unicode code
+ points of the keys, as a tuple of characters, or as a string where each character of the string represents
+ one key.
+ For example if pressing 'w' and space at the same time is supposed
+ to trigger action number 2 then ``key_to_action`` dict could look like this:
+
+ >>> key_to_action = {
+ ... # ...
+ ... (ord('w'), ord(' ')): 2
+ ... # ...
+ ... }
+
+ or like this:
+
+ >>> key_to_action = {
+ ... # ...
+ ... ("w", " "): 2
+ ... # ...
+ ... }
+
+ or like this:
+
+ >>> key_to_action = {
+ ... # ...
+ ... "w ": 2
+ ... # ...
+ ... }
+
+ If ``None``, default ``key_to_action`` mapping for that environment is used, if provided.
+ seed: Random seed used when resetting the environment. If None, no seed is used.
+ noop: The action used when no key input has been entered, or the entered key combination is unknown.
+ wait_on_player: Play should wait for a user action
+
+ Example:
+ >>> import gymnasium as gym
+ >>> import numpy as np
+ >>> from gymnasium.utils.play import play
+ >>> play(gym.make("CarRacing-v3", render_mode="rgb_array"), # doctest: +SKIP
+ ... keys_to_action={
+ ... "w": np.array([0, 0.7, 0], dtype=np.float32),
+ ... "a": np.array([-1, 0, 0], dtype=np.float32),
+ ... "s": np.array([0, 0, 1], dtype=np.float32),
+ ... "d": np.array([1, 0, 0], dtype=np.float32),
+ ... "wa": np.array([-1, 0.7, 0], dtype=np.float32),
+ ... "dw": np.array([1, 0.7, 0], dtype=np.float32),
+ ... "ds": np.array([1, 0, 1], dtype=np.float32),
+ ... "as": np.array([-1, 0, 1], dtype=np.float32),
+ ... },
+ ... noop=np.array([0, 0, 0], dtype=np.float32)
+ ... )
+
+ Above code works also if the environment is wrapped, so it's particularly useful in
+ verifying that the frame-level preprocessing does not render the game
+ unplayable.
+
+ If you wish to plot real time statistics as you play, you can use
+ :class:`PlayPlot`. Here's a sample code for plotting the reward
+ for last 150 steps.
+
+ >>> from gymnasium.utils.play import PlayPlot, play
+ >>> def callback(obs_t, obs_tp1, action, rew, terminated, truncated, info):
+ ... return [rew,]
+ >>> plotter = PlayPlot(callback, 150, ["reward"]) # doctest: +SKIP
+ >>> play(gym.make("CartPole-v1"), callback=plotter.callback) # doctest: +SKIP
+ """
+ env.reset(seed=seed)
+
+ if keys_to_action is None:
+ if env.has_wrapper_attr("get_keys_to_action"):
+ keys_to_action = env.get_wrapper_attr("get_keys_to_action")()
+ else:
+ assert env.spec is not None
+ raise MissingKeysToAction(
+ f"{env.spec.id} does not have explicit key to action mapping, "
+ "please specify one manually"
+ )
+
+ assert keys_to_action is not None
+
+ # validate the `keys_to_action` set provided
+ assert isinstance(keys_to_action, dict)
+ for key, action in keys_to_action.items():
+ if isinstance(key, tuple):
+ assert len(key) > 0
+ assert all(isinstance(k, (str, int)) for k in key)
+ else:
+ assert isinstance(key, (str, int))
+
+ assert action in env.action_space
+
+ key_code_to_action = {}
+ for key_combination, action in keys_to_action.items():
+ if isinstance(key_combination, int):
+ key_combination = (key_combination,)
+ key_code = tuple(
+ sorted(ord(key) if isinstance(key, str) else key for key in key_combination)
+ )
+ key_code_to_action[key_code] = action
+
+ game = PlayableGame(env, key_code_to_action, zoom)
+
+ if fps is None:
+ fps = env.metadata.get("render_fps", 30)
+
+ done, obs = True, None
+ clock = pygame.time.Clock()
+
+ while game.running:
+ if done:
+ done = False
+ obs = env.reset(seed=seed)
+ elif wait_on_player is False or len(game.pressed_keys) > 0:
+ action = key_code_to_action.get(tuple(sorted(game.pressed_keys)), noop)
+ prev_obs = obs
+ obs, rew, terminated, truncated, info = env.step(action)
+ done = terminated or truncated
+ if callback is not None:
+ callback(prev_obs, obs, action, rew, terminated, truncated, info)
+ if obs is not None:
+ rendered = env.render()
+ if isinstance(rendered, list):
+ rendered = rendered[-1]
+ assert rendered is not None and isinstance(rendered, np.ndarray)
+ display_arr(
+ game.screen, rendered, transpose=transpose, video_size=game.video_size
+ )
+
+ # process pygame events
+ for event in pygame.event.get():
+ game.process_event(event)
+
+ pygame.display.flip()
+ clock.tick(fps)
+ pygame.quit()
+
+
+
+
+[docs]
+class PlayPlot:
+ """Provides a callback to create live plots of arbitrary metrics when using :func:`play`.
+
+ This class is instantiated with a function that accepts information about a single environment transition:
+ - obs_t: observation before performing action
+ - obs_tp1: observation after performing action
+ - action: action that was executed
+ - rew: reward that was received
+ - terminated: whether the environment is terminated or not
+ - truncated: whether the environment is truncated or not
+ - info: debug info
+
+ It should return a list of metrics that are computed from this data.
+ For instance, the function may look like this::
+
+ >>> def compute_metrics(obs_t, obs_tp, action, reward, terminated, truncated, info):
+ ... return [reward, info["cumulative_reward"], np.linalg.norm(action)]
+
+ :class:`PlayPlot` provides the method :meth:`callback` which will pass its arguments along to that function
+ and uses the returned values to update live plots of the metrics.
+
+ Typically, this :meth:`callback` will be used in conjunction with :func:`play` to see how the metrics evolve as you play::
+
+ >>> plotter = PlayPlot(compute_metrics, horizon_timesteps=200, # doctest: +SKIP
+ ... plot_names=["Immediate Rew.", "Cumulative Rew.", "Action Magnitude"])
+ >>> play(your_env, callback=plotter.callback) # doctest: +SKIP
+ """
+
+ def __init__(
+ self, callback: Callable, horizon_timesteps: int, plot_names: list[str]
+ ):
+ """Constructor of :class:`PlayPlot`.
+
+ The function ``callback`` that is passed to this constructor should return
+ a list of metrics that is of length ``len(plot_names)``.
+
+ Args:
+ callback: Function that computes metrics from environment transitions
+ horizon_timesteps: The time horizon used for the live plots
+ plot_names: List of plot titles
+
+ Raises:
+ DependencyNotInstalled: If matplotlib is not installed
+ """
+ self.data_callback = callback
+ self.horizon_timesteps = horizon_timesteps
+ self.plot_names = plot_names
+
+ if plt is None:
+ raise DependencyNotInstalled(
+ 'matplotlib is not installed, run `pip install "gymnasium[other]"`'
+ )
+
+ num_plots = len(self.plot_names)
+ self.fig, self.ax = plt.subplots(num_plots)
+ if num_plots == 1:
+ self.ax = [self.ax]
+ for axis, name in zip(self.ax, plot_names):
+ axis.set_title(name)
+ self.t = 0
+ self.cur_plot: list[Axes | None] = [None for _ in range(num_plots)]
+ self.data = [deque(maxlen=horizon_timesteps) for _ in range(num_plots)]
+
+
+[docs]
+ def callback(
+ self,
+ obs_t: ObsType,
+ obs_tp1: ObsType,
+ action: ActType,
+ rew: float,
+ terminated: bool,
+ truncated: bool,
+ info: dict,
+ ):
+ """The callback that calls the provided data callback and adds the data to the plots.
+
+ Args:
+ obs_t: The observation at time step t
+ obs_tp1: The observation at time step t+1
+ action: The action
+ rew: The reward
+ terminated: If the environment is terminated
+ truncated: If the environment is truncated
+ info: The information from the environment
+ """
+ points = self.data_callback(
+ obs_t, obs_tp1, action, rew, terminated, truncated, info
+ )
+ for point, data_series in zip(points, self.data):
+ data_series.append(point)
+ self.t += 1
+
+ xmin, xmax = max(0, self.t - self.horizon_timesteps), self.t
+
+ for i, plot in enumerate(self.cur_plot):
+ if plot is not None:
+ plot.remove()
+ self.cur_plot[i] = self.ax[i].scatter(
+ range(xmin, xmax), list(self.data[i]), c="blue"
+ )
+ self.ax[i].set_xlim(xmin, xmax)
+
+ if plt is None:
+ raise DependencyNotInstalled(
+ 'matplotlib is not installed, run `pip install "gymnasium[other]"`'
+ )
+ plt.pause(0.000001)
+
+
+
+"""Utility functions to save rendering videos."""
+
+from __future__ import annotations
+
+import os
+from collections.abc import Callable
+
+import gymnasium as gym
+from gymnasium import logger
+
+
+try:
+ from moviepy.video.io.ImageSequenceClip import ImageSequenceClip
+except ImportError as e:
+ raise gym.error.DependencyNotInstalled(
+ 'moviepy is not installed, run `pip install "gymnasium[other]"`'
+ ) from e
+
+
+
+[docs]
+def capped_cubic_video_schedule(episode_id: int) -> bool:
+ r"""The default episode trigger.
+
+ This function will trigger recordings at the episode indices :math:`\{0, 1, 4, 8, 27, ..., k^3, ..., 729, 1000, 2000, 3000, ...\}`
+
+ Args:
+ episode_id: The episode number
+
+ Returns:
+ If to apply a video schedule number
+ """
+ if episode_id < 1000:
+ return int(round(episode_id ** (1.0 / 3))) ** 3 == episode_id
+ else:
+ return episode_id % 1000 == 0
+
+
+
+
+[docs]
+def save_video(
+ frames: list,
+ video_folder: str,
+ episode_trigger: Callable[[int], bool] = None,
+ step_trigger: Callable[[int], bool] = None,
+ video_length: int | None = None,
+ name_prefix: str = "rl-video",
+ episode_index: int = 0,
+ step_starting_index: int = 0,
+ save_logger: str | None = None,
+ **kwargs,
+):
+ """Save videos from rendering frames.
+
+ This function extract video from a list of render frame episodes.
+
+ Args:
+ frames (List[RenderFrame]): A list of frames to compose the video.
+ video_folder (str): The folder where the recordings will be stored
+ episode_trigger: Function that accepts an integer and returns ``True`` iff a recording should be started at this episode
+ step_trigger: Function that accepts an integer and returns ``True`` iff a recording should be started at this step
+ video_length (int): The length of recorded episodes. If it isn't specified, the entire episode is recorded.
+ Otherwise, snippets of the specified length are captured.
+ name_prefix (str): Will be prepended to the filename of the recordings.
+ episode_index (int): The index of the current episode.
+ step_starting_index (int): The step index of the first frame.
+ save_logger: If to log the video saving progress, helpful for long videos that take a while, use "bar" to enable.
+ **kwargs: The kwargs that will be passed to moviepy's ImageSequenceClip.
+ You need to specify either fps or duration.
+
+ Example:
+ >>> import gymnasium as gym
+ >>> from gymnasium.utils.save_video import save_video
+ >>> env = gym.make("FrozenLake-v1", render_mode="rgb_array_list")
+ >>> _ = env.reset()
+ >>> step_starting_index = 0
+ >>> episode_index = 0
+ >>> for step_index in range(199): # doctest: +SKIP
+ ... action = env.action_space.sample()
+ ... _, _, terminated, truncated, _ = env.step(action)
+ ...
+ ... if terminated or truncated:
+ ... save_video(
+ ... frames=env.render(),
+ ... video_folder="videos",
+ ... fps=env.metadata["render_fps"],
+ ... step_starting_index=step_starting_index,
+ ... episode_index=episode_index
+ ... )
+ ... step_starting_index = step_index + 1
+ ... episode_index += 1
+ ... env.reset()
+ >>> env.close()
+ """
+ if not isinstance(frames, list):
+ logger.error(f"Expected a list of frames, got a {type(frames)} instead.")
+ if episode_trigger is None and step_trigger is None:
+ episode_trigger = capped_cubic_video_schedule
+
+ video_folder = os.path.abspath(video_folder)
+ os.makedirs(video_folder, exist_ok=True)
+ path_prefix = f"{video_folder}/{name_prefix}"
+
+ if episode_trigger is not None and episode_trigger(episode_index):
+ clip = ImageSequenceClip(frames[:video_length], **kwargs)
+ clip.write_videofile(
+ f"{path_prefix}-episode-{episode_index}.mp4", logger=save_logger
+ )
+
+ if step_trigger is not None:
+ # skip the first frame since it comes from reset
+ for step_index, frame_index in enumerate(
+ range(1, len(frames)), start=step_starting_index
+ ):
+ if step_trigger(step_index):
+ end_index = (
+ frame_index + video_length if video_length is not None else None
+ )
+ clip = ImageSequenceClip(frames[frame_index:end_index], **kwargs)
+ clip.write_videofile(
+ f"{path_prefix}-step-{step_index}.mp4", logger=save_logger
+ )
+
+
+"""Set of random number generator functions: seeding, generator, hashing seeds."""
+
+from __future__ import annotations
+
+import numpy as np
+
+from gymnasium import error
+
+
+
+[docs]
+def np_random(seed: int | None = None) -> tuple[np.random.Generator, int]:
+ """Returns a NumPy random number generator (RNG) along with seed value from the inputted seed.
+
+ If ``seed`` is ``None`` then a **random** seed will be generated as the RNG's initial seed.
+ This randomly selected seed is returned as the second value of the tuple.
+
+ .. py:currentmodule:: gymnasium.Env
+
+ This function is called in :meth:`reset` to reset an environment's initial RNG.
+
+ Args:
+ seed: The seed used to create the generator
+
+ Returns:
+ A NumPy-based Random Number Generator and generator seed
+
+ Raises:
+ Error: Seed must be a non-negative integer
+ """
+ if seed is not None and not (isinstance(seed, int) and 0 <= seed):
+ if isinstance(seed, int) is False:
+ raise error.Error(
+ f"Seed must be a python integer, actual type: {type(seed)}"
+ )
+ else:
+ raise error.Error(
+ f"Seed must be greater or equal to zero, actual value: {seed}"
+ )
+
+ seed_seq = np.random.SeedSequence(seed)
+ np_seed = seed_seq.entropy
+ rng = RandomNumberGenerator(np.random.PCG64(seed_seq))
+ return rng, np_seed
+
+
+
+RNG = RandomNumberGenerator = np.random.Generator
+
+"""Contains methods for step compatibility, from old-to-new and new-to-old API."""
+
+from __future__ import annotations
+
+from typing import SupportsFloat, Union
+
+import numpy as np
+
+from gymnasium.core import ObsType
+
+
+DoneStepType = tuple[
+ Union[ObsType, np.ndarray],
+ Union[SupportsFloat, np.ndarray],
+ Union[bool, np.ndarray],
+ Union[dict, list],
+]
+
+TerminatedTruncatedStepType = tuple[
+ Union[ObsType, np.ndarray],
+ Union[SupportsFloat, np.ndarray],
+ Union[bool, np.ndarray],
+ Union[bool, np.ndarray],
+ Union[dict, list],
+]
+
+
+
+[docs]
+def convert_to_terminated_truncated_step_api(
+ step_returns: DoneStepType | TerminatedTruncatedStepType, is_vector_env=False
+) -> TerminatedTruncatedStepType:
+ """Function to transform step returns to new step API irrespective of input API.
+
+ .. py:currentmodule:: gymnasium.Env
+
+ Args:
+ step_returns (tuple): Items returned by :meth:`step`. Can be ``(obs, rew, done, info)`` or ``(obs, rew, terminated, truncated, info)``
+ is_vector_env (bool): Whether the ``step_returns`` are from a vector environment
+ """
+ if len(step_returns) == 5:
+ return step_returns
+ else:
+ assert len(step_returns) == 4
+ observations, rewards, dones, infos = step_returns
+
+ # Cases to handle - info single env / info vector env (list) / info vector env (dict)
+ if is_vector_env is False:
+ truncated = infos.pop("TimeLimit.truncated", False)
+ return (
+ observations,
+ rewards,
+ dones and not truncated,
+ dones and truncated,
+ infos,
+ )
+ elif isinstance(infos, list):
+ truncated = np.array(
+ [info.pop("TimeLimit.truncated", False) for info in infos]
+ )
+ return (
+ observations,
+ rewards,
+ np.logical_and(dones, np.logical_not(truncated)),
+ np.logical_and(dones, truncated),
+ infos,
+ )
+ elif isinstance(infos, dict):
+ num_envs = len(dones)
+ truncated = infos.pop("TimeLimit.truncated", np.zeros(num_envs, dtype=bool))
+ return (
+ observations,
+ rewards,
+ np.logical_and(dones, np.logical_not(truncated)),
+ np.logical_and(dones, truncated),
+ infos,
+ )
+ else:
+ raise TypeError(
+ f"Unexpected value of infos, as is_vector_envs=False, expects `info` to be a list or dict, actual type: {type(infos)}"
+ )
+
+
+
+
+[docs]
+def convert_to_done_step_api(
+ step_returns: TerminatedTruncatedStepType | DoneStepType,
+ is_vector_env: bool = False,
+) -> DoneStepType:
+ """Function to transform step returns to old step API irrespective of input API.
+
+ .. py:currentmodule:: gymnasium.Env
+
+ Args:
+ step_returns (tuple): Items returned by :meth:`step`. Can be ``(obs, rew, done, info)`` or ``(obs, rew, terminated, truncated, info)``
+ is_vector_env (bool): Whether the ``step_returns`` are from a vector environment
+ """
+ if len(step_returns) == 4:
+ return step_returns
+ else:
+ assert len(step_returns) == 5
+ observations, rewards, terminated, truncated, infos = step_returns
+
+ # Cases to handle - info single env / info vector env (list) / info vector env (dict)
+ if is_vector_env is False:
+ if truncated or terminated:
+ infos["TimeLimit.truncated"] = truncated and not terminated
+ return (
+ observations,
+ rewards,
+ terminated or truncated,
+ infos,
+ )
+ elif isinstance(infos, list):
+ for info, env_truncated, env_terminated in zip(
+ infos, truncated, terminated
+ ):
+ if env_truncated or env_terminated:
+ info["TimeLimit.truncated"] = env_truncated and not env_terminated
+ return (
+ observations,
+ rewards,
+ np.logical_or(terminated, truncated),
+ infos,
+ )
+ elif isinstance(infos, dict):
+ if np.logical_or(np.any(truncated), np.any(terminated)):
+ infos["TimeLimit.truncated"] = np.logical_and(
+ truncated, np.logical_not(terminated)
+ )
+ return (
+ observations,
+ rewards,
+ np.logical_or(terminated, truncated),
+ infos,
+ )
+ else:
+ raise TypeError(
+ f"Unexpected value of infos, as is_vector_envs=False, expects `info` to be a list or dict, actual type: {type(infos)}"
+ )
+
+
+
+
+[docs]
+def step_api_compatibility(
+ step_returns: TerminatedTruncatedStepType | DoneStepType,
+ output_truncation_bool: bool = True,
+ is_vector_env: bool = False,
+) -> TerminatedTruncatedStepType | DoneStepType:
+ """Function to transform step returns to the API specified by ``output_truncation_bool``.
+
+ .. py:currentmodule:: gymnasium.Env
+
+ Done (old) step API refers to :meth:`step` method returning ``(observation, reward, done, info)``
+ Terminated Truncated (new) step API refers to :meth:`step` method returning ``(observation, reward, terminated, truncated, info)``
+ (Refer to docs for details on the API change)
+
+ Args:
+ step_returns (tuple): Items returned by :meth:`step`. Can be ``(obs, rew, done, info)`` or ``(obs, rew, terminated, truncated, info)``
+ output_truncation_bool (bool): Whether the output should return two booleans (new API) or one (old) (``True`` by default)
+ is_vector_env (bool): Whether the ``step_returns`` are from a vector environment
+
+ Returns:
+ step_returns (tuple): Depending on ``output_truncation_bool``, it can return ``(obs, rew, done, info)`` or ``(obs, rew, terminated, truncated, info)``
+
+ Example:
+ This function can be used to ensure compatibility in step interfaces with conflicting API. E.g. if env is written in old API,
+ wrapper is written in new API, and the final step output is desired to be in old API.
+
+ >>> import gymnasium as gym
+ >>> env = gym.make("CartPole-v0")
+ >>> _, _ = env.reset()
+ >>> obs, reward, done, info = step_api_compatibility(env.step(0), output_truncation_bool=False)
+ >>> obs, reward, terminated, truncated, info = step_api_compatibility(env.step(0), output_truncation_bool=True)
+
+ >>> vec_env = gym.make_vec("CartPole-v0", vectorization_mode="sync")
+ >>> _, _ = vec_env.reset()
+ >>> obs, rewards, dones, infos = step_api_compatibility(vec_env.step([0]), is_vector_env=True, output_truncation_bool=False)
+ >>> obs, rewards, terminations, truncations, infos = step_api_compatibility(vec_env.step([0]), is_vector_env=True, output_truncation_bool=True)
+
+ """
+ if output_truncation_bool:
+ return convert_to_terminated_truncated_step_api(step_returns, is_vector_env)
+ else:
+ return convert_to_done_step_api(step_returns, is_vector_env)
+
+
+"""An async vector environment."""
+
+from __future__ import annotations
+
+import multiprocessing
+import sys
+import time
+import traceback
+from collections.abc import Callable, Sequence
+from copy import deepcopy
+from enum import Enum
+from multiprocessing import Queue
+from multiprocessing.connection import Connection
+from multiprocessing.sharedctypes import SynchronizedArray
+from typing import Any
+
+import numpy as np
+
+from gymnasium import Space, logger
+from gymnasium.core import ActType, Env, ObsType, RenderFrame
+from gymnasium.error import (
+ AlreadyPendingCallError,
+ ClosedEnvironmentError,
+ CustomSpaceError,
+ NoAsyncCallError,
+)
+from gymnasium.spaces.utils import is_space_dtype_shape_equiv
+from gymnasium.vector.utils import (
+ CloudpickleWrapper,
+ batch_differing_spaces,
+ batch_space,
+ clear_mpi_env_vars,
+ concatenate,
+ create_empty_array,
+ create_shared_memory,
+ iterate,
+ read_from_shared_memory,
+ write_to_shared_memory,
+)
+from gymnasium.vector.vector_env import ArrayType, AutoresetMode, VectorEnv
+
+
+__all__ = ["AsyncVectorEnv", "AsyncState"]
+
+
+class AsyncState(Enum):
+ """The AsyncVectorEnv possible states given the different actions."""
+
+ DEFAULT = "default"
+ WAITING_RESET = "reset"
+ WAITING_STEP = "step"
+ WAITING_CALL = "call"
+
+
+
+[docs]
+class AsyncVectorEnv(VectorEnv):
+ """Vectorized environment that runs multiple environments in parallel.
+
+ It uses ``multiprocessing`` processes, and pipes for communication.
+
+ Example:
+ >>> import gymnasium as gym
+ >>> envs = gym.make_vec("Pendulum-v1", num_envs=2, vectorization_mode="async")
+ >>> envs
+ AsyncVectorEnv(Pendulum-v1, num_envs=2)
+ >>> envs = gym.vector.AsyncVectorEnv([
+ ... lambda: gym.make("Pendulum-v1", g=9.81),
+ ... lambda: gym.make("Pendulum-v1", g=1.62)
+ ... ])
+ >>> envs
+ AsyncVectorEnv(num_envs=2)
+ >>> observations, infos = envs.reset(seed=42)
+ >>> observations
+ array([[-0.14995256, 0.9886932 , -0.12224312],
+ [ 0.5760367 , 0.8174238 , -0.91244936]], dtype=float32)
+ >>> infos
+ {}
+ >>> _ = envs.action_space.seed(123)
+ >>> observations, rewards, terminations, truncations, infos = envs.step(envs.action_space.sample())
+ >>> observations
+ array([[-0.1851753 , 0.98270553, 0.714599 ],
+ [ 0.6193494 , 0.7851154 , -1.0808398 ]], dtype=float32)
+ >>> rewards
+ array([-2.96495728, -1.00214607])
+ >>> terminations
+ array([False, False])
+ >>> truncations
+ array([False, False])
+ >>> infos
+ {}
+ """
+
+ def __init__(
+ self,
+ env_fns: Sequence[Callable[[], Env]],
+ shared_memory: bool = True,
+ copy: bool = True,
+ context: str | None = None,
+ daemon: bool = True,
+ worker: (
+ Callable[
+ [int, Callable[[], Env], Connection, Connection, bool, Queue], None
+ ]
+ | None
+ ) = None,
+ observation_mode: str | Space = "same",
+ autoreset_mode: str | AutoresetMode = AutoresetMode.NEXT_STEP,
+ ):
+ """Vectorized environment that runs multiple environments in parallel.
+
+ Args:
+ env_fns: Functions that create the environments.
+ shared_memory: If ``True``, then the observations from the worker processes are communicated back through
+ shared variables. This can improve the efficiency if the observations are large (e.g. images).
+ copy: If ``True``, then the :meth:`AsyncVectorEnv.reset` and :meth:`AsyncVectorEnv.step` methods
+ return a copy of the observations.
+ context: Context for `multiprocessing`. If ``None``, then the default context is used.
+ daemon: If ``True``, then subprocesses have ``daemon`` flag turned on; that is, they will quit if
+ the head process quits. However, ``daemon=True`` prevents subprocesses to spawn children,
+ so for some environments you may want to have it set to ``False``.
+ worker: If set, then use that worker in a subprocess instead of a default one.
+ Can be useful to override some inner vector env logic, for instance, how resets on termination or truncation are handled.
+ observation_mode: Defines how environment observation spaces should be batched. 'same' defines that there should be ``n`` copies of identical spaces.
+ 'different' defines that there can be multiple observation spaces with different parameters though requires the same shape and dtype,
+ warning, may raise unexpected errors. Passing a ``Tuple[Space, Space]`` object allows defining a custom ``single_observation_space`` and
+ ``observation_space``, warning, may raise unexpected errors.
+ autoreset_mode: The Autoreset Mode used, see https://farama.org/Vector-Autoreset-Mode for more information.
+
+ Warnings:
+ worker is an advanced mode option. It provides a high degree of flexibility and a high chance
+ to shoot yourself in the foot; thus, if you are writing your own worker, it is recommended to start
+ from the code for ``_worker`` (or ``_async_worker``) method, and add changes.
+
+ Raises:
+ RuntimeError: If the observation space of some sub-environment does not match observation_space
+ (or, by default, the observation space of the first sub-environment).
+ ValueError: If observation_space is a custom space (i.e. not a default space in Gym,
+ such as gymnasium.spaces.Box, gymnasium.spaces.Discrete, or gymnasium.spaces.Dict) and shared_memory is True.
+ """
+ self.env_fns = env_fns
+ self.shared_memory = shared_memory
+ self.copy = copy
+ self.context = context
+ self.daemon = daemon
+ self.worker = worker
+ self.observation_mode = observation_mode
+ self.autoreset_mode = (
+ autoreset_mode
+ if isinstance(autoreset_mode, AutoresetMode)
+ else AutoresetMode(autoreset_mode)
+ )
+
+ self.num_envs = len(env_fns)
+
+ # This would be nice to get rid of, but without it there's a deadlock between shared memory and pipes
+ # Create a dummy environment to gather the metadata and observation / action space of the environment
+ dummy_env = env_fns[0]()
+
+ # As we support `make_vec(spec)` then we can't include a `spec = dummy_env.spec` as this doesn't guarantee we can actual recreate the vector env.
+ self.metadata = dummy_env.metadata
+ self.metadata["autoreset_mode"] = self.autoreset_mode
+ self.render_mode = dummy_env.render_mode
+
+ self.single_action_space = dummy_env.action_space
+ self.action_space = batch_space(self.single_action_space, self.num_envs)
+
+ if isinstance(observation_mode, tuple) and len(observation_mode) == 2:
+ assert isinstance(observation_mode[0], Space)
+ assert isinstance(observation_mode[1], Space)
+ self.observation_space, self.single_observation_space = observation_mode
+ else:
+ if observation_mode == "same":
+ self.single_observation_space = dummy_env.observation_space
+ self.observation_space = batch_space(
+ self.single_observation_space, self.num_envs
+ )
+ elif observation_mode == "different":
+ # the environment is created and instantly destroy, might cause issues for some environment
+ # but I don't believe there is anything else we can do, for users with issues, pre-compute the spaces and use the custom option.
+ env_spaces = [env().observation_space for env in self.env_fns]
+
+ self.single_observation_space = env_spaces[0]
+ self.observation_space = batch_differing_spaces(env_spaces)
+ else:
+ raise ValueError(
+ f"Invalid `observation_mode`, expected: 'same' or 'different' or tuple of single and batch observation space, actual got {observation_mode}"
+ )
+
+ dummy_env.close()
+ del dummy_env
+
+ # Generate the multiprocessing context for the observation buffer
+ ctx = multiprocessing.get_context(context)
+ if self.shared_memory:
+ try:
+ _obs_buffer = create_shared_memory(
+ self.single_observation_space, n=self.num_envs, ctx=ctx
+ )
+ self.observations = read_from_shared_memory(
+ self.single_observation_space, _obs_buffer, n=self.num_envs
+ )
+ except CustomSpaceError as e:
+ raise ValueError(
+ "Using `AsyncVector(..., shared_memory=True)` caused an error, you can disable this feature with `shared_memory=False` however this is slower."
+ ) from e
+ else:
+ _obs_buffer = None
+ self.observations = create_empty_array(
+ self.single_observation_space, n=self.num_envs, fn=np.zeros
+ )
+
+ self.parent_pipes, self.processes = [], []
+ self.error_queue = ctx.Queue()
+ target = worker or _async_worker
+ with clear_mpi_env_vars():
+ for idx, env_fn in enumerate(self.env_fns):
+ parent_pipe, child_pipe = ctx.Pipe()
+ process = ctx.Process(
+ target=target,
+ name=f"Worker<{type(self).__name__}>-{idx}",
+ args=(
+ idx,
+ CloudpickleWrapper(env_fn),
+ child_pipe,
+ parent_pipe,
+ _obs_buffer,
+ self.error_queue,
+ self.autoreset_mode,
+ ),
+ )
+
+ self.parent_pipes.append(parent_pipe)
+ self.processes.append(process)
+
+ process.daemon = daemon
+ process.start()
+ child_pipe.close()
+
+ self._state = AsyncState.DEFAULT
+ self._check_spaces()
+
+ @property
+ def np_random_seed(self) -> tuple[int, ...]:
+ """Returns a tuple of np_random seeds for all the wrapped envs."""
+ return self.get_attr("np_random_seed")
+
+ @property
+ def np_random(self) -> tuple[np.random.Generator, ...]:
+ """Returns the tuple of the numpy random number generators for the wrapped envs."""
+ return self.get_attr("np_random")
+
+
+[docs]
+ def reset(
+ self,
+ *,
+ seed: int | list[int | None] | None = None,
+ options: dict[str, Any] | None = None,
+ ) -> tuple[ObsType, dict[str, Any]]:
+ """Resets all sub-environments in parallel and return a batch of concatenated observations and info.
+
+ Args:
+ seed: The environment reset seeds
+ options: If to return the options
+
+ Returns:
+ A batch of observations and info from the vectorized environment.
+ """
+ self.reset_async(seed=seed, options=options)
+ return self.reset_wait()
+
+
+ def reset_async(
+ self,
+ seed: int | list[int | None] | None = None,
+ options: dict | None = None,
+ ):
+ """Send calls to the :obj:`reset` methods of the sub-environments.
+
+ To get the results of these calls, you may invoke :meth:`reset_wait`.
+
+ Args:
+ seed: List of seeds for each environment
+ options: The reset option
+
+ Raises:
+ ClosedEnvironmentError: If the environment was closed (if :meth:`close` was previously called).
+ AlreadyPendingCallError: If the environment is already waiting for a pending call to another
+ method (e.g. :meth:`step_async`). This can be caused by two consecutive
+ calls to :meth:`reset_async`, with no call to :meth:`reset_wait` in between.
+ """
+ self._assert_is_running()
+
+ if seed is None:
+ seed = [None for _ in range(self.num_envs)]
+ elif isinstance(seed, int):
+ seed = [seed + i for i in range(self.num_envs)]
+ assert (
+ len(seed) == self.num_envs
+ ), f"If seeds are passed as a list the length must match num_envs={self.num_envs} but got length={len(seed)}."
+
+ if self._state != AsyncState.DEFAULT:
+ raise AlreadyPendingCallError(
+ f"Calling `reset_async` while waiting for a pending call to `{self._state.value}` to complete",
+ str(self._state.value),
+ )
+
+ if options is not None and "reset_mask" in options:
+ reset_mask = options.pop("reset_mask")
+ assert isinstance(
+ reset_mask, np.ndarray
+ ), f"`options['reset_mask': mask]` must be a numpy array, got {type(reset_mask)}"
+ assert reset_mask.shape == (
+ self.num_envs,
+ ), f"`options['reset_mask': mask]` must have shape `({self.num_envs},)`, got {reset_mask.shape}"
+ assert (
+ reset_mask.dtype == np.bool_
+ ), f"`options['reset_mask': mask]` must have `dtype=np.bool_`, got {reset_mask.dtype}"
+ assert np.any(
+ reset_mask
+ ), f"`options['reset_mask': mask]` must contain a boolean array, got reset_mask={reset_mask}"
+
+ for pipe, env_seed, env_reset in zip(self.parent_pipes, seed, reset_mask):
+ if env_reset:
+ env_kwargs = {"seed": env_seed, "options": options}
+ pipe.send(("reset", env_kwargs))
+ else:
+ pipe.send(("reset-noop", None))
+ else:
+ for pipe, env_seed in zip(self.parent_pipes, seed):
+ env_kwargs = {"seed": env_seed, "options": options}
+ pipe.send(("reset", env_kwargs))
+
+ self._state = AsyncState.WAITING_RESET
+
+ def reset_wait(
+ self,
+ timeout: int | float | None = None,
+ ) -> tuple[ObsType, dict[str, Any]]:
+ """Waits for the calls triggered by :meth:`reset_async` to finish and returns the results.
+
+ Args:
+ timeout: Number of seconds before the call to ``reset_wait`` times out. If `None`, the call to ``reset_wait`` never times out.
+
+ Returns:
+ A tuple of batched observations and list of dictionaries
+
+ Raises:
+ ClosedEnvironmentError: If the environment was closed (if :meth:`close` was previously called).
+ NoAsyncCallError: If :meth:`reset_wait` was called without any prior call to :meth:`reset_async`.
+ TimeoutError: If :meth:`reset_wait` timed out.
+ """
+ self._assert_is_running()
+ if self._state != AsyncState.WAITING_RESET:
+ raise NoAsyncCallError(
+ "Calling `reset_wait` without any prior " "call to `reset_async`.",
+ AsyncState.WAITING_RESET.value,
+ )
+
+ if not self._poll_pipe_envs(timeout):
+ self._state = AsyncState.DEFAULT
+ raise multiprocessing.TimeoutError(
+ f"The call to `reset_wait` has timed out after {timeout} second(s)."
+ )
+
+ results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
+ self._raise_if_errors(successes)
+
+ infos = {}
+ results, info_data = zip(*results)
+ for i, info in enumerate(info_data):
+ infos = self._add_info(infos, info, i)
+
+ if not self.shared_memory:
+ self.observations = concatenate(
+ self.single_observation_space, results, self.observations
+ )
+
+ self._state = AsyncState.DEFAULT
+ return (deepcopy(self.observations) if self.copy else self.observations), infos
+
+
+[docs]
+ def step(
+ self, actions: ActType
+ ) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict[str, Any]]:
+ """Take an action for each parallel environment.
+
+ Args:
+ actions: element of :attr:`action_space` batch of actions.
+
+ Returns:
+ Batch of (observations, rewards, terminations, truncations, infos)
+ """
+ self.step_async(actions)
+ return self.step_wait()
+
+
+ def step_async(self, actions: np.ndarray):
+ """Send the calls to :meth:`Env.step` to each sub-environment.
+
+ Args:
+ actions: Batch of actions. element of :attr:`VectorEnv.action_space`
+
+ Raises:
+ ClosedEnvironmentError: If the environment was closed (if :meth:`close` was previously called).
+ AlreadyPendingCallError: If the environment is already waiting for a pending call to another
+ method (e.g. :meth:`reset_async`). This can be caused by two consecutive
+ calls to :meth:`step_async`, with no call to :meth:`step_wait` in
+ between.
+ """
+ self._assert_is_running()
+ if self._state != AsyncState.DEFAULT:
+ raise AlreadyPendingCallError(
+ f"Calling `step_async` while waiting for a pending call to `{self._state.value}` to complete.",
+ str(self._state.value),
+ )
+
+ iter_actions = iterate(self.action_space, actions)
+ for pipe, action in zip(self.parent_pipes, iter_actions):
+ pipe.send(("step", action))
+ self._state = AsyncState.WAITING_STEP
+
+ def step_wait(
+ self, timeout: int | float | None = None
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, dict]:
+ """Wait for the calls to :obj:`step` in each sub-environment to finish.
+
+ Args:
+ timeout: Number of seconds before the call to :meth:`step_wait` times out. If ``None``, the call to :meth:`step_wait` never times out.
+
+ Returns:
+ The batched environment step information, (obs, reward, terminated, truncated, info)
+
+ Raises:
+ ClosedEnvironmentError: If the environment was closed (if :meth:`close` was previously called).
+ NoAsyncCallError: If :meth:`step_wait` was called without any prior call to :meth:`step_async`.
+ TimeoutError: If :meth:`step_wait` timed out.
+ """
+ self._assert_is_running()
+ if self._state != AsyncState.WAITING_STEP:
+ raise NoAsyncCallError(
+ "Calling `step_wait` without any prior call " "to `step_async`.",
+ AsyncState.WAITING_STEP.value,
+ )
+
+ if not self._poll_pipe_envs(timeout):
+ self._state = AsyncState.DEFAULT
+ raise multiprocessing.TimeoutError(
+ f"The call to `step_wait` has timed out after {timeout} second(s)."
+ )
+
+ observations, rewards, terminations, truncations, infos = [], [], [], [], {}
+ successes = []
+ for env_idx, pipe in enumerate(self.parent_pipes):
+ env_step_return, success = pipe.recv()
+
+ successes.append(success)
+ if success:
+ observations.append(env_step_return[0])
+ rewards.append(env_step_return[1])
+ terminations.append(env_step_return[2])
+ truncations.append(env_step_return[3])
+ infos = self._add_info(infos, env_step_return[4], env_idx)
+
+ self._raise_if_errors(successes)
+
+ if not self.shared_memory:
+ self.observations = concatenate(
+ self.single_observation_space,
+ observations,
+ self.observations,
+ )
+
+ self._state = AsyncState.DEFAULT
+ return (
+ deepcopy(self.observations) if self.copy else self.observations,
+ np.array(rewards, dtype=np.float64),
+ np.array(terminations, dtype=np.bool_),
+ np.array(truncations, dtype=np.bool_),
+ infos,
+ )
+
+
+[docs]
+ def call(self, name: str, *args: Any, **kwargs: Any) -> tuple[Any, ...]:
+ """Call a method from each parallel environment with args and kwargs.
+
+ Args:
+ name (str): Name of the method or property to call.
+ *args: Position arguments to apply to the method call.
+ **kwargs: Keyword arguments to apply to the method call.
+
+ Returns:
+ List of the results of the individual calls to the method or property for each environment.
+ """
+ self.call_async(name, *args, **kwargs)
+ return self.call_wait()
+
+
+ def render(self) -> tuple[RenderFrame, ...] | None:
+ """Returns a list of rendered frames from the environments."""
+ return self.call("render")
+
+ def call_async(self, name: str, *args, **kwargs):
+ """Calls the method with name asynchronously and apply args and kwargs to the method.
+
+ Args:
+ name: Name of the method or property to call.
+ *args: Arguments to apply to the method call.
+ **kwargs: Keyword arguments to apply to the method call.
+
+ Raises:
+ ClosedEnvironmentError: If the environment was closed (if :meth:`close` was previously called).
+ AlreadyPendingCallError: Calling `call_async` while waiting for a pending call to complete
+ """
+ self._assert_is_running()
+ if self._state != AsyncState.DEFAULT:
+ raise AlreadyPendingCallError(
+ f"Calling `call_async` while waiting for a pending call to `{self._state.value}` to complete.",
+ str(self._state.value),
+ )
+
+ for pipe in self.parent_pipes:
+ pipe.send(("_call", (name, args, kwargs)))
+ self._state = AsyncState.WAITING_CALL
+
+ def call_wait(self, timeout: int | float | None = None) -> tuple[Any, ...]:
+ """Calls all parent pipes and waits for the results.
+
+ Args:
+ timeout: Number of seconds before the call to :meth:`step_wait` times out.
+ If ``None`` (default), the call to :meth:`step_wait` never times out.
+
+ Returns:
+ List of the results of the individual calls to the method or property for each environment.
+
+ Raises:
+ NoAsyncCallError: Calling :meth:`call_wait` without any prior call to :meth:`call_async`.
+ TimeoutError: The call to :meth:`call_wait` has timed out after timeout second(s).
+ """
+ self._assert_is_running()
+ if self._state != AsyncState.WAITING_CALL:
+ raise NoAsyncCallError(
+ "Calling `call_wait` without any prior call to `call_async`.",
+ AsyncState.WAITING_CALL.value,
+ )
+
+ if not self._poll_pipe_envs(timeout):
+ self._state = AsyncState.DEFAULT
+ raise multiprocessing.TimeoutError(
+ f"The call to `call_wait` has timed out after {timeout} second(s)."
+ )
+
+ results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
+ self._raise_if_errors(successes)
+ self._state = AsyncState.DEFAULT
+
+ return results
+
+
+[docs]
+ def get_attr(self, name: str) -> tuple[Any, ...]:
+ """Get a property from each parallel environment.
+
+ Args:
+ name (str): Name of the property to be get from each individual environment.
+
+ Returns:
+ The property with name
+ """
+ return self.call(name)
+
+
+
+[docs]
+ def set_attr(self, name: str, values: list[Any] | tuple[Any] | object):
+ """Sets an attribute of the sub-environments.
+
+ Args:
+ name: Name of the property to be set in each individual environment.
+ values: Values of the property to be set to. If ``values`` is a list or
+ tuple, then it corresponds to the values for each individual
+ environment, otherwise a single value is set for all environments.
+
+ Raises:
+ ValueError: Values must be a list or tuple with length equal to the number of environments.
+ AlreadyPendingCallError: Calling :meth:`set_attr` while waiting for a pending call to complete.
+ """
+ self._assert_is_running()
+ if not isinstance(values, (list, tuple)):
+ values = [values for _ in range(self.num_envs)]
+ if len(values) != self.num_envs:
+ raise ValueError(
+ "Values must be a list or tuple with length equal to the number of environments. "
+ f"Got `{len(values)}` values for {self.num_envs} environments."
+ )
+
+ if self._state != AsyncState.DEFAULT:
+ raise AlreadyPendingCallError(
+ f"Calling `set_attr` while waiting for a pending call to `{self._state.value}` to complete.",
+ str(self._state.value),
+ )
+
+ for pipe, value in zip(self.parent_pipes, values):
+ pipe.send(("_setattr", (name, value)))
+ _, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
+ self._raise_if_errors(successes)
+
+
+ def close_extras(self, timeout: int | float | None = None, terminate: bool = False):
+ """Close the environments & clean up the extra resources (processes and pipes).
+
+ Args:
+ timeout: Number of seconds before the call to :meth:`close` times out. If ``None``,
+ the call to :meth:`close` never times out. If the call to :meth:`close`
+ times out, then all processes are terminated.
+ terminate: If ``True``, then the :meth:`close` operation is forced and all processes are terminated.
+
+ Raises:
+ TimeoutError: If :meth:`close` timed out.
+ """
+ timeout = 0 if terminate else timeout
+ try:
+ if self._state != AsyncState.DEFAULT:
+ logger.warn(
+ f"Calling `close` while waiting for a pending call to `{self._state.value}` to complete."
+ )
+ function = getattr(self, f"{self._state.value}_wait")
+ function(timeout)
+ except multiprocessing.TimeoutError:
+ terminate = True
+
+ if terminate:
+ for process in self.processes:
+ if process.is_alive():
+ process.terminate()
+ else:
+ for pipe in self.parent_pipes:
+ if (pipe is not None) and (not pipe.closed):
+ pipe.send(("close", None))
+ for pipe in self.parent_pipes:
+ if (pipe is not None) and (not pipe.closed):
+ pipe.recv()
+
+ for pipe in self.parent_pipes:
+ if pipe is not None:
+ pipe.close()
+ for process in self.processes:
+ process.join()
+
+ def _poll_pipe_envs(self, timeout: int | None = None):
+ self._assert_is_running()
+
+ if timeout is None:
+ return True
+
+ end_time = time.perf_counter() + timeout
+ for pipe in self.parent_pipes:
+ delta = max(end_time - time.perf_counter(), 0)
+
+ if pipe is None:
+ return False
+ if pipe.closed or (not pipe.poll(delta)):
+ return False
+ return True
+
+ def _check_spaces(self):
+ self._assert_is_running()
+
+ for pipe in self.parent_pipes:
+ pipe.send(
+ (
+ "_check_spaces",
+ (
+ self.observation_mode,
+ self.single_observation_space,
+ self.single_action_space,
+ ),
+ )
+ )
+
+ results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
+ self._raise_if_errors(successes)
+ same_observation_spaces, same_action_spaces = zip(*results)
+
+ if not all(same_observation_spaces):
+ if self.observation_mode == "same":
+ raise RuntimeError(
+ "AsyncVectorEnv(..., observation_mode='same') however some of the sub-environments observation spaces are not equivalent. If this is intentional, use `observation_mode='different'` instead."
+ )
+ else:
+ raise RuntimeError(
+ "AsyncVectorEnv(..., observation_mode='different' or custom space) however the sub-environment's observation spaces do not share a common shape and dtype."
+ )
+
+ if not all(same_action_spaces):
+ raise RuntimeError(
+ f"Some environments have an action space different from `{self.single_action_space}`. "
+ "In order to batch actions, the action spaces from all environments must be equal."
+ )
+
+ def _assert_is_running(self):
+ if self.closed:
+ raise ClosedEnvironmentError(
+ f"Trying to operate on `{type(self).__name__}`, after a call to `close()`."
+ )
+
+ def _raise_if_errors(self, successes: list[bool] | tuple[bool]):
+ if all(successes):
+ return
+
+ num_errors = self.num_envs - sum(successes)
+ assert num_errors > 0
+ for i in range(num_errors):
+ index, exctype, value, trace = self.error_queue.get()
+
+ logger.error(
+ f"Received the following error from Worker-{index} - Shutting it down"
+ )
+ logger.error(f"{trace}")
+
+ self.parent_pipes[index].close()
+ self.parent_pipes[index] = None
+
+ if i == num_errors - 1:
+ logger.error("Raising the last exception back to the main process.")
+ self._state = AsyncState.DEFAULT
+ raise exctype(value)
+
+ def __del__(self):
+ """On deleting the object, checks that the vector environment is closed."""
+ if not getattr(self, "closed", True) and hasattr(self, "_state"):
+ self.close(terminate=True)
+
+
+
+def _async_worker(
+ index: int,
+ env_fn: Callable,
+ pipe: Connection,
+ parent_pipe: Connection,
+ shared_memory: SynchronizedArray | dict[str, Any] | tuple[Any, ...],
+ error_queue: Queue,
+ autoreset_mode: AutoresetMode,
+):
+ env = env_fn()
+ observation_space = env.observation_space
+ action_space = env.action_space
+ autoreset = False
+ observation = None
+
+ parent_pipe.close()
+
+ try:
+ while True:
+ command, data = pipe.recv()
+
+ if command == "reset":
+ observation, info = env.reset(**data)
+ if shared_memory:
+ write_to_shared_memory(
+ observation_space, index, observation, shared_memory
+ )
+ observation = None
+ autoreset = False
+ pipe.send(((observation, info), True))
+ elif command == "reset-noop":
+ pipe.send(((observation, {}), True))
+ elif command == "step":
+ if autoreset_mode == AutoresetMode.NEXT_STEP:
+ if autoreset:
+ observation, info = env.reset()
+ reward, terminated, truncated = 0, False, False
+ else:
+ (
+ observation,
+ reward,
+ terminated,
+ truncated,
+ info,
+ ) = env.step(data)
+ autoreset = terminated or truncated
+ elif autoreset_mode == AutoresetMode.SAME_STEP:
+ (
+ observation,
+ reward,
+ terminated,
+ truncated,
+ info,
+ ) = env.step(data)
+
+ if terminated or truncated:
+ reset_observation, reset_info = env.reset()
+
+ info = {
+ "final_info": info,
+ "final_obs": observation,
+ **reset_info,
+ }
+ observation = reset_observation
+ elif autoreset_mode == AutoresetMode.DISABLED:
+ assert autoreset is False
+ (
+ observation,
+ reward,
+ terminated,
+ truncated,
+ info,
+ ) = env.step(data)
+ else:
+ raise ValueError(f"Unexpected autoreset_mode: {autoreset_mode}")
+
+ if shared_memory:
+ write_to_shared_memory(
+ observation_space, index, observation, shared_memory
+ )
+ observation = None
+
+ pipe.send(((observation, reward, terminated, truncated, info), True))
+ elif command == "close":
+ pipe.send((None, True))
+ break
+ elif command == "_call":
+ name, args, kwargs = data
+ if name in ["reset", "step", "close", "_setattr", "_check_spaces"]:
+ raise ValueError(
+ f"Trying to call function `{name}` with `call`, use `{name}` directly instead."
+ )
+
+ attr = env.get_wrapper_attr(name)
+ if callable(attr):
+ pipe.send((attr(*args, **kwargs), True))
+ else:
+ pipe.send((attr, True))
+ elif command == "_setattr":
+ name, value = data
+ env.set_wrapper_attr(name, value)
+ pipe.send((None, True))
+ elif command == "_check_spaces":
+ obs_mode, single_obs_space, single_action_space = data
+
+ pipe.send(
+ (
+ (
+ (
+ single_obs_space == observation_space
+ if obs_mode == "same"
+ else is_space_dtype_shape_equiv(
+ single_obs_space, observation_space
+ )
+ ),
+ single_action_space == action_space,
+ ),
+ True,
+ )
+ )
+ else:
+ raise RuntimeError(
+ f"Received unknown command `{command}`. Must be one of [`reset`, `step`, `close`, `_call`, `_setattr`, `_check_spaces`]."
+ )
+ except (KeyboardInterrupt, Exception):
+ error_type, error_message, _ = sys.exc_info()
+ trace = traceback.format_exc()
+
+ error_queue.put((index, error_type, error_message, trace))
+ pipe.send((None, False))
+ finally:
+ env.close()
+
+"""Implementation of a synchronous (for loop) vectorization method of any environment."""
+
+from __future__ import annotations
+
+from collections.abc import Callable, Iterator, Sequence
+from copy import deepcopy
+from typing import Any
+
+import numpy as np
+
+from gymnasium import Env, Space
+from gymnasium.core import ActType, ObsType, RenderFrame
+from gymnasium.spaces.utils import is_space_dtype_shape_equiv
+from gymnasium.vector.utils import (
+ batch_differing_spaces,
+ batch_space,
+ concatenate,
+ create_empty_array,
+ iterate,
+)
+from gymnasium.vector.vector_env import ArrayType, AutoresetMode, VectorEnv
+
+
+__all__ = ["SyncVectorEnv"]
+
+
+
+[docs]
+class SyncVectorEnv(VectorEnv):
+ """Vectorized environment that serially runs multiple environments.
+
+ Example:
+ >>> import gymnasium as gym
+ >>> envs = gym.make_vec("Pendulum-v1", num_envs=2, vectorization_mode="sync")
+ >>> envs
+ SyncVectorEnv(Pendulum-v1, num_envs=2)
+ >>> envs = gym.vector.SyncVectorEnv([
+ ... lambda: gym.make("Pendulum-v1", g=9.81),
+ ... lambda: gym.make("Pendulum-v1", g=1.62)
+ ... ])
+ >>> envs
+ SyncVectorEnv(num_envs=2)
+ >>> obs, infos = envs.reset(seed=42)
+ >>> obs
+ array([[-0.14995256, 0.9886932 , -0.12224312],
+ [ 0.5760367 , 0.8174238 , -0.91244936]], dtype=float32)
+ >>> infos
+ {}
+ >>> _ = envs.action_space.seed(42)
+ >>> actions = envs.action_space.sample()
+ >>> obs, rewards, terminates, truncates, infos = envs.step(actions)
+ >>> obs
+ array([[-0.1878752 , 0.98219293, 0.7695615 ],
+ [ 0.6102389 , 0.79221743, -0.8498053 ]], dtype=float32)
+ >>> rewards
+ array([-2.96562607, -0.99902063])
+ >>> terminates
+ array([False, False])
+ >>> truncates
+ array([False, False])
+ >>> infos
+ {}
+ >>> envs.close()
+ """
+
+ def __init__(
+ self,
+ env_fns: Iterator[Callable[[], Env]] | Sequence[Callable[[], Env]],
+ copy: bool = True,
+ observation_mode: str | Space = "same",
+ autoreset_mode: str | AutoresetMode = AutoresetMode.NEXT_STEP,
+ ):
+ """Vectorized environment that serially runs multiple environments.
+
+ Args:
+ env_fns: iterable of callable functions that create the environments.
+ copy: If ``True``, then the :meth:`reset` and :meth:`step` methods return a copy of the observations.
+ observation_mode: Defines how environment observation spaces should be batched. 'same' defines that there should be ``n`` copies of identical spaces.
+ 'different' defines that there can be multiple observation spaces with the same length but different high/low values batched together. Passing a ``Space`` object
+ allows the user to set some custom observation space mode not covered by 'same' or 'different.'
+ autoreset_mode: The Autoreset Mode used, see https://farama.org/Vector-Autoreset-Mode for more information.
+
+ Raises:
+ RuntimeError: If the observation space of some sub-environment does not match observation_space
+ (or, by default, the observation space of the first sub-environment).
+ """
+ super().__init__()
+
+ self.env_fns = env_fns
+ self.copy = copy
+ self.observation_mode = observation_mode
+ self.autoreset_mode = (
+ autoreset_mode
+ if isinstance(autoreset_mode, AutoresetMode)
+ else AutoresetMode(autoreset_mode)
+ )
+
+ # Initialise all sub-environments
+ self.envs = [env_fn() for env_fn in env_fns]
+
+ # Define core attributes using the sub-environments
+ # As we support `make_vec(spec)` then we can't include a `spec = self.envs[0].spec` as this doesn't guarantee we can actual recreate the vector env.
+ self.num_envs = len(self.envs)
+ self.metadata = self.envs[0].metadata
+ self.metadata["autoreset_mode"] = self.autoreset_mode
+ self.render_mode = self.envs[0].render_mode
+
+ self.single_action_space = self.envs[0].action_space
+ self.action_space = batch_space(self.single_action_space, self.num_envs)
+
+ if isinstance(observation_mode, tuple) and len(observation_mode) == 2:
+ assert isinstance(observation_mode[0], Space)
+ assert isinstance(observation_mode[1], Space)
+ self.observation_space, self.single_observation_space = observation_mode
+ else:
+ if observation_mode == "same":
+ self.single_observation_space = self.envs[0].observation_space
+ self.observation_space = batch_space(
+ self.single_observation_space, self.num_envs
+ )
+ elif observation_mode == "different":
+ self.single_observation_space = self.envs[0].observation_space
+ self.observation_space = batch_differing_spaces(
+ [env.observation_space for env in self.envs]
+ )
+ else:
+ raise ValueError(
+ f"Invalid `observation_mode`, expected: 'same' or 'different' or tuple of single and batch observation space, actual got {observation_mode}"
+ )
+
+ # check sub-environment obs and action spaces
+ for env in self.envs:
+ if observation_mode == "same":
+ assert (
+ env.observation_space == self.single_observation_space
+ ), f"SyncVectorEnv(..., observation_mode='same') however the sub-environments observation spaces are not equivalent. single_observation_space={self.single_observation_space}, sub-environment observation_space={env.observation_space}. If this is intentional, use `observation_mode='different'` instead."
+ else:
+ assert is_space_dtype_shape_equiv(
+ env.observation_space, self.single_observation_space
+ ), f"SyncVectorEnv(..., observation_mode='different' or custom space) however the sub-environments observation spaces do not share a common shape and dtype, single_observation_space={self.single_observation_space}, sub-environment observation space={env.observation_space}"
+
+ assert (
+ env.action_space == self.single_action_space
+ ), f"Sub-environment action space doesn't make the `single_action_space`, action_space={env.action_space}, single_action_space={self.single_action_space}"
+
+ # Initialise attributes used in `step` and `reset`
+ self._env_obs = [None for _ in range(self.num_envs)]
+ self._observations = create_empty_array(
+ self.single_observation_space, n=self.num_envs, fn=np.zeros
+ )
+ self._rewards = np.zeros((self.num_envs,), dtype=np.float64)
+ self._terminations = np.zeros((self.num_envs,), dtype=np.bool_)
+ self._truncations = np.zeros((self.num_envs,), dtype=np.bool_)
+
+ self._autoreset_envs = np.zeros((self.num_envs,), dtype=np.bool_)
+
+ @property
+ def np_random_seed(self) -> tuple[int, ...]:
+ """Returns a tuple of np random seeds for the wrapped envs."""
+ return self.get_attr("np_random_seed")
+
+ @property
+ def np_random(self) -> tuple[np.random.Generator, ...]:
+ """Returns a tuple of the numpy random number generators for the wrapped envs."""
+ return self.get_attr("np_random")
+
+
+[docs]
+ def reset(
+ self,
+ *,
+ seed: int | list[int | None] | None = None,
+ options: dict[str, Any] | None = None,
+ ) -> tuple[ObsType, dict[str, Any]]:
+ """Resets each of the sub-environments and concatenate the results together.
+
+ Args:
+ seed: Seeds used to reset the sub-environments, either
+ * ``None`` - random seeds for all environment
+ * ``int`` - ``[seed, seed+1, ..., seed+n]``
+ * List of ints - ``[1, 2, 3, ..., n]``
+ options: Option information used for each sub-environment
+
+ Returns:
+ Concatenated observations and info from each sub-environment
+ """
+ if seed is None:
+ seed = [None for _ in range(self.num_envs)]
+ elif isinstance(seed, int):
+ seed = [seed + i for i in range(self.num_envs)]
+ assert (
+ len(seed) == self.num_envs
+ ), f"If seeds are passed as a list the length must match num_envs={self.num_envs} but got length={len(seed)}."
+
+ if options is not None and "reset_mask" in options:
+ reset_mask = options.pop("reset_mask")
+ assert isinstance(
+ reset_mask, np.ndarray
+ ), f"`options['reset_mask': mask]` must be a numpy array, got {type(reset_mask)}"
+ assert reset_mask.shape == (
+ self.num_envs,
+ ), f"`options['reset_mask': mask]` must have shape `({self.num_envs},)`, got {reset_mask.shape}"
+ assert (
+ reset_mask.dtype == np.bool_
+ ), f"`options['reset_mask': mask]` must have `dtype=np.bool_`, got {reset_mask.dtype}"
+ assert np.any(
+ reset_mask
+ ), f"`options['reset_mask': mask]` must contain a boolean array, got reset_mask={reset_mask}"
+
+ self._terminations[reset_mask] = False
+ self._truncations[reset_mask] = False
+ self._autoreset_envs[reset_mask] = False
+
+ infos = {}
+ for i, (env, single_seed, env_mask) in enumerate(
+ zip(self.envs, seed, reset_mask)
+ ):
+ if env_mask:
+ self._env_obs[i], env_info = env.reset(
+ seed=single_seed, options=options
+ )
+
+ infos = self._add_info(infos, env_info, i)
+ else:
+ self._terminations = np.zeros((self.num_envs,), dtype=np.bool_)
+ self._truncations = np.zeros((self.num_envs,), dtype=np.bool_)
+ self._autoreset_envs = np.zeros((self.num_envs,), dtype=np.bool_)
+
+ infos = {}
+ for i, (env, single_seed) in enumerate(zip(self.envs, seed)):
+ self._env_obs[i], env_info = env.reset(
+ seed=single_seed, options=options
+ )
+
+ infos = self._add_info(infos, env_info, i)
+
+ # Concatenate the observations
+ self._observations = concatenate(
+ self.single_observation_space, self._env_obs, self._observations
+ )
+ return deepcopy(self._observations) if self.copy else self._observations, infos
+
+
+
+[docs]
+ def step(
+ self, actions: ActType
+ ) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict[str, Any]]:
+ """Steps through each of the environments returning the batched results.
+
+ Returns:
+ The batched environment step results
+ """
+ actions = iterate(self.action_space, actions)
+
+ infos = {}
+ for i, action in enumerate(actions):
+ if self.autoreset_mode == AutoresetMode.NEXT_STEP:
+ if self._autoreset_envs[i]:
+ self._env_obs[i], env_info = self.envs[i].reset()
+
+ self._rewards[i] = 0.0
+ self._terminations[i] = False
+ self._truncations[i] = False
+ else:
+ (
+ self._env_obs[i],
+ self._rewards[i],
+ self._terminations[i],
+ self._truncations[i],
+ env_info,
+ ) = self.envs[i].step(action)
+ elif self.autoreset_mode == AutoresetMode.DISABLED:
+ # assumes that the user has correctly autoreset
+ assert not self._autoreset_envs[i], f"{self._autoreset_envs=}"
+ (
+ self._env_obs[i],
+ self._rewards[i],
+ self._terminations[i],
+ self._truncations[i],
+ env_info,
+ ) = self.envs[i].step(action)
+ elif self.autoreset_mode == AutoresetMode.SAME_STEP:
+ (
+ self._env_obs[i],
+ self._rewards[i],
+ self._terminations[i],
+ self._truncations[i],
+ env_info,
+ ) = self.envs[i].step(action)
+
+ if self._terminations[i] or self._truncations[i]:
+ infos = self._add_info(
+ infos,
+ {"final_obs": self._env_obs[i], "final_info": env_info},
+ i,
+ )
+
+ self._env_obs[i], env_info = self.envs[i].reset()
+ else:
+ raise ValueError(f"Unexpected autoreset mode, {self.autoreset_mode}")
+
+ infos = self._add_info(infos, env_info, i)
+
+ # Concatenate the observations
+ self._observations = concatenate(
+ self.single_observation_space, self._env_obs, self._observations
+ )
+ self._autoreset_envs = np.logical_or(self._terminations, self._truncations)
+
+ return (
+ deepcopy(self._observations) if self.copy else self._observations,
+ np.copy(self._rewards),
+ np.copy(self._terminations),
+ np.copy(self._truncations),
+ infos,
+ )
+
+
+ def render(self) -> tuple[RenderFrame, ...] | None:
+ """Returns the rendered frames from the environments."""
+ return tuple(env.render() for env in self.envs)
+
+
+[docs]
+ def call(self, name: str, *args: Any, **kwargs: Any) -> tuple[Any, ...]:
+ """Calls a sub-environment method with name and applies args and kwargs.
+
+ Args:
+ name: The method name
+ *args: The method args
+ **kwargs: The method kwargs
+
+ Returns:
+ Tuple of results
+ """
+ results = []
+ for env in self.envs:
+ function = env.get_wrapper_attr(name)
+
+ if callable(function):
+ results.append(function(*args, **kwargs))
+ else:
+ results.append(function)
+
+ return tuple(results)
+
+
+
+[docs]
+ def get_attr(self, name: str) -> tuple[Any, ...]:
+ """Get a property from each parallel environment.
+
+ Args:
+ name (str): Name of the property to get from each individual environment.
+
+ Returns:
+ The property with name
+ """
+ return self.call(name)
+
+
+
+[docs]
+ def set_attr(self, name: str, values: list[Any] | tuple[Any, ...] | Any):
+ """Sets an attribute of the sub-environments.
+
+ Args:
+ name: The property name to change
+ values: Values of the property to be set to. If ``values`` is a list or
+ tuple, then it corresponds to the values for each individual
+ environment, otherwise, a single value is set for all environments.
+
+ Raises:
+ ValueError: Values must be a list or tuple with length equal to the number of environments.
+ """
+ if not isinstance(values, (list, tuple)):
+ values = [values for _ in range(self.num_envs)]
+
+ if len(values) != self.num_envs:
+ raise ValueError(
+ "Values must be a list or tuple with length equal to the number of environments. "
+ f"Got `{len(values)}` values for {self.num_envs} environments."
+ )
+
+ for env, value in zip(self.envs, values):
+ env.set_wrapper_attr(name, value)
+
+
+ def close_extras(self, **kwargs: Any):
+ """Close the environments."""
+ if hasattr(self, "envs"):
+ [env.close() for env in self.envs]
+
+
+"""Miscellaneous utilities."""
+
+from __future__ import annotations
+
+import contextlib
+import os
+from collections.abc import Callable
+
+from gymnasium.core import Env
+
+
+__all__ = ["CloudpickleWrapper", "clear_mpi_env_vars"]
+
+
+
+[docs]
+class CloudpickleWrapper:
+ """Wrapper that uses cloudpickle to pickle and unpickle the result."""
+
+ def __init__(self, fn: Callable[[], Env]):
+ """Cloudpickle wrapper for a function."""
+ self.fn = fn
+
+ def __getstate__(self):
+ """Get the state using `cloudpickle.dumps(self.fn)`."""
+ import cloudpickle
+
+ return cloudpickle.dumps(self.fn)
+
+ def __setstate__(self, ob):
+ """Sets the state with obs."""
+ import pickle
+
+ self.fn = pickle.loads(ob)
+
+ def __call__(self):
+ """Calls the function `self.fn` with no arguments."""
+ return self.fn()
+
+
+
+
+[docs]
+@contextlib.contextmanager
+def clear_mpi_env_vars():
+ """Clears the MPI of environment variables.
+
+ ``from mpi4py import MPI`` will call ``MPI_Init`` by default.
+ If the child process has MPI environment variables, MPI will think that the child process
+ is an MPI process just like the parent and do bad things such as hang.
+
+ This context manager is a hacky way to clear those environment variables
+ temporarily such as when we are starting multiprocessing Processes.
+
+ Yields:
+ Yields for the context manager
+ """
+ removed_environment = {}
+ for k, v in list(os.environ.items()):
+ for prefix in ["OMPI_", "PMI_"]:
+ if k.startswith(prefix):
+ removed_environment[k] = v
+ del os.environ[k]
+ try:
+ yield
+ finally:
+ os.environ.update(removed_environment)
+
+
+"""Utility functions for vector environments to share memory between processes."""
+
+from __future__ import annotations
+
+import multiprocessing as mp
+from ctypes import c_bool
+from functools import singledispatch
+from multiprocessing.sharedctypes import SynchronizedArray
+from typing import Any
+
+import numpy as np
+
+from gymnasium.error import CustomSpaceError
+from gymnasium.spaces import (
+ Box,
+ Dict,
+ Discrete,
+ Graph,
+ MultiBinary,
+ MultiDiscrete,
+ OneOf,
+ Sequence,
+ Space,
+ Text,
+ Tuple,
+ flatten,
+)
+
+
+__all__ = ["create_shared_memory", "read_from_shared_memory", "write_to_shared_memory"]
+
+
+
+
+
+
+@create_shared_memory.register(Box)
+@create_shared_memory.register(Discrete)
+@create_shared_memory.register(MultiDiscrete)
+@create_shared_memory.register(MultiBinary)
+def _create_base_shared_memory(
+ space: Box | Discrete | MultiDiscrete | MultiBinary, n: int = 1, ctx=mp
+):
+ assert space.dtype is not None
+ dtype = space.dtype.char
+ if dtype in "?":
+ dtype = c_bool
+ return ctx.Array(dtype, n * int(np.prod(space.shape)))
+
+
+@create_shared_memory.register(Tuple)
+def _create_tuple_shared_memory(space: Tuple, n: int = 1, ctx=mp):
+ return tuple(
+ create_shared_memory(subspace, n=n, ctx=ctx) for subspace in space.spaces
+ )
+
+
+@create_shared_memory.register(Dict)
+def _create_dict_shared_memory(space: Dict, n: int = 1, ctx=mp):
+ return {
+ key: create_shared_memory(subspace, n=n, ctx=ctx)
+ for (key, subspace) in space.spaces.items()
+ }
+
+
+@create_shared_memory.register(Text)
+def _create_text_shared_memory(space: Text, n: int = 1, ctx=mp):
+ return ctx.Array(np.dtype(np.int32).char, n * space.max_length)
+
+
+@create_shared_memory.register(OneOf)
+def _create_oneof_shared_memory(space: OneOf, n: int = 1, ctx=mp):
+ return (ctx.Array(np.dtype(np.int64).char, n),) + tuple(
+ create_shared_memory(subspace, n=n, ctx=ctx) for subspace in space.spaces
+ )
+
+
+@create_shared_memory.register(Graph)
+@create_shared_memory.register(Sequence)
+def _create_dynamic_shared_memory(space: Graph | Sequence, n: int = 1, ctx=mp):
+ raise TypeError(
+ f"As {space} has a dynamic shape so its not possible to make a static shared memory. For `AsyncVectorEnv`, disable `shared_memory`."
+ )
+
+
+
+
+
+
+@read_from_shared_memory.register(Box)
+@read_from_shared_memory.register(Discrete)
+@read_from_shared_memory.register(MultiDiscrete)
+@read_from_shared_memory.register(MultiBinary)
+def _read_base_from_shared_memory(
+ space: Box | Discrete | MultiDiscrete | MultiBinary, shared_memory, n: int = 1
+):
+ return np.frombuffer(shared_memory.get_obj(), dtype=space.dtype).reshape(
+ (n,) + space.shape
+ )
+
+
+@read_from_shared_memory.register(Tuple)
+def _read_tuple_from_shared_memory(space: Tuple, shared_memory, n: int = 1):
+ return tuple(
+ read_from_shared_memory(subspace, memory, n=n)
+ for (memory, subspace) in zip(shared_memory, space.spaces)
+ )
+
+
+@read_from_shared_memory.register(Dict)
+def _read_dict_from_shared_memory(space: Dict, shared_memory, n: int = 1):
+ return {
+ key: read_from_shared_memory(subspace, shared_memory[key], n=n)
+ for (key, subspace) in space.spaces.items()
+ }
+
+
+@read_from_shared_memory.register(Text)
+def _read_text_from_shared_memory(
+ space: Text, shared_memory, n: int = 1
+) -> tuple[str, ...]:
+ data = np.frombuffer(shared_memory.get_obj(), dtype=np.int32).reshape(
+ (n, space.max_length)
+ )
+
+ return tuple(
+ "".join(
+ [
+ space.character_list[val]
+ for val in values
+ if val < len(space.character_set)
+ ]
+ )
+ for values in data
+ )
+
+
+@read_from_shared_memory.register(OneOf)
+def _read_one_of_from_shared_memory(
+ space: OneOf, shared_memory, n: int = 1
+) -> tuple[Any, ...]:
+ sample_indexes = np.frombuffer(shared_memory[0].get_obj(), dtype=np.int64)
+
+ subspace_samples = tuple(
+ read_from_shared_memory(subspace, memory, n=n)
+ for (memory, subspace) in zip(shared_memory[1:], space.spaces)
+ )
+ return tuple(
+ (sample_index, subspace_samples[sample_index][index])
+ for index, sample_index in enumerate(sample_indexes)
+ )
+
+
+
+
+
+
+@write_to_shared_memory.register(Box)
+@write_to_shared_memory.register(Discrete)
+@write_to_shared_memory.register(MultiDiscrete)
+@write_to_shared_memory.register(MultiBinary)
+def _write_base_to_shared_memory(
+ space: Box | Discrete | MultiDiscrete | MultiBinary,
+ index: int,
+ value,
+ shared_memory,
+):
+ size = int(np.prod(space.shape))
+ destination = np.frombuffer(shared_memory.get_obj(), dtype=space.dtype)
+ np.copyto(
+ destination[index * size : (index + 1) * size],
+ np.asarray(value, dtype=space.dtype).flatten(),
+ )
+
+
+@write_to_shared_memory.register(Tuple)
+def _write_tuple_to_shared_memory(
+ space: Tuple, index: int, values: tuple[Any, ...], shared_memory
+):
+ for value, memory, subspace in zip(values, shared_memory, space.spaces):
+ write_to_shared_memory(subspace, index, value, memory)
+
+
+@write_to_shared_memory.register(Dict)
+def _write_dict_to_shared_memory(
+ space: Dict, index: int, values: dict[str, Any], shared_memory
+):
+ for key, subspace in space.spaces.items():
+ write_to_shared_memory(subspace, index, values[key], shared_memory[key])
+
+
+@write_to_shared_memory.register(Text)
+def _write_text_to_shared_memory(space: Text, index: int, values: str, shared_memory):
+ size = space.max_length
+ destination = np.frombuffer(shared_memory.get_obj(), dtype=np.int32)
+ np.copyto(
+ destination[index * size : (index + 1) * size],
+ flatten(space, values),
+ )
+
+
+@write_to_shared_memory.register(OneOf)
+def _write_oneof_to_shared_memory(
+ space: OneOf, index: int, values: tuple[int, Any], shared_memory
+):
+ subspace_idx, space_value = values
+
+ destination = np.frombuffer(shared_memory[0].get_obj(), dtype=np.int64)
+ np.copyto(destination[index : index + 1], subspace_idx)
+
+ # only the subspace's memory is updated with the sample value, ignoring the other memories as data might not match
+ write_to_shared_memory(
+ space.spaces[subspace_idx], index, space_value, shared_memory[1 + subspace_idx]
+ )
+
+"""Space-based utility functions for vector environments.
+
+- ``batch_space``: Create a (batched) space containing multiple copies of a single space.
+- ``batch_differing_spaces``: Create a (batched) space containing copies of different compatible spaces (share a common dtype and shape)
+- ``concatenate``: Concatenate multiple samples from (unbatched) space into a single object.
+- ``Iterate``: Iterate over the elements of a (batched) space and items.
+- ``create_empty_array``: Create an empty (possibly nested) (normally numpy-based) array, used in conjunction with ``concatenate(..., out=array)``
+"""
+
+from __future__ import annotations
+
+import typing
+from collections.abc import Callable, Iterable, Iterator
+from copy import deepcopy
+from functools import singledispatch
+from typing import Any
+
+import numpy as np
+
+from gymnasium.error import CustomSpaceError
+from gymnasium.spaces import (
+ Box,
+ Dict,
+ Discrete,
+ Graph,
+ GraphInstance,
+ MultiBinary,
+ MultiDiscrete,
+ OneOf,
+ Sequence,
+ Space,
+ Text,
+ Tuple,
+)
+from gymnasium.spaces.space import T_cov
+
+
+__all__ = [
+ "batch_space",
+ "batch_differing_spaces",
+ "iterate",
+ "concatenate",
+ "create_empty_array",
+]
+
+
+
+[docs]
+@singledispatch
+def batch_space(space: Space[Any], n: int = 1) -> Space[Any]:
+ """Batch spaces of size `n` optimized for neural networks.
+
+ Args:
+ space: Space (e.g. the observation space for a single environment in the vectorized environment).
+ n: Number of spaces to batch by (e.g. the number of environments in a vectorized environment).
+
+ Returns:
+ Batched space of size `n`.
+
+ Raises:
+ ValueError: Cannot batch spaces that does not have a registered function.
+
+ Example:
+
+ >>> from gymnasium.spaces import Box, Dict
+ >>> import numpy as np
+ >>> space = Dict({
+ ... 'position': Box(low=0, high=1, shape=(3,), dtype=np.float32),
+ ... 'velocity': Box(low=0, high=1, shape=(2,), dtype=np.float32)
+ ... })
+ >>> batch_space(space, n=5)
+ Dict('position': Box(0.0, 1.0, (5, 3), float32), 'velocity': Box(0.0, 1.0, (5, 2), float32))
+ """
+ raise TypeError(
+ f"The space provided to `batch_space` is not a gymnasium Space instance, type: {type(space)}, {space}"
+ )
+
+
+
+@batch_space.register(Box)
+def _batch_space_box(space: Box, n: int = 1):
+ repeats = tuple([n] + [1] * space.low.ndim)
+ low, high = np.tile(space.low, repeats), np.tile(space.high, repeats)
+ return Box(low=low, high=high, dtype=space.dtype, seed=deepcopy(space.np_random))
+
+
+@batch_space.register(Discrete)
+def _batch_space_discrete(space: Discrete, n: int = 1):
+ return MultiDiscrete(
+ np.full((n,), space.n, dtype=space.dtype),
+ dtype=space.dtype,
+ seed=deepcopy(space.np_random),
+ start=np.full((n,), space.start, dtype=space.dtype),
+ )
+
+
+@batch_space.register(MultiDiscrete)
+def _batch_space_multidiscrete(space: MultiDiscrete, n: int = 1):
+ repeats = tuple([n] + [1] * space.nvec.ndim)
+ low = np.tile(space.start, repeats)
+ high = low + np.tile(space.nvec, repeats) - 1
+ return Box(
+ low=low,
+ high=high,
+ dtype=space.dtype,
+ seed=deepcopy(space.np_random),
+ )
+
+
+@batch_space.register(MultiBinary)
+def _batch_space_multibinary(space: MultiBinary, n: int = 1):
+ return Box(
+ low=0,
+ high=1,
+ shape=(n,) + space.shape,
+ dtype=space.dtype,
+ seed=deepcopy(space.np_random),
+ )
+
+
+@batch_space.register(Tuple)
+def _batch_space_tuple(space: Tuple, n: int = 1):
+ return Tuple(
+ tuple(batch_space(subspace, n=n) for subspace in space.spaces),
+ seed=deepcopy(space.np_random),
+ )
+
+
+@batch_space.register(Dict)
+def _batch_space_dict(space: Dict, n: int = 1):
+ return Dict(
+ {key: batch_space(subspace, n=n) for key, subspace in space.items()},
+ seed=deepcopy(space.np_random),
+ )
+
+
+@batch_space.register(Graph)
+@batch_space.register(Text)
+@batch_space.register(Sequence)
+@batch_space.register(OneOf)
+@batch_space.register(Space)
+def _batch_space_custom(space: Graph | Text | Sequence | OneOf, n: int = 1):
+ # Without deepcopy, then the space.np_random is batched_space.spaces[0].np_random
+ # Which is an issue if you are sampling actions of both the original space and the batched space
+ batched_space = Tuple(
+ tuple(deepcopy(space) for _ in range(n)), seed=deepcopy(space.np_random)
+ )
+ space_rng = deepcopy(space.np_random)
+ new_seeds = list(map(int, space_rng.integers(0, 1e8, n)))
+ batched_space.seed(new_seeds)
+ return batched_space
+
+
+@singledispatch
+def batch_differing_spaces(spaces: typing.Sequence[Space]) -> Space:
+ """Batch a Sequence of spaces where subspaces to contain minor differences.
+
+ Args:
+ spaces: A sequence of Spaces with minor differences (the same space type but different parameters).
+
+ Returns:
+ A batched space
+
+ Example:
+ >>> from gymnasium.spaces import Discrete
+ >>> spaces = [Discrete(3), Discrete(5), Discrete(4), Discrete(8)]
+ >>> batch_differing_spaces(spaces)
+ MultiDiscrete([3 5 4 8])
+ """
+ assert len(spaces) > 0, "Expects a non-empty list of spaces"
+ assert all(
+ isinstance(space, type(spaces[0])) for space in spaces
+ ), f"Expects all spaces to be the same shape, actual types: {[type(space) for space in spaces]}"
+ assert (
+ type(spaces[0]) in batch_differing_spaces.registry
+ ), f"Requires the Space type to have a registered `batch_differing_space`, current list: {batch_differing_spaces.registry}"
+
+ return batch_differing_spaces.dispatch(type(spaces[0]))(spaces)
+
+
+@batch_differing_spaces.register(Box)
+def _batch_differing_spaces_box(spaces: list[Box]):
+ assert all(
+ spaces[0].dtype == space.dtype for space in spaces
+ ), f"Expected all dtypes to be equal, actually {[space.dtype for space in spaces]}"
+ assert all(
+ spaces[0].low.shape == space.low.shape for space in spaces
+ ), f"Expected all Box.low shape to be equal, actually {[space.low.shape for space in spaces]}"
+ assert all(
+ spaces[0].high.shape == space.high.shape for space in spaces
+ ), f"Expected all Box.high shape to be equal, actually {[space.high.shape for space in spaces]}"
+
+ return Box(
+ low=np.array([space.low for space in spaces]),
+ high=np.array([space.high for space in spaces]),
+ dtype=spaces[0].dtype,
+ seed=deepcopy(spaces[0].np_random),
+ )
+
+
+@batch_differing_spaces.register(Discrete)
+def _batch_differing_spaces_discrete(spaces: list[Discrete]):
+ return MultiDiscrete(
+ nvec=np.array([space.n for space in spaces]),
+ start=np.array([space.start for space in spaces]),
+ seed=deepcopy(spaces[0].np_random),
+ )
+
+
+@batch_differing_spaces.register(MultiDiscrete)
+def _batch_differing_spaces_multi_discrete(spaces: list[MultiDiscrete]):
+ assert all(
+ spaces[0].dtype == space.dtype for space in spaces
+ ), f"Expected all dtypes to be equal, actually {[space.dtype for space in spaces]}"
+ assert all(
+ spaces[0].nvec.shape == space.nvec.shape for space in spaces
+ ), f"Expects all MultiDiscrete.nvec shape, actually {[space.nvec.shape for space in spaces]}"
+ assert all(
+ spaces[0].start.shape == space.start.shape for space in spaces
+ ), f"Expects all MultiDiscrete.start shape, actually {[space.start.shape for space in spaces]}"
+
+ return Box(
+ low=np.array([space.start for space in spaces]),
+ high=np.array([space.start + space.nvec for space in spaces]) - 1,
+ dtype=spaces[0].dtype,
+ seed=deepcopy(spaces[0].np_random),
+ )
+
+
+@batch_differing_spaces.register(MultiBinary)
+def _batch_differing_spaces_multi_binary(spaces: list[MultiBinary]):
+ assert all(spaces[0].shape == space.shape for space in spaces)
+
+ return Box(
+ low=0,
+ high=1,
+ shape=(len(spaces),) + spaces[0].shape,
+ dtype=spaces[0].dtype,
+ seed=deepcopy(spaces[0].np_random),
+ )
+
+
+@batch_differing_spaces.register(Tuple)
+def _batch_differing_spaces_tuple(spaces: list[Tuple]):
+ return Tuple(
+ tuple(
+ batch_differing_spaces(subspaces)
+ for subspaces in zip(*[space.spaces for space in spaces])
+ ),
+ seed=deepcopy(spaces[0].np_random),
+ )
+
+
+@batch_differing_spaces.register(Dict)
+def _batch_differing_spaces_dict(spaces: list[Dict]):
+ assert all(spaces[0].keys() == space.keys() for space in spaces)
+
+ return Dict(
+ {
+ key: batch_differing_spaces([space[key] for space in spaces])
+ for key in spaces[0].keys()
+ },
+ seed=deepcopy(spaces[0].np_random),
+ )
+
+
+@batch_differing_spaces.register(Graph)
+@batch_differing_spaces.register(Text)
+@batch_differing_spaces.register(Sequence)
+@batch_differing_spaces.register(OneOf)
+def _batch_spaces_undefined(spaces: list[Graph | Text | Sequence | OneOf]):
+ return Tuple(
+ [deepcopy(space) for space in spaces], seed=deepcopy(spaces[0].np_random)
+ )
+
+
+
+[docs]
+@singledispatch
+def iterate(space: Space[T_cov], items: T_cov) -> Iterator:
+ """Iterate over the elements of a (batched) space.
+
+ Args:
+ space: (batched) space (e.g. `action_space` or `observation_space` from vectorized environment).
+ items: Batched samples to be iterated over (e.g. sample from the space).
+
+ Example:
+ >>> from gymnasium.spaces import Box, Dict
+ >>> import numpy as np
+ >>> space = Dict({
+ ... 'position': Box(low=0, high=1, shape=(2, 3), seed=42, dtype=np.float32),
+ ... 'velocity': Box(low=0, high=1, shape=(2, 2), seed=42, dtype=np.float32)})
+ >>> items = space.sample()
+ >>> it = iterate(space, items)
+ >>> next(it)
+ {'position': array([0.77395606, 0.43887845, 0.85859793], dtype=float32), 'velocity': array([0.77395606, 0.43887845], dtype=float32)}
+ >>> next(it)
+ {'position': array([0.697368 , 0.09417735, 0.97562236], dtype=float32), 'velocity': array([0.85859793, 0.697368 ], dtype=float32)}
+ >>> next(it)
+ Traceback (most recent call last):
+ ...
+ StopIteration
+ """
+ if isinstance(space, Space):
+ raise CustomSpaceError(
+ f"Space of type `{type(space)}` doesn't have an registered `iterate` function. Register `{type(space)}` for `iterate` to support it."
+ )
+ else:
+ raise TypeError(
+ f"The space provided to `iterate` is not a gymnasium Space instance, type: {type(space)}, {space}"
+ )
+
+
+
+@iterate.register(Discrete)
+def _iterate_discrete(space: Discrete, items: Iterable):
+ raise TypeError("Unable to iterate over a space of type `Discrete`.")
+
+
+@iterate.register(Box)
+@iterate.register(MultiDiscrete)
+@iterate.register(MultiBinary)
+def _iterate_base(space: Box | MultiDiscrete | MultiBinary, items: np.ndarray):
+ try:
+ return iter(items)
+ except TypeError as e:
+ raise TypeError(
+ f"Unable to iterate over the following elements: {items}"
+ ) from e
+
+
+@iterate.register(Tuple)
+def _iterate_tuple(space: Tuple, items: tuple[Any, ...]):
+ # If this is a tuple of custom subspaces only, then simply iterate over items
+ if all(type(subspace) in iterate.registry for subspace in space):
+ return zip(*[iterate(subspace, items[i]) for i, subspace in enumerate(space)])
+
+ try:
+ return iter(items)
+ except Exception as e:
+ unregistered_spaces = [
+ type(subspace)
+ for subspace in space
+ if type(subspace) not in iterate.registry
+ ]
+ raise CustomSpaceError(
+ f"Could not iterate through {space} as no custom iterate function is registered for {unregistered_spaces} and `iter(items)` raised the following error: {e}."
+ ) from e
+
+
+@iterate.register(Dict)
+def _iterate_dict(space: Dict, items: dict[str, Any]):
+ keys, values = zip(
+ *[
+ (key, iterate(subspace, items[key]))
+ for key, subspace in space.spaces.items()
+ ]
+ )
+ for item in zip(*values):
+ yield {key: value for key, value in zip(keys, item)}
+
+
+
+[docs]
+@singledispatch
+def concatenate(
+ space: Space, items: Iterable, out: tuple[Any, ...] | dict[str, Any] | np.ndarray
+) -> tuple[Any, ...] | dict[str, Any] | np.ndarray:
+ """Concatenate multiple samples from space into a single object.
+
+ Args:
+ space: Space of each item (e.g. `single_action_space` from vectorized environment)
+ items: Samples to be concatenated (e.g. all sample should be an element of the `space`).
+ out: The output object (e.g. generated from `create_empty_array`)
+
+ Returns:
+ The output object, can be the same object `out`.
+
+ Raises:
+ ValueError: Space is not a valid :class:`gymnasium.Space` instance
+
+ Example:
+ >>> from gymnasium.spaces import Box
+ >>> import numpy as np
+ >>> space = Box(low=0, high=1, shape=(3,), seed=42, dtype=np.float32)
+ >>> out = np.zeros((2, 3), dtype=np.float32)
+ >>> items = [space.sample() for _ in range(2)]
+ >>> concatenate(space, items, out)
+ array([[0.77395606, 0.43887845, 0.85859793],
+ [0.697368 , 0.09417735, 0.97562236]], dtype=float32)
+ """
+ raise TypeError(
+ f"The space provided to `concatenate` is not a gymnasium Space instance, type: {type(space)}, {space}"
+ )
+
+
+
+@concatenate.register(Box)
+@concatenate.register(Discrete)
+@concatenate.register(MultiDiscrete)
+@concatenate.register(MultiBinary)
+def _concatenate_base(
+ space: Box | Discrete | MultiDiscrete | MultiBinary,
+ items: Iterable,
+ out: np.ndarray,
+) -> np.ndarray:
+ return np.stack(items, axis=0, out=out)
+
+
+@concatenate.register(Tuple)
+def _concatenate_tuple(
+ space: Tuple, items: Iterable, out: tuple[Any, ...]
+) -> tuple[Any, ...]:
+ return tuple(
+ concatenate(subspace, [item[i] for item in items], out[i])
+ for (i, subspace) in enumerate(space.spaces)
+ )
+
+
+@concatenate.register(Dict)
+def _concatenate_dict(
+ space: Dict, items: Iterable, out: dict[str, Any]
+) -> dict[str, Any]:
+ return {
+ key: concatenate(subspace, [item[key] for item in items], out[key])
+ for key, subspace in space.items()
+ }
+
+
+@concatenate.register(Graph)
+@concatenate.register(Text)
+@concatenate.register(Sequence)
+@concatenate.register(Space)
+@concatenate.register(OneOf)
+def _concatenate_custom(space: Space, items: Iterable, out: None) -> tuple[Any, ...]:
+ return tuple(items)
+
+
+
+[docs]
+@singledispatch
+def create_empty_array(
+ space: Space, n: int = 1, fn: Callable = np.zeros
+) -> tuple[Any, ...] | dict[str, Any] | np.ndarray:
+ """Create an empty (possibly nested and normally numpy-based) array, used in conjunction with ``concatenate(..., out=array)``.
+
+ In most cases, the array will be contained within the batched space, however, this is not guaranteed.
+
+ Args:
+ space: Observation space of a single environment in the vectorized environment.
+ n: Number of environments in the vectorized environment. If ``None``, creates an empty sample from ``space``.
+ fn: Function to apply when creating the empty numpy array. Examples of such functions are ``np.empty`` or ``np.zeros``.
+
+ Returns:
+ The output object. This object is a (possibly nested) numpy array.
+
+ Raises:
+ ValueError: Space is not a valid :class:`gymnasium.Space` instance
+
+ Example:
+ >>> from gymnasium.spaces import Box, Dict
+ >>> import numpy as np
+ >>> space = Dict({
+ ... 'position': Box(low=0, high=1, shape=(3,), dtype=np.float32),
+ ... 'velocity': Box(low=0, high=1, shape=(2,), dtype=np.float32)})
+ >>> create_empty_array(space, n=2, fn=np.zeros)
+ {'position': array([[0., 0., 0.],
+ [0., 0., 0.]], dtype=float32), 'velocity': array([[0., 0.],
+ [0., 0.]], dtype=float32)}
+ """
+ raise TypeError(
+ f"The space provided to `create_empty_array` is not a gymnasium Space instance, type: {type(space)}, {space}"
+ )
+
+
+
+# It is possible for some of the Box low to be greater than 0, then array is not in space
+@create_empty_array.register(Box)
+# If the Discrete start > 0 or start + length < 0 then array is not in space
+@create_empty_array.register(Discrete)
+@create_empty_array.register(MultiDiscrete)
+@create_empty_array.register(MultiBinary)
+def _create_empty_array_multi(space: Box, n: int = 1, fn=np.zeros) -> np.ndarray:
+ return fn((n,) + space.shape, dtype=space.dtype)
+
+
+@create_empty_array.register(Tuple)
+def _create_empty_array_tuple(space: Tuple, n: int = 1, fn=np.zeros) -> tuple[Any, ...]:
+ return tuple(create_empty_array(subspace, n=n, fn=fn) for subspace in space.spaces)
+
+
+@create_empty_array.register(Dict)
+def _create_empty_array_dict(space: Dict, n: int = 1, fn=np.zeros) -> dict[str, Any]:
+ return {
+ key: create_empty_array(subspace, n=n, fn=fn) for key, subspace in space.items()
+ }
+
+
+@create_empty_array.register(Graph)
+def _create_empty_array_graph(
+ space: Graph, n: int = 1, fn=np.zeros
+) -> tuple[GraphInstance, ...]:
+ if space.edge_space is not None:
+ return tuple(
+ GraphInstance(
+ nodes=fn((1,) + space.node_space.shape, dtype=space.node_space.dtype),
+ edges=fn((1,) + space.edge_space.shape, dtype=space.edge_space.dtype),
+ edge_links=fn((1, 2), dtype=np.int64),
+ )
+ for _ in range(n)
+ )
+ else:
+ return tuple(
+ GraphInstance(
+ nodes=fn((1,) + space.node_space.shape, dtype=space.node_space.dtype),
+ edges=None,
+ edge_links=None,
+ )
+ for _ in range(n)
+ )
+
+
+@create_empty_array.register(Text)
+def _create_empty_array_text(space: Text, n: int = 1, fn=np.zeros) -> tuple[str, ...]:
+ return tuple(space.characters[0] * space.min_length for _ in range(n))
+
+
+@create_empty_array.register(Sequence)
+def _create_empty_array_sequence(
+ space: Sequence, n: int = 1, fn=np.zeros
+) -> tuple[Any, ...]:
+ if space.stack:
+ return tuple(
+ create_empty_array(space.feature_space, n=1, fn=fn) for _ in range(n)
+ )
+ else:
+ return tuple(tuple() for _ in range(n))
+
+
+@create_empty_array.register(OneOf)
+def _create_empty_array_oneof(space: OneOf, n: int = 1, fn=np.zeros):
+ return tuple(tuple() for _ in range(n))
+
+
+@create_empty_array.register(Space)
+def _create_empty_array_custom(space, n=1, fn=np.zeros):
+ return None
+
+"""Base class for vectorized environments."""
+
+from __future__ import annotations
+
+from enum import Enum
+from typing import TYPE_CHECKING, Any, Generic, TypeVar
+
+import numpy as np
+
+import gymnasium as gym
+from gymnasium.core import ActType, ObsType, RenderFrame
+from gymnasium.logger import warn
+from gymnasium.utils import seeding
+
+
+if TYPE_CHECKING:
+ from gymnasium.envs.registration import EnvSpec
+
+ArrayType = TypeVar("ArrayType")
+
+
+__all__ = [
+ "VectorEnv",
+ "VectorWrapper",
+ "VectorObservationWrapper",
+ "VectorActionWrapper",
+ "VectorRewardWrapper",
+ "ArrayType",
+ "AutoresetMode",
+]
+
+
+class AutoresetMode(Enum):
+ """Enum representing the different autoreset modes, next step, same step and disabled."""
+
+ NEXT_STEP = "NextStep"
+ SAME_STEP = "SameStep"
+ DISABLED = "Disabled"
+
+
+
+[docs]
+class VectorEnv(Generic[ObsType, ActType, ArrayType]):
+ """Base class for vectorized environments to run multiple independent copies of the same environment in parallel.
+
+ Vector environments can provide a linear speed-up in the steps taken per second through sampling multiple
+ sub-environments at the same time. Gymnasium contains two generalised Vector environments: :class:`AsyncVectorEnv`
+ and :class:`SyncVectorEnv` along with several custom vector environment implementations.
+ For :func:`reset` and :func:`step` batches `observations`, `rewards`, `terminations`, `truncations` and
+ `info` for each sub-environment, see the example below. For the `rewards`, `terminations`, and `truncations`,
+ the data is packaged into a NumPy array of shape `(num_envs,)`. For `observations` (and `actions`, the batching
+ process is dependent on the type of observation (and action) space, and generally optimised for neural network
+ input/outputs. For `info`, the data is kept as a dictionary such that a key will give the data for all sub-environment.
+
+ For creating environments, :func:`make_vec` is a vector environment equivalent to :func:`make` for easily creating
+ vector environments that contains several unique arguments for modifying environment qualities, number of environment,
+ vectorizer type, vectorizer arguments.
+
+ To avoid having to wait for all sub-environments to terminated before resetting, implementations can autoreset
+ sub-environments on episode end (`terminated or truncated is True`). This is crucial for correct implementing training
+ algorithms with vector environments. By default, Gymnasium's implementation uses `next-step` autoreset, with
+ :class:`AutoresetMode` enum as the options. The mode used by vector environment should be available in `metadata["autoreset_mode"]`.
+ Warning, some vector implementations or training algorithms will only support particular autoreset modes.
+ For more information, read https://farama.org/Vector-Autoreset-Mode.
+
+ Note:
+ The info parameter of :meth:`reset` and :meth:`step` was originally implemented before v0.25 as a list
+ of dictionary for each sub-environment. However, this was modified in v0.25+ to be a dictionary with a NumPy
+ array for each key. To use the old info style, utilise the :class:`DictInfoToList` wrapper.
+
+ Examples:
+ >>> import gymnasium as gym
+ >>> envs = gym.make_vec("CartPole-v1", num_envs=3, vectorization_mode="sync", wrappers=(gym.wrappers.TimeAwareObservation,))
+ >>> envs = gym.wrappers.vector.ClipReward(envs, min_reward=0.2, max_reward=0.8)
+ >>> envs
+ <ClipReward, SyncVectorEnv(CartPole-v1, num_envs=3)>
+ >>> envs.num_envs
+ 3
+ >>> envs.action_space
+ MultiDiscrete([2 2 2])
+ >>> envs.observation_space
+ Box([[-4.80000019 -inf -0.41887903 -inf 0. ]
+ [-4.80000019 -inf -0.41887903 -inf 0. ]
+ [-4.80000019 -inf -0.41887903 -inf 0. ]], [[4.80000019e+00 inf 4.18879032e-01 inf
+ 5.00000000e+02]
+ [4.80000019e+00 inf 4.18879032e-01 inf
+ 5.00000000e+02]
+ [4.80000019e+00 inf 4.18879032e-01 inf
+ 5.00000000e+02]], (3, 5), float64)
+ >>> observations, infos = envs.reset(seed=123)
+ >>> observations
+ array([[ 0.01823519, -0.0446179 , -0.02796401, -0.03156282, 0. ],
+ [ 0.02852531, 0.02858594, 0.0469136 , 0.02480598, 0. ],
+ [ 0.03517495, -0.000635 , -0.01098382, -0.03203924, 0. ]])
+ >>> infos
+ {}
+ >>> _ = envs.action_space.seed(123)
+ >>> actions = envs.action_space.sample()
+ >>> observations, rewards, terminations, truncations, infos = envs.step(actions)
+ >>> observations
+ array([[ 0.01734283, 0.15089367, -0.02859527, -0.33293587, 1. ],
+ [ 0.02909703, -0.16717631, 0.04740972, 0.3319138 , 1. ],
+ [ 0.03516225, -0.19559774, -0.01162461, 0.25715804, 1. ]])
+ >>> rewards
+ array([0.8, 0.8, 0.8])
+ >>> terminations
+ array([False, False, False])
+ >>> truncations
+ array([False, False, False])
+ >>> infos
+ {}
+ >>> envs.close()
+
+ The Vector Environments have the additional attributes for users to understand the implementation
+
+ - :attr:`num_envs` - The number of sub-environment in the vector environment
+ - :attr:`observation_space` - The batched observation space of the vector environment
+ - :attr:`single_observation_space` - The observation space of a single sub-environment
+ - :attr:`action_space` - The batched action space of the vector environment
+ - :attr:`single_action_space` - The action space of a single sub-environment
+ """
+
+ metadata: dict[str, Any] = {}
+ spec: EnvSpec | None = None
+ render_mode: str | None = None
+ closed: bool = False
+
+ observation_space: gym.Space
+ action_space: gym.Space
+ single_observation_space: gym.Space
+ single_action_space: gym.Space
+
+ num_envs: int
+
+ _np_random: np.random.Generator | None = None
+ _np_random_seed: int | None = None
+
+
+[docs]
+ def reset(
+ self,
+ *,
+ seed: int | None = None,
+ options: dict[str, Any] | None = None,
+ ) -> tuple[ObsType, dict[str, Any]]: # type: ignore
+ """Reset all parallel environments and return a batch of initial observations and info.
+
+ Args:
+ seed: The environment reset seed
+ options: If to return the options
+
+ Returns:
+ A batch of observations and info from the vectorized environment.
+
+ Example:
+ >>> import gymnasium as gym
+ >>> envs = gym.make_vec("CartPole-v1", num_envs=3, vectorization_mode="sync")
+ >>> observations, infos = envs.reset(seed=42)
+ >>> observations
+ array([[ 0.0273956 , -0.00611216, 0.03585979, 0.0197368 ],
+ [ 0.01522993, -0.04562247, -0.04799704, 0.03392126],
+ [-0.03774345, -0.02418869, -0.00942293, 0.0469184 ]],
+ dtype=float32)
+ >>> infos
+ {}
+ """
+ if seed is not None:
+ self._np_random, self._np_random_seed = seeding.np_random(seed)
+
+
+
+[docs]
+ def step(
+ self, actions: ActType
+ ) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict[str, Any]]:
+ """Take an action for each parallel environment.
+
+ Args:
+ actions: Batch of actions with the :attr:`action_space` shape.
+
+ Returns:
+ Batch of (observations, rewards, terminations, truncations, infos)
+
+ Note:
+ As the vector environments autoreset for a terminating and truncating sub-environments, this will occur on
+ the next step after `terminated or truncated is True`.
+
+ Example:
+ >>> import gymnasium as gym
+ >>> import numpy as np
+ >>> envs = gym.make_vec("CartPole-v1", num_envs=3, vectorization_mode="sync")
+ >>> _ = envs.reset(seed=42)
+ >>> actions = np.array([1, 0, 1], dtype=np.int32)
+ >>> observations, rewards, terminations, truncations, infos = envs.step(actions)
+ >>> observations
+ array([[ 0.02727336, 0.18847767, 0.03625453, -0.26141977],
+ [ 0.01431748, -0.24002443, -0.04731862, 0.3110827 ],
+ [-0.03822722, 0.1710671 , -0.00848456, -0.2487226 ]],
+ dtype=float32)
+ >>> rewards
+ array([1., 1., 1.])
+ >>> terminations
+ array([False, False, False])
+ >>> terminations
+ array([False, False, False])
+ >>> infos
+ {}
+ """
+ raise NotImplementedError(f"{self.__str__()} step function is not implemented.")
+
+
+
+[docs]
+ def render(self) -> tuple[RenderFrame, ...] | None:
+ """Returns the rendered frames from the parallel environments.
+
+ Returns:
+ A tuple of rendered frames from the parallel environments
+ """
+ raise NotImplementedError(
+ f"{self.__str__()} render function is not implemented."
+ )
+
+
+
+[docs]
+ def close(self, **kwargs: Any):
+ """Close all parallel environments and release resources.
+
+ It also closes all the existing image viewers, then calls :meth:`close_extras` and set
+ :attr:`closed` as ``True``.
+
+ Warnings:
+ This function itself does not close the environments, it should be handled
+ in :meth:`close_extras`. This is generic for both synchronous and asynchronous
+ vectorized environments.
+
+ Note:
+ This will be automatically called when garbage collected or program exited.
+
+ Args:
+ **kwargs: Keyword arguments passed to :meth:`close_extras`
+ """
+ if self.closed:
+ return
+
+ self.close_extras(**kwargs)
+ self.closed = True
+
+
+ def close_extras(self, **kwargs: Any):
+ """Clean up the extra resources e.g. beyond what's in this base class."""
+ pass
+
+ @property
+ def np_random(self) -> np.random.Generator:
+ """Returns the environment's internal :attr:`_np_random` that if not set will initialise with a random seed.
+
+ Returns:
+ Instances of `np.random.Generator`
+ """
+ if self._np_random is None:
+ self._np_random, self._np_random_seed = seeding.np_random()
+ return self._np_random
+
+ @np_random.setter
+ def np_random(self, value: np.random.Generator):
+ self._np_random = value
+ self._np_random_seed = -1
+
+ @property
+ def np_random_seed(self) -> int | None:
+ """Returns the environment's internal :attr:`_np_random_seed` that if not set will first initialise with a random int as seed.
+
+ If :attr:`np_random_seed` was set directly instead of through :meth:`reset` or :meth:`set_np_random_through_seed`,
+ the seed will take the value -1.
+
+ Returns:
+ int: the seed of the current `np_random` or -1, if the seed of the rng is unknown
+ """
+ if self._np_random_seed is None:
+ self._np_random, self._np_random_seed = seeding.np_random()
+ return self._np_random_seed
+
+ @property
+ def unwrapped(self):
+ """Return the base environment."""
+ return self
+
+ def _add_info(
+ self, vector_infos: dict[str, Any], env_info: dict[str, Any], env_num: int
+ ) -> dict[str, Any]:
+ """Add env info to the info dictionary of the vectorized environment.
+
+ Given the `info` of a single environment add it to the `infos` dictionary
+ which represents all the infos of the vectorized environment.
+ Every `key` of `info` is paired with a boolean mask `_key` representing
+ whether or not the i-indexed environment has this `info`.
+
+ Args:
+ vector_infos (dict): the infos of the vectorized environment
+ env_info (dict): the info coming from the single environment
+ env_num (int): the index of the single environment
+
+ Returns:
+ infos (dict): the (updated) infos of the vectorized environment
+ """
+ for key, value in env_info.items():
+ # It is easier for users to access their `final_obs` in the unbatched array of `obs` objects
+ if key == "final_obs":
+ if "final_obs" in vector_infos:
+ array = vector_infos["final_obs"]
+ else:
+ array = np.full(self.num_envs, fill_value=None, dtype=object)
+ array[env_num] = value
+ # If value is a dictionary, then we apply the `_add_info` recursively.
+ elif isinstance(value, dict):
+ array = self._add_info(vector_infos.get(key, {}), value, env_num)
+ # Otherwise, we are a base case to group the data
+ else:
+ # If the key doesn't exist in the vector infos, then we can create an array of that batch type
+ if key not in vector_infos:
+ if type(value) in [int, float, bool] or issubclass(
+ type(value), np.number
+ ):
+ array = np.zeros(self.num_envs, dtype=type(value))
+ elif isinstance(value, np.ndarray):
+ # We assume that all instances of the np.array info are of the same shape
+ array = np.zeros(
+ (self.num_envs, *value.shape), dtype=value.dtype
+ )
+ else:
+ # For unknown objects, we use a Numpy object array
+ array = np.full(self.num_envs, fill_value=None, dtype=object)
+ # Otherwise, just use the array that already exists
+ else:
+ array = vector_infos[key]
+
+ # Assign the data in the `env_num` position
+ # We only want to run this for the base-case data (not recursive data forcing the ugly function structure)
+ array[env_num] = value
+
+ # Get the array mask and if it doesn't already exist then create a zero bool array
+ array_mask = vector_infos.get(
+ f"_{key}", np.zeros(self.num_envs, dtype=np.bool_)
+ )
+ array_mask[env_num] = True
+
+ # Update the vector info with the updated data and mask information
+ vector_infos[key], vector_infos[f"_{key}"] = array, array_mask
+ return vector_infos
+
+ def __del__(self):
+ """Closes the vector environment."""
+ if not getattr(self, "closed", True):
+ self.close()
+
+ def __repr__(self) -> str:
+ """Returns a string representation of the vector environment.
+
+ Returns:
+ A string containing the class name, number of environments and environment spec id
+ """
+ if self.spec is None:
+ return f"{self.__class__.__name__}(num_envs={self.num_envs})"
+ else:
+ return (
+ f"{self.__class__.__name__}({self.spec.id}, num_envs={self.num_envs})"
+ )
+
+
+
+
+[docs]
+class VectorWrapper(VectorEnv):
+ """Wraps the vectorized environment to allow a modular transformation.
+
+ This class is the base class for all wrappers for vectorized environments. The subclass
+ could override some methods to change the behavior of the original vectorized environment
+ without touching the original code.
+
+ Note:
+ Don't forget to call ``super().__init__(env)`` if the subclass overrides :meth:`__init__`.
+ """
+
+ def __init__(self, env: VectorEnv):
+ """Initialize the vectorized environment wrapper.
+
+ Args:
+ env: The environment to wrap
+ """
+ self.env = env
+ assert isinstance(
+ env, VectorEnv
+ ), f"Expected env to be a `gymnasium.vector.VectorEnv` but got {type(env)}"
+
+ self._observation_space: gym.Space | None = None
+ self._action_space: gym.Space | None = None
+ self._single_observation_space: gym.Space | None = None
+ self._single_action_space: gym.Space | None = None
+ self._metadata: dict[str, Any] | None = None
+
+
+[docs]
+ def reset(
+ self,
+ *,
+ seed: int | list[int] | None = None,
+ options: dict[str, Any] | None = None,
+ ) -> tuple[ObsType, dict[str, Any]]:
+ """Reset all environment using seed and options."""
+ return self.env.reset(seed=seed, options=options)
+
+
+
+[docs]
+ def step(
+ self, actions: ActType
+ ) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict[str, Any]]:
+ """Step through all environments using the actions returning the batched data."""
+ return self.env.step(actions)
+
+
+
+[docs]
+ def render(self) -> tuple[RenderFrame, ...] | None:
+ """Returns the render mode from the base vector environment."""
+ return self.env.render()
+
+
+
+[docs]
+ def close(self, **kwargs: Any):
+ """Close all environments."""
+ return self.env.close(**kwargs)
+
+
+ def close_extras(self, **kwargs: Any):
+ """Close all extra resources."""
+ return self.env.close_extras(**kwargs)
+
+ @property
+ def unwrapped(self):
+ """Return the base non-wrapped environment."""
+ return self.env.unwrapped
+
+ def __repr__(self):
+ """Return the string representation of the vectorized environment."""
+ return f"<{self.__class__.__name__}, {self.env}>"
+
+ @property
+ def observation_space(self) -> gym.Space:
+ """Gets the observation space of the vector environment."""
+ if self._observation_space is None:
+ return self.env.observation_space
+ return self._observation_space
+
+ @observation_space.setter
+ def observation_space(self, space: gym.Space):
+ """Sets the observation space of the vector environment."""
+ self._observation_space = space
+
+ @property
+ def action_space(self) -> gym.Space:
+ """Gets the action space of the vector environment."""
+ if self._action_space is None:
+ return self.env.action_space
+ return self._action_space
+
+ @action_space.setter
+ def action_space(self, space: gym.Space):
+ """Sets the action space of the vector environment."""
+ self._action_space = space
+
+ @property
+ def single_observation_space(self) -> gym.Space:
+ """Gets the single observation space of the vector environment."""
+ if self._single_observation_space is None:
+ return self.env.single_observation_space
+ return self._single_observation_space
+
+ @single_observation_space.setter
+ def single_observation_space(self, space: gym.Space):
+ """Sets the single observation space of the vector environment."""
+ self._single_observation_space = space
+
+ @property
+ def single_action_space(self) -> gym.Space:
+ """Gets the single action space of the vector environment."""
+ if self._single_action_space is None:
+ return self.env.single_action_space
+ return self._single_action_space
+
+ @single_action_space.setter
+ def single_action_space(self, space):
+ """Sets the single action space of the vector environment."""
+ self._single_action_space = space
+
+ @property
+ def num_envs(self) -> int:
+ """Gets the wrapped vector environment's num of the sub-environments."""
+ return self.env.num_envs
+
+ @property
+ def np_random(self) -> np.random.Generator:
+ """Returns the environment's internal :attr:`_np_random` that if not set will initialise with a random seed.
+
+ Returns:
+ Instances of `np.random.Generator`
+ """
+ return self.env.np_random
+
+ @np_random.setter
+ def np_random(self, value: np.random.Generator):
+ self.env.np_random = value
+
+ @property
+ def np_random_seed(self) -> int | None:
+ """The seeds of the vector environment's internal :attr:`_np_random`."""
+ return self.env.np_random_seed
+
+ @property
+ def metadata(self):
+ """The metadata of the vector environment."""
+ if self._metadata is not None:
+ return self._metadata
+ return self.env.metadata
+
+ @metadata.setter
+ def metadata(self, value):
+ self._metadata = value
+
+ @property
+ def spec(self) -> EnvSpec | None:
+ """Gets the specification of the wrapped environment."""
+ return self.env.spec
+
+ @property
+ def render_mode(self) -> tuple[RenderFrame, ...] | None:
+ """Returns the `render_mode` from the base environment."""
+ return self.env.render_mode
+
+ @property
+ def closed(self):
+ """If the environment has closes."""
+ return self.env.closed
+
+ @closed.setter
+ def closed(self, value: bool):
+ self.env.closed = value
+
+
+
+
+[docs]
+class VectorObservationWrapper(VectorWrapper):
+ """Wraps the vectorized environment to allow a modular transformation of the observation.
+
+ Equivalent to :class:`gymnasium.ObservationWrapper` for vectorized environments.
+ """
+
+ def __init__(self, env: VectorEnv):
+ """Vector observation wrapper that batch transforms observations.
+
+ Args:
+ env: Vector environment.
+ """
+ super().__init__(env)
+ if "autoreset_mode" not in env.metadata:
+ warn(
+ f"Vector environment ({env}) is missing `autoreset_mode` metadata key."
+ )
+ else:
+ assert (
+ env.metadata["autoreset_mode"] == AutoresetMode.NEXT_STEP
+ or env.metadata["autoreset_mode"] == AutoresetMode.DISABLED
+ )
+
+ def reset(
+ self,
+ *,
+ seed: int | list[int] | None = None,
+ options: dict[str, Any] | None = None,
+ ) -> tuple[ObsType, dict[str, Any]]:
+ """Modifies the observation returned from the environment ``reset`` using the :meth:`observation`."""
+ observations, infos = self.env.reset(seed=seed, options=options)
+ return self.observations(observations), infos
+
+ def step(
+ self, actions: ActType
+ ) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict[str, Any]]:
+ """Modifies the observation returned from the environment ``step`` using the :meth:`observation`."""
+ observations, rewards, terminations, truncations, infos = self.env.step(actions)
+ return (
+ self.observations(observations),
+ rewards,
+ terminations,
+ truncations,
+ infos,
+ )
+
+
+[docs]
+ def observations(self, observations: ObsType) -> ObsType:
+ """Defines the vector observation transformation.
+
+ Args:
+ observations: A vector observation from the environment
+
+ Returns:
+ the transformed observation
+ """
+ raise NotImplementedError
+
+
+
+
+
+[docs]
+class VectorActionWrapper(VectorWrapper):
+ """Wraps the vectorized environment to allow a modular transformation of the actions.
+
+ Equivalent of :class:`gymnasium.ActionWrapper` for vectorized environments.
+ """
+
+ def step(
+ self, actions: ActType
+ ) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict[str, Any]]:
+ """Steps through the environment using a modified action by :meth:`action`."""
+ return self.env.step(self.actions(actions))
+
+
+[docs]
+ def actions(self, actions: ActType) -> ActType:
+ """Transform the actions before sending them to the environment.
+
+ Args:
+ actions (ActType): the actions to transform
+
+ Returns:
+ ActType: the transformed actions
+ """
+ raise NotImplementedError
+
+
+
+
+
+[docs]
+class VectorRewardWrapper(VectorWrapper):
+ """Wraps the vectorized environment to allow a modular transformation of the reward.
+
+ Equivalent of :class:`gymnasium.RewardWrapper` for vectorized environments.
+ """
+
+ def step(
+ self, actions: ActType
+ ) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict[str, Any]]:
+ """Steps through the environment returning a reward modified by :meth:`reward`."""
+ observations, rewards, terminations, truncations, infos = self.env.step(actions)
+ return observations, self.rewards(rewards), terminations, truncations, infos
+
+
+[docs]
+ def rewards(self, rewards: ArrayType) -> ArrayType:
+ """Transform the reward before returning it.
+
+ Args:
+ rewards (array): the reward to transform
+
+ Returns:
+ array: the transformed reward
+ """
+ raise NotImplementedError
+
+
+
+# This wrapper will convert array inputs from an Array API compatible framework A for the actions
+# to any other Array API compatible framework B for an underlying environment that is implemented
+# in framework B, then convert the return observations from framework B back to framework A.
+#
+# More precisely, the wrapper will work for any two frameworks that can be made compatible with the
+# `array-api-compat` package.
+#
+# See https://data-apis.org/array-api/latest/ for more information on the Array API standard, and
+# https://data-apis.org/array-api-compat/ for more information on the Array API compatibility layer.
+#
+# General structure for converting between types originally copied from
+# https://github.com/google/brax/blob/9d6b7ced2a13da0d074b5e9fbd3aad8311e26997/brax/io/torch.py
+# Under the Apache 2.0 license. Copyright is held by the authors
+
+"""Helper functions and wrapper class for converting between arbitrary Array API compatible frameworks and a target framework."""
+
+from __future__ import annotations
+
+import functools
+import importlib
+import numbers
+from collections import abc
+from collections.abc import Iterable, Mapping
+from types import ModuleType, NoneType
+from typing import Any, SupportsFloat
+
+import numpy as np
+from packaging.version import Version
+
+import gymnasium as gym
+from gymnasium.core import RenderFrame, WrapperActType, WrapperObsType
+from gymnasium.error import DependencyNotInstalled
+
+
+try:
+ from array_api_compat import array_namespace, is_array_api_obj, to_device
+
+except ImportError:
+ raise DependencyNotInstalled(
+ 'Array API packages are not installed therefore cannot call `array_conversion`, run `pip install "gymnasium[array-api]"`'
+ )
+
+
+if Version(np.__version__) < Version("2.1.0"):
+ raise DependencyNotInstalled("Array API functionality requires numpy >= 2.1.0")
+
+
+__all__ = ["ArrayConversion", "array_conversion"]
+
+Array = Any # TODO: Switch to ArrayAPI type once https://github.com/data-apis/array-api/pull/589 is merged
+Device = Any # TODO: Switch to ArrayAPI type if available
+
+
+def module_namespace(xp: ModuleType) -> ModuleType:
+ """Determine the Array API compatible namespace of the given module.
+
+ This function is closely linked to the `array_api_compat.array_namespace` function. It returns
+ the compatible namespace for a module directly instead of from an array object of that module.
+
+ See https://data-apis.org/array-api-compat/helper-functions.html#array_api_compat.array_namespace
+ """
+ try:
+ return array_namespace(xp.empty(0))
+ except AttributeError as e:
+ raise ValueError(f"Module {xp} is not an Array API compatible module.") from e
+
+
+def module_name_to_namespace(name: str) -> ModuleType:
+ return module_namespace(importlib.import_module(name))
+
+
+@functools.singledispatch
+def array_conversion(value: Any, xp: ModuleType, device: Device | None = None) -> Any:
+ """Convert a value into the specified xp module array type."""
+ raise Exception(
+ f"No known conversion for ({type(value)}) to xp module ({xp}) registered. Report as issue on github."
+ )
+
+
+@array_conversion.register(numbers.Number)
+def _number_array_conversion(
+ value: numbers.Number, xp: ModuleType, device: Device | None = None
+) -> Array:
+ """Convert a python number (int, float, complex) to an Array API framework array."""
+ return xp.asarray(value, device=device)
+
+
+@array_conversion.register(abc.Mapping)
+def _mapping_array_conversion(
+ value: Mapping[str, Any], xp: ModuleType, device: Device | None = None
+) -> Mapping[str, Any]:
+ """Convert a mapping of Arrays into a Dictionary of the specified xp module array type."""
+ return type(value)(**{k: array_conversion(v, xp, device) for k, v in value.items()})
+
+
+@array_conversion.register(abc.Iterable)
+def _iterable_array_conversion(
+ value: Iterable[Any], xp: ModuleType, device: Device | None = None
+) -> Iterable[Any]:
+ """Convert an Iterable from Arrays to an iterable of the specified xp module array type."""
+ # There is currently no type for ArrayAPI compatible objects, so they fall through to this
+ # function registered for any Iterable. If they are arrays, we can convert them directly.
+ # We currently cannot pass the device to the from_dlpack function, since it is not supported
+ # for some frameworks (see e.g. https://github.com/data-apis/array-api-compat/issues/204)
+ if is_array_api_obj(value):
+ return _array_api_array_conversion(value, xp, device)
+ if hasattr(value, "_make"):
+ # namedtuple - underline used to prevent potential name conflicts
+ # noinspection PyProtectedMember
+ return type(value)._make(array_conversion(v, xp, device) for v in value)
+ return type(value)(array_conversion(v, xp, device) for v in value)
+
+
+def _array_api_array_conversion(
+ value: Array, xp: ModuleType, device: Device | None = None
+) -> Array:
+ """Convert an Array API compatible array to the specified xp module array type."""
+ try:
+ x = xp.from_dlpack(value)
+ return to_device(x, device) if device is not None else x
+ except (RuntimeError, BufferError):
+ # If dlpack fails (e.g. because the array is read-only for frameworks that do not
+ # support it), we create a copy of the array that we own and then convert it.
+ # TODO: The correct treatment of read-only arrays is currently not fully clear in the
+ # Array API. Once ongoing discussions are resolved, we should update this code to remove
+ # any fallbacks.
+ value_namespace = array_namespace(value)
+ value_copy = value_namespace.asarray(value, copy=True)
+ return xp.asarray(value_copy, device=device)
+
+
+@array_conversion.register(NoneType)
+def _none_array_conversion(
+ value: None, xp: ModuleType, device: Device | None = None
+) -> None:
+ """Passes through None values."""
+ return value
+
+
+
+[docs]
+class ArrayConversion(gym.Wrapper, gym.utils.RecordConstructorArgs):
+ """Wraps an Array API compatible environment so that it can be interacted with with another Array API framework.
+
+ Popular Array API frameworks include ``numpy``, ``torch``, ``jax.numpy``, ``cupy`` etc. With this wrapper, you can convert outputs from your environment to
+ any of these frameworks. Conversely, actions are automatically mapped back to the environment framework, if possible without moving the
+ data or device transfers.
+
+ A vector version of the wrapper exists, :class:`gymnasium.wrappers.vector.ArrayConversion`.
+
+ Example:
+ >>> import torch # doctest: +SKIP
+ >>> import jax.numpy as jnp # doctest: +SKIP
+ >>> import gymnasium as gym # doctest: +SKIP
+ >>> env = gym.make("JaxEnv-vx") # doctest: +SKIP
+ >>> env = ArrayConversion(env, env_xp=jnp, target_xp=torch) # doctest: +SKIP
+ >>> obs, _ = env.reset(seed=123) # doctest: +SKIP
+ >>> type(obs) # doctest: +SKIP
+ <class 'torch.Tensor'>
+ >>> action = torch.tensor(env.action_space.sample()) # doctest: +SKIP
+ >>> obs, reward, terminated, truncated, info = env.step(action) # doctest: +SKIP
+ >>> type(obs) # doctest: +SKIP
+ <class 'torch.Tensor'>
+ >>> type(reward) # doctest: +SKIP
+ <class 'float'>
+ >>> type(terminated) # doctest: +SKIP
+ <class 'bool'>
+ >>> type(truncated) # doctest: +SKIP
+ <class 'bool'>
+
+ Change logs:
+ * v1.2.0 - Initially added
+ """
+
+ def __init__(
+ self,
+ env: gym.Env,
+ env_xp: ModuleType,
+ target_xp: ModuleType,
+ env_device: Device | None = None,
+ target_device: Device | None = None,
+ ):
+ """Wrapper class to change inputs and outputs of environment to any Array API framework.
+
+ Args:
+ env: The Array API compatible environment to wrap
+ env_xp: The Array API framework the environment is on
+ target_xp: The Array API framework to convert to
+ env_device: The device the environment is on
+ target_device: The device on which Arrays should be returned
+ """
+ gym.utils.RecordConstructorArgs.__init__(self)
+ gym.Wrapper.__init__(self, env)
+
+ self._env_xp = module_namespace(env_xp)
+ self._target_xp = module_namespace(target_xp)
+ self._env_device: Device | None = env_device
+ self._target_device: Device | None = target_device
+
+ def step(
+ self, action: WrapperActType
+ ) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]:
+ """Performs the given action within the environment.
+
+ Args:
+ action: The action to perform as any Array API compatible array
+
+ Returns:
+ The next observation, reward, termination, truncation, and extra info
+ """
+ action = array_conversion(action, xp=self._env_xp, device=self._env_device)
+ obs, reward, terminated, truncated, info = self.env.step(action)
+
+ return (
+ array_conversion(obs, xp=self._target_xp, device=self._target_device),
+ float(reward),
+ bool(terminated),
+ bool(truncated),
+ array_conversion(info, xp=self._target_xp, device=self._target_device),
+ )
+
+ def reset(
+ self, *, seed: int | None = None, options: dict[str, Any] | None = None
+ ) -> tuple[WrapperObsType, dict[str, Any]]:
+ """Resets the environment returning observation and info as Array from any Array API compatible framework.
+
+ Args:
+ seed: The seed for resetting the environment
+ options: The options for resetting the environment, these are converted to jax arrays.
+
+ Returns:
+ xp-based observations and info
+ """
+ if options:
+ options = array_conversion(options, self._env_xp, self._env_device)
+
+ return array_conversion(
+ self.env.reset(seed=seed, options=options),
+ self._target_xp,
+ self._target_device,
+ )
+
+ def render(self) -> RenderFrame | list[RenderFrame] | None:
+ """Returns the rendered frames as an xp Array."""
+ return array_conversion(self.env.render(), self._target_xp, self._target_device)
+
+ def __getstate__(self):
+ """Returns the object pickle state with args and kwargs."""
+ env_xp_name = self._env_xp.__name__.replace("array_api_compat.", "")
+ target_xp_name = self._target_xp.__name__.replace("array_api_compat.", "")
+ env_device = self._env_device
+ target_device = self._target_device
+ return {
+ "env_xp_name": env_xp_name,
+ "target_xp_name": target_xp_name,
+ "env_device": env_device,
+ "target_device": target_device,
+ "env": self.env,
+ }
+
+ def __setstate__(self, d):
+ """Sets the object pickle state using d."""
+ self.env = d["env"]
+ self._env_xp = module_name_to_namespace(d["env_xp_name"])
+ self._target_xp = module_name_to_namespace(d["target_xp_name"])
+ self._env_device = d["env_device"]
+ self._target_device = d["target_device"]
+
+
+"""Implementation of Atari 2600 Preprocessing following the guidelines of Machado et al., 2018."""
+
+from __future__ import annotations
+
+from typing import Any, SupportsFloat
+
+import numpy as np
+
+import gymnasium as gym
+from gymnasium.core import WrapperActType, WrapperObsType
+from gymnasium.spaces import Box
+
+
+__all__ = ["AtariPreprocessing"]
+
+
+
+[docs]
+class AtariPreprocessing(gym.Wrapper, gym.utils.RecordConstructorArgs):
+ """Implements the common preprocessing techniques for Atari environments (excluding frame stacking).
+
+ For frame stacking use :class:`gymnasium.wrappers.FrameStackObservation`.
+ No vector version of the wrapper exists
+
+ This class follows the guidelines in Machado et al. (2018),
+ "Revisiting the Arcade Learning Environment: Evaluation Protocols and Open Problems for General Agents".
+
+ Specifically, the following preprocess stages applies to the atari environment:
+
+ - Noop Reset: Obtains the initial state by taking a random number of no-ops on reset, default max 30 no-ops.
+ - Frame skipping: The number of frames skipped between steps, 4 by default.
+ - Max-pooling: Pools over the most recent two observations from the frame skips.
+ - Termination signal when a life is lost: When the agent losses a life during the environment, then the environment is terminated.
+ Turned off by default. Not recommended by Machado et al. (2018).
+ - Resize to a square image: Resizes the atari environment original observation shape from 210x180 to 84x84 by default.
+ - Grayscale observation: Makes the observation greyscale, enabled by default.
+ - Grayscale new axis: Extends the last channel of the observation such that the image is 3-dimensional, not enabled by default.
+ - Scale observation: Whether to scale the observation between [0, 1) or [0, 255), not scaled by default.
+
+ Example:
+ >>> import gymnasium as gym
+ >>> import ale_py
+ >>> gym.register_envs(ale_py)
+ >>> env = gym.make("ALE/Pong-v5", frameskip=1)
+ >>> env = AtariPreprocessing(
+ ... env,
+ ... noop_max=10, frame_skip=4, terminal_on_life_loss=True,
+ ... screen_size=84, grayscale_obs=False, grayscale_newaxis=False
+ ... )
+
+ Change logs:
+ * Added in gym v0.12.2 (gym #1455)
+ """
+
+ def __init__(
+ self,
+ env: gym.Env,
+ noop_max: int = 30,
+ frame_skip: int = 4,
+ screen_size: int | tuple[int, int] = 84,
+ terminal_on_life_loss: bool = False,
+ grayscale_obs: bool = True,
+ grayscale_newaxis: bool = False,
+ scale_obs: bool = False,
+ ):
+ """Wrapper for Atari 2600 preprocessing.
+
+ Args:
+ env (Env): The environment to apply the preprocessing
+ noop_max (int): For No-op reset, the max number no-ops actions are taken at reset, to turn off, set to 0.
+ frame_skip (int): The number of frames between new observation the agents observations effecting the frequency at which the agent experiences the game.
+ screen_size (int | tuple[int, int]): resize Atari frame.
+ terminal_on_life_loss (bool): `if True`, then :meth:`step()` returns `terminated=True` whenever a
+ life is lost.
+ grayscale_obs (bool): if True, then gray scale observation is returned, otherwise, RGB observation
+ is returned.
+ grayscale_newaxis (bool): `if True and grayscale_obs=True`, then a channel axis is added to
+ grayscale observations to make them 3-dimensional.
+ scale_obs (bool): if True, then observation normalized in range [0,1) is returned. It also limits memory
+ optimization benefits of FrameStack Wrapper.
+
+ Raises:
+ DependencyNotInstalled: opencv-python package not installed
+ ValueError: Disable frame-skipping in the original env
+ """
+ gym.utils.RecordConstructorArgs.__init__(
+ self,
+ noop_max=noop_max,
+ frame_skip=frame_skip,
+ screen_size=screen_size,
+ terminal_on_life_loss=terminal_on_life_loss,
+ grayscale_obs=grayscale_obs,
+ grayscale_newaxis=grayscale_newaxis,
+ scale_obs=scale_obs,
+ )
+ gym.Wrapper.__init__(self, env)
+
+ try:
+ import cv2 # noqa: F401
+ except ImportError:
+ raise gym.error.DependencyNotInstalled(
+ 'opencv-python package not installed, run `pip install "gymnasium[other]"` to get dependencies for atari'
+ )
+
+ assert frame_skip > 0
+ assert (isinstance(screen_size, int) and screen_size > 0) or (
+ isinstance(screen_size, tuple)
+ and len(screen_size) == 2
+ and all(isinstance(size, int) and size > 0 for size in screen_size)
+ ), f"Expect the `screen_size` to be positive, actually: {screen_size}"
+ if frame_skip > 1 and getattr(env.unwrapped, "_frameskip", None) != 1:
+ raise ValueError(
+ "Disable frame-skipping in the original env. Otherwise, more than one frame-skip will happen as through this wrapper"
+ )
+ assert noop_max >= 0
+ self.noop_max = noop_max
+ if noop_max > 0:
+ assert env.unwrapped.get_action_meanings()[0] == "NOOP"
+
+ self.frame_skip = frame_skip
+ self.screen_size: tuple[int, int] = (
+ screen_size
+ if isinstance(screen_size, tuple)
+ else (screen_size, screen_size)
+ )
+ self.terminal_on_life_loss = terminal_on_life_loss
+ self.grayscale_obs = grayscale_obs
+ self.grayscale_newaxis = grayscale_newaxis
+ self.scale_obs = scale_obs
+
+ # buffer of most recent two observations for max pooling
+ assert isinstance(env.observation_space, Box)
+ if grayscale_obs:
+ self.obs_buffer = [
+ np.empty(env.observation_space.shape[:2], dtype=np.uint8),
+ np.empty(env.observation_space.shape[:2], dtype=np.uint8),
+ ]
+ else:
+ self.obs_buffer = [
+ np.empty(env.observation_space.shape, dtype=np.uint8),
+ np.empty(env.observation_space.shape, dtype=np.uint8),
+ ]
+
+ self.lives = 0
+ self.game_over = False
+
+ _low, _high, _dtype = (0, 1, np.float32) if scale_obs else (0, 255, np.uint8)
+ _shape = (self.screen_size[1], self.screen_size[0], 1 if grayscale_obs else 3)
+ if grayscale_obs and not grayscale_newaxis:
+ _shape = _shape[:-1] # Remove channel axis
+ self.observation_space = Box(low=_low, high=_high, shape=_shape, dtype=_dtype)
+
+ @property
+ def ale(self):
+ """Make ale as a class property to avoid serialization error."""
+ return self.env.unwrapped.ale
+
+ def step(
+ self, action: WrapperActType
+ ) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]:
+ """Applies the preprocessing for an :meth:`env.step`."""
+ total_reward, terminated, truncated, info = 0.0, False, False, {}
+
+ for t in range(self.frame_skip):
+ _, reward, terminated, truncated, info = self.env.step(action)
+ total_reward += reward
+ self.game_over = terminated
+
+ if self.terminal_on_life_loss:
+ new_lives = self.ale.lives()
+ terminated = terminated or new_lives < self.lives
+ self.game_over = terminated
+ self.lives = new_lives
+
+ if terminated or truncated:
+ break
+ if t == self.frame_skip - 2:
+ if self.grayscale_obs:
+ self.ale.getScreenGrayscale(self.obs_buffer[1])
+ else:
+ self.ale.getScreenRGB(self.obs_buffer[1])
+ elif t == self.frame_skip - 1:
+ if self.grayscale_obs:
+ self.ale.getScreenGrayscale(self.obs_buffer[0])
+ else:
+ self.ale.getScreenRGB(self.obs_buffer[0])
+ return self._get_obs(), total_reward, terminated, truncated, info
+
+ def reset(
+ self, *, seed: int | None = None, options: dict[str, Any] | None = None
+ ) -> tuple[WrapperObsType, dict[str, Any]]:
+ """Resets the environment using preprocessing."""
+ # NoopReset
+ _, reset_info = self.env.reset(seed=seed, options=options)
+
+ noops = (
+ self.env.unwrapped.np_random.integers(1, self.noop_max + 1)
+ if self.noop_max > 0
+ else 0
+ )
+ for _ in range(noops):
+ _, _, terminated, truncated, step_info = self.env.step(0)
+ reset_info.update(step_info)
+ if terminated or truncated:
+ _, reset_info = self.env.reset(seed=seed, options=options)
+
+ self.lives = self.ale.lives()
+ if self.grayscale_obs:
+ self.ale.getScreenGrayscale(self.obs_buffer[0])
+ else:
+ self.ale.getScreenRGB(self.obs_buffer[0])
+ self.obs_buffer[1].fill(0)
+
+ return self._get_obs(), reset_info
+
+ def _get_obs(self):
+ if self.frame_skip > 1: # more efficient in-place pooling
+ np.maximum(self.obs_buffer[0], self.obs_buffer[1], out=self.obs_buffer[0])
+
+ import cv2
+
+ obs = cv2.resize(
+ self.obs_buffer[0],
+ self.screen_size,
+ interpolation=cv2.INTER_AREA,
+ )
+
+ if self.scale_obs:
+ obs = np.asarray(obs, dtype=np.float32) / 255.0
+ else:
+ obs = np.asarray(obs, dtype=np.uint8)
+
+ if self.grayscale_obs and self.grayscale_newaxis:
+ obs = np.expand_dims(obs, axis=-1) # Add a channel axis
+ return obs
+
+
+"""A collection of common wrappers.
+
+* ``TimeLimit`` - Provides a time limit on the number of steps for an environment before it truncates
+* ``Autoreset`` - Auto-resets the environment
+* ``PassiveEnvChecker`` - Passive environment checker that does not modify any environment data
+* ``OrderEnforcing`` - Enforces the order of function calls to environments
+* ``RecordEpisodeStatistics`` - Records the episode statistics
+"""
+
+from __future__ import annotations
+
+import time
+from collections import deque
+from copy import deepcopy
+from typing import TYPE_CHECKING, Any, SupportsFloat
+
+import gymnasium as gym
+from gymnasium import logger
+from gymnasium.core import ActType, ObsType, RenderFrame, WrapperObsType
+from gymnasium.error import ResetNeeded
+from gymnasium.utils.passive_env_checker import (
+ check_action_space,
+ check_observation_space,
+ env_render_passive_checker,
+ env_reset_passive_checker,
+ env_step_passive_checker,
+)
+
+
+if TYPE_CHECKING:
+ from gymnasium.envs.registration import EnvSpec
+
+
+__all__ = [
+ "TimeLimit",
+ "Autoreset",
+ "PassiveEnvChecker",
+ "OrderEnforcing",
+ "RecordEpisodeStatistics",
+]
+
+
+
+[docs]
+class TimeLimit(
+ gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
+):
+ """Limits the number of steps for an environment through truncating the environment if a maximum number of timesteps is exceeded.
+
+ If a truncation is not defined inside the environment itself, this is the only place that the truncation signal is issued.
+ Critically, this is different from the `terminated` signal that originates from the underlying environment as part of the MDP.
+ No vector wrapper exists.
+
+ Example using the TimeLimit wrapper:
+ >>> from gymnasium.wrappers import TimeLimit
+ >>> from gymnasium.envs.classic_control import CartPoleEnv
+
+ >>> spec = gym.spec("CartPole-v1")
+ >>> spec.max_episode_steps
+ 500
+ >>> env = gym.make("CartPole-v1")
+ >>> env # TimeLimit is included within the environment stack
+ <TimeLimit<OrderEnforcing<PassiveEnvChecker<CartPoleEnv<CartPole-v1>>>>>
+ >>> env.spec # doctest: +ELLIPSIS
+ EnvSpec(id='CartPole-v1', ..., max_episode_steps=500, ...)
+ >>> env = gym.make("CartPole-v1", max_episode_steps=3)
+ >>> env.spec # doctest: +ELLIPSIS
+ EnvSpec(id='CartPole-v1', ..., max_episode_steps=3, ...)
+ >>> env = TimeLimit(CartPoleEnv(), max_episode_steps=10)
+ >>> env
+ <TimeLimit<CartPoleEnv instance>>
+
+ Example of `TimeLimit` determining the episode step
+ >>> env = gym.make("CartPole-v1", max_episode_steps=3)
+ >>> _ = env.reset(seed=123)
+ >>> _ = env.action_space.seed(123)
+ >>> _, _, terminated, truncated, _ = env.step(env.action_space.sample())
+ >>> terminated, truncated
+ (False, False)
+ >>> _, _, terminated, truncated, _ = env.step(env.action_space.sample())
+ >>> terminated, truncated
+ (False, False)
+ >>> _, _, terminated, truncated, _ = env.step(env.action_space.sample())
+ >>> terminated, truncated
+ (False, True)
+
+ Change logs:
+ * v0.10.6 - Initially added
+ * v0.25.0 - With the step API update, the termination and truncation signal is returned separately.
+ """
+
+ def __init__(
+ self,
+ env: gym.Env,
+ max_episode_steps: int,
+ ):
+ """Initializes the :class:`TimeLimit` wrapper with an environment and the number of steps after which truncation will occur.
+
+ Args:
+ env: The environment to apply the wrapper
+ max_episode_steps: the environment step after which the episode is truncated (``elapsed >= max_episode_steps``)
+ """
+ assert (
+ isinstance(max_episode_steps, int) and max_episode_steps > 0
+ ), f"Expect the `max_episode_steps` to be positive, actually: {max_episode_steps}"
+ gym.utils.RecordConstructorArgs.__init__(
+ self, max_episode_steps=max_episode_steps
+ )
+ gym.Wrapper.__init__(self, env)
+
+ self._max_episode_steps = max_episode_steps
+ self._elapsed_steps = None
+
+ def step(
+ self, action: ActType
+ ) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
+ """Steps through the environment and if the number of steps elapsed exceeds ``max_episode_steps`` then truncate.
+
+ Args:
+ action: The environment step action
+
+ Returns:
+ The environment step ``(observation, reward, terminated, truncated, info)`` with `truncated=True`
+ if the number of steps elapsed >= max episode steps
+
+ """
+ observation, reward, terminated, truncated, info = self.env.step(action)
+ self._elapsed_steps += 1
+
+ if self._elapsed_steps >= self._max_episode_steps:
+ truncated = True
+
+ return observation, reward, terminated, truncated, info
+
+ def reset(
+ self, *, seed: int | None = None, options: dict[str, Any] | None = None
+ ) -> tuple[ObsType, dict[str, Any]]:
+ """Resets the environment with :param:`**kwargs` and sets the number of steps elapsed to zero.
+
+ Args:
+ seed: Seed for the environment
+ options: Options for the environment
+
+ Returns:
+ The reset environment
+ """
+ self._elapsed_steps = 0
+ return super().reset(seed=seed, options=options)
+
+ @property
+ def spec(self) -> EnvSpec | None:
+ """Modifies the environment spec to include the `max_episode_steps=self._max_episode_steps`."""
+ if self._cached_spec is not None:
+ return self._cached_spec
+
+ env_spec = self.env.spec
+ if env_spec is not None:
+ try:
+ env_spec = deepcopy(env_spec)
+ env_spec.max_episode_steps = self._max_episode_steps
+ except Exception as e:
+ gym.logger.warn(
+ f"An exception occurred ({e}) while copying the environment spec={env_spec}"
+ )
+ return None
+
+ self._cached_spec = env_spec
+ return env_spec
+
+
+
+
+[docs]
+class Autoreset(
+ gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
+):
+ """The wrapped environment is automatically reset when a terminated or truncated state is reached.
+
+ This follows the vector autoreset api where on the step after an episode terminates or truncated then the environment is reset.
+
+ Change logs:
+ * v0.24.0 - Initially added as `AutoResetWrapper`
+ * v1.0.0 - renamed to `Autoreset` and autoreset order was changed to reset on the step after the environment terminates or truncates. As a result, `"final_observation"` and `"final_info"` is removed.
+ """
+
+ def __init__(self, env: gym.Env):
+ """A class for providing an automatic reset functionality for gymnasium environments when calling :meth:`self.step`.
+
+ Args:
+ env (gym.Env): The environment to apply the wrapper
+ """
+ gym.utils.RecordConstructorArgs.__init__(self)
+ gym.Wrapper.__init__(self, env)
+
+ self.autoreset = False
+
+ def reset(
+ self, *, seed: int | None = None, options: dict[str, Any] | None = None
+ ) -> tuple[WrapperObsType, dict[str, Any]]:
+ """Resets the environment and sets autoreset to False preventing."""
+ self.autoreset = False
+ return super().reset(seed=seed, options=options)
+
+ def step(
+ self, action: ActType
+ ) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
+ """Steps through the environment with action and resets the environment if a terminated or truncated signal is encountered.
+
+ Args:
+ action: The action to take
+
+ Returns:
+ The autoreset environment :meth:`step`
+ """
+ if self.autoreset:
+ obs, info = self.env.reset()
+ reward, terminated, truncated = 0.0, False, False
+ else:
+ obs, reward, terminated, truncated, info = self.env.step(action)
+
+ self.autoreset = terminated or truncated
+ return obs, reward, terminated, truncated, info
+
+
+
+
+[docs]
+class PassiveEnvChecker(
+ gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
+):
+ """A passive wrapper that surrounds the ``step``, ``reset`` and ``render`` functions to check they follow Gymnasium's API.
+
+ This wrapper is automatically applied during make and can be disabled with `disable_env_checker`.
+ No vector version of the wrapper exists.
+
+ Example:
+ >>> import gymnasium as gym
+ >>> env = gym.make("CartPole-v1")
+ >>> env
+ <TimeLimit<OrderEnforcing<PassiveEnvChecker<CartPoleEnv<CartPole-v1>>>>>
+ >>> env = gym.make("CartPole-v1", disable_env_checker=True)
+ >>> env
+ <TimeLimit<OrderEnforcing<CartPoleEnv<CartPole-v1>>>>
+
+ Change logs:
+ * v0.24.1 - Initially added however broken in several ways
+ * v0.25.0 - Bugs was all fixed
+ * v0.29.0 - Removed warnings for infinite bounds for Box observation and action spaces and inregular bound shapes
+ """
+
+ def __init__(self, env: gym.Env[ObsType, ActType]):
+ """Initialises the wrapper with the environments, run the observation and action space tests."""
+ gym.utils.RecordConstructorArgs.__init__(self)
+ gym.Wrapper.__init__(self, env)
+
+ if not isinstance(env, gym.Env):
+ if str(env.__class__.__base__) == "<class 'gym.core.Env'>":
+ raise TypeError(
+ "Gym is incompatible with Gymnasium, please update the environment class to `gymnasium.Env`. "
+ "See https://gymnasium.farama.org/introduction/create_custom_env/ for more info."
+ )
+ else:
+ raise TypeError(
+ f"The environment must inherit from the gymnasium.Env class, actual class: {type(env)}. "
+ "See https://gymnasium.farama.org/introduction/create_custom_env/ for more info."
+ )
+
+ if not hasattr(env, "action_space"):
+ raise AttributeError(
+ "The environment must specify an action space. https://gymnasium.farama.org/introduction/create_custom_env/"
+ )
+ check_action_space(env.action_space)
+
+ if not hasattr(env, "observation_space"):
+ raise AttributeError(
+ "The environment must specify an observation space. https://gymnasium.farama.org/introduction/create_custom_env/"
+ )
+ check_observation_space(env.observation_space)
+
+ self.checked_reset: bool = False
+ self.checked_step: bool = False
+ self.checked_render: bool = False
+ self.close_called: bool = False
+
+ def step(
+ self, action: ActType
+ ) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
+ """Steps through the environment that on the first call will run the `passive_env_step_check`."""
+ if self.checked_step is False:
+ self.checked_step = True
+ return env_step_passive_checker(self.env, action)
+ else:
+ return self.env.step(action)
+
+ def reset(
+ self, *, seed: int | None = None, options: dict[str, Any] | None = None
+ ) -> tuple[ObsType, dict[str, Any]]:
+ """Resets the environment that on the first call will run the `passive_env_reset_check`."""
+ if self.checked_reset is False:
+ self.checked_reset = True
+ return env_reset_passive_checker(self.env, seed=seed, options=options)
+ else:
+ return self.env.reset(seed=seed, options=options)
+
+ def render(self) -> RenderFrame | list[RenderFrame] | None:
+ """Renders the environment that on the first call will run the `passive_env_render_check`."""
+ if self.checked_render is False:
+ self.checked_render = True
+ return env_render_passive_checker(self.env)
+ else:
+ return self.env.render()
+
+ @property
+ def spec(self) -> EnvSpec | None:
+ """Modifies the environment spec to such that `disable_env_checker=False`."""
+ if self._cached_spec is not None:
+ return self._cached_spec
+
+ env_spec = self.env.spec
+ if env_spec is not None:
+ try:
+ env_spec = deepcopy(env_spec)
+ env_spec.disable_env_checker = False
+ except Exception as e:
+ gym.logger.warn(
+ f"An exception occurred ({e}) while copying the environment spec={env_spec}"
+ )
+ return None
+
+ self._cached_spec = env_spec
+ return env_spec
+
+ def close(self):
+ """Warns if calling close on a closed environment fails."""
+ if not self.close_called:
+ self.close_called = True
+ return self.env.close()
+ else:
+ try:
+ return self.env.close()
+ except Exception as e:
+ logger.warn(
+ "Calling `env.close()` on the closed environment should be allowed, but it raised the following exception."
+ )
+ raise e
+
+
+
+
+[docs]
+class OrderEnforcing(
+ gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
+):
+ """Will produce an error if ``step`` or ``render`` is called before ``reset``.
+
+ No vector version of the wrapper exists.
+
+ Example:
+ >>> import gymnasium as gym
+ >>> from gymnasium.wrappers import OrderEnforcing
+ >>> env = gym.make("CartPole-v1", render_mode="human")
+ >>> env = OrderEnforcing(env)
+ >>> env.step(0)
+ Traceback (most recent call last):
+ ...
+ gymnasium.error.ResetNeeded: Cannot call env.step() before calling env.reset()
+ >>> env.render()
+ Traceback (most recent call last):
+ ...
+ gymnasium.error.ResetNeeded: Cannot call `env.render()` before calling `env.reset()`, if this is an intended action, set `disable_render_order_enforcing=True` on the OrderEnforcer wrapper.
+ >>> _ = env.reset()
+ >>> env.render()
+ >>> _ = env.step(0)
+ >>> env.close()
+
+ Change logs:
+ * v0.22.0 - Initially added
+ * v0.24.0 - Added order enforcing for the render function
+ """
+
+ def __init__(
+ self,
+ env: gym.Env[ObsType, ActType],
+ disable_render_order_enforcing: bool = False,
+ ):
+ """A wrapper that will produce an error if :meth:`step` is called before an initial :meth:`reset`.
+
+ Args:
+ env: The environment to wrap
+ disable_render_order_enforcing: If to disable render order enforcing
+ """
+ gym.utils.RecordConstructorArgs.__init__(
+ self, disable_render_order_enforcing=disable_render_order_enforcing
+ )
+ gym.Wrapper.__init__(self, env)
+
+ self._has_reset: bool = False
+ self._disable_render_order_enforcing: bool = disable_render_order_enforcing
+
+ def step(self, action: ActType) -> tuple[ObsType, SupportsFloat, bool, bool, dict]:
+ """Steps through the environment."""
+ if not self._has_reset:
+ raise ResetNeeded("Cannot call env.step() before calling env.reset()")
+ return super().step(action)
+
+ def reset(
+ self, *, seed: int | None = None, options: dict[str, Any] | None = None
+ ) -> tuple[ObsType, dict[str, Any]]:
+ """Resets the environment with `kwargs`."""
+ self._has_reset = True
+ return super().reset(seed=seed, options=options)
+
+ def render(self) -> RenderFrame | list[RenderFrame] | None:
+ """Renders the environment with `kwargs`."""
+ if not self._disable_render_order_enforcing and not self._has_reset:
+ raise ResetNeeded(
+ "Cannot call `env.render()` before calling `env.reset()`, if this is an intended action, "
+ "set `disable_render_order_enforcing=True` on the OrderEnforcer wrapper."
+ )
+ return super().render()
+
+ @property
+ def has_reset(self):
+ """Returns if the environment has been reset before."""
+ return self._has_reset
+
+ @property
+ def spec(self) -> EnvSpec | None:
+ """Modifies the environment spec to add the `order_enforce=True`."""
+ if self._cached_spec is not None:
+ return self._cached_spec
+
+ env_spec = self.env.spec
+ if env_spec is not None:
+ try:
+ env_spec = deepcopy(env_spec)
+ env_spec.order_enforce = True
+ except Exception as e:
+ gym.logger.warn(
+ f"An exception occurred ({e}) while copying the environment spec={env_spec}"
+ )
+ return None
+
+ self._cached_spec = env_spec
+ return env_spec
+
+
+
+
+[docs]
+class RecordEpisodeStatistics(
+ gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
+):
+ """This wrapper will keep track of cumulative rewards and episode lengths.
+
+ At the end of an episode, the statistics of the episode will be added to ``info``
+ using the key ``episode``. If using a vectorized environment also the key
+ ``_episode`` is used which indicates whether the env at the respective index has
+ the episode statistics.
+ A vector version of the wrapper exists, :class:`gymnasium.wrappers.vector.RecordEpisodeStatistics`.
+
+ After the completion of an episode, ``info`` will look like this::
+
+ >>> info = {
+ ... "episode": {
+ ... "r": "<cumulative reward>",
+ ... "l": "<episode length>",
+ ... "t": "<elapsed time since beginning of episode>"
+ ... },
+ ... }
+
+ For a vectorized environments the output will be in the form of::
+
+ >>> infos = {
+ ... "episode": {
+ ... "r": "<array of cumulative reward>",
+ ... "l": "<array of episode length>",
+ ... "t": "<array of elapsed time since beginning of episode>"
+ ... },
+ ... "_episode": "<boolean array of length num-envs>"
+ ... }
+
+ Moreover, the most recent rewards and episode lengths are stored in buffers that can be accessed via
+ :attr:`wrapped_env.return_queue` and :attr:`wrapped_env.length_queue` respectively.
+
+ Attributes:
+ * time_queue: The time length of the last ``deque_size``-many episodes
+ * return_queue: The cumulative rewards of the last ``deque_size``-many episodes
+ * length_queue: The lengths of the last ``deque_size``-many episodes
+
+ Change logs:
+ * v0.15.4 - Initially added
+ * v1.0.0 - Removed vector environment support (see :class:`gymnasium.wrappers.vector.RecordEpisodeStatistics`) and add attribute ``time_queue``
+ """
+
+ def __init__(
+ self,
+ env: gym.Env[ObsType, ActType],
+ buffer_length: int = 100,
+ stats_key: str = "episode",
+ ):
+ """This wrapper will keep track of cumulative rewards and episode lengths.
+
+ Args:
+ env (Env): The environment to apply the wrapper
+ buffer_length: The size of the buffers :attr:`return_queue`, :attr:`length_queue` and :attr:`time_queue`
+ stats_key: The info key for the episode statistics
+ """
+ gym.utils.RecordConstructorArgs.__init__(self)
+ gym.Wrapper.__init__(self, env)
+
+ self._stats_key = stats_key
+
+ self.episode_count = 0
+ self.episode_start_time: float = -1
+ self.episode_returns: float = 0.0
+ self.episode_lengths: int = 0
+
+ self.time_queue: deque[float] = deque(maxlen=buffer_length)
+ self.return_queue: deque[float] = deque(maxlen=buffer_length)
+ self.length_queue: deque[int] = deque(maxlen=buffer_length)
+
+ def step(
+ self, action: ActType
+ ) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
+ """Steps through the environment, recording the episode statistics."""
+ obs, reward, terminated, truncated, info = super().step(action)
+
+ self.episode_returns += reward
+ self.episode_lengths += 1
+
+ if terminated or truncated:
+ assert self._stats_key not in info
+
+ episode_time_length = round(
+ time.perf_counter() - self.episode_start_time, 6
+ )
+ info[self._stats_key] = {
+ "r": self.episode_returns,
+ "l": self.episode_lengths,
+ "t": episode_time_length,
+ }
+
+ self.time_queue.append(episode_time_length)
+ self.return_queue.append(self.episode_returns)
+ self.length_queue.append(self.episode_lengths)
+
+ self.episode_count += 1
+ self.episode_start_time = time.perf_counter()
+
+ return obs, reward, terminated, truncated, info
+
+ def reset(
+ self, *, seed: int | None = None, options: dict[str, Any] | None = None
+ ) -> tuple[ObsType, dict[str, Any]]:
+ """Resets the environment using seed and options and resets the episode rewards and lengths."""
+ obs, info = super().reset(seed=seed, options=options)
+
+ self.episode_start_time = time.perf_counter()
+ self.episode_returns = 0.0
+ self.episode_lengths = 0
+
+ return obs, info
+
+
+"""Helper functions and wrapper class for converting between numpy and Jax."""
+
+from __future__ import annotations
+
+import functools
+
+import numpy as np
+
+import gymnasium as gym
+from gymnasium.core import ActType, ObsType
+from gymnasium.error import DependencyNotInstalled
+from gymnasium.wrappers.array_conversion import (
+ ArrayConversion,
+ array_conversion,
+ module_namespace,
+)
+
+
+try:
+ import jax.numpy as jnp
+except ImportError:
+ raise DependencyNotInstalled(
+ 'Jax is not installed therefore cannot call `numpy_to_jax`, run `pip install "gymnasium[jax]"`'
+ )
+
+__all__ = ["JaxToNumpy", "jax_to_numpy", "numpy_to_jax"]
+
+
+jax_to_numpy = functools.partial(array_conversion, xp=module_namespace(np))
+
+numpy_to_jax = functools.partial(array_conversion, xp=module_namespace(jnp))
+
+
+
+[docs]
+class JaxToNumpy(ArrayConversion):
+ """Wraps a Jax-based environment such that it can be interacted with NumPy arrays.
+
+ Actions must be provided as numpy arrays and observations will be returned as numpy arrays.
+ A vector version of the wrapper exists, :class:`gymnasium.wrappers.vector.JaxToNumpy`.
+
+ Notes:
+ The Jax To Numpy and Numpy to Jax conversion does not guarantee a roundtrip (jax -> numpy -> jax) and vice versa.
+ The reason for this is jax does not support non-array values, therefore numpy ``int_32(5) -> DeviceArray([5], dtype=jnp.int23)``
+
+ Example:
+ >>> import gymnasium as gym # doctest: +SKIP
+ >>> env = gym.make("JaxEnv-vx") # doctest: +SKIP
+ >>> env = JaxToNumpy(env) # doctest: +SKIP
+ >>> obs, _ = env.reset(seed=123) # doctest: +SKIP
+ >>> type(obs) # doctest: +SKIP
+ <class 'numpy.ndarray'>
+ >>> action = env.action_space.sample() # doctest: +SKIP
+ >>> obs, reward, terminated, truncated, info = env.step(action) # doctest: +SKIP
+ >>> type(obs) # doctest: +SKIP
+ <class 'numpy.ndarray'>
+ >>> type(reward) # doctest: +SKIP
+ <class 'float'>
+ >>> type(terminated) # doctest: +SKIP
+ <class 'bool'>
+ >>> type(truncated) # doctest: +SKIP
+ <class 'bool'>
+
+ Change logs:
+ * v1.0.0 - Initially added
+ """
+
+ def __init__(self, env: gym.Env[ObsType, ActType]):
+ """Wraps a jax environment such that the input and outputs are numpy arrays.
+
+ Args:
+ env: the jax environment to wrap
+ """
+ if jnp is None:
+ raise DependencyNotInstalled(
+ 'Jax is not installed, run `pip install "gymnasium[jax]"`'
+ )
+ super().__init__(env=env, env_xp=jnp, target_xp=np)
+
+
+# This wrapper will convert torch inputs for the actions and observations to Jax arrays
+# for an underlying Jax environment then convert the return observations from Jax arrays
+# back to torch tensors.
+#
+# Functionality for converting between torch and jax types originally copied from
+# https://github.com/google/brax/blob/9d6b7ced2a13da0d074b5e9fbd3aad8311e26997/brax/io/torch.py
+# Under the Apache 2.0 license. Copyright is held by the authors
+
+"""Helper functions and wrapper class for converting between PyTorch and Jax."""
+
+from __future__ import annotations
+
+import functools
+from typing import Union
+
+import gymnasium as gym
+from gymnasium.error import DependencyNotInstalled
+from gymnasium.wrappers.array_conversion import (
+ ArrayConversion,
+ array_conversion,
+ module_namespace,
+)
+
+
+try:
+ import jax.numpy as jnp
+
+except ImportError:
+ raise DependencyNotInstalled(
+ 'Jax is not installed therefore cannot call `torch_to_jax`, run `pip install "gymnasium[jax]"`'
+ )
+
+try:
+ import torch
+
+ Device = Union[str, torch.device]
+except ImportError:
+ raise DependencyNotInstalled(
+ 'Torch is not installed therefore cannot call `torch_to_jax`, run `pip install "gymnasium[torch]"`'
+ )
+
+
+__all__ = ["JaxToTorch", "jax_to_torch", "torch_to_jax", "Device"]
+
+
+torch_to_jax = functools.partial(array_conversion, xp=module_namespace(jnp))
+
+jax_to_torch = functools.partial(array_conversion, xp=module_namespace(torch))
+
+
+
+[docs]
+class JaxToTorch(ArrayConversion):
+ """Wraps a Jax-based environment so that it can be interacted with PyTorch Tensors.
+
+ Actions must be provided as PyTorch Tensors and observations will be returned as PyTorch Tensors.
+ A vector version of the wrapper exists, :class:`gymnasium.wrappers.vector.JaxToTorch`.
+
+ Note:
+ For ``rendered`` this is returned as a NumPy array not a pytorch Tensor.
+
+ Example:
+ >>> import torch # doctest: +SKIP
+ >>> import gymnasium as gym # doctest: +SKIP
+ >>> env = gym.make("JaxEnv-vx") # doctest: +SKIP
+ >>> env = JaxtoTorch(env) # doctest: +SKIP
+ >>> obs, _ = env.reset(seed=123) # doctest: +SKIP
+ >>> type(obs) # doctest: +SKIP
+ <class 'torch.Tensor'>
+ >>> action = torch.tensor(env.action_space.sample()) # doctest: +SKIP
+ >>> obs, reward, terminated, truncated, info = env.step(action) # doctest: +SKIP
+ >>> type(obs) # doctest: +SKIP
+ <class 'torch.Tensor'>
+ >>> type(reward) # doctest: +SKIP
+ <class 'float'>
+ >>> type(terminated) # doctest: +SKIP
+ <class 'bool'>
+ >>> type(truncated) # doctest: +SKIP
+ <class 'bool'>
+
+ Change logs:
+ * v1.0.0 - Initially added
+ """
+
+ def __init__(self, env: gym.Env, device: Device | None = None):
+ """Wrapper class to change inputs and outputs of environment to PyTorch tensors.
+
+ Args:
+ env: The Jax-based environment to wrap
+ device: The device the torch Tensors should be moved to
+ """
+ super().__init__(env=env, env_xp=jnp, target_xp=torch, target_device=device)
+
+ # TODO: Device was part of the public API, but should be removed in favor of _env_device and
+ # _target_device.
+ self.device: Device | None = device
+
+
+"""Helper functions and wrapper class for converting between PyTorch and NumPy."""
+
+from __future__ import annotations
+
+import functools
+from typing import Union
+
+import numpy as np
+
+import gymnasium as gym
+from gymnasium.error import DependencyNotInstalled
+from gymnasium.wrappers.array_conversion import (
+ ArrayConversion,
+ array_conversion,
+ module_namespace,
+)
+
+
+try:
+ import torch
+
+ Device = Union[str, torch.device]
+except ImportError:
+ raise DependencyNotInstalled(
+ 'Torch is not installed therefore cannot call `torch_to_numpy`, run `pip install "gymnasium[torch]"`'
+ )
+
+
+__all__ = ["NumpyToTorch", "torch_to_numpy", "numpy_to_torch", "Device"]
+
+
+torch_to_numpy = functools.partial(array_conversion, xp=module_namespace(np))
+
+numpy_to_torch = functools.partial(array_conversion, xp=module_namespace(torch))
+
+
+
+[docs]
+class NumpyToTorch(ArrayConversion):
+ """Wraps a NumPy-based environment such that it can be interacted with PyTorch Tensors.
+
+ Actions must be provided as PyTorch Tensors and observations will be returned as PyTorch Tensors.
+ A vector version of the wrapper exists, :class:`gymnasium.wrappers.vector.NumpyToTorch`.
+
+ Note:
+ For ``rendered`` this is returned as a NumPy array not a pytorch Tensor.
+
+ Example:
+ >>> import torch
+ >>> import gymnasium as gym
+ >>> env = gym.make("CartPole-v1")
+ >>> env = NumpyToTorch(env)
+ >>> obs, _ = env.reset(seed=123)
+ >>> type(obs)
+ <class 'torch.Tensor'>
+ >>> action = torch.tensor(env.action_space.sample())
+ >>> obs, reward, terminated, truncated, info = env.step(action)
+ >>> type(obs)
+ <class 'torch.Tensor'>
+ >>> type(reward)
+ <class 'float'>
+ >>> type(terminated)
+ <class 'bool'>
+ >>> type(truncated)
+ <class 'bool'>
+
+ Change logs:
+ * v1.0.0 - Initially added
+ """
+
+ def __init__(self, env: gym.Env, device: Device | None = None):
+ """Wrapper class to change inputs and outputs of environment to PyTorch tensors.
+
+ Args:
+ env: The NumPy-based environment to wrap
+ device: The device the torch Tensors should be moved to
+ """
+ super().__init__(env=env, env_xp=np, target_xp=torch, target_device=device)
+
+ self.device: Device | None = device
+
+
+"""A collections of rendering-based wrappers.
+
+* ``RenderCollection`` - Collects rendered frames into a list
+* ``RecordVideo`` - Records a video of the environments
+* ``HumanRendering`` - Provides human rendering of environments with ``"rgb_array"``
+* ``AddWhiteNoise`` - Randomly replaces pixels with white noise
+* ``ObstructView`` - Randomly places patches of white noise to obstruct the pixel rendering
+"""
+
+from __future__ import annotations
+
+import gc
+import os
+from collections.abc import Callable
+from copy import deepcopy
+from typing import Any, Generic, SupportsFloat
+
+import numpy as np
+
+import gymnasium as gym
+from gymnasium import error, logger
+from gymnasium.core import ActType, ObsType, RenderFrame
+from gymnasium.error import DependencyNotInstalled, InvalidProbability
+
+
+__all__ = [
+ "RenderCollection",
+ "RecordVideo",
+ "HumanRendering",
+ "AddWhiteNoise",
+ "ObstructView",
+]
+
+
+
+[docs]
+class RenderCollection(
+ gym.Wrapper[ObsType, ActType, ObsType, ActType],
+ Generic[ObsType, ActType, RenderFrame],
+ gym.utils.RecordConstructorArgs,
+):
+ """Collect rendered frames of an environment such ``render`` returns a ``list[RenderedFrame]``.
+
+ No vector version of the wrapper exists.
+
+ Example:
+ Return the list of frames for the number of steps ``render`` wasn't called.
+ >>> import gymnasium as gym
+ >>> env = gym.make("LunarLander-v3", render_mode="rgb_array")
+ >>> env = RenderCollection(env)
+ >>> _ = env.reset(seed=123)
+ >>> for _ in range(5):
+ ... _ = env.step(env.action_space.sample())
+ ...
+ >>> frames = env.render()
+ >>> len(frames)
+ 6
+
+ >>> frames = env.render()
+ >>> len(frames)
+ 0
+
+ Return the list of frames for the number of steps the episode was running.
+ >>> import gymnasium as gym
+ >>> env = gym.make("LunarLander-v3", render_mode="rgb_array")
+ >>> env = RenderCollection(env, pop_frames=False)
+ >>> _ = env.reset(seed=123)
+ >>> for _ in range(5):
+ ... _ = env.step(env.action_space.sample())
+ ...
+ >>> frames = env.render()
+ >>> len(frames)
+ 6
+
+ >>> frames = env.render()
+ >>> len(frames)
+ 6
+
+ Collect all frames for all episodes, without clearing them when render is called
+ >>> import gymnasium as gym
+ >>> env = gym.make("LunarLander-v3", render_mode="rgb_array")
+ >>> env = RenderCollection(env, pop_frames=False, reset_clean=False)
+ >>> _ = env.reset(seed=123)
+ >>> for _ in range(5):
+ ... _ = env.step(env.action_space.sample())
+ ...
+ >>> _ = env.reset(seed=123)
+ >>> for _ in range(5):
+ ... _ = env.step(env.action_space.sample())
+ ...
+ >>> frames = env.render()
+ >>> len(frames)
+ 12
+
+ >>> frames = env.render()
+ >>> len(frames)
+ 12
+
+ Change logs:
+ * v0.26.2 - Initially added
+ """
+
+ def __init__(
+ self,
+ env: gym.Env[ObsType, ActType],
+ pop_frames: bool = True,
+ reset_clean: bool = True,
+ ):
+ """Initialize a :class:`RenderCollection` instance.
+
+ Args:
+ env: The environment that is being wrapped
+ pop_frames (bool): If true, clear the collection frames after ``meth:render`` is called. Default value is ``True``.
+ reset_clean (bool): If true, clear the collection frames when ``meth:reset`` is called. Default value is ``True``.
+ """
+ gym.utils.RecordConstructorArgs.__init__(
+ self, pop_frames=pop_frames, reset_clean=reset_clean
+ )
+ gym.Wrapper.__init__(self, env)
+
+ assert env.render_mode is not None
+ assert not env.render_mode.endswith("_list")
+
+ self.frame_list: list[RenderFrame] = []
+ self.pop_frames = pop_frames
+ self.reset_clean = reset_clean
+
+ self.metadata = deepcopy(self.env.metadata)
+ if f"{self.env.render_mode}_list" not in self.metadata["render_modes"]:
+ self.metadata["render_modes"].append(f"{self.env.render_mode}_list")
+
+ @property
+ def render_mode(self):
+ """Returns the collection render_mode name."""
+ return f"{self.env.render_mode}_list"
+
+ def step(
+ self, action: ActType
+ ) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
+ """Perform a step in the base environment and collect a frame."""
+ output = super().step(action)
+ self.frame_list.append(super().render())
+ return output
+
+ def reset(
+ self, *, seed: int | None = None, options: dict[str, Any] | None = None
+ ) -> tuple[ObsType, dict[str, Any]]:
+ """Reset the base environment, eventually clear the frame_list, and collect a frame."""
+ output = super().reset(seed=seed, options=options)
+
+ if self.reset_clean:
+ self.frame_list = []
+ self.frame_list.append(super().render())
+
+ return output
+
+ def render(self) -> list[RenderFrame]:
+ """Returns the collection of frames and, if pop_frames = True, clears it."""
+ frames = self.frame_list
+ if self.pop_frames:
+ self.frame_list = []
+
+ return frames
+
+
+
+
+[docs]
+class RecordVideo(
+ gym.Wrapper[ObsType, ActType, ObsType, ActType],
+ Generic[ObsType, ActType, RenderFrame],
+ gym.utils.RecordConstructorArgs,
+):
+ """Records videos of environment episodes using the environment's render function.
+
+ .. py:currentmodule:: gymnasium.utils.save_video
+
+ Usually, you only want to record episodes intermittently, say every hundredth episode or at every thousandth environment step.
+ To do this, you can specify ``episode_trigger`` or ``step_trigger``.
+ They should be functions returning a boolean that indicates whether a recording should be started at the
+ current episode or step, respectively.
+
+ The ``episode_trigger`` should return ``True`` on the episode when recording should start.
+ The ``step_trigger`` should return ``True`` on the n-th environment step that the recording should be started, where n sums over all previous episodes.
+ If neither :attr:`episode_trigger` nor ``step_trigger`` is passed, a default ``episode_trigger`` will be employed, i.e. :func:`capped_cubic_video_schedule`.
+ This function starts a video at every episode that is a power of 3 until 1000 and then every 1000 episodes.
+ By default, the recording will be stopped once reset is called.
+ However, you can also create recordings of fixed length (possibly spanning several episodes)
+ by passing a strictly positive value for ``video_length``.
+
+ No vector version of the wrapper exists.
+
+ Examples - Run the environment for 50 episodes, and save the video every 10 episodes starting from the 0th:
+ >>> import os
+ >>> import gymnasium as gym
+ >>> env = gym.make("LunarLander-v3", render_mode="rgb_array")
+ >>> trigger = lambda t: t % 10 == 0
+ >>> env = RecordVideo(env, video_folder="./save_videos1", episode_trigger=trigger, disable_logger=True)
+ >>> for i in range(50):
+ ... termination, truncation = False, False
+ ... _ = env.reset(seed=123)
+ ... while not (termination or truncation):
+ ... obs, rew, termination, truncation, info = env.step(env.action_space.sample())
+ ...
+ >>> env.close()
+ >>> len(os.listdir("./save_videos1"))
+ 5
+
+ Examples - Run the environment for 5 episodes, start a recording every 200th step, making sure each video is 100 frames long:
+ >>> import os
+ >>> import gymnasium as gym
+ >>> env = gym.make("LunarLander-v3", render_mode="rgb_array")
+ >>> trigger = lambda t: t % 200 == 0
+ >>> env = RecordVideo(env, video_folder="./save_videos2", step_trigger=trigger, video_length=100, disable_logger=True)
+ >>> for i in range(5):
+ ... termination, truncation = False, False
+ ... _ = env.reset(seed=123)
+ ... _ = env.action_space.seed(123)
+ ... while not (termination or truncation):
+ ... obs, rew, termination, truncation, info = env.step(env.action_space.sample())
+ ...
+ >>> env.close()
+ >>> len(os.listdir("./save_videos2"))
+ 2
+
+ Examples - Run 3 episodes, record everything, but in chunks of 1000 frames:
+ >>> import os
+ >>> import gymnasium as gym
+ >>> env = gym.make("LunarLander-v3", render_mode="rgb_array")
+ >>> env = RecordVideo(env, video_folder="./save_videos3", video_length=1000, disable_logger=True)
+ >>> for i in range(3):
+ ... termination, truncation = False, False
+ ... _ = env.reset(seed=123)
+ ... while not (termination or truncation):
+ ... obs, rew, termination, truncation, info = env.step(env.action_space.sample())
+ ...
+ >>> env.close()
+ >>> len(os.listdir("./save_videos3"))
+ 2
+
+ Change logs:
+ * v0.25.0 - Initially added to replace ``wrappers.monitoring.VideoRecorder``
+ """
+
+ def __init__(
+ self,
+ env: gym.Env[ObsType, ActType],
+ video_folder: str,
+ episode_trigger: Callable[[int], bool] | None = None,
+ step_trigger: Callable[[int], bool] | None = None,
+ video_length: int = 0,
+ name_prefix: str = "rl-video",
+ fps: int | None = None,
+ disable_logger: bool = True,
+ gc_trigger: Callable[[int], bool] | None = lambda episode: True,
+ ):
+ """Wrapper records videos of rollouts.
+
+ Args:
+ env: The environment that will be wrapped
+ video_folder (str): The folder where the recordings will be stored
+ episode_trigger: Function that accepts an integer and returns ``True`` iff a recording should be started at this episode
+ step_trigger: Function that accepts an integer and returns ``True`` iff a recording should be started at this step
+ video_length (int): The length of recorded episodes. If 0, entire episodes are recorded.
+ Otherwise, snippets of the specified length are captured
+ name_prefix (str): Will be prepended to the filename of the recordings
+ fps (int): The frame per second in the video. Provides a custom video fps for environment, if ``None`` then
+ the environment metadata ``render_fps`` key is used if it exists, otherwise a default value of 30 is used.
+ disable_logger (bool): Whether to disable moviepy logger or not, default it is disabled
+ gc_trigger: Function that accepts an integer and returns ``True`` iff garbage collection should be performed after this episode
+ """
+ gym.utils.RecordConstructorArgs.__init__(
+ self,
+ video_folder=video_folder,
+ episode_trigger=episode_trigger,
+ step_trigger=step_trigger,
+ video_length=video_length,
+ name_prefix=name_prefix,
+ disable_logger=disable_logger,
+ )
+ gym.Wrapper.__init__(self, env)
+
+ if env.render_mode in {None, "human", "ansi"}:
+ raise ValueError(
+ f"Render mode is {env.render_mode}, which is incompatible with RecordVideo.",
+ "Initialize your environment with a render_mode that returns an image, such as rgb_array.",
+ )
+
+ if episode_trigger is None and step_trigger is None:
+ from gymnasium.utils.save_video import capped_cubic_video_schedule
+
+ episode_trigger = capped_cubic_video_schedule
+
+ self.episode_trigger = episode_trigger
+ self.step_trigger = step_trigger
+ self.disable_logger = disable_logger
+ self.gc_trigger = gc_trigger
+
+ self.video_folder = os.path.abspath(video_folder)
+ if os.path.isdir(self.video_folder):
+ logger.warn(
+ f"Overwriting existing videos at {self.video_folder} folder "
+ f"(try specifying a different `video_folder` for the `RecordVideo` wrapper if this is not desired)"
+ )
+ os.makedirs(self.video_folder, exist_ok=True)
+
+ if fps is None:
+ fps = self.metadata.get("render_fps", 30)
+ self.frames_per_sec: int = fps
+ self.name_prefix: str = name_prefix
+ self._video_name: str | None = None
+ self.video_length: int = video_length if video_length != 0 else float("inf")
+ self.recording: bool = False
+ self.recorded_frames: list[RenderFrame] = []
+ self.render_history: list[RenderFrame] = []
+
+ self.step_id = -1
+ self.episode_id = -1
+
+ try:
+ import moviepy # noqa: F401
+ except ImportError as e:
+ raise error.DependencyNotInstalled(
+ 'MoviePy is not installed, run `pip install "gymnasium[other]"`'
+ ) from e
+
+ def _capture_frame(self):
+ assert self.recording, "Cannot capture a frame, recording wasn't started."
+
+ frame = self.env.render()
+ if isinstance(frame, list):
+ if len(frame) == 0: # render was called
+ return
+ self.render_history += frame
+ frame = frame[-1]
+
+ if isinstance(frame, np.ndarray):
+ self.recorded_frames.append(frame)
+ else:
+ self.stop_recording()
+ logger.warn(
+ f"Recording stopped: expected type of frame returned by render to be a numpy array, got instead {type(frame)}."
+ )
+
+ def reset(
+ self, *, seed: int | None = None, options: dict[str, Any] | None = None
+ ) -> tuple[ObsType, dict[str, Any]]:
+ """Reset the environment and eventually starts a new recording."""
+ obs, info = super().reset(seed=seed, options=options)
+ self.episode_id += 1
+
+ if self.recording and self.video_length == float("inf"):
+ self.stop_recording()
+
+ if self.episode_trigger and self.episode_trigger(self.episode_id):
+ self.start_recording(f"{self.name_prefix}-episode-{self.episode_id}")
+ if self.recording:
+ self._capture_frame()
+ if len(self.recorded_frames) > self.video_length:
+ self.stop_recording()
+
+ return obs, info
+
+ def step(
+ self, action: ActType
+ ) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
+ """Steps through the environment using action, recording observations if :attr:`self.recording`."""
+ obs, rew, terminated, truncated, info = self.env.step(action)
+ self.step_id += 1
+
+ if self.step_trigger and self.step_trigger(self.step_id):
+ self.start_recording(f"{self.name_prefix}-step-{self.step_id}")
+ if self.recording:
+ self._capture_frame()
+
+ if len(self.recorded_frames) > self.video_length:
+ self.stop_recording()
+
+ return obs, rew, terminated, truncated, info
+
+ def render(self) -> RenderFrame | list[RenderFrame]:
+ """Compute the render frames as specified by render_mode attribute during initialization of the environment."""
+ render_out = super().render()
+ if self.recording and isinstance(render_out, list):
+ self.recorded_frames += render_out
+
+ if len(self.render_history) > 0:
+ tmp_history = self.render_history
+ self.render_history = []
+ return tmp_history + render_out
+ else:
+ return render_out
+
+ def close(self):
+ """Closes the wrapper then the video recorder."""
+ super().close()
+ if self.recording:
+ self.stop_recording()
+
+ def start_recording(self, video_name: str):
+ """Start a new recording. If it is already recording, stops the current recording before starting the new one."""
+ if self.recording:
+ self.stop_recording()
+
+ self.recording = True
+ self._video_name = video_name
+
+ def stop_recording(self):
+ """Stop current recording and saves the video."""
+ assert self.recording, "stop_recording was called, but no recording was started"
+
+ if len(self.recorded_frames) == 0:
+ logger.warn("Ignored saving a video as there were zero frames to save.")
+ else:
+ try:
+ from moviepy.video.io.ImageSequenceClip import ImageSequenceClip
+ except ImportError as e:
+ raise error.DependencyNotInstalled(
+ 'MoviePy is not installed, run `pip install "gymnasium[other]"`'
+ ) from e
+
+ clip = ImageSequenceClip(self.recorded_frames, fps=self.frames_per_sec)
+ moviepy_logger = None if self.disable_logger else "bar"
+ path = os.path.join(self.video_folder, f"{self._video_name}.mp4")
+ clip.write_videofile(path, logger=moviepy_logger)
+
+ self.recorded_frames = []
+ self.recording = False
+ self._video_name = None
+
+ if self.gc_trigger and self.gc_trigger(self.episode_id):
+ gc.collect()
+
+ def __del__(self):
+ """Warn the user in case last video wasn't saved."""
+ if len(self.recorded_frames) > 0:
+ logger.warn("Unable to save last video! Did you call close()?")
+
+
+
+
+[docs]
+class HumanRendering(
+ gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
+):
+ """Allows human like rendering for environments that support "rgb_array" rendering.
+
+ This wrapper is particularly useful when you have implemented an environment that can produce
+ RGB images but haven't implemented any code to render the images to the screen.
+ If you want to use this wrapper with your environments, remember to specify ``"render_fps"``
+ in the metadata of your environment.
+
+ The ``render_mode`` of the wrapped environment must be either ``'rgb_array'`` or ``'rgb_array_list'``.
+
+ No vector version of the wrapper exists.
+
+ Example:
+ >>> import gymnasium as gym
+ >>> from gymnasium.wrappers import HumanRendering
+ >>> env = gym.make("LunarLander-v3", render_mode="rgb_array")
+ >>> wrapped = HumanRendering(env)
+ >>> obs, _ = wrapped.reset() # This will start rendering to the screen
+
+ The wrapper can also be applied directly when the environment is instantiated, simply by passing
+ ``render_mode="human"`` to ``make``. The wrapper will only be applied if the environment does not
+ implement human-rendering natively (i.e. ``render_mode`` does not contain ``"human"``).
+
+ >>> env = gym.make("phys2d/CartPole-v1", render_mode="human") # CartPoleJax-v1 doesn't implement human-rendering natively
+ >>> obs, _ = env.reset() # This will start rendering to the screen
+
+ Warning: If the base environment uses ``render_mode="rgb_array_list"``, its (i.e. the *base environment's*) render method
+ will always return an empty list:
+
+ >>> env = gym.make("LunarLander-v3", render_mode="rgb_array_list")
+ >>> wrapped = HumanRendering(env)
+ >>> obs, _ = wrapped.reset()
+ >>> env.render() # env.render() will always return an empty list!
+ []
+
+ Change logs:
+ * v0.25.0 - Initially added
+ """
+
+ ACCEPTED_RENDER_MODES = [
+ "rgb_array",
+ "rgb_array_list",
+ "depth_array",
+ "depth_array_list",
+ ]
+
+ def __init__(self, env: gym.Env[ObsType, ActType]):
+ """Initialize a :class:`HumanRendering` instance.
+
+ Args:
+ env: The environment that is being wrapped
+ """
+ gym.utils.RecordConstructorArgs.__init__(self)
+ gym.Wrapper.__init__(self, env)
+
+ self.screen_size = None
+ self.window = None # Has to be initialized before asserts, as self.window is used in auto close
+ self.clock = None
+
+ assert (
+ self.env.render_mode in self.ACCEPTED_RENDER_MODES
+ ), f"Expected env.render_mode to be one of {self.ACCEPTED_RENDER_MODES} but got '{env.render_mode}'"
+ assert (
+ "render_fps" in self.env.metadata
+ ), "The base environment must specify 'render_fps' to be used with the HumanRendering wrapper"
+
+ if "human" not in self.metadata["render_modes"]:
+ self.metadata = deepcopy(self.env.metadata)
+ self.metadata["render_modes"].append("human")
+
+ @property
+ def render_mode(self):
+ """Always returns ``'human'``."""
+ return "human"
+
+ def step(self, action: ActType) -> tuple[ObsType, SupportsFloat, bool, bool, dict]:
+ """Perform a step in the base environment and render a frame to the screen."""
+ result = super().step(action)
+ self._render_frame()
+ return result
+
+ def reset(
+ self, *, seed: int | None = None, options: dict[str, Any] | None = None
+ ) -> tuple[ObsType, dict[str, Any]]:
+ """Reset the base environment and render a frame to the screen."""
+ result = super().reset(seed=seed, options=options)
+ self._render_frame()
+ return result
+
+ def render(self) -> None:
+ """This method doesn't do much, actual rendering is performed in :meth:`step` and :meth:`reset`."""
+ return None
+
+ def _render_frame(self):
+ """Fetch the last frame from the base environment and render it to the screen."""
+ try:
+ import pygame
+ except ImportError:
+ raise DependencyNotInstalled(
+ 'pygame is not installed, run `pip install "gymnasium[classic-control]"`'
+ )
+ assert self.env.render_mode is not None
+ if self.env.render_mode.endswith("_list"):
+ last_rgb_array = self.env.render()
+ assert isinstance(last_rgb_array, list)
+ last_rgb_array = last_rgb_array[-1]
+ else:
+ last_rgb_array = self.env.render()
+
+ assert isinstance(
+ last_rgb_array, np.ndarray
+ ), f"Expected `env.render()` to return a numpy array, actually returned {type(last_rgb_array)}"
+
+ rgb_array = np.transpose(last_rgb_array, axes=(1, 0, 2))
+
+ if self.screen_size is None:
+ self.screen_size = rgb_array.shape[:2]
+
+ assert (
+ self.screen_size == rgb_array.shape[:2]
+ ), f"The shape of the rgb array has changed from {self.screen_size} to {rgb_array.shape[:2]}"
+
+ if self.window is None:
+ pygame.init()
+ pygame.display.init()
+ self.window = pygame.display.set_mode(self.screen_size)
+
+ if self.clock is None:
+ self.clock = pygame.time.Clock()
+
+ surf = pygame.surfarray.make_surface(rgb_array)
+ self.window.blit(surf, (0, 0))
+ pygame.event.pump()
+ self.clock.tick(self.metadata["render_fps"])
+ pygame.display.flip()
+
+ def close(self):
+ """Close the rendering window."""
+ if self.window is not None:
+ import pygame
+
+ pygame.display.quit()
+ pygame.quit()
+ super().close()
+
+
+
+class AddWhiteNoise(
+ gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
+):
+ """Randomly replaces pixels with white noise.
+
+ If used with ``render_mode="rgb_array"`` and ``AddRenderObservation``, it will
+ make observations noisy.
+ The environment may also become partially-observable, turning the MDP into a POMDP.
+
+ Example - Every pixel will be replaced by white noise with probability 0.5:
+ >>> env = gym.make("LunarLander-v3", render_mode="rgb_array")
+ >>> env = AddWhiteNoise(env, probability_of_noise_per_pixel=0.5)
+ >>> env = HumanRendering(env)
+ >>> obs, _ = env.reset(seed=123)
+ >>> obs, *_ = env.step(env.action_space.sample())
+ """
+
+ def __init__(
+ self,
+ env: gym.Env[ObsType, ActType],
+ probability_of_noise_per_pixel: float,
+ is_noise_grayscale: bool = False,
+ ):
+ """Wrapper replaces random pixels with white noise.
+
+ Args:
+ env: The environment that is being wrapped
+ probability_of_noise_per_pixel: the probability that a pixel is white noise
+ is_noise_grayscale: if True, RGB noise is converted to grayscale
+ """
+ if not 0 <= probability_of_noise_per_pixel < 1:
+ raise InvalidProbability(
+ f"probability_of_noise_per_pixel should be in the interval [0,1). Received {probability_of_noise_per_pixel}"
+ )
+
+ gym.utils.RecordConstructorArgs.__init__(
+ self,
+ probability_of_noise_per_pixel=probability_of_noise_per_pixel,
+ is_noise_grayscale=is_noise_grayscale,
+ )
+ gym.Wrapper.__init__(self, env)
+
+ self.probability_of_noise_per_pixel = probability_of_noise_per_pixel
+ self.is_noise_grayscale = is_noise_grayscale
+
+ def render(self) -> RenderFrame:
+ """Compute the render frames as specified by render_mode attribute during initialization of the environment, then add white noise."""
+ render_out = super().render()
+
+ if self.is_noise_grayscale:
+ noise = (
+ self.np_random.integers(
+ (0, 0, 0),
+ 255 * np.array([0.2989, 0.5870, 0.1140]),
+ size=render_out.shape,
+ dtype=np.uint8,
+ )
+ .sum(-1, keepdims=True)
+ .repeat(3, -1)
+ )
+ else:
+ noise = self.np_random.integers(
+ 0,
+ 255,
+ size=render_out.shape,
+ dtype=np.uint8,
+ )
+
+ mask = (
+ self.np_random.random(render_out.shape[0:2])
+ < self.probability_of_noise_per_pixel
+ )
+
+ return np.where(mask[..., None], noise, render_out)
+
+
+class ObstructView(
+ gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
+):
+ """Randomly obstructs rendering with white noise patches.
+
+ If used with ``render_mode="rgb_array"`` and ``AddRenderObservation``, it will
+ make observations noisy.
+ The number of patches depends on how many pixels we want to obstruct.
+ Depending on the size of the patches, the environment may become
+ partially-observable, turning the MDP into a POMDP.
+
+ Example - Obstruct 50% of the pixels with patches of size 50x50 pixels:
+ >>> env = gym.make("LunarLander-v3", render_mode="rgb_array")
+ >>> env = ObstructView(env, obstructed_pixels_ratio=0.5, obstruction_width=50)
+ >>> env = HumanRendering(env)
+ >>> obs, _ = env.reset(seed=123)
+ >>> obs, *_ = env.step(env.action_space.sample())
+ """
+
+ def __init__(
+ self,
+ env: gym.Env[ObsType, ActType],
+ obstructed_pixels_ratio: float,
+ obstruction_width: int,
+ is_noise_grayscale: bool = False,
+ ):
+ """Wrapper obstructs pixels with white noise patches.
+
+ Args:
+ env: The environment that is being wrapped
+ obstructed_pixels_ratio: the percentage of pixels obstructed with white noise
+ obstruction_width: the width of the obstruction patches
+ is_noise_grayscale: if True, RGB noise is converted to grayscale
+ """
+ if not 0 <= obstructed_pixels_ratio < 1:
+ raise ValueError(
+ f"obstructed_pixels_ratio should be in the interval [0,1). Received {obstructed_pixels_ratio}"
+ )
+
+ if obstruction_width < 1:
+ raise ValueError(
+ f"obstruction_width should be larger or equal than 1. Received {obstruction_width}"
+ )
+
+ gym.utils.RecordConstructorArgs.__init__(
+ self,
+ obstructed_pixels_ratio=obstructed_pixels_ratio,
+ obstruction_width=obstruction_width,
+ is_noise_grayscale=is_noise_grayscale,
+ )
+ gym.Wrapper.__init__(self, env)
+
+ self.obstruction_centers_ratio = obstructed_pixels_ratio / obstruction_width**2
+ self.obstruction_width = obstruction_width
+ self.is_noise_grayscale = is_noise_grayscale
+
+ def render(self) -> RenderFrame:
+ """Compute the render frames as specified by render_mode attribute during initialization of the environment, then add white noise patches."""
+ render_out = super().render()
+
+ render_shape = render_out.shape
+ n_pixels = render_shape[0] * render_shape[1]
+ n_obstructions = int(n_pixels * self.obstruction_centers_ratio)
+ centers = self.np_random.integers(0, n_pixels, n_obstructions)
+ centers = np.unravel_index(centers, (render_shape[0], render_shape[1]))
+ mask = np.zeros((render_shape[0], render_shape[1]), dtype=bool)
+ low = self.obstruction_width // 2
+ high = self.obstruction_width - low
+ for x, y in zip(*centers):
+ mask[
+ max(x - low, 0) : min(x + high, render_shape[0]),
+ max(y - low, 0) : min(y + high, render_shape[1]),
+ ] = True
+
+ if self.is_noise_grayscale:
+ noise = (
+ self.np_random.integers(
+ (0, 0, 0),
+ 255 * np.array([0.2989, 0.5870, 0.1140]),
+ size=render_out.shape,
+ dtype=np.uint8,
+ )
+ .sum(-1, keepdims=True)
+ .repeat(3, -1)
+ )
+ else:
+ noise = self.np_random.integers(
+ 0,
+ 255,
+ size=render_out.shape,
+ dtype=np.uint8,
+ )
+
+ return np.where(mask[..., None], noise, render_out)
+
+"""``StickyAction`` wrapper - There is a probability that the action is taken again."""
+
+from __future__ import annotations
+
+from typing import Any
+
+import numpy as np
+
+import gymnasium as gym
+from gymnasium.core import ActType, ObsType
+from gymnasium.error import InvalidBound, InvalidProbability
+
+
+__all__ = ["StickyAction"]
+
+
+
+[docs]
+class StickyAction(
+ gym.ActionWrapper[ObsType, ActType, ActType], gym.utils.RecordConstructorArgs
+):
+ """Adds a probability that the action is repeated for the same ``step`` function.
+
+ This wrapper follows the implementation proposed by `Machado et al., 2018 <https://arxiv.org/pdf/1709.06009.pdf>`_
+ in Section 5.2 on page 12, and adds the possibility to repeat the action for
+ more than one step.
+
+ No vector version of the wrapper exists.
+
+ Example:
+ >>> import gymnasium as gym
+ >>> env = gym.make("CartPole-v1")
+ >>> env = StickyAction(env, repeat_action_probability=0.9)
+ >>> env.reset(seed=123)
+ (array([ 0.01823519, -0.0446179 , -0.02796401, -0.03156282], dtype=float32), {})
+ >>> env.step(1)
+ (array([ 0.01734283, 0.15089367, -0.02859527, -0.33293587], dtype=float32), 1.0, False, False, {})
+ >>> env.step(0)
+ (array([ 0.0203607 , 0.34641072, -0.03525399, -0.6344974 ], dtype=float32), 1.0, False, False, {})
+ >>> env.step(1)
+ (array([ 0.02728892, 0.5420062 , -0.04794393, -0.9380709 ], dtype=float32), 1.0, False, False, {})
+ >>> env.step(0)
+ (array([ 0.03812904, 0.34756234, -0.06670535, -0.6608303 ], dtype=float32), 1.0, False, False, {})
+
+ Change logs:
+ * v1.0.0 - Initially added
+ * v1.1.0 - Add `repeat_action_duration` argument for dynamic number of sticky actions
+ """
+
+ def __init__(
+ self,
+ env: gym.Env[ObsType, ActType],
+ repeat_action_probability: float,
+ repeat_action_duration: int | tuple[int, int] = 1,
+ ):
+ """Initialize StickyAction wrapper.
+
+ Args:
+ env (Env): the wrapped environment,
+ repeat_action_probability (int | float): a probability of repeating the old action,
+ repeat_action_duration (int | tuple[int, int]): the number of steps
+ the action is repeated. It can be either an int (for deterministic
+ repeats) or a tuple[int, int] for a range of stochastic number of repeats.
+ """
+ if not 0 <= repeat_action_probability < 1:
+ raise InvalidProbability(
+ f"`repeat_action_probability` should be in the interval [0,1). Received {repeat_action_probability}"
+ )
+
+ if isinstance(repeat_action_duration, int):
+ repeat_action_duration = (repeat_action_duration, repeat_action_duration)
+
+ if not isinstance(repeat_action_duration, tuple):
+ raise ValueError(
+ f"`repeat_action_duration` should be either an integer or a tuple. Received {repeat_action_duration}"
+ )
+ elif len(repeat_action_duration) != 2:
+ raise ValueError(
+ f"`repeat_action_duration` should be a tuple or a list of two integers. Received {repeat_action_duration}"
+ )
+ elif repeat_action_duration[0] > repeat_action_duration[1]:
+ raise InvalidBound(
+ f"`repeat_action_duration` is not a valid bound. Received {repeat_action_duration}"
+ )
+ elif np.any(np.array(repeat_action_duration) < 1):
+ raise ValueError(
+ f"`repeat_action_duration` should be larger or equal than 1. Received {repeat_action_duration}"
+ )
+
+ gym.utils.RecordConstructorArgs.__init__(
+ self, repeat_action_probability=repeat_action_probability
+ )
+ gym.ActionWrapper.__init__(self, env)
+
+ self.repeat_action_probability = repeat_action_probability
+ self.repeat_action_duration_range = repeat_action_duration
+
+ self.last_action: ActType | None = None
+ self.is_sticky_actions: bool = False # if sticky actions are taken
+ self.num_repeats: int = 0 # number of sticky action repeats
+ self.repeats_taken: int = 0 # number of sticky actions taken
+
+ def reset(
+ self, *, seed: int | None = None, options: dict[str, Any] | None = None
+ ) -> tuple[ObsType, dict[str, Any]]:
+ """Reset the environment."""
+ self.last_action = None
+ self.is_sticky_actions = False
+ self.num_repeats = 0
+ self.repeats_taken = 0
+
+ return super().reset(seed=seed, options=options)
+
+ def action(self, action: ActType) -> ActType:
+ """Execute the action."""
+ # either the agent was already "stuck" into repeats, or a new series of repeats is triggered
+ if self.is_sticky_actions or (
+ self.last_action is not None
+ and self.np_random.uniform() < self.repeat_action_probability
+ ):
+ # if a new series starts, randomly sample its duration
+ if self.num_repeats == 0:
+ self.num_repeats = self.np_random.integers(
+ self.repeat_action_duration_range[0],
+ self.repeat_action_duration_range[1] + 1,
+ )
+ action = self.last_action
+ self.is_sticky_actions = True
+ self.repeats_taken += 1
+
+ # repeats are done, reset "stuck" status
+ if self.is_sticky_actions and self.num_repeats == self.repeats_taken:
+ self.is_sticky_actions = False
+ self.num_repeats = 0
+ self.repeats_taken = 0
+
+ self.last_action = action
+ return action
+
+
+"""A collection of stateful observation wrappers.
+
+* ``DelayObservation`` - A wrapper for delaying the returned observation
+* ``TimeAwareObservation`` - A wrapper for adding time aware observations to environment observation
+* ``FrameStackObservation`` - Frame stack the observations
+* ``NormalizeObservation`` - Normalized the observations to have unit variance with a moving mean
+* ``MaxAndSkipObservation`` - Return only every ``skip``-th frame (frameskipping) and return the max between the two last frames.
+"""
+
+from __future__ import annotations
+
+from collections import deque
+from copy import deepcopy
+from typing import Any, Final, SupportsFloat
+
+import numpy as np
+
+import gymnasium as gym
+import gymnasium.spaces as spaces
+from gymnasium.core import ActType, ObsType, WrapperActType, WrapperObsType
+from gymnasium.spaces import Box, Dict, Tuple
+from gymnasium.vector.utils import batch_space, concatenate, create_empty_array
+from gymnasium.wrappers.utils import RunningMeanStd, create_zero_array
+
+
+__all__ = [
+ "DelayObservation",
+ "TimeAwareObservation",
+ "FrameStackObservation",
+ "NormalizeObservation",
+ "MaxAndSkipObservation",
+]
+
+
+
+[docs]
+class DelayObservation(
+ gym.ObservationWrapper[ObsType, ActType, ObsType], gym.utils.RecordConstructorArgs
+):
+ """Adds a delay to the returned observation from the environment.
+
+ Before reaching the :attr:`delay` number of timesteps, returned observations is an array of zeros with
+ the same shape as the observation space.
+
+ No vector version of the wrapper exists.
+
+ Note:
+ This does not support random delay values, if users are interested, please raise an issue or pull request to add this feature.
+
+ Example:
+ >>> import gymnasium as gym
+ >>> env = gym.make("CartPole-v1")
+ >>> env.reset(seed=123)
+ (array([ 0.01823519, -0.0446179 , -0.02796401, -0.03156282], dtype=float32), {})
+
+ >>> env = DelayObservation(env, delay=2)
+ >>> env.reset(seed=123)
+ (array([0., 0., 0., 0.], dtype=float32), {})
+ >>> env.step(env.action_space.sample())
+ (array([0., 0., 0., 0.], dtype=float32), 1.0, False, False, {})
+ >>> env.step(env.action_space.sample())
+ (array([ 0.01823519, -0.0446179 , -0.02796401, -0.03156282], dtype=float32), 1.0, False, False, {})
+
+ Change logs:
+ * v1.0.0 - Initially added
+ """
+
+ def __init__(self, env: gym.Env[ObsType, ActType], delay: int):
+ """Initialises the DelayObservation wrapper with an integer.
+
+ Args:
+ env: The environment to wrap
+ delay: The number of timesteps to delay observations
+ """
+ if not np.issubdtype(type(delay), np.integer):
+ raise TypeError(
+ f"The delay is expected to be an integer, actual type: {type(delay)}"
+ )
+ if not 0 <= delay:
+ raise ValueError(
+ f"The delay needs to be greater than zero, actual value: {delay}"
+ )
+
+ gym.utils.RecordConstructorArgs.__init__(self, delay=delay)
+ gym.ObservationWrapper.__init__(self, env)
+
+ self.delay: Final[int] = int(delay)
+ self.observation_queue: Final[deque] = deque()
+
+ def reset(
+ self, *, seed: int | None = None, options: dict[str, Any] | None = None
+ ) -> tuple[ObsType, dict[str, Any]]:
+ """Resets the environment, clearing the observation queue."""
+ self.observation_queue.clear()
+
+ return super().reset(seed=seed, options=options)
+
+ def observation(self, observation: ObsType) -> ObsType:
+ """Return the delayed observation."""
+ self.observation_queue.append(observation)
+
+ if len(self.observation_queue) > self.delay:
+ return self.observation_queue.popleft()
+ else:
+ return create_zero_array(self.observation_space)
+
+
+
+
+[docs]
+class TimeAwareObservation(
+ gym.ObservationWrapper[WrapperObsType, ActType, ObsType],
+ gym.utils.RecordConstructorArgs,
+):
+ """Augment the observation with the number of time steps taken within an episode.
+
+ The :attr:`normalize_time` if ``True`` represents time as a normalized value between [0,1]
+ otherwise if ``False``, the current timestep is an integer.
+
+ For environments with ``Dict`` observation spaces, the time information is automatically
+ added in the key `"time"` (can be changed through :attr:`dict_time_key`) and for environments with ``Tuple``
+ observation space, the time information is added as the final element in the tuple.
+ Otherwise, the observation space is transformed into a ``Dict`` observation space with two keys,
+ `"obs"` for the base environment's observation and `"time"` for the time information.
+
+ To flatten the observation, use the :attr:`flatten` parameter which will use the
+ :func:`gymnasium.spaces.utils.flatten` function.
+
+ No vector version of the wrapper exists.
+
+ Example:
+ >>> import gymnasium as gym
+ >>> from gymnasium.wrappers import TimeAwareObservation
+ >>> env = gym.make("CartPole-v1")
+ >>> env = TimeAwareObservation(env)
+ >>> env.observation_space
+ Box([-4.80000019 -inf -0.41887903 -inf 0. ], [4.80000019e+00 inf 4.18879032e-01 inf
+ 5.00000000e+02], (5,), float64)
+ >>> env.reset(seed=42)[0]
+ array([ 0.0273956 , -0.00611216, 0.03585979, 0.0197368 , 0. ])
+ >>> _ = env.action_space.seed(42)
+ >>> env.step(env.action_space.sample())[0]
+ array([ 0.02727336, -0.20172954, 0.03625453, 0.32351476, 1. ])
+
+ Normalize time observation space example:
+ >>> env = gym.make('CartPole-v1')
+ >>> env = TimeAwareObservation(env, normalize_time=True)
+ >>> env.observation_space
+ Box([-4.8 -inf -0.41887903 -inf 0. ], [4.8 inf 0.41887903 inf 1. ], (5,), float32)
+ >>> env.reset(seed=42)[0]
+ array([ 0.0273956 , -0.00611216, 0.03585979, 0.0197368 , 0. ],
+ dtype=float32)
+ >>> _ = env.action_space.seed(42)
+ >>> env.step(env.action_space.sample())[0]
+ array([ 0.02727336, -0.20172954, 0.03625453, 0.32351476, 0.002 ],
+ dtype=float32)
+
+ Flatten observation space example:
+ >>> env = gym.make("CartPole-v1")
+ >>> env = TimeAwareObservation(env, flatten=False)
+ >>> env.observation_space
+ Dict('obs': Box([-4.8 -inf -0.41887903 -inf], [4.8 inf 0.41887903 inf], (4,), float32), 'time': Box(0, 500, (1,), int32))
+ >>> env.reset(seed=42)[0]
+ {'obs': array([ 0.0273956 , -0.00611216, 0.03585979, 0.0197368 ], dtype=float32), 'time': array([0], dtype=int32)}
+ >>> _ = env.action_space.seed(42)
+ >>> env.step(env.action_space.sample())[0]
+ {'obs': array([ 0.02727336, -0.20172954, 0.03625453, 0.32351476], dtype=float32), 'time': array([1], dtype=int32)}
+
+ Change logs:
+ * v0.18.0 - Initially added
+ * v1.0.0 - Remove vector environment support, add ``flatten`` and ``normalize_time`` parameters
+ """
+
+ def __init__(
+ self,
+ env: gym.Env[ObsType, ActType],
+ flatten: bool = True,
+ normalize_time: bool = False,
+ *,
+ dict_time_key: str = "time",
+ ):
+ """Initialize :class:`TimeAwareObservation`.
+
+ Args:
+ env: The environment to apply the wrapper
+ flatten: Flatten the observation to a `Box` of a single dimension
+ normalize_time: if `True` return time in the range [0,1]
+ otherwise return time as remaining timesteps before truncation
+ dict_time_key: For environment with a ``Dict`` observation space, the key for the time space. By default, `"time"`.
+ """
+ gym.utils.RecordConstructorArgs.__init__(
+ self,
+ flatten=flatten,
+ normalize_time=normalize_time,
+ dict_time_key=dict_time_key,
+ )
+ gym.ObservationWrapper.__init__(self, env)
+
+ self.flatten: Final[bool] = flatten
+ self.normalize_time: Final[bool] = normalize_time
+
+ # We don't need to keep if a TimeLimit wrapper exists as `spec` will do that work for us now
+ if env.spec is not None and env.spec.max_episode_steps is not None:
+ self.max_timesteps = env.spec.max_episode_steps
+ else:
+ # else we need to loop through the environment stack to check if a `TimeLimit` wrapper exists
+ wrapped_env = env
+ while isinstance(wrapped_env, gym.Wrapper):
+ if isinstance(wrapped_env, gym.wrappers.TimeLimit):
+ self.max_timesteps = wrapped_env._max_episode_steps
+ break
+ wrapped_env = wrapped_env.env
+
+ if not isinstance(wrapped_env, gym.wrappers.TimeLimit):
+ raise ValueError(
+ "The environment must be wrapped by a TimeLimit wrapper or the spec specify a `max_episode_steps`."
+ )
+
+ self.timesteps: int = 0
+
+ # Find the normalized time space
+ if self.normalize_time:
+ self._time_preprocess_func = lambda time: np.array(
+ [time / self.max_timesteps], dtype=np.float32
+ )
+ time_space = Box(0.0, 1.0)
+ else:
+ self._time_preprocess_func = lambda time: np.array([time], dtype=np.int32)
+ time_space = Box(0, self.max_timesteps, dtype=np.int32)
+
+ # Find the observation space
+ if isinstance(env.observation_space, Dict):
+ assert dict_time_key not in env.observation_space.keys()
+ observation_space = Dict(
+ {dict_time_key: time_space, **env.observation_space.spaces}
+ )
+ self._append_data_func = lambda obs, time: {dict_time_key: time, **obs}
+ elif isinstance(env.observation_space, Tuple):
+ observation_space = Tuple(env.observation_space.spaces + (time_space,))
+ self._append_data_func = lambda obs, time: obs + (time,)
+ else:
+ observation_space = Dict(obs=env.observation_space, time=time_space)
+ self._append_data_func = lambda obs, time: {"obs": obs, "time": time}
+
+ # If to flatten the observation space
+ if self.flatten:
+ self.observation_space: gym.Space[WrapperObsType] = spaces.flatten_space(
+ observation_space
+ )
+ self._obs_postprocess_func = lambda obs: spaces.flatten(
+ observation_space, obs
+ )
+ else:
+ self.observation_space: gym.Space[WrapperObsType] = observation_space
+ self._obs_postprocess_func = lambda obs: obs
+
+ def observation(self, observation: ObsType) -> WrapperObsType:
+ """Adds to the observation with the current time information.
+
+ Args:
+ observation: The observation to add the time step to
+
+ Returns:
+ The observation with the time information appended to it
+ """
+ return self._obs_postprocess_func(
+ self._append_data_func(
+ observation, self._time_preprocess_func(self.timesteps)
+ )
+ )
+
+ def step(
+ self, action: ActType
+ ) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]:
+ """Steps through the environment, incrementing the time step.
+
+ Args:
+ action: The action to take
+
+ Returns:
+ The environment's step using the action with the next observation containing the timestep info
+ """
+ self.timesteps += 1
+
+ return super().step(action)
+
+ def reset(
+ self, *, seed: int | None = None, options: dict[str, Any] | None = None
+ ) -> tuple[WrapperObsType, dict[str, Any]]:
+ """Reset the environment setting the time to zero.
+
+ Args:
+ seed: The seed to reset the environment
+ options: The options used to reset the environment
+
+ Returns:
+ Resets the environment with the initial timestep info added the observation
+ """
+ self.timesteps = 0
+
+ return super().reset(seed=seed, options=options)
+
+
+
+
+[docs]
+class FrameStackObservation(
+ gym.Wrapper[WrapperObsType, ActType, ObsType, ActType],
+ gym.utils.RecordConstructorArgs,
+):
+ """Stacks the observations from the last ``N`` time steps in a rolling manner.
+
+ For example, if the number of stacks is 4, then the returned observation contains
+ the most recent 4 observations. For environment 'Pendulum-v1', the original observation
+ is an array with shape [3], so if we stack 4 observations, the processed observation
+ has shape [4, 3].
+
+ Users have options for the padded observation used:
+
+ * "reset" (default) - The reset value is repeated
+ * "zero" - A "zero"-like instance of the observation space
+ * custom - An instance of the observation space
+
+ No vector version of the wrapper exists.
+
+ Example:
+ >>> import gymnasium as gym
+ >>> from gymnasium.wrappers import FrameStackObservation
+ >>> env = gym.make("CarRacing-v3")
+ >>> env = FrameStackObservation(env, stack_size=4)
+ >>> env.observation_space
+ Box(0, 255, (4, 96, 96, 3), uint8)
+ >>> obs, _ = env.reset()
+ >>> obs.shape
+ (4, 96, 96, 3)
+
+ Example with different padding observations:
+ >>> env = gym.make("CartPole-v1")
+ >>> env.reset(seed=123)
+ (array([ 0.01823519, -0.0446179 , -0.02796401, -0.03156282], dtype=float32), {})
+ >>> stacked_env = FrameStackObservation(env, 3) # the default is padding_type="reset"
+ >>> stacked_env.reset(seed=123)
+ (array([[ 0.01823519, -0.0446179 , -0.02796401, -0.03156282],
+ [ 0.01823519, -0.0446179 , -0.02796401, -0.03156282],
+ [ 0.01823519, -0.0446179 , -0.02796401, -0.03156282]],
+ dtype=float32), {})
+
+
+ >>> stacked_env = FrameStackObservation(env, 3, padding_type="zero")
+ >>> stacked_env.reset(seed=123)
+ (array([[ 0. , 0. , 0. , 0. ],
+ [ 0. , 0. , 0. , 0. ],
+ [ 0.01823519, -0.0446179 , -0.02796401, -0.03156282]],
+ dtype=float32), {})
+ >>> stacked_env = FrameStackObservation(env, 3, padding_type=np.array([1, -1, 0, 2], dtype=np.float32))
+ >>> stacked_env.reset(seed=123)
+ (array([[ 1. , -1. , 0. , 2. ],
+ [ 1. , -1. , 0. , 2. ],
+ [ 0.01823519, -0.0446179 , -0.02796401, -0.03156282]],
+ dtype=float32), {})
+
+ Change logs:
+ * v0.15.0 - Initially add as ``FrameStack`` with support for lz4
+ * v1.0.0 - Rename to ``FrameStackObservation`` and remove lz4 and ``LazyFrame`` support
+ along with adding the ``padding_type`` parameter
+
+ """
+
+ def __init__(
+ self,
+ env: gym.Env[ObsType, ActType],
+ stack_size: int,
+ *,
+ padding_type: str | ObsType = "reset",
+ ):
+ """Observation wrapper that stacks the observations in a rolling manner.
+
+ Args:
+ env: The environment to apply the wrapper
+ stack_size: The number of frames to stack.
+ padding_type: The padding type to use when stacking the observations, options: "reset", "zero", custom obs
+ """
+ gym.utils.RecordConstructorArgs.__init__(
+ self, stack_size=stack_size, padding_type=padding_type
+ )
+ gym.Wrapper.__init__(self, env)
+
+ if not np.issubdtype(type(stack_size), np.integer):
+ raise TypeError(
+ f"The stack_size is expected to be an integer, actual type: {type(stack_size)}"
+ )
+ if not 0 < stack_size:
+ raise ValueError(
+ f"The stack_size needs to be greater than zero, actual value: {stack_size}"
+ )
+ if isinstance(padding_type, str) and (
+ padding_type == "reset" or padding_type == "zero"
+ ):
+ self.padding_value: ObsType = create_zero_array(env.observation_space)
+ elif padding_type in env.observation_space:
+ self.padding_value = padding_type
+ padding_type = "_custom"
+ else:
+ if isinstance(padding_type, str):
+ raise ValueError( # we are guessing that the user just entered the "reset" or "zero" wrong
+ f"Unexpected `padding_type`, expected 'reset', 'zero' or a custom observation space, actual value: {padding_type!r}"
+ )
+ else:
+ raise ValueError(
+ f"Unexpected `padding_type`, expected 'reset', 'zero' or a custom observation space, actual value: {padding_type!r} not an instance of env observation ({env.observation_space})"
+ )
+
+ self.observation_space = batch_space(env.observation_space, n=stack_size)
+ self.stack_size: Final[int] = stack_size
+ self.padding_type: Final[str] = padding_type
+
+ self.obs_queue = deque(
+ [self.padding_value for _ in range(self.stack_size)], maxlen=self.stack_size
+ )
+ self.stacked_obs = create_empty_array(env.observation_space, n=self.stack_size)
+
+ def step(
+ self, action: WrapperActType
+ ) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]:
+ """Steps through the environment, appending the observation to the frame buffer.
+
+ Args:
+ action: The action to step through the environment with
+
+ Returns:
+ Stacked observations, reward, terminated, truncated, and info from the environment
+ """
+ obs, reward, terminated, truncated, info = self.env.step(action)
+ self.obs_queue.append(obs)
+
+ updated_obs = deepcopy(
+ concatenate(self.env.observation_space, self.obs_queue, self.stacked_obs)
+ )
+ return updated_obs, reward, terminated, truncated, info
+
+ def reset(
+ self, *, seed: int | None = None, options: dict[str, Any] | None = None
+ ) -> tuple[WrapperObsType, dict[str, Any]]:
+ """Reset the environment, returning the stacked observation and info.
+
+ Args:
+ seed: The environment seed
+ options: The reset options
+
+ Returns:
+ The stacked observations and info
+ """
+ obs, info = self.env.reset(seed=seed, options=options)
+
+ if self.padding_type == "reset":
+ self.padding_value = obs
+ for _ in range(self.stack_size - 1):
+ self.obs_queue.append(self.padding_value)
+ self.obs_queue.append(obs)
+
+ updated_obs = deepcopy(
+ concatenate(self.env.observation_space, self.obs_queue, self.stacked_obs)
+ )
+ return updated_obs, info
+
+
+
+
+[docs]
+class NormalizeObservation(
+ gym.ObservationWrapper[WrapperObsType, ActType, ObsType],
+ gym.utils.RecordConstructorArgs,
+):
+ """Normalizes observations to be centered at the mean with unit variance.
+
+ The property :attr:`update_running_mean` allows to freeze/continue the running mean calculation of the observation
+ statistics. If ``True`` (default), the ``RunningMeanStd`` will get updated every time ``step`` or ``reset`` is called.
+ If ``False``, the calculated statistics are used but not updated anymore; this may be used during evaluation.
+
+ A vector version of the wrapper exists :class:`gymnasium.wrappers.vector.NormalizeObservation`.
+
+ Note:
+ The normalization depends on past trajectories and observations will not be normalized correctly if the wrapper was
+ newly instantiated or the policy was changed recently.
+
+ Example:
+ >>> import numpy as np
+ >>> import gymnasium as gym
+ >>> env = gym.make("CartPole-v1")
+ >>> obs, info = env.reset(seed=123)
+ >>> term, trunc = False, False
+ >>> while not (term or trunc):
+ ... obs, _, term, trunc, _ = env.step(1)
+ ...
+ >>> obs
+ array([ 0.1511158 , 1.7183299 , -0.25533703, -2.8914354 ], dtype=float32)
+ >>> env = gym.make("CartPole-v1")
+ >>> env = NormalizeObservation(env)
+ >>> obs, info = env.reset(seed=123)
+ >>> term, trunc = False, False
+ >>> while not (term or trunc):
+ ... obs, _, term, trunc, _ = env.step(1)
+ >>> obs
+ array([ 2.0059888, 1.5676788, -1.9944268, -1.6120394], dtype=float32)
+
+ Change logs:
+ * v0.21.0 - Initially add
+ * v1.0.0 - Add `update_running_mean` attribute to allow disabling of updating the running mean / standard, particularly useful for evaluation time.
+ Casts all observations to `np.float32` and sets the observation space with low/high of `-np.inf` and `np.inf` and dtype as `np.float32`
+ """
+
+ def __init__(self, env: gym.Env[ObsType, ActType], epsilon: float = 1e-8):
+ """This wrapper will normalize observations such that each observation is centered with unit variance.
+
+ Args:
+ env (Env): The environment to apply the wrapper
+ epsilon: A stability parameter that is used when scaling the observations.
+ """
+ gym.utils.RecordConstructorArgs.__init__(self, epsilon=epsilon)
+ gym.ObservationWrapper.__init__(self, env)
+
+ assert env.observation_space.shape is not None
+ self.observation_space = gym.spaces.Box(
+ low=-np.inf,
+ high=np.inf,
+ shape=env.observation_space.shape,
+ dtype=np.float32,
+ )
+
+ self.obs_rms = RunningMeanStd(
+ shape=self.observation_space.shape, dtype=self.observation_space.dtype
+ )
+ self.epsilon = epsilon
+ self._update_running_mean = True
+
+ @property
+ def update_running_mean(self) -> bool:
+ """Property to freeze/continue the running mean calculation of the observation statistics."""
+ return self._update_running_mean
+
+ @update_running_mean.setter
+ def update_running_mean(self, setting: bool):
+ """Sets the property to freeze/continue the running mean calculation of the observation statistics."""
+ self._update_running_mean = setting
+
+ def observation(self, observation: ObsType) -> WrapperObsType:
+ """Normalises the observation using the running mean and variance of the observations."""
+ if self._update_running_mean:
+ self.obs_rms.update(np.array([observation]))
+ return np.float32(
+ (observation - self.obs_rms.mean) / np.sqrt(self.obs_rms.var + self.epsilon)
+ )
+
+
+
+
+[docs]
+class MaxAndSkipObservation(
+ gym.Wrapper[WrapperObsType, ActType, ObsType, ActType],
+ gym.utils.RecordConstructorArgs,
+):
+ """Skips the N-th frame (observation) and return the max values between the two last observations.
+
+ No vector version of the wrapper exists.
+
+ Note:
+ This wrapper is based on the wrapper from [stable-baselines3](https://stable-baselines3.readthedocs.io/en/master/_modules/stable_baselines3/common/atari_wrappers.html#MaxAndSkipEnv)
+
+ Example:
+ >>> import gymnasium as gym
+ >>> env = gym.make("CartPole-v1")
+ >>> obs0, *_ = env.reset(seed=123)
+ >>> obs1, *_ = env.step(1)
+ >>> obs2, *_ = env.step(1)
+ >>> obs3, *_ = env.step(1)
+ >>> obs4, *_ = env.step(1)
+ >>> skip_and_max_obs = np.max(np.stack([obs3, obs4], axis=0), axis=0)
+ >>> env = gym.make("CartPole-v1")
+ >>> wrapped_env = MaxAndSkipObservation(env)
+ >>> wrapped_obs0, *_ = wrapped_env.reset(seed=123)
+ >>> wrapped_obs1, *_ = wrapped_env.step(1)
+ >>> np.all(obs0 == wrapped_obs0)
+ np.True_
+ >>> np.all(wrapped_obs1 == skip_and_max_obs)
+ np.True_
+
+ Change logs:
+ * v1.0.0 - Initially add
+ """
+
+ def __init__(self, env: gym.Env[ObsType, ActType], skip: int = 4):
+ """This wrapper will return only every ``skip``-th frame (frameskipping) and return the max between the two last frames.
+
+ Args:
+ env (Env): The environment to apply the wrapper
+ skip: The number of frames to skip
+ """
+ gym.utils.RecordConstructorArgs.__init__(self, skip=skip)
+ gym.Wrapper.__init__(self, env)
+
+ if not np.issubdtype(type(skip), np.integer):
+ raise TypeError(
+ f"The skip is expected to be an integer, actual type: {type(skip)}"
+ )
+ if skip < 2:
+ raise ValueError(
+ f"The skip value needs to be equal or greater than two, actual value: {skip}"
+ )
+ if env.observation_space.shape is None:
+ raise ValueError("The observation space must have the shape attribute.")
+
+ self._skip = skip
+ self._obs_buffer = np.zeros(
+ (2, *env.observation_space.shape), dtype=env.observation_space.dtype
+ )
+
+ def step(
+ self, action: WrapperActType
+ ) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]:
+ """Step the environment with the given action for ``skip`` steps.
+
+ Repeat action, sum reward, and max over last observations.
+
+ Args:
+ action: The action to step through the environment with
+ Returns:
+ Max of the last two observations, reward, terminated, truncated, and info from the environment
+ """
+ total_reward = 0.0
+ terminated = truncated = False
+ info = {}
+ for i in range(self._skip):
+ obs, reward, terminated, truncated, info = self.env.step(action)
+ if i == self._skip - 2:
+ self._obs_buffer[0] = obs
+ if i == self._skip - 1:
+ self._obs_buffer[1] = obs
+ total_reward += float(reward)
+ if terminated or truncated:
+ break
+ max_frame = np.max(self._obs_buffer, axis=0)
+
+ return max_frame, total_reward, terminated, truncated, info
+
+
+"""A collection of wrappers for modifying the reward with an internal state.
+
+* ``NormalizeReward`` - Normalizes the rewards to a mean and standard deviation
+"""
+
+from __future__ import annotations
+
+from typing import Any, SupportsFloat
+
+import numpy as np
+
+import gymnasium as gym
+from gymnasium.core import ActType, ObsType
+from gymnasium.wrappers.utils import RunningMeanStd
+
+
+__all__ = ["NormalizeReward"]
+
+
+
+[docs]
+class NormalizeReward(
+ gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
+):
+ r"""Normalizes immediate rewards such that their exponential moving average has an approximately fixed variance.
+
+ The property `_update_running_mean` allows to freeze/continue the running mean calculation of the reward
+ statistics. If `True` (default), the `RunningMeanStd` will get updated every time `self.normalize()` is called.
+ If False, the calculated statistics are used but not updated anymore; this may be used during evaluation.
+
+ A vector version of the wrapper exists :class:`gymnasium.wrappers.vector.NormalizeReward`.
+
+ Note:
+ In v0.27, NormalizeReward was updated as the forward discounted reward estimate was incorrectly computed in Gym v0.25+.
+ For more detail, read [#3154](https://github.com/openai/gym/pull/3152).
+
+ Note:
+ The scaling depends on past trajectories and rewards will not be scaled correctly if the wrapper was newly
+ instantiated or the policy was changed recently.
+
+ Example without the normalize reward wrapper:
+ >>> import numpy as np
+ >>> import gymnasium as gym
+ >>> env = gym.make("MountainCarContinuous-v0")
+ >>> _ = env.reset(seed=123)
+ >>> _ = env.action_space.seed(123)
+ >>> episode_rewards = []
+ >>> terminated, truncated = False, False
+ >>> while not (terminated or truncated):
+ ... observation, reward, terminated, truncated, info = env.step(env.action_space.sample())
+ ... episode_rewards.append(reward)
+ ...
+ >>> env.close()
+ >>> np.var(episode_rewards)
+ np.float64(0.0008876301247721108)
+
+ Example with the normalize reward wrapper:
+ >>> import numpy as np
+ >>> import gymnasium as gym
+ >>> env = gym.make("MountainCarContinuous-v0")
+ >>> env = NormalizeReward(env, gamma=0.99, epsilon=1e-8)
+ >>> _ = env.reset(seed=123)
+ >>> _ = env.action_space.seed(123)
+ >>> episode_rewards = []
+ >>> terminated, truncated = False, False
+ >>> while not (terminated or truncated):
+ ... observation, reward, terminated, truncated, info = env.step(env.action_space.sample())
+ ... episode_rewards.append(reward)
+ ...
+ >>> env.close()
+ >>> np.var(episode_rewards)
+ np.float64(0.010162116476634746)
+
+ Change logs:
+ * v0.21.0 - Initially added
+ * v1.0.0 - Add `update_running_mean` attribute to allow disabling of updating the running mean / standard
+ """
+
+ def __init__(
+ self,
+ env: gym.Env[ObsType, ActType],
+ gamma: float = 0.99,
+ epsilon: float = 1e-8,
+ ):
+ """This wrapper will normalize immediate rewards s.t. their exponential moving average has an approximately fixed variance.
+
+ Args:
+ env (env): The environment to apply the wrapper
+ epsilon (float): A stability parameter
+ gamma (float): The discount factor that is used in the exponential moving average.
+ """
+ gym.utils.RecordConstructorArgs.__init__(self, gamma=gamma, epsilon=epsilon)
+ gym.Wrapper.__init__(self, env)
+
+ self.return_rms = RunningMeanStd(shape=())
+ self.discounted_reward = np.array([0.0])
+ self.gamma = gamma
+ self.epsilon = epsilon
+ self._update_running_mean = True
+
+ @property
+ def update_running_mean(self) -> bool:
+ """Property to freeze/continue the running mean calculation of the reward statistics."""
+ return self._update_running_mean
+
+ @update_running_mean.setter
+ def update_running_mean(self, setting: bool):
+ """Sets the property to freeze/continue the running mean calculation of the reward statistics."""
+ self._update_running_mean = setting
+
+ def step(
+ self, action: ActType
+ ) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
+ """Steps through the environment, normalizing the reward returned."""
+ obs, reward, terminated, truncated, info = super().step(action)
+
+ # Using the `discounted_reward` rather than `reward` makes no sense but for backward compatibility, it is being kept
+ self.discounted_reward = self.discounted_reward * self.gamma * (
+ 1 - terminated
+ ) + float(reward)
+ if self._update_running_mean:
+ self.return_rms.update(self.discounted_reward)
+
+ # We don't (reward - self.return_rms.mean) see https://github.com/openai/baselines/issues/538
+ normalized_reward = reward / np.sqrt(self.return_rms.var + self.epsilon)
+ return obs, normalized_reward, terminated, truncated, info
+
+
+"""A collection of wrappers that all use the LambdaAction class.
+
+* ``TransformAction`` - Transforms the actions based on a function
+* ``ClipAction`` - Clips the action within a bounds
+* ``RescaleAction`` - Rescales the action within a minimum and maximum actions
+"""
+
+from __future__ import annotations
+
+from collections.abc import Callable
+
+import numpy as np
+
+import gymnasium as gym
+from gymnasium.core import ActType, ObsType, WrapperActType
+from gymnasium.spaces import Box, Space
+
+
+__all__ = ["TransformAction", "ClipAction", "RescaleAction"]
+
+from gymnasium.wrappers.utils import rescale_box
+
+
+
+[docs]
+class TransformAction(
+ gym.ActionWrapper[ObsType, WrapperActType, ActType], gym.utils.RecordConstructorArgs
+):
+ """Applies a function to the ``action`` before passing the modified value to the environment ``step`` function.
+
+ A vector version of the wrapper exists :class:`gymnasium.wrappers.vector.TransformAction`.
+
+ Example:
+ >>> import numpy as np
+ >>> import gymnasium as gym
+ >>> env = gym.make("MountainCarContinuous-v0")
+ >>> _ = env.reset(seed=123)
+ >>> obs, *_= env.step(np.array([0.0, 1.0]))
+ >>> obs
+ array([-4.6397772e-01, -4.4808415e-04], dtype=float32)
+ >>> env = gym.make("MountainCarContinuous-v0")
+ >>> env = TransformAction(env, lambda a: 0.5 * a + 0.1, env.action_space)
+ >>> _ = env.reset(seed=123)
+ >>> obs, *_= env.step(np.array([0.0, 1.0]))
+ >>> obs
+ array([-4.6382770e-01, -2.9808417e-04], dtype=float32)
+
+ Change logs:
+ * v1.0.0 - Initially added
+ """
+
+ def __init__(
+ self,
+ env: gym.Env[ObsType, ActType],
+ func: Callable[[WrapperActType], ActType],
+ action_space: Space[WrapperActType] | None,
+ ):
+ """Initialize TransformAction.
+
+ Args:
+ env: The environment to wrap
+ func: Function to apply to the :meth:`step`'s ``action``
+ action_space: The updated action space of the wrapper given the function.
+ """
+ gym.utils.RecordConstructorArgs.__init__(
+ self, func=func, action_space=action_space
+ )
+ gym.Wrapper.__init__(self, env)
+
+ if action_space is not None:
+ self.action_space = action_space
+
+ self.func = func
+
+ def action(self, action: WrapperActType) -> ActType:
+ """Apply function to action."""
+ return self.func(action)
+
+
+
+
+[docs]
+class ClipAction(
+ TransformAction[ObsType, WrapperActType, ActType], gym.utils.RecordConstructorArgs
+):
+ """Clips the ``action`` pass to ``step`` to be within the environment's `action_space`.
+
+ A vector version of the wrapper exists :class:`gymnasium.wrappers.vector.ClipAction`.
+
+ Example:
+ >>> import gymnasium as gym
+ >>> from gymnasium.wrappers import ClipAction
+ >>> import numpy as np
+ >>> env = gym.make("Hopper-v4", disable_env_checker=True)
+ >>> env = ClipAction(env)
+ >>> env.action_space
+ Box(-inf, inf, (3,), float32)
+ >>> _ = env.reset(seed=42)
+ >>> _ = env.step(np.array([5.0, -2.0, 0.0], dtype=np.float32))
+ ... # Executes the action np.array([1.0, -1.0, 0]) in the base environment
+
+ Change logs:
+ * v0.12.6 - Initially added
+ * v1.0.0 - Action space is updated to infinite bounds as is technically correct
+ """
+
+ def __init__(self, env: gym.Env[ObsType, ActType]):
+ """A wrapper for clipping continuous actions within the valid bound.
+
+ Args:
+ env: The environment to wrap
+ """
+ assert isinstance(env.action_space, Box)
+
+ gym.utils.RecordConstructorArgs.__init__(self)
+ TransformAction.__init__(
+ self,
+ env=env,
+ func=lambda action: np.clip(
+ action, env.action_space.low, env.action_space.high
+ ),
+ action_space=Box(
+ -np.inf,
+ np.inf,
+ shape=env.action_space.shape,
+ dtype=env.action_space.dtype,
+ ),
+ )
+
+
+
+
+[docs]
+class RescaleAction(
+ TransformAction[ObsType, WrapperActType, ActType], gym.utils.RecordConstructorArgs
+):
+ """Affinely (linearly) rescales a ``Box`` action space of the environment to within the range of ``[min_action, max_action]``.
+
+ The base environment :attr:`env` must have an action space of type :class:`spaces.Box`. If :attr:`min_action`
+ or :attr:`max_action` are numpy arrays, the shape must match the shape of the environment's action space.
+
+ A vector version of the wrapper exists :class:`gymnasium.wrappers.vector.RescaleAction`.
+
+ Example:
+ >>> import gymnasium as gym
+ >>> from gymnasium.wrappers import RescaleAction
+ >>> import numpy as np
+ >>> env = gym.make("Hopper-v4", disable_env_checker=True)
+ >>> _ = env.reset(seed=42)
+ >>> obs, _, _, _, _ = env.step(np.array([1, 1, 1], dtype=np.float32))
+ >>> _ = env.reset(seed=42)
+ >>> min_action = -0.5
+ >>> max_action = np.array([0.0, 0.5, 0.75], dtype=np.float32)
+ >>> wrapped_env = RescaleAction(env, min_action=min_action, max_action=max_action)
+ >>> wrapped_env_obs, _, _, _, _ = wrapped_env.step(max_action)
+ >>> np.all(obs == wrapped_env_obs)
+ np.True_
+
+ Change logs:
+ * v0.15.4 - Initially added
+ """
+
+ def __init__(
+ self,
+ env: gym.Env[ObsType, ActType],
+ min_action: np.floating | np.integer | np.ndarray,
+ max_action: np.floating | np.integer | np.ndarray,
+ ):
+ """Constructor for the Rescale Action wrapper.
+
+ Args:
+ env (Env): The environment to wrap
+ min_action (float, int or np.ndarray): The min values for each action. This may be a numpy array or a scalar.
+ max_action (float, int or np.ndarray): The max values for each action. This may be a numpy array or a scalar.
+ """
+ assert isinstance(env.action_space, Box)
+
+ gym.utils.RecordConstructorArgs.__init__(
+ self, min_action=min_action, max_action=max_action
+ )
+
+ act_space, _, func = rescale_box(env.action_space, min_action, max_action)
+ TransformAction.__init__(
+ self,
+ env=env,
+ func=func,
+ action_space=act_space,
+ )
+
+
+"""A collection of observation wrappers using a lambda function.
+
+* ``TransformObservation`` - Transforms the observation with a function
+* ``FilterObservation`` - Filters a ``Tuple`` or ``Dict`` to only include certain keys
+* ``FlattenObservation`` - Flattens the observations
+* ``GrayscaleObservation`` - Converts a RGB observation to a grayscale observation
+* ``ResizeObservation`` - Resizes an array-based observation (normally a RGB observation)
+* ``ReshapeObservation`` - Reshapes an array-based observation
+* ``RescaleObservation`` - Rescales an observation to between a minimum and maximum value
+* ``DtypeObservation`` - Convert an observation to a dtype
+* ``RenderObservation`` - Allows the observation to the rendered frame
+"""
+
+from __future__ import annotations
+
+from collections.abc import Callable, Sequence
+from typing import Any, Final
+
+import numpy as np
+
+import gymnasium as gym
+from gymnasium import spaces
+from gymnasium.core import ActType, ObsType, WrapperObsType
+from gymnasium.error import DependencyNotInstalled
+
+
+__all__ = [
+ "TransformObservation",
+ "FilterObservation",
+ "FlattenObservation",
+ "GrayscaleObservation",
+ "ResizeObservation",
+ "ReshapeObservation",
+ "RescaleObservation",
+ "DtypeObservation",
+ "AddRenderObservation",
+]
+
+from gymnasium.wrappers.utils import rescale_box
+
+
+
+[docs]
+class TransformObservation(
+ gym.ObservationWrapper[WrapperObsType, ActType, ObsType],
+ gym.utils.RecordConstructorArgs,
+):
+ """Applies a function to the ``observation`` received from the environment's :meth:`Env.reset` and :meth:`Env.step` that is passed back to the user.
+
+ The function :attr:`func` will be applied to all observations.
+ If the observations from :attr:`func` are outside the bounds of the ``env``'s observation space, provide an updated :attr:`observation_space`.
+
+ A vector version of the wrapper exists :class:`gymnasium.wrappers.vector.TransformObservation`.
+
+ Example:
+ >>> import gymnasium as gym
+ >>> from gymnasium.wrappers import TransformObservation
+ >>> import numpy as np
+ >>> np.random.seed(0)
+ >>> env = gym.make("CartPole-v1")
+ >>> env.reset(seed=42)
+ (array([ 0.0273956 , -0.00611216, 0.03585979, 0.0197368 ], dtype=float32), {})
+ >>> env = gym.make("CartPole-v1")
+ >>> env = TransformObservation(env, lambda obs: obs + 0.1 * np.random.random(obs.shape), env.observation_space)
+ >>> env.reset(seed=42)
+ (array([0.08227695, 0.06540678, 0.09613613, 0.07422512]), {})
+
+ Change logs:
+ * v0.15.4 - Initially added
+ * v1.0.0 - Add requirement of ``observation_space``
+ """
+
+ def __init__(
+ self,
+ env: gym.Env[ObsType, ActType],
+ func: Callable[[ObsType], Any],
+ observation_space: gym.Space[WrapperObsType] | None,
+ ):
+ """Constructor for the transform observation wrapper.
+
+ Args:
+ env: The environment to wrap
+ func: A function that will transform an observation. If this transformed observation is outside the observation space of ``env.observation_space`` then provide an `observation_space`.
+ observation_space: The observation spaces of the wrapper, if None, then it is assumed the same as ``env.observation_space``.
+ """
+ gym.utils.RecordConstructorArgs.__init__(
+ self, func=func, observation_space=observation_space
+ )
+ gym.ObservationWrapper.__init__(self, env)
+
+ if observation_space is not None:
+ self.observation_space = observation_space
+
+ self.func = func
+
+ def observation(self, observation: ObsType) -> Any:
+ """Apply function to the observation."""
+ return self.func(observation)
+
+
+
+
+[docs]
+class FilterObservation(
+ TransformObservation[WrapperObsType, ActType, ObsType],
+ gym.utils.RecordConstructorArgs,
+):
+ """Filters a Dict or Tuple observation spaces by a set of keys or indexes.
+
+ A vector version of the wrapper exists :class:`gymnasium.wrappers.vector.FilterObservation`.
+
+ Example:
+ >>> import gymnasium as gym
+ >>> from gymnasium.wrappers import FilterObservation
+ >>> env = gym.make("CartPole-v1")
+ >>> env = gym.wrappers.TimeAwareObservation(env, flatten=False)
+ >>> env.observation_space
+ Dict('obs': Box([-4.8 -inf -0.41887903 -inf], [4.8 inf 0.41887903 inf], (4,), float32), 'time': Box(0, 500, (1,), int32))
+ >>> env.reset(seed=42)
+ ({'obs': array([ 0.0273956 , -0.00611216, 0.03585979, 0.0197368 ], dtype=float32), 'time': array([0], dtype=int32)}, {})
+ >>> env = FilterObservation(env, filter_keys=['time'])
+ >>> env.reset(seed=42)
+ ({'time': array([0], dtype=int32)}, {})
+ >>> env.step(0)
+ ({'time': array([1], dtype=int32)}, 1.0, False, False, {})
+
+ Change logs:
+ * v0.12.3 - Initially added, originally called `FilterObservationWrapper`
+ * v1.0.0 - Rename to `FilterObservation` and add support for tuple observation spaces with integer ``filter_keys``
+ """
+
+ def __init__(
+ self, env: gym.Env[ObsType, ActType], filter_keys: Sequence[str | int]
+ ):
+ """Constructor for the filter observation wrapper.
+
+ Args:
+ env: The environment to wrap
+ filter_keys: The set of subspaces to be *included*, use a list of strings for ``Dict`` and integers for ``Tuple`` spaces
+ """
+ if not isinstance(filter_keys, Sequence):
+ raise TypeError(
+ f"Expects `filter_keys` to be a Sequence, actual type: {type(filter_keys)}"
+ )
+ gym.utils.RecordConstructorArgs.__init__(self, filter_keys=filter_keys)
+
+ # Filters for dictionary space
+ if isinstance(env.observation_space, spaces.Dict):
+ assert all(isinstance(key, str) for key in filter_keys)
+
+ if any(
+ key not in env.observation_space.spaces.keys() for key in filter_keys
+ ):
+ missing_keys = [
+ key
+ for key in filter_keys
+ if key not in env.observation_space.spaces.keys()
+ ]
+ raise ValueError(
+ "All the `filter_keys` must be included in the observation space.\n"
+ f"Filter keys: {filter_keys}\n"
+ f"Observation keys: {list(env.observation_space.spaces.keys())}\n"
+ f"Missing keys: {missing_keys}"
+ )
+
+ new_observation_space = spaces.Dict(
+ {key: env.observation_space[key] for key in filter_keys}
+ )
+ if len(new_observation_space) == 0:
+ raise ValueError(
+ "The observation space is empty due to filtering all of the keys."
+ )
+
+ TransformObservation.__init__(
+ self,
+ env=env,
+ func=lambda obs: {key: obs[key] for key in filter_keys},
+ observation_space=new_observation_space,
+ )
+ # Filter for tuple observation
+ elif isinstance(env.observation_space, spaces.Tuple):
+ assert all(isinstance(key, int) for key in filter_keys)
+ assert len(set(filter_keys)) == len(
+ filter_keys
+ ), f"Duplicate keys exist, filter_keys: {filter_keys}"
+
+ if any(
+ 0 < key and key >= len(env.observation_space) for key in filter_keys
+ ):
+ missing_index = [
+ key
+ for key in filter_keys
+ if 0 < key and key >= len(env.observation_space)
+ ]
+ raise ValueError(
+ "All the `filter_keys` must be included in the length of the observation space.\n"
+ f"Filter keys: {filter_keys}, length of observation: {len(env.observation_space)}, "
+ f"missing indexes: {missing_index}"
+ )
+
+ new_observation_spaces = spaces.Tuple(
+ env.observation_space[key] for key in filter_keys
+ )
+ if len(new_observation_spaces) == 0:
+ raise ValueError(
+ "The observation space is empty due to filtering all keys."
+ )
+
+ TransformObservation.__init__(
+ self,
+ env=env,
+ func=lambda obs: tuple(obs[key] for key in filter_keys),
+ observation_space=new_observation_spaces,
+ )
+ else:
+ raise ValueError(
+ f"FilterObservation wrapper is only usable with `Dict` and `Tuple` observations, actual type: {type(env.observation_space)}"
+ )
+
+ self.filter_keys: Final[Sequence[str | int]] = filter_keys
+
+
+
+
+[docs]
+class FlattenObservation(
+ TransformObservation[WrapperObsType, ActType, ObsType],
+ gym.utils.RecordConstructorArgs,
+):
+ """Flattens the environment's observation space and each observation from ``reset`` and ``step`` functions.
+
+ A vector version of the wrapper exists :class:`gymnasium.wrappers.vector.FlattenObservation`.
+
+ Example:
+ >>> import gymnasium as gym
+ >>> from gymnasium.wrappers import FlattenObservation
+ >>> env = gym.make("CarRacing-v3")
+ >>> env.observation_space.shape
+ (96, 96, 3)
+ >>> env = FlattenObservation(env)
+ >>> env.observation_space.shape
+ (27648,)
+ >>> obs, _ = env.reset()
+ >>> obs.shape
+ (27648,)
+
+ Change logs:
+ * v0.15.0 - Initially added
+ """
+
+ def __init__(self, env: gym.Env[ObsType, ActType]):
+ """Constructor for any environment's observation space that implements ``spaces.utils.flatten_space`` and ``spaces.utils.flatten``.
+
+ Args:
+ env: The environment to wrap
+ """
+ gym.utils.RecordConstructorArgs.__init__(self)
+ TransformObservation.__init__(
+ self,
+ env=env,
+ func=lambda obs: spaces.utils.flatten(env.observation_space, obs),
+ observation_space=spaces.utils.flatten_space(env.observation_space),
+ )
+
+
+
+
+[docs]
+class GrayscaleObservation(
+ TransformObservation[WrapperObsType, ActType, ObsType],
+ gym.utils.RecordConstructorArgs,
+):
+ """Converts an image observation computed by ``reset`` and ``step`` from RGB to Grayscale.
+
+ The :attr:`keep_dim` will keep the channel dimension.
+
+ A vector version of the wrapper exists :class:`gymnasium.wrappers.vector.GrayscaleObservation`.
+
+ Example:
+ >>> import gymnasium as gym
+ >>> from gymnasium.wrappers import GrayscaleObservation
+ >>> env = gym.make("CarRacing-v3")
+ >>> env.observation_space.shape
+ (96, 96, 3)
+ >>> grayscale_env = GrayscaleObservation(env)
+ >>> grayscale_env.observation_space.shape
+ (96, 96)
+ >>> grayscale_env = GrayscaleObservation(env, keep_dim=True)
+ >>> grayscale_env.observation_space.shape
+ (96, 96, 1)
+
+ Change logs:
+ * v0.15.0 - Initially added, originally called ``GrayScaleObservation``
+ * v1.0.0 - Renamed to ``GrayscaleObservation``
+ """
+
+ def __init__(self, env: gym.Env[ObsType, ActType], keep_dim: bool = False):
+ """Constructor for an RGB image based environments to make the image grayscale.
+
+ Args:
+ env: The environment to wrap
+ keep_dim: If to keep the channel in the observation, if ``True``, ``obs.shape == 3`` else ``obs.shape == 2``
+ """
+ assert isinstance(env.observation_space, spaces.Box)
+ assert (
+ len(env.observation_space.shape) == 3
+ and env.observation_space.shape[-1] == 3
+ )
+ assert (
+ np.all(env.observation_space.low == 0)
+ and np.all(env.observation_space.high == 255)
+ and env.observation_space.dtype == np.uint8
+ )
+ gym.utils.RecordConstructorArgs.__init__(self, keep_dim=keep_dim)
+
+ self.keep_dim: Final[bool] = keep_dim
+ if keep_dim:
+ new_observation_space = spaces.Box(
+ low=0,
+ high=255,
+ shape=env.observation_space.shape[:2] + (1,),
+ dtype=np.uint8,
+ )
+ TransformObservation.__init__(
+ self,
+ env=env,
+ func=lambda obs: np.expand_dims(
+ np.sum(
+ np.multiply(obs, np.array([0.2125, 0.7154, 0.0721])), axis=-1
+ ).astype(np.uint8),
+ axis=-1,
+ ),
+ observation_space=new_observation_space,
+ )
+ else:
+ new_observation_space = spaces.Box(
+ low=0, high=255, shape=env.observation_space.shape[:2], dtype=np.uint8
+ )
+ TransformObservation.__init__(
+ self,
+ env=env,
+ func=lambda obs: np.sum(
+ np.multiply(obs, np.array([0.2125, 0.7154, 0.0721])), axis=-1
+ ).astype(np.uint8),
+ observation_space=new_observation_space,
+ )
+
+
+
+
+[docs]
+class ResizeObservation(
+ TransformObservation[WrapperObsType, ActType, ObsType],
+ gym.utils.RecordConstructorArgs,
+):
+ """Resizes image observations using OpenCV to a specified shape.
+
+ A vector version of the wrapper exists :class:`gymnasium.wrappers.vector.ResizeObservation`.
+
+ Example:
+ >>> import gymnasium as gym
+ >>> from gymnasium.wrappers import ResizeObservation
+ >>> env = gym.make("CarRacing-v3")
+ >>> env.observation_space.shape
+ (96, 96, 3)
+ >>> resized_env = ResizeObservation(env, (32, 32))
+ >>> resized_env.observation_space.shape
+ (32, 32, 3)
+
+ Change logs:
+ * v0.12.6 - Initially added
+ * v1.0.0 - Requires ``shape`` with a tuple of two integers
+ """
+
+ def __init__(self, env: gym.Env[ObsType, ActType], shape: tuple[int, int]):
+ """Constructor that requires an image environment observation space with a shape.
+
+ Args:
+ env: The environment to wrap
+ shape: The resized observation shape
+ """
+ assert isinstance(env.observation_space, spaces.Box)
+ assert len(env.observation_space.shape) in {2, 3}
+ assert np.all(env.observation_space.low == 0) and np.all(
+ env.observation_space.high == 255
+ )
+ assert env.observation_space.dtype == np.uint8
+
+ assert isinstance(shape, tuple)
+ assert len(shape) == 2
+ assert all(np.issubdtype(type(elem), np.integer) for elem in shape)
+ assert all(x > 0 for x in shape)
+
+ try:
+ import cv2
+ except ImportError as e:
+ raise DependencyNotInstalled(
+ 'opencv (cv2) is not installed, run `pip install "gymnasium[other]"`'
+ ) from e
+
+ self.shape: Final[tuple[int, int]] = tuple(shape)
+ # for some reason, cv2.resize will return the shape in reverse, todo confirm implementation
+ self.cv2_shape: Final[tuple[int, int]] = (shape[1], shape[0])
+
+ new_observation_space = spaces.Box(
+ low=0,
+ high=255,
+ shape=self.shape + env.observation_space.shape[2:],
+ dtype=np.uint8,
+ )
+
+ gym.utils.RecordConstructorArgs.__init__(self, shape=shape)
+ TransformObservation.__init__(
+ self,
+ env=env,
+ func=lambda obs: cv2.resize(
+ obs, self.cv2_shape, interpolation=cv2.INTER_AREA
+ ),
+ observation_space=new_observation_space,
+ )
+
+
+
+
+[docs]
+class ReshapeObservation(
+ TransformObservation[WrapperObsType, ActType, ObsType],
+ gym.utils.RecordConstructorArgs,
+):
+ """Reshapes Array based observations to a specified shape.
+
+ A vector version of the wrapper exists :class:`gymnasium.wrappers.vector.RescaleObservation`.
+
+ Example:
+ >>> import gymnasium as gym
+ >>> from gymnasium.wrappers import ReshapeObservation
+ >>> env = gym.make("CarRacing-v3")
+ >>> env.observation_space.shape
+ (96, 96, 3)
+ >>> reshape_env = ReshapeObservation(env, (24, 4, 96, 1, 3))
+ >>> reshape_env.observation_space.shape
+ (24, 4, 96, 1, 3)
+
+ Change logs:
+ * v1.0.0 - Initially added
+ """
+
+ def __init__(self, env: gym.Env[ObsType, ActType], shape: int | tuple[int, ...]):
+ """Constructor for env with ``Box`` observation space that has a shape product equal to the new shape product.
+
+ Args:
+ env: The environment to wrap
+ shape: The reshaped observation space
+ """
+ assert isinstance(env.observation_space, spaces.Box)
+ assert np.prod(shape) == np.prod(env.observation_space.shape)
+
+ assert isinstance(shape, tuple)
+ assert all(np.issubdtype(type(elem), np.integer) for elem in shape)
+ assert all(x > 0 or x == -1 for x in shape)
+
+ new_observation_space = spaces.Box(
+ low=np.reshape(np.ravel(env.observation_space.low), shape),
+ high=np.reshape(np.ravel(env.observation_space.high), shape),
+ shape=shape,
+ dtype=env.observation_space.dtype,
+ )
+ self.shape = shape
+
+ gym.utils.RecordConstructorArgs.__init__(self, shape=shape)
+ TransformObservation.__init__(
+ self,
+ env=env,
+ func=lambda obs: np.reshape(obs, shape),
+ observation_space=new_observation_space,
+ )
+
+
+
+
+[docs]
+class RescaleObservation(
+ TransformObservation[WrapperObsType, ActType, ObsType],
+ gym.utils.RecordConstructorArgs,
+):
+ """Affinely (linearly) rescales a ``Box`` observation space of the environment to within the range of ``[min_obs, max_obs]``.
+
+ For unbounded components in the original observation space, the corresponding target bounds must also be infinite and vice versa.
+
+ A vector version of the wrapper exists :class:`gymnasium.wrappers.vector.RescaleObservation`.
+
+ Example:
+ >>> import gymnasium as gym
+ >>> from gymnasium.wrappers import RescaleObservation
+ >>> env = gym.make("Pendulum-v1")
+ >>> env.observation_space
+ Box([-1. -1. -8.], [1. 1. 8.], (3,), float32)
+ >>> env = RescaleObservation(env, np.array([-2, -1, -10], dtype=np.float32), np.array([1, 0, 1], dtype=np.float32))
+ >>> env.observation_space
+ Box([ -2. -1. -10.], [1. 0. 1.], (3,), float32)
+
+ Change logs:
+ * v1.0.0 - Initially added
+ """
+
+ def __init__(
+ self,
+ env: gym.Env[ObsType, ActType],
+ min_obs: np.floating | np.integer | np.ndarray,
+ max_obs: np.floating | np.integer | np.ndarray,
+ ):
+ """Constructor that requires the env observation spaces to be a :class:`Box`.
+
+ Args:
+ env: The environment to wrap
+ min_obs: The new minimum observation bound
+ max_obs: The new maximum observation bound
+ """
+ assert isinstance(env.observation_space, spaces.Box)
+
+ gym.utils.RecordConstructorArgs.__init__(self, min_obs=min_obs, max_obs=max_obs)
+
+ obs_space, func, _ = rescale_box(env.observation_space, min_obs, max_obs)
+ TransformObservation.__init__(
+ self,
+ env=env,
+ func=func,
+ observation_space=obs_space,
+ )
+
+
+
+
+[docs]
+class DtypeObservation(
+ TransformObservation[WrapperObsType, ActType, ObsType],
+ gym.utils.RecordConstructorArgs,
+):
+ """Modifies the dtype of an observation array to a specified dtype.
+
+ Note:
+ This is only compatible with :class:`Box`, :class:`Discrete`, :class:`MultiDiscrete` and :class:`MultiBinary` observation spaces
+
+ A vector version of the wrapper exists :class:`gymnasium.wrappers.vector.DtypeObservation`.
+
+ Change logs:
+ * v1.0.0 - Initially added
+ """
+
+ def __init__(self, env: gym.Env[ObsType, ActType], dtype: Any):
+ """Constructor for Dtype observation wrapper.
+
+ Args:
+ env: The environment to wrap
+ dtype: The new dtype of the observation
+ """
+ assert isinstance(
+ env.observation_space,
+ (spaces.Box, spaces.Discrete, spaces.MultiDiscrete, spaces.MultiBinary),
+ )
+
+ self.dtype = dtype
+ if isinstance(env.observation_space, spaces.Box):
+ new_observation_space = spaces.Box(
+ low=env.observation_space.low,
+ high=env.observation_space.high,
+ shape=env.observation_space.shape,
+ dtype=self.dtype,
+ )
+ elif isinstance(env.observation_space, spaces.Discrete):
+ new_observation_space = spaces.Box(
+ low=env.observation_space.start,
+ high=env.observation_space.start + env.observation_space.n,
+ shape=(),
+ dtype=self.dtype,
+ )
+ elif isinstance(env.observation_space, spaces.MultiDiscrete):
+ new_observation_space = spaces.MultiDiscrete(
+ env.observation_space.nvec, dtype=dtype
+ )
+ elif isinstance(env.observation_space, spaces.MultiBinary):
+ new_observation_space = spaces.Box(
+ low=0,
+ high=1,
+ shape=env.observation_space.shape,
+ dtype=self.dtype,
+ )
+ else:
+ raise TypeError(
+ "DtypeObservation is only compatible with value / array-based observations."
+ )
+
+ gym.utils.RecordConstructorArgs.__init__(self, dtype=dtype)
+ TransformObservation.__init__(
+ self,
+ env=env,
+ func=lambda obs: dtype(obs),
+ observation_space=new_observation_space,
+ )
+
+
+
+
+[docs]
+class AddRenderObservation(
+ TransformObservation[WrapperObsType, ActType, ObsType],
+ gym.utils.RecordConstructorArgs,
+):
+ """Includes the rendered observations in the environment's observations.
+
+ Notes:
+ This was previously called ``PixelObservationWrapper``.
+
+ No vector version of the wrapper exists.
+
+ Example - Replace the observation with the rendered image:
+ >>> env = gym.make("CartPole-v1", render_mode="rgb_array")
+ >>> env = AddRenderObservation(env, render_only=True)
+ >>> env.observation_space
+ Box(0, 255, (400, 600, 3), uint8)
+ >>> obs, _ = env.reset(seed=123)
+ >>> image = env.render()
+ >>> np.all(obs == image)
+ np.True_
+ >>> obs, *_ = env.step(env.action_space.sample())
+ >>> image = env.render()
+ >>> np.all(obs == image)
+ np.True_
+
+ Example - Add the rendered image to the original observation as a dictionary item:
+ >>> env = gym.make("CartPole-v1", render_mode="rgb_array")
+ >>> env = AddRenderObservation(env, render_only=False)
+ >>> env.observation_space
+ Dict('pixels': Box(0, 255, (400, 600, 3), uint8), 'state': Box([-4.8 -inf -0.41887903 -inf], [4.8 inf 0.41887903 inf], (4,), float32))
+ >>> obs, info = env.reset(seed=123)
+ >>> obs.keys()
+ dict_keys(['state', 'pixels'])
+ >>> obs["state"]
+ array([ 0.01823519, -0.0446179 , -0.02796401, -0.03156282], dtype=float32)
+ >>> np.all(obs["pixels"] == env.render())
+ np.True_
+ >>> obs, reward, terminates, truncates, info = env.step(env.action_space.sample())
+ >>> image = env.render()
+ >>> np.all(obs["pixels"] == image)
+ np.True_
+
+ Change logs:
+ * v0.15.0 - Initially added as ``PixelObservationWrapper``
+ * v1.0.0 - Renamed to ``AddRenderObservation``
+ """
+
+ def __init__(
+ self,
+ env: gym.Env[ObsType, ActType],
+ render_only: bool = True,
+ render_key: str = "pixels",
+ obs_key: str = "state",
+ ):
+ """Constructor of the add render observation wrapper.
+
+ Args:
+ env: The environment to wrap.
+ render_only (bool): If ``True`` (default), the original observation returned
+ by the wrapped environment will be discarded, and a dictionary
+ observation will only include pixels. If ``False``, the
+ observation dictionary will contain both the original
+ observations and the pixel observations.
+ render_key: Optional custom string specifying the pixel key. Defaults to "pixels"
+ obs_key: Optional custom string specifying the obs key. Defaults to "state"
+ """
+ gym.utils.RecordConstructorArgs.__init__(
+ self,
+ pixels_only=render_only,
+ pixels_key=render_key,
+ obs_key=obs_key,
+ )
+
+ assert env.render_mode is not None and env.render_mode != "human"
+ env.reset()
+ pixels = env.render()
+ assert pixels is not None and isinstance(pixels, np.ndarray)
+ pixel_space = spaces.Box(low=0, high=255, shape=pixels.shape, dtype=np.uint8)
+
+ if render_only:
+ obs_space = pixel_space
+ TransformObservation.__init__(
+ self, env=env, func=lambda _: self.render(), observation_space=obs_space
+ )
+ elif isinstance(env.observation_space, spaces.Dict):
+ assert render_key not in env.observation_space.spaces.keys()
+
+ obs_space = spaces.Dict(
+ {render_key: pixel_space, **env.observation_space.spaces}
+ )
+ TransformObservation.__init__(
+ self,
+ env=env,
+ func=lambda obs: {render_key: self.render(), **obs},
+ observation_space=obs_space,
+ )
+ else:
+ obs_space = spaces.Dict(
+ {obs_key: env.observation_space, render_key: pixel_space}
+ )
+ TransformObservation.__init__(
+ self,
+ env=env,
+ func=lambda obs: {obs_key: obs, render_key: self.render()},
+ observation_space=obs_space,
+ )
+
+
+"""A collection of wrappers for modifying the reward.
+
+* ``TransformReward`` - Transforms the reward by a function
+* ``ClipReward`` - Clips the reward between a minimum and maximum value
+"""
+
+from __future__ import annotations
+
+from collections.abc import Callable
+from typing import SupportsFloat
+
+import numpy as np
+
+import gymnasium as gym
+from gymnasium.core import ActType, ObsType
+from gymnasium.error import InvalidBound
+
+
+__all__ = ["TransformReward", "ClipReward"]
+
+
+
+[docs]
+class TransformReward(
+ gym.RewardWrapper[ObsType, ActType], gym.utils.RecordConstructorArgs
+):
+ """Applies a function to the ``reward`` received from the environment's ``step``.
+
+ A vector version of the wrapper exists :class:`gymnasium.wrappers.vector.TransformReward`.
+
+ Example:
+ >>> import gymnasium as gym
+ >>> from gymnasium.wrappers import TransformReward
+ >>> env = gym.make("CartPole-v1")
+ >>> env = TransformReward(env, lambda r: 2 * r + 1)
+ >>> _ = env.reset()
+ >>> _, rew, _, _, _ = env.step(0)
+ >>> rew
+ 3.0
+
+ Change logs:
+ * v0.15.0 - Initially added
+ """
+
+ def __init__(
+ self,
+ env: gym.Env[ObsType, ActType],
+ func: Callable[[SupportsFloat], SupportsFloat],
+ ):
+ """Initialize TransformReward wrapper.
+
+ Args:
+ env (Env): The environment to wrap
+ func: (Callable): The function to apply to reward
+ """
+ gym.utils.RecordConstructorArgs.__init__(self, func=func)
+ gym.RewardWrapper.__init__(self, env)
+
+ self.func = func
+
+ def reward(self, reward: SupportsFloat) -> SupportsFloat:
+ """Apply function to reward.
+
+ Args:
+ reward (Union[float, int, np.ndarray]): environment's reward
+ """
+ return self.func(reward)
+
+
+
+
+[docs]
+class ClipReward(TransformReward[ObsType, ActType], gym.utils.RecordConstructorArgs):
+ """Clips the rewards for an environment between an upper and lower bound.
+
+ A vector version of the wrapper exists :class:`gymnasium.wrappers.vector.ClipReward`.
+
+ Example:
+ >>> import gymnasium as gym
+ >>> from gymnasium.wrappers import ClipReward
+ >>> env = gym.make("CartPole-v1")
+ >>> env = ClipReward(env, 0, 0.5)
+ >>> _ = env.reset()
+ >>> _, rew, _, _, _ = env.step(1)
+ >>> rew
+ np.float64(0.5)
+
+ Change logs:
+ * v1.0.0 - Initially added
+ """
+
+ def __init__(
+ self,
+ env: gym.Env[ObsType, ActType],
+ min_reward: float | np.ndarray | None = None,
+ max_reward: float | np.ndarray | None = None,
+ ):
+ """Initialize ClipRewards wrapper.
+
+ Args:
+ env (Env): The environment to wrap
+ min_reward (Union[float, np.ndarray]): lower bound to apply
+ max_reward (Union[float, np.ndarray]): higher bound to apply
+ """
+ if min_reward is None and max_reward is None:
+ raise InvalidBound("Both `min_reward` and `max_reward` cannot be None")
+
+ elif max_reward is not None and min_reward is not None:
+ if np.any(max_reward - min_reward < 0):
+ raise InvalidBound(
+ f"Min reward ({min_reward}) must be smaller than max reward ({max_reward})"
+ )
+
+ gym.utils.RecordConstructorArgs.__init__(
+ self, min_reward=min_reward, max_reward=max_reward
+ )
+ TransformReward.__init__(
+ self, env=env, func=lambda x: np.clip(x, a_min=min_reward, a_max=max_reward)
+ )
+
+
+"""Vector wrapper for converting between Array API compatible frameworks."""
+
+from __future__ import annotations
+
+from types import ModuleType
+from typing import Any
+
+import gymnasium as gym
+from gymnasium.core import ActType, ObsType
+from gymnasium.vector import VectorEnv, VectorWrapper
+from gymnasium.vector.vector_env import ArrayType
+from gymnasium.wrappers.array_conversion import (
+ Device,
+ array_conversion,
+ module_name_to_namespace,
+)
+
+
+__all__ = ["ArrayConversion"]
+
+
+
+[docs]
+class ArrayConversion(VectorWrapper, gym.utils.RecordConstructorArgs):
+ """Wraps a vector environment returning Array API compatible arrays so that it can be interacted with through a specific framework.
+
+ Popular Array API frameworks include ``numpy``, ``torch``, ``jax.numpy``, ``cupy`` etc. With this wrapper, you can convert outputs from your environment to
+ any of these frameworks. Conversely, actions are automatically mapped back to the environment framework, if possible without moving the
+ data or device transfers.
+
+ Notes:
+ A vectorized version of :class:`gymnasium.wrappers.ArrayConversion`
+
+ Example:
+ >>> import gymnasium as gym # doctest: +SKIP
+ >>> envs = gym.make_vec("JaxEnv-vx", 3) # doctest: +SKIP
+ >>> envs = ArrayConversion(envs, xp=np) # doctest: +SKIP
+ """
+
+ def __init__(
+ self,
+ env: VectorEnv,
+ env_xp: ModuleType | str,
+ target_xp: ModuleType | str,
+ env_device: Device | None = None,
+ target_device: Device | None = None,
+ ):
+ """Wrapper class to change inputs and outputs of environment to any Array API framework.
+
+ Args:
+ env: The Array API compatible environment to wrap
+ env_xp: The Array API framework the environment is on
+ target_xp: The Array API framework to convert to
+ env_device: The device the environment is on
+ target_device: The device on which Arrays should be returned
+ """
+ gym.utils.RecordConstructorArgs.__init__(self)
+ VectorWrapper.__init__(self, env)
+ self._env_xp = env_xp
+ self._target_xp = target_xp
+ self._env_device = env_device
+ self._target_device = target_device
+
+ def step(
+ self, actions: ActType
+ ) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]:
+ """Transforms the action to the specified xp module array type.
+
+ Args:
+ actions: The action to perform
+
+ Returns:
+ A tuple containing xp versions of the next observation, reward, termination, truncation, and extra info.
+ """
+ actions = array_conversion(actions, xp=self._env_xp, device=self._env_device)
+ obs, reward, terminated, truncated, info = self.env.step(actions)
+
+ return (
+ array_conversion(obs, xp=self._target_xp, device=self._target_device),
+ array_conversion(reward, xp=self._target_xp, device=self._target_device),
+ array_conversion(
+ terminated, xp=self._target_xp, device=self._target_device
+ ),
+ array_conversion(truncated, xp=self._target_xp, device=self._target_device),
+ array_conversion(info, xp=self._target_xp, device=self._target_device),
+ )
+
+ def reset(
+ self,
+ *,
+ seed: int | list[int] | None = None,
+ options: dict[str, Any] | None = None,
+ ) -> tuple[ObsType, dict[str, Any]]:
+ """Resets the environment returning xp-based observation and info.
+
+ Args:
+ seed: The seed for resetting the environment
+ options: The options for resetting the environment, these are converted to xp arrays.
+
+ Returns:
+ xp-based observations and info
+ """
+ if options:
+ options = array_conversion(
+ options, xp=self._env_xp, device=self._env_device
+ )
+
+ return array_conversion(
+ self.env.reset(seed=seed, options=options),
+ xp=self._target_xp,
+ device=self._target_device,
+ )
+
+ def __getstate__(self):
+ """Returns the object pickle state with args and kwargs."""
+ env_xp_name = self._env_xp.__name__.replace("array_api_compat.", "")
+ target_xp_name = self._target_xp.__name__.replace("array_api_compat.", "")
+ env_device = self._env_device
+ target_device = self._target_device
+ return {
+ "env_xp_name": env_xp_name,
+ "target_xp_name": target_xp_name,
+ "env_device": env_device,
+ "target_device": target_device,
+ "env": self.env,
+ }
+
+ def __setstate__(self, d):
+ """Sets the object pickle state using d."""
+ self.env = d["env"]
+ self._env_xp = module_name_to_namespace(d["env_xp_name"])
+ self._target_xp = module_name_to_namespace(d["target_xp_name"])
+ self._env_device = d["env_device"]
+ self._target_device = d["target_device"]
+
+
+"""Wrapper that tracks the cumulative rewards and episode lengths."""
+
+from __future__ import annotations
+
+import time
+from collections import deque
+
+import numpy as np
+
+from gymnasium.core import ActType, ObsType
+from gymnasium.logger import warn
+from gymnasium.vector.vector_env import (
+ ArrayType,
+ AutoresetMode,
+ VectorEnv,
+ VectorWrapper,
+)
+
+
+__all__ = ["RecordEpisodeStatistics"]
+
+
+
+[docs]
+class RecordEpisodeStatistics(VectorWrapper):
+ """This wrapper will keep track of cumulative rewards and episode lengths.
+
+ At the end of any episode within the vectorized env, the statistics of the episode
+ will be added to ``info`` using the key ``episode``, and the ``_episode`` key
+ is used to indicate the environment index which has a terminated or truncated episode.
+
+ >>> infos = { # doctest: +SKIP
+ ... ...
+ ... "episode": {
+ ... "r": "<array of cumulative reward for each done sub-environment>",
+ ... "l": "<array of episode length for each done sub-environment>",
+ ... "t": "<array of elapsed time since beginning of episode for each done sub-environment>"
+ ... },
+ ... "_episode": "<boolean array of length num-envs>"
+ ... }
+
+ Moreover, the most recent rewards and episode lengths are stored in buffers that can be accessed via
+ :attr:`wrapped_env.return_queue` and :attr:`wrapped_env.length_queue` respectively.
+
+ Attributes:
+ return_queue: The cumulative rewards of the last ``deque_size``-many episodes
+ length_queue: The lengths of the last ``deque_size``-many episodes
+
+ Example:
+ >>> from pprint import pprint
+ >>> import gymnasium as gym
+ >>> envs = gym.make_vec("CartPole-v1", num_envs=3)
+ >>> envs = RecordEpisodeStatistics(envs)
+ >>> obs, info = envs.reset(123)
+ >>> _ = envs.action_space.seed(123)
+ >>> end = False
+ >>> while not end:
+ ... obs, rew, term, trunc, info = envs.step(envs.action_space.sample())
+ ... end = term.any() or trunc.any()
+ ...
+ >>> envs.close()
+ >>> pprint(info) # doctest: +SKIP
+ {'_episode': array([ True, False, False]),
+ '_final_info': array([ True, False, False]),
+ '_final_observation': array([ True, False, False]),
+ 'episode': {'l': array([11, 0, 0], dtype=int32),
+ 'r': array([11., 0., 0.], dtype=float32),
+ 't': array([0.007812, 0. , 0. ], dtype=float32)},
+ 'final_info': array([{}, None, None], dtype=object),
+ 'final_observation': array([array([ 0.11448676, 0.9416149 , -0.20946532, -1.7619033 ], dtype=float32),
+ None, None], dtype=object)}
+ """
+
+ def __init__(
+ self,
+ env: VectorEnv,
+ buffer_length: int = 100,
+ stats_key: str = "episode",
+ ):
+ """This wrapper will keep track of cumulative rewards and episode lengths.
+
+ Args:
+ env (Env): The environment to apply the wrapper
+ buffer_length: The size of the buffers :attr:`return_queue`, :attr:`length_queue` and :attr:`time_queue`
+ stats_key: The info key to save the data
+ """
+ super().__init__(env)
+ self._stats_key = stats_key
+ if "autoreset_mode" not in self.env.metadata:
+ warn("todo")
+ self._autoreset_mode = AutoresetMode.NEXT_STEP
+ else:
+ assert isinstance(self.env.metadata["autoreset_mode"], AutoresetMode)
+ self._autoreset_mode = self.env.metadata["autoreset_mode"]
+
+ self.episode_count = 0
+
+ self.episode_start_times: np.ndarray = np.zeros((self.num_envs,))
+ self.episode_returns: np.ndarray = np.zeros((self.num_envs,))
+ self.episode_lengths: np.ndarray = np.zeros((self.num_envs,), dtype=int)
+ self.prev_dones: np.ndarray = np.zeros((self.num_envs,), dtype=bool)
+
+ self.time_queue = deque(maxlen=buffer_length)
+ self.return_queue = deque(maxlen=buffer_length)
+ self.length_queue = deque(maxlen=buffer_length)
+
+ def reset(
+ self,
+ seed: int | list[int] | None = None,
+ options: dict | None = None,
+ ):
+ """Resets the environment using kwargs and resets the episode returns and lengths."""
+ obs, info = super().reset(seed=seed, options=options)
+
+ if options is not None and "reset_mask" in options:
+ reset_mask = options.pop("reset_mask")
+ assert isinstance(
+ reset_mask, np.ndarray
+ ), f"`options['reset_mask': mask]` must be a numpy array, got {type(reset_mask)}"
+ assert reset_mask.shape == (
+ self.num_envs,
+ ), f"`options['reset_mask': mask]` must have shape `({self.num_envs},)`, got {reset_mask.shape}"
+ assert (
+ reset_mask.dtype == np.bool_
+ ), f"`options['reset_mask': mask]` must have `dtype=np.bool_`, got {reset_mask.dtype}"
+ assert np.any(
+ reset_mask
+ ), f"`options['reset_mask': mask]` must contain a boolean array, got reset_mask={reset_mask}"
+
+ self.episode_start_times[reset_mask] = time.perf_counter()
+ self.episode_returns[reset_mask] = 0
+ self.episode_lengths[reset_mask] = 0
+ self.prev_dones[reset_mask] = False
+ else:
+ self.episode_start_times = np.full(self.num_envs, time.perf_counter())
+ self.episode_returns = np.zeros(self.num_envs)
+ self.episode_lengths = np.zeros(self.num_envs, dtype=int)
+ self.prev_dones = np.zeros(self.num_envs, dtype=bool)
+
+ return obs, info
+
+ def step(
+ self, actions: ActType
+ ) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]:
+ """Steps through the environment, recording the episode statistics."""
+ (
+ observations,
+ rewards,
+ terminations,
+ truncations,
+ infos,
+ ) = self.env.step(actions)
+
+ assert isinstance(
+ infos, dict
+ ), f"`vector.RecordEpisodeStatistics` requires `info` type to be `dict`, its actual type is {type(infos)}. This may be due to usage of other wrappers in the wrong order."
+
+ self.episode_returns[self.prev_dones] = 0
+ self.episode_returns[np.logical_not(self.prev_dones)] += rewards[
+ np.logical_not(self.prev_dones)
+ ]
+
+ self.episode_lengths[self.prev_dones] = 0
+ self.episode_lengths[~self.prev_dones] += 1
+
+ self.episode_start_times[self.prev_dones] = time.perf_counter()
+
+ self.prev_dones = dones = np.logical_or(terminations, truncations)
+ num_dones = np.sum(dones)
+
+ if num_dones:
+ if self._stats_key in infos or f"_{self._stats_key}" in infos:
+ raise ValueError(
+ f"Attempted to add episode stats with key '{self._stats_key}' but this key already exists in info: {list(infos.keys())}"
+ )
+ else:
+ episode_time_length = np.round(
+ time.perf_counter() - self.episode_start_times, 6
+ )
+ infos[self._stats_key] = {
+ "r": np.where(dones, self.episode_returns, 0.0),
+ "l": np.where(dones, self.episode_lengths, 0),
+ "t": np.where(dones, episode_time_length, 0.0),
+ }
+ infos[f"_{self._stats_key}"] = dones
+
+ self.episode_count += num_dones
+
+ for i in np.where(dones):
+ self.time_queue.extend(episode_time_length[i])
+ self.return_queue.extend(self.episode_returns[i])
+ self.length_queue.extend(self.episode_lengths[i])
+
+ return (
+ observations,
+ rewards,
+ terminations,
+ truncations,
+ infos,
+ )
+
+
+"""Wrapper that converts the info format for vec envs into the list format."""
+
+from __future__ import annotations
+
+from typing import Any
+
+import numpy as np
+
+from gymnasium.core import ActType, ObsType
+from gymnasium.vector.vector_env import ArrayType, VectorEnv, VectorWrapper
+
+
+__all__ = ["DictInfoToList"]
+
+
+
+[docs]
+class DictInfoToList(VectorWrapper):
+ """Converts infos of vectorized environments from ``dict`` to ``List[dict]``.
+
+ This wrapper converts the info format of a
+ vector environment from a dictionary to a list of dictionaries.
+ This wrapper is intended to be used around vectorized
+ environments. If using other wrappers that perform
+ operation on info like `RecordEpisodeStatistics` this
+ need to be the outermost wrapper.
+
+ i.e. ``DictInfoToList(RecordEpisodeStatistics(vector_env))``
+
+ Example:
+ >>> import numpy as np
+ >>> dict_info = {
+ ... "k": np.array([0., 0., 0.5, 0.3]),
+ ... "_k": np.array([False, False, True, True])
+ ... }
+ ...
+ >>> list_info = [{}, {}, {"k": 0.5}, {"k": 0.3}]
+
+ Example for vector environments:
+ >>> import numpy as np
+ >>> import gymnasium as gym
+ >>> envs = gym.make_vec("CartPole-v1", num_envs=3)
+ >>> obs, info = envs.reset(seed=123)
+ >>> info
+ {}
+ >>> envs = DictInfoToList(envs)
+ >>> obs, info = envs.reset(seed=123)
+ >>> info
+ [{}, {}, {}]
+
+ Another example for vector environments:
+ >>> import numpy as np
+ >>> import gymnasium as gym
+ >>> envs = gym.make_vec("HalfCheetah-v4", num_envs=2)
+ >>> _ = envs.reset(seed=123)
+ >>> _ = envs.action_space.seed(123)
+ >>> _, _, _, _, infos = envs.step(envs.action_space.sample())
+ >>> infos
+ {'x_position': array([0.03332211, 0.10172355]), '_x_position': array([ True, True]), 'x_velocity': array([-0.06296527, 0.89345848]), '_x_velocity': array([ True, True]), 'reward_run': array([-0.06296527, 0.89345848]), '_reward_run': array([ True, True]), 'reward_ctrl': array([-0.24503504, -0.21944423], dtype=float32), '_reward_ctrl': array([ True, True])}
+ >>> envs = DictInfoToList(envs)
+ >>> _ = envs.reset(seed=123)
+ >>> _ = envs.action_space.seed(123)
+ >>> _, _, _, _, infos = envs.step(envs.action_space.sample())
+ >>> infos
+ [{'x_position': np.float64(0.0333221090036294), 'x_velocity': np.float64(-0.06296527291998574), 'reward_run': np.float64(-0.06296527291998574), 'reward_ctrl': np.float32(-0.24503504)}, {'x_position': np.float64(0.10172354684460168), 'x_velocity': np.float64(0.8934584807363618), 'reward_run': np.float64(0.8934584807363618), 'reward_ctrl': np.float32(-0.21944423)}]
+
+ Change logs:
+ * v0.24.0 - Initially added as ``VectorListInfo``
+ * v1.0.0 - Renamed to ``DictInfoToList``
+ """
+
+ def __init__(self, env: VectorEnv):
+ """This wrapper will convert the info into the list format.
+
+ Args:
+ env (Env): The environment to apply the wrapper
+ """
+ super().__init__(env)
+
+ def step(
+ self, actions: ActType
+ ) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, list[dict[str, Any]]]:
+ """Steps through the environment, convert dict info to list."""
+ observation, reward, terminated, truncated, infos = self.env.step(actions)
+ assert isinstance(infos, dict)
+ list_info = self._convert_info_to_list(infos)
+
+ return observation, reward, terminated, truncated, list_info
+
+ def reset(
+ self,
+ *,
+ seed: int | list[int] | None = None,
+ options: dict[str, Any] | None = None,
+ ) -> tuple[ObsType, list[dict[str, Any]]]:
+ """Resets the environment using kwargs."""
+ obs, infos = self.env.reset(seed=seed, options=options)
+ assert isinstance(infos, dict)
+ list_info = self._convert_info_to_list(infos)
+
+ return obs, list_info
+
+ def _convert_info_to_list(self, vector_infos: dict) -> list[dict[str, Any]]:
+ """Convert the dict info to list.
+
+ Convert the dict info of the vectorized environment
+ into a list of dictionaries where the i-th dictionary
+ has the info of the i-th environment.
+
+ Args:
+ vector_infos (dict): info dict coming from the env.
+
+ Returns:
+ list_info (list): converted info.
+ """
+ list_info = [{} for _ in range(self.num_envs)]
+
+ for key, value in vector_infos.items():
+ if key.startswith("_"):
+ continue
+
+ if isinstance(value, dict):
+ value_list_info = self._convert_info_to_list(value)
+ for env_num, (env_info, has_info) in enumerate(
+ zip(value_list_info, vector_infos[f"_{key}"])
+ ):
+ if has_info:
+ list_info[env_num][key] = env_info
+ else:
+ assert isinstance(value, np.ndarray)
+ for env_num, has_info in enumerate(vector_infos[f"_{key}"]):
+ if has_info:
+ list_info[env_num][key] = value[env_num]
+
+ return list_info
+
+
+"""Vector wrapper for converting between NumPy and Jax."""
+
+from __future__ import annotations
+
+import jax.numpy as jnp
+import numpy as np
+
+from gymnasium.error import DependencyNotInstalled
+from gymnasium.vector import VectorEnv
+from gymnasium.wrappers.vector.array_conversion import ArrayConversion
+
+
+__all__ = ["JaxToNumpy"]
+
+
+
+[docs]
+class JaxToNumpy(ArrayConversion):
+ """Wraps a jax vector environment so that it can be interacted with through numpy arrays.
+
+ Notes:
+ A vectorized version of :class:`gymnasium.wrappers.JaxToNumpy`
+
+ Actions must be provided as numpy arrays and observations, rewards, terminations and truncations will be returned as numpy arrays.
+
+ Example:
+ >>> import gymnasium as gym # doctest: +SKIP
+ >>> envs = gym.make_vec("JaxEnv-vx", 3) # doctest: +SKIP
+ >>> envs = JaxToNumpy(envs) # doctest: +SKIP
+ """
+
+ def __init__(self, env: VectorEnv):
+ """Wraps an environment such that the input and outputs are numpy arrays.
+
+ Args:
+ env: the vector jax environment to wrap
+ """
+ if jnp is None:
+ raise DependencyNotInstalled(
+ 'Jax is not installed, run `pip install "gymnasium[jax]"`'
+ )
+ super().__init__(env, env_xp=jnp, target_xp=np)
+
+
+"""Vector wrapper class for converting between PyTorch and Jax."""
+
+from __future__ import annotations
+
+import jax.numpy as jnp
+import torch
+
+from gymnasium.vector import VectorEnv
+from gymnasium.wrappers.jax_to_torch import Device
+from gymnasium.wrappers.vector.array_conversion import ArrayConversion
+
+
+__all__ = ["JaxToTorch"]
+
+
+
+[docs]
+class JaxToTorch(ArrayConversion):
+ """Wraps a Jax-based vector environment so that it can be interacted with through PyTorch Tensors.
+
+ Actions must be provided as PyTorch Tensors and observations, rewards, terminations and truncations will be returned as PyTorch Tensors.
+
+ Example:
+ >>> import gymnasium as gym # doctest: +SKIP
+ >>> envs = gym.make_vec("JaxEnv-vx", 3) # doctest: +SKIP
+ >>> envs = JaxToTorch(envs) # doctest: +SKIP
+ """
+
+ def __init__(self, env: VectorEnv, device: Device | None = None):
+ """Vector wrapper to change inputs and outputs to PyTorch tensors.
+
+ Args:
+ env: The Jax-based vector environment to wrap
+ device: The device the torch Tensors should be moved to
+ """
+ super().__init__(env, env_xp=jnp, target_xp=torch, target_device=device)
+
+ self.device: Device | None = device
+
+
+"""Wrapper for converting NumPy environments to PyTorch."""
+
+from __future__ import annotations
+
+import numpy as np
+import torch
+
+from gymnasium.vector import VectorEnv
+from gymnasium.wrappers.numpy_to_torch import Device
+from gymnasium.wrappers.vector.array_conversion import ArrayConversion
+
+
+__all__ = ["NumpyToTorch"]
+
+
+
+[docs]
+class NumpyToTorch(ArrayConversion):
+ """Wraps a numpy-based environment so that it can be interacted with through PyTorch Tensors.
+
+ Example:
+ >>> import torch
+ >>> import gymnasium as gym
+ >>> from gymnasium.wrappers.vector import NumpyToTorch
+ >>> envs = gym.make_vec("CartPole-v1", 3)
+ >>> envs = NumpyToTorch(envs)
+ >>> obs, _ = envs.reset(seed=123)
+ >>> type(obs)
+ <class 'torch.Tensor'>
+ >>> action = torch.tensor(envs.action_space.sample())
+ >>> obs, reward, terminated, truncated, info = envs.step(action)
+ >>> envs.close()
+ >>> type(obs)
+ <class 'torch.Tensor'>
+ >>> type(reward)
+ <class 'torch.Tensor'>
+ >>> type(terminated)
+ <class 'torch.Tensor'>
+ >>> type(truncated)
+ <class 'torch.Tensor'>
+ """
+
+ def __init__(self, env: VectorEnv, device: Device | None = None):
+ """Wrapper class to change inputs and outputs of environment to PyTorch tensors.
+
+ Args:
+ env: The NumPy-based vector environment to wrap
+ device: The device the torch Tensors should be moved to
+ """
+ super().__init__(env, env_xp=np, target_xp=torch, target_device=device)
+
+ self.device: Device | None = device
+
+
+"""A collection of stateful observation wrappers.
+
+* ``NormalizeObservation`` - Normalize the observations
+"""
+
+from __future__ import annotations
+
+from typing import Any
+
+import numpy as np
+
+import gymnasium as gym
+from gymnasium.core import ObsType
+from gymnasium.logger import warn
+from gymnasium.vector.vector_env import (
+ AutoresetMode,
+ VectorEnv,
+ VectorObservationWrapper,
+)
+from gymnasium.wrappers.utils import RunningMeanStd
+
+
+__all__ = ["NormalizeObservation"]
+
+
+
+[docs]
+class NormalizeObservation(VectorObservationWrapper, gym.utils.RecordConstructorArgs):
+ """This wrapper will normalize observations s.t. each coordinate is centered with unit variance.
+
+ The property `_update_running_mean` allows to freeze/continue the running mean calculation of the observation
+ statistics. If `True` (default), the `RunningMeanStd` will get updated every step and reset call.
+ If `False`, the calculated statistics are used but not updated anymore; this may be used during evaluation.
+
+ Note:
+ The normalization depends on past trajectories and observations will not be normalized correctly if the wrapper was
+ newly instantiated or the policy was changed recently.
+
+ Example without the normalize reward wrapper:
+ >>> import gymnasium as gym
+ >>> envs = gym.make_vec("CartPole-v1", num_envs=3, vectorization_mode="sync")
+ >>> obs, info = envs.reset(seed=123)
+ >>> _ = envs.action_space.seed(123)
+ >>> for _ in range(100):
+ ... obs, *_ = envs.step(envs.action_space.sample())
+ >>> np.mean(obs)
+ np.float32(0.024251968)
+ >>> np.std(obs)
+ np.float32(0.62259156)
+ >>> envs.close()
+
+ Example with the normalize reward wrapper:
+ >>> import gymnasium as gym
+ >>> envs = gym.make_vec("CartPole-v1", num_envs=3, vectorization_mode="sync")
+ >>> envs = NormalizeObservation(envs)
+ >>> obs, info = envs.reset(seed=123)
+ >>> _ = envs.action_space.seed(123)
+ >>> for _ in range(100):
+ ... obs, *_ = envs.step(envs.action_space.sample())
+ >>> np.mean(obs)
+ np.float32(-0.2359734)
+ >>> np.std(obs)
+ np.float32(1.1938739)
+ >>> envs.close()
+ """
+
+ def __init__(self, env: VectorEnv, epsilon: float = 1e-8):
+ """This wrapper will normalize observations s.t. each coordinate is centered with unit variance.
+
+ Args:
+ env (Env): The environment to apply the wrapper
+ epsilon: A stability parameter that is used when scaling the observations.
+ """
+ gym.utils.RecordConstructorArgs.__init__(self, epsilon=epsilon)
+ VectorObservationWrapper.__init__(self, env)
+
+ if "autoreset_mode" not in self.env.metadata:
+ warn(
+ f"{self} is missing `autoreset_mode` data. Assuming that the vector environment it follows the `NextStep` autoreset api or autoreset is disabled. Read todo for more details."
+ )
+ else:
+ assert self.env.metadata["autoreset_mode"] in {AutoresetMode.NEXT_STEP}
+
+ self.obs_rms = RunningMeanStd(
+ shape=self.single_observation_space.shape,
+ dtype=self.single_observation_space.dtype,
+ )
+ self.epsilon = epsilon
+ self._update_running_mean = True
+
+ @property
+ def update_running_mean(self) -> bool:
+ """Property to freeze/continue the running mean calculation of the observation statistics."""
+ return self._update_running_mean
+
+ @update_running_mean.setter
+ def update_running_mean(self, setting: bool):
+ """Sets the property to freeze/continue the running mean calculation of the observation statistics."""
+ self._update_running_mean = setting
+
+ def reset(
+ self,
+ *,
+ seed: int | list[int] | None = None,
+ options: dict[str, Any] | None = None,
+ ) -> tuple[ObsType, dict[str, Any]]:
+ """Reset function for `NormalizeObservationWrapper` which is disabled for partial resets."""
+ assert (
+ options is None
+ or "reset_mask" not in options
+ or np.all(options["reset_mask"])
+ )
+ return super().reset(seed=seed, options=options)
+
+ def observations(self, observations: ObsType) -> ObsType:
+ """Defines the vector observation normalization function.
+
+ Args:
+ observations: A vector observation from the environment
+
+ Returns:
+ the normalized observation
+ """
+ if self._update_running_mean:
+ self.obs_rms.update(observations)
+ return (observations - self.obs_rms.mean) / np.sqrt(
+ self.obs_rms.var + self.epsilon
+ )
+
+
+"""A collection of wrappers for modifying the reward with an internal state.
+
+* ``NormalizeReward`` - Normalizes the rewards to a mean and standard deviation
+"""
+
+from __future__ import annotations
+
+from typing import Any, SupportsFloat
+
+import numpy as np
+
+import gymnasium as gym
+from gymnasium.core import ActType, ObsType
+from gymnasium.vector.vector_env import ArrayType, VectorEnv, VectorWrapper
+from gymnasium.wrappers.utils import RunningMeanStd
+
+
+__all__ = ["NormalizeReward"]
+
+
+
+[docs]
+class NormalizeReward(VectorWrapper, gym.utils.RecordConstructorArgs):
+ r"""This wrapper will scale rewards s.t. their exponential moving average has an approximately fixed variance.
+
+ The property `_update_running_mean` allows to freeze/continue the running mean calculation of the reward
+ statistics. If `True` (default), the `RunningMeanStd` will get updated every time `self.normalize()` is called.
+ If False, the calculated statistics are used but not updated anymore; this may be used during evaluation.
+
+ Note:
+ The scaling depends on past trajectories and rewards will not be scaled correctly if the wrapper was newly
+ instantiated or the policy was changed recently.
+
+ Example without the normalize reward wrapper:
+ >>> import gymnasium as gym
+ >>> import numpy as np
+ >>> envs = gym.make_vec("MountainCarContinuous-v0", 3)
+ >>> _ = envs.reset(seed=123)
+ >>> _ = envs.action_space.seed(123)
+ >>> episode_rewards = []
+ >>> for _ in range(100):
+ ... observation, reward, *_ = envs.step(envs.action_space.sample())
+ ... episode_rewards.append(reward)
+ ...
+ >>> envs.close()
+ >>> np.mean(episode_rewards)
+ np.float64(-0.03359492141887935)
+ >>> np.std(episode_rewards)
+ np.float64(0.029028230434438706)
+
+ Example with the normalize reward wrapper:
+ >>> import gymnasium as gym
+ >>> import numpy as np
+ >>> envs = gym.make_vec("MountainCarContinuous-v0", 3)
+ >>> envs = NormalizeReward(envs)
+ >>> _ = envs.reset(seed=123)
+ >>> _ = envs.action_space.seed(123)
+ >>> episode_rewards = []
+ >>> for _ in range(100):
+ ... observation, reward, *_ = envs.step(envs.action_space.sample())
+ ... episode_rewards.append(reward)
+ ...
+ >>> envs.close()
+ >>> np.mean(episode_rewards)
+ np.float64(-0.1598639586606745)
+ >>> np.std(episode_rewards)
+ np.float64(0.27800309628058434)
+ """
+
+ def __init__(
+ self,
+ env: VectorEnv,
+ gamma: float = 0.99,
+ epsilon: float = 1e-8,
+ ):
+ """This wrapper will normalize immediate rewards s.t. their exponential moving average has an approximately fixed variance.
+
+ Args:
+ env (env): The environment to apply the wrapper
+ epsilon (float): A stability parameter
+ gamma (float): The discount factor that is used in the exponential moving average.
+ """
+ gym.utils.RecordConstructorArgs.__init__(self, gamma=gamma, epsilon=epsilon)
+ VectorWrapper.__init__(self, env)
+
+ self.return_rms = RunningMeanStd(shape=())
+ self.accumulated_reward = np.zeros((self.num_envs,), dtype=np.float32)
+ self.gamma = gamma
+ self.epsilon = epsilon
+ self._update_running_mean = True
+
+ @property
+ def update_running_mean(self) -> bool:
+ """Property to freeze/continue the running mean calculation of the reward statistics."""
+ return self._update_running_mean
+
+ @update_running_mean.setter
+ def update_running_mean(self, setting: bool):
+ """Sets the property to freeze/continue the running mean calculation of the reward statistics."""
+ self._update_running_mean = setting
+
+ def step(
+ self, actions: ActType
+ ) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict[str, Any]]:
+ """Steps through the environment, normalizing the reward returned."""
+ obs, reward, terminated, truncated, info = super().step(actions)
+ self.accumulated_reward = (
+ self.accumulated_reward * self.gamma * (1 - terminated) + reward
+ )
+ return obs, self.normalize(reward), terminated, truncated, info
+
+ def normalize(self, reward: SupportsFloat):
+ """Normalizes the rewards with the running mean rewards and their variance."""
+ if self._update_running_mean:
+ self.return_rms.update(self.accumulated_reward)
+ return reward / np.sqrt(self.return_rms.var + self.epsilon)
+
+
+"""Vectorizes action wrappers to work for `VectorEnv`."""
+
+from __future__ import annotations
+
+from collections.abc import Callable
+from copy import deepcopy
+from typing import Any
+
+import numpy as np
+
+from gymnasium import Space
+from gymnasium.core import ActType, Env
+from gymnasium.logger import warn
+from gymnasium.vector import VectorActionWrapper, VectorEnv
+from gymnasium.vector.utils import batch_space, concatenate, create_empty_array, iterate
+from gymnasium.wrappers import transform_action
+
+
+
+[docs]
+class TransformAction(VectorActionWrapper):
+ """Transforms an action via a function provided to the wrapper.
+
+ The function :attr:`func` will be applied to all vector actions.
+ If the observations from :attr:`func` are outside the bounds of the ``env``'s action space,
+ provide an :attr:`action_space` which specifies the action space for the vectorized environment.
+
+ Example - Without action transformation:
+ >>> import gymnasium as gym
+ >>> envs = gym.make_vec("MountainCarContinuous-v0", num_envs=3)
+ >>> _ = envs.action_space.seed(123)
+ >>> obs, info = envs.reset(seed=123)
+ >>> for _ in range(10):
+ ... obs, rew, term, trunc, info = envs.step(envs.action_space.sample())
+ ...
+ >>> envs.close()
+ >>> obs
+ array([[-0.46553135, -0.00142543],
+ [-0.498371 , -0.00715587],
+ [-0.46515748, -0.00624371]], dtype=float32)
+
+ Example - With action transformation:
+ >>> import gymnasium as gym
+ >>> from gymnasium.spaces import Box
+ >>> def shrink_action(act):
+ ... return act * 0.3
+ ...
+ >>> envs = gym.make_vec("MountainCarContinuous-v0", num_envs=3)
+ >>> new_action_space = Box(low=shrink_action(envs.action_space.low), high=shrink_action(envs.action_space.high))
+ >>> envs = TransformAction(env=envs, func=shrink_action, action_space=new_action_space)
+ >>> _ = envs.action_space.seed(123)
+ >>> obs, info = envs.reset(seed=123)
+ >>> for _ in range(10):
+ ... obs, rew, term, trunc, info = envs.step(envs.action_space.sample())
+ ...
+ >>> envs.close()
+ >>> obs
+ array([[-0.48468155, -0.00372536],
+ [-0.47599354, -0.00545912],
+ [-0.46543318, -0.00615723]], dtype=float32)
+ """
+
+ def __init__(
+ self,
+ env: VectorEnv,
+ func: Callable[[ActType], Any],
+ action_space: Space | None = None,
+ single_action_space: Space | None = None,
+ ):
+ """Constructor for the lambda action wrapper.
+
+ Args:
+ env: The vector environment to wrap
+ func: A function that will transform an action. If this transformed action is outside the action space of ``env.action_space`` then provide an ``action_space``.
+ action_space: The action spaces of the wrapper. If None, then it is computed from ``single_action_space``. If ``single_action_space`` is not provided either, then it is assumed to be the same as ``env.action_space``.
+ single_action_space: The action space of the non-vectorized environment. If None, then it is assumed the same as ``env.single_action_space``.
+ """
+ super().__init__(env)
+
+ if action_space is None:
+ if single_action_space is not None:
+ self.single_action_space = single_action_space
+ self.action_space = batch_space(single_action_space, self.num_envs)
+ else:
+ self.action_space = action_space
+ if single_action_space is not None:
+ self.single_action_space = single_action_space
+ # TODO: We could compute single_action_space from the action_space if only the latter is provided and avoid the warning below.
+ if self.action_space != batch_space(self.single_action_space, self.num_envs):
+ warn(
+ f"For {env}, the action space and the batched single action space don't match as expected, action_space={env.action_space}, batched single_action_space={batch_space(self.single_action_space, self.num_envs)}"
+ )
+
+ self.func = func
+
+ def actions(self, actions: ActType) -> ActType:
+ """Applies the :attr:`func` to the actions."""
+ return self.func(actions)
+
+
+
+
+[docs]
+class VectorizeTransformAction(VectorActionWrapper):
+ """Vectorizes a single-agent transform action wrapper for vector environments.
+
+ Example - Without action transformation:
+ >>> import gymnasium as gym
+ >>> envs = gym.make_vec("MountainCarContinuous-v0", num_envs=3)
+ >>> _ = envs.action_space.seed(123)
+ >>> obs, info = envs.reset(seed=123)
+ >>> obs, rew, term, trunc, info = envs.step(envs.action_space.sample())
+ >>> envs.close()
+ >>> obs
+ array([[-4.6343064e-01, 9.8971417e-05],
+ [-4.4488689e-01, -1.9375233e-03],
+ [-4.3118435e-01, -1.5342437e-03]], dtype=float32)
+
+ Example - Adding a transform that applies a ReLU to the action:
+ >>> import gymnasium as gym
+ >>> from gymnasium.wrappers import TransformAction
+ >>> envs = gym.make_vec("MountainCarContinuous-v0", num_envs=3)
+ >>> envs = VectorizeTransformAction(envs, wrapper=TransformAction, func=lambda x: (x > 0.0) * x, action_space=envs.single_action_space)
+ >>> _ = envs.action_space.seed(123)
+ >>> obs, info = envs.reset(seed=123)
+ >>> obs, rew, term, trunc, info = envs.step(envs.action_space.sample())
+ >>> envs.close()
+ >>> obs
+ array([[-4.6343064e-01, 9.8971417e-05],
+ [-4.4354835e-01, -5.9898634e-04],
+ [-4.3034542e-01, -6.9532328e-04]], dtype=float32)
+ """
+
+ class _SingleEnv(Env):
+ """Fake single-agent environment used for the single-agent wrapper."""
+
+ def __init__(self, action_space: Space):
+ """Constructor for the fake environment."""
+ self.action_space = action_space
+
+ def __init__(
+ self,
+ env: VectorEnv,
+ wrapper: type[transform_action.TransformAction],
+ **kwargs: Any,
+ ):
+ """Constructor for the vectorized lambda action wrapper.
+
+ Args:
+ env: The vector environment to wrap
+ wrapper: The wrapper to vectorize
+ **kwargs: Arguments for the LambdaAction wrapper
+ """
+ super().__init__(env)
+
+ self.wrapper = wrapper(self._SingleEnv(self.env.single_action_space), **kwargs)
+ self.single_action_space = self.wrapper.action_space
+ self.action_space = batch_space(self.single_action_space, self.num_envs)
+
+ self.same_out = self.action_space == self.env.action_space
+ self.out = create_empty_array(self.env.single_action_space, self.num_envs)
+
+ def actions(self, actions: ActType) -> ActType:
+ """Applies the wrapper to each of the action.
+
+ Args:
+ actions: The actions to apply the function to
+
+ Returns:
+ The updated actions using the wrapper func
+ """
+ if self.same_out:
+ return concatenate(
+ self.env.single_action_space,
+ tuple(
+ self.wrapper.func(action)
+ for action in iterate(self.action_space, actions)
+ ),
+ actions,
+ )
+ else:
+ return deepcopy(
+ concatenate(
+ self.env.single_action_space,
+ tuple(
+ self.wrapper.func(action)
+ for action in iterate(self.action_space, actions)
+ ),
+ self.out,
+ )
+ )
+
+
+
+
+[docs]
+class ClipAction(VectorizeTransformAction):
+ """Clip the continuous action within the valid :class:`Box` observation space bound.
+
+ Example - Passing an out-of-bounds action to the environment to be clipped.
+ >>> import numpy as np
+ >>> import gymnasium as gym
+ >>> envs = gym.make_vec("MountainCarContinuous-v0", num_envs=3)
+ >>> envs = ClipAction(envs)
+ >>> _ = envs.action_space.seed(123)
+ >>> obs, info = envs.reset(seed=123)
+ >>> obs, rew, term, trunc, info = envs.step(np.array([5.0, -5.0, 2.0]))
+ >>> envs.close()
+ >>> obs
+ array([[-0.4624777 , 0.00105192],
+ [-0.44504836, -0.00209899],
+ [-0.42884544, 0.00080468]], dtype=float32)
+ """
+
+ def __init__(self, env: VectorEnv):
+ """Constructor for the Clip Action wrapper.
+
+ Args:
+ env: The vector environment to wrap
+ """
+ super().__init__(env, transform_action.ClipAction)
+
+
+
+
+[docs]
+class RescaleAction(VectorizeTransformAction):
+ """Affinely rescales the continuous action space of the environment to the range [min_action, max_action].
+
+ Example - Without action scaling:
+ >>> import numpy as np
+ >>> import gymnasium as gym
+ >>> envs = gym.make_vec("MountainCarContinuous-v0", num_envs=3)
+ >>> _ = envs.action_space.seed(123)
+ >>> obs, info = envs.reset(seed=123)
+ >>> for _ in range(10):
+ ... obs, rew, term, trunc, info = envs.step(0.5 * np.ones((3, 1)))
+ ...
+ >>> envs.close()
+ >>> obs
+ array([[-0.44799727, 0.00266526],
+ [-0.4351738 , 0.00133522],
+ [-0.42683297, 0.00048403]], dtype=float32)
+
+ Example - With action scaling:
+ >>> import numpy as np
+ >>> import gymnasium as gym
+ >>> envs = gym.make_vec("MountainCarContinuous-v0", num_envs=3)
+ >>> envs = RescaleAction(envs, 0.0, 1.0)
+ >>> _ = envs.action_space.seed(123)
+ >>> obs, info = envs.reset(seed=123)
+ >>> for _ in range(10):
+ ... obs, rew, term, trunc, info = envs.step(0.5 * np.ones((3, 1)))
+ ...
+ >>> envs.close()
+ >>> obs
+ array([[-0.48657528, -0.00395268],
+ [-0.47377947, -0.00529102],
+ [-0.46546045, -0.00614867]], dtype=float32)
+ """
+
+ def __init__(
+ self,
+ env: VectorEnv,
+ min_action: float | int | np.ndarray,
+ max_action: float | int | np.ndarray,
+ ):
+ """Initializes the :class:`RescaleAction` wrapper.
+
+ Args:
+ env (Env): The vector environment to wrap
+ min_action (float, int or np.ndarray): The min values for each action. This may be a numpy array or a scalar.
+ max_action (float, int or np.ndarray): The max values for each action. This may be a numpy array or a scalar.
+ """
+ super().__init__(
+ env,
+ transform_action.RescaleAction,
+ min_action=min_action,
+ max_action=max_action,
+ )
+
+
+"""Vectorizes observation wrappers to works for `VectorEnv`."""
+
+from __future__ import annotations
+
+from collections.abc import Callable, Sequence
+from copy import deepcopy
+from typing import Any
+
+import numpy as np
+
+from gymnasium import Space
+from gymnasium.core import ActType, Env, ObsType
+from gymnasium.logger import warn
+from gymnasium.vector import VectorEnv, VectorObservationWrapper
+from gymnasium.vector.utils import batch_space, concatenate, create_empty_array, iterate
+from gymnasium.vector.vector_env import ArrayType, AutoresetMode
+from gymnasium.wrappers import transform_observation
+
+
+
+[docs]
+class TransformObservation(VectorObservationWrapper):
+ """Transforms an observation via a function provided to the wrapper.
+
+ This function allows the manual specification of the vector-observation function as well as the single-observation function.
+ This is desirable when, for example, it is possible to process vector observations in parallel or via other more optimized methods.
+ Otherwise, the ``VectorizeTransformObservation`` should be used instead, where only ``single_func`` needs to be defined.
+
+ Example - Without observation transformation:
+ >>> import gymnasium as gym
+ >>> envs = gym.make_vec("CartPole-v1", num_envs=3, vectorization_mode="sync")
+ >>> obs, info = envs.reset(seed=123)
+ >>> obs
+ array([[ 0.01823519, -0.0446179 , -0.02796401, -0.03156282],
+ [ 0.02852531, 0.02858594, 0.0469136 , 0.02480598],
+ [ 0.03517495, -0.000635 , -0.01098382, -0.03203924]],
+ dtype=float32)
+ >>> envs.close()
+
+ Example - With observation transformation:
+ >>> import gymnasium as gym
+ >>> from gymnasium.spaces import Box
+ >>> def scale_and_shift(obs):
+ ... return (obs - 1.0) * 2.0
+ ...
+ >>> import gymnasium as gym
+ >>> envs = gym.make_vec("CartPole-v1", num_envs=3, vectorization_mode="sync")
+ >>> new_obs_space = Box(low=envs.observation_space.low, high=envs.observation_space.high)
+ >>> envs = TransformObservation(envs, func=scale_and_shift, observation_space=new_obs_space)
+ >>> obs, info = envs.reset(seed=123)
+ >>> obs
+ array([[-1.9635296, -2.0892358, -2.055928 , -2.0631256],
+ [-1.9429494, -1.9428282, -1.9061728, -1.9503881],
+ [-1.9296501, -2.00127 , -2.0219676, -2.0640786]], dtype=float32)
+ >>> envs.close()
+ """
+
+ def __init__(
+ self,
+ env: VectorEnv,
+ func: Callable[[ObsType], Any],
+ observation_space: Space | None = None,
+ single_observation_space: Space | None = None,
+ ):
+ """Constructor for the transform observation wrapper.
+
+ Args:
+ env: The vector environment to wrap
+ func: A function that will transform the vector observation. If this transformed observation is outside the observation space of ``env.observation_space`` then provide an ``observation_space``.
+ observation_space: The observation spaces of the wrapper. If None, then it is computed from ``single_observation_space``. If ``single_observation_space`` is not provided either, then it is assumed to be the same as ``env.observation_space``.
+ single_observation_space: The observation space of the non-vectorized environment. If None, then it is assumed the same as ``env.single_observation_space``.
+ """
+ super().__init__(env)
+
+ if observation_space is None:
+ if single_observation_space is not None:
+ self.single_observation_space = single_observation_space
+ self.observation_space = batch_space(
+ single_observation_space, self.num_envs
+ )
+ else:
+ self.observation_space = observation_space
+ if single_observation_space is not None:
+ self._single_observation_space = single_observation_space
+ # TODO: We could compute single_observation_space from the observation_space if only the latter is provided and avoid the warning below.
+ if self.observation_space != batch_space(
+ self.single_observation_space, self.num_envs
+ ):
+ warn(
+ f"For {env}, the observation space and the batched single observation space don't match as expected, observation_space={env.observation_space}, batched single_observation_space={batch_space(self.single_observation_space, self.num_envs)}"
+ )
+
+ self.func = func
+
+ def observations(self, observations: ObsType) -> ObsType:
+ """Apply function to the vector observation."""
+ return self.func(observations)
+
+
+
+
+[docs]
+class VectorizeTransformObservation(VectorObservationWrapper):
+ """Vectorizes a single-agent transform observation wrapper for vector environments.
+
+ Most of the lambda observation wrappers for single agent environments have vectorized implementations,
+ it is advised that users simply use those instead via importing from `gymnasium.wrappers.vector...`.
+ The following example illustrate use-cases where a custom lambda observation wrapper is required.
+
+ Example - The normal observation:
+ >>> import gymnasium as gym
+ >>> envs = gym.make_vec("CartPole-v1", num_envs=3, vectorization_mode="sync")
+ >>> obs, info = envs.reset(seed=123)
+ >>> envs.close()
+ >>> obs
+ array([[ 0.01823519, -0.0446179 , -0.02796401, -0.03156282],
+ [ 0.02852531, 0.02858594, 0.0469136 , 0.02480598],
+ [ 0.03517495, -0.000635 , -0.01098382, -0.03203924]],
+ dtype=float32)
+
+ Example - Applying a custom lambda observation wrapper that duplicates the observation from the environment
+ >>> import numpy as np
+ >>> import gymnasium as gym
+ >>> from gymnasium.spaces import Box
+ >>> from gymnasium.wrappers import TransformObservation
+ >>> envs = gym.make_vec("CartPole-v1", num_envs=3, vectorization_mode="sync")
+ >>> old_space = envs.single_observation_space
+ >>> new_space = Box(low=np.array([old_space.low, old_space.low]), high=np.array([old_space.high, old_space.high]))
+ >>> envs = VectorizeTransformObservation(envs, wrapper=TransformObservation, func=lambda x: np.array([x, x]), observation_space=new_space)
+ >>> obs, info = envs.reset(seed=123)
+ >>> envs.close()
+ >>> obs
+ array([[[ 0.01823519, -0.0446179 , -0.02796401, -0.03156282],
+ [ 0.01823519, -0.0446179 , -0.02796401, -0.03156282]],
+ <BLANKLINE>
+ [[ 0.02852531, 0.02858594, 0.0469136 , 0.02480598],
+ [ 0.02852531, 0.02858594, 0.0469136 , 0.02480598]],
+ <BLANKLINE>
+ [[ 0.03517495, -0.000635 , -0.01098382, -0.03203924],
+ [ 0.03517495, -0.000635 , -0.01098382, -0.03203924]]],
+ dtype=float32)
+ """
+
+ class _SingleEnv(Env):
+ """Fake single-agent environment used for the single-agent wrapper."""
+
+ def __init__(self, observation_space: Space):
+ """Constructor for the fake environment."""
+ self.observation_space = observation_space
+
+ def __init__(
+ self,
+ env: VectorEnv,
+ wrapper: type[transform_observation.TransformObservation],
+ **kwargs: Any,
+ ):
+ """Constructor for the vectorized transform observation wrapper.
+
+ Args:
+ env: The vector environment to wrap.
+ wrapper: The wrapper to vectorize
+ **kwargs: Keyword argument for the wrapper
+ """
+ super().__init__(env)
+
+ if "autoreset_mode" not in env.metadata:
+ warn(
+ f"Vector environment ({env}) is missing `autoreset_mode` metadata key."
+ )
+ self.autoreset_mode = AutoresetMode.NEXT_STEP
+ else:
+ assert isinstance(env.metadata["autoreset_mode"], AutoresetMode)
+ self.autoreset_mode = env.metadata["autoreset_mode"]
+
+ self.wrapper = wrapper(
+ self._SingleEnv(self.env.single_observation_space), **kwargs
+ )
+ self.single_observation_space = self.wrapper.observation_space
+ self.observation_space = batch_space(
+ self.single_observation_space, self.num_envs
+ )
+
+ self.same_out = self.observation_space == self.env.observation_space
+ self.out = create_empty_array(self.single_observation_space, self.num_envs)
+
+ def step(
+ self, actions: ActType
+ ) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict[str, Any]]:
+ """Steps through the vector environments, transforming the observation and for final obs individually transformed."""
+ obs, rewards, terminations, truncations, infos = self.env.step(actions)
+ obs = self.observations(obs)
+
+ if self.autoreset_mode == AutoresetMode.SAME_STEP and "final_obs" in infos:
+ final_obs = infos["final_obs"]
+
+ for i, (sub_obs, has_final_obs) in enumerate(
+ zip(final_obs, infos["_final_obs"])
+ ):
+ if has_final_obs:
+ final_obs[i] = self.wrapper.observation(sub_obs)
+
+ return obs, rewards, terminations, truncations, infos
+
+ def observations(self, observations: ObsType) -> ObsType:
+ """Iterates over the vector observations applying the single-agent wrapper ``observation`` then concatenates the observations together again."""
+ if self.same_out:
+ return concatenate(
+ self.single_observation_space,
+ tuple(
+ self.wrapper.func(obs)
+ for obs in iterate(self.observation_space, observations)
+ ),
+ observations,
+ )
+ else:
+ return deepcopy(
+ concatenate(
+ self.single_observation_space,
+ tuple(
+ self.wrapper.func(obs)
+ for obs in iterate(self.env.observation_space, observations)
+ ),
+ self.out,
+ )
+ )
+
+
+
+
+[docs]
+class FilterObservation(VectorizeTransformObservation):
+ """Vector wrapper for filtering dict or tuple observation spaces.
+
+ Example - Create a vectorized environment with a Dict space to demonstrate how to filter keys:
+ >>> import numpy as np
+ >>> import gymnasium as gym
+ >>> from gymnasium.spaces import Dict, Box
+ >>> from gymnasium.wrappers import TransformObservation
+ >>> from gymnasium.wrappers.vector import VectorizeTransformObservation, FilterObservation
+ >>> envs = gym.make_vec("CartPole-v1", num_envs=3, vectorization_mode="sync")
+ >>> make_dict = lambda x: {"obs": x, "junk": np.array([0.0])}
+ >>> new_space = Dict({"obs": envs.single_observation_space, "junk": Box(low=-1.0, high=1.0)})
+ >>> envs = VectorizeTransformObservation(env=envs, wrapper=TransformObservation, func=make_dict, observation_space=new_space)
+ >>> envs = FilterObservation(envs, ["obs"])
+ >>> obs, info = envs.reset(seed=123)
+ >>> envs.close()
+ >>> obs
+ {'obs': array([[ 0.01823519, -0.0446179 , -0.02796401, -0.03156282],
+ [ 0.02852531, 0.02858594, 0.0469136 , 0.02480598],
+ [ 0.03517495, -0.000635 , -0.01098382, -0.03203924]],
+ dtype=float32)}
+ """
+
+ def __init__(self, env: VectorEnv, filter_keys: Sequence[str | int]):
+ """Constructor for the filter observation wrapper.
+
+ Args:
+ env: The vector environment to wrap
+ filter_keys: The subspaces to be included, use a list of strings or integers for ``Dict`` and ``Tuple`` spaces respectivesly
+ """
+ super().__init__(
+ env, transform_observation.FilterObservation, filter_keys=filter_keys
+ )
+
+
+
+
+[docs]
+class FlattenObservation(VectorizeTransformObservation):
+ """Observation wrapper that flattens the observation.
+
+ Example:
+ >>> import gymnasium as gym
+ >>> envs = gym.make_vec("CarRacing-v3", num_envs=3, vectorization_mode="sync")
+ >>> obs, info = envs.reset(seed=123)
+ >>> obs.shape
+ (3, 96, 96, 3)
+ >>> envs = FlattenObservation(envs)
+ >>> obs, info = envs.reset(seed=123)
+ >>> obs.shape
+ (3, 27648)
+ >>> envs.close()
+ """
+
+ def __init__(self, env: VectorEnv):
+ """Constructor for any environment's observation space that implements ``spaces.utils.flatten_space`` and ``spaces.utils.flatten``.
+
+ Args:
+ env: The vector environment to wrap
+ """
+ super().__init__(env, transform_observation.FlattenObservation)
+
+
+
+
+[docs]
+class GrayscaleObservation(VectorizeTransformObservation):
+ """Observation wrapper that converts an RGB image to grayscale.
+
+ Example:
+ >>> import gymnasium as gym
+ >>> envs = gym.make_vec("CarRacing-v3", num_envs=3, vectorization_mode="sync")
+ >>> obs, info = envs.reset(seed=123)
+ >>> obs.shape
+ (3, 96, 96, 3)
+ >>> envs = GrayscaleObservation(envs)
+ >>> obs, info = envs.reset(seed=123)
+ >>> obs.shape
+ (3, 96, 96)
+ >>> envs.close()
+ """
+
+ def __init__(self, env: VectorEnv, keep_dim: bool = False):
+ """Constructor for an RGB image based environments to make the image grayscale.
+
+ Args:
+ env: The vector environment to wrap
+ keep_dim: If to keep the channel in the observation, if ``True``, ``obs.shape == 3`` else ``obs.shape == 2``
+ """
+ super().__init__(
+ env, transform_observation.GrayscaleObservation, keep_dim=keep_dim
+ )
+
+
+
+
+[docs]
+class ResizeObservation(VectorizeTransformObservation):
+ """Resizes image observations using OpenCV to shape.
+
+ Example:
+ >>> import gymnasium as gym
+ >>> envs = gym.make_vec("CarRacing-v3", num_envs=3, vectorization_mode="sync")
+ >>> obs, info = envs.reset(seed=123)
+ >>> obs.shape
+ (3, 96, 96, 3)
+ >>> envs = ResizeObservation(envs, shape=(28, 28))
+ >>> obs, info = envs.reset(seed=123)
+ >>> obs.shape
+ (3, 28, 28, 3)
+ >>> envs.close()
+ """
+
+ def __init__(self, env: VectorEnv, shape: tuple[int, ...]):
+ """Constructor that requires an image environment observation space with a shape.
+
+ Args:
+ env: The vector environment to wrap
+ shape: The resized observation shape
+ """
+ super().__init__(env, transform_observation.ResizeObservation, shape=shape)
+
+
+
+
+[docs]
+class ReshapeObservation(VectorizeTransformObservation):
+ """Reshapes array based observations to shapes.
+
+ Example:
+ >>> import gymnasium as gym
+ >>> envs = gym.make_vec("CarRacing-v3", num_envs=3, vectorization_mode="sync")
+ >>> obs, info = envs.reset(seed=123)
+ >>> obs.shape
+ (3, 96, 96, 3)
+ >>> envs = ReshapeObservation(envs, shape=(9216, 3))
+ >>> obs, info = envs.reset(seed=123)
+ >>> obs.shape
+ (3, 9216, 3)
+ >>> envs.close()
+ """
+
+ def __init__(self, env: VectorEnv, shape: int | tuple[int, ...]):
+ """Constructor for env with Box observation space that has a shape product equal to the new shape product.
+
+ Args:
+ env: The vector environment to wrap
+ shape: The reshaped observation space
+ """
+ super().__init__(env, transform_observation.ReshapeObservation, shape=shape)
+
+
+
+
+[docs]
+class RescaleObservation(VectorizeTransformObservation):
+ """Linearly rescales observation to between a minimum and maximum value.
+
+ Example:
+ >>> import gymnasium as gym
+ >>> envs = gym.make_vec("MountainCar-v0", num_envs=3, vectorization_mode="sync")
+ >>> obs, info = envs.reset(seed=123)
+ >>> obs.min()
+ np.float32(-0.46352962)
+ >>> obs.max()
+ np.float32(0.0)
+ >>> envs = RescaleObservation(envs, min_obs=-5.0, max_obs=5.0)
+ >>> obs, info = envs.reset(seed=123)
+ >>> obs.min()
+ np.float32(-0.90849805)
+ >>> obs.max()
+ np.float32(0.0)
+ >>> envs.close()
+ """
+
+ def __init__(
+ self,
+ env: VectorEnv,
+ min_obs: np.floating | np.integer | np.ndarray,
+ max_obs: np.floating | np.integer | np.ndarray,
+ ):
+ """Constructor that requires the env observation spaces to be a :class:`Box`.
+
+ Args:
+ env: The vector environment to wrap
+ min_obs: The new minimum observation bound
+ max_obs: The new maximum observation bound
+ """
+ super().__init__(
+ env,
+ transform_observation.RescaleObservation,
+ min_obs=min_obs,
+ max_obs=max_obs,
+ )
+
+
+
+
+[docs]
+class DtypeObservation(VectorizeTransformObservation):
+ """Observation wrapper for transforming the dtype of an observation.
+
+ Example:
+ >>> import numpy as np
+ >>> import gymnasium as gym
+ >>> envs = gym.make_vec("CartPole-v1", num_envs=3, vectorization_mode="sync")
+ >>> obs, info = envs.reset(seed=123)
+ >>> obs.dtype
+ dtype('float32')
+ >>> envs = DtypeObservation(envs, dtype=np.float64)
+ >>> obs, info = envs.reset(seed=123)
+ >>> obs.dtype
+ dtype('float64')
+ >>> envs.close()
+ """
+
+ def __init__(self, env: VectorEnv, dtype: Any):
+ """Constructor for Dtype observation wrapper.
+
+ Args:
+ env: The vector environment to wrap
+ dtype: The new dtype of the observation
+ """
+ super().__init__(env, transform_observation.DtypeObservation, dtype=dtype)
+
+
+"""Vectorizes reward function to work with `VectorEnv`."""
+
+from __future__ import annotations
+
+from collections.abc import Callable
+from typing import Any
+
+import numpy as np
+
+from gymnasium import Env
+from gymnasium.vector import VectorEnv, VectorRewardWrapper
+from gymnasium.vector.vector_env import ArrayType
+from gymnasium.wrappers import transform_reward
+
+
+
+[docs]
+class TransformReward(VectorRewardWrapper):
+ """A reward wrapper that allows a custom function to modify the step reward.
+
+ Example with reward transformation:
+ >>> import gymnasium as gym
+ >>> from gymnasium.spaces import Box
+ >>> def scale_and_shift(rew):
+ ... return (rew - 1.0) * 2.0
+ ...
+ >>> envs = gym.make_vec("MountainCarContinuous-v0", num_envs=3)
+ >>> envs = TransformReward(env=envs, func=scale_and_shift)
+ >>> _ = envs.action_space.seed(123)
+ >>> obs, info = envs.reset(seed=123)
+ >>> obs, rew, term, trunc, info = envs.step(envs.action_space.sample())
+ >>> envs.close()
+ >>> obs
+ array([[-4.6343064e-01, 9.8971417e-05],
+ [-4.4488689e-01, -1.9375233e-03],
+ [-4.3118435e-01, -1.5342437e-03]], dtype=float32)
+ """
+
+ def __init__(self, env: VectorEnv, func: Callable[[ArrayType], ArrayType]):
+ """Initialize LambdaReward wrapper.
+
+ Args:
+ env (Env): The vector environment to wrap
+ func: (Callable): The function to apply to reward
+ """
+ super().__init__(env)
+
+ self.func = func
+
+ def rewards(self, reward: ArrayType) -> ArrayType:
+ """Apply function to reward."""
+ return self.func(reward)
+
+
+
+
+[docs]
+class VectorizeTransformReward(VectorRewardWrapper):
+ """Vectorizes a single-agent transform reward wrapper for vector environments.
+
+ An example such that applies a ReLU to the reward:
+ >>> import gymnasium as gym
+ >>> from gymnasium.wrappers import TransformReward
+ >>> envs = gym.make_vec("MountainCarContinuous-v0", num_envs=3)
+ >>> envs = VectorizeTransformReward(envs, wrapper=TransformReward, func=lambda x: (x > 0.0) * x)
+ >>> _ = envs.action_space.seed(123)
+ >>> obs, info = envs.reset(seed=123)
+ >>> obs, rew, term, trunc, info = envs.step(envs.action_space.sample())
+ >>> envs.close()
+ >>> rew
+ array([-0., -0., -0.])
+ """
+
+ def __init__(
+ self,
+ env: VectorEnv,
+ wrapper: type[transform_reward.TransformReward],
+ **kwargs: Any,
+ ):
+ """Constructor for the vectorized lambda reward wrapper.
+
+ Args:
+ env: The vector environment to wrap.
+ wrapper: The wrapper to vectorize
+ **kwargs: Keyword argument for the wrapper
+ """
+ super().__init__(env)
+
+ self.wrapper = wrapper(Env(), **kwargs)
+
+ def rewards(self, reward: ArrayType) -> ArrayType:
+ """Iterates over the reward updating each with the wrapper func."""
+ for i, r in enumerate(reward):
+ reward[i] = self.wrapper.func(r)
+ return reward
+
+
+
+
+[docs]
+class ClipReward(VectorizeTransformReward):
+ """A wrapper that clips the rewards for an environment between an upper and lower bound.
+
+ Example with clipped rewards:
+ >>> import numpy as np
+ >>> import gymnasium as gym
+ >>> envs = gym.make_vec("MountainCarContinuous-v0", num_envs=3)
+ >>> envs = ClipReward(envs, 0.0, 2.0)
+ >>> _ = envs.action_space.seed(123)
+ >>> obs, info = envs.reset(seed=123)
+ >>> for _ in range(10):
+ ... obs, rew, term, trunc, info = envs.step(0.5 * np.ones((3, 1)))
+ ...
+ >>> envs.close()
+ >>> rew
+ array([0., 0., 0.])
+ """
+
+ def __init__(
+ self,
+ env: VectorEnv,
+ min_reward: float | np.ndarray | None = None,
+ max_reward: float | np.ndarray | None = None,
+ ):
+ """Constructor for ClipReward wrapper.
+
+ Args:
+ env: The vector environment to wrap
+ min_reward: The min reward for each step
+ max_reward: the max reward for each step
+ """
+ super().__init__(
+ env,
+ transform_reward.ClipReward,
+ min_reward=min_reward,
+ max_reward=max_reward,
+ )
+
+
' + + '' + + _("Hide Search Matches") + + "
" + ) + ); + }, + + /** + * helper function to hide the search marks again + */ + hideSearchWords: () => { + document + .querySelectorAll("#searchbox .highlight-link") + .forEach((el) => el.remove()); + document + .querySelectorAll("span.highlighted") + .forEach((el) => el.classList.remove("highlighted")); + localStorage.removeItem("sphinx_highlight_terms") + }, + + initEscapeListener: () => { + // only install a listener if it is really needed + if (!DOCUMENTATION_OPTIONS.ENABLE_SEARCH_SHORTCUTS) return; + + document.addEventListener("keydown", (event) => { + // bail for input elements + if (BLACKLISTED_KEY_CONTROL_ELEMENTS.has(document.activeElement.tagName)) return; + // bail with special keys + if (event.shiftKey || event.altKey || event.ctrlKey || event.metaKey) return; + if (DOCUMENTATION_OPTIONS.ENABLE_SEARCH_SHORTCUTS && (event.key === "Escape")) { + SphinxHighlight.hideSearchWords(); + event.preventDefault(); + } + }); + }, +}; + +_ready(() => { + /* Do not call highlightSearchWords() when we are on the search page. + * It will highlight words from the *previous* search query. + */ + if (typeof Search === "undefined") SphinxHighlight.highlightSearchWords(); + SphinxHighlight.initEscapeListener(); +}); diff --git a/v1.2.0/_static/styles/furo-extensions.css b/v1.2.0/_static/styles/furo-extensions.css new file mode 100644 index 000000000..ee651ce3d --- /dev/null +++ b/v1.2.0/_static/styles/furo-extensions.css @@ -0,0 +1,2 @@ +#furo-sidebar-ad-placement{padding:var(--sidebar-item-spacing-vertical) var(--sidebar-item-spacing-horizontal)}#furo-sidebar-ad-placement .ethical-sidebar{background:var(--color-background-secondary);border:none;box-shadow:none}#furo-sidebar-ad-placement .ethical-sidebar:hover{background:var(--color-background-hover)}#furo-sidebar-ad-placement .ethical-sidebar a{color:var(--color-foreground-primary)}#furo-sidebar-ad-placement .ethical-callout a{color:var(--color-foreground-secondary)!important}#furo-readthedocs-versions{background:transparent;display:block;position:static;width:100%}#furo-readthedocs-versions .rst-versions{background:#1a1c1e}#furo-readthedocs-versions .rst-current-version{background:var(--color-sidebar-item-background);cursor:unset}#furo-readthedocs-versions .rst-current-version:hover{background:var(--color-sidebar-item-background)}#furo-readthedocs-versions .rst-current-version .fa-book{color:var(--color-foreground-primary)}#furo-readthedocs-versions>.rst-other-versions{padding:0}#furo-readthedocs-versions>.rst-other-versions small{opacity:1}#furo-readthedocs-versions .injected .rst-versions{position:unset}#furo-readthedocs-versions:focus-within,#furo-readthedocs-versions:hover{box-shadow:0 0 0 1px var(--color-sidebar-background-border)}#furo-readthedocs-versions:focus-within .rst-current-version,#furo-readthedocs-versions:hover .rst-current-version{background:#1a1c1e;font-size:inherit;height:auto;line-height:inherit;padding:12px;text-align:right}#furo-readthedocs-versions:focus-within .rst-current-version .fa-book,#furo-readthedocs-versions:hover .rst-current-version .fa-book{color:#fff;float:left}#furo-readthedocs-versions:focus-within .fa-caret-down,#furo-readthedocs-versions:hover .fa-caret-down{display:none}#furo-readthedocs-versions:focus-within .injected,#furo-readthedocs-versions:focus-within .rst-current-version,#furo-readthedocs-versions:focus-within .rst-other-versions,#furo-readthedocs-versions:hover .injected,#furo-readthedocs-versions:hover .rst-current-version,#furo-readthedocs-versions:hover .rst-other-versions{display:block}#furo-readthedocs-versions:focus-within>.rst-current-version,#furo-readthedocs-versions:hover>.rst-current-version{display:none}.highlight:hover button.copybtn{color:var(--color-code-foreground)}.highlight button.copybtn{align-items:center;background-color:var(--color-code-background);border:none;color:var(--color-background-item);cursor:pointer;height:1.25em;opacity:1;right:.5rem;top:.625rem;transition:color .3s,opacity .3s;width:1.25em}.highlight button.copybtn:hover{background-color:var(--color-code-background);color:var(--color-brand-content)}.highlight button.copybtn:after{background-color:transparent;color:var(--color-code-foreground);display:none}.highlight button.copybtn.success{color:#22863a;transition:color 0ms}.highlight button.copybtn.success:after{display:block}.highlight button.copybtn svg{padding:0}body{--sd-color-primary:var(--color-brand-primary);--sd-color-primary-highlight:var(--color-brand-content);--sd-color-primary-text:var(--color-background-primary);--sd-color-shadow:rgba(0,0,0,.05);--sd-color-card-border:var(--color-card-border);--sd-color-card-border-hover:var(--color-brand-content);--sd-color-card-background:var(--color-card-background);--sd-color-card-text:var(--color-foreground-primary);--sd-color-card-header:var(--color-card-marginals-background);--sd-color-card-footer:var(--color-card-marginals-background);--sd-color-tabs-label-active:var(--color-brand-content);--sd-color-tabs-label-hover:var(--color-foreground-muted);--sd-color-tabs-label-inactive:var(--color-foreground-muted);--sd-color-tabs-underline-active:var(--color-brand-content);--sd-color-tabs-underline-hover:var(--color-foreground-border);--sd-color-tabs-underline-inactive:var(--color-background-border);--sd-color-tabs-overline:var(--color-background-border);--sd-color-tabs-underline:var(--color-background-border)}.sd-tab-content{box-shadow:0 -2px var(--sd-color-tabs-overline),0 1px var(--sd-color-tabs-underline)}.sd-card{box-shadow:0 .1rem .25rem var(--sd-color-shadow),0 0 .0625rem rgba(0,0,0,.1)}.sd-shadow-sm{box-shadow:0 .1rem .25rem var(--sd-color-shadow),0 0 .0625rem rgba(0,0,0,.1)!important}.sd-shadow-md{box-shadow:0 .3rem .75rem var(--sd-color-shadow),0 0 .0625rem rgba(0,0,0,.1)!important}.sd-shadow-lg{box-shadow:0 .6rem 1.5rem var(--sd-color-shadow),0 0 .0625rem rgba(0,0,0,.1)!important}.sd-card-hover:hover{transform:none}.sd-cards-carousel{gap:.25rem;padding:.25rem}body{--tabs--label-text:var(--color-foreground-muted);--tabs--label-text--hover:var(--color-foreground-muted);--tabs--label-text--active:var(--color-brand-content);--tabs--label-text--active--hover:var(--color-brand-content);--tabs--label-background:transparent;--tabs--label-background--hover:transparent;--tabs--label-background--active:transparent;--tabs--label-background--active--hover:transparent;--tabs--padding-x:0.25em;--tabs--margin-x:1em;--tabs--border:var(--color-background-border);--tabs--label-border:transparent;--tabs--label-border--hover:var(--color-foreground-muted);--tabs--label-border--active:var(--color-brand-content);--tabs--label-border--active--hover:var(--color-brand-content)}[role=main] .container{max-width:none;padding-left:0;padding-right:0}.shadow.docutils{border:none;box-shadow:0 .2rem .5rem rgba(0,0,0,.05),0 0 .0625rem rgba(0,0,0,.1)!important}.sphinx-bs .card{background-color:var(--color-background-secondary);color:var(--color-foreground)}h1{font-size:2.2rem}h2{font-size:1.7rem}h3{font-size:1.4rem}html:has(.farama-header-menu.active){visibility:hidden}.farama-hidden[aria-hidden=true]{visibility:hidden}.farama-hidden[aria-hidden=false]{visibility:visible}.cookie-alert{background-color:var(--color-background-secondary);border-top:1px solid var(--color-background-border);bottom:0;color:var(--color-foreground-primary);display:flex;left:0;min-height:70px;position:fixed;width:100%;z-index:99999}.cookie-alert__container{align-items:center;display:flex;margin:auto;max-width:calc(100% - 28px);width:700px}.cookie-alert__button{margin-left:14px}.cookie-alert p{flex:1}.farama-btn{background:var(--color-farama-button-background);border:none;border-radius:6px;cursor:pointer;padding:10px 26px;transition:background-color .2s ease}.farama-btn:hover{background:var(--color-farama-button-background-hover)}article[role=main]:has(.farama-env-icon-container) .farama-env-icon-container{display:flex;margin-top:7px;position:absolute}article[role=main]:has(.farama-env-icon-container) .section h1:first-child,article[role=main]:has(.farama-env-icon-container) .section h2:first-child,article[role=main]:has(.farama-env-icon-container) section h1:first-child,article[role=main]:has(.farama-env-icon-container) section h2:first-child{margin-left:34px}.farama-env-icon{height:32px}.env-grid{box-sizing:border-box;display:flex;flex-wrap:wrap;justify-content:center;width:100%}.env-grid__cell{display:flex;flex-direction:column;height:180px;padding:10px;width:180px}.cell__image-container{display:flex;height:148px;justify-content:center}.cell__image-container img{max-height:100%;-o-object-fit:contain;object-fit:contain}.cell__title{align-items:flex-end;display:flex;height:32px;justify-content:center;line-height:16px;text-align:center}.more-btn{display:block;margin:12px auto;width:240px}html:has(.farama-header-menu.active){overflow:hidden}body{--farama-header-height:52px;--farama-header-logo-margin:10px;--farama-sidebar-logo-margin:2px 10px}.farama-header{background-color:var(--color-background-secondary);border-bottom:1px solid var(--color-header-border);box-sizing:border-box;display:flex;height:var(--farama-header-height);padding:0 36px 0 24px;position:absolute;width:100%;z-index:95}.farama-header .farama-header__container{display:flex;justify-content:space-between;margin:0 auto;max-width:1400px;width:100%}.farama-header a{color:var(--color-foreground-primary);text-decoration:none;transition:color .125s ease}.farama-header a:hover{color:var(--color-foreground-secondary)}.farama-header .farama-header__logo{margin:var(--farama-header-logo-margin);max-height:calc(var(--farama-header-height) - var(--farama-header-logo-margin))}.farama-header .farama-header__title{align-self:center;font-size:var(--font-size--normal);font-weight:400;margin:0 0 2px;padding:0 0 0 4px}.farama-header .farama-header__left,.farama-header .farama-header__left a{display:flex}.farama-header .farama-header__left--mobile{display:none}.farama-header .farama-header__left--mobile .nav-overlay-icon svg{stroke:var(--color-foreground-primary);fill:var(--color-foreground-primary);stroke-width:2px;padding:0 6px;width:20px}.farama-header .farama-header__right{align-items:center;display:flex;z-index:2}.farama-header .farama-header__right .farama-header__nav{display:flex;height:100%;list-style:none}.farama-header .farama-header__right .farama-header__nav li{align-items:center;cursor:pointer;display:flex;margin-left:20px;text-decoration:none}.farama-header .farama-header__right .farama-header__nav li a{align-items:center;display:flex;height:100%}.farama-header .farama-header__right .farama-header__nav li .farama-header__dropdown-container{align-items:center;display:flex;height:100%;position:relative}.farama-header .farama-header__right .farama-header__nav li .farama-header__dropdown-container:hover .farama-header__dropdown-menu{display:block}.farama-header .farama-header__right .farama-header__nav li .farama-header__dropdown-container svg{fill:var(--color-foreground-primary);width:32px}.farama-header .farama-header__right .farama-header__nav li .farama-header__dropdown-container .farama-header__dropdown-menu{background:var(--color-background-hover);border:1px solid var(--color-background-border);display:none;position:absolute;right:0;top:var(--farama-header-height);z-index:9999}.farama-header .farama-header__right .farama-header__nav li .farama-header__dropdown-container .farama-header__dropdown-menu ul{display:inherit;margin:0;padding:6px 14px}.farama-header .farama-header__right .farama-header__nav li .farama-header__dropdown-container .farama-header__dropdown-menu li{margin:0;padding:6px 0}.farama-header .farama-header__right .farama-header-menu{display:flex;justify-content:center;position:relative}.farama-header .farama-header__right .farama-header-menu .farama-header-menu__btn{background:none;border:none;cursor:pointer;display:flex}.farama-header .farama-header__right .farama-header-menu .farama-header-menu__btn img{width:26px}.farama-header .farama-header__right .farama-header-menu .farama-header-menu__btn svg{stroke:var(--color-foreground-primary);stroke-width:2px;align-self:center;width:14px}.farama-header .farama-header__right .farama-header-menu.active .farama-header-menu-container{transform:translateY(100vh)}.farama-header .farama-header__right .farama-header-menu .farama-header-menu-container{background-color:var(--color-background-secondary);border-left:1px solid var(--color-background-border);box-sizing:border-box;height:100%;overflow:auto;position:fixed;right:0;top:-100vh;transform:translateY(0);transition:transform .2s ease-in;width:100%;z-index:99}.farama-header .farama-header__right .farama-header-menu .farama-header-menu-container .farama-header-menu__header{align-items:center;border-bottom:1px solid var(--color-background-border);box-sizing:border-box;display:flex;margin:0 auto;max-width:1400px;padding:7px 52px;position:relative;width:100%}.farama-header .farama-header__right .farama-header-menu .farama-header-menu-container .farama-header-menu__header a{align-items:center;display:flex}.farama-header .farama-header__right .farama-header-menu .farama-header-menu-container .farama-header-menu__header a .farama-header-menu__logo{width:36px}.farama-header .farama-header__right .farama-header-menu .farama-header-menu-container .farama-header-menu__header a span{color:var(--color-sidebar-brand-text);padding-left:8px}.farama-header .farama-header__right .farama-header-menu .farama-header-menu-container .farama-header-menu__header .farama-header-menu-header__right{padding-right:inherit;position:absolute;right:0}.farama-header .farama-header__right .farama-header-menu .farama-header-menu-container .farama-header-menu__header .farama-header-menu-header__right button{background:none;border:none;cursor:pointer;display:flex}.farama-header .farama-header__right .farama-header-menu .farama-header-menu-container .farama-header-menu__header .farama-header-menu-header__right button svg{color:var(--color-foreground-primary);width:20px}.farama-header .farama-header__right .farama-header-menu .farama-header-menu-container .farama-header-menu__body{box-sizing:border-box;display:flex;flex-wrap:wrap;margin:0 auto;max-width:1500px;padding:22px 52px;width:100%}.farama-header .farama-header__right .farama-header-menu .farama-header-menu-container .farama-header-menu__body .farama-header-menu__section{margin-bottom:24px;min-width:220px;padding-left:18px}.farama-header .farama-header__right .farama-header-menu .farama-header-menu-container .farama-header-menu__body .farama-header-menu__section .farama-header-menu__section-title{display:block;font-size:var(--font-size--small);font-weight:600;padding:0 12px 12px;text-transform:uppercase}.farama-header .farama-header__right .farama-header-menu .farama-header-menu-container .farama-header-menu__body .farama-header-menu__section .farama-header-menu__subsections-container .farama-header-menu__subsection{min-width:210px}.farama-header .farama-header__right .farama-header-menu .farama-header-menu-container .farama-header-menu__body .farama-header-menu__section .farama-header-menu__subsections-container .farama-header-menu__subsection:not(:last-child){margin-right:12px}.farama-header .farama-header__right .farama-header-menu .farama-header-menu-container .farama-header-menu__body .farama-header-menu__section .farama-header-menu__subsections-container .farama-header-menu__subsection .farama-header-menu__subsection-title{color:var(--color-foreground-secondary);display:block;font-size:var(--font-size--small--3);font-weight:700;padding:20px 12px 10px;text-transform:uppercase}.farama-header .farama-header__right .farama-header-menu .farama-header-menu-container .farama-header-menu__body .farama-header-menu__section .farama-header-menu-list{display:inherit;list-style:none;margin:0;padding:0}.farama-header .farama-header__right .farama-header-menu .farama-header-menu-container .farama-header-menu__body .farama-header-menu__section .farama-header-menu-list li{border-radius:var(--sidebar-item-border-radius)}.farama-header .farama-header__right .farama-header-menu .farama-header-menu-container .farama-header-menu__body .farama-header-menu__section .farama-header-menu-list li:hover{background-color:var(--color-farama-header-background-hover)}.farama-header .farama-header__right .farama-header-menu .farama-header-menu-container .farama-header-menu__body .farama-header-menu__section .farama-header-menu-list li a{align-items:center;display:flex;padding:12px 14px}.farama-header .farama-header__right .farama-header-menu .farama-header-menu-container .farama-header-menu__body .farama-header-menu__section .farama-header-menu-list li a:hover{color:inherit}.farama-header .farama-header__right .farama-header-menu .farama-header-menu-container .farama-header-menu__body .farama-header-menu__section .farama-header-menu-list li a img{margin-right:10px;width:26px}.farama-sidebar__title{align-items:center;display:flex;margin-left:var(--sidebar-search-space-lateral);margin-top:.6rem;min-height:calc(52px - var(--sidebar-search-space-above));padding-right:4px;text-decoration:none}.farama-sidebar__title img{height:calc(var(--farama-header-height) - 20px);margin:var(--farama-sidebar-logo-margin)}.farama-sidebar__title span{color:var(--color-foreground-primary)}.farama-sidebar__title span:hover{text-decoration:none}.sidebar-brand{align-items:center;flex-direction:row;padding:var(--sidebar-item-spacing-vertical)}.sidebar-brand .sidebar-logo-container{display:flex;height:auto;max-width:55px}.sidebar-brand .sidebar-brand-text{font-size:1.3rem;padding-left:11px}.farama-sidebar-donate{margin:0 auto;padding:8px 16px 20px;width:76%}.farama-sidebar-donate .farama-donate-btn{background:linear-gradient(to right top,#765e3e,#054f5b);background-blend-mode:color;background-color:transparent;border:none;border-radius:6px;color:#fff;cursor:pointer;padding:8px 12px;transition:background-color .2s ease;width:100%}.farama-sidebar-donate .farama-donate-btn:hover{background-color:hsla(0,0%,100%,.15)}.farama-donate-banner{background-color:var(--color-highlighted-background);box-sizing:border-box;display:none;padding:16px 3em;width:100%}.farama-donate-banner.active{display:flex}.farama-donate-banner .farama-donate-banner__text{align-items:center;display:flex;flex:1;font-size:1.1em;justify-content:center}.farama-donate-banner .farama-donate-banner__btns{align-items:center;display:flex}.farama-donate-banner .farama-donate-banner__btns a{text-decoration:none}.farama-donate-banner .farama-donate-banner__btns button{align-items:center;border:none;border-radius:6px;cursor:pointer;display:flex;height:36px;justify-content:center;margin-left:22px;position:relative}.farama-donate-banner .farama-donate-banner__btns .farama-donate-banner__go{background:linear-gradient(to right top,#765e3e,#054f5b);background-blend-mode:color;background-color:transparent;color:#fff;padding:0 26px;transition:background-color .2s ease}.farama-donate-banner .farama-donate-banner__btns .farama-donate-banner__go:hover{background-color:hsla(0,0%,100%,.1)}.farama-donate-banner .farama-donate-banner__btns .farama-donate-banner__cancel{transition:background-color .2s ease}.farama-donate-banner .farama-donate-banner__btns .farama-donate-banner__cancel svg{height:26px}@media(prefers-color-scheme:dark){body:not([data-theme=light]) .farama-donate-banner__cancel{background-color:rgba(0,0,0,.1)}body:not([data-theme=light]) .farama-donate-banner__cancel:hover{background:rgba(0,0,0,.2)}body:not([data-theme=light]) .farama-donate-banner__cancel svg{stroke:#fff}body[data-theme=light] .farama-donate-banner__cancel{background-color:rgba(25,25,25,.1)}body[data-theme=light] .farama-donate-banner__cancel:hover{background:hsla(0,0%,100%,.2)}body[data-theme=light] .farama-donate-banner__cancel svg{stroke:#666}}@media(prefers-color-scheme:light){body:not([data-theme=dark]) .farama-donate-banner__cancel{background-color:rgba(25,25,25,.1)}body:not([data-theme=dark]) .farama-donate-banner__cancel:hover{background:hsla(0,0%,100%,.2)}body:not([data-theme=dark]) .farama-donate-banner__cancel svg{stroke:#666}body[data-theme=dark] .farama-donate-banner__cancel{background-color:rgba(0,0,0,.1)}body[data-theme=dark] .farama-donate-banner__cancel:hover{background:rgba(0,0,0,.2)}body[data-theme=dark] .farama-donate-banner__cancel svg{stroke:#fff}}.farama-project-logo{margin:1.5rem 0 .8rem!important}.farama-project-heading{margin:0;padding:0 0 1.6rem;text-align:center}.farama-project-logo img{width:65%}.mobile-header .header-center{opacity:0;transition:opacity easy-in .2s}.mobile-header.scrolled .header-center{opacity:1}.sphx-glr-script-out{color:var(--color-foreground-secondary);display:flex;gap:.5em}.sphx-glr-script-out:before{content:"Out:";line-height:1.4;padding-top:10px}.sphx-glr-script-out .highlight{overflow-x:auto}.sphx-glr-thumbcontainer{z-index:1}div.sphx-glr-download a{background:#0f4a65;box-sizing:border-box;max-width:100%;width:340px}div.sphx-glr-download a:hover{background:#0d3a4e;box-shadow:none}@media(prefers-color-scheme:dark){body:not([data-theme=light]) div.sphx-glr-download a{background:#0f4a65}body:not([data-theme=light]) div.sphx-glr-download a:hover{background:#0d3a4e}body[data-theme=light] div.sphx-glr-download a{background:#f9d4a1}body[data-theme=light] div.sphx-glr-download a:hover{background:#d9b481}}@media(prefers-color-scheme:light){body:not([data-theme=dark]) div.sphx-glr-download a{background:#f9d4a1}body:not([data-theme=dark]) div.sphx-glr-download a:hover{background:#d9b481}body[data-theme=dark] div.sphx-glr-download a{background:#0f4a65}body[data-theme=dark] div.sphx-glr-download a:hover{background:#0d3a4e}}body[data-theme=light] div.sphx-glr-download a{background:#f9d4a1}body[data-theme=light] div.sphx-glr-download a:hover{background:#d9b481}.sphx-glr-thumbcontainer img{background-color:#fff;border-radius:4px}.tab-content>[class^=highlight-]:first-child .highlight{background:var(--color-api-background);border-radius:6px}.tab-set>input+label{font-weight:600}.tab-set>input:checked+label,.tab-set>input:checked+label:hover{border-color:var(--color-brand-secondary);color:var(--color-brand-secondary)}div.jupyter_container{background:var(--color-api-background);border:none;box-shadow:none}div.jupyter_container div.code_cell,div.jupyter_container div.highlight{border:none;border-radius:0}div.jupyter_container div.code_cell pre{padding:.625rem .875rem}@media(prefers-color-scheme:dark){body:not([data-theme=light]) div.jupyter_container div.highlight{background:#202020;color:#d0d0d0}}body[data-theme=dark] div.jupyter_container div.highlight{background:#202020;color:#d0d0d0}@media(max-width:950px){.farama-header .farama-header__right .farama-header-menu .farama-header-menu-container .farama-header-menu__header{padding:7px 42px}.farama-header .farama-header-menu__btn-name{display:none}}@media(max-width:600px){.farama-header{padding:0 4px}.farama-header .farama-header__title{font-size:var(--font-size--small)}.farama-header .farama-header__right .farama-header-menu .farama-header-menu-container .farama-header-menu__header{padding:8px 12px}.farama-header .farama-header__right .farama-header-menu .farama-header-menu-container .farama-header-menu__body{padding:18px 12px}.farama-donate-banner{flex-direction:column}.farama-donate-banner .farama-donate-banner__btns{justify-content:end;margin-top:1em}.farama-donate-banner .farama-donate-banner__btns button{height:36px;margin-left:12px}.farama-donate-banner .farama-donate-banner__btns .farama-donate-banner__go{padding:0 20px}.farama-donate-banner .farama-donate-banner__btns .farama-donate-banner__cancel svg{height:26px}}@media(max-width:480px){.farama-header .farama-header__title{width:110px}.farama-header .farama-header-menu__btn-name{text-align:right;width:100px}.farama-project-heading{text-align:left}.farama-header-menu__subsections-container{display:block!important}}@media(min-width:1260px){div.highlight{max-width:60vw;width:-moz-fit-content;width:fit-content}}@media(min-width:2160px){div.highlight{max-width:50vw}}@media(prefers-color-scheme:light){body:not([data-theme=dark]) .farama-white-logo-invert,body[data-theme=dark] .farama-black-logo-invert{filter:invert(1)}}@media(prefers-color-scheme:dark){body:not([data-theme=light]) img[src*="//render.githubusercontent.com/render/math"]{filter:invert(90%)}body:not([data-theme=light]) .farama-black-logo-invert,body[data-theme=light] .farama-white-logo-invert{filter:invert(1)}} +/*# sourceMappingURL=furo-extensions.css.map*/ \ No newline at end of file diff --git a/v1.2.0/_static/styles/furo-extensions.css.map b/v1.2.0/_static/styles/furo-extensions.css.map new file mode 100644 index 000000000..866dff103 --- /dev/null +++ b/v1.2.0/_static/styles/furo-extensions.css.map @@ -0,0 +1 @@ +{"version":3,"file":"styles/furo-extensions.css","mappings":"AAGA,2BACE,oFACA,4CAKE,6CAHA,YACA,eAEA,CACA,kDACE,yCAEF,8CACE,sCAEJ,8CACE,kDAEJ,2BAGE,uBACA,cAHA,gBACA,UAEA,CAGA,yCACE,mBAEF,gDAEE,gDADA,YACA,CACA,sDACE,gDACF,yDACE,sCAEJ,+CACE,UACA,qDACE,UAGF,mDACE,eAEJ,yEAEE,4DAEA,mHASE,mBAPA,kBAEA,YADA,oBAGA,aADA,gBAIA,CAEA,qIAEE,WADA,UACA,CAEJ,uGACE,aAEF,iUAGE,cAEF,mHACE,aC1EJ,gCACE,mCAEF,0BAKE,mBAUA,8CACA,YAFA,mCAKA,eAZA,cALA,UASA,YADA,YAYA,iCAdA,YAcA,CAEA,gCAEE,8CADA,gCACA,CAEF,gCAGE,6BADA,mCADA,YAEA,CAEF,kCAEE,cADA,oBACA,CACA,wCACE,cAEJ,8BACE,UC5CN,KAEE,6CAA8C,CAC9C,uDAAwD,CACxD,uDAAwD,CAGxD,iCAAsC,CAGtC,+CAAgD,CAChD,uDAAwD,CACxD,uDAAwD,CACxD,oDAAqD,CACrD,6DAA8D,CAC9D,6DAA8D,CAG9D,uDAAwD,CACxD,yDAA0D,CAC1D,4DAA6D,CAC7D,2DAA4D,CAC5D,8DAA+D,CAC/D,iEAAkE,CAClE,uDAAwD,CACxD,wDAAyD,CAG3D,gBACE,qFAGF,SACE,6EAEF,cACE,uFAEF,cACE,uFAEF,cACE,uFAGF,qBACE,eAEF,mBACE,WACA,eChDF,KACE,gDAAiD,CACjD,uDAAwD,CACxD,qDAAsD,CACtD,4DAA6D,CAC7D,oCAAqC,CACrC,2CAA4C,CAC5C,4CAA6C,CAC7C,mDAAoD,CACpD,wBAAyB,CACzB,oBAAqB,CACrB,6CAA8C,CAC9C,gCAAiC,CACjC,yDAA0D,CAC1D,uDAAwD,CACxD,8DAA+D,CCbjE,uBACE,eACA,eACA,gBAGF,iBACE,YACA,+EAGF,iBACE,mDACA,8BCbF,GACI,iBACJ,GACI,iBACJ,GACI,iBAGJ,qCACI,kBAEJ,iCACI,kBAEJ,kCACI,mBAIJ,cAKI,mDAEA,oDACA,SAFA,sCAJA,aAOA,OALA,gBAHA,eAEA,WAOA,cAEJ,yBAEI,mBADA,aAIA,YADA,4BADA,WAEA,CAEJ,sBACI,iBAEJ,gBACI,OAIJ,YACI,iDAGA,YADA,kBAGA,eAJA,kBAGA,oCACA,CAEA,kBACI,uDAKJ,8EAEI,aACA,eAFA,iBAEA,CAEJ,0SACI,iBAER,iBACI,YAIJ,UAKI,sBAJA,aACA,eACA,uBACA,UACA,CAEJ,gBACI,aACA,sBAEA,aACA,aAFA,WAEA,CAEJ,uBACI,aACA,aACA,uBAEJ,2BACI,gBACA,yCAEJ,aAII,qBAHA,aAIA,YAHA,uBAIA,iBAHA,iBAGA,CAEJ,UAGI,cADA,iBADA,WAEA,CAIJ,qCACI,gBAEJ,KACI,2BAA4B,CAC5B,gCAAiC,CACjC,qCAAsC,CAE1C,eAMI,mDADA,mDAGA,sBANA,aAEA,mCAGA,sBANA,kBAEA,WAMA,WAEA,yCAGI,aAEA,8BADA,cAFA,iBADA,UAIA,CAEJ,iBACI,sCACA,qBACA,4BAEA,uBACI,wCAER,oCAEI,wCADA,+EACA,CAEJ,qCAKI,kBAJA,mCACA,gBACA,eACA,iBACA,CAKA,0EACI,aAER,4CACI,aAEA,kEAEI,uCACA,qCACA,iBACA,cAJA,UAIA,CAER,qCAEI,mBADA,aAEA,UAEA,yDACI,aAEA,YADA,eACA,CAEA,4DAII,mBACA,eAFA,aADA,iBADA,oBAIA,CAEA,8DAGI,mBADA,aADA,WAEA,CAEJ,+FAGI,mBADA,aAEA,YAHA,iBAGA,CAGI,mIACI,cAER,mGAEI,qCADA,UACA,CAEJ,6HAKI,yCADA,gDAGA,aANA,kBAEA,QADA,gCAIA,YACA,CAEA,gIACI,gBACA,SACA,iBAEJ,gIACI,SACA,cAEpB,yDAEI,aACA,uBAFA,iBAEA,CAEA,kFAEI,gBACA,YACA,eAHA,YAGA,CAEA,sFACI,WACJ,sFAEI,uCACA,iBACA,kBAHA,UAGA,CAGR,8FACI,4BAEJ,uFAUI,mDACA,qDAHA,sBAFA,YAMA,cAXA,eAEA,QACA,WAGA,wBAEA,iCAJA,WAHA,UAUA,CAEA,mHASI,mBAFA,uDAHA,sBAIA,aAHA,cAFA,iBAGA,iBALA,kBACA,UAOA,CAEA,qHAEI,mBADA,YACA,CAEA,+IACI,WAEJ,0HACI,sCACA,iBAER,qJAGI,sBAFA,kBACA,OACA,CAEA,4JAEI,gBACA,YACA,eAHA,YAGA,CAEA,gKAEI,sCADA,UACA,CAEhB,iHAKI,sBAJA,aAMA,eADA,cAHA,iBACA,kBAFA,UAKA,CAEA,8IAEI,mBADA,gBAEA,kBAEA,iLACI,cACA,kCACA,gBAEA,oBADA,wBACA,CAIA,yNACI,gBAEA,0OACI,kBAEJ,+PAGI,wCAFA,cACA,qCAEA,gBAEA,uBADA,wBACA,CAEZ,uKACI,gBAGA,gBAFA,SACA,SACA,CAEA,0KACI,gDAEA,gLACI,6DAEJ,4KAGI,mBAFA,aACA,iBACA,CAEA,kLACI,cAEJ,gLAEI,kBADA,UACA,CAExC,uBAEI,mBADA,aAKA,gDADA,iBADA,0DADA,kBAIA,qBAEA,2BACI,gDACA,yCACJ,4BACI,sCACA,kCACI,qBAEZ,eAGI,mBAFA,mBACA,4CACA,CAEA,uCACI,aAEA,YADA,cACA,CAEJ,mCACI,iBACA,kBAER,uBAGI,cADA,sBADA,SAEA,CAEA,0CAII,wDAnZY,CAoZZ,4BACA,6BAEA,YACA,kBANA,WAOA,eARA,iBAKA,qCANA,UASA,CAEA,gDACI,qCAEZ,sBAKI,qDADA,sBAHA,aACA,iBACA,UAEA,CAEA,6BACI,aAEJ,kDAII,mBAFA,aADA,OAIA,gBAFA,sBAEA,CAEJ,kDAEI,mBADA,YACA,CAEA,oDACI,qBAEJ,yDAQI,mBAJA,YACA,kBAIA,eAHA,aAJA,YAKA,uBANA,iBAEA,iBAMA,CAEJ,4EAEI,wDAncQ,CAocR,4BACA,6BAHA,WAKA,eADA,oCACA,CAEA,kFACI,oCAER,gFACI,qCAEA,oFACI,YAEhB,kCAEQ,2DACI,gCACA,iEACI,0BACJ,+DACI,YAGR,qDACI,mCACA,2DACQ,8BACR,yDACI,aAEhB,mCAEQ,0DACI,mCACA,gEACQ,8BACR,8DACI,YAGR,oDACI,gCACA,0DACI,0BACJ,wDACI,aAKhB,qBACI,gCAEJ,wBAGI,QAAO,CADP,mBADA,iBAEA,CAEJ,yBACI,UAGA,8BAEI,SAAQ,CADR,8BACA,CAGJ,uCACI,UAIR,qBACI,wCACA,aACA,SAEJ,4BACI,eACA,gBACA,iBAEJ,gCACI,gBAEJ,yBACI,UAEJ,wBAII,mBADA,sBADA,eADA,WAGA,CAEJ,8BACI,mBACA,gBAEJ,kCAEQ,qDACI,mBACJ,2DACI,mBAEJ,+CACI,mBACJ,qDACI,oBAEZ,mCAEQ,oDACI,mBACJ,0DACI,mBAEJ,8CACI,mBACJ,oDACI,oBAGR,+CACI,mBACJ,qDACI,mBAER,6BACI,sBACA,kBAIJ,wDACI,uCACA,kBAEJ,qBACI,gBAEJ,gEAEI,0CADA,kCACA,CAIJ,sBACI,uCACA,YACA,gBAEA,wEACI,YACA,gBAEJ,wCACI,wBAER,kCACI,iEACI,mBACA,eAGJ,0DACI,mBACA,cAKR,wBACI,mHACI,iBAEJ,6CACI,cAGR,wBACI,eACI,cAEA,qCACI,kCAGA,mHACI,iBACJ,iHACI,kBAEZ,sBACI,sBAGA,kDAEI,oBADA,cACA,CAEA,yDAEI,YADA,gBACA,CAEJ,4EACI,eAGA,oFACI,aAGpB,wBAEQ,qCACI,YAEJ,6CACI,iBACA,YAER,wBACI,gBAEJ,2CACI,yBAER,yBACI,cAEI,eADA,wCACA,EAER,yBACI,cACI,gBAER,mCAMQ,sGACI,kBAGZ,kCAGQ,oFACI,mBAMJ,wGACI","sources":["webpack:///./src/furo/assets/styles/extensions/_readthedocs.sass","webpack:///./src/furo/assets/styles/extensions/_copybutton.sass","webpack:///./src/furo/assets/styles/extensions/_sphinx-design.sass","webpack:///./src/furo/assets/styles/extensions/_sphinx-inline-tabs.sass","webpack:///./src/furo/assets/styles/extensions/_sphinx-panels.sass","webpack:///./src/furo/assets/styles/extensions/_farama.sass"],"sourcesContent":["// This file contains the styles used for tweaking how ReadTheDoc's embedded\n// contents would show up inside the theme.\n\n#furo-sidebar-ad-placement\n padding: var(--sidebar-item-spacing-vertical) var(--sidebar-item-spacing-horizontal)\n .ethical-sidebar\n // Remove the border and box-shadow.\n border: none\n box-shadow: none\n // Manage the background colors.\n background: var(--color-background-secondary)\n &:hover\n background: var(--color-background-hover)\n // Ensure the text is legible.\n a\n color: var(--color-foreground-primary)\n\n .ethical-callout a\n color: var(--color-foreground-secondary) !important\n\n#furo-readthedocs-versions\n position: static\n width: 100%\n background: transparent\n display: block\n\n // Make the background color fit with the theme's aesthetic.\n .rst-versions\n background: rgb(26, 28, 30)\n\n .rst-current-version\n cursor: unset\n background: var(--color-sidebar-item-background)\n &:hover\n background: var(--color-sidebar-item-background)\n .fa-book\n color: var(--color-foreground-primary)\n\n > .rst-other-versions\n padding: 0\n small\n opacity: 1\n\n .injected\n .rst-versions\n position: unset\n\n &:hover,\n &:focus-within\n box-shadow: 0 0 0 1px var(--color-sidebar-background-border)\n\n .rst-current-version\n // Undo the tweaks done in RTD's CSS\n font-size: inherit\n line-height: inherit\n height: auto\n text-align: right\n padding: 12px\n\n // Match the rest of the body\n background: #1a1c1e\n\n .fa-book\n float: left\n color: white\n\n .fa-caret-down\n display: none\n\n .rst-current-version,\n .rst-other-versions,\n .injected\n display: block\n\n > .rst-current-version\n display: none\n",".highlight\n &:hover button.copybtn\n color: var(--color-code-foreground)\n\n button.copybtn\n // Make it visible\n opacity: 1\n\n // Align things correctly\n align-items: center\n\n height: 1.25em\n width: 1.25em\n\n top: 0.625rem // $code-spacing-vertical\n right: 0.5rem\n\n // Make it look better\n color: var(--color-background-item)\n background-color: var(--color-code-background)\n border: none\n\n // Change to cursor to make it obvious that you can click on it\n cursor: pointer\n\n // Transition smoothly, for aesthetics\n transition: color 300ms, opacity 300ms\n\n &:hover\n color: var(--color-brand-content)\n background-color: var(--color-code-background)\n\n &::after\n display: none\n color: var(--color-code-foreground)\n background-color: transparent\n\n &.success\n transition: color 0ms\n color: #22863a\n &::after\n display: block\n\n svg\n padding: 0\n","body\n // Colors\n --sd-color-primary: var(--color-brand-primary)\n --sd-color-primary-highlight: var(--color-brand-content)\n --sd-color-primary-text: var(--color-background-primary)\n\n // Shadows\n --sd-color-shadow: rgba(0, 0, 0, 0.05)\n\n // Cards\n --sd-color-card-border: var(--color-card-border)\n --sd-color-card-border-hover: var(--color-brand-content)\n --sd-color-card-background: var(--color-card-background)\n --sd-color-card-text: var(--color-foreground-primary)\n --sd-color-card-header: var(--color-card-marginals-background)\n --sd-color-card-footer: var(--color-card-marginals-background)\n\n // Tabs\n --sd-color-tabs-label-active: var(--color-brand-content)\n --sd-color-tabs-label-hover: var(--color-foreground-muted)\n --sd-color-tabs-label-inactive: var(--color-foreground-muted)\n --sd-color-tabs-underline-active: var(--color-brand-content)\n --sd-color-tabs-underline-hover: var(--color-foreground-border)\n --sd-color-tabs-underline-inactive: var(--color-background-border)\n --sd-color-tabs-overline: var(--color-background-border)\n --sd-color-tabs-underline: var(--color-background-border)\n\n// Tabs\n.sd-tab-content\n box-shadow: 0 -2px var(--sd-color-tabs-overline), 0 1px var(--sd-color-tabs-underline)\n\n// Shadows\n.sd-card // Have a shadow by default\n box-shadow: 0 0.1rem 0.25rem var(--sd-color-shadow), 0 0 0.0625rem rgba(0, 0, 0, 0.1)\n\n.sd-shadow-sm\n box-shadow: 0 0.1rem 0.25rem var(--sd-color-shadow), 0 0 0.0625rem rgba(0, 0, 0, 0.1) !important\n\n.sd-shadow-md\n box-shadow: 0 0.3rem 0.75rem var(--sd-color-shadow), 0 0 0.0625rem rgba(0, 0, 0, 0.1) !important\n\n.sd-shadow-lg\n box-shadow: 0 0.6rem 1.5rem var(--sd-color-shadow), 0 0 0.0625rem rgba(0, 0, 0, 0.1) !important\n\n// Cards\n.sd-card-hover:hover // Don't change scale on hover\n transform: none\n\n.sd-cards-carousel // Have a bit of gap in the carousel by default\n gap: 0.25rem\n padding: 0.25rem\n","// This file contains styles to tweak sphinx-inline-tabs to work well with Furo.\n\nbody\n --tabs--label-text: var(--color-foreground-muted)\n --tabs--label-text--hover: var(--color-foreground-muted)\n --tabs--label-text--active: var(--color-brand-content)\n --tabs--label-text--active--hover: var(--color-brand-content)\n --tabs--label-background: transparent\n --tabs--label-background--hover: transparent\n --tabs--label-background--active: transparent\n --tabs--label-background--active--hover: transparent\n --tabs--padding-x: 0.25em\n --tabs--margin-x: 1em\n --tabs--border: var(--color-background-border)\n --tabs--label-border: transparent\n --tabs--label-border--hover: var(--color-foreground-muted)\n --tabs--label-border--active: var(--color-brand-content)\n --tabs--label-border--active--hover: var(--color-brand-content)\n","// This file contains styles to tweak sphinx-panels to work well with Furo.\n\n// sphinx-panels includes Bootstrap 4, which uses .container which can conflict\n// with docutils' `.. container::` directive.\n[role=\"main\"] .container\n max-width: initial\n padding-left: initial\n padding-right: initial\n\n// Make the panels look nicer!\n.shadow.docutils\n border: none\n box-shadow: 0 0.2rem 0.5rem rgba(0, 0, 0, 0.05), 0 0 0.0625rem rgba(0, 0, 0, 0.1) !important\n\n// Make panel colors respond to dark mode\n.sphinx-bs .card\n background-color: var(--color-background-secondary)\n color: var(--color-foreground)\n","// Farama Base\n\n$farama-background: linear-gradient(to right top, #765e3e, #054f5b)\n\nh1\n font-size: 2.2rem\nh2\n font-size: 1.7rem\nh3\n font-size: 1.4rem\n\n// If menu is active then all elements except the menu are not visible (i.e. only element with aria-hidden=\"true\")\nhtml:has(.farama-header-menu.active)\n visibility: hidden\n\n.farama-hidden[aria-hidden=\"true\"]\n visibility: hidden\n\n.farama-hidden[aria-hidden=\"false\"]\n visibility: visible\n\n// Cookies Alert\n\n.cookie-alert\n position: fixed\n display: flex\n width: 100%\n min-height: 70px\n background-color: var(--color-background-secondary)\n color: var(--color-foreground-primary)\n border-top: 1px solid var(--color-background-border)\n bottom: 0\n left: 0\n z-index: 99999\n\n.cookie-alert__container\n display: flex\n align-items: center\n width: 700px\n max-width: calc(100% - 28px)\n margin: auto\n\n.cookie-alert__button\n margin-left: 14px\n\n.cookie-alert p\n flex: 1\n\n// Farama default button style\n\n.farama-btn\n background: var(--color-farama-button-background)\n padding: 10px 26px\n border-radius: 6px\n border: none\n transition: background-color 0.2s ease\n cursor: pointer\n\n &:hover\n background: var(--color-farama-button-background-hover)\n\n// Env Icons\n\narticle[role=main]:has(.farama-env-icon-container)\n .farama-env-icon-container\n position: absolute\n display: flex\n margin-top: 7px\n\n .section h1:first-child, .section h2:first-child, section h1:first-child, section h2:first-child\n margin-left: 34px\n\n.farama-env-icon\n height: 32px\n\n// Envinronments grid\n\n.env-grid\n display: flex\n flex-wrap: wrap\n justify-content: center\n width: 100%\n box-sizing: border-box\n\n.env-grid__cell\n display: flex\n flex-direction: column\n width: 180px\n height: 180px\n padding: 10px\n\n.cell__image-container\n display: flex\n height: 148px\n justify-content: center\n\n.cell__image-container img\n max-height: 100%\n object-fit: contain\n\n.cell__title\n display: flex\n justify-content: center\n text-align: center\n align-items: flex-end\n height: 32px\n line-height: 16px\n\n.more-btn\n width: 240px\n margin: 12px auto\n display: block\n\n// Farama Header\n\nhtml:has(.farama-header-menu.active)\n overflow: hidden\n\nbody\n --farama-header-height: 52px\n --farama-header-logo-margin: 10px\n --farama-sidebar-logo-margin: 2px 10px\n\n.farama-header\n position: absolute\n display: flex\n width: 100%\n height: var(--farama-header-height)\n border-bottom: 1px solid var(--color-header-border)\n background-color: var(--color-background-secondary)\n padding: 0 36px 0 24px\n box-sizing: border-box\n z-index: 95\n\n .farama-header__container\n width: 100%\n max-width: 1400px\n display: flex\n margin: 0 auto\n justify-content: space-between\n\n a\n color: var(--color-foreground-primary)\n text-decoration: none\n transition: color 0.125s ease\n\n &:hover\n color: var(--color-foreground-secondary)\n\n .farama-header__logo\n max-height: calc(var(--farama-header-height) - var(--farama-header-logo-margin))\n margin: var(--farama-header-logo-margin)\n\n .farama-header__title\n font-size: var(--font-size--normal)\n font-weight: normal\n margin: 0 0 2px 0\n padding: 0 0 0 4px\n align-self: center\n\n .farama-header__left\n display: flex\n\n a\n display: flex\n\n .farama-header__left--mobile\n display: none\n\n .nav-overlay-icon svg\n width: 20px\n stroke: var(--color-foreground-primary)\n fill: var(--color-foreground-primary)\n stroke-width: 2px\n padding: 0 6px\n\n .farama-header__right\n display: flex\n align-items: center\n z-index: 2\n\n .farama-header__nav\n display: flex\n list-style: none\n height: 100%\n\n li\n text-decoration: none\n margin-left: 20px\n display: flex\n align-items: center\n cursor: pointer\n\n a\n height: 100%\n display: flex\n align-items: center\n\n .farama-header__dropdown-container\n position: relative\n display: flex\n align-items: center\n height: 100%\n\n &:hover\n .farama-header__dropdown-menu\n display: block\n\n svg\n width: 32px\n fill: var(--color-foreground-primary)\n\n .farama-header__dropdown-menu\n position: absolute\n top: var(--farama-header-height)\n right: 0\n border: 1px solid var(--color-background-border)\n background: var(--color-background-hover)\n z-index: 9999\n display: none\n\n ul\n display: inherit\n margin: 0\n padding: 6px 14px\n\n li\n margin: 0\n padding: 6px 0\n\n .farama-header-menu\n position: relative\n display: flex\n justify-content: center\n\n .farama-header-menu__btn\n display: flex\n background: none\n border: none\n cursor: pointer\n\n img\n width: 26px\n svg\n width: 14px\n stroke: var(--color-foreground-primary)\n stroke-width: 2px\n align-self: center\n\n\n &.active .farama-header-menu-container\n transform: translateY(100vh)\n\n .farama-header-menu-container\n position: fixed\n z-index: 99\n right: 0\n top: -100vh\n width: 100%\n height: calc(100vh - calc(100vh - 100%))\n transform: translateY(0)\n box-sizing: border-box\n transition: transform 0.2s ease-in\n background-color: var(--color-background-secondary)\n border-left: 1px solid var(--color-background-border)\n overflow: auto\n\n .farama-header-menu__header\n position: relative\n width: 100%\n max-width: 1400px\n box-sizing: border-box\n margin: 0 auto\n padding: 7px 52px\n border-bottom: 1px solid var(--color-background-border)\n display: flex\n align-items: center\n\n a\n display: flex\n align-items: center\n\n .farama-header-menu__logo\n width: 36px\n\n span\n color: var(--color-sidebar-brand-text)\n padding-left: 8px\n\n .farama-header-menu-header__right\n position: absolute\n right: 0\n padding-right: inherit\n\n button\n display: flex\n background: none\n border: none\n cursor: pointer\n\n svg\n width: 20px\n color: var(--color-foreground-primary)\n\n .farama-header-menu__body\n display: flex\n width: 100%\n max-width: 1500px\n padding: 22px 52px\n box-sizing: border-box\n margin: 0 auto\n flex-wrap: wrap\n\n .farama-header-menu__section\n min-width: 220px\n margin-bottom: 24px\n padding-left: 18px\n\n .farama-header-menu__section-title\n display: block\n font-size: var(--font-size--small)\n font-weight: 600\n text-transform: uppercase\n padding: 0 12px 12px\n\n .farama-header-menu__subsections-container\n\n .farama-header-menu__subsection\n min-width: 210px\n\n &:not(:last-child)\n margin-right: 12px\n\n .farama-header-menu__subsection-title\n display: block\n font-size: var(--font-size--small--3)\n color: var(--color-foreground-secondary)\n font-weight: 700\n text-transform: uppercase\n padding: 20px 12px 10px\n\n .farama-header-menu-list\n display: inherit\n margin: 0\n padding: 0\n list-style: none\n\n li\n border-radius: var(--sidebar-item-border-radius)\n\n &:hover\n background-color: var(--color-farama-header-background-hover)\n\n a\n display: flex\n padding: 12px 14px\n align-items: center\n\n &:hover\n color: inherit\n\n img\n width: 26px\n margin-right: 10px\n\n.farama-sidebar__title\n display: flex\n align-items: center\n padding-right: 4px\n min-height: calc(52px - var(--sidebar-search-space-above))\n margin-top: 0.6rem\n margin-left: var(--sidebar-search-space-lateral)\n text-decoration: none\n\n img\n height: calc(var(--farama-header-height) - 20px)\n margin: var(--farama-sidebar-logo-margin)\n span\n color: var(--color-foreground-primary)\n &:hover\n text-decoration: none\n\n.sidebar-brand\n flex-direction: row\n padding: var(--sidebar-item-spacing-vertical)\n align-items: center\n\n .sidebar-logo-container\n display: flex\n max-width: 55px\n height: auto\n\n .sidebar-brand-text\n font-size: 1.3rem\n padding-left: 11px\n\n.farama-sidebar-donate\n width: 76%\n padding: 8px 16px 20px\n margin: 0 auto\n\n .farama-donate-btn\n width: 100%\n padding: 8px 12px\n color: #fff\n background: $farama-background\n background-blend-mode: color\n background-color: transparent\n transition: background-color 0.2s ease\n border: none\n border-radius: 6px\n cursor: pointer\n\n &:hover\n background-color: rgb(255 255 255 / 15%)\n\n.farama-donate-banner\n display: none\n padding: 16px 3em\n width: 100%\n box-sizing: border-box\n background-color: var(--color-highlighted-background)\n\n &.active\n display: flex\n\n .farama-donate-banner__text\n flex: 1\n display: flex\n justify-content: center\n align-items: center\n font-size: 1.1em\n\n .farama-donate-banner__btns\n display: flex\n align-items: center\n\n a\n text-decoration: none\n\n button\n margin-left: 22px\n height: 36px\n position: relative\n border: none\n border-radius: 6px\n display: flex\n justify-content: center\n align-items: center\n cursor: pointer\n\n .farama-donate-banner__go\n color: #fff\n background: $farama-background\n background-blend-mode: color\n background-color: transparent\n transition: background-color 0.2s ease\n padding: 0 26px\n\n &:hover\n background-color: rgb(255 255 255 / 10%)\n\n .farama-donate-banner__cancel\n transition: background-color 0.2s ease\n\n svg\n height: 26px\n\n@media (prefers-color-scheme: dark)\n body:not([data-theme=\"light\"])\n .farama-donate-banner__cancel\n background-color: rgb(0 0 0 / 10%)\n &:hover\n background: rgb(0 0 0 / 20%)\n svg\n stroke: #fff\n\n body[data-theme=\"light\"]\n .farama-donate-banner__cancel\n background-color: rgb(25 25 25 / 10%)\n &:hover\n background: rgb(255 255 255 / 20%)\n svg\n stroke: #666\n\n@media (prefers-color-scheme: light)\n body:not([data-theme=\"dark\"])\n .farama-donate-banner__cancel\n background-color: rgb(25 25 25 / 10%)\n &:hover\n background: rgb(255 255 255 / 20%)\n svg\n stroke: #666\n\n body[data-theme=\"dark\"]\n .farama-donate-banner__cancel\n background-color: rgb(0 0 0 / 10%)\n &:hover\n background: rgb(0 0 0 / 20%)\n svg\n stroke: #fff\n\n\n// Farama custom directives\n\n.farama-project-logo\n margin: 1.5rem 0 0.8rem !important\n\n.farama-project-heading\n text-align: center\n padding: 0 0 1.6rem 0\n margin: 0\n\n.farama-project-logo img\n width: 65%\n\n.mobile-header\n .header-center\n transition: opacity 0.2s easy-in\n opacity: 0\n\n.mobile-header.scrolled\n .header-center\n opacity: 1\n\n// Sphinx Gallery\n\n.sphx-glr-script-out\n color: var(--color-foreground-secondary)\n display: flex\n gap: 0.5em\n\n.sphx-glr-script-out::before\n content: \"Out:\"\n line-height: 1.4\n padding-top: 10px\n\n.sphx-glr-script-out .highlight\n overflow-x: auto\n\n.sphx-glr-thumbcontainer\n z-index: 1\n\ndiv.sphx-glr-download a\n width: 340px\n max-width: 100%\n box-sizing: border-box\n background: #0f4a65\n\ndiv.sphx-glr-download a:hover\n background: #0d3a4e\n box-shadow: none\n\n@media (prefers-color-scheme: dark)\n body:not([data-theme=\"light\"])\n div.sphx-glr-download a\n background: #0f4a65\n div.sphx-glr-download a:hover\n background: #0d3a4e\n body[data-theme=\"light\"]\n div.sphx-glr-download a\n background: #f9d4a1\n div.sphx-glr-download a:hover\n background: #d9b481\n\n@media (prefers-color-scheme: light)\n body:not([data-theme=\"dark\"])\n div.sphx-glr-download a\n background: #f9d4a1\n div.sphx-glr-download a:hover\n background: #d9b481\n body[data-theme=\"dark\"]\n div.sphx-glr-download a\n background: #0f4a65\n div.sphx-glr-download a:hover\n background: #0d3a4e\n\nbody[data-theme=\"light\"]\n div.sphx-glr-download a\n background: #f9d4a1\n div.sphx-glr-download a:hover\n background: #d9b481\n\n.sphx-glr-thumbcontainer img\n background-color: white\n border-radius: 4px\n\n// Override Tabs styles\n\n.tab-content > [class^=\"highlight-\"]:first-child .highlight\n background: var(--color-api-background)\n border-radius: 6px\n\n.tab-set > input + label\n font-weight: 600\n\n.tab-set > input:checked + label, .tab-set > input:checked + label:hover\n color: var(--color-brand-secondary)\n border-color: var(--color-brand-secondary)\n\n// Sphinx Jupyter\n\ndiv.jupyter_container\n background: var(--color-api-background)\n border: none\n box-shadow: none\n\n div.code_cell, div.highlight\n border: none\n border-radius: 0\n\n div.code_cell pre\n padding: 0.625rem 0.875rem\n\n@media (prefers-color-scheme: dark)\n body:not([data-theme=\"light\"]) div.jupyter_container div.highlight\n background: #202020\n color: #d0d0d0\n\nbody[data-theme=\"dark\"]\n div.jupyter_container div.highlight\n background: #202020\n color: #d0d0d0\n\n\n\n\n@media (max-width: 950px)\n .farama-header .farama-header__right .farama-header-menu .farama-header-menu-container .farama-header-menu__header\n padding: 7px 42px\n\n .farama-header .farama-header-menu__btn-name\n display: none\n\n\n@media (max-width: 600px)\n .farama-header\n padding: 0 4px\n\n .farama-header__title\n font-size: var(--font-size--small)\n\n .farama-header__right .farama-header-menu .farama-header-menu-container\n .farama-header-menu__header\n padding: 8px 12px\n .farama-header-menu__body\n padding: 18px 12px\n\n .farama-donate-banner\n flex-direction: column\n\n\n .farama-donate-banner__btns\n margin-top: 1em\n justify-content: end\n\n button\n margin-left: 12px\n height: 36px\n\n .farama-donate-banner__go\n padding: 0 20px\n\n .farama-donate-banner__cancel\n svg\n height: 26px\n\n\n@media (max-width: 480px)\n .farama-header\n .farama-header__title\n width: 110px\n\n .farama-header-menu__btn-name\n text-align: right\n width: 100px\n\n .farama-project-heading\n text-align: left\n\n .farama-header-menu__subsections-container\n display: block !important\n\n@media (min-width: 1260px)\n div.highlight\n width: fit-content\n max-width: 60vw\n\n@media (min-width: 2160px)\n div.highlight\n max-width: 50vw\n\n@media (prefers-color-scheme: light)\n body:not([data-theme=\"dark\"])\n .farama-white-logo-invert\n filter: invert(1)\n\n body[data-theme=\"dark\"]\n .farama-black-logo-invert\n filter: invert(1)\n\n\n@media (prefers-color-scheme: dark)\n body:not([data-theme=\"light\"])\n // Github math render\n img[src*=\"//render.githubusercontent.com/render/math\"]\n filter: invert(90%)\n\n .farama-black-logo-invert\n filter: invert(1)\n\n body[data-theme=\"light\"]\n .farama-white-logo-invert\n filter: invert(1)\n"],"names":[],"sourceRoot":""} \ No newline at end of file diff --git a/v1.2.0/_static/styles/furo.css b/v1.2.0/_static/styles/furo.css new file mode 100644 index 000000000..04165d589 --- /dev/null +++ b/v1.2.0/_static/styles/furo.css @@ -0,0 +1,2 @@ +/*! normalize.css v8.0.1 | MIT License | github.com/necolas/normalize.css */html{-webkit-text-size-adjust:100%;line-height:1.15}body{margin:0}main{display:block}h1{font-size:2em;margin:.67em 0}hr{box-sizing:content-box;height:0;overflow:visible}pre{font-family:monospace,monospace;font-size:1em}a{background-color:transparent}abbr[title]{border-bottom:none;text-decoration:underline;text-decoration:underline dotted}b,strong{font-weight:bolder}code,kbd,samp{font-family:monospace,monospace;font-size:1em}sub,sup{font-size:75%;line-height:0;position:relative;vertical-align:baseline}sub{bottom:-.25em}sup{top:-.5em}img{border-style:none}button,input,optgroup,select,textarea{font-family:inherit;font-size:100%;line-height:1.15;margin:0}button,input{overflow:visible}button,select{text-transform:none}[type=button],[type=reset],[type=submit],button{-webkit-appearance:button}[type=button]::-moz-focus-inner,[type=reset]::-moz-focus-inner,[type=submit]::-moz-focus-inner,button::-moz-focus-inner{border-style:none;padding:0}[type=button]:-moz-focusring,[type=reset]:-moz-focusring,[type=submit]:-moz-focusring,button:-moz-focusring{outline:1px dotted ButtonText}fieldset{padding:.35em .75em .625em}legend{box-sizing:border-box;color:inherit;display:table;max-width:100%;padding:0;white-space:normal}progress{vertical-align:baseline}textarea{overflow:auto}[type=checkbox],[type=radio]{box-sizing:border-box;padding:0}[type=number]::-webkit-inner-spin-button,[type=number]::-webkit-outer-spin-button{height:auto}[type=search]{-webkit-appearance:textfield;outline-offset:-2px}[type=search]::-webkit-search-decoration{-webkit-appearance:none}::-webkit-file-upload-button{-webkit-appearance:button;font:inherit}details{display:block}summary{display:list-item}[hidden],template{display:none}@media print{.content-icon-container,.headerlink,.mobile-header,.related-pages{display:none!important}.highlight{border:.1pt solid var(--color-foreground-border)}a,blockquote,dl,ol,pre,table,ul{page-break-inside:avoid}caption,figure,h1,h2,h3,h4,h5,h6,img{page-break-after:avoid;page-break-inside:avoid}dl,ol,ul{page-break-before:avoid}}.visually-hidden{clip:rect(0,0,0,0)!important;border:0!important;height:1px!important;margin:-1px!important;overflow:hidden!important;padding:0!important;position:absolute!important;white-space:nowrap!important;width:1px!important}:-moz-focusring{outline:auto}body{--font-stack:-apple-system,BlinkMacSystemFont,Segoe UI,Helvetica,Arial,sans-serif,Apple Color Emoji,Segoe UI Emoji;--font-stack--monospace:"SFMono-Regular",Menlo,Consolas,Monaco,Liberation Mono,Lucida Console,monospace;--font-size--normal:100%;--font-size--small:87.5%;--font-size--small--2:81.25%;--font-size--small--3:75%;--font-size--small--4:62.5%;--sidebar-caption-font-size:var(--font-size--small);--sidebar-item-font-size:var(--font-size--small);--sidebar-search-input-font-size:var(--font-size--small);--toc-font-size:var(--font-size--small--2);--toc-font-size--mobile:var(--font-size--normal);--toc-title-font-size:var(--font-size--small--4);--admonition-font-size:0.8125rem;--admonition-title-font-size:0.8125rem;--code-font-size:var(--font-size--small--2);--api-font-size:var(--font-size--small);--header-height:calc(var(--sidebar-item-line-height) + var(--sidebar-item-spacing-vertical)*4);--header-padding:0.5rem;--sidebar-tree-space-above:1.2rem;--sidebar-tree-space-horizontal:0.5rem;--sidebar-caption-space-above:1rem;--sidebar-item-line-height:1rem;--sidebar-item-spacing-vertical:0.5rem;--sidebar-item-spacing-horizontal:1rem;--sidebar-item-height:calc(var(--sidebar-item-line-height) + var(--sidebar-item-spacing-vertical)*2);--sidebar-expander-width:var(--sidebar-item-height);--sidebar-search-space-above:1.2rem;--sidebar-search-space-lateral:0.7rem;--sidebar-search-input-spacing-vertical:0.5rem;--sidebar-search-input-spacing-horizontal:0.5rem;--sidebar-search-input-height:1.2rem;--sidebar-search-icon-size:var(--sidebar-search-input-height);--toc-title-padding:0.25rem 0;--toc-spacing-vertical:4.5rem;--toc-spacing-horizontal:1.5rem;--toc-item-spacing-vertical:0.4rem;--toc-item-spacing-horizontal:1rem;--sidebar-item-border-radius:8px;--sidebar-search-border-radius:8px;--icon-search:url('data:image/svg+xml;charset=utf-8,