Merge v1.0.0 (#682)

Co-authored-by: Kallinteris Andreas <30759571+Kallinteris-Andreas@users.noreply.github.com>
Co-authored-by: Jet <38184875+jjshoots@users.noreply.github.com>
Co-authored-by: Omar Younis <42100908+younik@users.noreply.github.com>
This commit is contained in:
Mark Towers
2023-11-07 13:27:25 +00:00
committed by GitHub
parent cf5f588433
commit 27f8e85051
256 changed files with 7051 additions and 13421 deletions

View File

@@ -1,15 +1,19 @@
name: Build main branch documentation website name: Build main branch documentation website
on: on:
push: push:
branches: [main] branches: [main]
permissions: permissions:
contents: write contents: write
jobs: jobs:
docs: docs:
name: Generate Website name: Generate Website
runs-on: ubuntu-latest runs-on: ubuntu-latest
env: env:
SPHINX_GITHUB_CHANGELOG_TOKEN: ${{ secrets.GITHUB_TOKEN }} SPHINX_GITHUB_CHANGELOG_TOKEN: ${{ secrets.GITHUB_TOKEN }}
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v3

View File

@@ -32,4 +32,4 @@ jobs:
--tag gymnasium-necessary-docker . --tag gymnasium-necessary-docker .
- name: Run tests - name: Run tests
run: | run: |
docker run gymnasium-necessary-docker pytest tests/test_core.py tests/envs/test_compatibility.py tests/envs/test_envs.py tests/spaces docker run gymnasium-necessary-docker pytest tests/test_core.py tests/envs/test_envs.py tests/spaces

View File

@@ -1,4 +1,5 @@
name: Manual Docs Versioning name: Manual Docs Versioning
on: on:
workflow_dispatch: workflow_dispatch:
inputs: inputs:
@@ -14,6 +15,7 @@ on:
permissions: permissions:
contents: write contents: write
jobs: jobs:
docs: docs:
name: Generate Website for new version name: Generate Website for new version

View File

@@ -1,10 +1,13 @@
name: Docs Versioning name: Docs Versioning
on: on:
push: push:
tags: tags:
- 'v?*.*.*' - 'v?*.*.*'
-
permissions: permissions:
contents: write contents: write
jobs: jobs:
docs: docs:
name: Generate Website for new version name: Generate Website for new version

View File

@@ -51,7 +51,7 @@ repos:
rev: 6.3.0 rev: 6.3.0
hooks: hooks:
- id: pydocstyle - id: pydocstyle
exclude: ^(gymnasium/envs/box2d)|(gymnasium/envs/classic_control)|(gymnasium/envs/mujoco)|(gymnasium/envs/toy_text)|(tests/envs)|(tests/spaces)|(tests/utils)|(tests/vector)|(tests/wrappers)|(docs/) exclude: ^(gymnasium/envs/box2d)|(gymnasium/envs/classic_control)|(gymnasium/envs/mujoco)|(gymnasium/envs/toy_text)|(tests/envs)|(tests/spaces)|(tests/utils)|(tests/vector)|(docs/)
args: args:
- --source - --source
- --explain - --explain

View File

@@ -0,0 +1,73 @@
import os.path
import gymnasium as gym
exclude_wrappers = {"vector"}
def generate_wrappers():
wrapper_table = ""
for wrapper_name in sorted(gym.wrappers.__all__):
if wrapper_name not in exclude_wrappers:
wrapper_doc = getattr(gym.wrappers, wrapper_name).__doc__.split("\n")[0]
wrapper_table += f""" * - :class:`{wrapper_name}`
- {wrapper_doc}
"""
return wrapper_table
def generate_vector_wrappers():
unique_vector_wrappers = set(gym.wrappers.vector.__all__) - set(
gym.wrappers.__all__
)
vector_table = ""
for vector_name in sorted(unique_vector_wrappers):
vector_doc = getattr(gym.wrappers.vector, vector_name).__doc__.split("\n")[0]
vector_table += f""" * - :class:`{vector_name}`
- {vector_doc}
"""
return vector_table
if __name__ == "__main__":
gen_wrapper_table = generate_wrappers()
gen_vector_table = generate_vector_wrappers()
page = f"""
# List of Gymnasium Wrappers
Gymnasium provides a number of commonly used wrappers listed below. More information can be found on the particular
wrapper in the page on the wrapper type
```{{eval-rst}}
.. py:currentmodule:: gymnasium.wrappers
.. list-table::
:header-rows: 1
* - Name
- Description
{gen_wrapper_table}
```
## Vector only Wrappers
```{{eval-rst}}
.. py:currentmodule:: gymnasium.wrappers.vector
.. list-table::
:header-rows: 1
* - Name
- Description
{gen_vector_table}
```
"""
filename = os.path.join(
os.path.dirname(__file__), "..", "api", "wrappers", "table.md"
)
with open(filename, "w") as file:
file.write(page)

View File

@@ -1,29 +1,26 @@
--- ---
title: Utils title: Env
--- ---
# Env # Env
## gymnasium.Env
```{eval-rst} ```{eval-rst}
.. autoclass:: gymnasium.Env .. autoclass:: gymnasium.Env
``` ```
### Methods ## Methods
```{eval-rst} ```{eval-rst}
.. autofunction:: gymnasium.Env.step .. automethod:: gymnasium.Env.step
.. autofunction:: gymnasium.Env.reset .. automethod:: gymnasium.Env.reset
.. autofunction:: gymnasium.Env.render .. automethod:: gymnasium.Env.render
.. automethod:: gymnasium.Env.close
``` ```
### Attributes ## Attributes
```{eval-rst} ```{eval-rst}
.. autoattribute:: gymnasium.Env.action_space .. autoattribute:: gymnasium.Env.action_space
The Space object corresponding to valid actions, all valid actions should be contained with the space. For example, if the action space is of type `Discrete` and gives the value `Discrete(2)`, this means there are two valid discrete actions: 0 & 1. The Space object corresponding to valid actions, all valid actions should be contained with the space. For example, if the action space is of type `Discrete` and gives the value `Discrete(2)`, this means there are two valid discrete actions: `0` & `1`.
.. code:: .. code::
@@ -51,29 +48,26 @@ title: Utils
The render mode of the environment determined at initialisation The render mode of the environment determined at initialisation
.. autoattribute:: gymnasium.Env.reward_range
A tuple corresponding to the minimum and maximum possible rewards for an agent over an episode. The default reward range is set to :math:`(-\infty,+\infty)`.
.. autoattribute:: gymnasium.Env.spec .. autoattribute:: gymnasium.Env.spec
The ``EnvSpec`` of the environment normally set during :py:meth:`gymnasium.make` The :class:`EnvSpec` of the environment normally set during :py:meth:`gymnasium.make`
```
### Additional Methods
```{eval-rst}
.. autofunction:: gymnasium.Env.close
.. autoproperty:: gymnasium.Env.unwrapped .. autoproperty:: gymnasium.Env.unwrapped
.. autoproperty:: gymnasium.Env.np_random .. autoproperty:: gymnasium.Env.np_random
``` ```
### Implementing environments ## Implementing environments
```{eval-rst} ```{eval-rst}
.. py:currentmodule:: gymnasium .. py:currentmodule:: gymnasium
When implementing an environment, the :meth:`Env.reset` and :meth:`Env.step` functions much be created describing the When implementing an environment, the :meth:`Env.reset` and :meth:`Env.step` functions much be created describing the dynamics of the environment. For more information see the environment creation tutorial.
dynamics of the environment. ```
For more information see the environment creation tutorial.
## Creating environments
```{eval-rst}
.. py:currentmodule:: gymnasium
To create an environment, gymnasium provides :meth:`make` to initialise the environment along with several important wrappers. Furthermore, gymnasium provides :meth:`make_vec` for creating vector environments and to view all the environment that can be created use :meth:`pprint_registry`.
``` ```

View File

@@ -1,157 +0,0 @@
---
title: Experimental
---
# Experimental
```{toctree}
:hidden:
experimental/functional
experimental/wrappers
experimental/vector
experimental/vector_wrappers
experimental/vector_utils
```
## Functional Environments
```{eval-rst}
The gymnasium ``Env`` provides high flexibility for the implementation of individual environments however this can complicate parallelism of environments. Therefore, we propose the :class:`gymnasium.experimental.FuncEnv` where each part of environment has its own function related to it.
```
## Wrappers
Gymnasium already contains a large collection of wrappers, but we believe that the wrappers can be improved to
* (Work in progress) Support arbitrarily complex observation / action spaces. As RL has advanced, action and observation spaces are becoming more complex and the current wrappers were not implemented with this mind.
* Support for Jax-based environments. With hardware accelerated environments, i.e. Brax, written in Jax and similar PyTorch based programs, NumPy is not the only game in town anymore. Therefore, these upgrades will use [Jumpy](https://github.com/farama-Foundation/jumpy), a project developed by Farama Foundation to provide automatic compatibility for NumPy, Jax and in the future PyTorch data for a large subset of the NumPy functions.
* More wrappers. Projects like [Supersuit](https://github.com/farama-Foundation/supersuit) aimed to bring more wrappers for RL, however, many users were not aware of the wrappers, so we plan to move the wrappers into Gymnasium. If we are missing common wrappers from the list provided above, please create an issue.
* Versioning. Like environments, the implementation details of wrappers can cause changes in agent performance. Therefore, we propose adding version numbers to all wrappers, i.e., `LambaActionV0`. We don't expect these version numbers to change regularly similar to environment version numbers and should ensure that all users know when significant changes could affect your agent's performance. Additionally, we hope that this will improve reproducibility of RL in the future, this is critical for academia.
* In v28, we aim to rewrite the VectorEnv to not inherit from Env, as a result new vectorized versions of the wrappers will be provided.
We aimed to replace the wrappers in gymnasium v0.30.0 with these experimental wrappers.
### Observation Wrappers
```{eval-rst}
.. py:currentmodule:: gymnasium
.. list-table::
:header-rows: 1
* - Old name
- New name
* - :class:`wrappers.TransformObservation`
- :class:`experimental.wrappers.LambdaObservationV0`
* - :class:`wrappers.FilterObservation`
- :class:`experimental.wrappers.FilterObservationV0`
* - :class:`wrappers.FlattenObservation`
- :class:`experimental.wrappers.FlattenObservationV0`
* - :class:`wrappers.GrayScaleObservation`
- :class:`experimental.wrappers.GrayscaleObservationV0`
* - :class:`wrappers.ResizeObservation`
- :class:`experimental.wrappers.ResizeObservationV0`
* - `supersuit.reshape_v0 <https://github.com/Farama-Foundation/SuperSuit/blob/314831a7d18e7254f455b181694c049908f95155/supersuit/generic_wrappers/basic_wrappers.py#L40>`_
- :class:`experimental.wrappers.ReshapeObservationV0`
* - Not Implemented
- :class:`experimental.wrappers.RescaleObservationV0`
* - `supersuit.dtype_v0 <https://github.com/Farama-Foundation/SuperSuit/blob/314831a7d18e7254f455b181694c049908f95155/supersuit/generic_wrappers/basic_wrappers.py#L32>`_
- :class:`experimental.wrappers.DtypeObservationV0`
* - :class:`wrappers.PixelObservationWrapper`
- :class:`experimental.wrappers.PixelObservationV0`
* - :class:`wrappers.NormalizeObservation`
- :class:`experimental.wrappers.NormalizeObservationV0`
* - :class:`wrappers.TimeAwareObservation`
- :class:`experimental.wrappers.TimeAwareObservationV0`
* - :class:`wrappers.FrameStack`
- :class:`experimental.wrappers.FrameStackObservationV0`
* - `supersuit.delay_observations_v0 <https://github.com/Farama-Foundation/SuperSuit/blob/314831a7d18e7254f455b181694c049908f95155/supersuit/generic_wrappers/delay_observations.py#L6>`_
- :class:`experimental.wrappers.DelayObservationV0`
```
### Action Wrappers
```{eval-rst}
.. py:currentmodule:: gymnasium
.. list-table::
:header-rows: 1
* - Old name
- New name
* - `supersuit.action_lambda_v1 <https://github.com/Farama-Foundation/SuperSuit/blob/314831a7d18e7254f455b181694c049908f95155/supersuit/lambda_wrappers/action_lambda.py#L73>`_
- :class:`experimental.wrappers.LambdaActionV0`
* - :class:`wrappers.ClipAction`
- :class:`experimental.wrappers.ClipActionV0`
* - :class:`wrappers.RescaleAction`
- :class:`experimental.wrappers.RescaleActionV0`
* - `supersuit.sticky_actions_v0 <https://github.com/Farama-Foundation/SuperSuit/blob/314831a7d18e7254f455b181694c049908f95155/supersuit/generic_wrappers/sticky_actions.py#L6>`_
- :class:`experimental.wrappers.StickyActionV0`
```
### Reward Wrappers
```{eval-rst}
.. py:currentmodule:: gymnasium
.. list-table::
:header-rows: 1
* - Old name
- New name
* - :class:`wrappers.TransformReward`
- :class:`experimental.wrappers.LambdaRewardV0`
* - `supersuit.clip_reward_v0 <https://github.com/Farama-Foundation/SuperSuit/blob/314831a7d18e7254f455b181694c049908f95155/supersuit/generic_wrappers/basic_wrappers.py#L74>`_
- :class:`experimental.wrappers.ClipRewardV0`
* - :class:`wrappers.NormalizeReward`
- :class:`experimental.wrappers.NormalizeRewardV1`
```
### Common Wrappers
```{eval-rst}
.. py:currentmodule:: gymnasium
.. list-table::
:header-rows: 1
* - Old name
- New name
* - :class:`wrappers.AutoResetWrapper`
- :class:`experimental.wrappers.AutoresetV0`
* - :class:`wrappers.PassiveEnvChecker`
- :class:`experimental.wrappers.PassiveEnvCheckerV0`
* - :class:`wrappers.OrderEnforcing`
- :class:`experimental.wrappers.OrderEnforcingV0`
* - :class:`wrappers.EnvCompatibility`
- Moved to `shimmy <https://github.com/Farama-Foundation/Shimmy/blob/main/shimmy/openai_gym_compatibility.py>`_
* - :class:`wrappers.RecordEpisodeStatistics`
- :class:`experimental.wrappers.RecordEpisodeStatisticsV0`
* - :class:`wrappers.AtariPreprocessing`
- :class:`experimental.wrappers.AtariPreprocessingV0`
```
### Rendering Wrappers
```{eval-rst}
.. py:currentmodule:: gymnasium
.. list-table::
:header-rows: 1
* - Old name
- New name
* - :class:`wrappers.RecordVideo`
- :class:`experimental.wrappers.RecordVideoV0`
* - :class:`wrappers.HumanRendering`
- :class:`experimental.wrappers.HumanRenderingV0`
* - :class:`wrappers.RenderCollection`
- :class:`experimental.wrappers.RenderCollectionV0`
```
### Environment data conversion
```{eval-rst}
.. py:currentmodule:: gymnasium
* :class:`experimental.wrappers.JaxToNumpyV0`
* :class:`experimental.wrappers.JaxToTorchV0`
* :class:`experimental.wrappers.NumpyToTorchV0`
```

View File

@@ -1,37 +0,0 @@
---
title: Functional
---
# Functional Environment
## gymnasium.experimental.FuncEnv
```{eval-rst}
.. autoclass:: gymnasium.experimental.functional.FuncEnv
.. autofunction:: gymnasium.experimental.functional.FuncEnv.initial
.. autofunction:: gymnasium.experimental.functional.FuncEnv.transition
.. autofunction:: gymnasium.experimental.functional.FuncEnv.observation
.. autofunction:: gymnasium.experimental.functional.FuncEnv.reward
.. autofunction:: gymnasium.experimental.functional.FuncEnv.terminal
.. autofunction:: gymnasium.experimental.functional.FuncEnv.state_info
.. autofunction:: gymnasium.experimental.functional.FuncEnv.step_info
.. autofunction:: gymnasium.experimental.functional.FuncEnv.transform
.. autofunction:: gymnasium.experimental.functional.FuncEnv.render_image
.. autofunction:: gymnasium.experimental.functional.FuncEnv.render_init
.. autofunction:: gymnasium.experimental.functional.FuncEnv.render_close
```
## gymnasium.experimental.func2env.FunctionalJaxCompatibilityEnv
```{eval-rst}
.. autoclass:: gymnasium.experimental.functional_jax_env.FunctionalJaxEnv
.. autofunction:: gymnasium.experimental.functional_jax_env.FunctionalJaxEnv.reset
.. autofunction:: gymnasium.experimental.functional_jax_env.FunctionalJaxEnv.step
.. autofunction:: gymnasium.experimental.functional_jax_env.FunctionalJaxEnv.render
```

View File

@@ -1,39 +0,0 @@
---
title: Vector
---
# Vectorizing Environment
## gymnasium.experimental.VectorEnv
```{eval-rst}
.. autoclass:: gymnasium.experimental.vector.VectorEnv
.. autofunction:: gymnasium.experimental.vector.VectorEnv.reset
.. autofunction:: gymnasium.experimental.vector.VectorEnv.step
.. autofunction:: gymnasium.experimental.vector.VectorEnv.close
.. autofunction:: gymnasium.experimental.vector.VectorEnv.reset
```
## gymnasium.experimental.vector.AsyncVectorEnv
```{eval-rst}
.. autoclass:: gymnasium.experimental.vector.AsyncVectorEnv
.. autofunction:: gymnasium.experimental.vector.AsyncVectorEnv.reset
.. autofunction:: gymnasium.experimental.vector.AsyncVectorEnv.step
.. autofunction:: gymnasium.experimental.vector.AsyncVectorEnv.close
.. autofunction:: gymnasium.experimental.vector.AsyncVectorEnv.call
.. autofunction:: gymnasium.experimental.vector.AsyncVectorEnv.get_attr
.. autofunction:: gymnasium.experimental.vector.AsyncVectorEnv.set_attr
```
## gymnasium.experimental.vector.SyncVectorEnv
```{eval-rst}
.. autoclass:: gymnasium.experimental.vector.SyncVectorEnv
.. autofunction:: gymnasium.experimental.vector.SyncVectorEnv.reset
.. autofunction:: gymnasium.experimental.vector.SyncVectorEnv.step
.. autofunction:: gymnasium.experimental.vector.SyncVectorEnv.close
.. autofunction:: gymnasium.experimental.vector.SyncVectorEnv.call
.. autofunction:: gymnasium.experimental.vector.SyncVectorEnv.get_attr
.. autofunction:: gymnasium.experimental.vector.SyncVectorEnv.set_attr
```

View File

@@ -1,29 +0,0 @@
---
title: Vector Utility
---
# Utility functions for vectorisation
## Spaces utility functions
```{eval-rst}
.. autofunction:: gymnasium.experimental.vector.utils.batch_space
.. autofunction:: gymnasium.experimental.vector.utils.concatenate
.. autofunction:: gymnasium.experimental.vector.utils.iterate
.. autofunction:: gymnasium.experimental.vector.utils.create_empty_array
```
## Shared memory functions
```{eval-rst}
.. autofunction:: gymnasium.experimental.vector.utils.create_shared_memory
.. autofunction:: gymnasium.experimental.vector.utils.read_from_shared_memory
.. autofunction:: gymnasium.experimental.vector.utils.write_to_shared_memory
```
## Miscellaneous
```{eval-rst}
.. autofunction:: gymnasium.experimental.vector.utils.CloudpickleWrapper
.. autofunction:: gymnasium.experimental.vector.utils.clear_mpi_env_vars
```

View File

@@ -1,53 +0,0 @@
---
title: Vector Wrappers
---
# Vector Environment Wrappers
```{eval-rst}
.. autoclass:: gymnasium.experimental.vector.VectorWrapper
```
## Vector Observation Wrappers
```{eval-rst}
.. autoclass:: gymnasium.experimental.vector.VectorObservationWrapper
.. autoclass:: gymnasium.experimental.wrappers.vector.LambdaObservationV0
.. autoclass:: gymnasium.experimental.wrappers.vector.FilterObservationV0
.. autoclass:: gymnasium.experimental.wrappers.vector.FlattenObservationV0
.. autoclass:: gymnasium.experimental.wrappers.vector.GrayscaleObservationV0
.. autoclass:: gymnasium.experimental.wrappers.vector.ResizeObservationV0
.. autoclass:: gymnasium.experimental.wrappers.vector.ReshapeObservationV0
.. autoclass:: gymnasium.experimental.wrappers.vector.RescaleObservationV0
.. autoclass:: gymnasium.experimental.wrappers.vector.DtypeObservationV0
```
## Vector Action Wrappers
```{eval-rst}
.. autoclass:: gymnasium.experimental.vector.VectorActionWrapper
.. autoclass:: gymnasium.experimental.wrappers.vector.LambdaActionV0
.. autoclass:: gymnasium.experimental.wrappers.vector.ClipActionV0
.. autoclass:: gymnasium.experimental.wrappers.vector.RescaleActionV0
```
## Vector Reward Wrappers
```{eval-rst}
.. autoclass:: gymnasium.experimental.vector.VectorRewardWrapper
.. autoclass:: gymnasium.experimental.wrappers.vector.LambdaRewardV0
.. autoclass:: gymnasium.experimental.wrappers.vector.ClipRewardV0
```
## More Vector Wrappers
```{eval-rst}
.. autoclass:: gymnasium.experimental.wrappers.vector.RecordEpisodeStatisticsV0
.. autoclass:: gymnasium.experimental.wrappers.vector.DictInfoToListV0
.. autoclass:: gymnasium.experimental.wrappers.vector.VectorizeLambdaObservationV0
.. autoclass:: gymnasium.experimental.wrappers.vector.VectorizeLambdaActionV0
.. autoclass:: gymnasium.experimental.wrappers.vector.VectorizeLambdaRewardV0
.. autoclass:: gymnasium.experimental.wrappers.vector.JaxToNumpyV0
.. autoclass:: gymnasium.experimental.wrappers.vector.JaxToTorchV0
.. autoclass:: gymnasium.experimental.wrappers.vector.NumpyToTorchV0
```

View File

@@ -1,65 +0,0 @@
---
title: Wrappers
---
# Wrappers
## Observation Wrappers
```{eval-rst}
.. autoclass:: gymnasium.experimental.wrappers.LambdaObservationV0
.. autoclass:: gymnasium.experimental.wrappers.FilterObservationV0
.. autoclass:: gymnasium.experimental.wrappers.FlattenObservationV0
.. autoclass:: gymnasium.experimental.wrappers.GrayscaleObservationV0
.. autoclass:: gymnasium.experimental.wrappers.ResizeObservationV0
.. autoclass:: gymnasium.experimental.wrappers.ReshapeObservationV0
.. autoclass:: gymnasium.experimental.wrappers.RescaleObservationV0
.. autoclass:: gymnasium.experimental.wrappers.DtypeObservationV0
.. autoclass:: gymnasium.experimental.wrappers.PixelObservationV0
.. autoclass:: gymnasium.experimental.wrappers.NormalizeObservationV0
.. autoclass:: gymnasium.experimental.wrappers.TimeAwareObservationV0
.. autoclass:: gymnasium.experimental.wrappers.FrameStackObservationV0
.. autoclass:: gymnasium.experimental.wrappers.DelayObservationV0
.. autoclass:: gymnasium.experimental.wrappers.AtariPreprocessingV0
```
## Action Wrappers
```{eval-rst}
.. autoclass:: gymnasium.experimental.wrappers.LambdaActionV0
.. autoclass:: gymnasium.experimental.wrappers.ClipActionV0
.. autoclass:: gymnasium.experimental.wrappers.RescaleActionV0
.. autoclass:: gymnasium.experimental.wrappers.StickyActionV0
```
## Reward Wrappers
```{eval-rst}
.. autoclass:: gymnasium.experimental.wrappers.LambdaRewardV0
.. autoclass:: gymnasium.experimental.wrappers.ClipRewardV0
.. autoclass:: gymnasium.experimental.wrappers.NormalizeRewardV1
```
## Other Wrappers
```{eval-rst}
.. autoclass:: gymnasium.experimental.wrappers.AutoresetV0
.. autoclass:: gymnasium.experimental.wrappers.PassiveEnvCheckerV0
.. autoclass:: gymnasium.experimental.wrappers.OrderEnforcingV0
.. autoclass:: gymnasium.experimental.wrappers.RecordEpisodeStatisticsV0
```
## Rendering Wrappers
```{eval-rst}
.. autoclass:: gymnasium.experimental.wrappers.RecordVideoV0
.. autoclass:: gymnasium.experimental.wrappers.HumanRenderingV0
.. autoclass:: gymnasium.experimental.wrappers.RenderCollectionV0
```
## Environment data conversion
```{eval-rst}
.. autoclass:: gymnasium.experimental.wrappers.JaxToNumpyV0
.. autoclass:: gymnasium.experimental.wrappers.JaxToTorchV0
.. autoclass:: gymnasium.experimental.wrappers.NumpyToTorchV0
```

34
docs/api/functional.md Normal file
View File

@@ -0,0 +1,34 @@
---
title: Functional
---
# Functional Env
```{eval-rst}
.. autoclass:: gymnasium.functional.FuncEnv
.. automethod:: gymnasium.functional.FuncEnv.transform
.. automethod:: gymnasium.functional.FuncEnv.initial
.. automethod:: gymnasium.functional.FuncEnv.initial_info
.. automethod:: gymnasium.functional.FuncEnv.transition
.. automethod:: gymnasium.functional.FuncEnv.observation
.. automethod:: gymnasium.functional.FuncEnv.reward
.. automethod:: gymnasium.functional.FuncEnv.terminal
.. automethod:: gymnasium.functional.FuncEnv.transition_info
.. automethod:: gymnasium.functional.FuncEnv.render_image
.. automethod:: gymnasium.functional.FuncEnv.render_initialise
.. automethod:: gymnasium.functional.FuncEnv.render_close
```
## Converting Jax-based Functional environments to standard Env
```{eval-rst}
.. autoclass:: gymnasium.utils.functional_jax_env.FunctionalJaxEnv
.. automethod:: gymnasium.utils.functional_jax_env.FunctionalJaxEnv.reset
.. automethod:: gymnasium.utils.functional_jax_env.FunctionalJaxEnv.step
.. automethod:: gymnasium.utils.functional_jax_env.FunctionalJaxEnv.render
```

View File

@@ -2,14 +2,13 @@
title: Registry title: Registry
--- ---
# Register and Make # Make and register
```{eval-rst} ```{eval-rst}
Gymnasium allows users to automatically load environments, pre-wrapped with several important wrappers through the :meth:`gymnasium.make` function. To do this, the environment must be registered prior with :meth:`gymnasium.register`. To get the environment specifications for a registered environment, use :meth:`gymnasium.spec` and to print the whole registry, use :meth:`gymnasium.pprint_registry`. Gymnasium allows users to automatically load environments, pre-wrapped with several important wrappers through the :meth:`gymnasium.make` function. To do this, the environment must be registered prior with :meth:`gymnasium.register`. To get the environment specifications for a registered environment, use :meth:`gymnasium.spec` and to print the whole registry, use :meth:`gymnasium.pprint_registry`.
```
```{eval-rst}
.. autofunction:: gymnasium.make .. autofunction:: gymnasium.make
.. autofunction:: gymnasium.make_vec
.. autofunction:: gymnasium.register .. autofunction:: gymnasium.register
.. autofunction:: gymnasium.spec .. autofunction:: gymnasium.spec
.. autofunction:: gymnasium.pprint_registry .. autofunction:: gymnasium.pprint_registry
@@ -19,6 +18,7 @@ Gymnasium allows users to automatically load environments, pre-wrapped with seve
```{eval-rst} ```{eval-rst}
.. autoclass:: gymnasium.envs.registration.EnvSpec .. autoclass:: gymnasium.envs.registration.EnvSpec
.. autoclass:: gymnasium.envs.registration.WrapperSpec
.. attribute:: gymnasium.envs.registration.registry .. attribute:: gymnasium.envs.registration.registry
The Global registry for gymnasium which is where environment specifications are stored by :meth:`gymnasium.register` and from which :meth:`gymnasium.make` is used to create environments. The Global registry for gymnasium which is where environment specifications are stored by :meth:`gymnasium.register` and from which :meth:`gymnasium.make` is used to create environments.
@@ -36,5 +36,4 @@ Gymnasium allows users to automatically load environments, pre-wrapped with seve
.. autofunction:: gymnasium.envs.registration.find_highest_version .. autofunction:: gymnasium.envs.registration.find_highest_version
.. autofunction:: gymnasium.envs.registration.namespace .. autofunction:: gymnasium.envs.registration.namespace
.. autofunction:: gymnasium.envs.registration.load_env_creator .. autofunction:: gymnasium.envs.registration.load_env_creator
.. autofunction:: gymnasium.envs.registration.load_plugin_envs
``` ```

View File

@@ -9,39 +9,38 @@ title: Spaces
spaces/fundamental spaces/fundamental
spaces/composite spaces/composite
spaces/utils spaces/utils
spaces/vector_utils vector/utils
``` ```
```{eval-rst} ```{eval-rst}
.. automodule:: gymnasium.spaces .. automodule:: gymnasium.spaces
```
## The Base Class
```{eval-rst}
.. autoclass:: gymnasium.spaces.Space .. autoclass:: gymnasium.spaces.Space
``` ```
### Attributes ## Attributes
```{eval-rst} ```{eval-rst}
.. autoproperty:: gymnasium.spaces.space.Space.shape .. py:currentmodule:: gymnasium.spaces
.. autoproperty:: Space.shape
.. property:: Space.dtype .. property:: Space.dtype
Return the data type of this space. Return the data type of this space.
.. autoproperty:: gymnasium.spaces.space.Space.is_np_flattenable .. autoproperty:: Space.is_np_flattenable
.. autoproperty:: Space.np_random
``` ```
### Methods ## Methods
Each space implements the following functions: Each space implements the following functions:
```{eval-rst} ```{eval-rst}
.. autofunction:: gymnasium.spaces.space.Space.sample .. py:currentmodule:: gymnasium.spaces
.. autofunction:: gymnasium.spaces.space.Space.contains
.. autofunction:: gymnasium.spaces.space.Space.seed .. automethod:: Space.sample
.. autofunction:: gymnasium.spaces.space.Space.to_jsonable .. automethod:: Space.contains
.. autofunction:: gymnasium.spaces.space.Space.from_jsonable .. automethod:: Space.seed
.. automethod:: Space.to_jsonable
.. automethod:: Space.from_jsonable
``` ```
## Fundamental Spaces ## Fundamental Spaces
@@ -49,13 +48,13 @@ Each space implements the following functions:
Gymnasium has a number of fundamental spaces that are used as building boxes for more complex spaces. Gymnasium has a number of fundamental spaces that are used as building boxes for more complex spaces.
```{eval-rst} ```{eval-rst}
.. currentmodule:: gymnasium.spaces .. py:currentmodule:: gymnasium.spaces
* :py:class:`Box` - Supports continuous (and discrete) vectors or matrices, used for vector observations, images, etc * :class:`Box` - Supports continuous (and discrete) vectors or matrices, used for vector observations, images, etc
* :py:class:`Discrete` - Supports a single discrete number of values with an optional start for the values * :class:`Discrete` - Supports a single discrete number of values with an optional start for the values
* :py:class:`MultiBinary` - Supports single or matrices of binary values, used for holding down a button or if an agent has an object * :class:`MultiBinary` - Supports single or matrices of binary values, used for holding down a button or if an agent has an object
* :py:class:`MultiDiscrete` - Supports multiple discrete values with multiple axes, used for controller actions * :class:`MultiDiscrete` - Supports multiple discrete values with multiple axes, used for controller actions
* :py:class:`Text` - Supports strings, used for passing agent messages, mission details, etc * :class:`Text` - Supports strings, used for passing agent messages, mission details, etc
``` ```
## Composite Spaces ## Composite Spaces
@@ -63,37 +62,41 @@ Gymnasium has a number of fundamental spaces that are used as building boxes for
Often environment spaces require joining fundamental spaces together for vectorised environments, separate agents or readability of the space. Often environment spaces require joining fundamental spaces together for vectorised environments, separate agents or readability of the space.
```{eval-rst} ```{eval-rst}
* :py:class:`Dict` - Supports a dictionary of keys and subspaces, used for a fixed number of unordered spaces .. py:currentmodule:: gymnasium.spaces
* :py:class:`Tuple` - Supports a tuple of subspaces, used for multiple for a fixed number of ordered spaces
* :py:class:`Sequence` - Supports a variable number of instances of a single subspace, used for entities spaces or selecting a variable number of actions * :class:`Dict` - Supports a dictionary of keys and subspaces, used for a fixed number of unordered spaces
* :class:`Tuple` - Supports a tuple of subspaces, used for multiple for a fixed number of ordered spaces
* :class:`Sequence` - Supports a variable number of instances of a single subspace, used for entities spaces or selecting a variable number of actions
* :py:class:`Graph` - Supports graph based actions or observations with discrete or continuous nodes and edge values. * :py:class:`Graph` - Supports graph based actions or observations with discrete or continuous nodes and edge values.
``` ```
## Utils ## Utility functions
Gymnasium contains a number of helpful utility functions for flattening and unflattening spaces. Gymnasium contains a number of helpful utility functions for flattening and unflattening spaces.
This can be important for passing information to neural networks. This can be important for passing information to neural networks.
```{eval-rst} ```{eval-rst}
* :py:class:`utils.flatdim` - The number of dimensions the flattened space will contain .. py:currentmodule:: gymnasium.spaces
* :py:class:`utils.flatten_space` - Flattens a space for which the `flattened` space instances will contain
* :py:class:`utils.flatten` - Flattens an instance of a space that is contained within the flattened version of the space * :class:`utils.flatdim` - The number of dimensions the flattened space will contain
* :py:class:`utils.unflatten` - The reverse of the `flatten_space` function * :class:`utils.flatten_space` - Flattens a space for which the :class:`utils.flattened` space instances will contain
* :class:`utils.flatten` - Flattens an instance of a space that is contained within the flattened version of the space
* :class:`utils.unflatten` - The reverse of the :class:`utils.flatten_space` function
``` ```
## Vector Utils ## Vector Utility functions
When vectorizing environments, it is necessary to modify the observation and action spaces for new batched spaces sizes. When vectorizing environments, it is necessary to modify the observation and action spaces for new batched spaces sizes.
Therefore, Gymnasium provides a number of additional functions used when using a space with a Vector environment. Therefore, Gymnasium provides a number of additional functions used when using a space with a Vector environment.
```{eval-rst} ```{eval-rst}
.. currentmodule:: gymnasium .. py:currentmodule:: gymnasium
* :py:class:`vector.utils.batch_space` * :class:`vector.utils.batch_space` - Transforms a space into the equivalent space for ``n`` users
* :py:class:`vector.utils.concatenate` * :class:`vector.utils.concatenate` - Concatenates a space's samples into a pre-generated space
* :py:class:`vector.utils.iterate` * :class:`vector.utils.iterate` - Iterate over the batched space's samples
* :py:class:`vector.utils.create_empty_array` * :class:`vector.utils.create_empty_array` - Creates an empty sample for an space (generally used with ``concatenate``)
* :py:class:`vector.utils.create_shared_memory` * :class:`vector.utils.create_shared_memory` - Creates a shared memory for asynchronous (multiprocessing) environment
* :py:class:`vector.utils.read_from_shared_memory` * :class:`vector.utils.read_from_shared_memory` - Reads a shared memory for asynchronous (multiprocessing) environment
* :py:class:`vector.utils.write_to_shared_memory` * :class:`vector.utils.write_to_shared_memory` - Write to a shared memory for asynchronous (multiprocessing) environment
``` ```

View File

@@ -1,35 +1,22 @@
# Composite Spaces # Composite Spaces
## Dict
```{eval-rst} ```{eval-rst}
.. autoclass:: gymnasium.spaces.Dict .. autoclass:: gymnasium.spaces.Dict
.. automethod:: gymnasium.spaces.Dict.sample .. automethod:: gymnasium.spaces.Dict.sample
.. automethod:: gymnasium.spaces.Dict.seed .. automethod:: gymnasium.spaces.Dict.seed
```
## Tuple
```{eval-rst}
.. autoclass:: gymnasium.spaces.Tuple .. autoclass:: gymnasium.spaces.Tuple
.. automethod:: gymnasium.spaces.Tuple.sample .. automethod:: gymnasium.spaces.Tuple.sample
.. automethod:: gymnasium.spaces.Tuple.seed .. automethod:: gymnasium.spaces.Tuple.seed
```
## Sequence
```{eval-rst}
.. autoclass:: gymnasium.spaces.Sequence .. autoclass:: gymnasium.spaces.Sequence
.. automethod:: gymnasium.spaces.Sequence.sample .. automethod:: gymnasium.spaces.Sequence.sample
.. automethod:: gymnasium.spaces.Sequence.seed .. automethod:: gymnasium.spaces.Sequence.seed
```
## Graph
```{eval-rst}
.. autoclass:: gymnasium.spaces.Graph .. autoclass:: gymnasium.spaces.Graph
.. automethod:: gymnasium.spaces.Graph.sample .. automethod:: gymnasium.spaces.Graph.sample

View File

@@ -4,44 +4,28 @@ title: Fundamental Spaces
# Fundamental Spaces # Fundamental Spaces
## Box
```{eval-rst} ```{eval-rst}
.. autoclass:: gymnasium.spaces.Box .. autoclass:: gymnasium.spaces.Box
.. automethod:: gymnasium.spaces.Box.sample .. automethod:: gymnasium.spaces.Box.sample
.. automethod:: gymnasium.spaces.Box.seed .. automethod:: gymnasium.spaces.Box.seed
.. automethod:: gymnasium.spaces.Box.is_bounded .. automethod:: gymnasium.spaces.Box.is_bounded
```
## Discrete
```{eval-rst}
.. autoclass:: gymnasium.spaces.Discrete .. autoclass:: gymnasium.spaces.Discrete
.. automethod:: gymnasium.spaces.Discrete.sample .. automethod:: gymnasium.spaces.Discrete.sample
.. automethod:: gymnasium.spaces.Discrete.seed .. automethod:: gymnasium.spaces.Discrete.seed
```
## MultiBinary
```{eval-rst}
.. autoclass:: gymnasium.spaces.MultiBinary .. autoclass:: gymnasium.spaces.MultiBinary
.. automethod:: gymnasium.spaces.MultiBinary.sample .. automethod:: gymnasium.spaces.MultiBinary.sample
.. automethod:: gymnasium.spaces.MultiBinary.seed .. automethod:: gymnasium.spaces.MultiBinary.seed
```
## MultiDiscrete
```{eval-rst}
.. autoclass:: gymnasium.spaces.MultiDiscrete .. autoclass:: gymnasium.spaces.MultiDiscrete
.. automethod:: gymnasium.spaces.MultiDiscrete.sample .. automethod:: gymnasium.spaces.MultiDiscrete.sample
.. automethod:: gymnasium.spaces.MultiDiscrete.seed .. automethod:: gymnasium.spaces.MultiDiscrete.seed
```
## Text
```{eval-rst}
.. autoclass:: gymnasium.spaces.Text .. autoclass:: gymnasium.spaces.Text
.. automethod:: gymnasium.spaces.Text.sample .. automethod:: gymnasium.spaces.Text.sample

View File

@@ -1,8 +1,20 @@
--- ---
title: Utils title: Utility functions
--- ---
# Utils # Utility functions
## Seeding
```{eval-rst}
.. autofunction:: gymnasium.utils.seeding.np_random
```
## Environment Checking
```{eval-rst}
.. autofunction:: gymnasium.utils.env_checker.check_env
```
## Visualization ## Visualization
@@ -17,6 +29,12 @@ title: Utils
.. automethod:: process_event .. automethod:: process_event
``` ```
## Environment pickling
```{eval-rst}
.. autoclass:: gymnasium.utils.ezpickle.EzPickle
```
## Save Rendering Videos ## Save Rendering Videos
```{eval-rst} ```{eval-rst}
@@ -31,15 +49,3 @@ title: Utils
.. autofunction:: gymnasium.utils.step_api_compatibility.convert_to_terminated_truncated_step_api .. autofunction:: gymnasium.utils.step_api_compatibility.convert_to_terminated_truncated_step_api
.. autofunction:: gymnasium.utils.step_api_compatibility.convert_to_done_step_api .. autofunction:: gymnasium.utils.step_api_compatibility.convert_to_done_step_api
``` ```
## Seeding
```{eval-rst}
.. autofunction:: gymnasium.utils.seeding.np_random
```
## Environment Checking
```{eval-rst}
.. autofunction:: gymnasium.utils.env_checker.check_env
```

View File

@@ -2,7 +2,15 @@
title: Vector title: Vector
--- ---
# Vector # Vector environments
```{toctree}
:hidden:
vector/wrappers
vector/async_vector_env
vector/sync_vector_env
vector/utils
```
## Gymnasium.vector.VectorEnv ## Gymnasium.vector.VectorEnv
@@ -14,62 +22,47 @@ title: Vector
```{eval-rst} ```{eval-rst}
.. automethod:: gymnasium.vector.VectorEnv.reset .. automethod:: gymnasium.vector.VectorEnv.reset
.. automethod:: gymnasium.vector.VectorEnv.step .. automethod:: gymnasium.vector.VectorEnv.step
.. automethod:: gymnasium.vector.VectorEnv.close .. automethod:: gymnasium.vector.VectorEnv.close
``` ```
### Attributes ### Attributes
```{eval-rst} ```{eval-rst}
.. attribute:: action_space .. autoattribute:: gymnasium.vector.VectorEnv.num_envs
The (batched) action space. The input actions of `step` must be valid elements of `action_space`.:: The number of sub-environments in the vector environment.
>>> envs = gymnasium.vector.make("CartPole-v1", num_envs=3) .. autoattribute:: gymnasium.vector.VectorEnv.action_space
>>> envs.action_space
MultiDiscrete([2 2 2])
.. attribute:: observation_space The (batched) action space. The input actions of `step` must be valid elements of `action_space`.
The (batched) observation space. The observations returned by `reset` and `step` are valid elements of `observation_space`.:: .. autoattribute:: gymnasium.vector.VectorEnv.observation_space
>>> envs = gymnasium.vector.make("CartPole-v1", num_envs=3) The (batched) observation space. The observations returned by `reset` and `step` are valid elements of `observation_space`.
>>> envs.observation_space
Box([[-4.8 ...]], [[4.8 ...]], (3, 4), float32)
.. attribute:: single_action_space .. autoattribute:: gymnasium.vector.VectorEnv.single_action_space
The action space of an environment copy.:: The action space of a sub-environment.
>>> envs = gymnasium.vector.make("CartPole-v1", num_envs=3) .. autoattribute:: gymnasium.vector.VectorEnv.single_observation_space
>>> envs.single_action_space
Discrete(2)
.. attribute:: single_observation_space The observation space of an environment copy.
The observation space of an environment copy.:: .. autoattribute:: gymnasium.vector.VectorEnv.spec
>>> envs = gymnasium.vector.make("CartPole-v1", num_envs=3) The ``EnvSpec`` of the environment normally set during :py:meth:`gymnasium.make_vec`
>>> envs.single_observation_space ```
Box([-4.8 ...], [4.8 ...], (4,), float32)
### Additional Methods
```{eval-rst}
.. autoproperty:: gymnasium.vector.VectorEnv.unwrapped
.. autoproperty:: gymnasium.vector.VectorEnv.np_random
``` ```
## Making Vector Environments ## Making Vector Environments
```{eval-rst} ```{eval-rst}
.. autofunction:: gymnasium.vector.make To create vector environments, gymnasium provides :func:`gymnasium.make_vec` as an equivalent function to :func:`gymnasium.make`.
```
## Async Vector Env
```{eval-rst}
.. autoclass:: gymnasium.vector.AsyncVectorEnv
```
## Sync Vector Env
```{eval-rst}
.. autoclass:: gymnasium.vector.SyncVectorEnv
``` ```

View File

@@ -0,0 +1,13 @@
# AsyncVectorEnv
```{eval-rst}
.. autoclass:: gymnasium.vector.AsyncVectorEnv
.. automethod:: gymnasium.vector.AsyncVectorEnv.reset
.. automethod:: gymnasium.vector.AsyncVectorEnv.step
.. automethod:: gymnasium.vector.AsyncVectorEnv.close
.. automethod:: gymnasium.vector.AsyncVectorEnv.call
.. automethod:: gymnasium.vector.AsyncVectorEnv.get_attr
.. automethod:: gymnasium.vector.AsyncVectorEnv.set_attr
```

View File

@@ -0,0 +1,13 @@
# SyncVectorEnv
```{eval-rst}
.. autoclass:: gymnasium.vector.SyncVectorEnv
.. automethod:: gymnasium.vector.SyncVectorEnv.reset
.. automethod:: gymnasium.vector.SyncVectorEnv.step
.. automethod:: gymnasium.vector.SyncVectorEnv.close
.. automethod:: gymnasium.vector.SyncVectorEnv.call
.. automethod:: gymnasium.vector.SyncVectorEnv.get_attr
.. automethod:: gymnasium.vector.SyncVectorEnv.set_attr
```

View File

@@ -1,20 +1,25 @@
--- # Utility functions
title: Vector Utils
---
# Spaces Vector Utils ## Vectorizing Spaces
```{eval-rst} ```{eval-rst}
.. autofunction:: gymnasium.vector.utils.batch_space .. autofunction:: gymnasium.vector.utils.batch_space
.. autofunction:: gymnasium.vector.utils.concatenate .. autofunction:: gymnasium.vector.utils.concatenate
.. autofunction:: gymnasium.vector.utils.iterate .. autofunction:: gymnasium.vector.utils.iterate
.. autofunction:: gymnasium.vector.utils.create_empty_array
``` ```
## Shared Memory Utils ## Shared Memory for a Space
```{eval-rst} ```{eval-rst}
.. autofunction:: gymnasium.vector.utils.create_empty_array
.. autofunction:: gymnasium.vector.utils.create_shared_memory .. autofunction:: gymnasium.vector.utils.create_shared_memory
.. autofunction:: gymnasium.vector.utils.read_from_shared_memory .. autofunction:: gymnasium.vector.utils.read_from_shared_memory
.. autofunction:: gymnasium.vector.utils.write_to_shared_memory .. autofunction:: gymnasium.vector.utils.write_to_shared_memory
``` ```
## Miscellaneous
```{eval-rst}
.. autofunction:: gymnasium.vector.utils.CloudpickleWrapper
.. autofunction:: gymnasium.vector.utils.clear_mpi_env_vars
```

View File

@@ -0,0 +1,26 @@
---
title: Vector Wrappers
---
# Vector Wrappers
```{eval-rst}
.. autoclass:: gymnasium.vector.VectorWrapper
.. automethod:: gymnasium.vector.VectorWrapper.step
.. automethod:: gymnasium.vector.VectorWrapper.reset
.. automethod:: gymnasium.vector.VectorWrapper.close
.. autoclass:: gymnasium.vector.VectorObservationWrapper
.. automethod:: gymnasium.vector.VectorObservationWrapper.vector_observation
.. automethod:: gymnasium.vector.VectorObservationWrapper.single_observation
.. autoclass:: gymnasium.vector.VectorActionWrapper
.. automethod:: gymnasium.vector.VectorActionWrapper.actions
.. autoclass:: gymnasium.vector.VectorRewardWrapper
.. automethod:: gymnasium.vector.VectorRewardWrapper.rewards
```

View File

@@ -6,134 +6,47 @@ title: Wrapper
```{toctree} ```{toctree}
:hidden: :hidden:
wrappers/table
wrappers/misc_wrappers wrappers/misc_wrappers
wrappers/action_wrappers wrappers/action_wrappers
wrappers/observation_wrappers wrappers/observation_wrappers
wrappers/reward_wrappers wrappers/reward_wrappers
wrappers/vector_wrappers
``` ```
```{eval-rst} ```{eval-rst}
.. automodule:: gymnasium.wrappers .. automodule:: gymnasium.wrappers
``` ```
## gymnasium.Wrapper
```{eval-rst} ```{eval-rst}
.. autoclass:: gymnasium.Wrapper .. autoclass:: gymnasium.Wrapper
``` ```
### Methods ## Methods
```{eval-rst} ```{eval-rst}
.. autofunction:: gymnasium.Wrapper.step .. automethod:: gymnasium.Wrapper.step
.. autofunction:: gymnasium.Wrapper.reset .. automethod:: gymnasium.Wrapper.reset
.. autofunction:: gymnasium.Wrapper.close .. automethod:: gymnasium.Wrapper.render
.. automethod:: gymnasium.Wrapper.close
.. automethod:: gymnasium.Wrapper.wrapper_spec
.. automethod:: gymnasium.Wrapper.get_wrapper_attr
.. automethod:: gymnasium.Wrapper.set_wrapper_attr
``` ```
### Attributes ## Attributes
```{eval-rst} ```{eval-rst}
.. autoproperty:: gymnasium.Wrapper.action_space .. autoattribute:: gymnasium.Wrapper.env
.. autoproperty:: gymnasium.Wrapper.observation_space
.. autoproperty:: gymnasium.Wrapper.reward_range
.. autoproperty:: gymnasium.Wrapper.spec
.. autoproperty:: gymnasium.Wrapper.metadata
.. autoproperty:: gymnasium.Wrapper.np_random
.. attribute:: gymnasium.Wrapper.env
The environment (one level underneath) this wrapper. The environment (one level underneath) this wrapper.
This may itself be a wrapped environment. This may itself be a wrapped environment. To obtain the environment underneath all layers of wrappers, use :attr:`gymnasium.Wrapper.unwrapped`.
To obtain the environment underneath all layers of wrappers, use :attr:`gymnasium.Wrapper.unwrapped`.
.. autoproperty:: gymnasium.Wrapper.action_space
.. autoproperty:: gymnasium.Wrapper.observation_space
.. autoproperty:: gymnasium.Wrapper.spec
.. autoproperty:: gymnasium.Wrapper.metadata
.. autoproperty:: gymnasium.Wrapper.np_random
.. autoproperty:: gymnasium.Wrapper.unwrapped .. autoproperty:: gymnasium.Wrapper.unwrapped
``` ```
## Gymnasium Wrappers
Gymnasium provides a number of commonly used wrappers listed below. More information can be found on the particular
wrapper in the page on the wrapper type
```{eval-rst}
.. py:currentmodule:: gymnasium.wrappers
.. list-table::
:header-rows: 1
* - Name
- Type
- Description
* - :class:`AtariPreprocessing`
- Misc Wrapper
- Implements the common preprocessing applied to Atari environments
* - :class:`AutoResetWrapper`
- Misc Wrapper
- The wrapped environment will automatically reset when the terminated or truncated state is reached.
* - :class:`ClipAction`
- Action Wrapper
- Clip the continuous action to the valid bound specified by the environment's `action_space`
* - :class:`EnvCompatibility`
- Misc Wrapper
- Provides compatibility for environments written in the OpenAI Gym v0.21 API to look like Gymnasium environments
* - :class:`FilterObservation`
- Observation Wrapper
- Filters a dictionary observation spaces to only requested keys
* - :class:`FlattenObservation`
- Observation Wrapper
- An Observation wrapper that flattens the observation
* - :class:`FrameStack`
- Observation Wrapper
- AnObservation wrapper that stacks the observations in a rolling manner.
* - :class:`GrayScaleObservation`
- Observation Wrapper
- Convert the image observation from RGB to gray scale.
* - :class:`HumanRendering`
- Misc Wrapper
- Allows human like rendering for environments that support "rgb_array" rendering
* - :class:`NormalizeObservation`
- Observation Wrapper
- This wrapper will normalize observations s.t. each coordinate is centered with unit variance.
* - :class:`NormalizeReward`
- Reward Wrapper
- This wrapper will normalize immediate rewards s.t. their exponential moving average has a fixed variance.
* - :class:`OrderEnforcing`
- Misc Wrapper
- This will produce an error if `step` or `render` is called before `reset`
* - :class:`PixelObservationWrapper`
- Observation Wrapper
- Augment observations by pixel values obtained via `render` that can be added to or replaces the environments observation.
* - :class:`RecordEpisodeStatistics`
- Misc Wrapper
- This will keep track of cumulative rewards and episode lengths returning them at the end.
* - :class:`RecordVideo`
- Misc Wrapper
- This wrapper will record videos of rollouts.
* - :class:`RenderCollection`
- Misc Wrapper
- Enable list versions of render modes, i.e. "rgb_array_list" for "rgb_array" such that the rendering for each step are saved in a list until `render` is called.
* - :class:`RescaleAction`
- Action Wrapper
- Rescales the continuous action space of the environment to a range \[`min_action`, `max_action`], where `min_action` and `max_action` are numpy arrays or floats.
* - :class:`ResizeObservation`
- Observation Wrapper
- This wrapper works on environments with image observations (or more generally observations of shape AxBxC) and resizes the observation to the shape given by the tuple `shape`.
* - :class:`StepAPICompatibility`
- Misc Wrapper
- Modifies an environment step function from (old) done to the (new) termination / truncation API.
* - :class:`TimeAwareObservation`
- Observation Wrapper
- Augment the observation with current time step in the trajectory (by appending it to the observation).
* - :class:`TimeLimit`
- Misc Wrapper
- This wrapper will emit a truncated signal if the specified number of steps is exceeded in an episode.
* - :class:`TransformObservation`
- Observation Wrapper
- This wrapper will apply function to observations
* - :class:`TransformReward`
- Reward Wrapper
- This wrapper will apply function to rewards
* - :class:`VectorListInfo`
- Misc Wrapper
- This wrapper will convert the info of a vectorized environment from the `dict` format to a `list` of dictionaries where the i-th dictionary contains info of the i-th environment.
```

View File

@@ -11,6 +11,8 @@
## Available Action Wrappers ## Available Action Wrappers
```{eval-rst} ```{eval-rst}
.. autoclass:: gymnasium.wrappers.TransformAction
.. autoclass:: gymnasium.wrappers.ClipAction .. autoclass:: gymnasium.wrappers.ClipAction
.. autoclass:: gymnasium.wrappers.RescaleAction .. autoclass:: gymnasium.wrappers.RescaleAction
.. autoclass:: gymnasium.wrappers.StickyAction
``` ```

View File

@@ -1,16 +1,33 @@
---
title: Misc Wrappers
---
# Misc Wrappers # Misc Wrappers
## Common Wrappers
```{eval-rst} ```{eval-rst}
.. autoclass:: gymnasium.wrappers.TimeLimit
.. autoclass:: gymnasium.wrappers.RecordVideo
.. autoclass:: gymnasium.wrappers.RecordEpisodeStatistics
.. autoclass:: gymnasium.wrappers.AtariPreprocessing .. autoclass:: gymnasium.wrappers.AtariPreprocessing
.. autoclass:: gymnasium.wrappers.AutoResetWrapper ```
.. autoclass:: gymnasium.wrappers.EnvCompatibility
.. autoclass:: gymnasium.wrappers.StepAPICompatibility ## Uncommon Wrappers
```{eval-rst}
.. autoclass:: gymnasium.wrappers.Autoreset
.. autoclass:: gymnasium.wrappers.PassiveEnvChecker .. autoclass:: gymnasium.wrappers.PassiveEnvChecker
.. autoclass:: gymnasium.wrappers.HumanRendering .. autoclass:: gymnasium.wrappers.HumanRendering
.. autoclass:: gymnasium.wrappers.OrderEnforcing .. autoclass:: gymnasium.wrappers.OrderEnforcing
.. autoclass:: gymnasium.wrappers.RecordEpisodeStatistics
.. autoclass:: gymnasium.wrappers.RecordVideo
.. autoclass:: gymnasium.wrappers.RenderCollection .. autoclass:: gymnasium.wrappers.RenderCollection
.. autoclass:: gymnasium.wrappers.TimeLimit ```
.. autoclass:: gymnasium.wrappers.VectorListInfo
## Data Conversion Wrappers
```{eval-rst}
.. autoclass:: gymnasium.wrappers.JaxToNumpy
.. autoclass:: gymnasium.wrappers.JaxToTorch
.. autoclass:: gymnasium.wrappers.NumpyToTorch
``` ```

View File

@@ -1,23 +1,26 @@
# Observation Wrappers # Observation Wrappers
## Base Class
```{eval-rst} ```{eval-rst}
.. autoclass:: gymnasium.ObservationWrapper .. autoclass:: gymnasium.ObservationWrapper
.. automethod:: gymnasium.ObservationWrapper.observation .. automethod:: gymnasium.ObservationWrapper.observation
``` ```
## Available Observation Wrappers ## Implemented Wrappers
```{eval-rst} ```{eval-rst}
.. autoclass:: gymnasium.wrappers.TransformObservation .. autoclass:: gymnasium.wrappers.TransformObservation
.. autoclass:: gymnasium.wrappers.DelayObservation
.. autoclass:: gymnasium.wrappers.DtypeObservation
.. autoclass:: gymnasium.wrappers.FilterObservation .. autoclass:: gymnasium.wrappers.FilterObservation
.. autoclass:: gymnasium.wrappers.FlattenObservation .. autoclass:: gymnasium.wrappers.FlattenObservation
.. autoclass:: gymnasium.wrappers.FrameStack .. autoclass:: gymnasium.wrappers.FrameStackObservation
.. autoclass:: gymnasium.wrappers.GrayScaleObservation .. autoclass:: gymnasium.wrappers.GrayscaleObservation
.. autoclass:: gymnasium.wrappers.MaxAndSkipObservation
.. autoclass:: gymnasium.wrappers.NormalizeObservation .. autoclass:: gymnasium.wrappers.NormalizeObservation
.. autoclass:: gymnasium.wrappers.PixelObservationWrapper .. autoclass:: gymnasium.wrappers.RenderObservation
.. autoclass:: gymnasium.wrappers.ResizeObservation .. autoclass:: gymnasium.wrappers.ResizeObservation
.. autoclass:: gymnasium.wrappers.ReshapeObservation
.. autoclass:: gymnasium.wrappers.RescaleObservation
.. autoclass:: gymnasium.wrappers.TimeAwareObservation .. autoclass:: gymnasium.wrappers.TimeAwareObservation
``` ```

View File

@@ -1,17 +1,19 @@
---
title: Reward Wrappers
---
# Reward Wrappers # Reward Wrappers
## Base Class
```{eval-rst} ```{eval-rst}
.. autoclass:: gymnasium.RewardWrapper .. autoclass:: gymnasium.RewardWrapper
.. automethod:: gymnasium.RewardWrapper.reward .. automethod:: gymnasium.RewardWrapper.reward
``` ```
## Available Reward Wrappers ## Implemented Wrappers
```{eval-rst} ```{eval-rst}
.. autoclass:: gymnasium.wrappers.TransformReward .. autoclass:: gymnasium.wrappers.TransformReward
.. autoclass:: gymnasium.wrappers.NormalizeReward .. autoclass:: gymnasium.wrappers.NormalizeReward
.. autoclass:: gymnasium.wrappers.ClipReward
``` ```

102
docs/api/wrappers/table.md Normal file
View File

@@ -0,0 +1,102 @@
# List of Wrappers
Gymnasium provides a number of commonly used wrappers listed below. More information can be found on the particular
wrapper in the page on the wrapper type
```{eval-rst}
.. py:currentmodule:: gymnasium.wrappers
.. list-table::
:header-rows: 1
* - Name
- Description
* - :class:`AtariPreprocessing`
- Implements the common preprocessing techniques for Atari environments (excluding frame stacking).
* - :class:`Autoreset`
- The wrapped environment is automatically reset when an terminated or truncated state is reached.
* - :class:`ClipAction`
- Clips the ``action`` pass to ``step`` to be within the environment's `action_space`.
* - :class:`ClipReward`
- Clips the rewards for an environment between an upper and lower bound.
* - :class:`DelayObservation`
- Adds a delay to the returned observation from the environment.
* - :class:`DtypeObservation`
- Modifies the dtype of an observation array to a specified dtype.
* - :class:`FilterObservation`
- Filters a Dict or Tuple observation spaces by a set of keys or indexes.
* - :class:`FlattenObservation`
- Flattens the environment's observation space and each observation from ``reset`` and ``step`` functions.
* - :class:`FrameStackObservation`
- Stacks the observations from the last ``N`` time steps in a rolling manner.
* - :class:`GrayscaleObservation`
- Converts an image observation computed by ``reset`` and ``step`` from RGB to Grayscale.
* - :class:`HumanRendering`
- Allows human like rendering for environments that support "rgb_array" rendering.
* - :class:`JaxToNumpy`
- Wraps a Jax-based environment such that it can be interacted with NumPy arrays.
* - :class:`JaxToTorch`
- Wraps a Jax-based environment so that it can be interacted with PyTorch Tensors.
* - :class:`MaxAndSkipObservation`
- Skips the N-th frame (observation) and return the max values between the two last observations.
* - :class:`NormalizeObservation`
- Normalizes observations to be centered at the mean with unit variance.
* - :class:`NormalizeReward`
- Normalizes immediate rewards such that their exponential moving average has a fixed variance.
* - :class:`NumpyToTorch`
- Wraps a NumPy-based environment such that it can be interacted with PyTorch Tensors.
* - :class:`OrderEnforcing`
- Will produce an error if ``step`` or ``render`` is called before ``render``.
* - :class:`PassiveEnvChecker`
- A passive environment checker wrapper that surrounds the ``step``, ``reset`` and ``render`` functions to check they follows gymnasium's API.
* - :class:`RecordEpisodeStatistics`
- This wrapper will keep track of cumulative rewards and episode lengths.
* - :class:`RecordVideo`
- Records videos of environment episodes using the environment's render function.
* - :class:`RenderCollection`
- Collect rendered frames of an environment such ``render`` returns a ``list[RenderedFrame]``.
* - :class:`RenderObservation`
- Includes the rendered observations in the environment's observations.
* - :class:`RescaleAction`
- Affinely (linearly) rescales a ``Box`` action space of the environment to within the range of ``[min_action, max_action]``.
* - :class:`RescaleObservation`
- Affinely (linearly) rescales a ``Box`` observation space of the environment to within the range of ``[min_obs, max_obs]``.
* - :class:`ReshapeObservation`
- Reshapes Array based observations to a specified shape.
* - :class:`ResizeObservation`
- Resizes image observations using OpenCV to a specified shape.
* - :class:`StickyAction`
- Adds a probability that the action is repeated for the same ``step`` function.
* - :class:`TimeAwareObservation`
- Augment the observation with the number of time steps taken within an episode.
* - :class:`TimeLimit`
- Limits the number of steps for an environment through truncating the environment if a maximum number of timesteps is exceeded.
* - :class:`TransformAction`
- Applies a function to the ``action`` before passing the modified value to the environment ``step`` function.
* - :class:`TransformObservation`
- Applies a function to the ``observation`` received from the environment's ``reset`` and ``step`` that is passed back to the user.
* - :class:`TransformReward`
- Applies a function to the ``reward`` received from the environment's ``step``.
```
## Vector only Wrappers
```{eval-rst}
.. py:currentmodule:: gymnasium.wrappers.vector
.. list-table::
:header-rows: 1
* - Name
- Description
* - :class:`DictInfoToList`
- Converts infos of vectorized environments from ``dict`` to ``List[dict]``.
* - :class:`VectorizeTransformAction`
- Vectorizes a single-agent transform action wrapper for vector environments.
* - :class:`VectorizeTransformObservation`
- Vectorizes a single-agent transform observation wrapper for vector environments.
* - :class:`VectorizeTransformReward`
- Vectorizes a single-agent transform reward wrapper for vector environments.
```

View File

@@ -0,0 +1,19 @@
---
title: Vector Wrappers
---
# Vector wrappers
## Vector only wrappers
```{eval-rst}
.. autoclass:: gymnasium.wrappers.vector.DictInfoToList
```
## Vectorize Transform Wrappers to Vector Wrappers
```{eval-rst}
.. autoclass:: gymnasium.wrappers.vector.VectorizeTransformObservation
.. autoclass:: gymnasium.wrappers.vector.VectorizeTransformAction
.. autoclass:: gymnasium.wrappers.vector.VectorizeTransformReward
```

View File

@@ -40,10 +40,10 @@ release = gymnasium.__version__
# ones. # ones.
extensions = [ extensions = [
"sphinx.ext.napoleon", "sphinx.ext.napoleon",
"sphinx.ext.doctest",
"sphinx.ext.autodoc", "sphinx.ext.autodoc",
"sphinx.ext.githubpages", "sphinx.ext.githubpages",
"sphinx.ext.viewcode", "sphinx.ext.viewcode",
"sphinx.ext.coverage",
"myst_parser", "myst_parser",
"furo.gen_tutorials", "furo.gen_tutorials",
"sphinx_gallery.gen_gallery", "sphinx_gallery.gen_gallery",

View File

@@ -1,134 +0,0 @@
---
layout: "contents"
title: Basic Usage
firstpage:
---
# Basic Usage
Gymnasium is a project that provides an API for all single agent reinforcement learning environments, and includes implementations of common environments: cartpole, pendulum, mountain-car, mujoco, atari, and more.
The API contains four key functions: ``make``, ``reset``, ``step`` and ``render``, that this basic usage will introduce you to. At the core of Gymnasium is ``Env``, a high-level python class representing a markov decision process (MDP) from reinforcement learning theory (this is not a perfect reconstruction, and is missing several components of MDPs). Within gymnasium, environments (MDPs) are implemented as ``Env`` classes, along with ``Wrappers``, which provide helpful utilities and can change the results passed to the user.
## Initializing Environments
Initializing environments is very easy in Gymnasium and can be done via the ``make`` function:
```python
import gymnasium as gym
env = gym.make('CartPole-v1')
```
This will return an ``Env`` for users to interact with. To see all environments you can create, use ``gymnasium.envs.registry.keys()``.``make`` includes a number of additional parameters to adding wrappers, specifying keywords to the environment and more.
## Interacting with the Environment
The classic "agent-environment loop" pictured below is simplified representation of reinforcement learning that Gymnasium implements.
```{image} /_static/diagrams/AE_loop.png
:width: 50%
:align: center
:class: only-light
```
```{image} /_static/diagrams/AE_loop_dark.png
:width: 50%
:align: center
:class: only-dark
```
This loop is implemented using the following gymnasium code
```python
import gymnasium as gym
env = gym.make("LunarLander-v2", render_mode="human")
observation, info = env.reset()
for _ in range(1000):
action = env.action_space.sample() # agent policy that uses the observation and info
observation, reward, terminated, truncated, info = env.step(action)
if terminated or truncated:
observation, info = env.reset()
env.close()
```
The output should look something like this:
```{figure} https://user-images.githubusercontent.com/15806078/153222406-af5ce6f0-4696-4a24-a683-46ad4939170c.gif
:width: 50%
:align: center
```
### Explaining the code
First, an environment is created using ``make`` with an additional keyword `"render_mode"` that specifies how the environment should be visualised. See ``render`` for details on the default meaning of different render modes. In this example, we use the ``"LunarLander"`` environment where the agent controls a spaceship that needs to land safely.
After initializing the environment, we ``reset`` the environment to get the first observation of the environment. For initializing the environment with a particular random seed or options (see environment documentation for possible values) use the ``seed`` or ``options`` parameters with ``reset``.
Next, the agent performs an action in the environment, ``step``, this can be imagined as moving a robot or pressing a button on a games' controller that causes a change within the environment. As a result, the agent receives a new observation from the updated environment along with a reward for taking the action. This reward could be for instance positive for destroying an enemy or a negative reward for moving into lava. One such action-observation exchange is referred to as a *timestep*.
However, after some timesteps, the environment may end, this is called the terminal state. For instance, the robot may have crashed, or the agent have succeeded in completing a task, the environment will need to stop as the agent cannot continue. In gymnasium, if the environment has terminated, this is returned by ``step``. Similarly, we may also want the environment to end after a fixed number of timesteps, in this case, the environment issues a truncated signal. If either of ``terminated`` or ``truncated`` are `true` then ``reset`` should be called next to restart the environment.
## Action and observation spaces
Every environment specifies the format of valid actions and observations with the ``env.action_space`` and ``env.observation_space`` attributes. This is helpful for both knowing the expected input and output of the environment as all valid actions and observation should be contained with the respective space.
In the example, we sampled random actions via ``env.action_space.sample()`` instead of using an agent policy, mapping observations to actions which users will want to make. See one of the agent tutorials for an example of creating and training an agent policy.
Every environment should have the attributes ``action_space`` and ``observation_space``, both of which should be instances of classes that inherit from ``Space``. Gymnasium has support for a majority of possible spaces users might need:
- ``Box``: describes an n-dimensional continuous space. It's a bounded space where we can define the upper and lower
limits which describe the valid values our observations can take.
- ``Discrete``: describes a discrete space where {0, 1, ..., n-1} are the possible values our observation or action can take.
Values can be shifted to {a, a+1, ..., a+n-1} using an optional argument.
- ``Dict``: represents a dictionary of simple spaces.
- ``Tuple``: represents a tuple of simple spaces.
- ``MultiBinary``: creates an n-shape binary space. Argument n can be a number or a list of numbers.
- ``MultiDiscrete``: consists of a series of ``Discrete`` action spaces with a different number of actions in each element.
For example usage of spaces, see their [documentation](/api/spaces) along with [utility functions](/api/spaces/utils). There are a couple of more niche spaces ``Graph``, ``Sequence`` and ``Text``.
## Modifying the environment
Wrappers are a convenient way to modify an existing environment without having to alter the underlying code directly. Using wrappers will allow you to avoid a lot of boilerplate code and make your environment more modular. Wrappers can also be chained to combine their effects. Most environments that are generated via ``gymnasium.make`` will already be wrapped by default using the ``TimeLimit``, ``OrderEnforcing`` and ``PassiveEnvChecker``.
In order to wrap an environment, you must first initialize a base environment. Then you can pass this environment along with (possibly optional) parameters to the wrapper's constructor:
```python
>>> import gymnasium as gym
>>> from gymnasium.wrappers import FlattenObservation
>>> env = gym.make("CarRacing-v2")
>>> env.observation_space.shape
(96, 96, 3)
>>> wrapped_env = FlattenObservation(env)
>>> wrapped_env.observation_space.shape
(27648,)
```
Gymnasium already provides many commonly used wrappers for you. Some examples:
- `TimeLimit`: Issue a truncated signal if a maximum number of timesteps has been exceeded (or the base environment has issued a truncated signal).
- `ClipAction`: Clip the action such that it lies in the action space (of type `Box`).
- `RescaleAction`: Rescale actions to lie in a specified interval
- `TimeAwareObservation`: Add information about the index of timestep to observation. In some cases helpful to ensure that transitions are Markov.
For a full list of implemented wrappers in gymnasium, see [wrappers](/api/wrappers).
If you have a wrapped environment, and you want to get the unwrapped environment underneath all the layers of wrappers (so that you can manually call a function or change some underlying aspect of the environment), you can use the `.unwrapped` attribute. If the environment is already a base environment, the `.unwrapped` attribute will just return itself.
```python
>>> wrapped_env
<FlattenObservation<TimeLimit<OrderEnforcing<PassiveEnvChecker<CarRacing<CarRacing-v2>>>>>>
>>> wrapped_env.unwrapped
<gymnasium.envs.box2d.car_racing.CarRacing object at 0x7f04efcb8850>
```
## More information
* [Making a Custom environment using the Gymnasium API](/tutorials/gymnasium_basics/environment_creation/)
* [Training an agent to play blackjack](/tutorials/training_agents/blackjack_tutorial)
* [Compatibility with OpenAI Gym](/content/gym_compatibility)

View File

@@ -1,122 +0,0 @@
---
layout: "contents"
title: Migration Guide
---
# v21 to v26 Migration Guide
```{eval-rst}
.. py:currentmodule:: gymnasium.wrappers
Gymnasium is a fork of `OpenAI Gym v26 <https://github.com/openai/gym/releases/tag/0.26.2>`_, which introduced a large breaking change from `Gym v21 <https://github.com/openai/gym/releases/tag/v0.21.0>`_.
In this guide, we briefly outline the API changes from Gym v21 - which a number of tutorials have been written for - to Gym v26.
For environments still stuck in the v21 API, users can use the :class:`EnvCompatibility` wrapper to convert them to v26 compliant.
For more information, see the `guide </content/gym_compatibility>`_
```
### Example code for v21
```python
import gym
env = gym.make("LunarLander-v2", options={})
env.seed(123)
observation = env.reset()
done = False
while not done:
action = env.action_space.sample() # agent policy that uses the observation and info
observation, reward, done, info = env.step(action)
env.render(mode="human")
env.close()
```
### Example code for v26
```python
import gym
env = gym.make("LunarLander-v2", render_mode="human")
observation, info = env.reset(seed=123, options={})
done = False
while not done:
action = env.action_space.sample() # agent policy that uses the observation and info
observation, reward, terminated, truncated, info = env.step(action)
done = terminated or truncated
env.close()
```
## Seed and random number generator
```{eval-rst}
.. py:currentmodule:: gymnasium.Env
The ``Env.seed()`` has been removed from the Gym v26 environments in favour of ``Env.reset(seed=seed)``.
This allows seeding to only be changed on environment reset.
The decision to remove ``seed`` was because some environments use emulators that cannot change random number generators within an episode and must be done at the beginning of a new episode.
We are aware of cases where controlling the random number generator is important, in these cases, if the environment uses the built-in random number generator, users can set the seed manually with the attribute :attr:`np_random`.
Gymnasium v26 changed to using ``numpy.random.Generator`` instead of a custom random number generator.
This means that several functions such as ``randint`` were removed in favour of ``integers``.
While some environments might use external random number generator, we recommend using the attribute :attr:`np_random` that wrappers and external users can access and utilise.
```
## Environment Reset
```{eval-rst}
In v26, :meth:`reset` takes two optional parameters and returns one value.
This contrasts to v21 which takes no parameters and returns ``None``.
The two parameters are ``seed`` for setting the random number generator and ``options`` which allows additional data to be passed to the environment on reset.
For example, in classic control, the ``options`` parameter now allows users to modify the range of the state bound.
See the original `PR <https://github.com/openai/gym/pull/2921>`_ for more details.
:meth:`reset` further returns ``info``, similar to the ``info`` returned by :meth:`step`.
This is important because ``info`` can include metrics or valid action mask that is used or saved in the next step.
To update older environments, we highly recommend that ``super().reset(seed=seed)`` is called on the first line of :meth:`reset`.
This will automatically update the :attr:`np_random` with the seed value.
```
## Environment Step
```{eval-rst}
In v21, the type definition of :meth:`step` is ``tuple[ObsType, SupportsFloat, bool, dict[str, Any]`` representing the next observation, the reward from the step, if the episode is done and additional info from the step.
Due to reproducibility issues that will be expanded on in a blog post soon, we have changed the type definition to ``tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]`` adding an extra boolean value.
This extra bool corresponds to the older `done` now changed to `terminated` and `truncated`.
These changes were introduced in Gym `v26 <https://github.com/openai/gym/releases/tag/0.26.0>`_ (turned off by default in `v25 <https://github.com/openai/gym/releases/tag/0.25.0>`_).
For users wishing to update, in most cases, replacing ``done`` with ``terminated`` and ``truncated=False`` in :meth:`step` should address most issues.
However, environments that have reasons for episode truncation rather than termination should read through the associated `PR <https://github.com/openai/gym/pull/2752>`_.
For users looping through an environment, they should modify ``done = terminated or truncated`` as is show in the example code.
For training libraries, the primary difference is to change ``done`` to ``terminated``, indicating whether bootstrapping should or shouldn't happen.
```
## TimeLimit Wrapper
```{eval-rst}
In v21, the :class:`TimeLimit` wrapper added an extra key in the ``info`` dictionary ``TimeLimit.truncated`` whenever the agent reached the time limit without reaching a terminal state.
In v26, this information is instead communicated through the `truncated` return value described in the previous section, which is `True` if the agent reaches the time limit, whether or not it reaches a terminal state. The old dictionary entry is equivalent to ``truncated and not terminated``
```
## Environment Render
```{eval-rst}
In v26, a new render API was introduced such that the render mode is fixed at initialisation as some environments don't allow on-the-fly render mode changes. Therefore, users should now specify the :attr:`render_mode` within ``gym.make`` as shown in the v26 example code above.
For a more complete explanation of the changes, please refer to this `summary <https://younis.dev/blog/render-api/>`_.
```
## Removed code
```{eval-rst}
.. py:currentmodule:: gymnasium.wrappers
* GoalEnv - This was removed, users needing it should reimplement the environment or use Gymnasium Robotics which contains an implementation of this environment.
* ``from gym.envs.classic_control import rendering`` - This was removed in favour of users implementing their own rendering systems. Gymnasium environments are coded using pygame.
* Robotics environments - The robotics environments have been moved to the `Gymnasium Robotics <https://robotics.farama.org/>`_ project.
* Monitor wrapper - This wrapper was replaced with two separate wrapper, :class:`RecordVideo` and :class:`RecordEpisodeStatistics`
```

View File

@@ -17,17 +17,25 @@ An API standard for reinforcement learning with a diverse collection of referenc
:width: 500 :width: 500
``` ```
**Gymnasium is a maintained fork of OpenAIs Gym library.** The Gymnasium interface is simple, pythonic, and capable of representing general RL problems, and has a [compatibility wrapper](content/gym_compatibility) for old Gym environments: **Gymnasium is a maintained fork of OpenAIs Gym library.** The Gymnasium interface is simple, pythonic, and capable of representing general RL problems, and has a [compatibility wrapper](introduction/gym_compatibility) for old Gym environments:
```{code-block} python ```{code-block} python
import gymnasium as gym import gymnasium as gym
# Initialise the environment
env = gym.make("LunarLander-v2", render_mode="human") env = gym.make("LunarLander-v2", render_mode="human")
# Reset the environment to generate the first observation
observation, info = env.reset(seed=42) observation, info = env.reset(seed=42)
for _ in range(1000): for _ in range(1000):
action = env.action_space.sample() # this is where you would insert your policy # this is where you would insert your policy
action = env.action_space.sample()
# step (transition) through the environment with the action
# receiving the next observation, reward and if the episode has terminated or truncated
observation, reward, terminated, truncated, info = env.step(action) observation, reward, terminated, truncated, info = env.step(action)
# If the episode has ended then we can reset to start a new episode
if terminated or truncated: if terminated or truncated:
observation, info = env.reset() observation, info = env.reset()
@@ -38,9 +46,9 @@ env.close()
:hidden: :hidden:
:caption: Introduction :caption: Introduction
content/basic_usage introduction/basic_usage
content/gym_compatibility introduction/gym_compatibility
content/migration-guide introduction/migration-guide
``` ```
```{toctree} ```{toctree}
@@ -53,7 +61,7 @@ api/spaces
api/wrappers api/wrappers
api/vector api/vector
api/utils api/utils
api/experimental api/functional
``` ```
```{toctree} ```{toctree}

View File

@@ -0,0 +1,172 @@
---
layout: "contents"
title: Basic Usage
firstpage:
---
# Basic Usage
```{eval-rst}
.. py:currentmodule:: gymnasium
Gymnasium is a project that provides an API for all single agent reinforcement learning environments, and includes implementations of common environments: cartpole, pendulum, mountain-car, mujoco, atari, and more.
The API contains four key functions: :meth:`make`, :meth:`Env.reset`, :meth:`Env.step` and :meth:`Env.render`, that this basic usage will introduce you to. At the core of Gymnasium is :class:`Env`, a high-level python class representing a markov decision process (MDP) from reinforcement learning theory (this is not a perfect reconstruction, and is missing several components of MDPs). Within gymnasium, environments (MDPs) are implemented as :class:`Env` classes, along with :class:`Wrapper`, provide helpful utilities to change actions passed to the environment and modified the observations, rewards, termination or truncations conditions passed back to the user.
```
## Initializing Environments
```{eval-rst}
.. py:currentmodule:: gymnasium
Initializing environments is very easy in Gymnasium and can be done via the :meth:`make` function:
```
```python
import gymnasium as gym
env = gym.make('CartPole-v1')
```
```{eval-rst}
.. py:currentmodule:: gymnasium
This will return an :class:`Env` for users to interact with. To see all environments you can create, use :meth:`pprint_registry`. Furthermore, :meth:`make` provides a number of additional arguments for specifying keywords to the environment, adding more or less wrappers, etc.
```
## Interacting with the Environment
The classic "agent-environment loop" pictured below is simplified representation of reinforcement learning that Gymnasium implements.
```{image} /_static/diagrams/AE_loop.png
:width: 50%
:align: center
:class: only-light
```
```{image} /_static/diagrams/AE_loop_dark.png
:width: 50%
:align: center
:class: only-dark
```
This loop is implemented using the following gymnasium code
```python
import gymnasium as gym
env = gym.make("LunarLander-v2", render_mode="human")
observation, info = env.reset()
for _ in range(1000):
action = env.action_space.sample() # agent policy that uses the observation and info
observation, reward, terminated, truncated, info = env.step(action)
if terminated or truncated:
observation, info = env.reset()
env.close()
```
The output should look something like this:
```{figure} https://user-images.githubusercontent.com/15806078/153222406-af5ce6f0-4696-4a24-a683-46ad4939170c.gif
:width: 50%
:align: center
```
### Explaining the code
```{eval-rst}
.. py:currentmodule:: gymnasium
First, an environment is created using :meth:`make` with an additional keyword ``"render_mode"`` that specifies how the environment should be visualised.
.. py:currentmodule:: gymnasium.Env
See :meth:`render` for details on the default meaning of different render modes. In this example, we use the ``"LunarLander"`` environment where the agent controls a spaceship that needs to land safely.
After initializing the environment, we :meth:`reset` the environment to get the first observation of the environment. For initializing the environment with a particular random seed or options (see environment documentation for possible values) use the ``seed`` or ``options`` parameters with :meth:`reset`.
Next, the agent performs an action in the environment, :meth:`step`, this can be imagined as moving a robot or pressing a button on a games' controller that causes a change within the environment. As a result, the agent receives a new observation from the updated environment along with a reward for taking the action. This reward could be for instance positive for destroying an enemy or a negative reward for moving into lava. One such action-observation exchange is referred to as a **timestep**.
However, after some timesteps, the environment may end, this is called the terminal state. For instance, the robot may have crashed, or the agent have succeeded in completing a task, the environment will need to stop as the agent cannot continue. In gymnasium, if the environment has terminated, this is returned by :meth:`step`. Similarly, we may also want the environment to end after a fixed number of timesteps, in this case, the environment issues a truncated signal. If either of ``terminated`` or ``truncated`` are ``True`` then :meth:`reset` should be called next to restart the environment.
```
## Action and observation spaces
```{eval-rst}
.. py:currentmodule:: gymnasium.Env
Every environment specifies the format of valid actions and observations with the :attr:`action_space` and :attr:`observation_space` attributes. This is helpful for both knowing the expected input and output of the environment as all valid actions and observation should be contained with the respective space.
In the example, we sampled random actions via ``env.action_space.sample()`` instead of using an agent policy, mapping observations to actions which users will want to make. See one of the agent tutorials for an example of creating and training an agent policy.
.. py:currentmodule:: gymnasium
Every environment should have the attributes :attr:`Env.action_space` and :attr:`Env.observation_space`, both of which should be instances of classes that inherit from :class:`spaces.Space`. Gymnasium has support for a majority of possible spaces users might need:
.. py:currentmodule:: gymnasium.spaces
- :class:`Box`: describes an n-dimensional continuous space. It's a bounded space where we can define the upper and lower
limits which describe the valid values our observations can take.
- :class:`Discrete`: describes a discrete space where ``{0, 1, ..., n-1}`` are the possible values our observation or action can take.
Values can be shifted to ``{a, a+1, ..., a+n-1}`` using an optional argument.
- :class:`Dict`: represents a dictionary of simple spaces.
- :class:`Tuple`: represents a tuple of simple spaces.
- :class:`MultiBinary`: creates an n-shape binary space. Argument n can be a number or a list of numbers.
- :class:`MultiDiscrete`: consists of a series of :class:`Discrete` action spaces with a different number of actions in each element.
For example usage of spaces, see their `documentation </api/spaces>`_ along with `utility functions </api/spaces/utils>`_. There are a couple of more niche spaces :class:`Graph`, :class:`Sequence` and :class:`Text`.
```
## Modifying the environment
```{eval-rst}
.. py:currentmodule:: gymnasium.wrappers
Wrappers are a convenient way to modify an existing environment without having to alter the underlying code directly. Using wrappers will allow you to avoid a lot of boilerplate code and make your environment more modular. Wrappers can also be chained to combine their effects. Most environments that are generated via ``gymnasium.make`` will already be wrapped by default using the :class:`TimeLimitV0`, :class:`OrderEnforcingV0` and :class:`PassiveEnvCheckerV0`.
In order to wrap an environment, you must first initialize a base environment. Then you can pass this environment along with (possibly optional) parameters to the wrapper's constructor:
```
```python
>>> import gymnasium as gym
>>> from gymnasium.wrappers import FlattenObservation
>>> env = gym.make("CarRacing-v2")
>>> env.observation_space.shape
(96, 96, 3)
>>> wrapped_env = FlattenObservation(env)
>>> wrapped_env.observation_space.shape
(27648,)
```
```{eval-rst}
.. py:currentmodule:: gymnasium.wrappers
Gymnasium already provides many commonly used wrappers for you. Some examples:
- :class:`TimeLimitV0`: Issue a truncated signal if a maximum number of timesteps has been exceeded (or the base environment has issued a truncated signal).
- :class:`ClipActionV0`: Clip the action such that it lies in the action space (of type `Box`).
- :class:`RescaleActionV0`: Rescale actions to lie in a specified interval
- :class:`TimeAwareObservationV0`: Add information about the index of timestep to observation. In some cases helpful to ensure that transitions are Markov.
```
For a full list of implemented wrappers in gymnasium, see [wrappers](/api/wrappers).
```{eval-rst}
.. py:currentmodule:: gymnasium.Env
If you have a wrapped environment, and you want to get the unwrapped environment underneath all the layers of wrappers (so that you can manually call a function or change some underlying aspect of the environment), you can use the :attr:`unwrapped` attribute. If the environment is already a base environment, the :attr:`unwrapped` attribute will just return itself.
```
```python
>>> wrapped_env
<FlattenObservation<TimeLimit<OrderEnforcing<PassiveEnvChecker<CarRacing<CarRacing-v2>>>>>>
>>> wrapped_env.unwrapped
<gymnasium.envs.box2d.car_racing.CarRacing object at 0x7f04efcb8850>
```
## More information
* [Making a Custom environment using the Gymnasium API](/tutorials/gymnasium_basics/environment_creation/)
* [Training an agent to play blackjack](/tutorials/training_agents/blackjack_tutorial)
* [Compatibility with OpenAI Gym](/introduction/gym_compatibility)

View File

@@ -12,9 +12,7 @@ Gymnasium provides a number of compatibility methods for a range of Environment
```{eval-rst} ```{eval-rst}
.. py:currentmodule:: gymnasium.wrappers .. py:currentmodule:: gymnasium.wrappers
For environments that are registered solely in OpenAI Gym and not in Gymnasium, Gymnasium v0.26.3 and above allows importing them through either a special environment or a wrapper. For environments that are registered solely in OpenAI Gym and not in Gymnasium, Gymnasium v0.26.3 and above allows importing them through either a special environment or a wrapper. The ``"GymV26Environment-v0"`` environment was introduced in Gymnasium v0.26.3, and allows importing of Gym environments through the ``env_name`` argument along with other relevant kwargs environment kwargs. To perform conversion through a wrapper, the environment itself can be passed to the wrapper :class:`EnvCompatibility` through the ``env`` kwarg.
The ``"GymV26Environment-v0"`` environment was introduced in Gymnasium v0.26.3, and allows importing of Gym environments through the ``env_name`` argument along with other relevant kwargs environment kwargs.
To perform conversion through a wrapper, the environment itself can be passed to the wrapper :class:`EnvCompatibility` through the ``env`` kwarg.
``` ```
An example of this is atari 0.8.0 which does not have a gymnasium implementation. An example of this is atari 0.8.0 which does not have a gymnasium implementation.
@@ -29,9 +27,7 @@ env = gym.make("GymV26Environment-v0", env_id="ALE/Pong-v5")
```{eval-rst} ```{eval-rst}
.. py:currentmodule:: gymnasium .. py:currentmodule:: gymnasium
A number of environments have not updated to the recent Gym changes, in particular since v0.21. A number of environments have not updated to the recent Gym changes, in particular since v0.21. This update is significant for the introduction of ``termination`` and ``truncation`` signatures in favour of the previously used ``done``. To allow backward compatibility, Gym and Gymnasium v0.26+ include an ``apply_api_compatibility`` kwarg when calling :meth:`make` that automatically converts a v0.21 API compliant environment to one that is compatible with v0.26+.
This update is significant for the introduction of ``termination`` and ``truncation`` signatures in favour of the previously used ``done``.
To allow backward compatibility, Gym and Gymnasium v0.26+ include an ``apply_api_compatibility`` kwarg when calling :meth:`make` that automatically converts a v0.21 API compliant environment to one that is compatible with v0.26+.
``` ```
```python ```python

View File

@@ -0,0 +1,103 @@
---
layout: "contents"
title: Migration Guide
---
# v0.21 to v0.26 Migration Guide
```{eval-rst}
.. py:currentmodule:: gymnasium.wrappers
Gymnasium is a fork of `OpenAI Gym v0.26 <https://github.com/openai/gym/releases/tag/0.26.2>`_, which introduced a large breaking change from `Gym v0.21 <https://github.com/openai/gym/releases/tag/v0.21.0>`_. In this guide, we briefly outline the API changes from Gym v0.21 - which a number of tutorials have been written for - to Gym v0.26. For environments still stuck in the v0.21 API, users can use the :class:`EnvCompatibility` wrapper to convert them to v0.26 compliant.
For more information, see the `guide </content/gym_compatibility>`_
```
## Example code for v0.21
```python
import gym
env = gym.make("LunarLander-v2", options={})
env.seed(123)
observation = env.reset()
done = False
while not done:
action = env.action_space.sample() # agent policy that uses the observation and info
observation, reward, done, info = env.step(action)
env.render(mode="human")
env.close()
```
## Example code for v0.26
```python
import gym
env = gym.make("LunarLander-v2", render_mode="human")
observation, info = env.reset(seed=123, options={})
done = False
while not done:
action = env.action_space.sample() # agent policy that uses the observation and info
observation, reward, terminated, truncated, info = env.step(action)
done = terminated or truncated
env.close()
```
## Seed and random number generator
```{eval-rst}
.. py:currentmodule:: gymnasium.Env
The ``Env.seed()`` has been removed from the Gym v0.26 environments in favour of ``Env.reset(seed=seed)``. This allows seeding to only be changed on environment reset. The decision to remove ``seed`` was because some environments use emulators that cannot change random number generators within an episode and must be done at the beginning of a new episode. We are aware of cases where controlling the random number generator is important, in these cases, if the environment uses the built-in random number generator, users can set the seed manually with the attribute :attr:`np_random`.
Gymnasium v0.26 changed to using ``numpy.random.Generator`` instead of a custom random number generator. This means that several functions such as ``randint`` were removed in favour of ``integers``. While some environments might use external random number generator, we recommend using the attribute :attr:`np_random` that wrappers and external users can access and utilise.
```
## Environment Reset
```{eval-rst}
In v0.26, :meth:`reset` takes two optional parameters and returns one value. This contrasts to v0.21 which takes no parameters and returns ``None``. The two parameters are ``seed`` for setting the random number generator and ``options`` which allows additional data to be passed to the environment on reset. For example, in classic control, the ``options`` parameter now allows users to modify the range of the state bound. See the original `PR <https://github.com/openai/gym/pull/2921>`_ for more details.
:meth:`reset` further returns ``info``, similar to the ``info`` returned by :meth:`step`. This is important because ``info`` can include metrics or valid action mask that is used or saved in the next step.
To update older environments, we highly recommend that ``super().reset(seed=seed)`` is called on the first line of :meth:`reset`. This will automatically update the :attr:`np_random` with the seed value.
```
## Environment Step
```{eval-rst}
In v0.21, the type definition of :meth:`step` is ``tuple[ObsType, SupportsFloat, bool, dict[str, Any]`` representing the next observation, the reward from the step, if the episode is done and additional info from the step. Due to reproducibility issues that will be expanded on in a blog post soon, we have changed the type definition to ``tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]`` adding an extra boolean value. This extra bool corresponds to the older `done` now changed to `terminated` and `truncated`. These changes were introduced in Gym `v0.26 <https://github.com/openai/gym/releases/tag/0.26.0>`_ (turned off by default in `v25 <https://github.com/openai/gym/releases/tag/0.25.0>`_).
For users wishing to update, in most cases, replacing ``done`` with ``terminated`` and ``truncated=False`` in :meth:`step` should address most issues. However, environments that have reasons for episode truncation rather than termination should read through the associated `PR <https://github.com/openai/gym/pull/2752>`_. For users looping through an environment, they should modify ``done = terminated or truncated`` as is show in the example code. For training libraries, the primary difference is to change ``done`` to ``terminated``, indicating whether bootstrapping should or shouldn't happen.
```
## TimeLimit Wrapper
```{eval-rst}
In v0.21, the :class:`TimeLimit` wrapper added an extra key in the ``info`` dictionary ``TimeLimit.truncated`` whenever the agent reached the time limit without reaching a terminal state.
In v0.26, this information is instead communicated through the `truncated` return value described in the previous section, which is `True` if the agent reaches the time limit, whether or not it reaches a terminal state. The old dictionary entry is equivalent to ``truncated and not terminated``
```
## Environment Render
```{eval-rst}
In v0.26, a new render API was introduced such that the render mode is fixed at initialisation as some environments don't allow on-the-fly render mode changes. Therefore, users should now specify the :attr:`render_mode` within ``gym.make`` as shown in the v0.26 example code above.
For a more complete explanation of the changes, please refer to this `summary <https://younis.dev/blog/render-api/>`_.
```
## Removed code
```{eval-rst}
.. py:currentmodule:: gymnasium.wrappers
* GoalEnv - This was removed, users needing it should reimplement the environment or use Gymnasium Robotics which contains an implementation of this environment.
* ``from gym.envs.classic_control import rendering`` - This was removed in favour of users implementing their own rendering systems. Gymnasium environments are coded using pygame.
* Robotics environments - The robotics environments have been moved to the `Gymnasium Robotics <https://robotics.farama.org/>`_ project.
* Monitor wrapper - This wrapper was replaced with two separate wrapper, :class:`RecordVideo` and :class:`RecordEpisodeStatistics`
```

View File

@@ -21,8 +21,7 @@ from gymnasium.envs.registration import (
# necessary for `envs.__init__` which registers all gymnasium environments and loads plugins # necessary for `envs.__init__` which registers all gymnasium environments and loads plugins
from gymnasium import envs from gymnasium import envs
from gymnasium import spaces, utils, vector, wrappers, error, logger from gymnasium import spaces, utils, vector, wrappers, error, logger, functional
from gymnasium import experimental
__all__ = [ __all__ = [
@@ -43,15 +42,15 @@ __all__ = [
"register_envs", "register_envs",
# module folders # module folders
"envs", "envs",
"experimental",
"spaces", "spaces",
"utils", "utils",
"vector", "vector",
"wrappers", "wrappers",
"error", "error",
"logger", "logger",
"functional",
] ]
__version__ = "0.29.0" __version__ = "1.0.0rc1"
# Initializing pygame initializes audio connections through SDL. SDL uses alsa by default on all Linux systems # Initializing pygame initializes audio connections through SDL. SDL uses alsa by default on all Linux systems

View File

@@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Any, Generic, SupportsFloat, TypeVar
import numpy as np import numpy as np
from gymnasium import logger, spaces from gymnasium import spaces
from gymnasium.utils import RecordConstructorArgs, seeding from gymnasium.utils import RecordConstructorArgs, seeding
@@ -37,12 +37,10 @@ class Env(Generic[ObsType, ActType]):
- :attr:`action_space` - The Space object corresponding to valid actions, all valid actions should be contained within the space. - :attr:`action_space` - The Space object corresponding to valid actions, all valid actions should be contained within the space.
- :attr:`observation_space` - The Space object corresponding to valid observations, all valid observations should be contained within the space. - :attr:`observation_space` - The Space object corresponding to valid observations, all valid observations should be contained within the space.
- :attr:`reward_range` - A tuple corresponding to the minimum and maximum possible rewards for an agent over an episode.
The default reward range is set to :math:`(-\infty,+\infty)`.
- :attr:`spec` - An environment spec that contains the information used to initialize the environment from :meth:`gymnasium.make` - :attr:`spec` - An environment spec that contains the information used to initialize the environment from :meth:`gymnasium.make`
- :attr:`metadata` - The metadata of the environment, i.e. render modes, render fps - :attr:`metadata` - The metadata of the environment, i.e. render modes, render fps
- :attr:`np_random` - The random number generator for the environment. This is automatically assigned during - :attr:`np_random` - The random number generator for the environment. This is automatically assigned during
``super().reset(seed=seed)`` and when assessing ``self.np_random``. ``super().reset(seed=seed)`` and when assessing :attr:`np_random`.
.. seealso:: For modifying or extending environments use the :py:class:`gymnasium.Wrapper` class .. seealso:: For modifying or extending environments use the :py:class:`gymnasium.Wrapper` class
@@ -54,7 +52,6 @@ class Env(Generic[ObsType, ActType]):
metadata: dict[str, Any] = {"render_modes": []} metadata: dict[str, Any] = {"render_modes": []}
# define render_mode if your environment supports rendering # define render_mode if your environment supports rendering
render_mode: str | None = None render_mode: str | None = None
reward_range = (-float("inf"), float("inf"))
spec: EnvSpec | None = None spec: EnvSpec | None = None
# Set these in ALL subclasses # Set these in ALL subclasses
@@ -238,6 +235,10 @@ class Env(Generic[ObsType, ActType]):
"""Gets the attribute `name` from the environment.""" """Gets the attribute `name` from the environment."""
return getattr(self, name) return getattr(self, name)
def set_wrapper_attr(self, name: str, value: Any):
"""Sets the attribute `name` on the environment with `value`."""
setattr(self, name, value)
WrapperObsType = TypeVar("WrapperObsType") WrapperObsType = TypeVar("WrapperObsType")
WrapperActType = TypeVar("WrapperActType") WrapperActType = TypeVar("WrapperActType")
@@ -268,56 +269,41 @@ class Wrapper(
env: The environment to wrap env: The environment to wrap
""" """
self.env = env self.env = env
assert isinstance(env, Env)
self._action_space: spaces.Space[WrapperActType] | None = None self._action_space: spaces.Space[WrapperActType] | None = None
self._observation_space: spaces.Space[WrapperObsType] | None = None self._observation_space: spaces.Space[WrapperObsType] | None = None
self._reward_range: tuple[SupportsFloat, SupportsFloat] | None = None
self._metadata: dict[str, Any] | None = None self._metadata: dict[str, Any] | None = None
self._cached_spec: EnvSpec | None = None self._cached_spec: EnvSpec | None = None
def __getattr__(self, name: str) -> Any: def step(
"""Returns an attribute with ``name``, unless ``name`` starts with an underscore. self, action: WrapperActType
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]:
"""Uses the :meth:`step` of the :attr:`env` that can be overwritten to change the returned data."""
return self.env.step(action)
Args: def reset(
name: The variable name self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[WrapperObsType, dict[str, Any]]:
"""Uses the :meth:`reset` of the :attr:`env` that can be overwritten to change the returned data."""
return self.env.reset(seed=seed, options=options)
Returns: def render(self) -> RenderFrame | list[RenderFrame] | None:
The value of the variable in the wrapper stack """Uses the :meth:`render` of the :attr:`env` that can be overwritten to change the returned data."""
return self.env.render()
Warnings: def close(self):
This feature is deprecated and removed in v1.0 and replaced with `env.get_attr(name})` """Closes the wrapper and :attr:`env`."""
return self.env.close()
@property
def unwrapped(self) -> Env[ObsType, ActType]:
"""Returns the base environment of the wrapper.
This will be the bare :class:`gymnasium.Env` environment, underneath all layers of wrappers.
""" """
if name == "_np_random": return self.env.unwrapped
raise AttributeError(
"Can't access `_np_random` of a wrapper, use `self.unwrapped._np_random` or `self.np_random`."
)
elif name.startswith("_"):
raise AttributeError(f"accessing private attribute '{name}' is prohibited")
logger.warn(
f"env.{name} to get variables from other wrappers is deprecated and will be removed in v1.0, "
f"to get this variable you can do `env.unwrapped.{name}` for environment variables or `env.get_attr('{name}')` that will search the remaining wrappers."
)
return getattr(self.env, name)
def get_wrapper_attr(self, name: str) -> Any:
"""Gets an attribute from the wrapper and lower environments if `name` doesn't exist in this object.
Args:
name: The variable name to get
Returns:
The variable with name in wrapper or lower environments
"""
if hasattr(self, name):
return getattr(self, name)
else:
try:
return self.env.get_wrapper_attr(name)
except AttributeError as e:
raise AttributeError(
f"wrapper {self.class_name()} has no attribute {name!r}"
) from e
@property @property
def spec(self) -> EnvSpec | None: def spec(self) -> EnvSpec | None:
@@ -362,6 +348,53 @@ class Wrapper(
kwargs=kwargs, kwargs=kwargs,
) )
def get_wrapper_attr(self, name: str) -> Any:
"""Gets an attribute from the wrapper and lower environments if `name` doesn't exist in this object.
Args:
name: The variable name to get
Returns:
The variable with name in wrapper or lower environments
"""
if hasattr(self, name):
return getattr(self, name)
else:
try:
return self.env.get_wrapper_attr(name)
except AttributeError as e:
raise AttributeError(
f"wrapper {self.class_name()} has no attribute {name!r}"
) from e
def set_wrapper_attr(self, name: str, value: Any):
"""Sets an attribute on this wrapper or lower environment if `name` is already defined.
Args:
name: The variable name
value: The new variable value
"""
sub_env = self.env
attr_set = False
while attr_set is False and isinstance(sub_env, Wrapper):
if hasattr(sub_env, name):
setattr(sub_env, name, value)
attr_set = True
else:
sub_env = sub_env.env
if attr_set is False:
setattr(sub_env, name, value)
def __str__(self):
"""Returns the wrapper name and the :attr:`env` representation string."""
return f"<{type(self).__name__}{self.env}>"
def __repr__(self):
"""Returns the string representation of the wrapper."""
return str(self)
@classmethod @classmethod
def class_name(cls) -> str: def class_name(cls) -> str:
"""Returns the class name of the wrapper.""" """Returns the class name of the wrapper."""
@@ -393,18 +426,6 @@ class Wrapper(
def observation_space(self, space: spaces.Space[WrapperObsType]): def observation_space(self, space: spaces.Space[WrapperObsType]):
self._observation_space = space self._observation_space = space
@property
def reward_range(self) -> tuple[SupportsFloat, SupportsFloat]:
"""Return the :attr:`Env` :attr:`reward_range` unless overwritten then the wrapper :attr:`reward_range` is used."""
if self._reward_range is None:
return self.env.reward_range
logger.warn("The `reward_range` is deprecated and will be removed in v1.0")
return self._reward_range
@reward_range.setter
def reward_range(self, value: tuple[SupportsFloat, SupportsFloat]):
self._reward_range = value
@property @property
def metadata(self) -> dict[str, Any]: def metadata(self) -> dict[str, Any]:
"""Returns the :attr:`Env` :attr:`metadata`.""" """Returns the :attr:`Env` :attr:`metadata`."""
@@ -440,54 +461,15 @@ class Wrapper(
"Can't access `_np_random` of a wrapper, use `.unwrapped._np_random` or `.np_random`." "Can't access `_np_random` of a wrapper, use `.unwrapped._np_random` or `.np_random`."
) )
def step(
self, action: WrapperActType
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]:
"""Uses the :meth:`step` of the :attr:`env` that can be overwritten to change the returned data."""
return self.env.step(action)
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[WrapperObsType, dict[str, Any]]:
"""Uses the :meth:`reset` of the :attr:`env` that can be overwritten to change the returned data."""
return self.env.reset(seed=seed, options=options)
def render(self) -> RenderFrame | list[RenderFrame] | None:
"""Uses the :meth:`render` of the :attr:`env` that can be overwritten to change the returned data."""
return self.env.render()
def close(self):
"""Closes the wrapper and :attr:`env`."""
return self.env.close()
def __str__(self):
"""Returns the wrapper name and the :attr:`env` representation string."""
return f"<{type(self).__name__}{self.env}>"
def __repr__(self):
"""Returns the string representation of the wrapper."""
return str(self)
@property
def unwrapped(self) -> Env[ObsType, ActType]:
"""Returns the base environment of the wrapper.
This will be the bare :class:`gymnasium.Env` environment, underneath all layers of wrappers.
"""
return self.env.unwrapped
class ObservationWrapper(Wrapper[WrapperObsType, ActType, ObsType, ActType]): class ObservationWrapper(Wrapper[WrapperObsType, ActType, ObsType, ActType]):
"""Superclass of wrappers that can modify observations using :meth:`observation` for :meth:`reset` and :meth:`step`. """Modify observations from :meth:`Env.reset` and :meth:`Env.step` using :meth:`observation` function.
If you would like to apply a function to only the observation before If you would like to apply a function to only the observation before
passing it to the learning code, you can simply inherit from :class:`ObservationWrapper` and overwrite the method passing it to the learning code, you can simply inherit from :class:`ObservationWrapper` and overwrite the method
:meth:`observation` to implement that transformation. The transformation defined in that method must be :meth:`observation` to implement that transformation. The transformation defined in that method must be
reflected by the :attr:`env` observation space. Otherwise, you need to specify the new observation space of the reflected by the :attr:`env` observation space. Otherwise, you need to specify the new observation space of the
wrapper by setting :attr:`self.observation_space` in the :meth:`__init__` method of your wrapper. wrapper by setting :attr:`self.observation_space` in the :meth:`__init__` method of your wrapper.
Among others, Gymnasium provides the observation wrapper :class:`TimeAwareObservation`, which adds information about the
index of the timestep to the observation.
""" """
def __init__(self, env: Env[ObsType, ActType]): def __init__(self, env: Env[ObsType, ActType]):

View File

@@ -1,14 +1,7 @@
"""Registers the internal gym envs then loads the env plugins for module using the entry point.""" """Registers the internal gym envs then loads the env plugins for module using the entry point."""
from typing import Any from typing import Any
from gymnasium.envs.registration import ( from gymnasium.envs.registration import make, pprint_registry, register, registry, spec
load_plugin_envs,
make,
pprint_registry,
register,
registry,
spec,
)
# Classic # Classic
@@ -459,7 +452,3 @@ def _raise_shimmy_error(*args: Any, **kwargs: Any):
# When installed, shimmy will re-register these environments with the correct entry_point # When installed, shimmy will re-register these environments with the correct entry_point
register(id="GymV21Environment-v0", entry_point=_raise_shimmy_error) register(id="GymV21Environment-v0", entry_point=_raise_shimmy_error)
register(id="GymV26Environment-v0", entry_point=_raise_shimmy_error) register(id="GymV26Environment-v0", entry_point=_raise_shimmy_error)
# Hook to load plugins from entry points
load_plugin_envs()

View File

@@ -840,7 +840,7 @@ def heuristic(env, s):
-(s[3]) * 0.5 -(s[3]) * 0.5
) # override to reduce fall speed, that's all we need after contact ) # override to reduce fall speed, that's all we need after contact
if env.continuous: if env.unwrapped.continuous:
a = np.array([hover_todo * 20 - 1, -angle_todo * 20]) a = np.array([hover_todo * 20 - 1, -angle_todo * 20])
a = np.clip(a, -1, +1) a = np.clip(a, -1, +1)
else: else:

View File

@@ -12,7 +12,7 @@ import gymnasium as gym
from gymnasium import logger, spaces from gymnasium import logger, spaces
from gymnasium.envs.classic_control import utils from gymnasium.envs.classic_control import utils
from gymnasium.error import DependencyNotInstalled from gymnasium.error import DependencyNotInstalled
from gymnasium.experimental.vector import VectorEnv from gymnasium.vector import VectorEnv
from gymnasium.vector.utils import batch_space from gymnasium.vector.utils import batch_space
@@ -74,13 +74,29 @@ class CartPoleEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]):
## Arguments ## Arguments
```python Cartpole only has ``render_mode`` as a keyword for ``gymnasium.make``.
import gymnasium as gym On reset, the `options` parameter allows the user to change the bounds used to determine the new random state.
gym.make('CartPole-v1')
```
On reset, the `options` parameter allows the user to change the bounds used to determine Examples:
the new random state. >>> import gymnasium as gym
>>> env = gym.make("CartPole-v1", render_mode="rgb_array")
>>> env
<TimeLimit<OrderEnforcing<PassiveEnvChecker<CartPoleEnv<CartPole-v1>>>>>
>>> env.reset(seed=123, options={"low": 0, "high": 1})
(array([0.6823519 , 0.05382102, 0.22035988, 0.18437181], dtype=float32), {})
## Vectorized environment
To increase steps per seconds, users can use a custom vector environment or with an environment vectorizor.
Examples:
>>> import gymnasium as gym
>>> envs = gym.make_vec("CartPole-v1", num_envs=3, vectorization_mode="vector_entry_point")
>>> envs
CartPoleVectorEnv(CartPole-v1, num_envs=3)
>>> envs = gym.make_vec("CartPole-v1", num_envs=3, vectorization_mode="sync")
>>> envs
SyncVectorEnv(CartPole-v1, num_envs=3)
""" """
metadata = { metadata = {
@@ -328,8 +344,10 @@ class CartPoleVectorEnv(VectorEnv):
max_episode_steps: int = 500, max_episode_steps: int = 500,
render_mode: Optional[str] = None, render_mode: Optional[str] = None,
): ):
super().__init__()
self.num_envs = num_envs self.num_envs = num_envs
self.max_episode_steps = max_episode_steps
self.render_mode = render_mode
self.gravity = 9.8 self.gravity = 9.8
self.masscart = 1.0 self.masscart = 1.0
self.masspole = 0.1 self.masspole = 0.1
@@ -339,7 +357,6 @@ class CartPoleVectorEnv(VectorEnv):
self.force_mag = 10.0 self.force_mag = 10.0
self.tau = 0.02 # seconds between state updates self.tau = 0.02 # seconds between state updates
self.kinematics_integrator = "euler" self.kinematics_integrator = "euler"
self.max_episode_steps = max_episode_steps
self.steps = np.zeros(num_envs, dtype=np.int32) self.steps = np.zeros(num_envs, dtype=np.int32)
@@ -367,8 +384,6 @@ class CartPoleVectorEnv(VectorEnv):
self.single_observation_space = spaces.Box(-high, high, dtype=np.float32) self.single_observation_space = spaces.Box(-high, high, dtype=np.float32)
self.observation_space = batch_space(self.single_observation_space, num_envs) self.observation_space = batch_space(self.single_observation_space, num_envs)
self.render_mode = render_mode
self.screen_width = 600 self.screen_width = 600
self.screen_height = 400 self.screen_height = 400
self.screens = None self.screens = None
@@ -464,6 +479,7 @@ class CartPoleVectorEnv(VectorEnv):
def render(self): def render(self):
if self.render_mode is None: if self.render_mode is None:
assert self.spec is not None
gym.logger.warn( gym.logger.warn(
"You are calling render method without specifying any render mode. " "You are calling render method without specifying any render mode. "
"You can specify the render_mode at initialization, " "You can specify the render_mode at initialization, "

View File

@@ -10,10 +10,10 @@ import numpy as np
import gymnasium as gym import gymnasium as gym
from gymnasium.envs.registration import EnvSpec from gymnasium.envs.registration import EnvSpec
from gymnasium.experimental.functional import ActType, FuncEnv, StateType from gymnasium.functional import ActType, FuncEnv, StateType
from gymnasium.experimental.wrappers.jax_to_numpy import jax_to_numpy
from gymnasium.utils import seeding from gymnasium.utils import seeding
from gymnasium.vector.utils import batch_space from gymnasium.vector.utils import batch_space
from gymnasium.wrappers.jax_to_numpy import jax_to_numpy
class FunctionalJaxEnv(gym.Env): class FunctionalJaxEnv(gym.Env):
@@ -89,7 +89,7 @@ class FunctionalJaxEnv(gym.Env):
observation = self.func_env.observation(next_state) observation = self.func_env.observation(next_state)
reward = self.func_env.reward(self.state, action, next_state) reward = self.func_env.reward(self.state, action, next_state)
terminated = self.func_env.terminal(next_state) terminated = self.func_env.terminal(next_state)
info = self.func_env.step_info(self.state, action, next_state) info = self.func_env.transition_info(self.state, action, next_state)
self.state = next_state self.state = next_state
observation = jax_to_numpy(observation) observation = jax_to_numpy(observation)
@@ -113,7 +113,7 @@ class FunctionalJaxEnv(gym.Env):
self.render_state = None self.render_state = None
class FunctionalJaxVectorEnv(gym.experimental.vector.VectorEnv): class FunctionalJaxVectorEnv(gym.vector.VectorEnv):
"""A vector env implementation for functional Jax envs.""" """A vector env implementation for functional Jax envs."""
state: StateType state: StateType
@@ -211,7 +211,7 @@ class FunctionalJaxVectorEnv(gym.experimental.vector.VectorEnv):
else jnp.zeros_like(terminated) else jnp.zeros_like(terminated)
) )
info = self.func_env.step_info(self.state, action, next_state) info = self.func_env.transition_info(self.state, action, next_state)
done = jnp.logical_or(terminated, truncated) done = jnp.logical_or(terminated, truncated)
if jnp.any(done): if jnp.any(done):

View File

@@ -9,12 +9,9 @@ import numpy as np
from jax.random import PRNGKey from jax.random import PRNGKey
import gymnasium as gym import gymnasium as gym
from gymnasium.envs.functional_jax_env import FunctionalJaxEnv, FunctionalJaxVectorEnv
from gymnasium.error import DependencyNotInstalled from gymnasium.error import DependencyNotInstalled
from gymnasium.experimental.functional import ActType, FuncEnv, StateType from gymnasium.functional import ActType, FuncEnv, StateType
from gymnasium.experimental.functional_jax_env import (
FunctionalJaxEnv,
FunctionalJaxVectorEnv,
)
from gymnasium.utils import EzPickle from gymnasium.utils import EzPickle

View File

@@ -10,12 +10,9 @@ import numpy as np
from jax.random import PRNGKey from jax.random import PRNGKey
import gymnasium as gym import gymnasium as gym
from gymnasium.envs.functional_jax_env import FunctionalJaxEnv, FunctionalJaxVectorEnv
from gymnasium.error import DependencyNotInstalled from gymnasium.error import DependencyNotInstalled
from gymnasium.experimental.functional import ActType, FuncEnv, StateType from gymnasium.functional import ActType, FuncEnv, StateType
from gymnasium.experimental.functional_jax_env import (
FunctionalJaxEnv,
FunctionalJaxVectorEnv,
)
from gymnasium.utils import EzPickle from gymnasium.utils import EzPickle

View File

@@ -10,7 +10,6 @@ import importlib.util
import json import json
import re import re
import sys import sys
import traceback
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass, field from dataclasses import dataclass, field
from types import ModuleType from types import ModuleType
@@ -86,11 +85,13 @@ class EnvSpec:
* **nondeterministic**: If the observation of an environment cannot be repeated with the same initial state, random number generator state and actions. * **nondeterministic**: If the observation of an environment cannot be repeated with the same initial state, random number generator state and actions.
* **max_episode_steps**: The max number of steps that the environment can take before truncation * **max_episode_steps**: The max number of steps that the environment can take before truncation
* **order_enforce**: If to enforce the order of :meth:`gymnasium.Env.reset` before :meth:`gymnasium.Env.step` and :meth:`gymnasium.Env.render` functions * **order_enforce**: If to enforce the order of :meth:`gymnasium.Env.reset` before :meth:`gymnasium.Env.step` and :meth:`gymnasium.Env.render` functions
* **autoreset**: If to automatically reset the environment on episode end
* **disable_env_checker**: If to disable the environment checker wrapper in :meth:`gymnasium.make`, by default False (runs the environment checker) * **disable_env_checker**: If to disable the environment checker wrapper in :meth:`gymnasium.make`, by default False (runs the environment checker)
* **kwargs**: Additional keyword arguments passed to the environment during initialisation * **kwargs**: Additional keyword arguments passed to the environment during initialisation
* **additional_wrappers**: A tuple of additional wrappers applied to the environment (WrapperSpec) * **additional_wrappers**: A tuple of additional wrappers applied to the environment (WrapperSpec)
* **vector_entry_point**: The location of the vectorized environment to create from * **vector_entry_point**: The location of the vectorized environment to create from
Changelogs:
v1.0.0 - Autoreset attribute removed
""" """
id: str id: str
@@ -103,9 +104,7 @@ class EnvSpec:
# Wrappers # Wrappers
max_episode_steps: int | None = field(default=None) max_episode_steps: int | None = field(default=None)
order_enforce: bool = field(default=True) order_enforce: bool = field(default=True)
autoreset: bool = field(default=False)
disable_env_checker: bool = field(default=False) disable_env_checker: bool = field(default=False)
apply_api_compatibility: bool = field(default=False)
# Environment arguments # Environment arguments
kwargs: dict = field(default_factory=dict) kwargs: dict = field(default_factory=dict)
@@ -224,12 +223,8 @@ class EnvSpec:
output += f"\nmax_episode_steps={self.max_episode_steps}" output += f"\nmax_episode_steps={self.max_episode_steps}"
if print_all or self.order_enforce is not True: if print_all or self.order_enforce is not True:
output += f"\norder_enforce={self.order_enforce}" output += f"\norder_enforce={self.order_enforce}"
if print_all or self.autoreset is not False:
output += f"\nautoreset={self.autoreset}"
if print_all or self.disable_env_checker is not False: if print_all or self.disable_env_checker is not False:
output += f"\ndisable_env_checker={self.disable_env_checker}" output += f"\ndisable_env_checker={self.disable_env_checker}"
if print_all or self.apply_api_compatibility is not False:
output += f"\napplied_api_compatibility={self.apply_api_compatibility}"
if print_all or self.additional_wrappers: if print_all or self.additional_wrappers:
wrapper_output: list[str] = [] wrapper_output: list[str] = []
@@ -547,55 +542,6 @@ def load_env_creator(name: str) -> EnvCreator | VectorEnvCreator:
return fn return fn
def load_plugin_envs(entry_point: str = "gymnasium.envs"):
"""Load modules (plugins) using the gymnasium entry points in order to register external module's environments on ``import gymnasium``.
Args:
entry_point: The string for the entry point.
"""
# Load third-party environments
for plugin in metadata.entry_points(group=entry_point):
# Python 3.8 doesn't support plugin.module, plugin.attr
# So we'll have to try and parse this ourselves
module, attr = None, None
try:
module, attr = plugin.module, plugin.attr # type: ignore ## error: Cannot access member "attr" for type "EntryPoint"
except AttributeError:
if ":" in plugin.value:
module, attr = plugin.value.split(":", maxsplit=1)
else:
module, attr = plugin.value, None
except Exception as e:
logger.warn(
f"While trying to load plugin `{plugin}` from {entry_point}, an exception occurred: {e}"
)
module, attr = None, None
finally:
if attr is None:
raise error.Error(
f"Gymnasium environment plugin `{module}` must specify a function to execute, not a root module"
)
context = namespace(plugin.name)
if plugin.name.startswith("__") and plugin.name.endswith("__"):
# `__internal__` is an artifact of the plugin system when the root namespace had an allow-list.
# The allow-list is now removed and plugins can register environments in the root namespace with the `__root__` magic key.
if plugin.name == "__root__" or plugin.name == "__internal__":
context = contextlib.nullcontext()
else:
logger.warn(
f"The environment namespace magic key `{plugin.name}` is unsupported. "
"To register an environment at the root namespace you should specify the `__root__` namespace."
)
with context:
fn = plugin.load()
try:
fn()
except Exception:
logger.warn(f"plugin: {plugin.value} raised {traceback.format_exc()}")
def register_envs(env_module: ModuleType): def register_envs(env_module: ModuleType):
"""A No-op function such that it can appear to IDEs that a module is used.""" """A No-op function such that it can appear to IDEs that a module is used."""
pass pass
@@ -618,9 +564,7 @@ def register(
nondeterministic: bool = False, nondeterministic: bool = False,
max_episode_steps: int | None = None, max_episode_steps: int | None = None,
order_enforce: bool = True, order_enforce: bool = True,
autoreset: bool = False,
disable_env_checker: bool = False, disable_env_checker: bool = False,
apply_api_compatibility: bool = False,
additional_wrappers: tuple[WrapperSpec, ...] = (), additional_wrappers: tuple[WrapperSpec, ...] = (),
vector_entry_point: VectorEnvCreator | str | None = None, vector_entry_point: VectorEnvCreator | str | None = None,
**kwargs: Any, **kwargs: Any,
@@ -640,13 +584,13 @@ def register(
max_episode_steps: The maximum number of episodes steps before truncation. Used by the :class:`gymnasium.wrappers.TimeLimit` wrapper if not ``None``. max_episode_steps: The maximum number of episodes steps before truncation. Used by the :class:`gymnasium.wrappers.TimeLimit` wrapper if not ``None``.
order_enforce: If to enable the order enforcer wrapper to ensure users run functions in the correct order. order_enforce: If to enable the order enforcer wrapper to ensure users run functions in the correct order.
If ``True``, then the :class:`gymnasium.wrappers.OrderEnforcing` is applied to the environment. If ``True``, then the :class:`gymnasium.wrappers.OrderEnforcing` is applied to the environment.
autoreset: If to add the :class:`gymnasium.wrappers.AutoResetWrapper` such that on ``(terminated or truncated) is True``, :meth:`gymnasium.Env.reset` is called.
disable_env_checker: If to disable the :class:`gymnasium.wrappers.PassiveEnvChecker` to the environment. disable_env_checker: If to disable the :class:`gymnasium.wrappers.PassiveEnvChecker` to the environment.
apply_api_compatibility: If to apply the :class:`gymnasium.wrappers.StepAPICompatibility` wrapper to the environment.
Use if the environment is implemented in the gym v0.21 environment API.
additional_wrappers: Additional wrappers to apply the environment. additional_wrappers: Additional wrappers to apply the environment.
vector_entry_point: The entry point for creating the vector environment vector_entry_point: The entry point for creating the vector environment
**kwargs: arbitrary keyword arguments which are passed to the environment constructor on initialisation. **kwargs: arbitrary keyword arguments which are passed to the environment constructor on initialisation.
Changelogs:
v1.0.0 - `autoreset` and `apply_api_compatibility` parameter was removed
""" """
assert ( assert (
entry_point is not None or vector_entry_point is not None entry_point is not None or vector_entry_point is not None
@@ -669,11 +613,6 @@ def register(
ns_id = ns ns_id = ns
full_env_id = get_env_id(ns_id, name, version) full_env_id = get_env_id(ns_id, name, version)
if autoreset is True:
logger.warn(
"`gymnasium.register(..., autoreset=True)` is deprecated and will be removed in v1.0. If users wish to use it then add the auto reset wrapper in the `addition_wrappers` argument."
)
new_spec = EnvSpec( new_spec = EnvSpec(
id=full_env_id, id=full_env_id,
entry_point=entry_point, entry_point=entry_point,
@@ -681,9 +620,7 @@ def register(
nondeterministic=nondeterministic, nondeterministic=nondeterministic,
max_episode_steps=max_episode_steps, max_episode_steps=max_episode_steps,
order_enforce=order_enforce, order_enforce=order_enforce,
autoreset=autoreset,
disable_env_checker=disable_env_checker, disable_env_checker=disable_env_checker,
apply_api_compatibility=apply_api_compatibility,
**kwargs, **kwargs,
additional_wrappers=additional_wrappers, additional_wrappers=additional_wrappers,
vector_entry_point=vector_entry_point, vector_entry_point=vector_entry_point,
@@ -698,8 +635,6 @@ def register(
def make( def make(
id: str | EnvSpec, id: str | EnvSpec,
max_episode_steps: int | None = None, max_episode_steps: int | None = None,
autoreset: bool | None = None,
apply_api_compatibility: bool | None = None,
disable_env_checker: bool | None = None, disable_env_checker: bool | None = None,
**kwargs: Any, **kwargs: Any,
) -> Env: ) -> Env:
@@ -710,12 +645,9 @@ def make(
Args: Args:
id: A string for the environment id or a :class:`EnvSpec`. Optionally if using a string, a module to import can be included, e.g. ``'module:Env-v0'``. id: A string for the environment id or a :class:`EnvSpec`. Optionally if using a string, a module to import can be included, e.g. ``'module:Env-v0'``.
This is equivalent to importing the module first to register the environment followed by making the environment. This is equivalent to importing the module first to register the environment followed by making the environment.
max_episode_steps: Maximum length of an episode, can override the registered :class:`EnvSpec` ``max_episode_steps``. max_episode_steps: Maximum length of an episode, can override the registered :class:`EnvSpec` ``max_episode_steps``
The value is used by :class:`gymnasium.wrappers.TimeLimit`. with the value being passed to :class:`gymnasium.wrappers.TimeLimit`.
autoreset: Whether to automatically reset the environment after each episode (:class:`gymnasium.wrappers.AutoResetWrapper`). Using ``max_episode_steps=-1`` will not apply the wrapper to the environment.
apply_api_compatibility: Whether to wrap the environment with the :class:`gymnasium.wrappers.StepAPICompatibility` wrapper that
converts the environment step from a done bool to return termination and truncation bools.
By default, the argument is None in which the :class:`EnvSpec` ``apply_api_compatibility`` is used, otherwise this variable is used in favor.
disable_env_checker: If to add :class:`gymnasium.wrappers.PassiveEnvChecker`, ``None`` will default to the disable_env_checker: If to add :class:`gymnasium.wrappers.PassiveEnvChecker`, ``None`` will default to the
:class:`EnvSpec` ``disable_env_checker`` value otherwise use this value will be used. :class:`EnvSpec` ``disable_env_checker`` value otherwise use this value will be used.
kwargs: Additional arguments to pass to the environment constructor. kwargs: Additional arguments to pass to the environment constructor.
@@ -725,6 +657,9 @@ def make(
Raises: Raises:
Error: If the ``id`` doesn't exist in the :attr:`registry` Error: If the ``id`` doesn't exist in the :attr:`registry`
Changelogs:
v1.0.0 - `autoreset` and `apply_api_compatibility` was removed
""" """
if isinstance(id, EnvSpec): if isinstance(id, EnvSpec):
env_spec = id env_spec = id
@@ -790,14 +725,6 @@ def make(
f"that is not in the possible render_modes ({render_modes})." f"that is not in the possible render_modes ({render_modes})."
) )
if apply_api_compatibility or (
apply_api_compatibility is None and env_spec.apply_api_compatibility
):
# If we use the compatibility layer, we treat the render mode explicitly and don't pass it to the env creator
render_mode = env_spec_kwargs.pop("render_mode", None)
else:
render_mode = None
try: try:
env = env_creator(**env_spec_kwargs) env = env_creator(**env_spec_kwargs)
except TypeError as e: except TypeError as e:
@@ -823,9 +750,7 @@ def make(
nondeterministic=env_spec.nondeterministic, nondeterministic=env_spec.nondeterministic,
max_episode_steps=None, max_episode_steps=None,
order_enforce=False, order_enforce=False,
autoreset=False,
disable_env_checker=True, disable_env_checker=True,
apply_api_compatibility=False,
kwargs=env_spec_kwargs, kwargs=env_spec_kwargs,
additional_wrappers=(), additional_wrappers=(),
vector_entry_point=env_spec.vector_entry_point, vector_entry_point=env_spec.vector_entry_point,
@@ -845,15 +770,6 @@ def make(
f"The environment's wrapper spec {recreated_wrapper_spec} is different from the saved `EnvSpec` additional wrapper {env_spec_wrapper_spec}" f"The environment's wrapper spec {recreated_wrapper_spec} is different from the saved `EnvSpec` additional wrapper {env_spec_wrapper_spec}"
) )
# Add step API wrapper
if apply_api_compatibility is True or (
apply_api_compatibility is None and env_spec.apply_api_compatibility is True
):
logger.warn(
"`gymnasium.make(..., apply_api_compatibility=True)` and `env_spec.apply_api_compatibility` is deprecated and will be removed in v1.0"
)
env = gym.wrappers.EnvCompatibility(env, render_mode)
# Run the environment checker as the lowest level wrapper # Run the environment checker as the lowest level wrapper
if disable_env_checker is False or ( if disable_env_checker is False or (
disable_env_checker is None and env_spec.disable_env_checker is False disable_env_checker is None and env_spec.disable_env_checker is False
@@ -865,19 +781,12 @@ def make(
env = gym.wrappers.OrderEnforcing(env) env = gym.wrappers.OrderEnforcing(env)
# Add the time limit wrapper # Add the time limit wrapper
if max_episode_steps != -1:
if max_episode_steps is not None: if max_episode_steps is not None:
env = gym.wrappers.TimeLimit(env, max_episode_steps) env = gym.wrappers.TimeLimit(env, max_episode_steps)
elif env_spec.max_episode_steps is not None: elif env_spec.max_episode_steps is not None:
env = gym.wrappers.TimeLimit(env, env_spec.max_episode_steps) env = gym.wrappers.TimeLimit(env, env_spec.max_episode_steps)
# Add the auto-reset wrapper
if autoreset is True or (autoreset is None and env_spec.autoreset is True):
env = gym.wrappers.AutoResetWrapper(env)
logger.warn(
"`gymnasium.make(..., autoreset=True)` is deprecated and will be removed in v1.0"
)
for wrapper_spec in env_spec.additional_wrappers[num_prior_wrappers:]: for wrapper_spec in env_spec.additional_wrappers[num_prior_wrappers:]:
if wrapper_spec.kwargs is None: if wrapper_spec.kwargs is None:
raise ValueError( raise ValueError(
@@ -898,25 +807,25 @@ def make(
def make_vec( def make_vec(
id: str | EnvSpec, id: str | EnvSpec,
num_envs: int = 1, num_envs: int = 1,
vectorization_mode: str = "async", vectorization_mode: str | None = None,
vector_kwargs: dict[str, Any] | None = None, vector_kwargs: dict[str, Any] | None = None,
wrappers: Sequence[Callable[[Env], Wrapper]] | None = None, wrappers: Sequence[Callable[[Env], Wrapper]] | None = None,
**kwargs, **kwargs,
) -> gym.experimental.vector.VectorEnv: ) -> gym.vector.VectorEnv:
"""Create a vector environment according to the given ID. """Create a vector environment according to the given ID.
Note: To find all available environments use :func:`gymnasium.pprint_registry` or ``gymnasium.registry.keys()`` for all valid ids.
This feature is experimental, and is likely to change in future releases. We refer to the Vector environment as the vectorizor while the environment being vectorized is the base or vectorized environment (``vectorizor(vectorized env)``).
To find all available environments use `gymnasium.envs.registry.keys()` for all valid ids.
Args: Args:
id: Name of the environment. Optionally, a module to import can be included, eg. 'module:Env-v0' id: Name of the environment. Optionally, a module to import can be included, eg. 'module:Env-v0'
num_envs: Number of environments to create num_envs: Number of environments to create
vectorization_mode: How to vectorize the environment. Can be either "async", "sync" or "custom" vectorization_mode: The vectorization method used, defaults to ``None`` such that if a ``vector_entry_point`` exists,
vector_kwargs: Additional arguments to pass to the vectorized environment constructor. this is first used otherwise defaults to ``sync`` to use the :class:`gymnasium.vector.SyncVectorEnv`.
wrappers: A sequence of wrapper functions to apply to the environment. Can only be used in "sync" or "async" mode. Valid modes are ``"async"``, ``"sync"`` or ``"vector_entry_point"``.
**kwargs: Additional arguments to pass to the environment constructor. vector_kwargs: Additional arguments to pass to the vectorizor environment constructor, i.e., ``SyncVectorEnv(..., **vector_kwargs)``.
wrappers: A sequence of wrapper functions to apply to the base environment. Can only be used in ``"sync"`` or ``"async"`` mode.
**kwargs: Additional arguments passed to the base environment constructor.
Returns: Returns:
An instance of the environment. An instance of the environment.
@@ -926,87 +835,93 @@ def make_vec(
""" """
if vector_kwargs is None: if vector_kwargs is None:
vector_kwargs = {} vector_kwargs = {}
if wrappers is None: if wrappers is None:
wrappers = [] wrappers = []
if isinstance(id, EnvSpec): if isinstance(id, EnvSpec):
spec_ = id id_env_spec = id
else: env_spec_kwargs = id_env_spec.kwargs.copy()
spec_ = _find_spec(id)
_kwargs = spec_.kwargs.copy() num_envs = env_spec_kwargs.pop("num_envs", num_envs)
_kwargs.update(kwargs) vectorization_mode = env_spec_kwargs.pop(
"vectorization_mode", vectorization_mode
# Check if we have the necessary entry point
if vectorization_mode in ("sync", "async"):
if spec_.entry_point is None:
raise error.Error(
f"Cannot create vectorized environment for {id} because it doesn't have an entry point defined."
) )
entry_point = spec_.entry_point vector_kwargs = env_spec_kwargs.pop("vector_kwargs", vector_kwargs)
elif vectorization_mode in ("custom",): wrappers = env_spec_kwargs.pop("wrappers", wrappers)
if spec_.vector_entry_point is None:
raise error.Error(
f"Cannot create vectorized environment for {id} because it doesn't have a vector entry point defined."
)
entry_point = spec_.vector_entry_point
else: else:
raise error.Error(f"Invalid vectorization mode: {vectorization_mode}") id_env_spec = _find_spec(id)
env_spec_kwargs = id_env_spec.kwargs.copy()
if callable(entry_point): env_spec_kwargs.update(kwargs)
env_creator = entry_point
# Update the vectorization_mode if None
if vectorization_mode is None:
if id_env_spec.vector_entry_point is not None:
vectorization_mode = "vector_entry_point"
else: else:
# Assume it's a string vectorization_mode = "sync"
env_creator = load_env_creator(entry_point)
def _create_env(): def create_single_env() -> Env:
# Env creator for use with sync and async modes single_env = make(id_env_spec.id, **env_spec_kwargs.copy())
_kwargs_copy = _kwargs.copy()
render_mode = _kwargs.get("render_mode", None)
if render_mode is not None:
inner_render_mode = (
render_mode[: -len("_list")]
if render_mode.endswith("_list")
else render_mode
)
_kwargs_copy["render_mode"] = inner_render_mode
_env = env_creator(**_kwargs_copy)
_env.spec = spec_
if spec_.max_episode_steps is not None:
_env = gym.wrappers.TimeLimit(_env, spec_.max_episode_steps)
if render_mode is not None and render_mode.endswith("_list"):
_env = gym.wrappers.RenderCollection(_env)
for wrapper in wrappers: for wrapper in wrappers:
_env = wrapper(_env) single_env = wrapper(single_env)
return _env return single_env
if vectorization_mode == "sync": if vectorization_mode == "sync":
env = gym.experimental.vector.SyncVectorEnv( if id_env_spec.entry_point is None:
env_fns=[_create_env for _ in range(num_envs)], raise error.Error(
f"Cannot create vectorized environment for {id_env_spec.id} because it doesn't have an entry point defined."
)
env = gym.vector.SyncVectorEnv(
env_fns=(create_single_env for _ in range(num_envs)),
**vector_kwargs, **vector_kwargs,
) )
elif vectorization_mode == "async": elif vectorization_mode == "async":
env = gym.experimental.vector.AsyncVectorEnv( if id_env_spec.entry_point is None:
env_fns=[_create_env for _ in range(num_envs)], raise error.Error(
f"Cannot create vectorized environment for {id_env_spec.id} because it doesn't have an entry point defined."
)
env = gym.vector.AsyncVectorEnv(
env_fns=[create_single_env for _ in range(num_envs)],
**vector_kwargs, **vector_kwargs,
) )
elif vectorization_mode == "custom": elif vectorization_mode == "vector_entry_point":
entry_point = id_env_spec.vector_entry_point
if entry_point is None:
raise error.Error(
f"Cannot create vectorized environment for {id} because it doesn't have a vector entry point defined."
)
elif callable(entry_point):
env_creator = entry_point
else: # Assume it's a string
env_creator = load_env_creator(entry_point)
if len(wrappers) > 0: if len(wrappers) > 0:
raise error.Error("Cannot use custom vectorization mode with wrappers.") raise error.Error(
vector_kwargs["max_episode_steps"] = spec_.max_episode_steps "Cannot use `vector_entry_point` vectorization mode with the wrappers argument."
)
if "max_episode_steps" not in vector_kwargs:
vector_kwargs["max_episode_steps"] = id_env_spec.max_episode_steps
env = env_creator(num_envs=num_envs, **vector_kwargs) env = env_creator(num_envs=num_envs, **vector_kwargs)
else: else:
raise error.Error(f"Invalid vectorization mode: {vectorization_mode}") raise error.Error(f"Invalid vectorization mode: {vectorization_mode}")
# Copies the environment creation specification and kwargs to add to the environment specification details # Copies the environment creation specification and kwargs to add to the environment specification details
spec_ = copy.deepcopy(spec_) copied_id_spec = copy.deepcopy(id_env_spec)
spec_.kwargs = _kwargs copied_id_spec.kwargs = env_spec_kwargs
env.unwrapped.spec = spec_ if num_envs != 1:
copied_id_spec.kwargs["num_envs"] = num_envs
if vectorization_mode != "async":
copied_id_spec.kwargs["vectorization_mode"] = vectorization_mode
if vector_kwargs is not None:
copied_id_spec.kwargs["vector_kwargs"] = vector_kwargs
if wrappers is not None:
copied_id_spec.kwargs["wrappers"] = wrappers
env.unwrapped.spec = copied_id_spec
return env return env

View File

@@ -12,9 +12,9 @@ from jax import random
from jax.random import PRNGKey from jax.random import PRNGKey
from gymnasium import spaces from gymnasium import spaces
from gymnasium.envs.functional_jax_env import FunctionalJaxEnv
from gymnasium.error import DependencyNotInstalled from gymnasium.error import DependencyNotInstalled
from gymnasium.experimental.functional import ActType, FuncEnv, StateType from gymnasium.functional import ActType, FuncEnv, StateType
from gymnasium.experimental.functional_jax_env import FunctionalJaxEnv
from gymnasium.utils import EzPickle, seeding from gymnasium.utils import EzPickle, seeding
from gymnasium.wrappers import HumanRendering from gymnasium.wrappers import HumanRendering

View File

@@ -12,9 +12,9 @@ import numpy as np
from jax.random import PRNGKey from jax.random import PRNGKey
from gymnasium import spaces from gymnasium import spaces
from gymnasium.envs.functional_jax_env import FunctionalJaxEnv
from gymnasium.error import DependencyNotInstalled from gymnasium.error import DependencyNotInstalled
from gymnasium.experimental.functional import ActType, FuncEnv, StateType from gymnasium.functional import ActType, FuncEnv, StateType
from gymnasium.experimental.functional_jax_env import FunctionalJaxEnv
from gymnasium.utils import EzPickle from gymnasium.utils import EzPickle
from gymnasium.wrappers import HumanRendering from gymnasium.wrappers import HumanRendering

View File

@@ -1,25 +0,0 @@
"""Root __init__ of the gym experimental wrappers."""
from gymnasium.experimental import functional, vector, wrappers
# from gymnasium.experimental.functional import FuncEnv
# from gymnasium.experimental.vector.async_vector_env import AsyncVectorEnv
# from gymnasium.experimental.vector.sync_vector_env import SyncVectorEnv
# from gymnasium.experimental.vector.vector_env import VectorEnv, VectorWrapper
__all__ = [
# Functional
# "FuncEnv",
"functional",
# Vector
# "VectorEnv",
# "VectorWrapper",
# "SyncVectorEnv",
# "AsyncVectorEnv",
# wrappers
"wrappers",
"vector",
]

View File

@@ -1,23 +0,0 @@
"""Experimental vector env API."""
from gymnasium.experimental.vector import utils
from gymnasium.experimental.vector.async_vector_env import AsyncVectorEnv
from gymnasium.experimental.vector.sync_vector_env import SyncVectorEnv
from gymnasium.experimental.vector.vector_env import (
VectorActionWrapper,
VectorEnv,
VectorObservationWrapper,
VectorRewardWrapper,
VectorWrapper,
)
__all__ = [
"VectorEnv",
"VectorWrapper",
"VectorObservationWrapper",
"VectorActionWrapper",
"VectorRewardWrapper",
"SyncVectorEnv",
"AsyncVectorEnv",
"utils",
]

View File

@@ -1,685 +0,0 @@
"""An async vector environment."""
from __future__ import annotations
import multiprocessing
import sys
import time
from copy import deepcopy
from enum import Enum
from multiprocessing import Queue
from multiprocessing.connection import Connection
from typing import Any, Callable, Sequence
import numpy as np
from gymnasium import logger
from gymnasium.core import Env, ObsType
from gymnasium.error import (
AlreadyPendingCallError,
ClosedEnvironmentError,
CustomSpaceError,
NoAsyncCallError,
)
from gymnasium.experimental.vector.utils import (
CloudpickleWrapper,
batch_space,
clear_mpi_env_vars,
concatenate,
create_empty_array,
create_shared_memory,
iterate,
read_from_shared_memory,
write_to_shared_memory,
)
from gymnasium.experimental.vector.vector_env import VectorEnv
__all__ = ["AsyncVectorEnv"]
class AsyncState(Enum):
DEFAULT = "default"
WAITING_RESET = "reset"
WAITING_STEP = "step"
WAITING_CALL = "call"
class AsyncVectorEnv(VectorEnv):
"""Vectorized environment that runs multiple environments in parallel.
It uses ``multiprocessing`` processes, and pipes for communication.
Example:
>>> import gymnasium as gym
>>> env = gym.vector.AsyncVectorEnv([
... lambda: gym.make("Pendulum-v1", g=9.81),
... lambda: gym.make("Pendulum-v1", g=1.62)
... ])
>>> env.reset(seed=42)
(array([[-0.14995256, 0.9886932 , -0.12224312],
[ 0.5760367 , 0.8174238 , -0.91244936]], dtype=float32), {})
"""
def __init__(
self,
env_fns: Sequence[Callable[[], Env]],
shared_memory: bool = True,
copy: bool = True,
context: str | None = None,
daemon: bool = True,
worker: callable | None = None,
):
"""Vectorized environment that runs multiple environments in parallel.
Args:
env_fns: Functions that create the environments.
shared_memory: If ``True``, then the observations from the worker processes are communicated back through
shared variables. This can improve the efficiency if the observations are large (e.g. images).
copy: If ``True``, then the :meth:`~AsyncVectorEnv.reset` and :meth:`~AsyncVectorEnv.step` methods
return a copy of the observations.
context: Context for `multiprocessing`_. If ``None``, then the default context is used.
daemon: If ``True``, then subprocesses have ``daemon`` flag turned on; that is, they will quit if
the head process quits. However, ``daemon=True`` prevents subprocesses to spawn children,
so for some environments you may want to have it set to ``False``.
worker: If set, then use that worker in a subprocess instead of a default one.
Can be useful to override some inner vector env logic, for instance, how resets on termination or truncation are handled.
Warnings:
worker is an advanced mode option. It provides a high degree of flexibility and a high chance
to shoot yourself in the foot; thus, if you are writing your own worker, it is recommended to start
from the code for ``_worker`` (or ``_worker_shared_memory``) method, and add changes.
Raises:
RuntimeError: If the observation space of some sub-environment does not match observation_space
(or, by default, the observation space of the first sub-environment).
ValueError: If observation_space is a custom space (i.e. not a default space in Gym,
such as gymnasium.spaces.Box, gymnasium.spaces.Discrete, or gymnasium.spaces.Dict) and shared_memory is True.
"""
super().__init__()
ctx = multiprocessing.get_context(context)
self.env_fns = env_fns
self.num_envs = len(env_fns)
self.shared_memory = shared_memory
self.copy = copy
# This would be nice to get rid of, but without it there's a deadlock between shared memory and pipes
dummy_env = env_fns[0]()
self.metadata = dummy_env.metadata
self.single_observation_space = dummy_env.observation_space
self.single_action_space = dummy_env.action_space
self.observation_space = batch_space(
self.single_observation_space, self.num_envs
)
self.action_space = batch_space(self.single_action_space, self.num_envs)
dummy_env.close()
del dummy_env
if self.shared_memory:
try:
_obs_buffer = create_shared_memory(
self.single_observation_space, n=self.num_envs, ctx=ctx
)
self.observations = read_from_shared_memory(
self.single_observation_space, _obs_buffer, n=self.num_envs
)
except CustomSpaceError as e:
raise ValueError(
"Using `shared_memory=True` in `AsyncVectorEnv` "
"is incompatible with non-standard Gymnasium observation spaces "
"(i.e. custom spaces inheriting from `gymnasium.Space`), and is "
"only compatible with default Gymnasium spaces (e.g. `Box`, "
"`Tuple`, `Dict`) for batching. Set `shared_memory=False` "
"if you use custom observation spaces."
) from e
else:
_obs_buffer = None
self.observations = create_empty_array(
self.single_observation_space, n=self.num_envs, fn=np.zeros
)
self.parent_pipes, self.processes = [], []
self.error_queue = ctx.Queue()
target = worker or _worker
with clear_mpi_env_vars():
for idx, env_fn in enumerate(self.env_fns):
parent_pipe, child_pipe = ctx.Pipe()
process = ctx.Process(
target=target,
name=f"Worker<{type(self).__name__}>-{idx}",
args=(
idx,
CloudpickleWrapper(env_fn),
child_pipe,
parent_pipe,
_obs_buffer,
self.error_queue,
),
)
self.parent_pipes.append(parent_pipe)
self.processes.append(process)
process.daemon = daemon
process.start()
child_pipe.close()
self._state = AsyncState.DEFAULT
self._check_spaces()
def reset_async(
self,
seed: int | list[int] | None = None,
options: dict | None = None,
):
"""Send calls to the :obj:`reset` methods of the sub-environments.
To get the results of these calls, you may invoke :meth:`reset_wait`.
Args:
seed: List of seeds for each environment
options: The reset option
Raises:
ClosedEnvironmentError: If the environment was closed (if :meth:`close` was previously called).
AlreadyPendingCallError: If the environment is already waiting for a pending call to another
method (e.g. :meth:`step_async`). This can be caused by two consecutive
calls to :meth:`reset_async`, with no call to :meth:`reset_wait` in between.
"""
self._assert_is_running()
if seed is None:
seed = [None for _ in range(self.num_envs)]
if isinstance(seed, int):
seed = [seed + i for i in range(self.num_envs)]
assert len(seed) == self.num_envs
if self._state != AsyncState.DEFAULT:
raise AlreadyPendingCallError(
f"Calling `reset_async` while waiting for a pending call to `{self._state.value}` to complete",
str(self._state.value),
)
for pipe, single_seed in zip(self.parent_pipes, seed):
single_kwargs = {}
if single_seed is not None:
single_kwargs["seed"] = single_seed
if options is not None:
single_kwargs["options"] = options
pipe.send(("reset", single_kwargs))
self._state = AsyncState.WAITING_RESET
def reset_wait(
self,
timeout: int | float | None = None,
) -> tuple[ObsType, list[dict]]:
"""Waits for the calls triggered by :meth:`reset_async` to finish and returns the results.
Args:
timeout: Number of seconds before the call to `reset_wait` times out. If `None`, the call to `reset_wait` never times out.
Returns:
A tuple of batched observations and list of dictionaries
Raises:
ClosedEnvironmentError: If the environment was closed (if :meth:`close` was previously called).
NoAsyncCallError: If :meth:`reset_wait` was called without any prior call to :meth:`reset_async`.
TimeoutError: If :meth:`reset_wait` timed out.
"""
self._assert_is_running()
if self._state != AsyncState.WAITING_RESET:
raise NoAsyncCallError(
"Calling `reset_wait` without any prior " "call to `reset_async`.",
AsyncState.WAITING_RESET.value,
)
if not self._poll(timeout):
self._state = AsyncState.DEFAULT
raise multiprocessing.TimeoutError(
f"The call to `reset_wait` has timed out after {timeout} second(s)."
)
results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
self._raise_if_errors(successes)
self._state = AsyncState.DEFAULT
infos = {}
results, info_data = zip(*results)
for i, info in enumerate(info_data):
infos = self._add_info(infos, info, i)
if not self.shared_memory:
self.observations = concatenate(
self.single_observation_space, results, self.observations
)
return (deepcopy(self.observations) if self.copy else self.observations), infos
def reset(
self,
*,
seed: int | list[int] | None = None,
options: dict | None = None,
):
"""Reset all parallel environments and return a batch of initial observations and info.
Args:
seed: The environment reset seeds
options: If to return the options
Returns:
A batch of observations and info from the vectorized environment.
"""
self.reset_async(seed=seed, options=options)
return self.reset_wait()
def step_async(self, actions: np.ndarray):
"""Send the calls to :obj:`step` to each sub-environment.
Args:
actions: Batch of actions. element of :attr:`~VectorEnv.action_space`
Raises:
ClosedEnvironmentError: If the environment was closed (if :meth:`close` was previously called).
AlreadyPendingCallError: If the environment is already waiting for a pending call to another
method (e.g. :meth:`reset_async`). This can be caused by two consecutive
calls to :meth:`step_async`, with no call to :meth:`step_wait` in
between.
"""
self._assert_is_running()
if self._state != AsyncState.DEFAULT:
raise AlreadyPendingCallError(
f"Calling `step_async` while waiting for a pending call to `{self._state.value}` to complete.",
str(self._state.value),
)
actions = iterate(self.action_space, actions)
for pipe, action in zip(self.parent_pipes, actions):
pipe.send(("step", action))
self._state = AsyncState.WAITING_STEP
def step_wait(
self, timeout: int | float | None = None
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, dict]:
"""Wait for the calls to :obj:`step` in each sub-environment to finish.
Args:
timeout: Number of seconds before the call to :meth:`step_wait` times out. If ``None``, the call to :meth:`step_wait` never times out.
Returns:
The batched environment step information, (obs, reward, terminated, truncated, info)
Raises:
ClosedEnvironmentError: If the environment was closed (if :meth:`close` was previously called).
NoAsyncCallError: If :meth:`step_wait` was called without any prior call to :meth:`step_async`.
TimeoutError: If :meth:`step_wait` timed out.
"""
self._assert_is_running()
if self._state != AsyncState.WAITING_STEP:
raise NoAsyncCallError(
"Calling `step_wait` without any prior call " "to `step_async`.",
AsyncState.WAITING_STEP.value,
)
if not self._poll(timeout):
self._state = AsyncState.DEFAULT
raise multiprocessing.TimeoutError(
f"The call to `step_wait` has timed out after {timeout} second(s)."
)
observations_list, rewards, terminateds, truncateds, infos = [], [], [], [], {}
successes = []
for i, pipe in enumerate(self.parent_pipes):
result, success = pipe.recv()
obs, rew, terminated, truncated, info = result
successes.append(success)
if success:
observations_list.append(obs)
rewards.append(rew)
terminateds.append(terminated)
truncateds.append(truncated)
infos = self._add_info(infos, info, i)
self._raise_if_errors(successes)
self._state = AsyncState.DEFAULT
if not self.shared_memory:
self.observations = concatenate(
self.single_observation_space,
observations_list,
self.observations,
)
return (
deepcopy(self.observations) if self.copy else self.observations,
np.array(rewards),
np.array(terminateds, dtype=np.bool_),
np.array(truncateds, dtype=np.bool_),
infos,
)
def step(self, actions):
"""Take an action for each parallel environment.
Args:
actions: element of :attr:`action_space` Batch of actions.
Returns:
Batch of (observations, rewards, terminations, truncations, infos)
"""
self.step_async(actions)
return self.step_wait()
def call_async(self, name: str, *args, **kwargs):
"""Calls the method with name asynchronously and apply args and kwargs to the method.
Args:
name: Name of the method or property to call.
*args: Arguments to apply to the method call.
**kwargs: Keyword arguments to apply to the method call.
Raises:
ClosedEnvironmentError: If the environment was closed (if :meth:`close` was previously called).
AlreadyPendingCallError: Calling `call_async` while waiting for a pending call to complete
"""
self._assert_is_running()
if self._state != AsyncState.DEFAULT:
raise AlreadyPendingCallError(
"Calling `call_async` while waiting "
f"for a pending call to `{self._state.value}` to complete.",
str(self._state.value),
)
for pipe in self.parent_pipes:
pipe.send(("_call", (name, args, kwargs)))
self._state = AsyncState.WAITING_CALL
def call_wait(self, timeout: int | float | None = None) -> list:
"""Calls all parent pipes and waits for the results.
Args:
timeout: Number of seconds before the call to `step_wait` times out.
If `None` (default), the call to `step_wait` never times out.
Returns:
List of the results of the individual calls to the method or property for each environment.
Raises:
NoAsyncCallError: Calling `call_wait` without any prior call to `call_async`.
TimeoutError: The call to `call_wait` has timed out after timeout second(s).
"""
self._assert_is_running()
if self._state != AsyncState.WAITING_CALL:
raise NoAsyncCallError(
"Calling `call_wait` without any prior call to `call_async`.",
AsyncState.WAITING_CALL.value,
)
if not self._poll(timeout):
self._state = AsyncState.DEFAULT
raise multiprocessing.TimeoutError(
f"The call to `call_wait` has timed out after {timeout} second(s)."
)
results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
self._raise_if_errors(successes)
self._state = AsyncState.DEFAULT
return results
def call(self, name: str, *args, **kwargs) -> list[Any]:
"""Call a method, or get a property, from each parallel environment.
Args:
name (str): Name of the method or property to call.
*args: Arguments to apply to the method call.
**kwargs: Keyword arguments to apply to the method call.
Returns:
List of the results of the individual calls to the method or property for each environment.
"""
self.call_async(name, *args, **kwargs)
return self.call_wait()
def get_attr(self, name: str):
"""Get a property from each parallel environment.
Args:
name (str): Name of the property to be get from each individual environment.
Returns:
The property with name
"""
return self.call(name)
def set_attr(self, name: str, values: list[Any] | tuple[Any] | object):
"""Sets an attribute of the sub-environments.
Args:
name: Name of the property to be set in each individual environment.
values: Values of the property to be set to. If ``values`` is a list or
tuple, then it corresponds to the values for each individual
environment, otherwise a single value is set for all environments.
Raises:
ValueError: Values must be a list or tuple with length equal to the number of environments.
AlreadyPendingCallError: Calling `set_attr` while waiting for a pending call to complete.
"""
self._assert_is_running()
if not isinstance(values, (list, tuple)):
values = [values for _ in range(self.num_envs)]
if len(values) != self.num_envs:
raise ValueError(
"Values must be a list or tuple with length equal to the "
f"number of environments. Got `{len(values)}` values for "
f"{self.num_envs} environments."
)
if self._state != AsyncState.DEFAULT:
raise AlreadyPendingCallError(
"Calling `set_attr` while waiting "
f"for a pending call to `{self._state.value}` to complete.",
str(self._state.value),
)
for pipe, value in zip(self.parent_pipes, values):
pipe.send(("_setattr", (name, value)))
_, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
self._raise_if_errors(successes)
def close_extras(self, timeout: int | float | None = None, terminate: bool = False):
"""Close the environments & clean up the extra resources (processes and pipes).
Args:
timeout: Number of seconds before the call to :meth:`close` times out. If ``None``,
the call to :meth:`close` never times out. If the call to :meth:`close`
times out, then all processes are terminated.
terminate: If ``True``, then the :meth:`close` operation is forced and all processes are terminated.
Raises:
TimeoutError: If :meth:`close` timed out.
"""
timeout = 0 if terminate else timeout
try:
if self._state != AsyncState.DEFAULT:
logger.warn(
f"Calling `close` while waiting for a pending call to `{self._state.value}` to complete."
)
function = getattr(self, f"{self._state.value}_wait")
function(timeout)
except multiprocessing.TimeoutError:
terminate = True
if terminate:
for process in self.processes:
if process.is_alive():
process.terminate()
else:
for pipe in self.parent_pipes:
if (pipe is not None) and (not pipe.closed):
pipe.send(("close", None))
for pipe in self.parent_pipes:
if (pipe is not None) and (not pipe.closed):
pipe.recv()
for pipe in self.parent_pipes:
if pipe is not None:
pipe.close()
for process in self.processes:
process.join()
def _poll(self, timeout=None):
self._assert_is_running()
if timeout is None:
return True
end_time = time.perf_counter() + timeout
delta = None
for pipe in self.parent_pipes:
delta = max(end_time - time.perf_counter(), 0)
if pipe is None:
return False
if pipe.closed or (not pipe.poll(delta)):
return False
return True
def _check_spaces(self):
self._assert_is_running()
spaces = (self.single_observation_space, self.single_action_space)
for pipe in self.parent_pipes:
pipe.send(("_check_spaces", spaces))
results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
self._raise_if_errors(successes)
same_observation_spaces, same_action_spaces = zip(*results)
if not all(same_observation_spaces):
raise RuntimeError(
f"Some environments have an observation space different from `{self.single_observation_space}`. "
"In order to batch observations, the observation spaces from all environments must be equal."
)
if not all(same_action_spaces):
raise RuntimeError(
f"Some environments have an action space different from `{self.single_action_space}`. "
"In order to batch actions, the action spaces from all environments must be equal."
)
def _assert_is_running(self):
if self.closed:
raise ClosedEnvironmentError(
f"Trying to operate on `{type(self).__name__}`, after a call to `close()`."
)
def _raise_if_errors(self, successes: list[bool]):
if all(successes):
return
num_errors = self.num_envs - sum(successes)
assert num_errors > 0
for i in range(num_errors):
index, exctype, value = self.error_queue.get()
logger.error(
f"Received the following error from Worker-{index}: {exctype.__name__}: {value}"
)
logger.error(f"Shutting down Worker-{index}.")
self.parent_pipes[index].close()
self.parent_pipes[index] = None
if i == num_errors - 1:
logger.error("Raising the last exception back to the main process.")
raise exctype(value)
def __del__(self):
"""On deleting the object, checks that the vector environment is closed."""
if not getattr(self, "closed", True) and hasattr(self, "_state"):
self.close(terminate=True)
def _worker(
index: int,
env_fn: callable,
pipe: Connection,
parent_pipe: Connection,
shared_memory: bool,
error_queue: Queue,
):
env = env_fn()
observation_space = env.observation_space
action_space = env.action_space
parent_pipe.close()
try:
while True:
command, data = pipe.recv()
if command == "reset":
observation, info = env.reset(**data)
if shared_memory:
write_to_shared_memory(
observation_space, index, observation, shared_memory
)
observation = None
pipe.send(((observation, info), True))
elif command == "step":
(
observation,
reward,
terminated,
truncated,
info,
) = env.step(data)
if terminated or truncated:
old_observation, old_info = observation, info
observation, info = env.reset()
info["final_observation"] = old_observation
info["final_info"] = old_info
if shared_memory:
write_to_shared_memory(
observation_space, index, observation, shared_memory
)
observation = None
pipe.send(((observation, reward, terminated, truncated, info), True))
elif command == "seed":
env.seed(data)
pipe.send((None, True))
elif command == "close":
pipe.send((None, True))
break
elif command == "_call":
name, args, kwargs = data
if name in ["reset", "step", "seed", "close"]:
raise ValueError(
f"Trying to call function `{name}` with "
f"`_call`. Use `{name}` directly instead."
)
function = getattr(env, name)
if callable(function):
pipe.send((function(*args, **kwargs), True))
else:
pipe.send((function, True))
elif command == "_setattr":
name, value = data
setattr(env, name, value)
pipe.send((None, True))
elif command == "_check_spaces":
pipe.send(
(
(data[0] == observation_space, data[1] == action_space),
True,
)
)
else:
raise RuntimeError(
f"Received unknown command `{command}`. Must "
"be one of {`reset`, `step`, `seed`, `close`, `_call`, "
"`_setattr`, `_check_spaces`}."
)
except (KeyboardInterrupt, Exception):
error_queue.put((index,) + sys.exc_info()[:2])
pipe.send((None, False))
finally:
env.close()

View File

@@ -1,229 +0,0 @@
"""A synchronous vector environment."""
from __future__ import annotations
from copy import deepcopy
from typing import Any, Callable, Iterator
import numpy as np
from gymnasium import Env
from gymnasium.experimental.vector.utils import (
batch_space,
concatenate,
create_empty_array,
iterate,
)
from gymnasium.experimental.vector.vector_env import VectorEnv
__all__ = ["SyncVectorEnv"]
class SyncVectorEnv(VectorEnv):
"""Vectorized environment that serially runs multiple environments.
Example:
>>> import gymnasium as gym
>>> env = gym.vector.SyncVectorEnv([
... lambda: gym.make("Pendulum-v1", g=9.81),
... lambda: gym.make("Pendulum-v1", g=1.62)
... ])
>>> env.reset(seed=42)
(array([[-0.14995256, 0.9886932 , -0.12224312],
[ 0.5760367 , 0.8174238 , -0.91244936]], dtype=float32), {})
"""
def __init__(
self,
env_fns: Iterator[Callable[[], Env]],
copy: bool = True,
):
"""Vectorized environment that serially runs multiple environments.
Args:
env_fns: iterable of callable functions that create the environments.
copy: If ``True``, then the :meth:`reset` and :meth:`step` methods return a copy of the observations.
Raises:
RuntimeError: If the observation space of some sub-environment does not match observation_space
(or, by default, the observation space of the first sub-environment).
"""
super().__init__()
self.env_fns = env_fns
self.envs = [env_fn() for env_fn in env_fns]
self.num_envs = len(self.envs)
self.copy = copy
self.metadata = self.envs[0].metadata
self.spec = self.envs[0].spec
self.single_observation_space = self.envs[0].observation_space
self.single_action_space = self.envs[0].action_space
self.observation_space = batch_space(
self.single_observation_space, self.num_envs
)
self.action_space = batch_space(self.single_action_space, self.num_envs)
self._check_spaces()
self.observations = create_empty_array(
self.single_observation_space, n=self.num_envs, fn=np.zeros
)
self._rewards = np.zeros((self.num_envs,), dtype=np.float64)
self._terminateds = np.zeros((self.num_envs,), dtype=np.bool_)
self._truncateds = np.zeros((self.num_envs,), dtype=np.bool_)
def reset(
self,
seed: int | list[int] | None = None,
options: dict | None = None,
):
"""Waits for the calls triggered by :meth:`reset_async` to finish and returns the results.
Args:
seed: The reset environment seed
options: Option information for the environment reset
Returns:
The reset observation of the environment and reset information
"""
if seed is None:
seed = [None for _ in range(self.num_envs)]
if isinstance(seed, int):
seed = [seed + i for i in range(self.num_envs)]
assert len(seed) == self.num_envs
self._terminateds[:] = False
self._truncateds[:] = False
observations = []
infos = {}
for i, (env, single_seed) in enumerate(zip(self.envs, seed)):
kwargs = {}
if single_seed is not None:
kwargs["seed"] = single_seed
if options is not None:
kwargs["options"] = options
observation, info = env.reset(**kwargs)
observations.append(observation)
infos = self._add_info(infos, info, i)
self.observations = concatenate(
self.single_observation_space, observations, self.observations
)
return (deepcopy(self.observations) if self.copy else self.observations), infos
def step(self, actions):
"""Steps through each of the environments returning the batched results.
Returns:
The batched environment step results
"""
actions = iterate(self.action_space, actions)
observations, infos = [], {}
for i, (env, action) in enumerate(zip(self.envs, actions)):
(
observation,
self._rewards[i],
self._terminateds[i],
self._truncateds[i],
info,
) = env.step(action)
if self._terminateds[i] or self._truncateds[i]:
old_observation, old_info = observation, info
observation, info = env.reset()
info["final_observation"] = old_observation
info["final_info"] = old_info
observations.append(observation)
infos = self._add_info(infos, info, i)
self.observations = concatenate(
self.single_observation_space, observations, self.observations
)
return (
deepcopy(self.observations) if self.copy else self.observations,
np.copy(self._rewards),
np.copy(self._terminateds),
np.copy(self._truncateds),
infos,
)
def call(self, name, *args, **kwargs) -> tuple:
"""Calls the method with name and applies args and kwargs.
Args:
name: The method name
*args: The method args
**kwargs: The method kwargs
Returns:
Tuple of results
"""
results = []
for env in self.envs:
function = getattr(env, name)
if callable(function):
results.append(function(*args, **kwargs))
else:
results.append(function)
return tuple(results)
def get_attr(self, name: str):
"""Get a property from each parallel environment.
Args:
name (str): Name of the property to be get from each individual environment.
Returns:
The property with name
"""
return self.call(name)
def set_attr(self, name: str, values: list | tuple | Any):
"""Sets an attribute of the sub-environments.
Args:
name: The property name to change
values: Values of the property to be set to. If ``values`` is a list or
tuple, then it corresponds to the values for each individual
environment, otherwise, a single value is set for all environments.
Raises:
ValueError: Values must be a list or tuple with length equal to the number of environments.
"""
if not isinstance(values, (list, tuple)):
values = [values for _ in range(self.num_envs)]
if len(values) != self.num_envs:
raise ValueError(
"Values must be a list or tuple with length equal to the "
f"number of environments. Got `{len(values)}` values for "
f"{self.num_envs} environments."
)
for env, value in zip(self.envs, values):
setattr(env, name, value)
def close_extras(self, **kwargs):
"""Close the environments."""
[env.close() for env in self.envs]
def _check_spaces(self) -> bool:
for env in self.envs:
if not (env.observation_space == self.single_observation_space):
raise RuntimeError(
"Some environments have an observation space different from "
f"`{self.single_observation_space}`. In order to batch observations, "
"the observation spaces from all environments must be equal."
)
if not (env.action_space == self.single_action_space):
raise RuntimeError(
"Some environments have an action space different from "
f"`{self.single_action_space}`. In order to batch actions, the "
"action spaces from all environments must be equal."
)
return True

View File

@@ -1,30 +0,0 @@
"""Module for gymnasium experimental vector utility functions."""
from gymnasium.experimental.vector.utils.misc import (
CloudpickleWrapper,
clear_mpi_env_vars,
)
from gymnasium.experimental.vector.utils.shared_memory import (
create_shared_memory,
read_from_shared_memory,
write_to_shared_memory,
)
from gymnasium.experimental.vector.utils.space_utils import (
batch_space,
concatenate,
create_empty_array,
iterate,
)
__all__ = [
"batch_space",
"iterate",
"concatenate",
"create_empty_array",
"create_shared_memory",
"read_from_shared_memory",
"write_to_shared_memory",
"CloudpickleWrapper",
"clear_mpi_env_vars",
]

View File

@@ -1,61 +0,0 @@
"""Miscellaneous utilities."""
from __future__ import annotations
import contextlib
import os
from collections.abc import Callable
from gymnasium.core import Env
__all__ = ["CloudpickleWrapper", "clear_mpi_env_vars"]
class CloudpickleWrapper:
"""Wrapper that uses cloudpickle to pickle and unpickle the result."""
def __init__(self, fn: Callable[[], Env]):
"""Cloudpickle wrapper for a function."""
self.fn = fn
def __getstate__(self):
"""Get the state using `cloudpickle.dumps(self.fn)`."""
import cloudpickle
return cloudpickle.dumps(self.fn)
def __setstate__(self, ob):
"""Sets the state with obs."""
import pickle
self.fn = pickle.loads(ob)
def __call__(self):
"""Calls the function `self.fn` with no arguments."""
return self.fn()
@contextlib.contextmanager
def clear_mpi_env_vars():
"""Clears the MPI of environment variables.
`from mpi4py import MPI` will call `MPI_Init` by default.
If the child process has MPI environment variables, MPI will think that the child process
is an MPI process just like the parent and do bad things such as hang.
This context manager is a hacky way to clear those environment variables
temporarily such as when we are starting multiprocessing Processes.
Yields:
Yields for the context manager
"""
removed_environment = {}
for k, v in list(os.environ.items()):
for prefix in ["OMPI_", "PMI_"]:
if k.startswith(prefix):
removed_environment[k] = v
del os.environ[k]
try:
yield
finally:
os.environ.update(removed_environment)

View File

@@ -1,255 +0,0 @@
"""Utility functions for vector environments to share memory between processes."""
from __future__ import annotations
import multiprocessing as mp
from collections import OrderedDict
from ctypes import c_bool
from functools import singledispatch
from typing import Any
import numpy as np
from gymnasium.error import CustomSpaceError
from gymnasium.spaces import (
Box,
Dict,
Discrete,
Graph,
MultiBinary,
MultiDiscrete,
Sequence,
Space,
Text,
Tuple,
flatten,
)
__all__ = ["create_shared_memory", "read_from_shared_memory", "write_to_shared_memory"]
@singledispatch
def create_shared_memory(
space: Space[Any], n: int = 1, ctx=mp
) -> dict[str, Any] | tuple[Any, ...] | mp.Array:
"""Create a shared memory object, to be shared across processes.
This eventually contains the observations from the vectorized environment.
Args:
space: Observation space of a single environment in the vectorized environment.
n: Number of environments in the vectorized environment (i.e. the number of processes).
ctx: The multiprocess module
Returns:
shared_memory for the shared object across processes.
Raises:
CustomSpaceError: Space is not a valid :class:`gymnasium.Space` instance
"""
if isinstance(space, Space):
raise CustomSpaceError(
f"Space of type `{type(space)}` doesn't have an registered `create_shared_memory` function. Register `{type(space)}` for `create_shared_memory` to support it."
)
else:
raise TypeError(
f"The space provided to `create_shared_memory` is not a gymnasium Space instance, type: {type(space)}, {space}"
)
@create_shared_memory.register(Box)
@create_shared_memory.register(Discrete)
@create_shared_memory.register(MultiDiscrete)
@create_shared_memory.register(MultiBinary)
def _create_base_shared_memory(
space: Box | Discrete | MultiDiscrete | MultiBinary, n: int = 1, ctx=mp
):
assert space.dtype is not None
dtype = space.dtype.char
if dtype in "?":
dtype = c_bool
return ctx.Array(dtype, n * int(np.prod(space.shape)))
@create_shared_memory.register(Tuple)
def _create_tuple_shared_memory(space: Tuple, n: int = 1, ctx=mp):
return tuple(
create_shared_memory(subspace, n=n, ctx=ctx) for subspace in space.spaces
)
@create_shared_memory.register(Dict)
def _create_dict_shared_memory(space: Dict, n: int = 1, ctx=mp):
return OrderedDict(
[
(key, create_shared_memory(subspace, n=n, ctx=ctx))
for (key, subspace) in space.spaces.items()
]
)
@create_shared_memory.register(Text)
def _create_text_shared_memory(space: Text, n: int = 1, ctx=mp):
return ctx.Array(np.dtype(np.int32).char, n * space.max_length)
@create_shared_memory.register(Graph)
@create_shared_memory.register(Sequence)
def _create_dynamic_shared_memory(space: Graph | Sequence, n: int = 1, ctx=mp):
raise TypeError(
f"As {space} has a dynamic shape then it is not possible to make a static shared memory."
)
@singledispatch
def read_from_shared_memory(
space: Space, shared_memory: dict | tuple | mp.Array, n: int = 1
) -> dict[str, Any] | tuple[Any, ...] | np.ndarray:
"""Read the batch of observations from shared memory as a numpy array.
..notes::
The numpy array objects returned by `read_from_shared_memory` shares the
memory of `shared_memory`. Any changes to `shared_memory` are forwarded
to `observations`, and vice-versa. To avoid any side-effect, use `np.copy`.
Args:
space: Observation space of a single environment in the vectorized environment.
shared_memory: Shared object across processes. This contains the observations from the vectorized environment.
This object is created with `create_shared_memory`.
n: Number of environments in the vectorized environment (i.e. the number of processes).
Returns:
Batch of observations as a (possibly nested) numpy array.
Raises:
CustomSpaceError: Space is not a valid :class:`gymnasium.Space` instance
"""
if isinstance(space, Space):
raise CustomSpaceError(
f"Space of type `{type(space)}` doesn't have an registered `read_from_shared_memory` function. Register `{type(space)}` for `read_from_shared_memory` to support it."
)
else:
raise TypeError(
f"The space provided to `read_from_shared_memory` is not a gymnasium Space instance, type: {type(space)}, {space}"
)
@read_from_shared_memory.register(Box)
@read_from_shared_memory.register(Discrete)
@read_from_shared_memory.register(MultiDiscrete)
@read_from_shared_memory.register(MultiBinary)
def _read_base_from_shared_memory(
space: Box | Discrete | MultiDiscrete | MultiBinary, shared_memory, n: int = 1
):
return np.frombuffer(shared_memory.get_obj(), dtype=space.dtype).reshape(
(n,) + space.shape
)
@read_from_shared_memory.register(Tuple)
def _read_tuple_from_shared_memory(space: Tuple, shared_memory, n: int = 1):
return tuple(
read_from_shared_memory(subspace, memory, n=n)
for (memory, subspace) in zip(shared_memory, space.spaces)
)
@read_from_shared_memory.register(Dict)
def _read_dict_from_shared_memory(space: Dict, shared_memory, n: int = 1):
return OrderedDict(
[
(key, read_from_shared_memory(subspace, shared_memory[key], n=n))
for (key, subspace) in space.spaces.items()
]
)
@read_from_shared_memory.register(Text)
def _read_text_from_shared_memory(space: Text, shared_memory, n: int = 1) -> tuple[str]:
data = np.frombuffer(shared_memory.get_obj(), dtype=np.int32).reshape(
(n, space.max_length)
)
return tuple(
"".join(
[
space.character_list[val]
for val in values
if val < len(space.character_set)
]
)
for values in data
)
@singledispatch
def write_to_shared_memory(
space: Space,
index: int,
value: np.ndarray,
shared_memory: dict[str, Any] | tuple[Any, ...] | mp.Array,
):
"""Write the observation of a single environment into shared memory.
Args:
space: Observation space of a single environment in the vectorized environment.
index: Index of the environment (must be in `[0, num_envs)`).
value: Observation of the single environment to write to shared memory.
shared_memory: Shared object across processes. This contains the observations from the vectorized environment.
This object is created with `create_shared_memory`.
Raises:
CustomSpaceError: Space is not a valid :class:`gymnasium.Space` instance
"""
if isinstance(space, Space):
raise CustomSpaceError(
f"Space of type `{type(space)}` doesn't have an registered `write_to_shared_memory` function. Register `{type(space)}` for `write_to_shared_memory` to support it."
)
else:
raise TypeError(
f"The space provided to `write_to_shared_memory` is not a gymnasium Space instance, type: {type(space)}, {space}"
)
@write_to_shared_memory.register(Box)
@write_to_shared_memory.register(Discrete)
@write_to_shared_memory.register(MultiDiscrete)
@write_to_shared_memory.register(MultiBinary)
def _write_base_to_shared_memory(
space: Box | Discrete | MultiDiscrete | MultiBinary,
index: int,
value,
shared_memory,
):
size = int(np.prod(space.shape))
destination = np.frombuffer(shared_memory.get_obj(), dtype=space.dtype)
np.copyto(
destination[index * size : (index + 1) * size],
np.asarray(value, dtype=space.dtype).flatten(),
)
@write_to_shared_memory.register(Tuple)
def _write_tuple_to_shared_memory(
space: Tuple, index: int, values: tuple[Any, ...], shared_memory
):
for value, memory, subspace in zip(values, shared_memory, space.spaces):
write_to_shared_memory(subspace, index, value, memory)
@write_to_shared_memory.register(Dict)
def _write_dict_to_shared_memory(
space: Dict, index: int, values: dict[str, Any], shared_memory
):
for key, subspace in space.spaces.items():
write_to_shared_memory(subspace, index, values[key], shared_memory[key])
@write_to_shared_memory.register(Text)
def _write_text_to_shared_memory(space: Text, index: int, values: str, shared_memory):
size = space.max_length
destination = np.frombuffer(shared_memory.get_obj(), dtype=np.int32)
np.copyto(
destination[index * size : (index + 1) * size],
flatten(space, values),
)

View File

@@ -1,486 +0,0 @@
"""Base class for vectorized environments."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Generic, TypeVar
import numpy as np
import gymnasium as gym
from gymnasium.core import ActType, ObsType
from gymnasium.utils import seeding
if TYPE_CHECKING:
from gymnasium.envs.registration import EnvSpec
ArrayType = TypeVar("ArrayType")
__all__ = [
"VectorEnv",
"VectorWrapper",
"VectorObservationWrapper",
"VectorActionWrapper",
"VectorRewardWrapper",
"ArrayType",
]
class VectorEnv(Generic[ObsType, ActType, ArrayType]):
"""Base class for vectorized environments to run multiple independent copies of the same environment in parallel.
Vector environments can provide a linear speed-up in the steps taken per second through sampling multiple
sub-environments at the same time. To prevent terminated environments waiting until all sub-environments have
terminated or truncated, the vector environments autoreset sub-environments after they terminate or truncated.
As a result, the final step's observation and info are overwritten by the reset's observation and info.
Therefore, the observation and info for the final step of a sub-environment is stored in the info parameter,
using `"final_observation"` and `"final_info"` respectively. See :meth:`step` for more information.
The vector environments batch `observations`, `rewards`, `terminations`, `truncations` and `info` for each
parallel environment. In addition, :meth:`step` expects to receive a batch of actions for each parallel environment.
Gymnasium contains two types of Vector environments: :class:`AsyncVectorEnv` and :class:`SyncVectorEnv`.
The Vector Environments have the additional attributes for users to understand the implementation
- :attr:`num_envs` - The number of sub-environment in the vector environment
- :attr:`observation_space` - The batched observation space of the vector environment
- :attr:`single_observation_space` - The observation space of a single sub-environment
- :attr:`action_space` - The batched action space of the vector environment
- :attr:`single_action_space` - The action space of a single sub-environment
Note:
The info parameter of :meth:`reset` and :meth:`step` was originally implemented before OpenAI Gym v25 was a list
of dictionary for each sub-environment. However, this was modified in OpenAI Gym v25+ and in Gymnasium to a
dictionary with a NumPy array for each key. To use the old info style using the :class:`VectorListInfo`.
Note:
To render the sub-environments, use :meth:`call` with "render" arguments. Remember to set the `render_modes`
for all the sub-environments during initialization.
Note:
All parallel environments should share the identical observation and action spaces.
In other words, a vector of multiple different environments is not supported.
"""
spec: EnvSpec
observation_space: gym.Space
action_space: gym.Space
single_observation_space: gym.Space
single_action_space: gym.Space
num_envs: int
closed = False
_np_random: np.random.Generator | None = None
def reset(
self,
*,
seed: int | list[int] | None = None,
options: dict[str, Any] | None = None,
) -> tuple[ObsType, dict[str, Any]]: # type: ignore
"""Reset all parallel environments and return a batch of initial observations and info.
Args:
seed: The environment reset seeds
options: If to return the options
Returns:
A batch of observations and info from the vectorized environment.
Example:
>>> import gymnasium as gym
>>> envs = gym.vector.make("CartPole-v1", num_envs=3)
>>> envs.reset(seed=42)
(array([[ 0.0273956 , -0.00611216, 0.03585979, 0.0197368 ],
[ 0.01522993, -0.04562247, -0.04799704, 0.03392126],
[-0.03774345, -0.02418869, -0.00942293, 0.0469184 ]],
dtype=float32), {})
"""
if seed is not None:
self._np_random, seed = seeding.np_random(seed)
def step(
self, actions: ActType
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]:
"""Take an action for each parallel environment.
Args:
actions: element of :attr:`action_space` Batch of actions.
Returns:
Batch of (observations, rewards, terminations, truncations, infos)
Note:
As the vector environments autoreset for a terminating and truncating sub-environments,
the returned observation and info is not the final step's observation or info which is instead stored in
info as `"final_observation"` and `"final_info"`.
Example:
>>> import gymnasium as gym
>>> import numpy as np
>>> envs = gym.vector.make("CartPole-v1", num_envs=3)
>>> _ = envs.reset(seed=42)
>>> actions = np.array([1, 0, 1])
>>> observations, rewards, termination, truncation, infos = envs.step(actions)
>>> observations
array([[ 0.02727336, 0.18847767, 0.03625453, -0.26141977],
[ 0.01431748, -0.24002443, -0.04731862, 0.3110827 ],
[-0.03822722, 0.1710671 , -0.00848456, -0.2487226 ]],
dtype=float32)
>>> rewards
array([1., 1., 1.])
>>> termination
array([False, False, False])
>>> termination
array([False, False, False])
>>> infos
{}
"""
pass
def close_extras(self, **kwargs):
"""Clean up the extra resources e.g. beyond what's in this base class."""
pass
def close(self, **kwargs):
"""Close all parallel environments and release resources.
It also closes all the existing image viewers, then calls :meth:`close_extras` and set
:attr:`closed` as ``True``.
Warnings:
This function itself does not close the environments, it should be handled
in :meth:`close_extras`. This is generic for both synchronous and asynchronous
vectorized environments.
Note:
This will be automatically called when garbage collected or program exited.
Args:
**kwargs: Keyword arguments passed to :meth:`close_extras`
"""
if self.closed:
return
self.close_extras(**kwargs)
self.closed = True
@property
def np_random(self) -> np.random.Generator:
"""Returns the environment's internal :attr:`_np_random` that if not set will initialise with a random seed.
Returns:
Instances of `np.random.Generator`
"""
if self._np_random is None:
self._np_random, seed = seeding.np_random()
return self._np_random
@np_random.setter
def np_random(self, value: np.random.Generator):
self._np_random = value
@property
def unwrapped(self):
"""Return the base environment."""
return self
def _add_info(self, infos: dict, info: dict, env_num: int) -> dict:
"""Add env info to the info dictionary of the vectorized environment.
Given the `info` of a single environment add it to the `infos` dictionary
which represents all the infos of the vectorized environment.
Every `key` of `info` is paired with a boolean mask `_key` representing
whether or not the i-indexed environment has this `info`.
Args:
infos (dict): the infos of the vectorized environment
info (dict): the info coming from the single environment
env_num (int): the index of the single environment
Returns:
infos (dict): the (updated) infos of the vectorized environment
"""
for k in info.keys():
if k not in infos:
info_array, array_mask = self._init_info_arrays(type(info[k]))
else:
info_array, array_mask = infos[k], infos[f"_{k}"]
info_array[env_num], array_mask[env_num] = info[k], True
infos[k], infos[f"_{k}"] = info_array, array_mask
return infos
def _init_info_arrays(self, dtype: type) -> tuple[np.ndarray, np.ndarray]:
"""Initialize the info array.
Initialize the info array. If the dtype is numeric
the info array will have the same dtype, otherwise
will be an array of `None`. Also, a boolean array
of the same length is returned. It will be used for
assessing which environment has info data.
Args:
dtype (type): data type of the info coming from the env.
Returns:
array (np.ndarray): the initialized info array.
array_mask (np.ndarray): the initialized boolean array.
"""
if dtype in [int, float, bool] or issubclass(dtype, np.number):
array = np.zeros(self.num_envs, dtype=dtype)
else:
array = np.zeros(self.num_envs, dtype=object)
array[:] = None
array_mask = np.zeros(self.num_envs, dtype=bool)
return array, array_mask
def __del__(self):
"""Closes the vector environment."""
if not getattr(self, "closed", True):
self.close()
def __repr__(self) -> str:
"""Returns a string representation of the vector environment.
Returns:
A string containing the class name, number of environments and environment spec id
"""
if getattr(self, "spec", None) is None:
return f"{self.__class__.__name__}({self.num_envs})"
else:
return f"{self.__class__.__name__}({self.spec.id}, {self.num_envs})"
class VectorWrapper(VectorEnv):
"""Wraps the vectorized environment to allow a modular transformation.
This class is the base class for all wrappers for vectorized environments. The subclass
could override some methods to change the behavior of the original vectorized environment
without touching the original code.
Note:
Don't forget to call ``super().__init__(env)`` if the subclass overrides :meth:`__init__`.
"""
_observation_space: gym.Space | None = None
_action_space: gym.Space | None = None
_single_observation_space: gym.Space | None = None
_single_action_space: gym.Space | None = None
def __init__(self, env: VectorEnv):
"""Initialize the vectorized environment wrapper."""
super().__init__()
assert isinstance(env, VectorEnv)
self.env = env
# explicitly forward the methods defined in VectorEnv
# to self.env (instead of the base class)
def reset(
self,
*,
seed: int | list[int] | None = None,
options: dict[str, Any] | None = None,
) -> tuple[ObsType, dict[str, Any]]:
"""Reset all environment using seed and options."""
return self.env.reset(seed=seed, options=options)
def step(
self, actions: ActType
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]:
"""Step all environments."""
return self.env.step(actions)
def close(self, **kwargs: Any):
"""Close all environments."""
return self.env.close(**kwargs)
def close_extras(self, **kwargs: Any):
"""Close all extra resources."""
return self.env.close_extras(**kwargs)
# implicitly forward all other methods and attributes to self.env
def __getattr__(self, name: str) -> Any:
"""Forward all other attributes to the base environment."""
if name.startswith("_"):
raise AttributeError(f"attempted to get missing private attribute '{name}'")
return getattr(self.env, name)
@property
def unwrapped(self):
"""Return the base non-wrapped environment."""
return self.env.unwrapped
def __repr__(self):
"""Return the string representation of the vectorized environment."""
return f"<{self.__class__.__name__}, {self.env}>"
def __del__(self):
"""Close the vectorized environment."""
self.env.__del__()
@property
def spec(self) -> EnvSpec | None:
"""Gets the specification of the wrapped environment."""
return self.env.spec
@property
def observation_space(self) -> gym.Space:
"""Gets the observation space of the vector environment."""
if self._observation_space is None:
return self.env.observation_space
return self._observation_space
@observation_space.setter
def observation_space(self, space: gym.Space):
"""Sets the observation space of the vector environment."""
self._observation_space = space
@property
def action_space(self) -> gym.Space:
"""Gets the action space of the vector environment."""
if self._action_space is None:
return self.env.action_space
return self._action_space
@action_space.setter
def action_space(self, space: gym.Space):
"""Sets the action space of the vector environment."""
self._action_space = space
@property
def single_observation_space(self) -> gym.Space:
"""Gets the single observation space of the vector environment."""
if self._single_observation_space is None:
return self.env.single_observation_space
return self._single_observation_space
@single_observation_space.setter
def single_observation_space(self, space: gym.Space):
"""Sets the single observation space of the vector environment."""
self._single_observation_space = space
@property
def single_action_space(self) -> gym.Space:
"""Gets the single action space of the vector environment."""
if self._single_action_space is None:
return self.env.single_action_space
return self._single_action_space
@single_action_space.setter
def single_action_space(self, space):
"""Sets the single action space of the vector environment."""
self._single_action_space = space
@property
def num_envs(self) -> int:
"""Gets the wrapped vector environment's num of the sub-environments."""
return self.env.num_envs
class VectorObservationWrapper(VectorWrapper):
"""Wraps the vectorized environment to allow a modular transformation of the observation. Equivalent to :class:`gym.ObservationWrapper` for vectorized environments."""
def reset(
self,
*,
seed: int | list[int] | None = None,
options: dict[str, Any] | None = None,
) -> tuple[ObsType, dict[str, Any]]:
"""Modifies the observation returned from the environment ``reset`` using the :meth:`observation`."""
obs, info = self.env.reset(seed=seed, options=options)
return self.vector_observation(obs), info
def step(
self, actions: ActType
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]:
"""Modifies the observation returned from the environment ``step`` using the :meth:`observation`."""
observation, reward, termination, truncation, info = self.env.step(actions)
return (
self.vector_observation(observation),
reward,
termination,
truncation,
self.update_final_obs(info),
)
def vector_observation(self, observation: ObsType) -> ObsType:
"""Defines the vector observation transformation.
Args:
observation: A vector observation from the environment
Returns:
the transformed observation
"""
raise NotImplementedError
def single_observation(self, observation: ObsType) -> ObsType:
"""Defines the single observation transformation.
Args:
observation: A single observation from the environment
Returns:
The transformed observation
"""
raise NotImplementedError
def update_final_obs(self, info: dict[str, Any]) -> dict[str, Any]:
"""Updates the `final_obs` in the info using `single_observation`."""
if "final_observation" in info:
for i, obs in enumerate(info["final_observation"]):
if obs is not None:
info["final_observation"][i] = self.single_observation(obs)
return info
class VectorActionWrapper(VectorWrapper):
"""Wraps the vectorized environment to allow a modular transformation of the actions. Equivalent of :class:`~gym.ActionWrapper` for vectorized environments."""
def step(
self, actions: ActType
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]:
"""Steps through the environment using a modified action by :meth:`action`."""
return self.env.step(self.actions(actions))
def actions(self, actions: ActType) -> ActType:
"""Transform the actions before sending them to the environment.
Args:
actions (ActType): the actions to transform
Returns:
ActType: the transformed actions
"""
raise NotImplementedError
class VectorRewardWrapper(VectorWrapper):
"""Wraps the vectorized environment to allow a modular transformation of the reward. Equivalent of :class:`~gym.RewardWrapper` for vectorized environments."""
def step(
self, actions: ActType
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]:
"""Steps through the environment returning a reward modified by :meth:`reward`."""
observation, reward, termination, truncation, info = self.env.step(actions)
return observation, self.reward(reward), termination, truncation, info
def reward(self, reward: ArrayType) -> ArrayType:
"""Transform the reward before returning it.
Args:
reward (array): the reward to transform
Returns:
array: the transformed reward
"""
raise NotImplementedError

View File

@@ -1,164 +0,0 @@
"""`__init__` for experimental wrappers, to avoid loading the wrappers if unnecessary, we can hack python."""
# pyright: reportUnsupportedDunderAll=false
import importlib
import re
from gymnasium.error import DeprecatedWrapper
from gymnasium.experimental.wrappers import vector
from gymnasium.experimental.wrappers.atari_preprocessing import AtariPreprocessingV0
from gymnasium.experimental.wrappers.common import (
AutoresetV0,
OrderEnforcingV0,
PassiveEnvCheckerV0,
RecordEpisodeStatisticsV0,
)
from gymnasium.experimental.wrappers.lambda_action import (
ClipActionV0,
LambdaActionV0,
RescaleActionV0,
)
from gymnasium.experimental.wrappers.lambda_observation import (
DtypeObservationV0,
FilterObservationV0,
FlattenObservationV0,
GrayscaleObservationV0,
LambdaObservationV0,
PixelObservationV0,
RescaleObservationV0,
ReshapeObservationV0,
ResizeObservationV0,
)
from gymnasium.experimental.wrappers.lambda_reward import ClipRewardV0, LambdaRewardV0
from gymnasium.experimental.wrappers.rendering import (
HumanRenderingV0,
RecordVideoV0,
RenderCollectionV0,
)
from gymnasium.experimental.wrappers.stateful_action import StickyActionV0
from gymnasium.experimental.wrappers.stateful_observation import (
DelayObservationV0,
FrameStackObservationV0,
MaxAndSkipObservationV0,
NormalizeObservationV0,
TimeAwareObservationV0,
)
from gymnasium.experimental.wrappers.stateful_reward import NormalizeRewardV1
# Todo - Add legacy wrapper to new wrapper error for users when merged into gymnasium.wrappers
__all__ = [
"vector",
# --- Observation wrappers ---
"AtariPreprocessingV0",
"DelayObservationV0",
"DtypeObservationV0",
"FilterObservationV0",
"FlattenObservationV0",
"FrameStackObservationV0",
"GrayscaleObservationV0",
"LambdaObservationV0",
"MaxAndSkipObservationV0",
"NormalizeObservationV0",
"PixelObservationV0",
"ResizeObservationV0",
"ReshapeObservationV0",
"RescaleObservationV0",
"TimeAwareObservationV0",
# --- Action Wrappers ---
"ClipActionV0",
"LambdaActionV0",
"RescaleActionV0",
# "NanAction",
"StickyActionV0",
# --- Reward wrappers ---
"ClipRewardV0",
"LambdaRewardV0",
"NormalizeRewardV1",
# --- Common ---
"AutoresetV0",
"PassiveEnvCheckerV0",
"OrderEnforcingV0",
"RecordEpisodeStatisticsV0",
# --- Rendering ---
"RenderCollectionV0",
"RecordVideoV0",
"HumanRenderingV0",
# --- Conversion ---
"JaxToNumpyV0",
"JaxToTorchV0",
"NumpyToTorchV0",
]
# As these wrappers requires `jax` or `torch`, they are loaded by runtime for users trying to access them
# to avoid `import jax` or `import torch` on `import gymnasium`.
_wrapper_to_class = {
# data converters
"JaxToNumpyV0": "jax_to_numpy",
"JaxToTorchV0": "jax_to_torch",
"NumpyToTorchV0": "numpy_to_torch",
}
def __getattr__(wrapper_name: str):
"""Load a wrapper by name.
This optimizes the loading of gymnasium wrappers by only loading the wrapper if it is used.
Errors will be raised if the wrapper does not exist or if the version is not the latest.
Args:
wrapper_name: The name of a wrapper to load.
Returns:
The specified wrapper.
Raises:
AttributeError: If the wrapper does not exist.
DeprecatedWrapper: If the version is not the latest.
"""
# Check if the requested wrapper is in the _wrapper_to_class dictionary
if wrapper_name in _wrapper_to_class:
import_stmt = (
f"gymnasium.experimental.wrappers.{_wrapper_to_class[wrapper_name]}"
)
module = importlib.import_module(import_stmt)
return getattr(module, wrapper_name)
# Define a regex pattern to match the integer suffix (version number) of the wrapper
int_suffix_pattern = r"(\d+)$"
version_match = re.search(int_suffix_pattern, wrapper_name)
# If a version number is found, extract it and the base wrapper name
if version_match:
version = int(version_match.group())
base_name = wrapper_name[: -len(version_match.group())]
else:
version = float("inf")
base_name = wrapper_name[:-2]
# Filter the list of all wrappers to include only those with the same base name
matching_wrappers = [name for name in __all__ if name.startswith(base_name)]
# If no matching wrappers are found, raise an AttributeError
if not matching_wrappers:
raise AttributeError(f"module {__name__!r} has no attribute {wrapper_name!r}")
# Find the latest version of the matching wrappers
latest_wrapper = max(
matching_wrappers, key=lambda s: int(re.findall(int_suffix_pattern, s)[0])
)
latest_version = int(re.findall(int_suffix_pattern, latest_wrapper)[0])
# If the requested wrapper is an older version, raise a DeprecatedWrapper exception
if version < latest_version:
raise DeprecatedWrapper(
f"{wrapper_name!r} is now deprecated, use {latest_wrapper!r} instead.\n"
f"To see the changes made, go to "
f"https://gymnasium.farama.org/api/experimental/wrappers/#gymnasium.experimental.wrappers.{latest_wrapper}"
)
# If the requested version is invalid, raise an AttributeError
else:
raise AttributeError(
f"module {__name__!r} has no attribute {wrapper_name!r}, did you mean {latest_wrapper!r}"
)

View File

@@ -1,206 +0,0 @@
"""Implementation of Atari 2600 Preprocessing following the guidelines of Machado et al., 2018."""
import numpy as np
import gymnasium as gym
from gymnasium.spaces import Box
__all__ = ["AtariPreprocessingV0"]
class AtariPreprocessingV0(gym.Wrapper, gym.utils.RecordConstructorArgs):
"""Atari 2600 preprocessing wrapper.
This class follows the guidelines in Machado et al. (2018),
"Revisiting the Arcade Learning Environment: Evaluation Protocols and Open Problems for General Agents".
Specifically, the following preprocess stages applies to the atari environment:
- Noop Reset: Obtains the initial state by taking a random number of no-ops on reset, default max 30 no-ops.
- Frame skipping: The number of frames skipped between steps, 4 by default
- Max-pooling: Pools over the most recent two observations from the frame skips
- Termination signal when a life is lost: When the agent losses a life during the environment, then the environment is terminated.
Turned off by default. Not recommended by Machado et al. (2018).
- Resize to a square image: Resizes the atari environment original observation shape from 210x180 to 84x84 by default
- Grayscale observation: If the observation is colour or greyscale, by default, greyscale.
- Scale observation: If to scale the observation between [0, 1) or [0, 255), by default, not scaled.
"""
def __init__(
self,
env: gym.Env,
noop_max: int = 30,
frame_skip: int = 4,
screen_size: int = 84,
terminal_on_life_loss: bool = False,
grayscale_obs: bool = True,
grayscale_newaxis: bool = False,
scale_obs: bool = False,
):
"""Wrapper for Atari 2600 preprocessing.
Args:
env (Env): The environment to apply the preprocessing
noop_max (int): For No-op reset, the max number no-ops actions are taken at reset, to turn off, set to 0.
frame_skip (int): The number of frames between new observation the agents observations effecting the frequency at which the agent experiences the game.
screen_size (int): resize Atari frame
terminal_on_life_loss (bool): `if True`, then :meth:`step()` returns `terminated=True` whenever a
life is lost.
grayscale_obs (bool): if True, then gray scale observation is returned, otherwise, RGB observation
is returned.
grayscale_newaxis (bool): `if True and grayscale_obs=True`, then a channel axis is added to
grayscale observations to make them 3-dimensional.
scale_obs (bool): if True, then observation normalized in range [0,1) is returned. It also limits memory
optimization benefits of FrameStack Wrapper.
Raises:
DependencyNotInstalled: opencv-python package not installed
ValueError: Disable frame-skipping in the original env
"""
gym.utils.RecordConstructorArgs.__init__(
self,
noop_max=noop_max,
frame_skip=frame_skip,
screen_size=screen_size,
terminal_on_life_loss=terminal_on_life_loss,
grayscale_obs=grayscale_obs,
grayscale_newaxis=grayscale_newaxis,
scale_obs=scale_obs,
)
gym.Wrapper.__init__(self, env)
try:
import cv2 # noqa: F401
except ImportError:
raise gym.error.DependencyNotInstalled(
"opencv-python package not installed, run `pip install gymnasium[other]` to get dependencies for atari"
)
assert frame_skip > 0
assert screen_size > 0
assert noop_max >= 0
if frame_skip > 1:
if (
env.spec is not None
and "NoFrameskip" not in env.spec.id
and getattr(env.unwrapped, "_frameskip", None) != 1
):
raise ValueError(
"Disable frame-skipping in the original env. Otherwise, more than one "
"frame-skip will happen as through this wrapper"
)
self.noop_max = noop_max
assert env.unwrapped.get_action_meanings()[0] == "NOOP"
self.frame_skip = frame_skip
self.screen_size = screen_size
self.terminal_on_life_loss = terminal_on_life_loss
self.grayscale_obs = grayscale_obs
self.grayscale_newaxis = grayscale_newaxis
self.scale_obs = scale_obs
# buffer of most recent two observations for max pooling
assert isinstance(env.observation_space, Box)
if grayscale_obs:
self.obs_buffer = [
np.empty(env.observation_space.shape[:2], dtype=np.uint8),
np.empty(env.observation_space.shape[:2], dtype=np.uint8),
]
else:
self.obs_buffer = [
np.empty(env.observation_space.shape, dtype=np.uint8),
np.empty(env.observation_space.shape, dtype=np.uint8),
]
self.lives = 0
self.game_over = False
_low, _high, _obs_dtype = (
(0, 255, np.uint8) if not scale_obs else (0, 1, np.float32)
)
_shape = (screen_size, screen_size, 1 if grayscale_obs else 3)
if grayscale_obs and not grayscale_newaxis:
_shape = _shape[:-1] # Remove channel axis
self.observation_space = Box(
low=_low, high=_high, shape=_shape, dtype=_obs_dtype
)
@property
def ale(self):
"""Make ale as a class property to avoid serialization error."""
return self.env.unwrapped.ale
def step(self, action):
"""Applies the preprocessing for an :meth:`env.step`."""
total_reward, terminated, truncated, info = 0.0, False, False, {}
for t in range(self.frame_skip):
_, reward, terminated, truncated, info = self.env.step(action)
total_reward += reward
self.game_over = terminated
if self.terminal_on_life_loss:
new_lives = self.ale.lives()
terminated = terminated or new_lives < self.lives
self.game_over = terminated
self.lives = new_lives
if terminated or truncated:
break
if t == self.frame_skip - 2:
if self.grayscale_obs:
self.ale.getScreenGrayscale(self.obs_buffer[1])
else:
self.ale.getScreenRGB(self.obs_buffer[1])
elif t == self.frame_skip - 1:
if self.grayscale_obs:
self.ale.getScreenGrayscale(self.obs_buffer[0])
else:
self.ale.getScreenRGB(self.obs_buffer[0])
return self._get_obs(), total_reward, terminated, truncated, info
def reset(self, **kwargs):
"""Resets the environment using preprocessing."""
# NoopReset
_, reset_info = self.env.reset(**kwargs)
noops = (
self.env.unwrapped.np_random.integers(1, self.noop_max + 1)
if self.noop_max > 0
else 0
)
for _ in range(noops):
_, _, terminated, truncated, step_info = self.env.step(0)
reset_info.update(step_info)
if terminated or truncated:
_, reset_info = self.env.reset(**kwargs)
self.lives = self.ale.lives()
if self.grayscale_obs:
self.ale.getScreenGrayscale(self.obs_buffer[0])
else:
self.ale.getScreenRGB(self.obs_buffer[0])
self.obs_buffer[1].fill(0)
return self._get_obs(), reset_info
def _get_obs(self):
if self.frame_skip > 1: # more efficient in-place pooling
np.maximum(self.obs_buffer[0], self.obs_buffer[1], out=self.obs_buffer[0])
import cv2
obs = cv2.resize(
self.obs_buffer[0],
(self.screen_size, self.screen_size),
interpolation=cv2.INTER_AREA,
)
if self.scale_obs:
obs = np.asarray(obs, dtype=np.float32) / 255.0
else:
obs = np.asarray(obs, dtype=np.uint8)
if self.grayscale_obs and self.grayscale_newaxis:
obs = np.expand_dims(obs, axis=-1) # Add a channel axis
return obs

View File

@@ -1,315 +0,0 @@
"""A collection of common wrappers.
* ``AutoresetV0`` - Auto-resets the environment
* ``PassiveEnvCheckerV0`` - Passive environment checker that does not modify any environment data
* ``OrderEnforcingV0`` - Enforces the order of function calls to environments
* ``RecordEpisodeStatisticsV0`` - Records the episode statistics
"""
from __future__ import annotations
import time
from collections import deque
from typing import Any, SupportsFloat
import numpy as np
import gymnasium as gym
from gymnasium.core import ActType, ObsType, RenderFrame
from gymnasium.error import ResetNeeded
from gymnasium.utils.passive_env_checker import (
check_action_space,
check_observation_space,
env_render_passive_checker,
env_reset_passive_checker,
env_step_passive_checker,
)
__all__ = [
"AutoresetV0",
"PassiveEnvCheckerV0",
"OrderEnforcingV0",
"RecordEpisodeStatisticsV0",
]
class AutoresetV0(
gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
):
"""A class for providing an automatic reset functionality for gymnasium environments when calling :meth:`self.step`."""
def __init__(self, env: gym.Env[ObsType, ActType]):
"""A class for providing an automatic reset functionality for gymnasium environments when calling :meth:`self.step`.
Args:
env (gym.Env): The environment to apply the wrapper
"""
gym.utils.RecordConstructorArgs.__init__(self)
gym.Wrapper.__init__(self, env)
self._episode_ended: bool = False
self._reset_options: dict[str, Any] | None = None
def step(
self, action: ActType
) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
"""Steps through the environment with action and resets the environment if a terminated or truncated signal is encountered in the previous step.
Args:
action: The action to take
Returns:
The autoreset environment :meth:`step`
"""
if self._episode_ended:
obs, info = self.env.reset(options=self._reset_options)
self._episode_ended = True
return obs, 0, False, False, info
else:
obs, reward, terminated, truncated, info = super().step(action)
self._episode_ended = terminated or truncated
return obs, reward, terminated, truncated, info
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[ObsType, dict[str, Any]]:
"""Resets the environment, saving the options used."""
self._episode_ended = False
self._reset_options = options
return super().reset(seed=seed, options=self._reset_options)
class PassiveEnvCheckerV0(
gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
):
"""A passive environment checker wrapper that surrounds the step, reset and render functions to check they follow the gymnasium API."""
def __init__(self, env: gym.Env[ObsType, ActType]):
"""Initialises the wrapper with the environments, run the observation and action space tests."""
gym.utils.RecordConstructorArgs.__init__(self)
gym.Wrapper.__init__(self, env)
assert hasattr(
env, "action_space"
), "The environment must specify an action space. https://gymnasium.farama.org/tutorials/gymnasium_basics/environment_creation/"
check_action_space(env.action_space)
assert hasattr(
env, "observation_space"
), "The environment must specify an observation space. https://gymnasium.farama.org/tutorials/gymnasium_basics/environment_creation/"
check_observation_space(env.observation_space)
self._checked_reset: bool = False
self._checked_step: bool = False
self._checked_render: bool = False
def step(
self, action: ActType
) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
"""Steps through the environment that on the first call will run the `passive_env_step_check`."""
if self._checked_step is False:
self._checked_step = True
return env_step_passive_checker(self.env, action)
else:
return self.env.step(action)
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[ObsType, dict[str, Any]]:
"""Resets the environment that on the first call will run the `passive_env_reset_check`."""
if self._checked_reset is False:
self._checked_reset = True
return env_reset_passive_checker(self.env, seed=seed, options=options)
else:
return self.env.reset(seed=seed, options=options)
def render(self) -> RenderFrame | list[RenderFrame] | None:
"""Renders the environment that on the first call will run the `passive_env_render_check`."""
if self._checked_render is False:
self._checked_render = True
return env_render_passive_checker(self.env)
else:
return self.env.render()
class OrderEnforcingV0(
gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
):
"""A wrapper that will produce an error if :meth:`step` is called before an initial :meth:`reset`.
Example:
>>> import gymnasium as gym
>>> from gymnasium.experimental.wrappers import OrderEnforcingV0
>>> env = gym.make("CartPole-v1", render_mode="human")
>>> env = OrderEnforcingV0(env)
>>> env.step(0)
Traceback (most recent call last):
...
gymnasium.error.ResetNeeded: Cannot call env.step() before calling env.reset()
>>> env.render()
Traceback (most recent call last):
...
gymnasium.error.ResetNeeded: Cannot call `env.render()` before calling `env.reset()`, if this is a intended action, set `disable_render_order_enforcing=True` on the OrderEnforcer wrapper.
>>> _ = env.reset()
>>> env.render()
>>> _ = env.step(0)
>>> env.close()
"""
def __init__(
self,
env: gym.Env[ObsType, ActType],
disable_render_order_enforcing: bool = False,
):
"""A wrapper that will produce an error if :meth:`step` is called before an initial :meth:`reset`.
Args:
env: The environment to wrap
disable_render_order_enforcing: If to disable render order enforcing
"""
gym.utils.RecordConstructorArgs.__init__(
self, disable_render_order_enforcing=disable_render_order_enforcing
)
gym.Wrapper.__init__(self, env)
self._has_reset: bool = False
self._disable_render_order_enforcing: bool = disable_render_order_enforcing
def step(self, action: ActType) -> tuple[ObsType, SupportsFloat, bool, bool, dict]:
"""Steps through the environment."""
if not self._has_reset:
raise ResetNeeded("Cannot call env.step() before calling env.reset()")
return super().step(action)
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[ObsType, dict[str, Any]]:
"""Resets the environment with `kwargs`."""
self._has_reset = True
return super().reset(seed=seed, options=options)
def render(self) -> RenderFrame | list[RenderFrame] | None:
"""Renders the environment with `kwargs`."""
if not self._disable_render_order_enforcing and not self._has_reset:
raise ResetNeeded(
"Cannot call `env.render()` before calling `env.reset()`, if this is a intended action, "
"set `disable_render_order_enforcing=True` on the OrderEnforcer wrapper."
)
return super().render()
@property
def has_reset(self):
"""Returns if the environment has been reset before."""
return self._has_reset
class RecordEpisodeStatisticsV0(
gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
):
"""This wrapper will keep track of cumulative rewards and episode lengths.
At the end of an episode, the statistics of the episode will be added to ``info``
using the key ``episode``. If using a vectorized environment also the key
``_episode`` is used which indicates whether the env at the respective index has
the episode statistics.
After the completion of an episode, ``info`` will look like this::
>>> info = {
... "episode": {
... "r": "<cumulative reward>",
... "l": "<episode length>",
... "t": "<elapsed time since beginning of episode>"
... },
... }
For a vectorized environments the output will be in the form of::
>>> infos = {
... "final_observation": "<array of length num-envs>",
... "_final_observation": "<boolean array of length num-envs>",
... "final_info": "<array of length num-envs>",
... "_final_info": "<boolean array of length num-envs>",
... "episode": {
... "r": "<array of cumulative reward>",
... "l": "<array of episode length>",
... "t": "<array of elapsed time since beginning of episode>"
... },
... "_episode": "<boolean array of length num-envs>"
... }
Moreover, the most recent rewards and episode lengths are stored in buffers that can be accessed via
:attr:`wrapped_env.return_queue` and :attr:`wrapped_env.length_queue` respectively.
Attributes:
episode_reward_buffer: The cumulative rewards of the last ``deque_size``-many episodes
episode_length_buffer: The lengths of the last ``deque_size``-many episodes
"""
def __init__(
self,
env: gym.Env[ObsType, ActType],
buffer_length: int | None = 100,
stats_key: str = "episode",
):
"""This wrapper will keep track of cumulative rewards and episode lengths.
Args:
env (Env): The environment to apply the wrapper
buffer_length: The size of the buffers :attr:`return_queue` and :attr:`length_queue`
stats_key: The info key for the episode statistics
"""
gym.utils.RecordConstructorArgs.__init__(self)
gym.Wrapper.__init__(self, env)
self._stats_key = stats_key
self.episode_count = 0
self.episode_start_time: float = -1
self.episode_reward: float = -1
self.episode_length: int = -1
self.episode_time_length_buffer: deque[int] = deque(maxlen=buffer_length)
self.episode_reward_buffer: deque[float] = deque(maxlen=buffer_length)
self.episode_length_buffer: deque[int] = deque(maxlen=buffer_length)
def step(
self, action: ActType
) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
"""Steps through the environment, recording the episode statistics."""
obs, reward, terminated, truncated, info = super().step(action)
self.episode_reward += reward
self.episode_length += 1
if terminated or truncated:
assert self._stats_key not in info
episode_time_length = np.round(
time.perf_counter() - self.episode_start_time, 6
)
info[self._stats_key] = {
"r": self.episode_reward,
"l": self.episode_length,
"t": episode_time_length,
}
self.episode_time_length_buffer.append(episode_time_length)
self.episode_reward_buffer.append(self.episode_reward)
self.episode_length_buffer.append(self.episode_length)
self.episode_count += 1
return obs, reward, terminated, truncated, info
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[ObsType, dict[str, Any]]:
"""Resets the environment using seed and options and resets the episode rewards and lengths."""
obs, info = super().reset(seed=seed, options=options)
self.episode_start_time = time.perf_counter()
self.episode_reward = 0
self.episode_length = 0
return obs, info

View File

@@ -1,620 +0,0 @@
"""A collection of observation wrappers using a lambda function.
* ``LambdaObservationV0`` - Transforms the observation with a function
* ``FilterObservationV0`` - Filters a ``Tuple`` or ``Dict`` to only include certain keys
* ``FlattenObservationV0`` - Flattens the observations
* ``GrayscaleObservationV0`` - Converts a RGB observation to a grayscale observation
* ``ResizeObservationV0`` - Resizes an array-based observation (normally a RGB observation)
* ``ReshapeObservationV0`` - Reshapes an array-based observation
* ``RescaleObservationV0`` - Rescales an observation to between a minimum and maximum value
* ``DtypeObservationV0`` - Convert an observation to a dtype
* ``PixelObservationV0`` - Allows the observation to the rendered frame
"""
from __future__ import annotations
from typing import Any, Callable, Final, Sequence
import numpy as np
import gymnasium as gym
from gymnasium import spaces
from gymnasium.core import ActType, ObsType, WrapperObsType
from gymnasium.error import DependencyNotInstalled
__all__ = [
"LambdaObservationV0",
"FilterObservationV0",
"FlattenObservationV0",
"GrayscaleObservationV0",
"ResizeObservationV0",
"ReshapeObservationV0",
"RescaleObservationV0",
"DtypeObservationV0",
"PixelObservationV0",
]
class LambdaObservationV0(
gym.ObservationWrapper[WrapperObsType, ActType, ObsType],
gym.utils.RecordConstructorArgs,
):
"""Transforms an observation via a function provided to the wrapper.
The function :attr:`func` will be applied to all observations.
If the observations from :attr:`func` are outside the bounds of the ``env``'s observation space, provide an :attr:`observation_space`.
Example:
>>> import gymnasium as gym
>>> from gymnasium.experimental.wrappers import LambdaObservationV0
>>> import numpy as np
>>> np.random.seed(0)
>>> env = gym.make("CartPole-v1")
>>> env = LambdaObservationV0(env, lambda obs: obs + 0.1 * np.random.random(obs.shape), env.observation_space)
>>> env.reset(seed=42)
(array([0.08227695, 0.06540678, 0.09613613, 0.07422512]), {})
"""
def __init__(
self,
env: gym.Env[ObsType, ActType],
func: Callable[[ObsType], Any],
observation_space: gym.Space[WrapperObsType] | None,
):
"""Constructor for the lambda observation wrapper.
Args:
env: The environment to wrap
func: A function that will transform an observation. If this transformed observation is outside the observation space of ``env.observation_space`` then provide an `observation_space`.
observation_space: The observation spaces of the wrapper, if None, then it is assumed the same as ``env.observation_space``.
"""
gym.utils.RecordConstructorArgs.__init__(
self, func=func, observation_space=observation_space
)
gym.ObservationWrapper.__init__(self, env)
if observation_space is not None:
self.observation_space = observation_space
self.func = func
def observation(self, observation: ObsType) -> Any:
"""Apply function to the observation."""
return self.func(observation)
class FilterObservationV0(
LambdaObservationV0[WrapperObsType, ActType, ObsType],
gym.utils.RecordConstructorArgs,
):
"""Filters Dict or Tuple observation space by the keys or indexes.
Example:
>>> import gymnasium as gym
>>> from gymnasium.wrappers import TransformObservation
>>> from gymnasium.experimental.wrappers import FilterObservationV0
>>> env = gym.make("CartPole-v1")
>>> env = gym.wrappers.TransformObservation(env, lambda obs: {'obs': obs, 'time': 0})
>>> env.observation_space = gym.spaces.Dict(obs=env.observation_space, time=gym.spaces.Discrete(1))
>>> env.reset(seed=42)
({'obs': array([ 0.0273956 , -0.00611216, 0.03585979, 0.0197368 ], dtype=float32), 'time': 0}, {})
>>> env = FilterObservationV0(env, filter_keys=['time'])
>>> env.reset(seed=42)
({'time': 0}, {})
>>> env.step(0)
({'time': 0}, 1.0, False, False, {})
"""
def __init__(
self, env: gym.Env[ObsType, ActType], filter_keys: Sequence[str | int]
):
"""Constructor for the filter observation wrapper.
Args:
env: The environment to wrap
filter_keys: The subspaces to be included, use a list of strings or integers for ``Dict`` and ``Tuple`` spaces respectivesly
"""
assert isinstance(filter_keys, Sequence)
gym.utils.RecordConstructorArgs.__init__(self, filter_keys=filter_keys)
# Filters for dictionary space
if isinstance(env.observation_space, spaces.Dict):
assert all(isinstance(key, str) for key in filter_keys)
if any(
key not in env.observation_space.spaces.keys() for key in filter_keys
):
missing_keys = [
key
for key in filter_keys
if key not in env.observation_space.spaces.keys()
]
raise ValueError(
"All the `filter_keys` must be included in the observation space.\n"
f"Filter keys: {filter_keys}\n"
f"Observation keys: {list(env.observation_space.spaces.keys())}\n"
f"Missing keys: {missing_keys}"
)
new_observation_space = spaces.Dict(
{key: env.observation_space[key] for key in filter_keys}
)
if len(new_observation_space) == 0:
raise ValueError(
"The observation space is empty due to filtering all keys."
)
LambdaObservationV0.__init__(
self,
env=env,
func=lambda obs: {key: obs[key] for key in filter_keys},
observation_space=new_observation_space,
)
# Filter for tuple observation
elif isinstance(env.observation_space, spaces.Tuple):
assert all(isinstance(key, int) for key in filter_keys)
assert len(set(filter_keys)) == len(
filter_keys
), f"Duplicate keys exist, filter_keys: {filter_keys}"
if any(
0 < key and key >= len(env.observation_space) for key in filter_keys
):
missing_index = [
key
for key in filter_keys
if 0 < key and key >= len(env.observation_space)
]
raise ValueError(
"All the `filter_keys` must be included in the length of the observation space.\n"
f"Filter keys: {filter_keys}, length of observation: {len(env.observation_space)}, "
f"missing indexes: {missing_index}"
)
new_observation_spaces = spaces.Tuple(
env.observation_space[key] for key in filter_keys
)
if len(new_observation_spaces) == 0:
raise ValueError(
"The observation space is empty due to filtering all keys."
)
LambdaObservationV0.__init__(
self,
env=env,
func=lambda obs: tuple(obs[key] for key in filter_keys),
observation_space=new_observation_spaces,
)
else:
raise ValueError(
f"FilterObservation wrapper is only usable with `Dict` and `Tuple` observations, actual type: {type(env.observation_space)}"
)
self.filter_keys: Final[Sequence[str | int]] = filter_keys
class FlattenObservationV0(
LambdaObservationV0[WrapperObsType, ActType, ObsType],
gym.utils.RecordConstructorArgs,
):
"""Observation wrapper that flattens the observation.
Example:
>>> import gymnasium as gym
>>> from gymnasium.experimental.wrappers import FlattenObservationV0
>>> env = gym.make("CarRacing-v2")
>>> env.observation_space.shape
(96, 96, 3)
>>> env = FlattenObservationV0(env)
>>> env.observation_space.shape
(27648,)
>>> obs, _ = env.reset()
>>> obs.shape
(27648,)
"""
def __init__(self, env: gym.Env[ObsType, ActType]):
"""Constructor for any environment's observation space that implements ``spaces.utils.flatten_space`` and ``spaces.utils.flatten``.
Args:
env: The environment to wrap
"""
gym.utils.RecordConstructorArgs.__init__(self)
LambdaObservationV0.__init__(
self,
env=env,
func=lambda obs: spaces.utils.flatten(env.observation_space, obs),
observation_space=spaces.utils.flatten_space(env.observation_space),
)
class GrayscaleObservationV0(
LambdaObservationV0[WrapperObsType, ActType, ObsType],
gym.utils.RecordConstructorArgs,
):
"""Observation wrapper that converts an RGB image to grayscale.
The :attr:`keep_dim` will keep the channel dimension
Example:
>>> import gymnasium as gym
>>> from gymnasium.experimental.wrappers import GrayscaleObservationV0
>>> env = gym.make("CarRacing-v2")
>>> env.observation_space.shape
(96, 96, 3)
>>> grayscale_env = GrayscaleObservationV0(env)
>>> grayscale_env.observation_space.shape
(96, 96)
>>> grayscale_env = GrayscaleObservationV0(env, keep_dim=True)
>>> grayscale_env.observation_space.shape
(96, 96, 1)
"""
def __init__(self, env: gym.Env[ObsType, ActType], keep_dim: bool = False):
"""Constructor for an RGB image based environments to make the image grayscale.
Args:
env: The environment to wrap
keep_dim: If to keep the channel in the observation, if ``True``, ``obs.shape == 3`` else ``obs.shape == 2``
"""
assert isinstance(env.observation_space, spaces.Box)
assert (
len(env.observation_space.shape) == 3
and env.observation_space.shape[-1] == 3
)
assert (
np.all(env.observation_space.low == 0)
and np.all(env.observation_space.high == 255)
and env.observation_space.dtype == np.uint8
)
gym.utils.RecordConstructorArgs.__init__(self, keep_dim=keep_dim)
self.keep_dim: Final[bool] = keep_dim
if keep_dim:
new_observation_space = spaces.Box(
low=0,
high=255,
shape=env.observation_space.shape[:2] + (1,),
dtype=np.uint8,
)
LambdaObservationV0.__init__(
self,
env=env,
func=lambda obs: np.expand_dims(
np.sum(
np.multiply(obs, np.array([0.2125, 0.7154, 0.0721])), axis=-1
).astype(np.uint8),
axis=-1,
),
observation_space=new_observation_space,
)
else:
new_observation_space = spaces.Box(
low=0, high=255, shape=env.observation_space.shape[:2], dtype=np.uint8
)
LambdaObservationV0.__init__(
self,
env=env,
func=lambda obs: np.sum(
np.multiply(obs, np.array([0.2125, 0.7154, 0.0721])), axis=-1
).astype(np.uint8),
observation_space=new_observation_space,
)
class ResizeObservationV0(
LambdaObservationV0[WrapperObsType, ActType, ObsType],
gym.utils.RecordConstructorArgs,
):
"""Resizes image observations using OpenCV to shape.
Example:
>>> import gymnasium as gym
>>> from gymnasium.experimental.wrappers import ResizeObservationV0
>>> env = gym.make("CarRacing-v2")
>>> env.observation_space.shape
(96, 96, 3)
>>> resized_env = ResizeObservationV0(env, (32, 32))
>>> resized_env.observation_space.shape
(32, 32, 3)
"""
def __init__(self, env: gym.Env[ObsType, ActType], shape: tuple[int, ...]):
"""Constructor that requires an image environment observation space with a shape.
Args:
env: The environment to wrap
shape: The resized observation shape
"""
assert isinstance(env.observation_space, spaces.Box)
assert len(env.observation_space.shape) in [2, 3]
assert np.all(env.observation_space.low == 0) and np.all(
env.observation_space.high == 255
)
assert env.observation_space.dtype == np.uint8
assert isinstance(shape, tuple)
assert all(np.issubdtype(type(elem), np.integer) for elem in shape)
assert all(x > 0 for x in shape)
try:
import cv2
except ImportError as e:
raise DependencyNotInstalled(
"opencv (cv2) is not installed, run `pip install gymnasium[other]`"
) from e
self.shape: Final[tuple[int, ...]] = tuple(shape)
new_observation_space = spaces.Box(
low=0,
high=255,
shape=self.shape + env.observation_space.shape[2:],
dtype=np.uint8,
)
gym.utils.RecordConstructorArgs.__init__(self, shape=shape)
LambdaObservationV0.__init__(
self,
env=env,
func=lambda obs: cv2.resize(obs, self.shape, interpolation=cv2.INTER_AREA),
observation_space=new_observation_space,
)
class ReshapeObservationV0(
LambdaObservationV0[WrapperObsType, ActType, ObsType],
gym.utils.RecordConstructorArgs,
):
"""Reshapes array based observations to shapes.
Example:
>>> import gymnasium as gym
>>> from gymnasium.experimental.wrappers import ReshapeObservationV0
>>> env = gym.make("CarRacing-v2")
>>> env.observation_space.shape
(96, 96, 3)
>>> reshape_env = ReshapeObservationV0(env, (24, 4, 96, 1, 3))
>>> reshape_env.observation_space.shape
(24, 4, 96, 1, 3)
"""
def __init__(self, env: gym.Env[ObsType, ActType], shape: int | tuple[int, ...]):
"""Constructor for env with ``Box`` observation space that has a shape product equal to the new shape product.
Args:
env: The environment to wrap
shape: The reshaped observation space
"""
assert isinstance(env.observation_space, spaces.Box)
assert np.product(shape) == np.product(env.observation_space.shape)
assert isinstance(shape, tuple)
assert all(np.issubdtype(type(elem), np.integer) for elem in shape)
assert all(x > 0 or x == -1 for x in shape)
new_observation_space = spaces.Box(
low=np.reshape(np.ravel(env.observation_space.low), shape),
high=np.reshape(np.ravel(env.observation_space.high), shape),
shape=shape,
dtype=env.observation_space.dtype,
)
self.shape = shape
gym.utils.RecordConstructorArgs.__init__(self, shape=shape)
LambdaObservationV0.__init__(
self,
env=env,
func=lambda obs: np.reshape(obs, shape),
observation_space=new_observation_space,
)
class RescaleObservationV0(
LambdaObservationV0[WrapperObsType, ActType, ObsType],
gym.utils.RecordConstructorArgs,
):
"""Linearly rescales observation to between a minimum and maximum value.
Example:
>>> import gymnasium as gym
>>> from gymnasium.experimental.wrappers import RescaleObservationV0
>>> env = gym.make("Pendulum-v1")
>>> env.observation_space
Box([-1. -1. -8.], [1. 1. 8.], (3,), float32)
>>> env = RescaleObservationV0(env, np.array([-2, -1, -10], dtype=np.float32), np.array([1, 0, 1], dtype=np.float32))
>>> env.observation_space
Box([ -2. -1. -10.], [1. 0. 1.], (3,), float32)
"""
def __init__(
self,
env: gym.Env[ObsType, ActType],
min_obs: np.floating | np.integer | np.ndarray,
max_obs: np.floating | np.integer | np.ndarray,
):
"""Constructor that requires the env observation spaces to be a :class:`Box`.
Args:
env: The environment to wrap
min_obs: The new minimum observation bound
max_obs: The new maximum observation bound
"""
assert isinstance(env.observation_space, spaces.Box)
assert not np.any(env.observation_space.low == np.inf) and not np.any(
env.observation_space.high == np.inf
)
if not isinstance(min_obs, np.ndarray):
assert np.issubdtype(type(min_obs), np.integer) or np.issubdtype(
type(max_obs), np.floating
)
min_obs = np.full(env.observation_space.shape, min_obs)
assert (
min_obs.shape == env.observation_space.shape
), f"{min_obs.shape}, {env.observation_space.shape}, {min_obs}, {env.observation_space.low}"
assert not np.any(min_obs == np.inf)
if not isinstance(max_obs, np.ndarray):
assert np.issubdtype(type(max_obs), np.integer) or np.issubdtype(
type(max_obs), np.floating
)
max_obs = np.full(env.observation_space.shape, max_obs)
assert max_obs.shape == env.observation_space.shape
assert not np.any(max_obs == np.inf)
self.min_obs = min_obs
self.max_obs = max_obs
# Imagine the x-axis between the old Box and the y-axis being the new Box
gradient = (max_obs - min_obs) / (
env.observation_space.high - env.observation_space.low
)
intercept = gradient * -env.observation_space.low + min_obs
gym.utils.RecordConstructorArgs.__init__(self, min_obs=min_obs, max_obs=max_obs)
LambdaObservationV0.__init__(
self,
env=env,
func=lambda obs: gradient * obs + intercept,
observation_space=spaces.Box(
low=min_obs,
high=max_obs,
shape=env.observation_space.shape,
dtype=env.observation_space.dtype,
),
)
class DtypeObservationV0(
LambdaObservationV0[WrapperObsType, ActType, ObsType],
gym.utils.RecordConstructorArgs,
):
"""Observation wrapper for transforming the dtype of an observation.
Note:
This is only compatible with :class:`Box`, :class:`Discrete`, :class:`MultiDiscrete` and :class:`MultiBinary` observation spaces
"""
def __init__(self, env: gym.Env[ObsType, ActType], dtype: Any):
"""Constructor for Dtype observation wrapper.
Args:
env: The environment to wrap
dtype: The new dtype of the observation
"""
assert isinstance(
env.observation_space,
(spaces.Box, spaces.Discrete, spaces.MultiDiscrete, spaces.MultiBinary),
)
self.dtype = dtype
if isinstance(env.observation_space, spaces.Box):
new_observation_space = spaces.Box(
low=env.observation_space.low,
high=env.observation_space.high,
shape=env.observation_space.shape,
dtype=self.dtype,
)
elif isinstance(env.observation_space, spaces.Discrete):
new_observation_space = spaces.Box(
low=env.observation_space.start,
high=env.observation_space.start + env.observation_space.n,
shape=(),
dtype=self.dtype,
)
elif isinstance(env.observation_space, spaces.MultiDiscrete):
new_observation_space = spaces.MultiDiscrete(
env.observation_space.nvec, dtype=dtype
)
elif isinstance(env.observation_space, spaces.MultiBinary):
new_observation_space = spaces.Box(
low=0,
high=1,
shape=env.observation_space.shape,
dtype=self.dtype,
)
else:
raise TypeError(
"DtypeObservation is only compatible with value / array-based observations."
)
gym.utils.RecordConstructorArgs.__init__(self, dtype=dtype)
LambdaObservationV0.__init__(
self,
env=env,
func=lambda obs: dtype(obs),
observation_space=new_observation_space,
)
class PixelObservationV0(
LambdaObservationV0[WrapperObsType, ActType, ObsType],
gym.utils.RecordConstructorArgs,
):
"""Includes the rendered observations to the environment's observations.
Observations of this wrapper will be dictionaries of images.
You can also choose to add the observation of the base environment to this dictionary.
In that case, if the base environment has an observation space of type :class:`Dict`, the dictionary
of rendered images will be updated with the base environment's observation. If, however, the observation
space is of type :class:`Box`, the base environment's observation (which will be an element of the :class:`Box`
space) will be added to the dictionary under the key "state".
"""
def __init__(
self,
env: gym.Env[ObsType, ActType],
pixels_only: bool = True,
pixels_key: str = "pixels",
obs_key: str = "state",
):
"""Constructor of the pixel observation wrapper.
Args:
env: The environment to wrap.
pixels_only (bool): If ``True`` (default), the original observation returned
by the wrapped environment will be discarded, and a dictionary
observation will only include pixels. If ``False``, the
observation dictionary will contain both the original
observations and the pixel observations.
pixels_key: Optional custom string specifying the pixel key. Defaults to "pixels"
obs_key: Optional custom string specifying the obs key. Defaults to "state"
"""
gym.utils.RecordConstructorArgs.__init__(
self, pixels_only=pixels_only, pixels_key=pixels_key, obs_key=obs_key
)
assert env.render_mode is not None and env.render_mode != "human"
env.reset()
pixels = env.render()
assert pixels is not None and isinstance(pixels, np.ndarray)
pixel_space = spaces.Box(low=0, high=255, shape=pixels.shape, dtype=np.uint8)
if pixels_only:
obs_space = pixel_space
LambdaObservationV0.__init__(
self, env=env, func=lambda _: self.render(), observation_space=obs_space
)
elif isinstance(env.observation_space, spaces.Dict):
assert pixels_key not in env.observation_space.spaces.keys()
obs_space = spaces.Dict(
{pixels_key: pixel_space, **env.observation_space.spaces}
)
LambdaObservationV0.__init__(
self,
env=env,
func=lambda obs: {pixels_key: self.render(), **obs_space},
observation_space=obs_space,
)
else:
obs_space = spaces.Dict(
{obs_key: env.observation_space, pixels_key: pixel_space}
)
LambdaObservationV0.__init__(
self,
env=env,
func=lambda obs: {obs_key: obs, pixels_key: self.render()},
observation_space=obs_space,
)

View File

@@ -1,102 +0,0 @@
"""A collection of wrappers for modifying the reward.
* ``LambdaRewardV0`` - Transforms the reward by a function
* ``ClipRewardV0`` - Clips the reward between a minimum and maximum value
"""
from __future__ import annotations
from typing import Callable, SupportsFloat
import numpy as np
import gymnasium as gym
from gymnasium.core import ActType, ObsType
from gymnasium.error import InvalidBound
__all__ = ["LambdaRewardV0", "ClipRewardV0"]
class LambdaRewardV0(
gym.RewardWrapper[ObsType, ActType], gym.utils.RecordConstructorArgs
):
"""A reward wrapper that allows a custom function to modify the step reward.
Example:
>>> import gymnasium as gym
>>> from gymnasium.experimental.wrappers import LambdaRewardV0
>>> env = gym.make("CartPole-v1")
>>> env = LambdaRewardV0(env, lambda r: 2 * r + 1)
>>> _ = env.reset()
>>> _, rew, _, _, _ = env.step(0)
>>> rew
3.0
"""
def __init__(
self,
env: gym.Env[ObsType, ActType],
func: Callable[[SupportsFloat], SupportsFloat],
):
"""Initialize LambdaRewardV0 wrapper.
Args:
env (Env): The environment to wrap
func: (Callable): The function to apply to reward
"""
gym.utils.RecordConstructorArgs.__init__(self, func=func)
gym.RewardWrapper.__init__(self, env)
self.func = func
def reward(self, reward: SupportsFloat) -> SupportsFloat:
"""Apply function to reward.
Args:
reward (Union[float, int, np.ndarray]): environment's reward
"""
return self.func(reward)
class ClipRewardV0(LambdaRewardV0[ObsType, ActType], gym.utils.RecordConstructorArgs):
"""A wrapper that clips the rewards for an environment between an upper and lower bound.
Example:
>>> import gymnasium as gym
>>> from gymnasium.experimental.wrappers import ClipRewardV0
>>> env = gym.make("CartPole-v1")
>>> env = ClipRewardV0(env, 0, 0.5)
>>> _ = env.reset()
>>> _, rew, _, _, _ = env.step(1)
>>> rew
0.5
"""
def __init__(
self,
env: gym.Env[ObsType, ActType],
min_reward: float | np.ndarray | None = None,
max_reward: float | np.ndarray | None = None,
):
"""Initialize ClipRewardsV0 wrapper.
Args:
env (Env): The environment to wrap
min_reward (Union[float, np.ndarray]): lower bound to apply
max_reward (Union[float, np.ndarray]): higher bound to apply
"""
if min_reward is None and max_reward is None:
raise InvalidBound("Both `min_reward` and `max_reward` cannot be None")
elif max_reward is not None and min_reward is not None:
if np.any(max_reward - min_reward < 0):
raise InvalidBound(
f"Min reward ({min_reward}) must be smaller than max reward ({max_reward})"
)
gym.utils.RecordConstructorArgs.__init__(
self, min_reward=min_reward, max_reward=max_reward
)
LambdaRewardV0.__init__(
self, env=env, func=lambda x: np.clip(x, a_min=min_reward, a_max=max_reward)
)

View File

@@ -1,146 +0,0 @@
"""Wrappers for vector environments."""
# pyright: reportUnsupportedDunderAll=false
import importlib
import re
from gymnasium.error import DeprecatedWrapper
from gymnasium.experimental.wrappers.vector.dict_info_to_list import DictInfoToListV0
from gymnasium.experimental.wrappers.vector.record_episode_statistics import (
RecordEpisodeStatisticsV0,
)
from gymnasium.experimental.wrappers.vector.vectorize_action import (
ClipActionV0,
LambdaActionV0,
RescaleActionV0,
VectorizeLambdaActionV0,
)
from gymnasium.experimental.wrappers.vector.vectorize_observation import (
DtypeObservationV0,
FilterObservationV0,
FlattenObservationV0,
GrayscaleObservationV0,
LambdaObservationV0,
RescaleObservationV0,
ReshapeObservationV0,
ResizeObservationV0,
VectorizeLambdaObservationV0,
)
from gymnasium.experimental.wrappers.vector.vectorize_reward import (
ClipRewardV0,
LambdaRewardV0,
VectorizeLambdaRewardV0,
)
__all__ = [
# --- Vector only wrappers
"VectorizeLambdaObservationV0",
"VectorizeLambdaActionV0",
"VectorizeLambdaRewardV0",
"DictInfoToListV0",
# --- Observation wrappers ---
"LambdaObservationV0",
"FilterObservationV0",
"FlattenObservationV0",
"GrayscaleObservationV0",
"ResizeObservationV0",
"ReshapeObservationV0",
"RescaleObservationV0",
"DtypeObservationV0",
# "PixelObservationV0",
# "NormalizeObservationV0",
# "TimeAwareObservationV0",
# "FrameStackObservationV0",
# "DelayObservationV0",
# --- Action Wrappers ---
"LambdaActionV0",
"ClipActionV0",
"RescaleActionV0",
# --- Reward wrappers ---
"LambdaRewardV0",
"ClipRewardV0",
# "NormalizeRewardV1",
# --- Common ---
"RecordEpisodeStatisticsV0",
# --- Rendering ---
# "RenderCollectionV0",
# "RecordVideoV0",
# "HumanRenderingV0",
# --- Conversion ---
"JaxToNumpyV0",
"JaxToTorchV0",
"NumpyToTorchV0",
]
# As these wrappers requires `jax` or `torch`, they are loaded by runtime on users trying to access them
# to avoid `import jax` or `import torch` on `import gymnasium`.
_wrapper_to_class = {
# data converters
"JaxToNumpyV0": "jax_to_numpy",
"JaxToTorchV0": "jax_to_torch",
"NumpyToTorchV0": "numpy_to_torch",
}
def __getattr__(wrapper_name: str):
"""Load a wrapper by name.
This optimizes the loading of gymnasium wrappers by only loading the wrapper if it is used.
Errors will be raised if the wrapper does not exist or if the version is not the latest.
Args:
wrapper_name: The name of a wrapper to load.
Returns:
The specified wrapper.
Raises:
AttributeError: If the wrapper does not exist.
DeprecatedWrapper: If the version is not the latest.
"""
# Check if the requested wrapper is in the _wrapper_to_class dictionary
if wrapper_name in _wrapper_to_class:
import_stmt = (
f"gymnasium.experimental.wrappers.vector.{_wrapper_to_class[wrapper_name]}"
)
module = importlib.import_module(import_stmt)
return getattr(module, wrapper_name)
# Define a regex pattern to match the integer suffix (version number) of the wrapper
int_suffix_pattern = r"(\d+)$"
version_match = re.search(int_suffix_pattern, wrapper_name)
# If a version number is found, extract it and the base wrapper name
if version_match:
version = int(version_match.group())
base_name = wrapper_name[: -len(version_match.group())]
else:
version = float("inf")
base_name = wrapper_name[:-2]
# Filter the list of all wrappers to include only those with the same base name
matching_wrappers = [name for name in __all__ if name.startswith(base_name)]
# If no matching wrappers are found, raise an AttributeError
if not matching_wrappers:
raise AttributeError(f"module {__name__!r} has no attribute {wrapper_name!r}")
# Find the latest version of the matching wrappers
latest_wrapper = max(
matching_wrappers, key=lambda s: int(re.findall(int_suffix_pattern, s)[0])
)
latest_version = int(re.findall(int_suffix_pattern, latest_wrapper)[0])
# If the requested wrapper is an older version, raise a DeprecatedWrapper exception
if version < latest_version:
raise DeprecatedWrapper(
f"{wrapper_name!r} is now deprecated, use {latest_wrapper!r} instead.\n"
f"To see the changes made, go to "
f"https://gymnasium.farama.org/api/experimental/vector-wrappers/#gymnasium.experimental.wrappers.vector.{latest_wrapper}"
)
# If the requested version is invalid, raise an AttributeError
else:
raise AttributeError(
f"module {__name__!r} has no attribute {wrapper_name!r}, did you mean {latest_wrapper!r}"
)

View File

@@ -1,86 +0,0 @@
"""Wrapper that converts the info format for vec envs into the list format."""
from __future__ import annotations
from typing import Any
from gymnasium.core import ActType, ObsType
from gymnasium.experimental.vector.vector_env import ArrayType, VectorEnv, VectorWrapper
__all__ = ["DictInfoToListV0"]
class DictInfoToListV0(VectorWrapper):
"""Converts infos of vectorized environments from dict to List[dict].
This wrapper converts the info format of a
vector environment from a dictionary to a list of dictionaries.
This wrapper is intended to be used around vectorized
environments. If using other wrappers that perform
operation on info like `RecordEpisodeStatistics` this
need to be the outermost wrapper.
i.e. ``DictInfoToListV0(RecordEpisodeStatisticsV0(vector_env))``
Example::
>>> import numpy as np
>>> dict_info = {
... "k": np.array([0., 0., 0.5, 0.3]),
... "_k": np.array([False, False, True, True])
... }
>>> list_info = [{}, {}, {"k": 0.5}, {"k": 0.3}]
"""
def __init__(self, env: VectorEnv):
"""This wrapper will convert the info into the list format.
Args:
env (Env): The environment to apply the wrapper
"""
super().__init__(env)
def step(
self, actions: ActType
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, list[dict[str, Any]]]:
"""Steps through the environment, convert dict info to list."""
observation, reward, terminated, truncated, infos = self.env.step(actions)
list_info = self._convert_info_to_list(infos)
return observation, reward, terminated, truncated, list_info
def reset(
self,
*,
seed: int | list[int] | None = None,
options: dict[str, Any] | None = None,
) -> tuple[ObsType, list[dict[str, Any]]]:
"""Resets the environment using kwargs."""
obs, infos = self.env.reset(seed=seed, options=options)
list_info = self._convert_info_to_list(infos)
return obs, list_info
def _convert_info_to_list(self, infos: dict) -> list[dict[str, Any]]:
"""Convert the dict info to list.
Convert the dict info of the vectorized environment
into a list of dictionaries where the i-th dictionary
has the info of the i-th environment.
Args:
infos (dict): info dict coming from the env.
Returns:
list_info (list): converted info.
"""
list_info = [{} for _ in range(self.num_envs)]
list_info = self._process_episode_statistics(infos, list_info)
for k in infos:
if k.startswith("_"):
continue
for i, has_info in enumerate(infos[f"_{k}"]):
if has_info:
list_info[i][k] = infos[k][i]
return list_info

View File

@@ -1,143 +0,0 @@
"""Vectorizes action wrappers to work for `VectorEnv`."""
from __future__ import annotations
from copy import deepcopy
from typing import Any, Callable
import numpy as np
from gymnasium import Space
from gymnasium.core import ActType, Env
from gymnasium.experimental.vector import VectorActionWrapper, VectorEnv
from gymnasium.experimental.wrappers import lambda_action
from gymnasium.vector.utils import batch_space, concatenate, create_empty_array, iterate
class LambdaActionV0(VectorActionWrapper):
"""Transforms an action via a function provided to the wrapper.
The function :attr:`func` will be applied to all vector actions.
If the observations from :attr:`func` are outside the bounds of the ``env``'s action space, provide an :attr:`action_space`.
"""
def __init__(
self,
env: VectorEnv,
func: Callable[[ActType], Any],
action_space: Space | None = None,
):
"""Constructor for the lambda action wrapper.
Args:
env: The vector environment to wrap
func: A function that will transform an action. If this transformed action is outside the action space of ``env.action_space`` then provide an ``action_space``.
action_space: The action spaces of the wrapper, if None, then it is assumed the same as ``env.action_space``.
"""
super().__init__(env)
if action_space is not None:
self.action_space = action_space
self.func = func
def actions(self, actions: ActType) -> ActType:
"""Applies the :attr:`func` to the actions."""
return self.func(actions)
class VectorizeLambdaActionV0(VectorActionWrapper):
"""Vectorizes a single-agent lambda action wrapper for vector environments."""
class VectorizedEnv(Env):
"""Fake single-agent environment uses for the single-agent wrapper."""
def __init__(self, action_space: Space):
"""Constructor for the fake environment."""
self.action_space = action_space
def __init__(
self, env: VectorEnv, wrapper: type[lambda_action.LambdaActionV0], **kwargs: Any
):
"""Constructor for the vectorized lambda action wrapper.
Args:
env: The vector environment to wrap
wrapper: The wrapper to vectorize
**kwargs: Arguments for the LambdaActionV0 wrapper
"""
super().__init__(env)
self.wrapper = wrapper(
self.VectorizedEnv(self.env.single_action_space), **kwargs
)
self.single_action_space = self.wrapper.action_space
self.action_space = batch_space(self.single_action_space, self.num_envs)
self.same_out = self.action_space == self.env.action_space
self.out = create_empty_array(self.single_action_space, self.num_envs)
def actions(self, actions: ActType) -> ActType:
"""Applies the wrapper to each of the action.
Args:
actions: The actions to apply the function to
Returns:
The updated actions using the wrapper func
"""
if self.same_out:
return concatenate(
self.single_action_space,
tuple(
self.wrapper.func(action)
for action in iterate(self.action_space, actions)
),
actions,
)
else:
return deepcopy(
concatenate(
self.single_action_space,
tuple(
self.wrapper.func(action)
for action in iterate(self.action_space, actions)
),
self.out,
)
)
class ClipActionV0(VectorizeLambdaActionV0):
"""Clip the continuous action within the valid :class:`Box` observation space bound."""
def __init__(self, env: VectorEnv):
"""Constructor for the Clip Action wrapper.
Args:
env: The vector environment to wrap
"""
super().__init__(env, lambda_action.ClipActionV0)
class RescaleActionV0(VectorizeLambdaActionV0):
"""Affinely rescales the continuous action space of the environment to the range [min_action, max_action]."""
def __init__(
self,
env: VectorEnv,
min_action: float | int | np.ndarray,
max_action: float | int | np.ndarray,
):
"""Initializes the :class:`RescaleAction` wrapper.
Args:
env (Env): The vector environment to wrap
min_action (float, int or np.ndarray): The min values for each action. This may be a numpy array or a scalar.
max_action (float, int or np.ndarray): The max values for each action. This may be a numpy array or a scalar.
"""
super().__init__(
env,
lambda_action.RescaleActionV0,
min_action=min_action,
max_action=max_action,
)

View File

@@ -1,222 +0,0 @@
"""Vectorizes observation wrappers to works for `VectorEnv`."""
from __future__ import annotations
from copy import deepcopy
from typing import Any, Callable, Sequence
import numpy as np
from gymnasium import Space
from gymnasium.core import Env, ObsType
from gymnasium.experimental.vector import VectorEnv, VectorObservationWrapper
from gymnasium.experimental.vector.utils import batch_space, concatenate, iterate
from gymnasium.experimental.wrappers import lambda_observation
from gymnasium.vector.utils import create_empty_array
class LambdaObservationV0(VectorObservationWrapper):
"""Transforms an observation via a function provided to the wrapper.
The function :attr:`func` will be applied to all vector observations.
If the observations from :attr:`func` are outside the bounds of the ``env``'s observation space, provide an :attr:`observation_space`.
"""
def __init__(
self,
env: VectorEnv,
vector_func: Callable[[ObsType], Any],
single_func: Callable[[ObsType], Any],
observation_space: Space | None = None,
):
"""Constructor for the lambda observation wrapper.
Args:
env: The vector environment to wrap
vector_func: A function that will transform the vector observation. If this transformed observation is outside the observation space of ``env.observation_space`` then provide an ``observation_space``.
single_func: A function that will transform an individual observation.
observation_space: The observation spaces of the wrapper, if None, then it is assumed the same as ``env.observation_space``.
"""
super().__init__(env)
if observation_space is not None:
self.observation_space = observation_space
self.vector_func = vector_func
self.single_func = single_func
def vector_observation(self, observation: ObsType) -> ObsType:
"""Apply function to the vector observation."""
return self.vector_func(observation)
def single_observation(self, observation: ObsType) -> ObsType:
"""Apply function to the single observation."""
return self.single_func(observation)
class VectorizeLambdaObservationV0(VectorObservationWrapper):
"""Vectori`es a single-agent lambda observation wrapper for vector environments."""
class VectorizedEnv(Env):
"""Fake single-agent environment uses for the single-agent wrapper."""
def __init__(self, observation_space: Space):
"""Constructor for the fake environment."""
self.observation_space = observation_space
def __init__(
self,
env: VectorEnv,
wrapper: type[lambda_observation.LambdaObservationV0],
**kwargs: Any,
):
"""Constructor for the vectorized lambda observation wrapper.
Args:
env: The vector environment to wrap.
wrapper: The wrapper to vectorize
**kwargs: Keyword argument for the wrapper
"""
super().__init__(env)
self.wrapper = wrapper(
self.VectorizedEnv(self.env.single_observation_space), **kwargs
)
self.single_observation_space = self.wrapper.observation_space
self.observation_space = batch_space(
self.single_observation_space, self.num_envs
)
self.same_out = self.observation_space == self.env.observation_space
self.out = create_empty_array(self.single_observation_space, self.num_envs)
def vector_observation(self, observation: ObsType) -> ObsType:
"""Iterates over the vector observations applying the single-agent wrapper ``observation`` then concatenates the observations together again."""
if self.same_out:
return concatenate(
self.single_observation_space,
tuple(
self.wrapper.func(obs)
for obs in iterate(self.observation_space, observation)
),
observation,
)
else:
return deepcopy(
concatenate(
self.single_observation_space,
tuple(
self.wrapper.func(obs)
for obs in iterate(self.observation_space, observation)
),
self.out,
)
)
def single_observation(self, observation: ObsType) -> ObsType:
"""Transforms a single observation using the wrapper transformation function."""
return self.wrapper.func(observation)
class FilterObservationV0(VectorizeLambdaObservationV0):
"""Vector wrapper for filtering dict or tuple observation spaces."""
def __init__(self, env: VectorEnv, filter_keys: Sequence[str | int]):
"""Constructor for the filter observation wrapper.
Args:
env: The vector environment to wrap
filter_keys: The subspaces to be included, use a list of strings or integers for ``Dict`` and ``Tuple`` spaces respectivesly
"""
super().__init__(
env, lambda_observation.FilterObservationV0, filter_keys=filter_keys
)
class FlattenObservationV0(VectorizeLambdaObservationV0):
"""Observation wrapper that flattens the observation."""
def __init__(self, env: VectorEnv):
"""Constructor for any environment's observation space that implements ``spaces.utils.flatten_space`` and ``spaces.utils.flatten``.
Args:
env: The vector environment to wrap
"""
super().__init__(env, lambda_observation.FlattenObservationV0)
class GrayscaleObservationV0(VectorizeLambdaObservationV0):
"""Observation wrapper that converts an RGB image to grayscale."""
def __init__(self, env: VectorEnv, keep_dim: bool = False):
"""Constructor for an RGB image based environments to make the image grayscale.
Args:
env: The vector environment to wrap
keep_dim: If to keep the channel in the observation, if ``True``, ``obs.shape == 3`` else ``obs.shape == 2``
"""
super().__init__(
env, lambda_observation.GrayscaleObservationV0, keep_dim=keep_dim
)
class ResizeObservationV0(VectorizeLambdaObservationV0):
"""Resizes image observations using OpenCV to shape."""
def __init__(self, env: VectorEnv, shape: tuple[int, ...]):
"""Constructor that requires an image environment observation space with a shape.
Args:
env: The vector environment to wrap
shape: The resized observation shape
"""
super().__init__(env, lambda_observation.ResizeObservationV0, shape=shape)
class ReshapeObservationV0(VectorizeLambdaObservationV0):
"""Reshapes array based observations to shapes."""
def __init__(self, env: VectorEnv, shape: int | tuple[int, ...]):
"""Constructor for env with Box observation space that has a shape product equal to the new shape product.
Args:
env: The vector environment to wrap
shape: The reshaped observation space
"""
super().__init__(env, lambda_observation.ReshapeObservationV0, shape=shape)
class RescaleObservationV0(VectorizeLambdaObservationV0):
"""Linearly rescales observation to between a minimum and maximum value."""
def __init__(
self,
env: VectorEnv,
min_obs: np.floating | np.integer | np.ndarray,
max_obs: np.floating | np.integer | np.ndarray,
):
"""Constructor that requires the env observation spaces to be a :class:`Box`.
Args:
env: The vector environment to wrap
min_obs: The new minimum observation bound
max_obs: The new maximum observation bound
"""
super().__init__(
env,
lambda_observation.RescaleObservationV0,
min_obs=min_obs,
max_obs=max_obs,
)
class DtypeObservationV0(VectorizeLambdaObservationV0):
"""Observation wrapper for transforming the dtype of an observation."""
def __init__(self, env: VectorEnv, dtype: Any):
"""Constructor for Dtype observation wrapper.
Args:
env: The vector environment to wrap
dtype: The new dtype of the observation
"""
super().__init__(env, lambda_observation.DtypeObservationV0, dtype=dtype)

View File

@@ -1,78 +0,0 @@
"""Vectorizes reward function to work with `VectorEnv`."""
from __future__ import annotations
from typing import Any, Callable
import numpy as np
from gymnasium import Env
from gymnasium.experimental.vector import VectorEnv, VectorRewardWrapper
from gymnasium.experimental.vector.vector_env import ArrayType
from gymnasium.experimental.wrappers import lambda_reward
class LambdaRewardV0(VectorRewardWrapper):
"""A reward wrapper that allows a custom function to modify the step reward."""
def __init__(self, env: VectorEnv, func: Callable[[ArrayType], ArrayType]):
"""Initialize LambdaRewardV0 wrapper.
Args:
env (Env): The vector environment to wrap
func: (Callable): The function to apply to reward
"""
super().__init__(env)
self.func = func
def reward(self, reward: ArrayType) -> ArrayType:
"""Apply function to reward."""
return self.func(reward)
class VectorizeLambdaRewardV0(VectorRewardWrapper):
"""Vectorizes a single-agent lambda reward wrapper for vector environments."""
def __init__(
self, env: VectorEnv, wrapper: type[lambda_reward.LambdaRewardV0], **kwargs: Any
):
"""Constructor for the vectorized lambda reward wrapper.
Args:
env: The vector environment to wrap.
wrapper: The wrapper to vectorize
**kwargs: Keyword argument for the wrapper
"""
super().__init__(env)
self.wrapper = wrapper(Env(), **kwargs)
def reward(self, reward: ArrayType) -> ArrayType:
"""Iterates over the reward updating each with the wrapper func."""
for i, r in enumerate(reward):
reward[i] = self.wrapper.func(r)
return reward
class ClipRewardV0(VectorizeLambdaRewardV0):
"""A wrapper that clips the rewards for an environment between an upper and lower bound."""
def __init__(
self,
env: VectorEnv,
min_reward: float | np.ndarray | None = None,
max_reward: float | np.ndarray | None = None,
):
"""Constructor for ClipReward wrapper.
Args:
env: The vector environment to wrap
min_reward: The min reward for each step
max_reward: the max reward for each step
"""
super().__init__(
env,
lambda_reward.ClipRewardV0,
min_reward=min_reward,
max_reward=max_reward,
)

View File

@@ -24,13 +24,14 @@ class FuncEnv(
This API is meant to be used in a stateless manner, with the environment state being passed around explicitly. This API is meant to be used in a stateless manner, with the environment state being passed around explicitly.
That being said, nothing here prevents users from using the environment statefully, it's just not recommended. That being said, nothing here prevents users from using the environment statefully, it's just not recommended.
A functional env consists of the following functions (in this case, instance methods): A functional env consists of the following functions (in this case, instance methods):
- initial: returns the initial state of the POMDP
- observation: returns the observation in a given state * initial: returns the initial state of the POMDP
- transition: returns the next state after taking an action in a given state * observation: returns the observation in a given state
- reward: returns the reward for a given (state, action, next_state) tuple * transition: returns the next state after taking an action in a given state
- terminal: returns whether a given state is terminal * reward: returns the reward for a given (state, action, next_state) tuple
- state_info: optional, returns a dict of info about a given state * terminal: returns whether a given state is terminal
- step_info: optional, returns a dict of info about a given (state, action, next_state) tuple * state_info: optional, returns a dict of info about a given state
* step_info: optional, returns a dict of info about a given (state, action, next_state) tuple
The class-based structure serves the purpose of allowing environment constants to be defined in the class, The class-based structure serves the purpose of allowing environment constants to be defined in the class,
and then using them by name in the code itself. and then using them by name in the code itself.
@@ -47,32 +48,32 @@ class FuncEnv(
self.__dict__.update(options or {}) self.__dict__.update(options or {})
def initial(self, rng: Any) -> StateType: def initial(self, rng: Any) -> StateType:
"""Initial state.""" """Generates the initial state of the environment with a random number generator."""
raise NotImplementedError raise NotImplementedError
def transition(self, state: StateType, action: ActType, rng: Any) -> StateType: def transition(self, state: StateType, action: ActType, rng: Any) -> StateType:
"""Transition.""" """Updates (transitions) the state with an action and random number generator."""
raise NotImplementedError raise NotImplementedError
def observation(self, state: StateType) -> ObsType: def observation(self, state: StateType) -> ObsType:
"""Observation.""" """Generates an observation for a given state of an environment."""
raise NotImplementedError raise NotImplementedError
def reward( def reward(
self, state: StateType, action: ActType, next_state: StateType self, state: StateType, action: ActType, next_state: StateType
) -> RewardType: ) -> RewardType:
"""Reward.""" """Computes the reward for a given transition between `state`, `action` to `next_state`."""
raise NotImplementedError raise NotImplementedError
def terminal(self, state: StateType) -> TerminalType: def terminal(self, state: StateType) -> TerminalType:
"""Terminal state.""" """Returns if the state is a final terminal state."""
raise NotImplementedError raise NotImplementedError
def state_info(self, state: StateType) -> dict: def state_info(self, state: StateType) -> dict:
"""Info dict about a single state.""" """Info dict about a single state."""
return {} return {}
def step_info( def transition_info(
self, state: StateType, action: ActType, next_state: StateType self, state: StateType, action: ActType, next_state: StateType
) -> dict: ) -> dict:
"""Info dict about a full transition.""" """Info dict about a full transition."""
@@ -82,11 +83,13 @@ class FuncEnv(
"""Functional transformations.""" """Functional transformations."""
self.initial = func(self.initial) self.initial = func(self.initial)
self.transition = func(self.transition) self.transition = func(self.transition)
self.observation = func(self.observation) self.observation = func(self.observation)
self.reward = func(self.reward) self.reward = func(self.reward)
self.terminal = func(self.terminal) self.terminal = func(self.terminal)
self.state_info = func(self.state_info) self.state_info = func(self.state_info)
self.step_info = func(self.step_info) self.transition_info = func(self.transition_info)
def render_image( def render_image(
self, state: StateType, render_state: RenderStateType self, state: StateType, render_state: RenderStateType

View File

@@ -274,7 +274,7 @@ class Box(Space[NDArray[Any]]):
return ( return (
isinstance(other, Box) isinstance(other, Box)
and (self.shape == other.shape) and (self.shape == other.shape)
# and (self.dtype == other.dtype) and (self.dtype == other.dtype)
and np.allclose(self.low, other.low) and np.allclose(self.low, other.low)
and np.allclose(self.high, other.high) and np.allclose(self.high, other.high)
) )

View File

@@ -45,7 +45,7 @@ class Dict(Space[typing.Dict[str, Any]], typing.Mapping[str, Space[Any]]):
It can be convenient to use :class:`Dict` spaces if you want to make complex observations or actions more human-readable. It can be convenient to use :class:`Dict` spaces if you want to make complex observations or actions more human-readable.
Usually, it will not be possible to use elements of this space directly in learning code. However, you can easily Usually, it will not be possible to use elements of this space directly in learning code. However, you can easily
convert `Dict` observations to flat arrays by using a :class:`gymnasium.wrappers.FlattenObservation` wrapper. convert :class:`Dict` observations to flat arrays by using a :class:`gymnasium.wrappers.FlattenObservation` wrapper.
Similar wrappers can be implemented to deal with :class:`Dict` actions. Similar wrappers can be implemented to deal with :class:`Dict` actions.
""" """

View File

@@ -62,8 +62,8 @@ class Discrete(Space[np.int64]):
Args: Args:
mask: An optional mask for if an action can be selected. mask: An optional mask for if an action can be selected.
Expected `np.ndarray` of shape `(n,)` and dtype `np.int8` where `1` represents valid actions and `0` invalid / infeasible actions. Expected `np.ndarray` of shape ``(n,)`` and dtype ``np.int8`` where ``1`` represents valid actions and ``0`` invalid / infeasible actions.
If there are no possible actions (i.e. `np.all(mask == 0)`) then `space.start` will be returned. If there are no possible actions (i.e. ``np.all(mask == 0)``) then ``space.start`` will be returned.
Returns: Returns:
A sampled integer from the space A sampled integer from the space

View File

@@ -27,7 +27,7 @@ class GraphInstance(NamedTuple):
class Graph(Space[GraphInstance]): class Graph(Space[GraphInstance]):
r"""A space representing graph information as a series of `nodes` connected with `edges` according to an adjacency matrix represented as a series of `edge_links`. r"""A space representing graph information as a series of ``nodes`` connected with ``edges`` according to an adjacency matrix represented as a series of ``edge_links``.
Example: Example:
>>> from gymnasium.spaces import Graph, Box, Discrete >>> from gymnasium.spaces import Graph, Box, Discrete
@@ -122,14 +122,14 @@ class Graph(Space[GraphInstance]):
num_nodes: int = 10, num_nodes: int = 10,
num_edges: int | None = None, num_edges: int | None = None,
) -> GraphInstance: ) -> GraphInstance:
"""Generates a single sample graph with num_nodes between 1 and 10 sampled from the Graph. """Generates a single sample graph with num_nodes between ``1`` and ``10`` sampled from the Graph.
Args: Args:
mask: An optional tuple of optional node and edge mask that is only possible with Discrete spaces mask: An optional tuple of optional node and edge mask that is only possible with Discrete spaces
(Box spaces don't support sample masks). (Box spaces don't support sample masks).
If no `num_edges` is provided then the `edge_mask` is multiplied by the number of edges If no ``num_edges`` is provided then the ``edge_mask`` is multiplied by the number of edges
num_nodes: The number of nodes that will be sampled, the default is 10 nodes num_nodes: The number of nodes that will be sampled, the default is `10` nodes
num_edges: An optional number of edges, otherwise, a random number between 0 and `num_nodes` ^ 2 num_edges: An optional number of edges, otherwise, a random number between `0` and :math:`num_nodes^2`
Returns: Returns:
A :class:`GraphInstance` with attributes `.nodes`, `.edges`, and `.edge_links`. A :class:`GraphInstance` with attributes `.nodes`, `.edges`, and `.edge_links`.
@@ -212,7 +212,7 @@ class Graph(Space[GraphInstance]):
def __repr__(self) -> str: def __repr__(self) -> str:
"""A string representation of this space. """A string representation of this space.
The representation will include node_space and edge_space The representation will include ``node_space`` and ``edge_space``
Returns: Returns:
A representation of the space A representation of the space

View File

@@ -65,8 +65,8 @@ class MultiBinary(Space[NDArray[np.int8]]):
Args: Args:
mask: An optional np.ndarray to mask samples with expected shape of ``space.shape``. mask: An optional np.ndarray to mask samples with expected shape of ``space.shape``.
For mask == 0 then the samples will be 0 and mask == 1 then random samples will be generated. For ``mask == 0`` then the samples will be ``0`` and ``mask == 1` then random samples will be generated.
The expected mask shape is the space shape and mask dtype is `np.int8`. The expected mask shape is the space shape and mask dtype is ``np.int8``.
Returns: Returns:
Sampled values from space Sampled values from space

View File

@@ -87,12 +87,12 @@ class MultiDiscrete(Space[NDArray[np.integer]]):
"""Generates a single random sample this space. """Generates a single random sample this space.
Args: Args:
mask: An optional mask for multi-discrete, expects tuples with a `np.ndarray` mask in the position of each mask: An optional mask for multi-discrete, expects tuples with a ``np.ndarray`` mask in the position of each
action with shape `(n,)` where `n` is the number of actions and `dtype=np.int8`. action with shape ``(n,)`` where ``n`` is the number of actions and ``dtype=np.int8``.
Only mask values == 1 are possible to sample unless all mask values for an action are 0 then the default action `self.start` (the smallest element) is sampled. Only ``mask values == 1`` are possible to sample unless all mask values for an action are ``0`` then the default action ``self.start`` (the smallest element) is sampled.
Returns: Returns:
An `np.ndarray` of shape `space.shape` An ``np.ndarray`` of :meth:`Space.shape`
""" """
if mask is not None: if mask is not None:
@@ -206,6 +206,7 @@ class MultiDiscrete(Space[NDArray[np.integer]]):
"""Check whether ``other`` is equivalent to this instance.""" """Check whether ``other`` is equivalent to this instance."""
return bool( return bool(
isinstance(other, MultiDiscrete) isinstance(other, MultiDiscrete)
and self.dtype == other.dtype
and np.all(self.nvec == other.nvec) and np.all(self.nvec == other.nvec)
and np.all(self.start == other.start) and np.all(self.start == other.start)
) )

View File

@@ -38,7 +38,7 @@ class Sequence(Space[Union[typing.Tuple[Any, ...], Any]]):
Args: Args:
space: Elements in the sequences this space represent must belong to this space. space: Elements in the sequences this space represent must belong to this space.
seed: Optionally, you can use this argument to seed the RNG that is used to sample from the space. seed: Optionally, you can use this argument to seed the RNG that is used to sample from the space.
stack: If `True` then the resulting samples would be stacked. stack: If ``True`` then the resulting samples would be stacked.
""" """
assert isinstance( assert isinstance(
space, Space space, Space
@@ -78,14 +78,13 @@ class Sequence(Space[Union[typing.Tuple[Any, ...], Any]]):
Args: Args:
mask: An optional mask for (optionally) the length of the sequence and (optionally) the values in the sequence. mask: An optional mask for (optionally) the length of the sequence and (optionally) the values in the sequence.
If you specify `mask`, it is expected to be a tuple of the form `(length_mask, sample_mask)` where `length_mask` If you specify ``mask``, it is expected to be a tuple of the form ``(length_mask, sample_mask)`` where ``length_mask`` is
is
* ``None`` The length will be randomly drawn from a geometric distribution * ``None`` The length will be randomly drawn from a geometric distribution
* ``np.ndarray`` of integers, in which case the length of the sampled sequence is randomly drawn from this array. * ``np.ndarray`` of integers, in which case the length of the sampled sequence is randomly drawn from this array.
* ``int`` for a fixed length sample * ``int`` for a fixed length sample
The second element of the mask tuple `sample` mask specifies a mask that is applied when The second element of the mask tuple ``sample`` mask specifies a mask that is applied when
sampling elements from the base space. The mask is applied for each feature space sample. sampling elements from the base space. The mask is applied for each feature space sample.
Returns: Returns:

View File

@@ -78,13 +78,13 @@ class Text(Space[str]):
self, self,
mask: None | (tuple[int | None, NDArray[np.int8] | None]) = None, mask: None | (tuple[int | None, NDArray[np.int8] | None]) = None,
) -> str: ) -> str:
"""Generates a single random sample from this space with by default a random length between `min_length` and `max_length` and sampled from the `charset`. """Generates a single random sample from this space with by default a random length between ``min_length`` and ``max_length`` and sampled from the ``charset``.
Args: Args:
mask: An optional tuples of length and mask for the text. mask: An optional tuples of length and mask for the text.
The length is expected to be between the `min_length` and `max_length` otherwise a random integer between `min_length` and `max_length` is selected. The length is expected to be between the ``min_length`` and ``max_length`` otherwise a random integer between ``min_length`` and ``max_length`` is selected.
For the mask, we expect a numpy array of length of the charset passed with `dtype == np.int8`. For the mask, we expect a numpy array of length of the charset passed with ``dtype == np.int8``.
If the charlist mask is all zero then an empty string is returned no matter the `min_length` If the charlist mask is all zero then an empty string is returned no matter the ``min_length``
Returns: Returns:
A sampled string from the space A sampled string from the space

View File

@@ -53,8 +53,8 @@ class Tuple(Space[typing.Tuple[Any, ...]], typing.Sequence[Any]):
Depending on the type of seed, the subspaces will be seeded differently Depending on the type of seed, the subspaces will be seeded differently
* ``None`` - All the subspaces will use a random initial seed * ``None`` - All the subspaces will use a random initial seed
* ``Int`` - The integer is used to seed the `Tuple` space that is used to generate seed values for each of the subspaces. Warning, this does not guarantee unique seeds for all of the subspaces. * ``Int`` - The integer is used to seed the :class:`Tuple` space that is used to generate seed values for each of the subspaces. Warning, this does not guarantee unique seeds for all the subspaces.
* ``List`` - Values used to seed the subspaces. This allows the seeding of multiple composite subspaces (``List(42, 54, ...``). * ``List`` - Values used to seed the subspaces. This allows the seeding of multiple composite subspaces ``[42, 54, ...]``.
Args: Args:
seed: An optional list of ints or int to seed the (sub-)spaces. seed: An optional list of ints or int to seed the (sub-)spaces.

View File

@@ -428,9 +428,7 @@ def flatten_space(space: Space[Any]) -> Box | Dict | Sequence | Tuple | Graph:
Raises: Raises:
NotImplementedError: if the space is not defined in :mod:`gymnasium.spaces`. NotImplementedError: if the space is not defined in :mod:`gymnasium.spaces`.
Example: Example - Flatten spaces.Box:
Flatten spaces.Box:
>>> from gymnasium.spaces import Box >>> from gymnasium.spaces import Box
>>> box = Box(0.0, 1.0, shape=(3, 4, 5)) >>> box = Box(0.0, 1.0, shape=(3, 4, 5))
>>> box >>> box
@@ -440,8 +438,7 @@ def flatten_space(space: Space[Any]) -> Box | Dict | Sequence | Tuple | Graph:
>>> flatten(box, box.sample()) in flatten_space(box) >>> flatten(box, box.sample()) in flatten_space(box)
True True
Flatten spaces.Discrete: Example - Flatten spaces.Discrete:
>>> from gymnasium.spaces import Discrete >>> from gymnasium.spaces import Discrete
>>> discrete = Discrete(5) >>> discrete = Discrete(5)
>>> flatten_space(discrete) >>> flatten_space(discrete)
@@ -449,8 +446,7 @@ def flatten_space(space: Space[Any]) -> Box | Dict | Sequence | Tuple | Graph:
>>> flatten(discrete, discrete.sample()) in flatten_space(discrete) >>> flatten(discrete, discrete.sample()) in flatten_space(discrete)
True True
Flatten spaces.Dict: Example - Flatten spaces.Dict:
>>> from gymnasium.spaces import Dict, Discrete, Box >>> from gymnasium.spaces import Dict, Discrete, Box
>>> space = Dict({"position": Discrete(2), "velocity": Box(0, 1, shape=(2, 2))}) >>> space = Dict({"position": Discrete(2), "velocity": Box(0, 1, shape=(2, 2))})
>>> flatten_space(space) >>> flatten_space(space)
@@ -458,8 +454,7 @@ def flatten_space(space: Space[Any]) -> Box | Dict | Sequence | Tuple | Graph:
>>> flatten(space, space.sample()) in flatten_space(space) >>> flatten(space, space.sample()) in flatten_space(space)
True True
Flatten spaces.Graph: Example - Flatten spaces.Graph:
>>> from gymnasium.spaces import Graph, Discrete, Box >>> from gymnasium.spaces import Graph, Discrete, Box
>>> space = Graph(node_space=Box(low=-100, high=100, shape=(3, 4)), edge_space=Discrete(5)) >>> space = Graph(node_space=Box(low=-100, high=100, shape=(3, 4)), edge_space=Discrete(5))
>>> flatten_space(space) >>> flatten_space(space)

View File

@@ -1,4 +1,4 @@
"""A set of functions for checking an environment details. """A set of functions for checking an environment implementation.
This file is originally from the Stable Baselines3 repository hosted on GitHub This file is originally from the Stable Baselines3 repository hosted on GitHub
(https://github.com/DLR-RM/stable-baselines3/) (https://github.com/DLR-RM/stable-baselines3/)
@@ -63,7 +63,7 @@ def data_equivalence(data_1, data_2) -> bool:
return False return False
def check_reset_seed(env: gym.Env) -> None: def check_reset_seed(env: gym.Env):
"""Check that the environment can be reset with a seed. """Check that the environment can be reset with a seed.
Args: Args:
@@ -132,7 +132,7 @@ def check_reset_seed(env: gym.Env) -> None:
) )
def check_reset_options(env: gym.Env) -> None: def check_reset_options(env: gym.Env):
"""Check that the environment can be reset with options. """Check that the environment can be reset with options.
Args: Args:
@@ -160,7 +160,7 @@ def check_reset_options(env: gym.Env) -> None:
) )
def check_reset_return_info_deprecation(env: gym.Env) -> None: def check_reset_return_info_deprecation(env: gym.Env):
"""Makes sure support for deprecated `return_info` argument is dropped. """Makes sure support for deprecated `return_info` argument is dropped.
Args: Args:
@@ -177,7 +177,7 @@ def check_reset_return_info_deprecation(env: gym.Env) -> None:
) )
def check_seed_deprecation(env: gym.Env) -> None: def check_seed_deprecation(env: gym.Env):
"""Makes sure support for deprecated function `seed` is dropped. """Makes sure support for deprecated function `seed` is dropped.
Args: Args:
@@ -193,7 +193,7 @@ def check_seed_deprecation(env: gym.Env) -> None:
) )
def check_reset_return_type(env: gym.Env) -> None: def check_reset_return_type(env: gym.Env):
"""Checks that :meth:`reset` correctly returns a tuple of the form `(obs , info)`. """Checks that :meth:`reset` correctly returns a tuple of the form `(obs , info)`.
Args: Args:
@@ -218,7 +218,7 @@ def check_reset_return_type(env: gym.Env) -> None:
), f"The second element returned by `env.reset()` was not a dictionary, actual type: {type(info)}" ), f"The second element returned by `env.reset()` was not a dictionary, actual type: {type(info)}"
def check_space_limit(space: spaces.Space, space_type: str) -> None: def check_space_limit(space, space_type: str):
"""Check the space limit for only the Box space as a test that only runs as part of `check_env`.""" """Check the space limit for only the Box space as a test that only runs as part of `check_env`."""
if isinstance(space, spaces.Box): if isinstance(space, spaces.Box):
if np.any(np.equal(space.low, -np.inf)): if np.any(np.equal(space.low, -np.inf)):
@@ -256,18 +256,19 @@ def check_space_limit(space: spaces.Space, space_type: str) -> None:
check_space_limit(subspace, space_type) check_space_limit(subspace, space_type)
def check_env(env: gym.Env, warn: bool = None, skip_render_check: bool = False) -> None: def check_env(env: gym.Env, warn: bool = None, skip_render_check: bool = False):
"""Check that an environment follows Gym API. """Check that an environment follows Gymnasium's API.
This is an invasive function that calls the environment's reset and step. .. py:currentmodule:: gymnasium.Env
This is particularly useful when using a custom environment. To ensure that an environment is implemented "correctly", ``check_env`` checks that the :attr:`observation_space` and :attr:`action_space` are correct.
Please take a look at https://gymnasium.farama.org/tutorials/gymnasium_basics/environment_creation/ Furthermore, the function will call the :meth:`reset`, :meth:`step` and :meth:`render` functions with a variety of values.
for more information about the API.
We highly recommend users calling this function after an environment is constructed and within a projects continuous integration to keep an environment update with Gymnasium's API.
Args: Args:
env: The Gym environment that will be checked env: The Gym environment that will be checked
warn: Ignored warn: Ignored, previously silenced particular warnings
skip_render_check: Whether to skip the checks for the render method. True by default (useful for the CI) skip_render_check: Whether to skip the checks for the render method. True by default (useful for the CI)
""" """
if warn is not None: if warn is not None:

View File

@@ -12,6 +12,8 @@ __all__ = [
"env_render_passive_checker", "env_render_passive_checker",
"env_reset_passive_checker", "env_reset_passive_checker",
"env_step_passive_checker", "env_step_passive_checker",
"check_action_space",
"check_observation_space",
] ]

View File

@@ -1,6 +1,8 @@
"""Utilities of visualising an environment.""" """Utilities of visualising an environment."""
from __future__ import annotations
from collections import deque from collections import deque
from typing import Callable, Dict, List, Optional, Tuple, Union from typing import Callable, List
import numpy as np import numpy as np
@@ -40,8 +42,8 @@ class PlayableGame:
def __init__( def __init__(
self, self,
env: Env, env: Env,
keys_to_action: Optional[Dict[Tuple[int, ...], int]] = None, keys_to_action: dict[tuple[int, ...], int] | None = None,
zoom: Optional[float] = None, zoom: float | None = None,
): ):
"""Wraps an environment with a dictionary of keyboard buttons to action and if to zoom in on the environment. """Wraps an environment with a dictionary of keyboard buttons to action and if to zoom in on the environment.
@@ -66,7 +68,7 @@ class PlayableGame:
self.running = True self.running = True
def _get_relevant_keys( def _get_relevant_keys(
self, keys_to_action: Optional[Dict[Tuple[int], int]] = None self, keys_to_action: dict[tuple[int], int] | None = None
) -> set: ) -> set:
if keys_to_action is None: if keys_to_action is None:
if hasattr(self.env, "get_keys_to_action"): if hasattr(self.env, "get_keys_to_action"):
@@ -83,7 +85,7 @@ class PlayableGame:
relevant_keys = set(sum((list(k) for k in keys_to_action.keys()), [])) relevant_keys = set(sum((list(k) for k in keys_to_action.keys()), []))
return relevant_keys return relevant_keys
def _get_video_size(self, zoom: Optional[float] = None) -> Tuple[int, int]: def _get_video_size(self, zoom: float | None = None) -> tuple[int, int]:
rendered = self.env.render() rendered = self.env.render()
if isinstance(rendered, List): if isinstance(rendered, List):
rendered = rendered[-1] rendered = rendered[-1]
@@ -123,7 +125,7 @@ class PlayableGame:
def display_arr( def display_arr(
screen: Surface, arr: np.ndarray, video_size: Tuple[int, int], transpose: bool screen: Surface, arr: np.ndarray, video_size: tuple[int, int], transpose: bool
): ):
"""Displays a numpy array on screen. """Displays a numpy array on screen.
@@ -147,15 +149,15 @@ def display_arr(
def play( def play(
env: Env, env: Env,
transpose: Optional[bool] = True, transpose: bool | None = True,
fps: Optional[int] = None, fps: int | None = None,
zoom: Optional[float] = None, zoom: float | None = None,
callback: Optional[Callable] = None, callback: Callable | None = None,
keys_to_action: Optional[Dict[Union[Tuple[Union[str, int]], str], ActType]] = None, keys_to_action: dict[tuple[str | int] | str, ActType] | None = None,
seed: Optional[int] = None, seed: int | None = None,
noop: ActType = 0, noop: ActType = 0,
): ):
"""Allows one to play the game using keyboard. """Allows the user to play the environment using a keyboard.
Args: Args:
env: Environment to use for playing. env: Environment to use for playing.
@@ -164,13 +166,14 @@ def play(
``env.metadata["render_fps""]`` (or 30, if the environment does not specify "render_fps") is used. ``env.metadata["render_fps""]`` (or 30, if the environment does not specify "render_fps") is used.
zoom: Zoom the observation in, ``zoom`` amount, should be positive float zoom: Zoom the observation in, ``zoom`` amount, should be positive float
callback: If a callback is provided, it will be executed after every step. It takes the following input: callback: If a callback is provided, it will be executed after every step. It takes the following input:
obs_t: observation before performing action
obs_tp1: observation after performing action * obs_t: observation before performing action
action: action that was executed * obs_tp1: observation after performing action
rew: reward that was received * action: action that was executed
terminated: whether the environment is terminated or not * rew: reward that was received
truncated: whether the environment is truncated or not * terminated: whether the environment is terminated or not
info: debug info * truncated: whether the environment is truncated or not
* info: debug info
keys_to_action: Mapping from keys pressed to action performed. keys_to_action: Mapping from keys pressed to action performed.
Different formats are supported: Key combinations can either be expressed as a tuple of unicode code Different formats are supported: Key combinations can either be expressed as a tuple of unicode code
points of the keys, as a tuple of characters, or as a string where each character of the string represents points of the keys, as a tuple of characters, or as a string where each character of the string represents
@@ -205,9 +208,9 @@ def play(
noop: The action used when no key input has been entered, or the entered key combination is unknown. noop: The action used when no key input has been entered, or the entered key combination is unknown.
Example: Example:
>>> import gymnasium as gym
>>> from gymnasium.utils.play import play >>> from gymnasium.utils.play import play
>>> play(gym.make("CarRacing-v2", render_mode="rgb_array"), keys_to_action={ # doctest: +SKIP >>> play(gym.make("CarRacing-v2", render_mode="rgb_array"), # doctest: +SKIP
... keys_to_action={
... "w": np.array([0, 0.7, 0]), ... "w": np.array([0, 0.7, 0]),
... "a": np.array([-1, 0, 0]), ... "a": np.array([-1, 0, 0]),
... "s": np.array([0, 0, 1]), ... "s": np.array([0, 0, 1]),
@@ -216,17 +219,18 @@ def play(
... "dw": np.array([1, 0.7, 0]), ... "dw": np.array([1, 0.7, 0]),
... "ds": np.array([1, 0, 1]), ... "ds": np.array([1, 0, 1]),
... "as": np.array([-1, 0, 1]), ... "as": np.array([-1, 0, 1]),
... }, noop=np.array([0,0,0])) ... },
... noop=np.array([0, 0, 0])
... )
Above code works also if the environment is wrapped, so it's particularly useful in Above code works also if the environment is wrapped, so it's particularly useful in
verifying that the frame-level preprocessing does not render the game verifying that the frame-level preprocessing does not render the game
unplayable. unplayable.
If you wish to plot real time statistics as you play, you can use If you wish to plot real time statistics as you play, you can use
:class:`gym.utils.play.PlayPlot`. Here's a sample code for plotting the reward :class:`PlayPlot`. Here's a sample code for plotting the reward
for last 150 steps. for last 150 steps.
>>> import gymnasium as gym
>>> from gymnasium.utils.play import PlayPlot, play >>> from gymnasium.utils.play import PlayPlot, play
>>> def callback(obs_t, obs_tp1, action, rew, terminated, truncated, info): >>> def callback(obs_t, obs_tp1, action, rew, terminated, truncated, info):
... return [rew,] ... return [rew,]
@@ -321,7 +325,7 @@ class PlayPlot:
""" """
def __init__( def __init__(
self, callback: Callable, horizon_timesteps: int, plot_names: List[str] self, callback: Callable, horizon_timesteps: int, plot_names: list[str]
): ):
"""Constructor of :class:`PlayPlot`. """Constructor of :class:`PlayPlot`.
@@ -355,7 +359,7 @@ class PlayPlot:
for axis, name in zip(self.ax, plot_names): for axis, name in zip(self.ax, plot_names):
axis.set_title(name) axis.set_title(name)
self.t = 0 self.t = 0
self.cur_plot: List[Optional[plt.Axes]] = [None for _ in range(num_plots)] self.cur_plot: list[plt.Axes | None] = [None for _ in range(num_plots)]
self.data = [deque(maxlen=horizon_timesteps) for _ in range(num_plots)] self.data = [deque(maxlen=horizon_timesteps) for _ in range(num_plots)]
def callback( def callback(

View File

@@ -15,9 +15,9 @@ except ImportError as e:
def capped_cubic_video_schedule(episode_id: int) -> bool: def capped_cubic_video_schedule(episode_id: int) -> bool:
"""The default episode trigger. r"""The default episode trigger.
This function will trigger recordings at the episode indices 0, 1, 4, 8, 27, ..., :math:`k^3`, ..., 729, 1000, 2000, 3000, ... This function will trigger recordings at the episode indices :math:`\{0, 1, 4, 8, 27, ..., k^3, ..., 729, 1000, 2000, 3000, ...\}`
Args: Args:
episode_id: The episode number episode_id: The episode number

View File

@@ -1,22 +1,29 @@
"""Set of random number generator functions: seeding, generator, hashing seeds.""" """Set of random number generator functions: seeding, generator, hashing seeds."""
from typing import Any, Optional, Tuple from __future__ import annotations
import numpy as np import numpy as np
from gymnasium import error from gymnasium import error
def np_random(seed: Optional[int] = None) -> Tuple[np.random.Generator, Any]: def np_random(seed: int | None = None) -> tuple[np.random.Generator, int]:
"""Generates a random number generator from the seed and returns the Generator and seed. """Returns a NumPy random number generator (RNG) along with seed value from the inputted seed.
If ``seed`` is ``None`` then a **random** seed will be generated as the RNG's initial seed.
This randomly selected seed is returned as the second value of the tuple.
.. py:currentmodule:: gymnasium.Env
This function is called in :meth:`reset` to reset an environment's initial RNG.
Args: Args:
seed: The seed used to create the generator seed: The seed used to create the generator
Returns: Returns:
The generator and resulting seed A NumPy-based Random Number Generator and generator seed
Raises: Raises:
Error: Seed must be a non-negative integer or omitted Error: Seed must be a non-negative integer
""" """
if seed is not None and not (isinstance(seed, int) and 0 <= seed): if seed is not None and not (isinstance(seed, int) and 0 <= seed):
if isinstance(seed, int) is False: if isinstance(seed, int) is False:

View File

@@ -1,4 +1,6 @@
"""Contains methods for step compatibility, from old-to-new and new-to-old API.""" """Contains methods for step compatibility, from old-to-new and new-to-old API."""
from __future__ import annotations
from typing import SupportsFloat, Tuple, Union from typing import SupportsFloat, Tuple, Union
import numpy as np import numpy as np
@@ -23,13 +25,15 @@ TerminatedTruncatedStepType = Tuple[
def convert_to_terminated_truncated_step_api( def convert_to_terminated_truncated_step_api(
step_returns: Union[DoneStepType, TerminatedTruncatedStepType], is_vector_env=False step_returns: DoneStepType | TerminatedTruncatedStepType, is_vector_env=False
) -> TerminatedTruncatedStepType: ) -> TerminatedTruncatedStepType:
"""Function to transform step returns to new step API irrespective of input API. """Function to transform step returns to new step API irrespective of input API.
.. py:currentmodule:: gymnasium.Env
Args: Args:
step_returns (tuple): Items returned by step(). Can be (obs, rew, done, info) or (obs, rew, terminated, truncated, info) step_returns (tuple): Items returned by :meth:`step`. Can be ``(obs, rew, done, info)`` or ``(obs, rew, terminated, truncated, info)``
is_vector_env (bool): Whether the step_returns are from a vector environment is_vector_env (bool): Whether the ``step_returns`` are from a vector environment
""" """
if len(step_returns) == 5: if len(step_returns) == 5:
return step_returns return step_returns
@@ -75,14 +79,16 @@ def convert_to_terminated_truncated_step_api(
def convert_to_done_step_api( def convert_to_done_step_api(
step_returns: Union[TerminatedTruncatedStepType, DoneStepType], step_returns: TerminatedTruncatedStepType | DoneStepType,
is_vector_env: bool = False, is_vector_env: bool = False,
) -> DoneStepType: ) -> DoneStepType:
"""Function to transform step returns to old step API irrespective of input API. """Function to transform step returns to old step API irrespective of input API.
.. py:currentmodule:: gymnasium.Env
Args: Args:
step_returns (tuple): Items returned by step(). Can be (obs, rew, done, info) or (obs, rew, terminated, truncated, info) step_returns (tuple): Items returned by :meth:`step`. Can be ``(obs, rew, done, info)`` or ``(obs, rew, terminated, truncated, info)``
is_vector_env (bool): Whether the step_returns are from a vector environment is_vector_env (bool): Whether the ``step_returns`` are from a vector environment
""" """
if len(step_returns) == 4: if len(step_returns) == 4:
return step_returns return step_returns
@@ -130,38 +136,41 @@ def convert_to_done_step_api(
def step_api_compatibility( def step_api_compatibility(
step_returns: Union[TerminatedTruncatedStepType, DoneStepType], step_returns: TerminatedTruncatedStepType | DoneStepType,
output_truncation_bool: bool = True, output_truncation_bool: bool = True,
is_vector_env: bool = False, is_vector_env: bool = False,
) -> Union[TerminatedTruncatedStepType, DoneStepType]: ) -> TerminatedTruncatedStepType | DoneStepType:
"""Function to transform step returns to the API specified by `output_truncation_bool` bool. """Function to transform step returns to the API specified by ``output_truncation_bool``.
Done (old) step API refers to step() method returning (observation, reward, done, info) .. py:currentmodule:: gymnasium.Env
Terminated Truncated (new) step API refers to step() method returning (observation, reward, terminated, truncated, info)
Done (old) step API refers to :meth:`step` method returning ``(observation, reward, done, info)``
Terminated Truncated (new) step API refers to :meth:`step` method returning ``(observation, reward, terminated, truncated, info)``
(Refer to docs for details on the API change) (Refer to docs for details on the API change)
Args: Args:
step_returns (tuple): Items returned by step(). Can be (obs, rew, done, info) or (obs, rew, terminated, truncated, info) step_returns (tuple): Items returned by :meth:`step`. Can be ``(obs, rew, done, info)`` or ``(obs, rew, terminated, truncated, info)``
output_truncation_bool (bool): Whether the output should return two booleans (new API) or one (old) (True by default) output_truncation_bool (bool): Whether the output should return two booleans (new API) or one (old) (``True`` by default)
is_vector_env (bool): Whether the step_returns are from a vector environment is_vector_env (bool): Whether the ``step_returns`` are from a vector environment
Returns: Returns:
step_returns (tuple): Depending on `output_truncation_bool` bool, it can return `(obs, rew, done, info)` or `(obs, rew, terminated, truncated, info)` step_returns (tuple): Depending on ``output_truncation_bool``, it can return ``(obs, rew, done, info)`` or ``(obs, rew, terminated, truncated, info)``
Example: Example:
This function can be used to ensure compatibility in step interfaces with conflicting API. Eg. if env is written in old API, This function can be used to ensure compatibility in step interfaces with conflicting API. E.g. if env is written in old API,
wrapper is written in new API, and the final step output is desired to be in old API. wrapper is written in new API, and the final step output is desired to be in old API.
>>> import gymnasium as gym >>> import gymnasium as gym
>>> env = gym.make("CartPole-v0") >>> env = gym.make("CartPole-v0")
>>> _ = env.reset() >>> _, _ = env.reset()
>>> obs, rewards, done, info = step_api_compatibility(env.step(0), output_truncation_bool=False) >>> obs, reward, done, info = step_api_compatibility(env.step(0), output_truncation_bool=False)
>>> obs, rewards, terminated, truncated, info = step_api_compatibility(env.step(0), output_truncation_bool=True) >>> obs, reward, terminated, truncated, info = step_api_compatibility(env.step(0), output_truncation_bool=True)
>>> vec_env = gym.vector.make("CartPole-v0") >>> vec_env = gym.make_vec("CartPole-v0", vectorization_mode="sync")
>>> _ = vec_env.reset() >>> _, _ = vec_env.reset()
>>> obs, rewards, dones, infos = step_api_compatibility(vec_env.step([0]), is_vector_env=True, output_truncation_bool=False) >>> obs, rewards, dones, infos = step_api_compatibility(vec_env.step([0]), is_vector_env=True, output_truncation_bool=False)
>>> obs, rewards, terminated, truncated, info = step_api_compatibility(vec_env.step([0]), is_vector_env=True, output_truncation_bool=True) >>> obs, rewards, terminations, truncations, infos = step_api_compatibility(vec_env.step([0]), is_vector_env=True, output_truncation_bool=True)
""" """
if output_truncation_bool: if output_truncation_bool:
return convert_to_terminated_truncated_step_api(step_returns, is_vector_env) return convert_to_terminated_truncated_step_api(step_returns, is_vector_env)

View File

@@ -1,85 +1,23 @@
"""Module for vector environments.""" """Experimental vector env API."""
from typing import Callable, Iterable, List, Optional, Union
import gymnasium as gym
from gymnasium.core import Env
from gymnasium.vector import utils from gymnasium.vector import utils
from gymnasium.vector.async_vector_env import AsyncVectorEnv from gymnasium.vector.async_vector_env import AsyncVectorEnv
from gymnasium.vector.sync_vector_env import SyncVectorEnv from gymnasium.vector.sync_vector_env import SyncVectorEnv
from gymnasium.vector.vector_env import VectorEnv, VectorEnvWrapper from gymnasium.vector.vector_env import (
VectorActionWrapper,
VectorEnv,
VectorObservationWrapper,
VectorRewardWrapper,
VectorWrapper,
)
__all__ = [ __all__ = [
"AsyncVectorEnv",
"SyncVectorEnv",
"VectorEnv", "VectorEnv",
"VectorEnvWrapper", "VectorWrapper",
"make", "VectorObservationWrapper",
"VectorActionWrapper",
"VectorRewardWrapper",
"SyncVectorEnv",
"AsyncVectorEnv",
"utils", "utils",
] ]
def make(
id: str,
num_envs: int = 1,
asynchronous: bool = True,
wrappers: Optional[Union[Callable[[Env], Env], List[Callable[[Env], Env]]]] = None,
disable_env_checker: Optional[bool] = None,
**kwargs,
) -> VectorEnv:
"""Create a vectorized environment from multiple copies of an environment, from its id.
Args:
id: The environment ID. This must be a valid ID from the registry.
num_envs: Number of copies of the environment.
asynchronous: If `True`, wraps the environments in an :class:`AsyncVectorEnv` (which uses `multiprocessing` to run the environments in parallel). If ``False``, wraps the environments in a :class:`SyncVectorEnv`.
wrappers: If not ``None``, then apply the wrappers to each internal environment during creation.
disable_env_checker: If to run the env checker for the first environment only. None will default to the environment spec `disable_env_checker` parameter
(that is by default False), otherwise will run according to this argument (True = not run, False = run)
**kwargs: Keywords arguments applied during `gym.make`
Returns:
The vectorized environment.
Example:
>>> import gymnasium as gym
>>> env = gym.vector.make('CartPole-v1', num_envs=3)
>>> env.reset(seed=42)
(array([[ 0.0273956 , -0.00611216, 0.03585979, 0.0197368 ],
[ 0.01522993, -0.04562247, -0.04799704, 0.03392126],
[-0.03774345, -0.02418869, -0.00942293, 0.0469184 ]],
dtype=float32), {})
"""
gym.logger.warn(
"`gymnasium.vector.make(...)` is deprecated and will be replaced by `gymnasium.make_vec(...)` in v1.0"
)
def create_env(env_num: int) -> Callable[[], Env]:
"""Creates an environment that can enable or disable the environment checker."""
# If the env_num > 0 then disable the environment checker otherwise use the parameter
_disable_env_checker = True if env_num > 0 else disable_env_checker
def _make_env() -> Env:
env = gym.envs.registration.make(
id,
disable_env_checker=_disable_env_checker,
**kwargs,
)
if wrappers is not None:
if callable(wrappers):
env = wrappers(env)
elif isinstance(wrappers, Iterable) and all(
[callable(w) for w in wrappers]
):
for wrapper in wrappers:
env = wrapper(env)
else:
raise NotImplementedError
return env
return _make_env
env_fns = [
create_env(disable_env_checker or env_num > 0) for env_num in range(num_envs)
]
return AsyncVectorEnv(env_fns) if asynchronous else SyncVectorEnv(env_fns)

View File

@@ -1,17 +1,19 @@
"""An async vector environment.""" """An async vector environment."""
import multiprocessing as mp from __future__ import annotations
import multiprocessing
import sys import sys
import time import time
from copy import deepcopy from copy import deepcopy
from enum import Enum from enum import Enum
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union from multiprocessing import Queue
from multiprocessing.connection import Connection
from typing import Any, Callable, Sequence
import numpy as np import numpy as np
from numpy.typing import NDArray
import gymnasium as gym
from gymnasium import logger from gymnasium import logger
from gymnasium.core import Env, ObsType from gymnasium.core import ActType, Env, ObsType, RenderFrame
from gymnasium.error import ( from gymnasium.error import (
AlreadyPendingCallError, AlreadyPendingCallError,
ClosedEnvironmentError, ClosedEnvironmentError,
@@ -20,6 +22,7 @@ from gymnasium.error import (
) )
from gymnasium.vector.utils import ( from gymnasium.vector.utils import (
CloudpickleWrapper, CloudpickleWrapper,
batch_space,
clear_mpi_env_vars, clear_mpi_env_vars,
concatenate, concatenate,
create_empty_array, create_empty_array,
@@ -28,13 +31,15 @@ from gymnasium.vector.utils import (
read_from_shared_memory, read_from_shared_memory,
write_to_shared_memory, write_to_shared_memory,
) )
from gymnasium.vector.vector_env import VectorEnv from gymnasium.vector.vector_env import ArrayType, VectorEnv
__all__ = ["AsyncVectorEnv"] __all__ = ["AsyncVectorEnv", "AsyncState"]
class AsyncState(Enum): class AsyncState(Enum):
"""The AsyncVectorEnv possible states given the different actions."""
DEFAULT = "default" DEFAULT = "default"
WAITING_RESET = "reset" WAITING_RESET = "reset"
WAITING_STEP = "step" WAITING_STEP = "step"
@@ -48,39 +53,57 @@ class AsyncVectorEnv(VectorEnv):
Example: Example:
>>> import gymnasium as gym >>> import gymnasium as gym
>>> env = gym.vector.AsyncVectorEnv([ >>> envs = gym.make_vec("Pendulum-v1", num_envs=2, vectorization_mode="async")
>>> envs
AsyncVectorEnv(Pendulum-v1, num_envs=2)
>>> envs = gym.vector.AsyncVectorEnv([
... lambda: gym.make("Pendulum-v1", g=9.81), ... lambda: gym.make("Pendulum-v1", g=9.81),
... lambda: gym.make("Pendulum-v1", g=1.62) ... lambda: gym.make("Pendulum-v1", g=1.62)
... ]) ... ])
>>> env.reset(seed=42) >>> envs
(array([[-0.14995256, 0.9886932 , -0.12224312], AsyncVectorEnv(num_envs=2)
[ 0.5760367 , 0.8174238 , -0.91244936]], dtype=float32), {}) >>> observations, infos = envs.reset(seed=42)
>>> observations
array([[-0.14995256, 0.9886932 , -0.12224312],
[ 0.5760367 , 0.8174238 , -0.91244936]], dtype=float32)
>>> infos
{}
>>> _ = envs.action_space.seed(123)
>>> observations, rewards, terminations, truncations, infos = envs.step(envs.action_space.sample())
>>> observations
array([[-0.1851753 , 0.98270553, 0.714599 ],
[ 0.6193494 , 0.7851154 , -1.0808398 ]], dtype=float32)
>>> rewards
array([-2.96495728, -1.00214607])
>>> terminations
array([False, False])
>>> truncations
array([False, False])
>>> infos
{}
""" """
def __init__( def __init__(
self, self,
env_fns: Sequence[Callable[[], Env]], env_fns: Sequence[Callable[[], Env]],
observation_space: Optional[gym.Space] = None,
action_space: Optional[gym.Space] = None,
shared_memory: bool = True, shared_memory: bool = True,
copy: bool = True, copy: bool = True,
context: Optional[str] = None, context: str | None = None,
daemon: bool = True, daemon: bool = True,
worker: Optional[Callable] = None, worker: Callable[
[int, Callable[[], Env], Connection, Connection, bool, Queue], None
]
| None = None,
): ):
"""Vectorized environment that runs multiple environments in parallel. """Vectorized environment that runs multiple environments in parallel.
Args: Args:
env_fns: Functions that create the environments. env_fns: Functions that create the environments.
observation_space: Observation space of a single environment. If ``None``,
then the observation space of the first environment is taken.
action_space: Action space of a single environment. If ``None``,
then the action space of the first environment is taken.
shared_memory: If ``True``, then the observations from the worker processes are communicated back through shared_memory: If ``True``, then the observations from the worker processes are communicated back through
shared variables. This can improve the efficiency if the observations are large (e.g. images). shared variables. This can improve the efficiency if the observations are large (e.g. images).
copy: If ``True``, then the :meth:`~AsyncVectorEnv.reset` and :meth:`~AsyncVectorEnv.step` methods copy: If ``True``, then the :meth:`AsyncVectorEnv.reset` and :meth:`AsyncVectorEnv.step` methods
return a copy of the observations. return a copy of the observations.
context: Context for `multiprocessing`_. If ``None``, then the default context is used. context: Context for `multiprocessing`. If ``None``, then the default context is used.
daemon: If ``True``, then subprocesses have ``daemon`` flag turned on; that is, they will quit if daemon: If ``True``, then subprocesses have ``daemon`` flag turned on; that is, they will quit if
the head process quits. However, ``daemon=True`` prevents subprocesses to spawn children, the head process quits. However, ``daemon=True`` prevents subprocesses to spawn children,
so for some environments you may want to have it set to ``False``. so for some environments you may want to have it set to ``False``.
@@ -98,24 +121,33 @@ class AsyncVectorEnv(VectorEnv):
ValueError: If observation_space is a custom space (i.e. not a default space in Gym, ValueError: If observation_space is a custom space (i.e. not a default space in Gym,
such as gymnasium.spaces.Box, gymnasium.spaces.Discrete, or gymnasium.spaces.Dict) and shared_memory is True. such as gymnasium.spaces.Box, gymnasium.spaces.Discrete, or gymnasium.spaces.Dict) and shared_memory is True.
""" """
ctx = mp.get_context(context)
self.env_fns = env_fns self.env_fns = env_fns
self.shared_memory = shared_memory self.shared_memory = shared_memory
self.copy = copy self.copy = copy
dummy_env = env_fns[0]()
self.metadata = dummy_env.metadata
if (observation_space is None) or (action_space is None): self.num_envs = len(env_fns)
observation_space = observation_space or dummy_env.observation_space
action_space = action_space or dummy_env.action_space # This would be nice to get rid of, but without it there's a deadlock between shared memory and pipes
# Create a dummy environment to gather the metadata and observation / action space of the environment
dummy_env = env_fns[0]()
# As we support `make_vec(spec)` then we can't include a `spec = dummy_env.spec` as this doesn't guarantee we can actual recreate the vector env.
self.metadata = dummy_env.metadata
self.render_mode = dummy_env.render_mode
self.single_observation_space = dummy_env.observation_space
self.single_action_space = dummy_env.action_space
self.observation_space = batch_space(
self.single_observation_space, self.num_envs
)
self.action_space = batch_space(self.single_action_space, self.num_envs)
dummy_env.close() dummy_env.close()
del dummy_env del dummy_env
super().__init__(
num_envs=len(env_fns),
observation_space=observation_space,
action_space=action_space,
)
# Generate the multiprocessing context for the observation buffer
ctx = multiprocessing.get_context(context)
if self.shared_memory: if self.shared_memory:
try: try:
_obs_buffer = create_shared_memory( _obs_buffer = create_shared_memory(
@@ -126,12 +158,9 @@ class AsyncVectorEnv(VectorEnv):
) )
except CustomSpaceError as e: except CustomSpaceError as e:
raise ValueError( raise ValueError(
"Using `shared_memory=True` in `AsyncVectorEnv` " "Using `shared_memory=True` in `AsyncVectorEnv` is incompatible with non-standard Gymnasium observation spaces (i.e. custom spaces inheriting from `gymnasium.Space`), "
"is incompatible with non-standard Gymnasium observation spaces " "and is only compatible with default Gymnasium spaces (e.g. `Box`, `Tuple`, `Dict`) for batching. "
"(i.e. custom spaces inheriting from `gymnasium.Space`), and is " "Set `shared_memory=False` if you use custom observation spaces."
"only compatible with default Gymnasium spaces (e.g. `Box`, "
"`Tuple`, `Dict`) for batching. Set `shared_memory=False` "
"if you use custom observation spaces."
) from e ) from e
else: else:
_obs_buffer = None _obs_buffer = None
@@ -141,8 +170,7 @@ class AsyncVectorEnv(VectorEnv):
self.parent_pipes, self.processes = [], [] self.parent_pipes, self.processes = [], []
self.error_queue = ctx.Queue() self.error_queue = ctx.Queue()
target = _worker_shared_memory if self.shared_memory else _worker target = worker or _async_worker
target = worker or target
with clear_mpi_env_vars(): with clear_mpi_env_vars():
for idx, env_fn in enumerate(self.env_fns): for idx, env_fn in enumerate(self.env_fns):
parent_pipe, child_pipe = ctx.Pipe() parent_pipe, child_pipe = ctx.Pipe()
@@ -169,10 +197,28 @@ class AsyncVectorEnv(VectorEnv):
self._state = AsyncState.DEFAULT self._state = AsyncState.DEFAULT
self._check_spaces() self._check_spaces()
def reset(
self,
*,
seed: int | list[int] | None = None,
options: dict[str, Any] | None = None,
) -> tuple[ObsType, dict[str, Any]]:
"""Resets all sub-environments in parallel and return a batch of concatenated observations and info.
Args:
seed: The environment reset seeds
options: If to return the options
Returns:
A batch of observations and info from the vectorized environment.
"""
self.reset_async(seed=seed, options=options)
return self.reset_wait()
def reset_async( def reset_async(
self, self,
seed: Optional[Union[int, List[int]]] = None, seed: int | list[int] | None = None,
options: Optional[dict] = None, options: dict | None = None,
): ):
"""Send calls to the :obj:`reset` methods of the sub-environments. """Send calls to the :obj:`reset` methods of the sub-environments.
@@ -192,38 +238,29 @@ class AsyncVectorEnv(VectorEnv):
if seed is None: if seed is None:
seed = [None for _ in range(self.num_envs)] seed = [None for _ in range(self.num_envs)]
if isinstance(seed, int): elif isinstance(seed, int):
seed = [seed + i for i in range(self.num_envs)] seed = [seed + i for i in range(self.num_envs)]
assert len(seed) == self.num_envs assert len(seed) == self.num_envs
if self._state != AsyncState.DEFAULT: if self._state != AsyncState.DEFAULT:
raise AlreadyPendingCallError( raise AlreadyPendingCallError(
f"Calling `reset_async` while waiting for a pending call to `{self._state.value}` to complete", f"Calling `reset_async` while waiting for a pending call to `{self._state.value}` to complete",
self._state.value, str(self._state.value),
) )
for pipe, single_seed in zip(self.parent_pipes, seed): for pipe, env_seed in zip(self.parent_pipes, seed):
single_kwargs = {} env_kwargs = {"seed": env_seed, "options": options}
if single_seed is not None: pipe.send(("reset", env_kwargs))
single_kwargs["seed"] = single_seed
if options is not None:
single_kwargs["options"] = options
pipe.send(("reset", single_kwargs))
self._state = AsyncState.WAITING_RESET self._state = AsyncState.WAITING_RESET
def reset_wait( def reset_wait(
self, self,
timeout: Optional[Union[int, float]] = None, timeout: int | float | None = None,
seed: Optional[int] = None, ) -> tuple[ObsType, dict[str, Any]]:
options: Optional[dict] = None,
) -> Union[ObsType, Tuple[ObsType, dict]]:
"""Waits for the calls triggered by :meth:`reset_async` to finish and returns the results. """Waits for the calls triggered by :meth:`reset_async` to finish and returns the results.
Args: Args:
timeout: Number of seconds before the call to `reset_wait` times out. If `None`, the call to `reset_wait` never times out. timeout: Number of seconds before the call to ``reset_wait`` times out. If `None`, the call to ``reset_wait`` never times out.
seed: ignored
options: ignored
Returns: Returns:
A tuple of batched observations and list of dictionaries A tuple of batched observations and list of dictionaries
@@ -240,15 +277,14 @@ class AsyncVectorEnv(VectorEnv):
AsyncState.WAITING_RESET.value, AsyncState.WAITING_RESET.value,
) )
if not self._poll(timeout): if not self._poll_pipe_envs(timeout):
self._state = AsyncState.DEFAULT self._state = AsyncState.DEFAULT
raise mp.TimeoutError( raise multiprocessing.TimeoutError(
f"The call to `reset_wait` has timed out after {timeout} second(s)." f"The call to `reset_wait` has timed out after {timeout} second(s)."
) )
results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes]) results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
self._raise_if_errors(successes) self._raise_if_errors(successes)
self._state = AsyncState.DEFAULT
infos = {} infos = {}
results, info_data = zip(*results) results, info_data = zip(*results)
@@ -260,13 +296,28 @@ class AsyncVectorEnv(VectorEnv):
self.single_observation_space, results, self.observations self.single_observation_space, results, self.observations
) )
self._state = AsyncState.DEFAULT
return (deepcopy(self.observations) if self.copy else self.observations), infos return (deepcopy(self.observations) if self.copy else self.observations), infos
def step_async(self, actions: np.ndarray): def step(
"""Send the calls to :obj:`step` to each sub-environment. self, actions: ActType
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict[str, Any]]:
"""Take an action for each parallel environment.
Args: Args:
actions: Batch of actions. element of :attr:`~VectorEnv.action_space` actions: element of :attr:`action_space` batch of actions.
Returns:
Batch of (observations, rewards, terminations, truncations, infos)
"""
self.step_async(actions)
return self.step_wait()
def step_async(self, actions: np.ndarray):
"""Send the calls to :meth:`Env.step` to each sub-environment.
Args:
actions: Batch of actions. element of :attr:`VectorEnv.action_space`
Raises: Raises:
ClosedEnvironmentError: If the environment was closed (if :meth:`close` was previously called). ClosedEnvironmentError: If the environment was closed (if :meth:`close` was previously called).
@@ -279,17 +330,17 @@ class AsyncVectorEnv(VectorEnv):
if self._state != AsyncState.DEFAULT: if self._state != AsyncState.DEFAULT:
raise AlreadyPendingCallError( raise AlreadyPendingCallError(
f"Calling `step_async` while waiting for a pending call to `{self._state.value}` to complete.", f"Calling `step_async` while waiting for a pending call to `{self._state.value}` to complete.",
self._state.value, str(self._state.value),
) )
actions = iterate(self.action_space, actions) iter_actions = iterate(self.action_space, actions)
for pipe, action in zip(self.parent_pipes, actions): for pipe, action in zip(self.parent_pipes, iter_actions):
pipe.send(("step", action)) pipe.send(("step", action))
self._state = AsyncState.WAITING_STEP self._state = AsyncState.WAITING_STEP
def step_wait( def step_wait(
self, timeout: Optional[Union[int, float]] = None self, timeout: int | float | None = None
) -> Tuple[Any, NDArray[Any], NDArray[Any], NDArray[Any], dict]: ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, dict]:
"""Wait for the calls to :obj:`step` in each sub-environment to finish. """Wait for the calls to :obj:`step` in each sub-environment to finish.
Args: Args:
@@ -310,44 +361,61 @@ class AsyncVectorEnv(VectorEnv):
AsyncState.WAITING_STEP.value, AsyncState.WAITING_STEP.value,
) )
if not self._poll(timeout): if not self._poll_pipe_envs(timeout):
self._state = AsyncState.DEFAULT self._state = AsyncState.DEFAULT
raise mp.TimeoutError( raise multiprocessing.TimeoutError(
f"The call to `step_wait` has timed out after {timeout} second(s)." f"The call to `step_wait` has timed out after {timeout} second(s)."
) )
observations_list, rewards, terminateds, truncateds, infos = [], [], [], [], {} observations, rewards, terminations, truncations, infos = [], [], [], [], {}
successes = [] successes = []
for i, pipe in enumerate(self.parent_pipes): for env_idx, pipe in enumerate(self.parent_pipes):
result, success = pipe.recv() env_step_return, success = pipe.recv()
successes.append(success) successes.append(success)
if success: if success:
obs, rew, terminated, truncated, info = result observations.append(env_step_return[0])
rewards.append(env_step_return[1])
observations_list.append(obs) terminations.append(env_step_return[2])
rewards.append(rew) truncations.append(env_step_return[3])
terminateds.append(terminated) infos = self._add_info(infos, env_step_return[4], env_idx)
truncateds.append(truncated)
infos = self._add_info(infos, info, i)
self._raise_if_errors(successes) self._raise_if_errors(successes)
self._state = AsyncState.DEFAULT
if not self.shared_memory: if not self.shared_memory:
self.observations = concatenate( self.observations = concatenate(
self.single_observation_space, self.single_observation_space,
observations_list, observations,
self.observations, self.observations,
) )
self._state = AsyncState.DEFAULT
return ( return (
deepcopy(self.observations) if self.copy else self.observations, deepcopy(self.observations) if self.copy else self.observations,
np.array(rewards), np.array(rewards, dtype=np.float64),
np.array(terminateds, dtype=np.bool_), np.array(terminations, dtype=np.bool_),
np.array(truncateds, dtype=np.bool_), np.array(truncations, dtype=np.bool_),
infos, infos,
) )
def call(self, name: str, *args: Any, **kwargs: Any) -> tuple[Any, ...]:
"""Call a method from each parallel environment with args and kwargs.
Args:
name (str): Name of the method or property to call.
*args: Position arguments to apply to the method call.
**kwargs: Keyword arguments to apply to the method call.
Returns:
List of the results of the individual calls to the method or property for each environment.
"""
self.call_async(name, *args, **kwargs)
return self.call_wait()
def render(self) -> tuple[RenderFrame, ...] | None:
"""Returns a list of rendered frames from the environments."""
return self.call("render")
def call_async(self, name: str, *args, **kwargs): def call_async(self, name: str, *args, **kwargs):
"""Calls the method with name asynchronously and apply args and kwargs to the method. """Calls the method with name asynchronously and apply args and kwargs to the method.
@@ -363,28 +431,27 @@ class AsyncVectorEnv(VectorEnv):
self._assert_is_running() self._assert_is_running()
if self._state != AsyncState.DEFAULT: if self._state != AsyncState.DEFAULT:
raise AlreadyPendingCallError( raise AlreadyPendingCallError(
"Calling `call_async` while waiting " f"Calling `call_async` while waiting for a pending call to `{self._state.value}` to complete.",
f"for a pending call to `{self._state.value}` to complete.", str(self._state.value),
self._state.value,
) )
for pipe in self.parent_pipes: for pipe in self.parent_pipes:
pipe.send(("_call", (name, args, kwargs))) pipe.send(("_call", (name, args, kwargs)))
self._state = AsyncState.WAITING_CALL self._state = AsyncState.WAITING_CALL
def call_wait(self, timeout: Optional[Union[int, float]] = None) -> list: def call_wait(self, timeout: int | float | None = None) -> tuple[Any, ...]:
"""Calls all parent pipes and waits for the results. """Calls all parent pipes and waits for the results.
Args: Args:
timeout: Number of seconds before the call to `step_wait` times out. timeout: Number of seconds before the call to :meth:`step_wait` times out.
If `None` (default), the call to `step_wait` never times out. If ``None`` (default), the call to :meth:`step_wait` never times out.
Returns: Returns:
List of the results of the individual calls to the method or property for each environment. List of the results of the individual calls to the method or property for each environment.
Raises: Raises:
NoAsyncCallError: Calling `call_wait` without any prior call to `call_async`. NoAsyncCallError: Calling :meth:`call_wait` without any prior call to :meth:`call_async`.
TimeoutError: The call to `call_wait` has timed out after timeout second(s). TimeoutError: The call to :meth:`call_wait` has timed out after timeout second(s).
""" """
self._assert_is_running() self._assert_is_running()
if self._state != AsyncState.WAITING_CALL: if self._state != AsyncState.WAITING_CALL:
@@ -393,9 +460,9 @@ class AsyncVectorEnv(VectorEnv):
AsyncState.WAITING_CALL.value, AsyncState.WAITING_CALL.value,
) )
if not self._poll(timeout): if not self._poll_pipe_envs(timeout):
self._state = AsyncState.DEFAULT self._state = AsyncState.DEFAULT
raise mp.TimeoutError( raise multiprocessing.TimeoutError(
f"The call to `call_wait` has timed out after {timeout} second(s)." f"The call to `call_wait` has timed out after {timeout} second(s)."
) )
@@ -405,7 +472,18 @@ class AsyncVectorEnv(VectorEnv):
return results return results
def set_attr(self, name: str, values: Union[list, tuple, object]): def get_attr(self, name: str):
"""Get a property from each parallel environment.
Args:
name (str): Name of the property to be get from each individual environment.
Returns:
The property with name
"""
return self.call(name)
def set_attr(self, name: str, values: list[Any] | tuple[Any] | object):
"""Sets an attribute of the sub-environments. """Sets an attribute of the sub-environments.
Args: Args:
@@ -416,23 +494,21 @@ class AsyncVectorEnv(VectorEnv):
Raises: Raises:
ValueError: Values must be a list or tuple with length equal to the number of environments. ValueError: Values must be a list or tuple with length equal to the number of environments.
AlreadyPendingCallError: Calling `set_attr` while waiting for a pending call to complete. AlreadyPendingCallError: Calling :meth:`set_attr` while waiting for a pending call to complete.
""" """
self._assert_is_running() self._assert_is_running()
if not isinstance(values, (list, tuple)): if not isinstance(values, (list, tuple)):
values = [values for _ in range(self.num_envs)] values = [values for _ in range(self.num_envs)]
if len(values) != self.num_envs: if len(values) != self.num_envs:
raise ValueError( raise ValueError(
"Values must be a list or tuple with length equal to the " "Values must be a list or tuple with length equal to the number of environments. "
f"number of environments. Got `{len(values)}` values for " f"Got `{len(values)}` values for {self.num_envs} environments."
f"{self.num_envs} environments."
) )
if self._state != AsyncState.DEFAULT: if self._state != AsyncState.DEFAULT:
raise AlreadyPendingCallError( raise AlreadyPendingCallError(
"Calling `set_attr` while waiting " f"Calling `set_attr` while waiting for a pending call to `{self._state.value}` to complete.",
f"for a pending call to `{self._state.value}` to complete.", str(self._state.value),
self._state.value,
) )
for pipe, value in zip(self.parent_pipes, values): for pipe, value in zip(self.parent_pipes, values):
@@ -440,9 +516,7 @@ class AsyncVectorEnv(VectorEnv):
_, successes = zip(*[pipe.recv() for pipe in self.parent_pipes]) _, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
self._raise_if_errors(successes) self._raise_if_errors(successes)
def close_extras( def close_extras(self, timeout: int | float | None = None, terminate: bool = False):
self, timeout: Optional[Union[int, float]] = None, terminate: bool = False
):
"""Close the environments & clean up the extra resources (processes and pipes). """Close the environments & clean up the extra resources (processes and pipes).
Args: Args:
@@ -462,7 +536,7 @@ class AsyncVectorEnv(VectorEnv):
) )
function = getattr(self, f"{self._state.value}_wait") function = getattr(self, f"{self._state.value}_wait")
function(timeout) function(timeout)
except mp.TimeoutError: except multiprocessing.TimeoutError:
terminate = True terminate = True
if terminate: if terminate:
@@ -483,14 +557,16 @@ class AsyncVectorEnv(VectorEnv):
for process in self.processes: for process in self.processes:
process.join() process.join()
def _poll(self, timeout=None): def _poll_pipe_envs(self, timeout: int | None = None):
self._assert_is_running() self._assert_is_running()
if timeout is None: if timeout is None:
return True return True
end_time = time.perf_counter() + timeout end_time = time.perf_counter() + timeout
delta = None
for pipe in self.parent_pipes: for pipe in self.parent_pipes:
delta = max(end_time - time.perf_counter(), 0) delta = max(end_time - time.perf_counter(), 0)
if pipe is None: if pipe is None:
return False return False
if pipe.closed or (not pipe.poll(delta)): if pipe.closed or (not pipe.poll(delta)):
@@ -500,22 +576,23 @@ class AsyncVectorEnv(VectorEnv):
def _check_spaces(self): def _check_spaces(self):
self._assert_is_running() self._assert_is_running()
spaces = (self.single_observation_space, self.single_action_space) spaces = (self.single_observation_space, self.single_action_space)
for pipe in self.parent_pipes: for pipe in self.parent_pipes:
pipe.send(("_check_spaces", spaces)) pipe.send(("_check_spaces", spaces))
results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes]) results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
self._raise_if_errors(successes) self._raise_if_errors(successes)
same_observation_spaces, same_action_spaces = zip(*results) same_observation_spaces, same_action_spaces = zip(*results)
if not all(same_observation_spaces): if not all(same_observation_spaces):
raise RuntimeError( raise RuntimeError(
"Some environments have an observation space different from " f"Some environments have an observation space different from `{self.single_observation_space}`. "
f"`{self.single_observation_space}`. In order to batch observations, " "In order to batch observations, the observation spaces from all environments must be equal."
"the observation spaces from all environments must be equal."
) )
if not all(same_action_spaces): if not all(same_action_spaces):
raise RuntimeError( raise RuntimeError(
"Some environments have an action space different from " f"Some environments have an action space different from `{self.single_action_space}`. "
f"`{self.single_action_space}`. In order to batch actions, the " "In order to batch actions, the action spaces from all environments must be equal."
"action spaces from all environments must be equal."
) )
def _assert_is_running(self): def _assert_is_running(self):
@@ -524,7 +601,7 @@ class AsyncVectorEnv(VectorEnv):
f"Trying to operate on `{type(self).__name__}`, after a call to `close()`." f"Trying to operate on `{type(self).__name__}`, after a call to `close()`."
) )
def _raise_if_errors(self, successes): def _raise_if_errors(self, successes: list[bool]):
if all(successes): if all(successes):
return return
@@ -532,10 +609,12 @@ class AsyncVectorEnv(VectorEnv):
assert num_errors > 0 assert num_errors > 0
for i in range(num_errors): for i in range(num_errors):
index, exctype, value = self.error_queue.get() index, exctype, value = self.error_queue.get()
logger.error( logger.error(
f"Received the following error from Worker-{index}: {exctype.__name__}: {value}" f"Received the following error from Worker-{index}: {exctype.__name__}: {value}"
) )
logger.error(f"Shutting down Worker-{index}.") logger.error(f"Shutting down Worker-{index}.")
self.parent_pipes[index].close() self.parent_pipes[index].close()
self.parent_pipes[index] = None self.parent_pipes[index] = None
@@ -549,17 +628,32 @@ class AsyncVectorEnv(VectorEnv):
self.close(terminate=True) self.close(terminate=True)
def _worker(index, env_fn, pipe, parent_pipe, shared_memory, error_queue): def _async_worker(
assert shared_memory is None index: int,
env_fn: callable,
pipe: Connection,
parent_pipe: Connection,
shared_memory: bool,
error_queue: Queue,
):
env = env_fn() env = env_fn()
observation_space = env.observation_space
action_space = env.action_space
parent_pipe.close() parent_pipe.close()
try: try:
while True: while True:
command, data = pipe.recv() command, data = pipe.recv()
if command == "reset": if command == "reset":
observation, info = env.reset(**data) observation, info = env.reset(**data)
if shared_memory:
write_to_shared_memory(
observation_space, index, observation, shared_memory
)
observation = None
pipe.send(((observation, info), True)) pipe.send(((observation, info), True))
elif command == "step": elif command == "step":
( (
observation, observation,
@@ -573,112 +667,43 @@ def _worker(index, env_fn, pipe, parent_pipe, shared_memory, error_queue):
observation, info = env.reset() observation, info = env.reset()
info["final_observation"] = old_observation info["final_observation"] = old_observation
info["final_info"] = old_info info["final_info"] = old_info
if shared_memory:
write_to_shared_memory(
observation_space, index, observation, shared_memory
)
observation = None
pipe.send(((observation, reward, terminated, truncated, info), True)) pipe.send(((observation, reward, terminated, truncated, info), True))
elif command == "seed":
env.seed(data)
pipe.send((None, True))
elif command == "close": elif command == "close":
pipe.send((None, True)) pipe.send((None, True))
break break
elif command == "_call": elif command == "_call":
name, args, kwargs = data name, args, kwargs = data
if name in ["reset", "step", "seed", "close"]: if name in ["reset", "step", "close", "set_wrapper_attr"]:
raise ValueError( raise ValueError(
f"Trying to call function `{name}` with " f"Trying to call function `{name}` with `call`, use `{name}` directly instead."
f"`_call`. Use `{name}` directly instead."
) )
function = getattr(env, name)
if callable(function): attr = env.get_wrapper_attr(name)
pipe.send((function(*args, **kwargs), True)) if callable(attr):
pipe.send((attr(*args, **kwargs), True))
else: else:
pipe.send((function, True)) pipe.send((attr, True))
elif command == "_setattr": elif command == "_setattr":
name, value = data name, value = data
setattr(env, name, value) env.set_wrapper_attr(name, value)
pipe.send((None, True)) pipe.send((None, True))
elif command == "_check_spaces": elif command == "_check_spaces":
pipe.send( pipe.send(
( (
(data[0] == env.observation_space, data[1] == env.action_space), (data[0] == observation_space, data[1] == action_space),
True, True,
) )
) )
else: else:
raise RuntimeError( raise RuntimeError(
f"Received unknown command `{command}`. Must " f"Received unknown command `{command}`. Must be one of [`reset`, `step`, `close`, `_call`, `_setattr`, `_check_spaces`]."
"be one of {`reset`, `step`, `seed`, `close`, `_call`, "
"`_setattr`, `_check_spaces`}."
)
except (KeyboardInterrupt, Exception):
error_queue.put((index,) + sys.exc_info()[:2])
pipe.send((None, False))
finally:
env.close()
def _worker_shared_memory(index, env_fn, pipe, parent_pipe, shared_memory, error_queue):
assert shared_memory is not None
env = env_fn()
observation_space = env.observation_space
parent_pipe.close()
try:
while True:
command, data = pipe.recv()
if command == "reset":
observation, info = env.reset(**data)
write_to_shared_memory(
observation_space, index, observation, shared_memory
)
pipe.send(((None, info), True))
elif command == "step":
(
observation,
reward,
terminated,
truncated,
info,
) = env.step(data)
if terminated or truncated:
old_observation, old_info = observation, info
observation, info = env.reset()
info["final_observation"] = old_observation
info["final_info"] = old_info
write_to_shared_memory(
observation_space, index, observation, shared_memory
)
pipe.send(((None, reward, terminated, truncated, info), True))
elif command == "seed":
env.seed(data)
pipe.send((None, True))
elif command == "close":
pipe.send((None, True))
break
elif command == "_call":
name, args, kwargs = data
if name in ["reset", "step", "seed", "close"]:
raise ValueError(
f"Trying to call function `{name}` with "
f"`_call`. Use `{name}` directly instead."
)
function = getattr(env, name)
if callable(function):
pipe.send((function(*args, **kwargs), True))
else:
pipe.send((function, True))
elif command == "_setattr":
name, value = data
setattr(env, name, value)
pipe.send((None, True))
elif command == "_check_spaces":
pipe.send(
((data[0] == observation_space, data[1] == env.action_space), True)
)
else:
raise RuntimeError(
f"Received unknown command `{command}`. Must "
"be one of {`reset`, `step`, `seed`, `close`, `_call`, "
"`_setattr`, `_check_spaces`}."
) )
except (KeyboardInterrupt, Exception): except (KeyboardInterrupt, Exception):
error_queue.put((index,) + sys.exc_info()[:2]) error_queue.put((index,) + sys.exc_info()[:2])

View File

@@ -1,14 +1,15 @@
"""A synchronous vector environment.""" """Implementation of a synchronous (for loop) vectorization method of any environment."""
from __future__ import annotations
from copy import deepcopy from copy import deepcopy
from typing import Any, Callable, Iterable, List, Optional, Sequence, Tuple, Union from typing import Any, Callable, Iterator, Sequence
import numpy as np import numpy as np
from numpy.typing import NDArray
from gymnasium import Env from gymnasium import Env
from gymnasium.spaces import Space from gymnasium.core import ActType, ObsType, RenderFrame
from gymnasium.vector.utils import concatenate, create_empty_array, iterate from gymnasium.vector.utils import batch_space, concatenate, create_empty_array, iterate
from gymnasium.vector.vector_env import VectorEnv from gymnasium.vector.vector_env import ArrayType, VectorEnv
__all__ = ["SyncVectorEnv"] __all__ = ["SyncVectorEnv"]
@@ -19,156 +20,175 @@ class SyncVectorEnv(VectorEnv):
Example: Example:
>>> import gymnasium as gym >>> import gymnasium as gym
>>> env = gym.vector.SyncVectorEnv([ >>> envs = gym.make_vec("Pendulum-v1", num_envs=2, vectorization_mode="sync")
>>> envs
SyncVectorEnv(Pendulum-v1, num_envs=2)
>>> envs = gym.vector.SyncVectorEnv([
... lambda: gym.make("Pendulum-v1", g=9.81), ... lambda: gym.make("Pendulum-v1", g=9.81),
... lambda: gym.make("Pendulum-v1", g=1.62) ... lambda: gym.make("Pendulum-v1", g=1.62)
... ]) ... ])
>>> env.reset(seed=42) >>> envs
(array([[-0.14995256, 0.9886932 , -0.12224312], SyncVectorEnv(num_envs=2)
[ 0.5760367 , 0.8174238 , -0.91244936]], dtype=float32), {}) >>> obs, infos = envs.reset(seed=42)
>>> obs
array([[-0.14995256, 0.9886932 , -0.12224312],
[ 0.5760367 , 0.8174238 , -0.91244936]], dtype=float32)
>>> infos
{}
>>> _ = envs.action_space.seed(42)
>>> actions = envs.action_space.sample()
>>> obs, rewards, terminates, truncates, infos = envs.step(actions)
>>> obs
array([[-0.1878752 , 0.98219293, 0.7695615 ],
[ 0.6102389 , 0.79221743, -0.8498053 ]], dtype=float32)
>>> rewards
array([-2.96562607, -0.99902063])
>>> terminates
array([False, False])
>>> truncates
array([False, False])
>>> infos
{}
>>> envs.close()
""" """
def __init__( def __init__(
self, self,
env_fns: Iterable[Callable[[], Env]], env_fns: Iterator[Callable[[], Env]] | Sequence[Callable[[], Env]],
observation_space: Space = None,
action_space: Space = None,
copy: bool = True, copy: bool = True,
): ):
"""Vectorized environment that serially runs multiple environments. """Vectorized environment that serially runs multiple environments.
Args: Args:
env_fns: iterable of callable functions that create the environments. env_fns: iterable of callable functions that create the environments.
observation_space: Observation space of a single environment. If ``None``,
then the observation space of the first environment is taken.
action_space: Action space of a single environment. If ``None``,
then the action space of the first environment is taken.
copy: If ``True``, then the :meth:`reset` and :meth:`step` methods return a copy of the observations. copy: If ``True``, then the :meth:`reset` and :meth:`step` methods return a copy of the observations.
Raises: Raises:
RuntimeError: If the observation space of some sub-environment does not match observation_space RuntimeError: If the observation space of some sub-environment does not match observation_space
(or, by default, the observation space of the first sub-environment). (or, by default, the observation space of the first sub-environment).
""" """
self.env_fns = env_fns
self.envs = [env_fn() for env_fn in env_fns]
self.copy = copy self.copy = copy
self.env_fns = env_fns
# Initialise all sub-environments
self.envs = [env_fn() for env_fn in env_fns]
# Define core attributes using the sub-environments
# As we support `make_vec(spec)` then we can't include a `spec = self.envs[0].spec` as this doesn't guarantee we can actual recreate the vector env.
self.num_envs = len(self.envs)
self.metadata = self.envs[0].metadata self.metadata = self.envs[0].metadata
self.render_mode = self.envs[0].render_mode
if (observation_space is None) or (action_space is None): # Initialises the single spaces from the sub-environments
observation_space = observation_space or self.envs[0].observation_space self.single_observation_space = self.envs[0].observation_space
action_space = action_space or self.envs[0].action_space self.single_action_space = self.envs[0].action_space
super().__init__(
num_envs=len(self.envs),
observation_space=observation_space,
action_space=action_space,
)
self._check_spaces() self._check_spaces()
self.observations = create_empty_array(
# Initialise the obs and action space based on the single versions and num of sub-environments
self.observation_space = batch_space(
self.single_observation_space, self.num_envs
)
self.action_space = batch_space(self.single_action_space, self.num_envs)
# Initialise attributes used in `step` and `reset`
self._observations = create_empty_array(
self.single_observation_space, n=self.num_envs, fn=np.zeros self.single_observation_space, n=self.num_envs, fn=np.zeros
) )
self._rewards = np.zeros((self.num_envs,), dtype=np.float64) self._rewards = np.zeros((self.num_envs,), dtype=np.float64)
self._terminateds = np.zeros((self.num_envs,), dtype=np.bool_) self._terminations = np.zeros((self.num_envs,), dtype=np.bool_)
self._truncateds = np.zeros((self.num_envs,), dtype=np.bool_) self._truncations = np.zeros((self.num_envs,), dtype=np.bool_)
self._actions = None
def seed(self, seed: Optional[Union[int, Sequence[int]]] = None): def reset(
"""Sets the seed in all sub-environments.
Args:
seed: The seed
"""
super().seed(seed=seed)
if seed is None:
seed = [None for _ in range(self.num_envs)]
if isinstance(seed, int):
seed = [seed + i for i in range(self.num_envs)]
assert len(seed) == self.num_envs
for env, single_seed in zip(self.envs, seed):
env.seed(single_seed)
def reset_wait(
self, self,
seed: Optional[Union[int, List[int]]] = None, *,
options: Optional[dict] = None, seed: int | list[int] | None = None,
): options: dict[str, Any] | None = None,
"""Waits for the calls triggered by :meth:`reset_async` to finish and returns the results. ) -> tuple[ObsType, dict[str, Any]]:
"""Resets each of the sub-environments and concatenate the results together.
Args: Args:
seed: The reset environment seed seed: Seeds used to reset the sub-environments, either
options: Option information for the environment reset * ``None`` - random seeds for all environment
* ``int`` - ``[seed, seed+1, ..., seed+n]``
* List of ints - ``[1, 2, 3, ..., n]``
options: Option information used for each sub-environment
Returns: Returns:
The reset observation of the environment and reset information Concatenated observations and info from each sub-environment
""" """
if seed is None: if seed is None:
seed = [None for _ in range(self.num_envs)] seed = [None for _ in range(self.num_envs)]
if isinstance(seed, int): elif isinstance(seed, int):
seed = [seed + i for i in range(self.num_envs)] seed = [seed + i for i in range(self.num_envs)]
assert len(seed) == self.num_envs assert len(seed) == self.num_envs
self._terminateds[:] = False self._terminations = np.zeros((self.num_envs,), dtype=np.bool_)
self._truncateds[:] = False self._truncations = np.zeros((self.num_envs,), dtype=np.bool_)
observations = []
infos = {} observations, infos = [], {}
for i, (env, single_seed) in enumerate(zip(self.envs, seed)): for i, (env, single_seed) in enumerate(zip(self.envs, seed)):
kwargs = {} env_obs, env_info = env.reset(seed=single_seed, options=options)
if single_seed is not None:
kwargs["seed"] = single_seed
if options is not None:
kwargs["options"] = options
observation, info = env.reset(**kwargs) observations.append(env_obs)
observations.append(observation) infos = self._add_info(infos, env_info, i)
infos = self._add_info(infos, info, i)
self.observations = concatenate( # Concatenate the observations
self.single_observation_space, observations, self.observations self._observations = concatenate(
self.single_observation_space, observations, self._observations
) )
return (deepcopy(self.observations) if self.copy else self.observations), infos
def step_async(self, actions): return deepcopy(self._observations) if self.copy else self._observations, infos
"""Sets :attr:`_actions` for use by the :meth:`step_wait` by converting the ``actions`` to an iterable version."""
self._actions = iterate(self.action_space, actions)
def step_wait(self) -> Tuple[Any, NDArray[Any], NDArray[Any], NDArray[Any], dict]: def step(
self, actions: ActType
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict[str, Any]]:
"""Steps through each of the environments returning the batched results. """Steps through each of the environments returning the batched results.
Returns: Returns:
The batched environment step results The batched environment step results
""" """
actions = iterate(self.action_space, actions)
observations, infos = [], {} observations, infos = [], {}
for i, (env, action) in enumerate(zip(self.envs, self._actions)): for i, (env, action) in enumerate(zip(self.envs, actions)):
( (
observation, env_obs,
self._rewards[i], self._rewards[i],
self._terminateds[i], self._terminations[i],
self._truncateds[i], self._truncations[i],
info, env_info,
) = env.step(action) ) = env.step(action)
if self._terminateds[i] or self._truncateds[i]: # If sub-environments terminates or truncates then save the obs and info to the batched info
old_observation, old_info = observation, info if self._terminations[i] or self._truncations[i]:
observation, info = env.reset() old_observation, old_info = env_obs, env_info
info["final_observation"] = old_observation env_obs, env_info = env.reset()
info["final_info"] = old_info
observations.append(observation) env_info["final_observation"] = old_observation
infos = self._add_info(infos, info, i) env_info["final_info"] = old_info
self.observations = concatenate(
self.single_observation_space, observations, self.observations observations.append(env_obs)
infos = self._add_info(infos, env_info, i)
# Concatenate the observations
self._observations = concatenate(
self.single_observation_space, observations, self._observations
) )
return ( return (
deepcopy(self.observations) if self.copy else self.observations, deepcopy(self._observations) if self.copy else self._observations,
np.copy(self._rewards), np.copy(self._rewards),
np.copy(self._terminateds), np.copy(self._terminations),
np.copy(self._truncateds), np.copy(self._truncations),
infos, infos,
) )
def call(self, name, *args, **kwargs) -> tuple: def render(self) -> tuple[RenderFrame, ...] | None:
"""Calls the method with name and applies args and kwargs. """Returns the rendered frames from the environments."""
return tuple(env.render() for env in self.envs)
def call(self, name: str, *args: Any, **kwargs: Any) -> tuple[Any, ...]:
"""Calls a sub-environment method with name and applies args and kwargs.
Args: Args:
name: The method name name: The method name
@@ -180,7 +200,8 @@ class SyncVectorEnv(VectorEnv):
""" """
results = [] results = []
for env in self.envs: for env in self.envs:
function = getattr(env, name) function = env.get_wrapper_attr(name)
if callable(function): if callable(function):
results.append(function(*args, **kwargs)) results.append(function(*args, **kwargs))
else: else:
@@ -188,7 +209,18 @@ class SyncVectorEnv(VectorEnv):
return tuple(results) return tuple(results)
def set_attr(self, name: str, values: Union[list, tuple, Any]): def get_attr(self, name: str) -> Any:
"""Get a property from each parallel environment.
Args:
name (str): Name of the property to get from each individual environment.
Returns:
The property with name
"""
return self.call(name)
def set_attr(self, name: str, values: list[Any] | tuple[Any, ...] | Any):
"""Sets an attribute of the sub-environments. """Sets an attribute of the sub-environments.
Args: Args:
@@ -202,34 +234,33 @@ class SyncVectorEnv(VectorEnv):
""" """
if not isinstance(values, (list, tuple)): if not isinstance(values, (list, tuple)):
values = [values for _ in range(self.num_envs)] values = [values for _ in range(self.num_envs)]
if len(values) != self.num_envs: if len(values) != self.num_envs:
raise ValueError( raise ValueError(
"Values must be a list or tuple with length equal to the " "Values must be a list or tuple with length equal to the number of environments. "
f"number of environments. Got `{len(values)}` values for " f"Got `{len(values)}` values for {self.num_envs} environments."
f"{self.num_envs} environments."
) )
for env, value in zip(self.envs, values): for env, value in zip(self.envs, values):
setattr(env, name, value) env.set_wrapper_attr(name, value)
def close_extras(self, **kwargs): def close_extras(self, **kwargs: Any):
"""Close the environments.""" """Close the environments."""
[env.close() for env in self.envs] [env.close() for env in self.envs]
def _check_spaces(self) -> bool: def _check_spaces(self) -> bool:
"""Check that each of the environments obs and action spaces are equivalent to the single obs and action space."""
for env in self.envs: for env in self.envs:
if not (env.observation_space == self.single_observation_space): if not (env.observation_space == self.single_observation_space):
raise RuntimeError( raise RuntimeError(
"Some environments have an observation space different from " f"Some environments have an observation space different from `{self.single_observation_space}`. "
f"`{self.single_observation_space}`. In order to batch observations, " "In order to batch observations, the observation spaces from all environments must be equal."
"the observation spaces from all environments must be equal."
) )
if not (env.action_space == self.single_action_space): if not (env.action_space == self.single_action_space):
raise RuntimeError( raise RuntimeError(
"Some environments have an action space different from " f"Some environments have an action space different from `{self.single_action_space}`. "
f"`{self.single_action_space}`. In order to batch actions, the " "In order to batch actions, the action spaces from all environments must be equal."
"action spaces from all environments must be equal."
) )
return True return True

View File

@@ -1,26 +1,27 @@
"""Module for gymnasium vector utils.""" """Module for gymnasium experimental vector utility functions."""
from gymnasium.vector.utils.misc import CloudpickleWrapper, clear_mpi_env_vars from gymnasium.vector.utils.misc import CloudpickleWrapper, clear_mpi_env_vars
from gymnasium.vector.utils.numpy_utils import concatenate, create_empty_array
from gymnasium.vector.utils.shared_memory import ( from gymnasium.vector.utils.shared_memory import (
create_shared_memory, create_shared_memory,
read_from_shared_memory, read_from_shared_memory,
write_to_shared_memory, write_to_shared_memory,
) )
from gymnasium.vector.utils.spaces import ( from gymnasium.vector.utils.space_utils import (
_BaseGymSpaces, # pyright: ignore[reportPrivateUsage] batch_space,
concatenate,
create_empty_array,
iterate,
) )
from gymnasium.vector.utils.spaces import BaseGymSpaces, batch_space, iterate
__all__ = [ __all__ = [
"CloudpickleWrapper", "batch_space",
"clear_mpi_env_vars", "iterate",
"concatenate", "concatenate",
"create_empty_array", "create_empty_array",
"create_shared_memory", "create_shared_memory",
"read_from_shared_memory", "read_from_shared_memory",
"write_to_shared_memory", "write_to_shared_memory",
"BaseGymSpaces", "CloudpickleWrapper",
"batch_space", "clear_mpi_env_vars",
"iterate",
] ]

View File

@@ -39,7 +39,7 @@ class CloudpickleWrapper:
def clear_mpi_env_vars(): def clear_mpi_env_vars():
"""Clears the MPI of environment variables. """Clears the MPI of environment variables.
`from mpi4py import MPI` will call `MPI_Init` by default. ``from mpi4py import MPI`` will call ``MPI_Init`` by default.
If the child process has MPI environment variables, MPI will think that the child process If the child process has MPI environment variables, MPI will think that the child process
is an MPI process just like the parent and do bad things such as hang. is an MPI process just like the parent and do bad things such as hang.

View File

@@ -1,146 +0,0 @@
"""Numpy utility functions: concatenate space samples and create empty array."""
from collections import OrderedDict
from functools import singledispatch
from typing import Callable, Iterable, Union
import numpy as np
from gymnasium.spaces import (
Box,
Dict,
Discrete,
MultiBinary,
MultiDiscrete,
Space,
Tuple,
)
__all__ = ["concatenate", "create_empty_array"]
@singledispatch
def concatenate(
space: Space, items: Iterable, out: Union[tuple, dict, np.ndarray]
) -> Union[tuple, dict, np.ndarray]:
"""Concatenate multiple samples from space into a single object.
Args:
space: Observation space of a single environment in the vectorized environment.
items: Samples to be concatenated.
out: The output object. This object is a (possibly nested) numpy array.
Returns:
The output object. This object is a (possibly nested) numpy array.
Raises:
ValueError: Space is not a valid :class:`gym.Space` instance
Example:
>>> from gymnasium.spaces import Box
>>> import numpy as np
>>> space = Box(low=0, high=1, shape=(3,), seed=42, dtype=np.float32)
>>> out = np.zeros((2, 3), dtype=np.float32)
>>> items = [space.sample() for _ in range(2)]
>>> concatenate(space, items, out)
array([[0.77395606, 0.43887845, 0.85859793],
[0.697368 , 0.09417735, 0.97562236]], dtype=float32)
"""
raise ValueError(
f"Space of type `{type(space)}` is not a valid `gymnasium.Space` instance."
)
@concatenate.register(Box)
@concatenate.register(Discrete)
@concatenate.register(MultiDiscrete)
@concatenate.register(MultiBinary)
def _concatenate_base(space, items, out):
return np.stack(items, axis=0, out=out)
@concatenate.register(Tuple)
def _concatenate_tuple(space, items, out):
return tuple(
concatenate(subspace, [item[i] for item in items], out[i])
for (i, subspace) in enumerate(space.spaces)
)
@concatenate.register(Dict)
def _concatenate_dict(space, items, out):
return OrderedDict(
[
(key, concatenate(subspace, [item[key] for item in items], out[key]))
for (key, subspace) in space.spaces.items()
]
)
@concatenate.register(Space)
def _concatenate_custom(space, items, out):
return tuple(items)
@singledispatch
def create_empty_array(
space: Space, n: int = 1, fn: Callable[..., np.ndarray] = np.zeros
) -> Union[tuple, dict, np.ndarray]:
"""Create an empty (possibly nested) numpy array.
Args:
space: Observation space of a single environment in the vectorized environment.
n: Number of environments in the vectorized environment. If `None`, creates an empty sample from `space`.
fn: Function to apply when creating the empty numpy array. Examples of such functions are `np.empty` or `np.zeros`.
Returns:
The output object. This object is a (possibly nested) numpy array.
Raises:
ValueError: Space is not a valid :class:`gym.Space` instance
Example:
>>> from gymnasium.spaces import Box, Dict
>>> import numpy as np
>>> space = Dict({
... 'position': Box(low=0, high=1, shape=(3,), dtype=np.float32),
... 'velocity': Box(low=0, high=1, shape=(2,), dtype=np.float32)})
>>> create_empty_array(space, n=2, fn=np.zeros)
OrderedDict([('position', array([[0., 0., 0.],
[0., 0., 0.]], dtype=float32)), ('velocity', array([[0., 0.],
[0., 0.]], dtype=float32))])
"""
raise ValueError(
f"Space of type `{type(space)}` is not a valid `gymnasium.Space` instance."
)
# It is possible for the some of the Box low to be greater than 0, then array is not in space
@create_empty_array.register(Box)
# If the Discrete start > 0 or start + length < 0 then array is not in space
@create_empty_array.register(Discrete)
@create_empty_array.register(MultiDiscrete)
@create_empty_array.register(MultiBinary)
def _create_empty_array_base(space, n=1, fn=np.zeros):
shape = space.shape if (n is None) else (n,) + space.shape
return fn(shape, dtype=space.dtype)
@create_empty_array.register(Tuple)
def _create_empty_array_tuple(space, n=1, fn=np.zeros):
return tuple(create_empty_array(subspace, n=n, fn=fn) for subspace in space.spaces)
@create_empty_array.register(Dict)
def _create_empty_array_dict(space, n=1, fn=np.zeros):
return OrderedDict(
[
(key, create_empty_array(subspace, n=n, fn=fn))
for (key, subspace) in space.spaces.items()
]
)
@create_empty_array.register(Space)
def _create_empty_array_custom(space, n=1, fn=np.zeros):
return None

View File

@@ -1,9 +1,11 @@
"""Utility functions for vector environments to share memory between processes.""" """Utility functions for vector environments to share memory between processes."""
from __future__ import annotations
import multiprocessing as mp import multiprocessing as mp
from collections import OrderedDict from collections import OrderedDict
from ctypes import c_bool from ctypes import c_bool
from functools import singledispatch from functools import singledispatch
from typing import Union from typing import Any
import numpy as np import numpy as np
@@ -12,10 +14,14 @@ from gymnasium.spaces import (
Box, Box,
Dict, Dict,
Discrete, Discrete,
Graph,
MultiBinary, MultiBinary,
MultiDiscrete, MultiDiscrete,
Sequence,
Space, Space,
Text,
Tuple, Tuple,
flatten,
) )
@@ -24,8 +30,8 @@ __all__ = ["create_shared_memory", "read_from_shared_memory", "write_to_shared_m
@singledispatch @singledispatch
def create_shared_memory( def create_shared_memory(
space: Space, n: int = 1, ctx=mp space: Space[Any], n: int = 1, ctx=mp
) -> Union[dict, tuple, mp.Array]: ) -> dict[str, Any] | tuple[Any, ...] | mp.Array:
"""Create a shared memory object, to be shared across processes. """Create a shared memory object, to be shared across processes.
This eventually contains the observations from the vectorized environment. This eventually contains the observations from the vectorized environment.
@@ -41,12 +47,13 @@ def create_shared_memory(
Raises: Raises:
CustomSpaceError: Space is not a valid :class:`gymnasium.Space` instance CustomSpaceError: Space is not a valid :class:`gymnasium.Space` instance
""" """
if isinstance(space, Space):
raise CustomSpaceError( raise CustomSpaceError(
"Cannot create a shared memory for space with " f"Space of type `{type(space)}` doesn't have an registered `create_shared_memory` function. Register `{type(space)}` for `create_shared_memory` to support it."
f"type `{type(space)}`. Shared memory only supports " )
"default Gymnasium spaces (e.g. `Box`, `Tuple`, " else:
"`Dict`, etc...), and does not support custom " raise TypeError(
"Gymnasium spaces." f"The space provided to `create_shared_memory` is not a gymnasium Space instance, type: {type(space)}, {space}"
) )
@@ -54,7 +61,10 @@ def create_shared_memory(
@create_shared_memory.register(Discrete) @create_shared_memory.register(Discrete)
@create_shared_memory.register(MultiDiscrete) @create_shared_memory.register(MultiDiscrete)
@create_shared_memory.register(MultiBinary) @create_shared_memory.register(MultiBinary)
def _create_base_shared_memory(space, n: int = 1, ctx=mp): def _create_base_shared_memory(
space: Box | Discrete | MultiDiscrete | MultiBinary, n: int = 1, ctx=mp
):
assert space.dtype is not None
dtype = space.dtype.char dtype = space.dtype.char
if dtype in "?": if dtype in "?":
dtype = c_bool dtype = c_bool
@@ -62,14 +72,14 @@ def _create_base_shared_memory(space, n: int = 1, ctx=mp):
@create_shared_memory.register(Tuple) @create_shared_memory.register(Tuple)
def _create_tuple_shared_memory(space, n: int = 1, ctx=mp): def _create_tuple_shared_memory(space: Tuple, n: int = 1, ctx=mp):
return tuple( return tuple(
create_shared_memory(subspace, n=n, ctx=ctx) for subspace in space.spaces create_shared_memory(subspace, n=n, ctx=ctx) for subspace in space.spaces
) )
@create_shared_memory.register(Dict) @create_shared_memory.register(Dict)
def _create_dict_shared_memory(space, n=1, ctx=mp): def _create_dict_shared_memory(space: Dict, n: int = 1, ctx=mp):
return OrderedDict( return OrderedDict(
[ [
(key, create_shared_memory(subspace, n=n, ctx=ctx)) (key, create_shared_memory(subspace, n=n, ctx=ctx))
@@ -78,10 +88,23 @@ def _create_dict_shared_memory(space, n=1, ctx=mp):
) )
@create_shared_memory.register(Text)
def _create_text_shared_memory(space: Text, n: int = 1, ctx=mp):
return ctx.Array(np.dtype(np.int32).char, n * space.max_length)
@create_shared_memory.register(Graph)
@create_shared_memory.register(Sequence)
def _create_dynamic_shared_memory(space: Graph | Sequence, n: int = 1, ctx=mp):
raise TypeError(
f"As {space} has a dynamic shape then it is not possible to make a static shared memory."
)
@singledispatch @singledispatch
def read_from_shared_memory( def read_from_shared_memory(
space: Space, shared_memory: Union[dict, tuple, mp.Array], n: int = 1 space: Space, shared_memory: dict | tuple | mp.Array, n: int = 1
) -> Union[dict, tuple, np.ndarray]: ) -> dict[str, Any] | tuple[Any, ...] | np.ndarray:
"""Read the batch of observations from shared memory as a numpy array. """Read the batch of observations from shared memory as a numpy array.
..notes:: ..notes::
@@ -101,12 +124,13 @@ def read_from_shared_memory(
Raises: Raises:
CustomSpaceError: Space is not a valid :class:`gymnasium.Space` instance CustomSpaceError: Space is not a valid :class:`gymnasium.Space` instance
""" """
if isinstance(space, Space):
raise CustomSpaceError( raise CustomSpaceError(
"Cannot read from a shared memory for space with " f"Space of type `{type(space)}` doesn't have an registered `read_from_shared_memory` function. Register `{type(space)}` for `read_from_shared_memory` to support it."
f"type `{type(space)}`. Shared memory only supports " )
"default Gymnasium spaces (e.g. `Box`, `Tuple`, " else:
"`Dict`, etc...), and does not support custom " raise TypeError(
"Gymnasium spaces." f"The space provided to `read_from_shared_memory` is not a gymnasium Space instance, type: {type(space)}, {space}"
) )
@@ -114,14 +138,16 @@ def read_from_shared_memory(
@read_from_shared_memory.register(Discrete) @read_from_shared_memory.register(Discrete)
@read_from_shared_memory.register(MultiDiscrete) @read_from_shared_memory.register(MultiDiscrete)
@read_from_shared_memory.register(MultiBinary) @read_from_shared_memory.register(MultiBinary)
def _read_base_from_shared_memory(space, shared_memory, n: int = 1): def _read_base_from_shared_memory(
space: Box | Discrete | MultiDiscrete | MultiBinary, shared_memory, n: int = 1
):
return np.frombuffer(shared_memory.get_obj(), dtype=space.dtype).reshape( return np.frombuffer(shared_memory.get_obj(), dtype=space.dtype).reshape(
(n,) + space.shape (n,) + space.shape
) )
@read_from_shared_memory.register(Tuple) @read_from_shared_memory.register(Tuple)
def _read_tuple_from_shared_memory(space, shared_memory, n: int = 1): def _read_tuple_from_shared_memory(space: Tuple, shared_memory, n: int = 1):
return tuple( return tuple(
read_from_shared_memory(subspace, memory, n=n) read_from_shared_memory(subspace, memory, n=n)
for (memory, subspace) in zip(shared_memory, space.spaces) for (memory, subspace) in zip(shared_memory, space.spaces)
@@ -129,7 +155,7 @@ def _read_tuple_from_shared_memory(space, shared_memory, n: int = 1):
@read_from_shared_memory.register(Dict) @read_from_shared_memory.register(Dict)
def _read_dict_from_shared_memory(space, shared_memory, n: int = 1): def _read_dict_from_shared_memory(space: Dict, shared_memory, n: int = 1):
return OrderedDict( return OrderedDict(
[ [
(key, read_from_shared_memory(subspace, shared_memory[key], n=n)) (key, read_from_shared_memory(subspace, shared_memory[key], n=n))
@@ -138,12 +164,30 @@ def _read_dict_from_shared_memory(space, shared_memory, n: int = 1):
) )
@read_from_shared_memory.register(Text)
def _read_text_from_shared_memory(space: Text, shared_memory, n: int = 1) -> tuple[str]:
data = np.frombuffer(shared_memory.get_obj(), dtype=np.int32).reshape(
(n, space.max_length)
)
return tuple(
"".join(
[
space.character_list[val]
for val in values
if val < len(space.character_set)
]
)
for values in data
)
@singledispatch @singledispatch
def write_to_shared_memory( def write_to_shared_memory(
space: Space, space: Space,
index: int, index: int,
value: np.ndarray, value: np.ndarray,
shared_memory: Union[dict, tuple, mp.Array], shared_memory: dict[str, Any] | tuple[Any, ...] | mp.Array,
): ):
"""Write the observation of a single environment into shared memory. """Write the observation of a single environment into shared memory.
@@ -157,12 +201,13 @@ def write_to_shared_memory(
Raises: Raises:
CustomSpaceError: Space is not a valid :class:`gymnasium.Space` instance CustomSpaceError: Space is not a valid :class:`gymnasium.Space` instance
""" """
if isinstance(space, Space):
raise CustomSpaceError( raise CustomSpaceError(
"Cannot write to a shared memory for space with " f"Space of type `{type(space)}` doesn't have an registered `write_to_shared_memory` function. Register `{type(space)}` for `write_to_shared_memory` to support it."
f"type `{type(space)}`. Shared memory only supports " )
"default Gymnasium spaces (e.g. `Box`, `Tuple`, " else:
"`Dict`, etc...), and does not support custom " raise TypeError(
"Gymnasium spaces." f"The space provided to `write_to_shared_memory` is not a gymnasium Space instance, type: {type(space)}, {space}"
) )
@@ -170,7 +215,12 @@ def write_to_shared_memory(
@write_to_shared_memory.register(Discrete) @write_to_shared_memory.register(Discrete)
@write_to_shared_memory.register(MultiDiscrete) @write_to_shared_memory.register(MultiDiscrete)
@write_to_shared_memory.register(MultiBinary) @write_to_shared_memory.register(MultiBinary)
def _write_base_to_shared_memory(space, index, value, shared_memory): def _write_base_to_shared_memory(
space: Box | Discrete | MultiDiscrete | MultiBinary,
index: int,
value,
shared_memory,
):
size = int(np.prod(space.shape)) size = int(np.prod(space.shape))
destination = np.frombuffer(shared_memory.get_obj(), dtype=space.dtype) destination = np.frombuffer(shared_memory.get_obj(), dtype=space.dtype)
np.copyto( np.copyto(
@@ -180,12 +230,26 @@ def _write_base_to_shared_memory(space, index, value, shared_memory):
@write_to_shared_memory.register(Tuple) @write_to_shared_memory.register(Tuple)
def _write_tuple_to_shared_memory(space, index, values, shared_memory): def _write_tuple_to_shared_memory(
space: Tuple, index: int, values: tuple[Any, ...], shared_memory
):
for value, memory, subspace in zip(values, shared_memory, space.spaces): for value, memory, subspace in zip(values, shared_memory, space.spaces):
write_to_shared_memory(subspace, index, value, memory) write_to_shared_memory(subspace, index, value, memory)
@write_to_shared_memory.register(Dict) @write_to_shared_memory.register(Dict)
def _write_dict_to_shared_memory(space, index, values, shared_memory): def _write_dict_to_shared_memory(
space: Dict, index: int, values: dict[str, Any], shared_memory
):
for key, subspace in space.spaces.items(): for key, subspace in space.spaces.items():
write_to_shared_memory(subspace, index, values[key], shared_memory[key]) write_to_shared_memory(subspace, index, values[key], shared_memory[key])
@write_to_shared_memory.register(Text)
def _write_text_to_shared_memory(space: Text, index: int, values: str, shared_memory):
size = space.max_length
destination = np.frombuffer(shared_memory.get_obj(), dtype=np.int32)
np.copyto(
destination[index * size : (index + 1) * size],
flatten(space, values),
)

View File

@@ -150,7 +150,7 @@ def iterate(space: Space[T_cov], items: Iterable[T_cov]) -> Iterator:
The output object. This object is a (possibly nested) numpy array. The output object. This object is a (possibly nested) numpy array.
Raises: Raises:
ValueError: Space is not an instance of :class:`gym.Space` ValueError: Space is not an instance of :class:`gymnasium.Space`
Example: Example:
>>> from gymnasium.spaces import Box, Dict >>> from gymnasium.spaces import Box, Dict
@@ -311,14 +311,14 @@ def create_empty_array(
Args: Args:
space: Observation space of a single environment in the vectorized environment. space: Observation space of a single environment in the vectorized environment.
n: Number of environments in the vectorized environment. If `None`, creates an empty sample from `space`. n: Number of environments in the vectorized environment. If ``None``, creates an empty sample from ``space``.
fn: Function to apply when creating the empty numpy array. Examples of such functions are `np.empty` or `np.zeros`. fn: Function to apply when creating the empty numpy array. Examples of such functions are ``np.empty`` or ``np.zeros``.
Returns: Returns:
The output object. This object is a (possibly nested) numpy array. The output object. This object is a (possibly nested) numpy array.
Raises: Raises:
ValueError: Space is not a valid :class:`gym.Space` instance ValueError: Space is not a valid :class:`gymnasium.Space` instance
Example: Example:
>>> from gymnasium.spaces import Box, Dict >>> from gymnasium.spaces import Box, Dict

View File

@@ -1,215 +0,0 @@
"""Utility functions for gymnasium spaces: batch space and iterator."""
from collections import OrderedDict
from copy import deepcopy
from functools import singledispatch
from typing import Iterator
import numpy as np
from gymnasium.error import CustomSpaceError
from gymnasium.spaces import (
Box,
Dict,
Discrete,
MultiBinary,
MultiDiscrete,
Space,
Tuple,
)
BaseGymSpaces = (Box, Discrete, MultiDiscrete, MultiBinary)
_BaseGymSpaces = BaseGymSpaces
__all__ = ["BaseGymSpaces", "_BaseGymSpaces", "batch_space", "iterate"]
@singledispatch
def batch_space(space: Space, n: int = 1) -> Space:
"""Create a (batched) space, containing multiple copies of a single space.
Args:
space: Space (e.g. the observation space) for a single environment in the vectorized environment.
n: Number of environments in the vectorized environment.
Returns:
Space (e.g. the observation space) for a batch of environments in the vectorized environment.
Raises:
ValueError: Cannot batch space that is not a valid :class:`gym.Space` instance
Example:
>>> from gymnasium.spaces import Box, Dict
>>> import numpy as np
>>> space = Dict({
... 'position': Box(low=0, high=1, shape=(3,), dtype=np.float32),
... 'velocity': Box(low=0, high=1, shape=(2,), dtype=np.float32)
... })
>>> batch_space(space, n=5)
Dict('position': Box(0.0, 1.0, (5, 3), float32), 'velocity': Box(0.0, 1.0, (5, 2), float32))
"""
raise ValueError(
f"Cannot batch space with type `{type(space)}`. The space must be a valid `gymnasium.Space` instance."
)
@batch_space.register(Box)
def _batch_space_box(space, n=1):
repeats = tuple([n] + [1] * space.low.ndim)
low, high = np.tile(space.low, repeats), np.tile(space.high, repeats)
return Box(low=low, high=high, dtype=space.dtype, seed=deepcopy(space.np_random))
@batch_space.register(Discrete)
def _batch_space_discrete(space, n=1):
return MultiDiscrete(
np.full((n,), space.n, dtype=space.dtype),
dtype=space.dtype,
seed=deepcopy(space.np_random),
start=np.full((n,), space.start, dtype=space.dtype),
)
@batch_space.register(MultiDiscrete)
def _batch_space_multidiscrete(space, n=1):
repeats = tuple([n] + [1] * space.nvec.ndim)
low = np.tile(space.start, repeats)
high = low + np.tile(space.nvec, repeats) - 1
return Box(
low=low,
high=high,
dtype=space.dtype,
seed=deepcopy(space.np_random),
)
@batch_space.register(MultiBinary)
def _batch_space_multibinary(space, n=1):
return Box(
low=0,
high=1,
shape=(n,) + space.shape,
dtype=space.dtype,
seed=deepcopy(space.np_random),
)
@batch_space.register(Tuple)
def _batch_space_tuple(space, n=1):
return Tuple(
tuple(batch_space(subspace, n=n) for subspace in space.spaces),
seed=deepcopy(space.np_random),
)
@batch_space.register(Dict)
def _batch_space_dict(space, n=1):
return Dict(
OrderedDict(
[
(key, batch_space(subspace, n=n))
for (key, subspace) in space.spaces.items()
]
),
seed=deepcopy(space.np_random),
)
@batch_space.register(Space)
def _batch_space_custom(space, n=1):
# Without deepcopy, then the space.np_random is batched_space.spaces[0].np_random
# Which is an issue if you are sampling actions of both the original space and the batched space
batched_space = Tuple(
tuple(deepcopy(space) for _ in range(n)), seed=deepcopy(space.np_random)
)
new_seeds = list(map(int, batched_space.np_random.integers(0, 1e8, n)))
batched_space.seed(new_seeds)
return batched_space
@singledispatch
def iterate(space: Space, items) -> Iterator:
"""Iterate over the elements of a (batched) space.
Args:
space: Space to which `items` belong to.
items: Items to be iterated over.
Returns:
Iterator over the elements in `items`.
Raises:
ValueError: Space is not an instance of :class:`gym.Space`
Example:
>>> from gymnasium.spaces import Box, Dict
>>> import numpy as np
>>> space = Dict({
... 'position': Box(low=0, high=1, shape=(2, 3), seed=42, dtype=np.float32),
... 'velocity': Box(low=0, high=1, shape=(2, 2), seed=42, dtype=np.float32)})
>>> items = space.sample()
>>> it = iterate(space, items)
>>> next(it)
OrderedDict([('position', array([0.77395606, 0.43887845, 0.85859793], dtype=float32)), ('velocity', array([0.77395606, 0.43887845], dtype=float32))])
>>> next(it)
OrderedDict([('position', array([0.697368 , 0.09417735, 0.97562236], dtype=float32)), ('velocity', array([0.85859793, 0.697368 ], dtype=float32))])
>>> next(it)
Traceback (most recent call last):
...
StopIteration
"""
raise ValueError(
f"Space of type `{type(space)}` is not a valid `gymnasium.Space` instance."
)
@iterate.register(Discrete)
def _iterate_discrete(space, items):
raise TypeError("Unable to iterate over a space of type `Discrete`.")
@iterate.register(Box)
@iterate.register(MultiDiscrete)
@iterate.register(MultiBinary)
def _iterate_base(space, items):
try:
return iter(items)
except TypeError as e:
raise TypeError(
f"Unable to iterate over the following elements: {items}"
) from e
@iterate.register(Tuple)
def _iterate_tuple(space, items):
# If this is a tuple of custom subspaces only, then simply iterate over items
if all(
isinstance(subspace, Space)
and (not isinstance(subspace, BaseGymSpaces + (Tuple, Dict)))
for subspace in space.spaces
):
return iter(items)
return zip(
*[iterate(subspace, items[i]) for i, subspace in enumerate(space.spaces)]
)
@iterate.register(Dict)
def _iterate_dict(space, items):
keys, values = zip(
*[
(key, iterate(subspace, items[key]))
for key, subspace in space.spaces.items()
]
)
for item in zip(*values):
yield OrderedDict([(key, value) for (key, value) in zip(keys, item)])
@iterate.register(Space)
def _iterate_custom(space, items):
raise CustomSpaceError(
f"Unable to iterate over {items}, since {space} "
"is a custom `gymnasium.Space` instance (i.e. not one of "
"`Box`, `Dict`, etc...)."
)

View File

@@ -1,30 +1,46 @@
"""Base class for vectorized environments.""" """Base class for vectorized environments."""
from typing import Any, List, Optional, Tuple, Union from __future__ import annotations
from typing import TYPE_CHECKING, Any, Generic, TypeVar
import numpy as np import numpy as np
from numpy.typing import NDArray
import gymnasium as gym import gymnasium as gym
from gymnasium.vector.utils.spaces import batch_space from gymnasium.core import ActType, ObsType, RenderFrame
from gymnasium.utils import seeding
__all__ = ["VectorEnv"] if TYPE_CHECKING:
from gymnasium.envs.registration import EnvSpec
ArrayType = TypeVar("ArrayType")
class VectorEnv(gym.Env): __all__ = [
"VectorEnv",
"VectorWrapper",
"VectorObservationWrapper",
"VectorActionWrapper",
"VectorRewardWrapper",
"ArrayType",
]
class VectorEnv(Generic[ObsType, ActType, ArrayType]):
"""Base class for vectorized environments to run multiple independent copies of the same environment in parallel. """Base class for vectorized environments to run multiple independent copies of the same environment in parallel.
Vector environments can provide a linear speed-up in the steps taken per second through sampling multiple Vector environments can provide a linear speed-up in the steps taken per second through sampling multiple
sub-environments at the same time. To prevent terminated environments waiting until all sub-environments have sub-environments at the same time. To prevent terminated environments waiting until all sub-environments have
terminated or truncated, the vector environments autoreset sub-environments after they terminate or truncated. terminated or truncated, the vector environments automatically reset sub-environments after they terminate or truncated (within the same step call).
As a result, the final step's observation and info are overwritten by the reset's observation and info. As a result, the step's observation and info are overwritten by the reset's observation and info.
Therefore, the observation and info for the final step of a sub-environment is stored in the info parameter, To preserve this data, the observation and info for the final step of a sub-environment is stored in the info parameter,
using `"final_observation"` and `"final_info"` respectively. See :meth:`step` for more information. using `"final_observation"` and `"final_info"` respectively. See :meth:`step` for more information.
The vector environments batch `observations`, `rewards`, `terminations`, `truncations` and `info` for each The vector environments batches `observations`, `rewards`, `terminations`, `truncations` and `info` for each
parallel environment. In addition, :meth:`step` expects to receive a batch of actions for each parallel environment. sub-environment. In addition, :meth:`step` expects to receive a batch of actions for each parallel environment.
Gymnasium contains two types of Vector environments: :class:`AsyncVectorEnv` and :class:`SyncVectorEnv`. Gymnasium contains two generalised Vector environments: :class:`AsyncVectorEnv` and :class:`SyncVectorEnv` along with
several custom vector environment implementations.
The Vector Environments have the additional attributes for users to understand the implementation The Vector Environments have the additional attributes for users to understand the implementation
@@ -34,89 +50,67 @@ class VectorEnv(gym.Env):
- :attr:`action_space` - The batched action space of the vector environment - :attr:`action_space` - The batched action space of the vector environment
- :attr:`single_action_space` - The action space of a single sub-environment - :attr:`single_action_space` - The action space of a single sub-environment
Note: Examples:
The info parameter of :meth:`reset` and :meth:`step` was originally implemented before OpenAI Gym v25 was a list >>> import gymnasium as gym
of dictionary for each sub-environment. However, this was modified in OpenAI Gym v25+ and in Gymnasium to a >>> envs = gym.make_vec("CartPole-v1", num_envs=3, vectorization_mode="sync", wrappers=(gym.wrappers.TimeAwareObservation,))
dictionary with a NumPy array for each key. To use the old info style using the :class:`VectorListInfo`. >>> envs = gym.wrappers.vector.ClipReward(envs, min_reward=0.2, max_reward=0.8)
>>> envs
<ClipReward, SyncVectorEnv(CartPole-v1, num_envs=3)>
>>> observations, infos = envs.reset(seed=123)
>>> observations
array([[ 0.01823519, -0.0446179 , -0.02796401, -0.03156282, 0. ],
[ 0.02852531, 0.02858594, 0.0469136 , 0.02480598, 0. ],
[ 0.03517495, -0.000635 , -0.01098382, -0.03203924, 0. ]])
>>> infos
{}
>>> _ = envs.action_space.seed(123)
>>> observations, rewards, terminations, truncations, infos = envs.step(envs.action_space.sample())
>>> observations
array([[ 0.01734283, 0.15089367, -0.02859527, -0.33293587, 1. ],
[ 0.02909703, -0.16717631, 0.04740972, 0.3319138 , 1. ],
[ 0.03516225, -0.19559774, -0.01162461, 0.25715804, 1. ]])
>>> rewards
array([0.8, 0.8, 0.8])
>>> terminations
array([False, False, False])
>>> truncations
array([False, False, False])
>>> infos
{}
>>> envs.close()
Note: Note:
To render the sub-environments, use :meth:`call` with "render" arguments. Remember to set the `render_modes` The info parameter of :meth:`reset` and :meth:`step` was originally implemented before v0.25 as a list
for all the sub-environments during initialization. of dictionary for each sub-environment. However, this was modified in v0.25+ to be a
dictionary with a NumPy array for each key. To use the old info style, utilise the :class:`DictInfoToList` wrapper.
Note: Note:
All parallel environments should share the identical observation and action spaces. All parallel environments should share the identical observation and action spaces.
In other words, a vector of multiple different environments is not supported. In other words, a vector of multiple different environments is not supported.
Note:
:func:`make_vec` is the equivalent function to :func:`make` for vector environments.
""" """
def __init__( spec: EnvSpec | None = None
self, render_mode: str | None = None
num_envs: int, closed: bool = False
observation_space: gym.Space,
action_space: gym.Space,
):
"""Base class for vectorized environments.
Args: observation_space: gym.Space
num_envs: Number of environments in the vectorized environment. action_space: gym.Space
observation_space: Observation space of a single environment. single_observation_space: gym.Space
action_space: Action space of a single environment. single_action_space: gym.Space
"""
self.num_envs = num_envs
self.is_vector_env = True
self.observation_space = batch_space(observation_space, n=num_envs)
self.action_space = batch_space(action_space, n=num_envs)
self.closed = False num_envs: int
self.viewer = None
# The observation and action spaces of a single environment are _np_random: np.random.Generator | None = None
# kept in separate properties
self.single_observation_space = observation_space
self.single_action_space = action_space
def reset_async(
self,
seed: Optional[Union[int, List[int]]] = None,
options: Optional[dict] = None,
):
"""Reset the sub-environments asynchronously.
This method will return ``None``. A call to :meth:`reset_async` should be followed
by a call to :meth:`reset_wait` to retrieve the results.
Args:
seed: The reset seed
options: Reset options
"""
pass
def reset_wait(
self,
seed: Optional[Union[int, List[int]]] = None,
options: Optional[dict] = None,
):
"""Retrieves the results of a :meth:`reset_async` call.
A call to this method must always be preceded by a call to :meth:`reset_async`.
Args:
seed: The reset seed
options: Reset options
Returns:
The results from :meth:`reset_async`
Raises:
NotImplementedError: VectorEnv does not implement function
"""
raise NotImplementedError("VectorEnv does not implement function")
def reset( def reset(
self, self,
*, *,
seed: Optional[Union[int, List[int]]] = None, seed: int | list[int] | None = None,
options: Optional[dict] = None, options: dict[str, Any] | None = None,
): ) -> tuple[ObsType, dict[str, Any]]: # type: ignore
"""Reset all parallel environments and return a batch of initial observations and info. """Reset all parallel environments and return a batch of initial observations and info.
Args: Args:
@@ -128,47 +122,26 @@ class VectorEnv(gym.Env):
Example: Example:
>>> import gymnasium as gym >>> import gymnasium as gym
>>> envs = gym.vector.make("CartPole-v1", num_envs=3) >>> envs = gym.make_vec("CartPole-v1", num_envs=3, vectorization_mode="sync")
>>> envs.reset(seed=42) >>> observations, infos = envs.reset(seed=42)
(array([[ 0.0273956 , -0.00611216, 0.03585979, 0.0197368 ], >>> observations
array([[ 0.0273956 , -0.00611216, 0.03585979, 0.0197368 ],
[ 0.01522993, -0.04562247, -0.04799704, 0.03392126], [ 0.01522993, -0.04562247, -0.04799704, 0.03392126],
[-0.03774345, -0.02418869, -0.00942293, 0.0469184 ]], [-0.03774345, -0.02418869, -0.00942293, 0.0469184 ]],
dtype=float32), {}) dtype=float32)
>>> infos
{}
""" """
self.reset_async(seed=seed, options=options) if seed is not None:
return self.reset_wait(seed=seed, options=options) self._np_random, seed = seeding.np_random(seed)
def step_async(self, actions):
"""Asynchronously performs steps in the sub-environments.
The results can be retrieved via a call to :meth:`step_wait`.
Args:
actions: The actions to take asynchronously
"""
def step_wait(
self, **kwargs
) -> Tuple[Any, NDArray[Any], NDArray[Any], NDArray[Any], dict]:
"""Retrieves the results of a :meth:`step_async` call.
A call to this method must always be preceded by a call to :meth:`step_async`.
Args:
**kwargs: Additional keywords for vector implementation
Returns:
The results from the :meth:`step_async` call
"""
raise NotImplementedError()
def step( def step(
self, actions self, actions: ActType
) -> Tuple[Any, NDArray[Any], NDArray[Any], NDArray[Any], dict]: ) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict[str, Any]]:
"""Take an action for each parallel environment. """Take an action for each parallel environment.
Args: Args:
actions: element of :attr:`action_space` Batch of actions. actions: Batch of actions with the :attr:`action_space` shape.
Returns: Returns:
Batch of (observations, rewards, terminations, truncations, infos) Batch of (observations, rewards, terminations, truncations, infos)
@@ -181,10 +154,10 @@ class VectorEnv(gym.Env):
Example: Example:
>>> import gymnasium as gym >>> import gymnasium as gym
>>> import numpy as np >>> import numpy as np
>>> envs = gym.vector.make("CartPole-v1", num_envs=3) >>> envs = gym.make_vec("CartPole-v1", num_envs=3, vectorization_mode="sync")
>>> _ = envs.reset(seed=42) >>> _ = envs.reset(seed=42)
>>> actions = np.array([1, 0, 1]) >>> actions = np.array([1, 0, 1], dtype=np.int32)
>>> observations, rewards, termination, truncation, infos = envs.step(actions) >>> observations, rewards, terminations, truncations, infos = envs.step(actions)
>>> observations >>> observations
array([[ 0.02727336, 0.18847767, 0.03625453, -0.26141977], array([[ 0.02727336, 0.18847767, 0.03625453, -0.26141977],
[ 0.01431748, -0.24002443, -0.04731862, 0.3110827 ], [ 0.01431748, -0.24002443, -0.04731862, 0.3110827 ],
@@ -192,62 +165,25 @@ class VectorEnv(gym.Env):
dtype=float32) dtype=float32)
>>> rewards >>> rewards
array([1., 1., 1.]) array([1., 1., 1.])
>>> termination >>> terminations
array([False, False, False]) array([False, False, False])
>>> truncation >>> terminations
array([False, False, False]) array([False, False, False])
>>> infos >>> infos
{} {}
""" """
self.step_async(actions)
return self.step_wait()
def call_async(self, name, *args, **kwargs): def render(self) -> tuple[RenderFrame, ...] | None:
"""Calls a method name for each parallel environment asynchronously.""" """Returns the rendered frames from the parallel environments.
def call_wait(self, **kwargs) -> List[Any]: # type: ignore
"""After calling a method in :meth:`call_async`, this function collects the results."""
def call(self, name: str, *args, **kwargs) -> List[Any]:
"""Call a method, or get a property, from each parallel environment.
Args:
name (str): Name of the method or property to call.
*args: Arguments to apply to the method call.
**kwargs: Keyword arguments to apply to the method call.
Returns: Returns:
List of the results of the individual calls to the method or property for each environment. A tuple of rendered frames from the parallel environments
""" """
self.call_async(name, *args, **kwargs) raise NotImplementedError(
return self.call_wait() f"{self.__str__()} render function is not implemented."
)
def get_attr(self, name: str): def close(self, **kwargs: Any):
"""Get a property from each parallel environment.
Args:
name (str): Name of the property to be get from each individual environment.
Returns:
The property with name
"""
return self.call(name)
def set_attr(self, name: str, values: Union[list, tuple, object]):
"""Set a property in each sub-environment.
Args:
name (str): Name of the property to be set in each individual environment.
values (list, tuple, or object): Values of the property to be set to. If `values` is a list or
tuple, then it corresponds to the values for each individual environment, otherwise a single value
is set for all environments.
"""
def close_extras(self, **kwargs):
"""Clean up the extra resources e.g. beyond what's in this base class."""
pass
def close(self, **kwargs):
"""Close all parallel environments and release resources. """Close all parallel environments and release resources.
It also closes all the existing image viewers, then calls :meth:`close_extras` and set It also closes all the existing image viewers, then calls :meth:`close_extras` and set
@@ -266,12 +202,37 @@ class VectorEnv(gym.Env):
""" """
if self.closed: if self.closed:
return return
if self.viewer is not None:
self.viewer.close()
self.close_extras(**kwargs) self.close_extras(**kwargs)
self.closed = True self.closed = True
def _add_info(self, infos: dict, info: dict, env_num: int) -> dict: def close_extras(self, **kwargs: Any):
"""Clean up the extra resources e.g. beyond what's in this base class."""
pass
@property
def np_random(self) -> np.random.Generator:
"""Returns the environment's internal :attr:`_np_random` that if not set will initialise with a random seed.
Returns:
Instances of `np.random.Generator`
"""
if self._np_random is None:
self._np_random, seed = seeding.np_random()
return self._np_random
@np_random.setter
def np_random(self, value: np.random.Generator):
self._np_random = value
@property
def unwrapped(self):
"""Return the base environment."""
return self
def _add_info(
self, infos: dict[str, Any], info: dict[str, Any], env_num: int
) -> dict[str, Any]:
"""Add env info to the info dictionary of the vectorized environment. """Add env info to the info dictionary of the vectorized environment.
Given the `info` of a single environment add it to the `infos` dictionary Given the `info` of a single environment add it to the `infos` dictionary
@@ -298,7 +259,7 @@ class VectorEnv(gym.Env):
infos[k], infos[f"_{k}"] = info_array, array_mask infos[k], infos[f"_{k}"] = info_array, array_mask
return infos return infos
def _init_info_arrays(self, dtype: type) -> Tuple[np.ndarray, np.ndarray]: def _init_info_arrays(self, dtype: type) -> tuple[np.ndarray, np.ndarray]:
"""Initialize the info array. """Initialize the info array.
Initialize the info array. If the dtype is numeric Initialize the info array. If the dtype is numeric
@@ -335,12 +296,14 @@ class VectorEnv(gym.Env):
A string containing the class name, number of environments and environment spec id A string containing the class name, number of environments and environment spec id
""" """
if self.spec is None: if self.spec is None:
return f"{self.__class__.__name__}({self.num_envs})" return f"{self.__class__.__name__}(num_envs={self.num_envs})"
else: else:
return f"{self.__class__.__name__}({self.spec.id}, {self.num_envs})" return (
f"{self.__class__.__name__}({self.spec.id}, num_envs={self.num_envs})"
)
class VectorEnvWrapper(VectorEnv): class VectorWrapper(VectorEnv):
"""Wraps the vectorized environment to allow a modular transformation. """Wraps the vectorized environment to allow a modular transformation.
This class is the base class for all wrappers for vectorized environments. The subclass This class is the base class for all wrappers for vectorized environments. The subclass
@@ -352,47 +315,223 @@ class VectorEnvWrapper(VectorEnv):
""" """
def __init__(self, env: VectorEnv): def __init__(self, env: VectorEnv):
assert isinstance(env, VectorEnv) """Initialize the vectorized environment wrapper.
Args:
env: The environment to wrap
"""
self.env = env self.env = env
assert isinstance(env, VectorEnv)
# explicitly forward the methods defined in VectorEnv self._observation_space: gym.Space | None = None
# to self.env (instead of the base class) self._action_space: gym.Space | None = None
def reset_async(self, **kwargs): self._single_observation_space: gym.Space | None = None
return self.env.reset_async(**kwargs) self._single_action_space: gym.Space | None = None
def reset_wait(self, **kwargs): def reset(
return self.env.reset_wait(**kwargs) self,
*,
seed: int | list[int] | None = None,
options: dict[str, Any] | None = None,
) -> tuple[ObsType, dict[str, Any]]:
"""Reset all environment using seed and options."""
return self.env.reset(seed=seed, options=options)
def step_async(self, actions): def step(
return self.env.step_async(actions) self, actions: ActType
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict[str, Any]]:
"""Step through all environments using the actions returning the batched data."""
return self.env.step(actions)
def step_wait(self): def render(self) -> tuple[RenderFrame, ...] | None:
return self.env.step_wait() """Returns the render mode from the base vector environment."""
return self.env.render()
def close(self, **kwargs): def close(self, **kwargs: Any):
"""Close all environments."""
return self.env.close(**kwargs) return self.env.close(**kwargs)
def close_extras(self, **kwargs): def close_extras(self, **kwargs: Any):
"""Close all extra resources."""
return self.env.close_extras(**kwargs) return self.env.close_extras(**kwargs)
def call(self, name, *args, **kwargs):
return self.env.call(name, *args, **kwargs)
def set_attr(self, name, values):
return self.env.set_attr(name, values)
# implicitly forward all other methods and attributes to self.env
def __getattr__(self, name):
if name.startswith("_"):
raise AttributeError(f"attempted to get missing private attribute '{name}'")
return getattr(self.env, name)
@property @property
def unwrapped(self): def unwrapped(self):
"""Return the base non-wrapped environment."""
return self.env.unwrapped return self.env.unwrapped
def __repr__(self): def __repr__(self):
"""Return the string representation of the vectorized environment."""
return f"<{self.__class__.__name__}, {self.env}>" return f"<{self.__class__.__name__}, {self.env}>"
def __del__(self): @property
self.env.__del__() def spec(self) -> EnvSpec | None:
"""Gets the specification of the wrapped environment."""
return self.env.spec
@property
def observation_space(self) -> gym.Space:
"""Gets the observation space of the vector environment."""
if self._observation_space is None:
return self.env.observation_space
return self._observation_space
@observation_space.setter
def observation_space(self, space: gym.Space):
"""Sets the observation space of the vector environment."""
self._observation_space = space
@property
def action_space(self) -> gym.Space:
"""Gets the action space of the vector environment."""
if self._action_space is None:
return self.env.action_space
return self._action_space
@action_space.setter
def action_space(self, space: gym.Space):
"""Sets the action space of the vector environment."""
self._action_space = space
@property
def single_observation_space(self) -> gym.Space:
"""Gets the single observation space of the vector environment."""
if self._single_observation_space is None:
return self.env.single_observation_space
return self._single_observation_space
@single_observation_space.setter
def single_observation_space(self, space: gym.Space):
"""Sets the single observation space of the vector environment."""
self._single_observation_space = space
@property
def single_action_space(self) -> gym.Space:
"""Gets the single action space of the vector environment."""
if self._single_action_space is None:
return self.env.single_action_space
return self._single_action_space
@single_action_space.setter
def single_action_space(self, space):
"""Sets the single action space of the vector environment."""
self._single_action_space = space
@property
def num_envs(self) -> int:
"""Gets the wrapped vector environment's num of the sub-environments."""
return self.env.num_envs
@property
def render_mode(self) -> tuple[RenderFrame, ...] | None:
"""Returns the `render_mode` from the base environment."""
return self.env.render_mode
class VectorObservationWrapper(VectorWrapper):
"""Wraps the vectorized environment to allow a modular transformation of the observation.
Equivalent to :class:`gymnasium.ObservationWrapper` for vectorized environments.
"""
def reset(
self,
*,
seed: int | list[int] | None = None,
options: dict[str, Any] | None = None,
) -> tuple[ObsType, dict[str, Any]]:
"""Modifies the observation returned from the environment ``reset`` using the :meth:`observation`."""
obs, info = self.env.reset(seed=seed, options=options)
return self.vector_observation(obs), info
def step(
self, actions: ActType
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict[str, Any]]:
"""Modifies the observation returned from the environment ``step`` using the :meth:`observation`."""
observation, reward, termination, truncation, info = self.env.step(actions)
return (
self.vector_observation(observation),
reward,
termination,
truncation,
self.update_final_obs(info),
)
def vector_observation(self, observation: ObsType) -> ObsType:
"""Defines the vector observation transformation.
Args:
observation: A vector observation from the environment
Returns:
the transformed observation
"""
raise NotImplementedError
def single_observation(self, observation: ObsType) -> ObsType:
"""Defines the single observation transformation.
Args:
observation: A single observation from the environment
Returns:
The transformed observation
"""
raise NotImplementedError
def update_final_obs(self, info: dict[str, Any]) -> dict[str, Any]:
"""Updates the `final_obs` in the info using `single_observation`."""
if "final_observation" in info:
for i, obs in enumerate(info["final_observation"]):
if obs is not None:
info["final_observation"][i] = self.single_observation(obs)
return info
class VectorActionWrapper(VectorWrapper):
"""Wraps the vectorized environment to allow a modular transformation of the actions.
Equivalent of :class:`gymnasium.ActionWrapper` for vectorized environments.
"""
def step(
self, actions: ActType
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict[str, Any]]:
"""Steps through the environment using a modified action by :meth:`action`."""
return self.env.step(self.actions(actions))
def actions(self, actions: ActType) -> ActType:
"""Transform the actions before sending them to the environment.
Args:
actions (ActType): the actions to transform
Returns:
ActType: the transformed actions
"""
raise NotImplementedError
class VectorRewardWrapper(VectorWrapper):
"""Wraps the vectorized environment to allow a modular transformation of the reward.
Equivalent of :class:`gymnasium.RewardWrapper` for vectorized environments.
"""
def step(
self, actions: ActType
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict[str, Any]]:
"""Steps through the environment returning a reward modified by :meth:`reward`."""
observation, reward, termination, truncation, info = self.env.step(actions)
return observation, self.rewards(reward), termination, truncation, info
def rewards(self, reward: ArrayType) -> ArrayType:
"""Transform the reward before returning it.
Args:
reward (array): the reward to transform
Returns:
array: the transformed reward
"""
raise NotImplementedError

View File

@@ -1,18 +0,0 @@
# Wrappers
Wrappers are used to transform an environment in a modular way:
```python
import gymnasium as gym
env = gym.make('CartPole-v1')
env = MyWrapper(env)
```
## Quick tips for writing your own wrapper
- Don't forget to call `super(class_name, self).__init__(env)` if you override the wrapper's `__init__` function
- You can access the inner environment with `self.unwrapped`
- You can access the previous wrapper using `self.env`
- The variables `metadata`, `action_space`, `observation_space`, `reward_range`, and `spec` are copied to `self` from the previous layer
- Create a wrapped function for at least one of the following: `__init__(self, env)`, `step`, `reset`, `render`, `close`, or `seed`
- Your layered function should take its input from the previous layer (`self.env`) and/or the inner layer (`self.unwrapped`)

View File

@@ -1,9 +1,8 @@
"""Module of wrapper classes. """Wrappers are a convenient way to modify an existing environment without having to alter the underlying code directly.
Wrappers are a convenient way to modify an existing environment without having to alter the underlying code directly. Using wrappers will allow you to avoid a lot of boilerplate code and make your environment more modular.
Using wrappers will allow you to avoid a lot of boilerplate code and make your environment more modular. Wrappers can Importantly wrappers can be chained to combine their effects and most environments that are generated via
also be chained to combine their effects. :meth:`gymnasium.make` will already be wrapped by default.
Most environments that are generated via :meth:`gymnasium.make` will already be wrapped by default.
In order to wrap an environment, you must first initialize a base environment. Then you can pass this environment along In order to wrap an environment, you must first initialize a base environment. Then you can pass this environment along
with (possibly optional) parameters to the wrapper's constructor. with (possibly optional) parameters to the wrapper's constructor.
@@ -46,27 +45,135 @@ If you need a wrapper to do more complicated tasks, you can inherit from the :cl
If you'd like to implement your own custom wrapper, check out `the corresponding tutorial <../../tutorials/implementing_custom_wrappers>`_. If you'd like to implement your own custom wrapper, check out `the corresponding tutorial <../../tutorials/implementing_custom_wrappers>`_.
""" """
# pyright: reportUnsupportedDunderAll=false
import importlib
import re
from gymnasium.error import DeprecatedWrapper
from gymnasium.wrappers import vector
from gymnasium.wrappers.atari_preprocessing import AtariPreprocessing from gymnasium.wrappers.atari_preprocessing import AtariPreprocessing
from gymnasium.wrappers.autoreset import AutoResetWrapper from gymnasium.wrappers.common import (
from gymnasium.wrappers.clip_action import ClipAction Autoreset,
from gymnasium.wrappers.compatibility import EnvCompatibility OrderEnforcing,
from gymnasium.wrappers.env_checker import PassiveEnvChecker PassiveEnvChecker,
from gymnasium.wrappers.filter_observation import FilterObservation RecordEpisodeStatistics,
from gymnasium.wrappers.flatten_observation import FlattenObservation TimeLimit,
from gymnasium.wrappers.frame_stack import FrameStack, LazyFrames )
from gymnasium.wrappers.gray_scale_observation import GrayScaleObservation from gymnasium.wrappers.rendering import HumanRendering, RecordVideo, RenderCollection
from gymnasium.wrappers.human_rendering import HumanRendering from gymnasium.wrappers.stateful_action import StickyAction
from gymnasium.wrappers.normalize import NormalizeObservation, NormalizeReward from gymnasium.wrappers.stateful_observation import (
from gymnasium.wrappers.order_enforcing import OrderEnforcing DelayObservation,
from gymnasium.wrappers.pixel_observation import PixelObservationWrapper FrameStackObservation,
from gymnasium.wrappers.record_episode_statistics import RecordEpisodeStatistics MaxAndSkipObservation,
from gymnasium.wrappers.record_video import RecordVideo, capped_cubic_video_schedule NormalizeObservation,
from gymnasium.wrappers.render_collection import RenderCollection TimeAwareObservation,
from gymnasium.wrappers.rescale_action import RescaleAction )
from gymnasium.wrappers.resize_observation import ResizeObservation from gymnasium.wrappers.stateful_reward import NormalizeReward
from gymnasium.wrappers.step_api_compatibility import StepAPICompatibility from gymnasium.wrappers.transform_action import (
from gymnasium.wrappers.time_aware_observation import TimeAwareObservation ClipAction,
from gymnasium.wrappers.time_limit import TimeLimit RescaleAction,
from gymnasium.wrappers.transform_observation import TransformObservation TransformAction,
from gymnasium.wrappers.transform_reward import TransformReward )
from gymnasium.wrappers.vector_list_info import VectorListInfo from gymnasium.wrappers.transform_observation import (
DtypeObservation,
FilterObservation,
FlattenObservation,
GrayscaleObservation,
RenderObservation,
RescaleObservation,
ReshapeObservation,
ResizeObservation,
TransformObservation,
)
from gymnasium.wrappers.transform_reward import ClipReward, TransformReward
__all__ = [
"vector",
# --- Observation wrappers ---
"AtariPreprocessing",
"DelayObservation",
"DtypeObservation",
"FilterObservation",
"FlattenObservation",
"FrameStackObservation",
"GrayscaleObservation",
"TransformObservation",
"MaxAndSkipObservation",
"NormalizeObservation",
"RenderObservation",
"ResizeObservation",
"ReshapeObservation",
"RescaleObservation",
"TimeAwareObservation",
# --- Action Wrappers ---
"ClipAction",
"TransformAction",
"RescaleAction",
# "NanAction",
"StickyAction",
# --- Reward wrappers ---
"ClipReward",
"TransformReward",
"NormalizeReward",
# --- Common ---
"TimeLimit",
"Autoreset",
"PassiveEnvChecker",
"OrderEnforcing",
"RecordEpisodeStatistics",
# --- Rendering ---
"RenderCollection",
"RecordVideo",
"HumanRendering",
# --- Conversion ---
"JaxToNumpy",
"JaxToTorch",
"NumpyToTorch",
]
# As these wrappers requires `jax` or `torch`, they are loaded by runtime for users trying to access them
# to avoid `import jax` or `import torch` on `import gymnasium`.
_wrapper_to_class = {
# data converters
"JaxToNumpy": "jax_to_numpy",
"JaxToTorch": "jax_to_torch",
"NumpyToTorch": "numpy_to_torch",
}
_renamed_wrapper = {
"AutoResetWrapper": "Autoreset",
"FrameStack": "FrameStackObservation",
"PixelObservationWrapper": "RenderObservation",
"VectorListInfo": "vector.DictInfoToList",
}
def __getattr__(wrapper_name: str):
"""Load a wrapper by name.
This optimizes the loading of gymnasium wrappers by only loading the wrapper if it is used.
Errors will be raised if the wrapper does not exist or if the version is not the latest.
Args:
wrapper_name: The name of a wrapper to load.
Returns:
The specified wrapper.
Raises:
AttributeError: If the wrapper does not exist.
DeprecatedWrapper: If the version is not the latest.
"""
# Check if the requested wrapper is in the _wrapper_to_class dictionary
if wrapper_name in _wrapper_to_class:
import_stmt = f"gymnasium.wrappers.{_wrapper_to_class[wrapper_name]}"
module = importlib.import_module(import_stmt)
return getattr(module, wrapper_name)
elif wrapper_name in _renamed_wrapper:
raise AttributeError(
f"{wrapper_name!r} has been renamed with `wrappers.{_renamed_wrapper[wrapper_name]}`"
)
raise AttributeError(f"module {__name__!r} has no attribute {wrapper_name!r}")

View File

@@ -5,14 +5,14 @@ import gymnasium as gym
from gymnasium.spaces import Box from gymnasium.spaces import Box
try: __all__ = ["AtariPreprocessing"]
import cv2
except ImportError:
cv2 = None
class AtariPreprocessing(gym.Wrapper, gym.utils.RecordConstructorArgs): class AtariPreprocessing(gym.Wrapper, gym.utils.RecordConstructorArgs):
"""Atari 2600 preprocessing wrapper. """Implements the common preprocessing techniques for Atari environments (excluding frame stacking).
For frame stacking use :class:`gymnasium.wrappers.FrameStackObservation`.
No vector version of the wrapper exists
This class follows the guidelines in Machado et al. (2018), This class follows the guidelines in Machado et al. (2018),
"Revisiting the Arcade Learning Environment: Evaluation Protocols and Open Problems for General Agents". "Revisiting the Arcade Learning Environment: Evaluation Protocols and Open Problems for General Agents".
@@ -20,13 +20,22 @@ class AtariPreprocessing(gym.Wrapper, gym.utils.RecordConstructorArgs):
Specifically, the following preprocess stages applies to the atari environment: Specifically, the following preprocess stages applies to the atari environment:
- Noop Reset: Obtains the initial state by taking a random number of no-ops on reset, default max 30 no-ops. - Noop Reset: Obtains the initial state by taking a random number of no-ops on reset, default max 30 no-ops.
- Frame skipping: The number of frames skipped between steps, 4 by default - Frame skipping: The number of frames skipped between steps, 4 by default.
- Max-pooling: Pools over the most recent two observations from the frame skips - Max-pooling: Pools over the most recent two observations from the frame skips.
- Termination signal when a life is lost: When the agent losses a life during the environment, then the environment is terminated. - Termination signal when a life is lost: When the agent losses a life during the environment, then the environment is terminated.
Turned off by default. Not recommended by Machado et al. (2018). Turned off by default. Not recommended by Machado et al. (2018).
- Resize to a square image: Resizes the atari environment original observation shape from 210x180 to 84x84 by default - Resize to a square image: Resizes the atari environment original observation shape from 210x180 to 84x84 by default.
- Grayscale observation: If the observation is colour or greyscale, by default, greyscale. - Grayscale observation: Makes the observation greyscale, enabled by default.
- Scale observation: If to scale the observation between [0, 1) or [0, 255), by default, not scaled. - Grayscale new axis: Extends the last channel of the observation such that the image is 3-dimensional, not enabled by default.
- Scale observation: Whether to scale the observation between [0, 1) or [0, 255), not scaled by default.
Example:
>>> import gymnasium as gym # doctest: +SKIP
>>> env = gym.make("ALE/Adventure-v5") # doctest: +SKIP
>>> env = AtariPreprocessing(env, noop_max=10, frame_skip=0, screen_size=84, terminal_on_life_loss=True, grayscale_obs=False, grayscale_newaxis=False) # doctest: +SKIP
Change logs:
* Added in gym v0.12.2 (gym #1455)
""" """
def __init__( def __init__(
@@ -46,7 +55,7 @@ class AtariPreprocessing(gym.Wrapper, gym.utils.RecordConstructorArgs):
env (Env): The environment to apply the preprocessing env (Env): The environment to apply the preprocessing
noop_max (int): For No-op reset, the max number no-ops actions are taken at reset, to turn off, set to 0. noop_max (int): For No-op reset, the max number no-ops actions are taken at reset, to turn off, set to 0.
frame_skip (int): The number of frames between new observation the agents observations effecting the frequency at which the agent experiences the game. frame_skip (int): The number of frames between new observation the agents observations effecting the frequency at which the agent experiences the game.
screen_size (int): resize Atari frame screen_size (int): resize Atari frame.
terminal_on_life_loss (bool): `if True`, then :meth:`step()` returns `terminated=True` whenever a terminal_on_life_loss (bool): `if True`, then :meth:`step()` returns `terminated=True` whenever a
life is lost. life is lost.
grayscale_obs (bool): if True, then gray scale observation is returned, otherwise, RGB observation grayscale_obs (bool): if True, then gray scale observation is returned, otherwise, RGB observation
@@ -72,10 +81,13 @@ class AtariPreprocessing(gym.Wrapper, gym.utils.RecordConstructorArgs):
) )
gym.Wrapper.__init__(self, env) gym.Wrapper.__init__(self, env)
if cv2 is None: try:
import cv2 # noqa: F401
except ImportError:
raise gym.error.DependencyNotInstalled( raise gym.error.DependencyNotInstalled(
"opencv-python package not installed, run `pip install gymnasium[other]` to get dependencies for atari" "opencv-python package not installed, run `pip install gymnasium[other]` to get dependencies for atari"
) )
assert frame_skip > 0 assert frame_skip > 0
assert screen_size > 0 assert screen_size > 0
assert noop_max >= 0 assert noop_max >= 0
@@ -187,7 +199,9 @@ class AtariPreprocessing(gym.Wrapper, gym.utils.RecordConstructorArgs):
def _get_obs(self): def _get_obs(self):
if self.frame_skip > 1: # more efficient in-place pooling if self.frame_skip > 1: # more efficient in-place pooling
np.maximum(self.obs_buffer[0], self.obs_buffer[1], out=self.obs_buffer[0]) np.maximum(self.obs_buffer[0], self.obs_buffer[1], out=self.obs_buffer[0])
assert cv2 is not None
import cv2
obs = cv2.resize( obs = cv2.resize(
self.obs_buffer[0], self.obs_buffer[0],
(self.screen_size, self.screen_size), (self.screen_size, self.screen_size),

View File

@@ -1,86 +0,0 @@
"""Wrapper that autoreset environments when `terminated=True` or `truncated=True`."""
from __future__ import annotations
from copy import deepcopy
from typing import TYPE_CHECKING
import gymnasium as gym
if TYPE_CHECKING:
from gymnasium.envs.registration import EnvSpec
class AutoResetWrapper(gym.Wrapper, gym.utils.RecordConstructorArgs):
"""A class for providing an automatic reset functionality for gymnasium environments when calling :meth:`self.step`.
When calling step causes :meth:`Env.step` to return `terminated=True` or `truncated=True`, :meth:`Env.reset` is called,
and the return format of :meth:`self.step` is as follows: ``(new_obs, final_reward, final_terminated, final_truncated, info)``
with new step API and ``(new_obs, final_reward, final_done, info)`` with the old step API.
- ``new_obs`` is the first observation after calling :meth:`self.env.reset`
- ``final_reward`` is the reward after calling :meth:`self.env.step`, prior to calling :meth:`self.env.reset`.
- ``final_terminated`` is the terminated value before calling :meth:`self.env.reset`.
- ``final_truncated`` is the truncated value before calling :meth:`self.env.reset`. Both `final_terminated` and `final_truncated` cannot be False.
- ``info`` is a dict containing all the keys from the info dict returned by the call to :meth:`self.env.reset`,
with an additional key "final_observation" containing the observation returned by the last call to :meth:`self.env.step`
and "final_info" containing the info dict returned by the last call to :meth:`self.env.step`.
Warning:
When using this wrapper to collect rollouts, note that when :meth:`Env.step` returns `terminated` or `truncated`, a
new observation from after calling :meth:`Env.reset` is returned by :meth:`Env.step` alongside the
final reward, terminated and truncated state from the previous episode.
If you need the final state from the previous episode, you need to retrieve it via the
"final_observation" key in the info dict.
Make sure you know what you're doing if you use this wrapper!
"""
def __init__(self, env: gym.Env):
"""A class for providing an automatic reset functionality for gymnasium environments when calling :meth:`self.step`.
Args:
env (gym.Env): The environment to apply the wrapper
"""
gym.utils.RecordConstructorArgs.__init__(self)
gym.Wrapper.__init__(self, env)
def step(self, action):
"""Steps through the environment with action and resets the environment if a terminated or truncated signal is encountered.
Args:
action: The action to take
Returns:
The autoreset environment :meth:`step`
"""
obs, reward, terminated, truncated, info = self.env.step(action)
if terminated or truncated:
new_obs, new_info = self.env.reset()
assert (
"final_observation" not in new_info
), 'info dict cannot contain key "final_observation" '
assert (
"final_info" not in new_info
), 'info dict cannot contain key "final_info" '
new_info["final_observation"] = obs
new_info["final_info"] = info
obs = new_obs
info = new_info
return obs, reward, terminated, truncated, info
@property
def spec(self) -> EnvSpec | None:
"""Modifies the environment spec to specify the `autoreset=True`."""
if self._cached_spec is not None:
return self._cached_spec
env_spec = self.env.spec
if env_spec is not None:
env_spec = deepcopy(env_spec)
env_spec.autoreset = True
self._cached_spec = env_spec
return env_spec

View File

@@ -1,43 +0,0 @@
"""Wrapper for clipping actions within a valid bound."""
import numpy as np
import gymnasium as gym
from gymnasium.spaces import Box
class ClipAction(gym.ActionWrapper, gym.utils.RecordConstructorArgs):
"""Clip the continuous action within the valid :class:`Box` observation space bound.
Example:
>>> import gymnasium as gym
>>> from gymnasium.wrappers import ClipAction
>>> env = gym.make("Hopper-v4")
>>> env = ClipAction(env)
>>> env.action_space
Box(-1.0, 1.0, (3,), float32)
>>> _ = env.reset(seed=42)
>>> _ = env.step(np.array([5.0, -2.0, 0.0]))
... # Executes the action np.array([1.0, -1.0, 0]) in the base environment
"""
def __init__(self, env: gym.Env):
"""A wrapper for clipping continuous actions within the valid bound.
Args:
env: The environment to apply the wrapper
"""
assert isinstance(env.action_space, Box)
gym.utils.RecordConstructorArgs.__init__(self)
gym.ActionWrapper.__init__(self, env)
def action(self, action):
"""Clips the action within the valid bounds.
Args:
action: The action to clip
Returns:
The clipped action
"""
return np.clip(action, self.action_space.low, self.action_space.high)

View File

@@ -0,0 +1,536 @@
"""A collection of common wrappers.
* ``TimeLimit`` - Provides a time limit on the number of steps for an environment before it truncates
* ``Autoreset`` - Auto-resets the environment
* ``PassiveEnvChecker`` - Passive environment checker that does not modify any environment data
* ``OrderEnforcing`` - Enforces the order of function calls to environments
* ``RecordEpisodeStatistics`` - Records the episode statistics
"""
from __future__ import annotations
import time
from collections import deque
from copy import deepcopy
from typing import TYPE_CHECKING, Any, SupportsFloat
import gymnasium as gym
from gymnasium import logger
from gymnasium.core import ActType, ObsType, RenderFrame
from gymnasium.error import ResetNeeded
from gymnasium.utils.passive_env_checker import (
check_action_space,
check_observation_space,
env_render_passive_checker,
env_reset_passive_checker,
env_step_passive_checker,
)
if TYPE_CHECKING:
from gymnasium.envs.registration import EnvSpec
__all__ = [
"TimeLimit",
"Autoreset",
"PassiveEnvChecker",
"OrderEnforcing",
"RecordEpisodeStatistics",
]
class TimeLimit(
gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
):
"""Limits the number of steps for an environment through truncating the environment if a maximum number of timesteps is exceeded.
If a truncation is not defined inside the environment itself, this is the only place that the truncation signal is issued.
Critically, this is different from the `terminated` signal that originates from the underlying environment as part of the MDP.
No vector wrapper exists.
Example using the TimeLimit wrapper:
>>> from gymnasium.wrappers import TimeLimit
>>> from gymnasium.envs.classic_control import CartPoleEnv
>>> spec = gym.spec("CartPole-v1")
>>> spec.max_episode_steps
500
>>> env = gym.make("CartPole-v1")
>>> env # TimeLimit is included within the environment stack
<TimeLimit<OrderEnforcing<PassiveEnvChecker<CartPoleEnv<CartPole-v1>>>>>
>>> env.spec # doctest: +ELLIPSIS
EnvSpec(id='CartPole-v1', ..., max_episode_steps=500, ...)
>>> env = gym.make("CartPole-v1", max_episode_steps=3)
>>> env.spec # doctest: +ELLIPSIS
EnvSpec(id='CartPole-v1', ..., max_episode_steps=3, ...)
>>> env = TimeLimit(CartPoleEnv(), max_episode_steps=10)
>>> env
<TimeLimit<CartPoleEnv instance>>
Example of `TimeLimit` determining the episode step
>>> env = gym.make("CartPole-v1", max_episode_steps=3)
>>> _ = env.reset(seed=123)
>>> _ = env.action_space.seed(123)
>>> _, _, terminated, truncated, _ = env.step(env.action_space.sample())
>>> terminated, truncated
(False, False)
>>> _, _, terminated, truncated, _ = env.step(env.action_space.sample())
>>> terminated, truncated
(False, False)
>>> _, _, terminated, truncated, _ = env.step(env.action_space.sample())
>>> terminated, truncated
(False, True)
Change logs:
* v0.10.6 - Initially added
* v0.25.0 - With the step API update, the termination and truncation signal is returned separately.
"""
def __init__(
self,
env: gym.Env,
max_episode_steps: int,
):
"""Initializes the :class:`TimeLimit` wrapper with an environment and the number of steps after which truncation will occur.
Args:
env: The environment to apply the wrapper
max_episode_steps: An optional max episode steps (if ``None``, ``env.spec.max_episode_steps`` is used)
"""
gym.utils.RecordConstructorArgs.__init__(
self, max_episode_steps=max_episode_steps
)
gym.Wrapper.__init__(self, env)
self._max_episode_steps = max_episode_steps
self._elapsed_steps = None
def step(
self, action: ActType
) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
"""Steps through the environment and if the number of steps elapsed exceeds ``max_episode_steps`` then truncate.
Args:
action: The environment step action
Returns:
The environment step ``(observation, reward, terminated, truncated, info)`` with `truncated=True`
if the number of steps elapsed >= max episode steps
"""
observation, reward, terminated, truncated, info = self.env.step(action)
self._elapsed_steps += 1
if self._elapsed_steps >= self._max_episode_steps:
truncated = True
return observation, reward, terminated, truncated, info
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[ObsType, dict[str, Any]]:
"""Resets the environment with :param:`**kwargs` and sets the number of steps elapsed to zero.
Args:
seed: Seed for the environment
options: Options for the environment
Returns:
The reset environment
"""
self._elapsed_steps = 0
return super().reset(seed=seed, options=options)
@property
def spec(self) -> EnvSpec | None:
"""Modifies the environment spec to include the `max_episode_steps=self._max_episode_steps`."""
if self._cached_spec is not None:
return self._cached_spec
env_spec = self.env.spec
if env_spec is not None:
env_spec = deepcopy(env_spec)
env_spec.max_episode_steps = self._max_episode_steps
self._cached_spec = env_spec
return env_spec
class Autoreset(
gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
):
"""The wrapped environment is automatically reset when an terminated or truncated state is reached.
When calling step causes :meth:`Env.step` to return `terminated=True` or `truncated=True`, :meth:`Env.reset` is called,
and the return format of :meth:`self.step` is as follows: ``(new_obs, final_reward, final_terminated, final_truncated, info)``
with new step API and ``(new_obs, final_reward, final_done, info)`` with the old step API.
No vector version of the wrapper exists.
- ``obs`` is the first observation after calling :meth:`self.env.reset`
- ``final_reward`` is the reward after calling :meth:`self.env.step`, prior to calling :meth:`self.env.reset`.
- ``final_terminated`` is the terminated value before calling :meth:`self.env.reset`.
- ``final_truncated`` is the truncated value before calling :meth:`self.env.reset`. Both `final_terminated` and `final_truncated` cannot be False.
- ``info`` is a dict containing all the keys from the info dict returned by the call to :meth:`self.env.reset`,
with an additional key "final_observation" containing the observation returned by the last call to :meth:`self.env.step`
and "final_info" containing the info dict returned by the last call to :meth:`self.env.step`.
Warning:
When using this wrapper to collect rollouts, note that when :meth:`Env.step` returns `terminated` or `truncated`, a
new observation from after calling :meth:`Env.reset` is returned by :meth:`Env.step` alongside the
final reward, terminated and truncated state from the previous episode.
If you need the final state from the previous episode, you need to retrieve it via the
"final_observation" key in the info dict.
Make sure you know what you're doing if you use this wrapper!
Change logs:
* v0.24.0 - Initially added as `AutoResetWrapper`
* v1.0.0 - renamed to `Autoreset` and autoreset order was changed to reset on the step after the environment terminates or truncates. As a result, `"final_observation"` and `"final_info"` is removed.
"""
def __init__(self, env: gym.Env):
"""A class for providing an automatic reset functionality for gymnasium environments when calling :meth:`self.step`.
Args:
env (gym.Env): The environment to apply the wrapper
"""
gym.utils.RecordConstructorArgs.__init__(self)
gym.Wrapper.__init__(self, env)
def step(
self, action: ActType
) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
"""Steps through the environment with action and resets the environment if a terminated or truncated signal is encountered.
Args:
action: The action to take
Returns:
The autoreset environment :meth:`step`
"""
obs, reward, terminated, truncated, info = self.env.step(action)
if terminated or truncated:
new_obs, new_info = self.env.reset()
assert (
"final_observation" not in new_info
), f'new info dict already contains "final_observation", info keys: {new_info.keys()}'
assert (
"final_info" not in new_info
), f'new info dict already contains "final_observation", info keys: {new_info.keys()}'
new_info["final_observation"] = obs
new_info["final_info"] = info
obs = new_obs
info = new_info
return obs, reward, terminated, truncated, info
class PassiveEnvChecker(
gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
):
"""A passive wrapper that surrounds the ``step``, ``reset`` and ``render`` functions to check they follow Gymnasium's API.
This wrapper is automatically applied during make and can be disabled with `disable_env_checker`.
No vector version of the wrapper exists.
Example:
>>> import gymnasium as gym
>>> env = gym.make("CartPole-v1")
>>> env
<TimeLimit<OrderEnforcing<PassiveEnvChecker<CartPoleEnv<CartPole-v1>>>>>
>>> env = gym.make("CartPole-v1", disable_env_checker=True)
>>> env
<TimeLimit<OrderEnforcing<CartPoleEnv<CartPole-v1>>>>
Change logs:
* v0.24.1 - Initially added however broken in several ways
* v0.25.0 - Bugs was all fixed
* v0.29.0 - Removed warnings for infinite bounds for Box observation and action spaces and inregular bound shapes
"""
def __init__(self, env: gym.Env[ObsType, ActType]):
"""Initialises the wrapper with the environments, run the observation and action space tests."""
gym.utils.RecordConstructorArgs.__init__(self)
gym.Wrapper.__init__(self, env)
assert hasattr(
env, "action_space"
), "The environment must specify an action space. https://gymnasium.farama.org/tutorials/gymnasium_basics/environment_creation/"
check_action_space(env.action_space)
assert hasattr(
env, "observation_space"
), "The environment must specify an observation space. https://gymnasium.farama.org/tutorials/gymnasium_basics/environment_creation/"
check_observation_space(env.observation_space)
self.checked_reset: bool = False
self.checked_step: bool = False
self.checked_render: bool = False
self.close_called: bool = False
def step(
self, action: ActType
) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
"""Steps through the environment that on the first call will run the `passive_env_step_check`."""
if self.checked_step is False:
self.checked_step = True
return env_step_passive_checker(self.env, action)
else:
return self.env.step(action)
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[ObsType, dict[str, Any]]:
"""Resets the environment that on the first call will run the `passive_env_reset_check`."""
if self.checked_reset is False:
self.checked_reset = True
return env_reset_passive_checker(self.env, seed=seed, options=options)
else:
return self.env.reset(seed=seed, options=options)
def render(self) -> RenderFrame | list[RenderFrame] | None:
"""Renders the environment that on the first call will run the `passive_env_render_check`."""
if self.checked_render is False:
self.checked_render = True
return env_render_passive_checker(self.env)
else:
return self.env.render()
@property
def spec(self) -> EnvSpec | None:
"""Modifies the environment spec to such that `disable_env_checker=False`."""
if self._cached_spec is not None:
return self._cached_spec
env_spec = self.env.spec
if env_spec is not None:
env_spec = deepcopy(env_spec)
env_spec.disable_env_checker = False
self._cached_spec = env_spec
return env_spec
def close(self):
"""Warns if calling close on a closed environment fails."""
if not self.close_called:
self.close_called = True
return self.env.close()
else:
try:
return self.env.close()
except Exception as e:
logger.warn(
"Calling `env.close()` on the closed environment should be allowed, but it raised the following exception."
)
raise e
class OrderEnforcing(
gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
):
"""Will produce an error if ``step`` or ``render`` is called before ``reset``.
No vector version of the wrapper exists.
Example:
>>> import gymnasium as gym
>>> from gymnasium.wrappers import OrderEnforcing
>>> env = gym.make("CartPole-v1", render_mode="human")
>>> env = OrderEnforcing(env)
>>> env.step(0)
Traceback (most recent call last):
...
gymnasium.error.ResetNeeded: Cannot call env.step() before calling env.reset()
>>> env.render()
Traceback (most recent call last):
...
gymnasium.error.ResetNeeded: Cannot call `env.render()` before calling `env.reset()`, if this is an intended action, set `disable_render_order_enforcing=True` on the OrderEnforcer wrapper.
>>> _ = env.reset()
>>> env.render()
>>> _ = env.step(0)
>>> env.close()
Change logs:
* v0.22.0 - Initially added
* v0.24.0 - Added order enforcing for the render function
"""
def __init__(
self,
env: gym.Env[ObsType, ActType],
disable_render_order_enforcing: bool = False,
):
"""A wrapper that will produce an error if :meth:`step` is called before an initial :meth:`reset`.
Args:
env: The environment to wrap
disable_render_order_enforcing: If to disable render order enforcing
"""
gym.utils.RecordConstructorArgs.__init__(
self, disable_render_order_enforcing=disable_render_order_enforcing
)
gym.Wrapper.__init__(self, env)
self._has_reset: bool = False
self._disable_render_order_enforcing: bool = disable_render_order_enforcing
def step(self, action: ActType) -> tuple[ObsType, SupportsFloat, bool, bool, dict]:
"""Steps through the environment."""
if not self._has_reset:
raise ResetNeeded("Cannot call env.step() before calling env.reset()")
return super().step(action)
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[ObsType, dict[str, Any]]:
"""Resets the environment with `kwargs`."""
self._has_reset = True
return super().reset(seed=seed, options=options)
def render(self) -> RenderFrame | list[RenderFrame] | None:
"""Renders the environment with `kwargs`."""
if not self._disable_render_order_enforcing and not self._has_reset:
raise ResetNeeded(
"Cannot call `env.render()` before calling `env.reset()`, if this is an intended action, "
"set `disable_render_order_enforcing=True` on the OrderEnforcer wrapper."
)
return super().render()
@property
def has_reset(self):
"""Returns if the environment has been reset before."""
return self._has_reset
@property
def spec(self) -> EnvSpec | None:
"""Modifies the environment spec to add the `order_enforce=True`."""
if self._cached_spec is not None:
return self._cached_spec
env_spec = self.env.spec
if env_spec is not None:
env_spec = deepcopy(env_spec)
env_spec.order_enforce = True
self._cached_spec = env_spec
return env_spec
class RecordEpisodeStatistics(
gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
):
"""This wrapper will keep track of cumulative rewards and episode lengths.
At the end of an episode, the statistics of the episode will be added to ``info``
using the key ``episode``. If using a vectorized environment also the key
``_episode`` is used which indicates whether the env at the respective index has
the episode statistics.
A vector version of the wrapper exists, :class:`gymnasium.wrappers.vector.RecordEpisodeStatistics`.
After the completion of an episode, ``info`` will look like this::
>>> info = {
... "episode": {
... "r": "<cumulative reward>",
... "l": "<episode length>",
... "t": "<elapsed time since beginning of episode>"
... },
... }
For a vectorized environments the output will be in the form of::
>>> infos = {
... "final_observation": "<array of length num-envs>",
... "_final_observation": "<boolean array of length num-envs>",
... "final_info": "<array of length num-envs>",
... "_final_info": "<boolean array of length num-envs>",
... "episode": {
... "r": "<array of cumulative reward>",
... "l": "<array of episode length>",
... "t": "<array of elapsed time since beginning of episode>"
... },
... "_episode": "<boolean array of length num-envs>"
... }
Moreover, the most recent rewards and episode lengths are stored in buffers that can be accessed via
:attr:`wrapped_env.return_queue` and :attr:`wrapped_env.length_queue` respectively.
Attributes:
* time_queue: The time length of the last ``deque_size``-many episodes
* return_queue: The cumulative rewards of the last ``deque_size``-many episodes
* length_queue: The lengths of the last ``deque_size``-many episodes
Change logs:
* v0.15.4 - Initially added
* v1.0.0 - Removed vector environment support for `wrappers.vector.RecordEpisodeStatistics` and add attribute ``time_queue``
"""
def __init__(
self,
env: gym.Env[ObsType, ActType],
buffer_length: int | None = 100,
stats_key: str = "episode",
):
"""This wrapper will keep track of cumulative rewards and episode lengths.
Args:
env (Env): The environment to apply the wrapper
buffer_length: The size of the buffers :attr:`return_queue` and :attr:`length_queue`
stats_key: The info key for the episode statistics
"""
gym.utils.RecordConstructorArgs.__init__(self)
gym.Wrapper.__init__(self, env)
self._stats_key = stats_key
self.episode_count = 0
self.episode_start_time: float = -1
self.episode_returns: float = 0.0
self.episode_lengths: int = 0
self.time_queue: deque[float] = deque(maxlen=buffer_length)
self.return_queue: deque[float] = deque(maxlen=buffer_length)
self.length_queue: deque[int] = deque(maxlen=buffer_length)
def step(
self, action: ActType
) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
"""Steps through the environment, recording the episode statistics."""
obs, reward, terminated, truncated, info = super().step(action)
self.episode_returns += reward
self.episode_lengths += 1
if terminated or truncated:
assert self._stats_key not in info
episode_time_length = round(
time.perf_counter() - self.episode_start_time, 6
)
info[self._stats_key] = {
"r": self.episode_returns,
"l": self.episode_lengths,
"t": episode_time_length,
}
self.time_queue.append(episode_time_length)
self.return_queue.append(self.episode_returns)
self.length_queue.append(self.episode_lengths)
self.episode_count += 1
return obs, reward, terminated, truncated, info
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[ObsType, dict[str, Any]]:
"""Resets the environment using seed and options and resets the episode rewards and lengths."""
obs, info = super().reset(seed=seed, options=options)
self.episode_start_time = time.perf_counter()
self.episode_returns = 0.0
self.episode_lengths = 0
return obs, info

Some files were not shown because too many files have changed in this diff Show More