mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-07-31 22:04:31 +00:00
Merge v1.0.0 (#682)
Co-authored-by: Kallinteris Andreas <30759571+Kallinteris-Andreas@users.noreply.github.com> Co-authored-by: Jet <38184875+jjshoots@users.noreply.github.com> Co-authored-by: Omar Younis <42100908+younik@users.noreply.github.com>
This commit is contained in:
4
.github/workflows/build-docs.yml
vendored
4
.github/workflows/build-docs.yml
vendored
@@ -1,15 +1,19 @@
|
||||
name: Build main branch documentation website
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main]
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
|
||||
jobs:
|
||||
docs:
|
||||
name: Generate Website
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
SPHINX_GITHUB_CHANGELOG_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
|
||||
|
2
.github/workflows/build.yml
vendored
2
.github/workflows/build.yml
vendored
@@ -32,4 +32,4 @@ jobs:
|
||||
--tag gymnasium-necessary-docker .
|
||||
- name: Run tests
|
||||
run: |
|
||||
docker run gymnasium-necessary-docker pytest tests/test_core.py tests/envs/test_compatibility.py tests/envs/test_envs.py tests/spaces
|
||||
docker run gymnasium-necessary-docker pytest tests/test_core.py tests/envs/test_envs.py tests/spaces
|
||||
|
2
.github/workflows/docs-manual-versioning.yml
vendored
2
.github/workflows/docs-manual-versioning.yml
vendored
@@ -1,4 +1,5 @@
|
||||
name: Manual Docs Versioning
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
@@ -14,6 +15,7 @@ on:
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
|
||||
jobs:
|
||||
docs:
|
||||
name: Generate Website for new version
|
||||
|
3
.github/workflows/docs-versioning.yml
vendored
3
.github/workflows/docs-versioning.yml
vendored
@@ -1,10 +1,13 @@
|
||||
name: Docs Versioning
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'v?*.*.*'
|
||||
-
|
||||
permissions:
|
||||
contents: write
|
||||
|
||||
jobs:
|
||||
docs:
|
||||
name: Generate Website for new version
|
||||
|
@@ -51,7 +51,7 @@ repos:
|
||||
rev: 6.3.0
|
||||
hooks:
|
||||
- id: pydocstyle
|
||||
exclude: ^(gymnasium/envs/box2d)|(gymnasium/envs/classic_control)|(gymnasium/envs/mujoco)|(gymnasium/envs/toy_text)|(tests/envs)|(tests/spaces)|(tests/utils)|(tests/vector)|(tests/wrappers)|(docs/)
|
||||
exclude: ^(gymnasium/envs/box2d)|(gymnasium/envs/classic_control)|(gymnasium/envs/mujoco)|(gymnasium/envs/toy_text)|(tests/envs)|(tests/spaces)|(tests/utils)|(tests/vector)|(docs/)
|
||||
args:
|
||||
- --source
|
||||
- --explain
|
||||
|
73
docs/_scripts/gen_wrapper_table.py
Normal file
73
docs/_scripts/gen_wrapper_table.py
Normal file
@@ -0,0 +1,73 @@
|
||||
import os.path
|
||||
|
||||
import gymnasium as gym
|
||||
|
||||
|
||||
exclude_wrappers = {"vector"}
|
||||
|
||||
|
||||
def generate_wrappers():
|
||||
wrapper_table = ""
|
||||
for wrapper_name in sorted(gym.wrappers.__all__):
|
||||
if wrapper_name not in exclude_wrappers:
|
||||
wrapper_doc = getattr(gym.wrappers, wrapper_name).__doc__.split("\n")[0]
|
||||
wrapper_table += f""" * - :class:`{wrapper_name}`
|
||||
- {wrapper_doc}
|
||||
"""
|
||||
return wrapper_table
|
||||
|
||||
|
||||
def generate_vector_wrappers():
|
||||
unique_vector_wrappers = set(gym.wrappers.vector.__all__) - set(
|
||||
gym.wrappers.__all__
|
||||
)
|
||||
|
||||
vector_table = ""
|
||||
for vector_name in sorted(unique_vector_wrappers):
|
||||
vector_doc = getattr(gym.wrappers.vector, vector_name).__doc__.split("\n")[0]
|
||||
vector_table += f""" * - :class:`{vector_name}`
|
||||
- {vector_doc}
|
||||
"""
|
||||
return vector_table
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
gen_wrapper_table = generate_wrappers()
|
||||
gen_vector_table = generate_vector_wrappers()
|
||||
|
||||
page = f"""
|
||||
# List of Gymnasium Wrappers
|
||||
|
||||
Gymnasium provides a number of commonly used wrappers listed below. More information can be found on the particular
|
||||
wrapper in the page on the wrapper type
|
||||
|
||||
```{{eval-rst}}
|
||||
.. py:currentmodule:: gymnasium.wrappers
|
||||
|
||||
.. list-table::
|
||||
:header-rows: 1
|
||||
|
||||
* - Name
|
||||
- Description
|
||||
{gen_wrapper_table}
|
||||
```
|
||||
|
||||
## Vector only Wrappers
|
||||
|
||||
```{{eval-rst}}
|
||||
.. py:currentmodule:: gymnasium.wrappers.vector
|
||||
|
||||
.. list-table::
|
||||
:header-rows: 1
|
||||
|
||||
* - Name
|
||||
- Description
|
||||
{gen_vector_table}
|
||||
```
|
||||
"""
|
||||
|
||||
filename = os.path.join(
|
||||
os.path.dirname(__file__), "..", "api", "wrappers", "table.md"
|
||||
)
|
||||
with open(filename, "w") as file:
|
||||
file.write(page)
|
@@ -1,29 +1,26 @@
|
||||
---
|
||||
title: Utils
|
||||
title: Env
|
||||
---
|
||||
|
||||
# Env
|
||||
|
||||
## gymnasium.Env
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: gymnasium.Env
|
||||
```
|
||||
|
||||
### Methods
|
||||
|
||||
## Methods
|
||||
```{eval-rst}
|
||||
.. autofunction:: gymnasium.Env.step
|
||||
.. autofunction:: gymnasium.Env.reset
|
||||
.. autofunction:: gymnasium.Env.render
|
||||
.. automethod:: gymnasium.Env.step
|
||||
.. automethod:: gymnasium.Env.reset
|
||||
.. automethod:: gymnasium.Env.render
|
||||
.. automethod:: gymnasium.Env.close
|
||||
```
|
||||
|
||||
### Attributes
|
||||
|
||||
## Attributes
|
||||
```{eval-rst}
|
||||
.. autoattribute:: gymnasium.Env.action_space
|
||||
|
||||
The Space object corresponding to valid actions, all valid actions should be contained with the space. For example, if the action space is of type `Discrete` and gives the value `Discrete(2)`, this means there are two valid discrete actions: 0 & 1.
|
||||
The Space object corresponding to valid actions, all valid actions should be contained with the space. For example, if the action space is of type `Discrete` and gives the value `Discrete(2)`, this means there are two valid discrete actions: `0` & `1`.
|
||||
|
||||
.. code::
|
||||
|
||||
@@ -51,29 +48,26 @@ title: Utils
|
||||
|
||||
The render mode of the environment determined at initialisation
|
||||
|
||||
.. autoattribute:: gymnasium.Env.reward_range
|
||||
|
||||
A tuple corresponding to the minimum and maximum possible rewards for an agent over an episode. The default reward range is set to :math:`(-\infty,+\infty)`.
|
||||
|
||||
.. autoattribute:: gymnasium.Env.spec
|
||||
|
||||
The ``EnvSpec`` of the environment normally set during :py:meth:`gymnasium.make`
|
||||
```
|
||||
The :class:`EnvSpec` of the environment normally set during :py:meth:`gymnasium.make`
|
||||
|
||||
### Additional Methods
|
||||
|
||||
```{eval-rst}
|
||||
.. autofunction:: gymnasium.Env.close
|
||||
.. autoproperty:: gymnasium.Env.unwrapped
|
||||
.. autoproperty:: gymnasium.Env.np_random
|
||||
```
|
||||
|
||||
### Implementing environments
|
||||
## Implementing environments
|
||||
|
||||
```{eval-rst}
|
||||
.. py:currentmodule:: gymnasium
|
||||
|
||||
When implementing an environment, the :meth:`Env.reset` and :meth:`Env.step` functions much be created describing the
|
||||
dynamics of the environment.
|
||||
For more information see the environment creation tutorial.
|
||||
When implementing an environment, the :meth:`Env.reset` and :meth:`Env.step` functions much be created describing the dynamics of the environment. For more information see the environment creation tutorial.
|
||||
```
|
||||
|
||||
## Creating environments
|
||||
|
||||
```{eval-rst}
|
||||
.. py:currentmodule:: gymnasium
|
||||
|
||||
To create an environment, gymnasium provides :meth:`make` to initialise the environment along with several important wrappers. Furthermore, gymnasium provides :meth:`make_vec` for creating vector environments and to view all the environment that can be created use :meth:`pprint_registry`.
|
||||
```
|
||||
|
@@ -1,157 +0,0 @@
|
||||
---
|
||||
title: Experimental
|
||||
---
|
||||
|
||||
# Experimental
|
||||
|
||||
```{toctree}
|
||||
:hidden:
|
||||
experimental/functional
|
||||
experimental/wrappers
|
||||
experimental/vector
|
||||
experimental/vector_wrappers
|
||||
experimental/vector_utils
|
||||
```
|
||||
|
||||
## Functional Environments
|
||||
|
||||
```{eval-rst}
|
||||
The gymnasium ``Env`` provides high flexibility for the implementation of individual environments however this can complicate parallelism of environments. Therefore, we propose the :class:`gymnasium.experimental.FuncEnv` where each part of environment has its own function related to it.
|
||||
```
|
||||
|
||||
## Wrappers
|
||||
|
||||
Gymnasium already contains a large collection of wrappers, but we believe that the wrappers can be improved to
|
||||
|
||||
* (Work in progress) Support arbitrarily complex observation / action spaces. As RL has advanced, action and observation spaces are becoming more complex and the current wrappers were not implemented with this mind.
|
||||
* Support for Jax-based environments. With hardware accelerated environments, i.e. Brax, written in Jax and similar PyTorch based programs, NumPy is not the only game in town anymore. Therefore, these upgrades will use [Jumpy](https://github.com/farama-Foundation/jumpy), a project developed by Farama Foundation to provide automatic compatibility for NumPy, Jax and in the future PyTorch data for a large subset of the NumPy functions.
|
||||
* More wrappers. Projects like [Supersuit](https://github.com/farama-Foundation/supersuit) aimed to bring more wrappers for RL, however, many users were not aware of the wrappers, so we plan to move the wrappers into Gymnasium. If we are missing common wrappers from the list provided above, please create an issue.
|
||||
* Versioning. Like environments, the implementation details of wrappers can cause changes in agent performance. Therefore, we propose adding version numbers to all wrappers, i.e., `LambaActionV0`. We don't expect these version numbers to change regularly similar to environment version numbers and should ensure that all users know when significant changes could affect your agent's performance. Additionally, we hope that this will improve reproducibility of RL in the future, this is critical for academia.
|
||||
* In v28, we aim to rewrite the VectorEnv to not inherit from Env, as a result new vectorized versions of the wrappers will be provided.
|
||||
|
||||
We aimed to replace the wrappers in gymnasium v0.30.0 with these experimental wrappers.
|
||||
|
||||
### Observation Wrappers
|
||||
```{eval-rst}
|
||||
.. py:currentmodule:: gymnasium
|
||||
|
||||
.. list-table::
|
||||
:header-rows: 1
|
||||
|
||||
* - Old name
|
||||
- New name
|
||||
* - :class:`wrappers.TransformObservation`
|
||||
- :class:`experimental.wrappers.LambdaObservationV0`
|
||||
* - :class:`wrappers.FilterObservation`
|
||||
- :class:`experimental.wrappers.FilterObservationV0`
|
||||
* - :class:`wrappers.FlattenObservation`
|
||||
- :class:`experimental.wrappers.FlattenObservationV0`
|
||||
* - :class:`wrappers.GrayScaleObservation`
|
||||
- :class:`experimental.wrappers.GrayscaleObservationV0`
|
||||
* - :class:`wrappers.ResizeObservation`
|
||||
- :class:`experimental.wrappers.ResizeObservationV0`
|
||||
* - `supersuit.reshape_v0 <https://github.com/Farama-Foundation/SuperSuit/blob/314831a7d18e7254f455b181694c049908f95155/supersuit/generic_wrappers/basic_wrappers.py#L40>`_
|
||||
- :class:`experimental.wrappers.ReshapeObservationV0`
|
||||
* - Not Implemented
|
||||
- :class:`experimental.wrappers.RescaleObservationV0`
|
||||
* - `supersuit.dtype_v0 <https://github.com/Farama-Foundation/SuperSuit/blob/314831a7d18e7254f455b181694c049908f95155/supersuit/generic_wrappers/basic_wrappers.py#L32>`_
|
||||
- :class:`experimental.wrappers.DtypeObservationV0`
|
||||
* - :class:`wrappers.PixelObservationWrapper`
|
||||
- :class:`experimental.wrappers.PixelObservationV0`
|
||||
* - :class:`wrappers.NormalizeObservation`
|
||||
- :class:`experimental.wrappers.NormalizeObservationV0`
|
||||
* - :class:`wrappers.TimeAwareObservation`
|
||||
- :class:`experimental.wrappers.TimeAwareObservationV0`
|
||||
* - :class:`wrappers.FrameStack`
|
||||
- :class:`experimental.wrappers.FrameStackObservationV0`
|
||||
* - `supersuit.delay_observations_v0 <https://github.com/Farama-Foundation/SuperSuit/blob/314831a7d18e7254f455b181694c049908f95155/supersuit/generic_wrappers/delay_observations.py#L6>`_
|
||||
- :class:`experimental.wrappers.DelayObservationV0`
|
||||
```
|
||||
|
||||
### Action Wrappers
|
||||
```{eval-rst}
|
||||
.. py:currentmodule:: gymnasium
|
||||
|
||||
.. list-table::
|
||||
:header-rows: 1
|
||||
|
||||
* - Old name
|
||||
- New name
|
||||
* - `supersuit.action_lambda_v1 <https://github.com/Farama-Foundation/SuperSuit/blob/314831a7d18e7254f455b181694c049908f95155/supersuit/lambda_wrappers/action_lambda.py#L73>`_
|
||||
- :class:`experimental.wrappers.LambdaActionV0`
|
||||
* - :class:`wrappers.ClipAction`
|
||||
- :class:`experimental.wrappers.ClipActionV0`
|
||||
* - :class:`wrappers.RescaleAction`
|
||||
- :class:`experimental.wrappers.RescaleActionV0`
|
||||
* - `supersuit.sticky_actions_v0 <https://github.com/Farama-Foundation/SuperSuit/blob/314831a7d18e7254f455b181694c049908f95155/supersuit/generic_wrappers/sticky_actions.py#L6>`_
|
||||
- :class:`experimental.wrappers.StickyActionV0`
|
||||
```
|
||||
|
||||
### Reward Wrappers
|
||||
```{eval-rst}
|
||||
.. py:currentmodule:: gymnasium
|
||||
|
||||
.. list-table::
|
||||
:header-rows: 1
|
||||
|
||||
* - Old name
|
||||
- New name
|
||||
* - :class:`wrappers.TransformReward`
|
||||
- :class:`experimental.wrappers.LambdaRewardV0`
|
||||
* - `supersuit.clip_reward_v0 <https://github.com/Farama-Foundation/SuperSuit/blob/314831a7d18e7254f455b181694c049908f95155/supersuit/generic_wrappers/basic_wrappers.py#L74>`_
|
||||
- :class:`experimental.wrappers.ClipRewardV0`
|
||||
* - :class:`wrappers.NormalizeReward`
|
||||
- :class:`experimental.wrappers.NormalizeRewardV1`
|
||||
```
|
||||
|
||||
### Common Wrappers
|
||||
|
||||
```{eval-rst}
|
||||
.. py:currentmodule:: gymnasium
|
||||
|
||||
.. list-table::
|
||||
:header-rows: 1
|
||||
|
||||
* - Old name
|
||||
- New name
|
||||
* - :class:`wrappers.AutoResetWrapper`
|
||||
- :class:`experimental.wrappers.AutoresetV0`
|
||||
* - :class:`wrappers.PassiveEnvChecker`
|
||||
- :class:`experimental.wrappers.PassiveEnvCheckerV0`
|
||||
* - :class:`wrappers.OrderEnforcing`
|
||||
- :class:`experimental.wrappers.OrderEnforcingV0`
|
||||
* - :class:`wrappers.EnvCompatibility`
|
||||
- Moved to `shimmy <https://github.com/Farama-Foundation/Shimmy/blob/main/shimmy/openai_gym_compatibility.py>`_
|
||||
* - :class:`wrappers.RecordEpisodeStatistics`
|
||||
- :class:`experimental.wrappers.RecordEpisodeStatisticsV0`
|
||||
* - :class:`wrappers.AtariPreprocessing`
|
||||
- :class:`experimental.wrappers.AtariPreprocessingV0`
|
||||
```
|
||||
|
||||
### Rendering Wrappers
|
||||
|
||||
```{eval-rst}
|
||||
.. py:currentmodule:: gymnasium
|
||||
|
||||
.. list-table::
|
||||
:header-rows: 1
|
||||
|
||||
* - Old name
|
||||
- New name
|
||||
* - :class:`wrappers.RecordVideo`
|
||||
- :class:`experimental.wrappers.RecordVideoV0`
|
||||
* - :class:`wrappers.HumanRendering`
|
||||
- :class:`experimental.wrappers.HumanRenderingV0`
|
||||
* - :class:`wrappers.RenderCollection`
|
||||
- :class:`experimental.wrappers.RenderCollectionV0`
|
||||
```
|
||||
|
||||
### Environment data conversion
|
||||
|
||||
```{eval-rst}
|
||||
.. py:currentmodule:: gymnasium
|
||||
|
||||
* :class:`experimental.wrappers.JaxToNumpyV0`
|
||||
* :class:`experimental.wrappers.JaxToTorchV0`
|
||||
* :class:`experimental.wrappers.NumpyToTorchV0`
|
||||
```
|
@@ -1,37 +0,0 @@
|
||||
---
|
||||
title: Functional
|
||||
---
|
||||
|
||||
# Functional Environment
|
||||
|
||||
## gymnasium.experimental.FuncEnv
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: gymnasium.experimental.functional.FuncEnv
|
||||
|
||||
.. autofunction:: gymnasium.experimental.functional.FuncEnv.initial
|
||||
.. autofunction:: gymnasium.experimental.functional.FuncEnv.transition
|
||||
|
||||
.. autofunction:: gymnasium.experimental.functional.FuncEnv.observation
|
||||
.. autofunction:: gymnasium.experimental.functional.FuncEnv.reward
|
||||
.. autofunction:: gymnasium.experimental.functional.FuncEnv.terminal
|
||||
|
||||
.. autofunction:: gymnasium.experimental.functional.FuncEnv.state_info
|
||||
.. autofunction:: gymnasium.experimental.functional.FuncEnv.step_info
|
||||
|
||||
.. autofunction:: gymnasium.experimental.functional.FuncEnv.transform
|
||||
|
||||
.. autofunction:: gymnasium.experimental.functional.FuncEnv.render_image
|
||||
.. autofunction:: gymnasium.experimental.functional.FuncEnv.render_init
|
||||
.. autofunction:: gymnasium.experimental.functional.FuncEnv.render_close
|
||||
```
|
||||
|
||||
## gymnasium.experimental.func2env.FunctionalJaxCompatibilityEnv
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: gymnasium.experimental.functional_jax_env.FunctionalJaxEnv
|
||||
|
||||
.. autofunction:: gymnasium.experimental.functional_jax_env.FunctionalJaxEnv.reset
|
||||
.. autofunction:: gymnasium.experimental.functional_jax_env.FunctionalJaxEnv.step
|
||||
.. autofunction:: gymnasium.experimental.functional_jax_env.FunctionalJaxEnv.render
|
||||
```
|
||||
|
@@ -1,39 +0,0 @@
|
||||
---
|
||||
title: Vector
|
||||
---
|
||||
|
||||
# Vectorizing Environment
|
||||
|
||||
## gymnasium.experimental.VectorEnv
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: gymnasium.experimental.vector.VectorEnv
|
||||
.. autofunction:: gymnasium.experimental.vector.VectorEnv.reset
|
||||
.. autofunction:: gymnasium.experimental.vector.VectorEnv.step
|
||||
.. autofunction:: gymnasium.experimental.vector.VectorEnv.close
|
||||
.. autofunction:: gymnasium.experimental.vector.VectorEnv.reset
|
||||
```
|
||||
|
||||
## gymnasium.experimental.vector.AsyncVectorEnv
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: gymnasium.experimental.vector.AsyncVectorEnv
|
||||
.. autofunction:: gymnasium.experimental.vector.AsyncVectorEnv.reset
|
||||
.. autofunction:: gymnasium.experimental.vector.AsyncVectorEnv.step
|
||||
.. autofunction:: gymnasium.experimental.vector.AsyncVectorEnv.close
|
||||
.. autofunction:: gymnasium.experimental.vector.AsyncVectorEnv.call
|
||||
.. autofunction:: gymnasium.experimental.vector.AsyncVectorEnv.get_attr
|
||||
.. autofunction:: gymnasium.experimental.vector.AsyncVectorEnv.set_attr
|
||||
```
|
||||
|
||||
## gymnasium.experimental.vector.SyncVectorEnv
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: gymnasium.experimental.vector.SyncVectorEnv
|
||||
.. autofunction:: gymnasium.experimental.vector.SyncVectorEnv.reset
|
||||
.. autofunction:: gymnasium.experimental.vector.SyncVectorEnv.step
|
||||
.. autofunction:: gymnasium.experimental.vector.SyncVectorEnv.close
|
||||
.. autofunction:: gymnasium.experimental.vector.SyncVectorEnv.call
|
||||
.. autofunction:: gymnasium.experimental.vector.SyncVectorEnv.get_attr
|
||||
.. autofunction:: gymnasium.experimental.vector.SyncVectorEnv.set_attr
|
||||
```
|
@@ -1,29 +0,0 @@
|
||||
---
|
||||
title: Vector Utility
|
||||
---
|
||||
|
||||
# Utility functions for vectorisation
|
||||
|
||||
## Spaces utility functions
|
||||
|
||||
```{eval-rst}
|
||||
.. autofunction:: gymnasium.experimental.vector.utils.batch_space
|
||||
.. autofunction:: gymnasium.experimental.vector.utils.concatenate
|
||||
.. autofunction:: gymnasium.experimental.vector.utils.iterate
|
||||
.. autofunction:: gymnasium.experimental.vector.utils.create_empty_array
|
||||
```
|
||||
|
||||
## Shared memory functions
|
||||
|
||||
```{eval-rst}
|
||||
.. autofunction:: gymnasium.experimental.vector.utils.create_shared_memory
|
||||
.. autofunction:: gymnasium.experimental.vector.utils.read_from_shared_memory
|
||||
.. autofunction:: gymnasium.experimental.vector.utils.write_to_shared_memory
|
||||
```
|
||||
|
||||
## Miscellaneous
|
||||
|
||||
```{eval-rst}
|
||||
.. autofunction:: gymnasium.experimental.vector.utils.CloudpickleWrapper
|
||||
.. autofunction:: gymnasium.experimental.vector.utils.clear_mpi_env_vars
|
||||
```
|
@@ -1,53 +0,0 @@
|
||||
---
|
||||
title: Vector Wrappers
|
||||
---
|
||||
|
||||
# Vector Environment Wrappers
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: gymnasium.experimental.vector.VectorWrapper
|
||||
```
|
||||
|
||||
## Vector Observation Wrappers
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: gymnasium.experimental.vector.VectorObservationWrapper
|
||||
.. autoclass:: gymnasium.experimental.wrappers.vector.LambdaObservationV0
|
||||
.. autoclass:: gymnasium.experimental.wrappers.vector.FilterObservationV0
|
||||
.. autoclass:: gymnasium.experimental.wrappers.vector.FlattenObservationV0
|
||||
.. autoclass:: gymnasium.experimental.wrappers.vector.GrayscaleObservationV0
|
||||
.. autoclass:: gymnasium.experimental.wrappers.vector.ResizeObservationV0
|
||||
.. autoclass:: gymnasium.experimental.wrappers.vector.ReshapeObservationV0
|
||||
.. autoclass:: gymnasium.experimental.wrappers.vector.RescaleObservationV0
|
||||
.. autoclass:: gymnasium.experimental.wrappers.vector.DtypeObservationV0
|
||||
```
|
||||
|
||||
## Vector Action Wrappers
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: gymnasium.experimental.vector.VectorActionWrapper
|
||||
.. autoclass:: gymnasium.experimental.wrappers.vector.LambdaActionV0
|
||||
.. autoclass:: gymnasium.experimental.wrappers.vector.ClipActionV0
|
||||
.. autoclass:: gymnasium.experimental.wrappers.vector.RescaleActionV0
|
||||
```
|
||||
|
||||
## Vector Reward Wrappers
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: gymnasium.experimental.vector.VectorRewardWrapper
|
||||
.. autoclass:: gymnasium.experimental.wrappers.vector.LambdaRewardV0
|
||||
.. autoclass:: gymnasium.experimental.wrappers.vector.ClipRewardV0
|
||||
```
|
||||
|
||||
## More Vector Wrappers
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: gymnasium.experimental.wrappers.vector.RecordEpisodeStatisticsV0
|
||||
.. autoclass:: gymnasium.experimental.wrappers.vector.DictInfoToListV0
|
||||
.. autoclass:: gymnasium.experimental.wrappers.vector.VectorizeLambdaObservationV0
|
||||
.. autoclass:: gymnasium.experimental.wrappers.vector.VectorizeLambdaActionV0
|
||||
.. autoclass:: gymnasium.experimental.wrappers.vector.VectorizeLambdaRewardV0
|
||||
.. autoclass:: gymnasium.experimental.wrappers.vector.JaxToNumpyV0
|
||||
.. autoclass:: gymnasium.experimental.wrappers.vector.JaxToTorchV0
|
||||
.. autoclass:: gymnasium.experimental.wrappers.vector.NumpyToTorchV0
|
||||
```
|
@@ -1,65 +0,0 @@
|
||||
---
|
||||
title: Wrappers
|
||||
---
|
||||
|
||||
# Wrappers
|
||||
|
||||
## Observation Wrappers
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: gymnasium.experimental.wrappers.LambdaObservationV0
|
||||
.. autoclass:: gymnasium.experimental.wrappers.FilterObservationV0
|
||||
.. autoclass:: gymnasium.experimental.wrappers.FlattenObservationV0
|
||||
.. autoclass:: gymnasium.experimental.wrappers.GrayscaleObservationV0
|
||||
.. autoclass:: gymnasium.experimental.wrappers.ResizeObservationV0
|
||||
.. autoclass:: gymnasium.experimental.wrappers.ReshapeObservationV0
|
||||
.. autoclass:: gymnasium.experimental.wrappers.RescaleObservationV0
|
||||
.. autoclass:: gymnasium.experimental.wrappers.DtypeObservationV0
|
||||
.. autoclass:: gymnasium.experimental.wrappers.PixelObservationV0
|
||||
.. autoclass:: gymnasium.experimental.wrappers.NormalizeObservationV0
|
||||
.. autoclass:: gymnasium.experimental.wrappers.TimeAwareObservationV0
|
||||
.. autoclass:: gymnasium.experimental.wrappers.FrameStackObservationV0
|
||||
.. autoclass:: gymnasium.experimental.wrappers.DelayObservationV0
|
||||
.. autoclass:: gymnasium.experimental.wrappers.AtariPreprocessingV0
|
||||
```
|
||||
|
||||
## Action Wrappers
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: gymnasium.experimental.wrappers.LambdaActionV0
|
||||
.. autoclass:: gymnasium.experimental.wrappers.ClipActionV0
|
||||
.. autoclass:: gymnasium.experimental.wrappers.RescaleActionV0
|
||||
.. autoclass:: gymnasium.experimental.wrappers.StickyActionV0
|
||||
```
|
||||
|
||||
## Reward Wrappers
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: gymnasium.experimental.wrappers.LambdaRewardV0
|
||||
.. autoclass:: gymnasium.experimental.wrappers.ClipRewardV0
|
||||
.. autoclass:: gymnasium.experimental.wrappers.NormalizeRewardV1
|
||||
```
|
||||
|
||||
## Other Wrappers
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: gymnasium.experimental.wrappers.AutoresetV0
|
||||
.. autoclass:: gymnasium.experimental.wrappers.PassiveEnvCheckerV0
|
||||
.. autoclass:: gymnasium.experimental.wrappers.OrderEnforcingV0
|
||||
.. autoclass:: gymnasium.experimental.wrappers.RecordEpisodeStatisticsV0
|
||||
```
|
||||
|
||||
## Rendering Wrappers
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: gymnasium.experimental.wrappers.RecordVideoV0
|
||||
.. autoclass:: gymnasium.experimental.wrappers.HumanRenderingV0
|
||||
.. autoclass:: gymnasium.experimental.wrappers.RenderCollectionV0
|
||||
```
|
||||
|
||||
## Environment data conversion
|
||||
```{eval-rst}
|
||||
.. autoclass:: gymnasium.experimental.wrappers.JaxToNumpyV0
|
||||
.. autoclass:: gymnasium.experimental.wrappers.JaxToTorchV0
|
||||
.. autoclass:: gymnasium.experimental.wrappers.NumpyToTorchV0
|
||||
```
|
34
docs/api/functional.md
Normal file
34
docs/api/functional.md
Normal file
@@ -0,0 +1,34 @@
|
||||
---
|
||||
title: Functional
|
||||
---
|
||||
|
||||
# Functional Env
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: gymnasium.functional.FuncEnv
|
||||
|
||||
.. automethod:: gymnasium.functional.FuncEnv.transform
|
||||
|
||||
.. automethod:: gymnasium.functional.FuncEnv.initial
|
||||
.. automethod:: gymnasium.functional.FuncEnv.initial_info
|
||||
|
||||
.. automethod:: gymnasium.functional.FuncEnv.transition
|
||||
.. automethod:: gymnasium.functional.FuncEnv.observation
|
||||
.. automethod:: gymnasium.functional.FuncEnv.reward
|
||||
.. automethod:: gymnasium.functional.FuncEnv.terminal
|
||||
.. automethod:: gymnasium.functional.FuncEnv.transition_info
|
||||
|
||||
.. automethod:: gymnasium.functional.FuncEnv.render_image
|
||||
.. automethod:: gymnasium.functional.FuncEnv.render_initialise
|
||||
.. automethod:: gymnasium.functional.FuncEnv.render_close
|
||||
```
|
||||
|
||||
## Converting Jax-based Functional environments to standard Env
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: gymnasium.utils.functional_jax_env.FunctionalJaxEnv
|
||||
|
||||
.. automethod:: gymnasium.utils.functional_jax_env.FunctionalJaxEnv.reset
|
||||
.. automethod:: gymnasium.utils.functional_jax_env.FunctionalJaxEnv.step
|
||||
.. automethod:: gymnasium.utils.functional_jax_env.FunctionalJaxEnv.render
|
||||
```
|
@@ -2,14 +2,13 @@
|
||||
title: Registry
|
||||
---
|
||||
|
||||
# Register and Make
|
||||
# Make and register
|
||||
|
||||
```{eval-rst}
|
||||
Gymnasium allows users to automatically load environments, pre-wrapped with several important wrappers through the :meth:`gymnasium.make` function. To do this, the environment must be registered prior with :meth:`gymnasium.register`. To get the environment specifications for a registered environment, use :meth:`gymnasium.spec` and to print the whole registry, use :meth:`gymnasium.pprint_registry`.
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autofunction:: gymnasium.make
|
||||
.. autofunction:: gymnasium.make_vec
|
||||
.. autofunction:: gymnasium.register
|
||||
.. autofunction:: gymnasium.spec
|
||||
.. autofunction:: gymnasium.pprint_registry
|
||||
@@ -19,6 +18,7 @@ Gymnasium allows users to automatically load environments, pre-wrapped with seve
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: gymnasium.envs.registration.EnvSpec
|
||||
.. autoclass:: gymnasium.envs.registration.WrapperSpec
|
||||
.. attribute:: gymnasium.envs.registration.registry
|
||||
|
||||
The Global registry for gymnasium which is where environment specifications are stored by :meth:`gymnasium.register` and from which :meth:`gymnasium.make` is used to create environments.
|
||||
@@ -36,5 +36,4 @@ Gymnasium allows users to automatically load environments, pre-wrapped with seve
|
||||
.. autofunction:: gymnasium.envs.registration.find_highest_version
|
||||
.. autofunction:: gymnasium.envs.registration.namespace
|
||||
.. autofunction:: gymnasium.envs.registration.load_env_creator
|
||||
.. autofunction:: gymnasium.envs.registration.load_plugin_envs
|
||||
```
|
||||
|
@@ -9,39 +9,38 @@ title: Spaces
|
||||
spaces/fundamental
|
||||
spaces/composite
|
||||
spaces/utils
|
||||
spaces/vector_utils
|
||||
vector/utils
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. automodule:: gymnasium.spaces
|
||||
```
|
||||
|
||||
## The Base Class
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: gymnasium.spaces.Space
|
||||
```
|
||||
|
||||
### Attributes
|
||||
|
||||
## Attributes
|
||||
```{eval-rst}
|
||||
.. autoproperty:: gymnasium.spaces.space.Space.shape
|
||||
.. py:currentmodule:: gymnasium.spaces
|
||||
|
||||
.. autoproperty:: Space.shape
|
||||
.. property:: Space.dtype
|
||||
|
||||
Return the data type of this space.
|
||||
.. autoproperty:: gymnasium.spaces.space.Space.is_np_flattenable
|
||||
.. autoproperty:: Space.is_np_flattenable
|
||||
.. autoproperty:: Space.np_random
|
||||
```
|
||||
|
||||
### Methods
|
||||
|
||||
## Methods
|
||||
Each space implements the following functions:
|
||||
|
||||
```{eval-rst}
|
||||
.. autofunction:: gymnasium.spaces.space.Space.sample
|
||||
.. autofunction:: gymnasium.spaces.space.Space.contains
|
||||
.. autofunction:: gymnasium.spaces.space.Space.seed
|
||||
.. autofunction:: gymnasium.spaces.space.Space.to_jsonable
|
||||
.. autofunction:: gymnasium.spaces.space.Space.from_jsonable
|
||||
.. py:currentmodule:: gymnasium.spaces
|
||||
|
||||
.. automethod:: Space.sample
|
||||
.. automethod:: Space.contains
|
||||
.. automethod:: Space.seed
|
||||
.. automethod:: Space.to_jsonable
|
||||
.. automethod:: Space.from_jsonable
|
||||
```
|
||||
|
||||
## Fundamental Spaces
|
||||
@@ -49,13 +48,13 @@ Each space implements the following functions:
|
||||
Gymnasium has a number of fundamental spaces that are used as building boxes for more complex spaces.
|
||||
|
||||
```{eval-rst}
|
||||
.. currentmodule:: gymnasium.spaces
|
||||
.. py:currentmodule:: gymnasium.spaces
|
||||
|
||||
* :py:class:`Box` - Supports continuous (and discrete) vectors or matrices, used for vector observations, images, etc
|
||||
* :py:class:`Discrete` - Supports a single discrete number of values with an optional start for the values
|
||||
* :py:class:`MultiBinary` - Supports single or matrices of binary values, used for holding down a button or if an agent has an object
|
||||
* :py:class:`MultiDiscrete` - Supports multiple discrete values with multiple axes, used for controller actions
|
||||
* :py:class:`Text` - Supports strings, used for passing agent messages, mission details, etc
|
||||
* :class:`Box` - Supports continuous (and discrete) vectors or matrices, used for vector observations, images, etc
|
||||
* :class:`Discrete` - Supports a single discrete number of values with an optional start for the values
|
||||
* :class:`MultiBinary` - Supports single or matrices of binary values, used for holding down a button or if an agent has an object
|
||||
* :class:`MultiDiscrete` - Supports multiple discrete values with multiple axes, used for controller actions
|
||||
* :class:`Text` - Supports strings, used for passing agent messages, mission details, etc
|
||||
```
|
||||
|
||||
## Composite Spaces
|
||||
@@ -63,37 +62,41 @@ Gymnasium has a number of fundamental spaces that are used as building boxes for
|
||||
Often environment spaces require joining fundamental spaces together for vectorised environments, separate agents or readability of the space.
|
||||
|
||||
```{eval-rst}
|
||||
* :py:class:`Dict` - Supports a dictionary of keys and subspaces, used for a fixed number of unordered spaces
|
||||
* :py:class:`Tuple` - Supports a tuple of subspaces, used for multiple for a fixed number of ordered spaces
|
||||
* :py:class:`Sequence` - Supports a variable number of instances of a single subspace, used for entities spaces or selecting a variable number of actions
|
||||
.. py:currentmodule:: gymnasium.spaces
|
||||
|
||||
* :class:`Dict` - Supports a dictionary of keys and subspaces, used for a fixed number of unordered spaces
|
||||
* :class:`Tuple` - Supports a tuple of subspaces, used for multiple for a fixed number of ordered spaces
|
||||
* :class:`Sequence` - Supports a variable number of instances of a single subspace, used for entities spaces or selecting a variable number of actions
|
||||
* :py:class:`Graph` - Supports graph based actions or observations with discrete or continuous nodes and edge values.
|
||||
```
|
||||
|
||||
## Utils
|
||||
## Utility functions
|
||||
|
||||
Gymnasium contains a number of helpful utility functions for flattening and unflattening spaces.
|
||||
This can be important for passing information to neural networks.
|
||||
|
||||
```{eval-rst}
|
||||
* :py:class:`utils.flatdim` - The number of dimensions the flattened space will contain
|
||||
* :py:class:`utils.flatten_space` - Flattens a space for which the `flattened` space instances will contain
|
||||
* :py:class:`utils.flatten` - Flattens an instance of a space that is contained within the flattened version of the space
|
||||
* :py:class:`utils.unflatten` - The reverse of the `flatten_space` function
|
||||
.. py:currentmodule:: gymnasium.spaces
|
||||
|
||||
* :class:`utils.flatdim` - The number of dimensions the flattened space will contain
|
||||
* :class:`utils.flatten_space` - Flattens a space for which the :class:`utils.flattened` space instances will contain
|
||||
* :class:`utils.flatten` - Flattens an instance of a space that is contained within the flattened version of the space
|
||||
* :class:`utils.unflatten` - The reverse of the :class:`utils.flatten_space` function
|
||||
```
|
||||
|
||||
## Vector Utils
|
||||
## Vector Utility functions
|
||||
|
||||
When vectorizing environments, it is necessary to modify the observation and action spaces for new batched spaces sizes.
|
||||
Therefore, Gymnasium provides a number of additional functions used when using a space with a Vector environment.
|
||||
|
||||
```{eval-rst}
|
||||
.. currentmodule:: gymnasium
|
||||
.. py:currentmodule:: gymnasium
|
||||
|
||||
* :py:class:`vector.utils.batch_space`
|
||||
* :py:class:`vector.utils.concatenate`
|
||||
* :py:class:`vector.utils.iterate`
|
||||
* :py:class:`vector.utils.create_empty_array`
|
||||
* :py:class:`vector.utils.create_shared_memory`
|
||||
* :py:class:`vector.utils.read_from_shared_memory`
|
||||
* :py:class:`vector.utils.write_to_shared_memory`
|
||||
* :class:`vector.utils.batch_space` - Transforms a space into the equivalent space for ``n`` users
|
||||
* :class:`vector.utils.concatenate` - Concatenates a space's samples into a pre-generated space
|
||||
* :class:`vector.utils.iterate` - Iterate over the batched space's samples
|
||||
* :class:`vector.utils.create_empty_array` - Creates an empty sample for an space (generally used with ``concatenate``)
|
||||
* :class:`vector.utils.create_shared_memory` - Creates a shared memory for asynchronous (multiprocessing) environment
|
||||
* :class:`vector.utils.read_from_shared_memory` - Reads a shared memory for asynchronous (multiprocessing) environment
|
||||
* :class:`vector.utils.write_to_shared_memory` - Write to a shared memory for asynchronous (multiprocessing) environment
|
||||
```
|
||||
|
@@ -1,37 +1,24 @@
|
||||
# Composite Spaces
|
||||
|
||||
## Dict
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: gymnasium.spaces.Dict
|
||||
|
||||
.. automethod:: gymnasium.spaces.Dict.sample
|
||||
.. automethod:: gymnasium.spaces.Dict.seed
|
||||
```
|
||||
.. automethod:: gymnasium.spaces.Dict.sample
|
||||
.. automethod:: gymnasium.spaces.Dict.seed
|
||||
|
||||
## Tuple
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: gymnasium.spaces.Tuple
|
||||
|
||||
.. automethod:: gymnasium.spaces.Tuple.sample
|
||||
.. automethod:: gymnasium.spaces.Tuple.seed
|
||||
```
|
||||
.. automethod:: gymnasium.spaces.Tuple.sample
|
||||
.. automethod:: gymnasium.spaces.Tuple.seed
|
||||
|
||||
## Sequence
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: gymnasium.spaces.Sequence
|
||||
|
||||
.. automethod:: gymnasium.spaces.Sequence.sample
|
||||
.. automethod:: gymnasium.spaces.Sequence.seed
|
||||
```
|
||||
.. automethod:: gymnasium.spaces.Sequence.sample
|
||||
.. automethod:: gymnasium.spaces.Sequence.seed
|
||||
|
||||
## Graph
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: gymnasium.spaces.Graph
|
||||
|
||||
.. automethod:: gymnasium.spaces.Graph.sample
|
||||
.. automethod:: gymnasium.spaces.Graph.seed
|
||||
.. automethod:: gymnasium.spaces.Graph.sample
|
||||
.. automethod:: gymnasium.spaces.Graph.seed
|
||||
```
|
||||
|
@@ -4,46 +4,30 @@ title: Fundamental Spaces
|
||||
|
||||
# Fundamental Spaces
|
||||
|
||||
## Box
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: gymnasium.spaces.Box
|
||||
|
||||
.. automethod:: gymnasium.spaces.Box.sample
|
||||
.. automethod:: gymnasium.spaces.Box.seed
|
||||
.. automethod:: gymnasium.spaces.Box.is_bounded
|
||||
```
|
||||
.. automethod:: gymnasium.spaces.Box.sample
|
||||
.. automethod:: gymnasium.spaces.Box.seed
|
||||
.. automethod:: gymnasium.spaces.Box.is_bounded
|
||||
|
||||
## Discrete
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: gymnasium.spaces.Discrete
|
||||
.. automethod:: gymnasium.spaces.Discrete.sample
|
||||
.. automethod:: gymnasium.spaces.Discrete.seed
|
||||
```
|
||||
|
||||
## MultiBinary
|
||||
.. automethod:: gymnasium.spaces.Discrete.sample
|
||||
.. automethod:: gymnasium.spaces.Discrete.seed
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: gymnasium.spaces.MultiBinary
|
||||
.. automethod:: gymnasium.spaces.MultiBinary.sample
|
||||
.. automethod:: gymnasium.spaces.MultiBinary.seed
|
||||
```
|
||||
|
||||
## MultiDiscrete
|
||||
.. automethod:: gymnasium.spaces.MultiBinary.sample
|
||||
.. automethod:: gymnasium.spaces.MultiBinary.seed
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: gymnasium.spaces.MultiDiscrete
|
||||
|
||||
.. automethod:: gymnasium.spaces.MultiDiscrete.sample
|
||||
.. automethod:: gymnasium.spaces.MultiDiscrete.seed
|
||||
```
|
||||
.. automethod:: gymnasium.spaces.MultiDiscrete.sample
|
||||
.. automethod:: gymnasium.spaces.MultiDiscrete.seed
|
||||
|
||||
## Text
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: gymnasium.spaces.Text
|
||||
|
||||
.. automethod:: gymnasium.spaces.Text.sample
|
||||
.. automethod:: gymnasium.spaces.Text.seed
|
||||
.. automethod:: gymnasium.spaces.Text.sample
|
||||
.. automethod:: gymnasium.spaces.Text.seed
|
||||
```
|
||||
|
@@ -1,8 +1,20 @@
|
||||
---
|
||||
title: Utils
|
||||
title: Utility functions
|
||||
---
|
||||
|
||||
# Utils
|
||||
# Utility functions
|
||||
|
||||
## Seeding
|
||||
|
||||
```{eval-rst}
|
||||
.. autofunction:: gymnasium.utils.seeding.np_random
|
||||
```
|
||||
|
||||
## Environment Checking
|
||||
|
||||
```{eval-rst}
|
||||
.. autofunction:: gymnasium.utils.env_checker.check_env
|
||||
```
|
||||
|
||||
## Visualization
|
||||
|
||||
@@ -17,6 +29,12 @@ title: Utils
|
||||
.. automethod:: process_event
|
||||
```
|
||||
|
||||
## Environment pickling
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: gymnasium.utils.ezpickle.EzPickle
|
||||
```
|
||||
|
||||
## Save Rendering Videos
|
||||
|
||||
```{eval-rst}
|
||||
@@ -31,15 +49,3 @@ title: Utils
|
||||
.. autofunction:: gymnasium.utils.step_api_compatibility.convert_to_terminated_truncated_step_api
|
||||
.. autofunction:: gymnasium.utils.step_api_compatibility.convert_to_done_step_api
|
||||
```
|
||||
|
||||
## Seeding
|
||||
|
||||
```{eval-rst}
|
||||
.. autofunction:: gymnasium.utils.seeding.np_random
|
||||
```
|
||||
|
||||
## Environment Checking
|
||||
|
||||
```{eval-rst}
|
||||
.. autofunction:: gymnasium.utils.env_checker.check_env
|
||||
```
|
||||
|
@@ -2,7 +2,15 @@
|
||||
title: Vector
|
||||
---
|
||||
|
||||
# Vector
|
||||
# Vector environments
|
||||
|
||||
```{toctree}
|
||||
:hidden:
|
||||
vector/wrappers
|
||||
vector/async_vector_env
|
||||
vector/sync_vector_env
|
||||
vector/utils
|
||||
```
|
||||
|
||||
## Gymnasium.vector.VectorEnv
|
||||
|
||||
@@ -14,62 +22,47 @@ title: Vector
|
||||
|
||||
```{eval-rst}
|
||||
.. automethod:: gymnasium.vector.VectorEnv.reset
|
||||
|
||||
.. automethod:: gymnasium.vector.VectorEnv.step
|
||||
|
||||
.. automethod:: gymnasium.vector.VectorEnv.close
|
||||
```
|
||||
|
||||
### Attributes
|
||||
|
||||
```{eval-rst}
|
||||
.. attribute:: action_space
|
||||
.. autoattribute:: gymnasium.vector.VectorEnv.num_envs
|
||||
|
||||
The (batched) action space. The input actions of `step` must be valid elements of `action_space`.::
|
||||
The number of sub-environments in the vector environment.
|
||||
|
||||
>>> envs = gymnasium.vector.make("CartPole-v1", num_envs=3)
|
||||
>>> envs.action_space
|
||||
MultiDiscrete([2 2 2])
|
||||
.. autoattribute:: gymnasium.vector.VectorEnv.action_space
|
||||
|
||||
.. attribute:: observation_space
|
||||
The (batched) action space. The input actions of `step` must be valid elements of `action_space`.
|
||||
|
||||
The (batched) observation space. The observations returned by `reset` and `step` are valid elements of `observation_space`.::
|
||||
.. autoattribute:: gymnasium.vector.VectorEnv.observation_space
|
||||
|
||||
>>> envs = gymnasium.vector.make("CartPole-v1", num_envs=3)
|
||||
>>> envs.observation_space
|
||||
Box([[-4.8 ...]], [[4.8 ...]], (3, 4), float32)
|
||||
The (batched) observation space. The observations returned by `reset` and `step` are valid elements of `observation_space`.
|
||||
|
||||
.. attribute:: single_action_space
|
||||
.. autoattribute:: gymnasium.vector.VectorEnv.single_action_space
|
||||
|
||||
The action space of an environment copy.::
|
||||
The action space of a sub-environment.
|
||||
|
||||
>>> envs = gymnasium.vector.make("CartPole-v1", num_envs=3)
|
||||
>>> envs.single_action_space
|
||||
Discrete(2)
|
||||
.. autoattribute:: gymnasium.vector.VectorEnv.single_observation_space
|
||||
|
||||
.. attribute:: single_observation_space
|
||||
The observation space of an environment copy.
|
||||
|
||||
The observation space of an environment copy.::
|
||||
.. autoattribute:: gymnasium.vector.VectorEnv.spec
|
||||
|
||||
>>> envs = gymnasium.vector.make("CartPole-v1", num_envs=3)
|
||||
>>> envs.single_observation_space
|
||||
Box([-4.8 ...], [4.8 ...], (4,), float32)
|
||||
The ``EnvSpec`` of the environment normally set during :py:meth:`gymnasium.make_vec`
|
||||
```
|
||||
|
||||
### Additional Methods
|
||||
|
||||
```{eval-rst}
|
||||
.. autoproperty:: gymnasium.vector.VectorEnv.unwrapped
|
||||
.. autoproperty:: gymnasium.vector.VectorEnv.np_random
|
||||
```
|
||||
|
||||
## Making Vector Environments
|
||||
|
||||
```{eval-rst}
|
||||
.. autofunction:: gymnasium.vector.make
|
||||
```
|
||||
|
||||
## Async Vector Env
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: gymnasium.vector.AsyncVectorEnv
|
||||
```
|
||||
|
||||
## Sync Vector Env
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: gymnasium.vector.SyncVectorEnv
|
||||
To create vector environments, gymnasium provides :func:`gymnasium.make_vec` as an equivalent function to :func:`gymnasium.make`.
|
||||
```
|
||||
|
13
docs/api/vector/async_vector_env.md
Normal file
13
docs/api/vector/async_vector_env.md
Normal file
@@ -0,0 +1,13 @@
|
||||
# AsyncVectorEnv
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: gymnasium.vector.AsyncVectorEnv
|
||||
|
||||
.. automethod:: gymnasium.vector.AsyncVectorEnv.reset
|
||||
.. automethod:: gymnasium.vector.AsyncVectorEnv.step
|
||||
.. automethod:: gymnasium.vector.AsyncVectorEnv.close
|
||||
|
||||
.. automethod:: gymnasium.vector.AsyncVectorEnv.call
|
||||
.. automethod:: gymnasium.vector.AsyncVectorEnv.get_attr
|
||||
.. automethod:: gymnasium.vector.AsyncVectorEnv.set_attr
|
||||
```
|
13
docs/api/vector/sync_vector_env.md
Normal file
13
docs/api/vector/sync_vector_env.md
Normal file
@@ -0,0 +1,13 @@
|
||||
# SyncVectorEnv
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: gymnasium.vector.SyncVectorEnv
|
||||
|
||||
.. automethod:: gymnasium.vector.SyncVectorEnv.reset
|
||||
.. automethod:: gymnasium.vector.SyncVectorEnv.step
|
||||
.. automethod:: gymnasium.vector.SyncVectorEnv.close
|
||||
|
||||
.. automethod:: gymnasium.vector.SyncVectorEnv.call
|
||||
.. automethod:: gymnasium.vector.SyncVectorEnv.get_attr
|
||||
.. automethod:: gymnasium.vector.SyncVectorEnv.set_attr
|
||||
```
|
@@ -1,20 +1,25 @@
|
||||
---
|
||||
title: Vector Utils
|
||||
---
|
||||
# Utility functions
|
||||
|
||||
# Spaces Vector Utils
|
||||
## Vectorizing Spaces
|
||||
|
||||
```{eval-rst}
|
||||
.. autofunction:: gymnasium.vector.utils.batch_space
|
||||
.. autofunction:: gymnasium.vector.utils.concatenate
|
||||
.. autofunction:: gymnasium.vector.utils.iterate
|
||||
.. autofunction:: gymnasium.vector.utils.create_empty_array
|
||||
```
|
||||
|
||||
## Shared Memory Utils
|
||||
## Shared Memory for a Space
|
||||
|
||||
```{eval-rst}
|
||||
.. autofunction:: gymnasium.vector.utils.create_empty_array
|
||||
.. autofunction:: gymnasium.vector.utils.create_shared_memory
|
||||
.. autofunction:: gymnasium.vector.utils.read_from_shared_memory
|
||||
.. autofunction:: gymnasium.vector.utils.write_to_shared_memory
|
||||
```
|
||||
|
||||
## Miscellaneous
|
||||
|
||||
```{eval-rst}
|
||||
.. autofunction:: gymnasium.vector.utils.CloudpickleWrapper
|
||||
.. autofunction:: gymnasium.vector.utils.clear_mpi_env_vars
|
||||
```
|
26
docs/api/vector/wrappers.md
Normal file
26
docs/api/vector/wrappers.md
Normal file
@@ -0,0 +1,26 @@
|
||||
---
|
||||
title: Vector Wrappers
|
||||
---
|
||||
|
||||
# Vector Wrappers
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: gymnasium.vector.VectorWrapper
|
||||
|
||||
.. automethod:: gymnasium.vector.VectorWrapper.step
|
||||
.. automethod:: gymnasium.vector.VectorWrapper.reset
|
||||
.. automethod:: gymnasium.vector.VectorWrapper.close
|
||||
|
||||
.. autoclass:: gymnasium.vector.VectorObservationWrapper
|
||||
|
||||
.. automethod:: gymnasium.vector.VectorObservationWrapper.vector_observation
|
||||
.. automethod:: gymnasium.vector.VectorObservationWrapper.single_observation
|
||||
|
||||
.. autoclass:: gymnasium.vector.VectorActionWrapper
|
||||
|
||||
.. automethod:: gymnasium.vector.VectorActionWrapper.actions
|
||||
|
||||
.. autoclass:: gymnasium.vector.VectorRewardWrapper
|
||||
|
||||
.. automethod:: gymnasium.vector.VectorRewardWrapper.rewards
|
||||
```
|
@@ -6,134 +6,47 @@ title: Wrapper
|
||||
|
||||
```{toctree}
|
||||
:hidden:
|
||||
|
||||
wrappers/table
|
||||
wrappers/misc_wrappers
|
||||
wrappers/action_wrappers
|
||||
wrappers/observation_wrappers
|
||||
wrappers/reward_wrappers
|
||||
wrappers/vector_wrappers
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. automodule:: gymnasium.wrappers
|
||||
|
||||
```
|
||||
|
||||
## gymnasium.Wrapper
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: gymnasium.Wrapper
|
||||
```
|
||||
|
||||
### Methods
|
||||
|
||||
## Methods
|
||||
```{eval-rst}
|
||||
.. autofunction:: gymnasium.Wrapper.step
|
||||
.. autofunction:: gymnasium.Wrapper.reset
|
||||
.. autofunction:: gymnasium.Wrapper.close
|
||||
.. automethod:: gymnasium.Wrapper.step
|
||||
.. automethod:: gymnasium.Wrapper.reset
|
||||
.. automethod:: gymnasium.Wrapper.render
|
||||
.. automethod:: gymnasium.Wrapper.close
|
||||
.. automethod:: gymnasium.Wrapper.wrapper_spec
|
||||
.. automethod:: gymnasium.Wrapper.get_wrapper_attr
|
||||
.. automethod:: gymnasium.Wrapper.set_wrapper_attr
|
||||
```
|
||||
|
||||
### Attributes
|
||||
|
||||
## Attributes
|
||||
```{eval-rst}
|
||||
.. autoproperty:: gymnasium.Wrapper.action_space
|
||||
.. autoproperty:: gymnasium.Wrapper.observation_space
|
||||
.. autoproperty:: gymnasium.Wrapper.reward_range
|
||||
.. autoproperty:: gymnasium.Wrapper.spec
|
||||
.. autoproperty:: gymnasium.Wrapper.metadata
|
||||
.. autoproperty:: gymnasium.Wrapper.np_random
|
||||
.. attribute:: gymnasium.Wrapper.env
|
||||
.. autoattribute:: gymnasium.Wrapper.env
|
||||
|
||||
The environment (one level underneath) this wrapper.
|
||||
|
||||
This may itself be a wrapped environment.
|
||||
To obtain the environment underneath all layers of wrappers, use :attr:`gymnasium.Wrapper.unwrapped`.
|
||||
This may itself be a wrapped environment. To obtain the environment underneath all layers of wrappers, use :attr:`gymnasium.Wrapper.unwrapped`.
|
||||
|
||||
.. autoproperty:: gymnasium.Wrapper.action_space
|
||||
.. autoproperty:: gymnasium.Wrapper.observation_space
|
||||
.. autoproperty:: gymnasium.Wrapper.spec
|
||||
.. autoproperty:: gymnasium.Wrapper.metadata
|
||||
.. autoproperty:: gymnasium.Wrapper.np_random
|
||||
.. autoproperty:: gymnasium.Wrapper.unwrapped
|
||||
```
|
||||
|
||||
## Gymnasium Wrappers
|
||||
|
||||
Gymnasium provides a number of commonly used wrappers listed below. More information can be found on the particular
|
||||
wrapper in the page on the wrapper type
|
||||
|
||||
```{eval-rst}
|
||||
.. py:currentmodule:: gymnasium.wrappers
|
||||
|
||||
.. list-table::
|
||||
:header-rows: 1
|
||||
|
||||
* - Name
|
||||
- Type
|
||||
- Description
|
||||
* - :class:`AtariPreprocessing`
|
||||
- Misc Wrapper
|
||||
- Implements the common preprocessing applied to Atari environments
|
||||
* - :class:`AutoResetWrapper`
|
||||
- Misc Wrapper
|
||||
- The wrapped environment will automatically reset when the terminated or truncated state is reached.
|
||||
* - :class:`ClipAction`
|
||||
- Action Wrapper
|
||||
- Clip the continuous action to the valid bound specified by the environment's `action_space`
|
||||
* - :class:`EnvCompatibility`
|
||||
- Misc Wrapper
|
||||
- Provides compatibility for environments written in the OpenAI Gym v0.21 API to look like Gymnasium environments
|
||||
* - :class:`FilterObservation`
|
||||
- Observation Wrapper
|
||||
- Filters a dictionary observation spaces to only requested keys
|
||||
* - :class:`FlattenObservation`
|
||||
- Observation Wrapper
|
||||
- An Observation wrapper that flattens the observation
|
||||
* - :class:`FrameStack`
|
||||
- Observation Wrapper
|
||||
- AnObservation wrapper that stacks the observations in a rolling manner.
|
||||
* - :class:`GrayScaleObservation`
|
||||
- Observation Wrapper
|
||||
- Convert the image observation from RGB to gray scale.
|
||||
* - :class:`HumanRendering`
|
||||
- Misc Wrapper
|
||||
- Allows human like rendering for environments that support "rgb_array" rendering
|
||||
* - :class:`NormalizeObservation`
|
||||
- Observation Wrapper
|
||||
- This wrapper will normalize observations s.t. each coordinate is centered with unit variance.
|
||||
* - :class:`NormalizeReward`
|
||||
- Reward Wrapper
|
||||
- This wrapper will normalize immediate rewards s.t. their exponential moving average has a fixed variance.
|
||||
* - :class:`OrderEnforcing`
|
||||
- Misc Wrapper
|
||||
- This will produce an error if `step` or `render` is called before `reset`
|
||||
* - :class:`PixelObservationWrapper`
|
||||
- Observation Wrapper
|
||||
- Augment observations by pixel values obtained via `render` that can be added to or replaces the environments observation.
|
||||
* - :class:`RecordEpisodeStatistics`
|
||||
- Misc Wrapper
|
||||
- This will keep track of cumulative rewards and episode lengths returning them at the end.
|
||||
* - :class:`RecordVideo`
|
||||
- Misc Wrapper
|
||||
- This wrapper will record videos of rollouts.
|
||||
* - :class:`RenderCollection`
|
||||
- Misc Wrapper
|
||||
- Enable list versions of render modes, i.e. "rgb_array_list" for "rgb_array" such that the rendering for each step are saved in a list until `render` is called.
|
||||
* - :class:`RescaleAction`
|
||||
- Action Wrapper
|
||||
- Rescales the continuous action space of the environment to a range \[`min_action`, `max_action`], where `min_action` and `max_action` are numpy arrays or floats.
|
||||
* - :class:`ResizeObservation`
|
||||
- Observation Wrapper
|
||||
- This wrapper works on environments with image observations (or more generally observations of shape AxBxC) and resizes the observation to the shape given by the tuple `shape`.
|
||||
* - :class:`StepAPICompatibility`
|
||||
- Misc Wrapper
|
||||
- Modifies an environment step function from (old) done to the (new) termination / truncation API.
|
||||
* - :class:`TimeAwareObservation`
|
||||
- Observation Wrapper
|
||||
- Augment the observation with current time step in the trajectory (by appending it to the observation).
|
||||
* - :class:`TimeLimit`
|
||||
- Misc Wrapper
|
||||
- This wrapper will emit a truncated signal if the specified number of steps is exceeded in an episode.
|
||||
* - :class:`TransformObservation`
|
||||
- Observation Wrapper
|
||||
- This wrapper will apply function to observations
|
||||
* - :class:`TransformReward`
|
||||
- Reward Wrapper
|
||||
- This wrapper will apply function to rewards
|
||||
* - :class:`VectorListInfo`
|
||||
- Misc Wrapper
|
||||
- This wrapper will convert the info of a vectorized environment from the `dict` format to a `list` of dictionaries where the i-th dictionary contains info of the i-th environment.
|
||||
```
|
||||
|
@@ -5,12 +5,14 @@
|
||||
```{eval-rst}
|
||||
.. autoclass:: gymnasium.ActionWrapper
|
||||
|
||||
.. automethod:: gymnasium.ActionWrapper.action
|
||||
.. automethod:: gymnasium.ActionWrapper.action
|
||||
```
|
||||
|
||||
## Available Action Wrappers
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: gymnasium.wrappers.TransformAction
|
||||
.. autoclass:: gymnasium.wrappers.ClipAction
|
||||
.. autoclass:: gymnasium.wrappers.RescaleAction
|
||||
.. autoclass:: gymnasium.wrappers.StickyAction
|
||||
```
|
||||
|
@@ -1,16 +1,33 @@
|
||||
---
|
||||
title: Misc Wrappers
|
||||
---
|
||||
|
||||
# Misc Wrappers
|
||||
|
||||
|
||||
## Common Wrappers
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: gymnasium.wrappers.TimeLimit
|
||||
.. autoclass:: gymnasium.wrappers.RecordVideo
|
||||
.. autoclass:: gymnasium.wrappers.RecordEpisodeStatistics
|
||||
.. autoclass:: gymnasium.wrappers.AtariPreprocessing
|
||||
.. autoclass:: gymnasium.wrappers.AutoResetWrapper
|
||||
.. autoclass:: gymnasium.wrappers.EnvCompatibility
|
||||
.. autoclass:: gymnasium.wrappers.StepAPICompatibility
|
||||
```
|
||||
|
||||
## Uncommon Wrappers
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: gymnasium.wrappers.Autoreset
|
||||
.. autoclass:: gymnasium.wrappers.PassiveEnvChecker
|
||||
.. autoclass:: gymnasium.wrappers.HumanRendering
|
||||
.. autoclass:: gymnasium.wrappers.OrderEnforcing
|
||||
.. autoclass:: gymnasium.wrappers.RecordEpisodeStatistics
|
||||
.. autoclass:: gymnasium.wrappers.RecordVideo
|
||||
.. autoclass:: gymnasium.wrappers.RenderCollection
|
||||
.. autoclass:: gymnasium.wrappers.TimeLimit
|
||||
.. autoclass:: gymnasium.wrappers.VectorListInfo
|
||||
```
|
||||
|
||||
## Data Conversion Wrappers
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: gymnasium.wrappers.JaxToNumpy
|
||||
.. autoclass:: gymnasium.wrappers.JaxToTorch
|
||||
.. autoclass:: gymnasium.wrappers.NumpyToTorch
|
||||
```
|
||||
|
@@ -1,23 +1,26 @@
|
||||
# Observation Wrappers
|
||||
|
||||
## Base Class
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: gymnasium.ObservationWrapper
|
||||
|
||||
.. automethod:: gymnasium.ObservationWrapper.observation
|
||||
```
|
||||
|
||||
## Available Observation Wrappers
|
||||
## Implemented Wrappers
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: gymnasium.wrappers.TransformObservation
|
||||
.. autoclass:: gymnasium.wrappers.DelayObservation
|
||||
.. autoclass:: gymnasium.wrappers.DtypeObservation
|
||||
.. autoclass:: gymnasium.wrappers.FilterObservation
|
||||
.. autoclass:: gymnasium.wrappers.FlattenObservation
|
||||
.. autoclass:: gymnasium.wrappers.FrameStack
|
||||
.. autoclass:: gymnasium.wrappers.GrayScaleObservation
|
||||
.. autoclass:: gymnasium.wrappers.FrameStackObservation
|
||||
.. autoclass:: gymnasium.wrappers.GrayscaleObservation
|
||||
.. autoclass:: gymnasium.wrappers.MaxAndSkipObservation
|
||||
.. autoclass:: gymnasium.wrappers.NormalizeObservation
|
||||
.. autoclass:: gymnasium.wrappers.PixelObservationWrapper
|
||||
.. autoclass:: gymnasium.wrappers.RenderObservation
|
||||
.. autoclass:: gymnasium.wrappers.ResizeObservation
|
||||
.. autoclass:: gymnasium.wrappers.ReshapeObservation
|
||||
.. autoclass:: gymnasium.wrappers.RescaleObservation
|
||||
.. autoclass:: gymnasium.wrappers.TimeAwareObservation
|
||||
```
|
||||
|
@@ -1,17 +1,19 @@
|
||||
---
|
||||
title: Reward Wrappers
|
||||
---
|
||||
|
||||
# Reward Wrappers
|
||||
|
||||
## Base Class
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: gymnasium.RewardWrapper
|
||||
|
||||
.. automethod:: gymnasium.RewardWrapper.reward
|
||||
```
|
||||
|
||||
## Available Reward Wrappers
|
||||
## Implemented Wrappers
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: gymnasium.wrappers.TransformReward
|
||||
.. autoclass:: gymnasium.wrappers.NormalizeReward
|
||||
.. autoclass:: gymnasium.wrappers.ClipReward
|
||||
```
|
||||
|
102
docs/api/wrappers/table.md
Normal file
102
docs/api/wrappers/table.md
Normal file
@@ -0,0 +1,102 @@
|
||||
|
||||
# List of Wrappers
|
||||
|
||||
Gymnasium provides a number of commonly used wrappers listed below. More information can be found on the particular
|
||||
wrapper in the page on the wrapper type
|
||||
|
||||
```{eval-rst}
|
||||
.. py:currentmodule:: gymnasium.wrappers
|
||||
|
||||
.. list-table::
|
||||
:header-rows: 1
|
||||
|
||||
* - Name
|
||||
- Description
|
||||
* - :class:`AtariPreprocessing`
|
||||
- Implements the common preprocessing techniques for Atari environments (excluding frame stacking).
|
||||
* - :class:`Autoreset`
|
||||
- The wrapped environment is automatically reset when an terminated or truncated state is reached.
|
||||
* - :class:`ClipAction`
|
||||
- Clips the ``action`` pass to ``step`` to be within the environment's `action_space`.
|
||||
* - :class:`ClipReward`
|
||||
- Clips the rewards for an environment between an upper and lower bound.
|
||||
* - :class:`DelayObservation`
|
||||
- Adds a delay to the returned observation from the environment.
|
||||
* - :class:`DtypeObservation`
|
||||
- Modifies the dtype of an observation array to a specified dtype.
|
||||
* - :class:`FilterObservation`
|
||||
- Filters a Dict or Tuple observation spaces by a set of keys or indexes.
|
||||
* - :class:`FlattenObservation`
|
||||
- Flattens the environment's observation space and each observation from ``reset`` and ``step`` functions.
|
||||
* - :class:`FrameStackObservation`
|
||||
- Stacks the observations from the last ``N`` time steps in a rolling manner.
|
||||
* - :class:`GrayscaleObservation`
|
||||
- Converts an image observation computed by ``reset`` and ``step`` from RGB to Grayscale.
|
||||
* - :class:`HumanRendering`
|
||||
- Allows human like rendering for environments that support "rgb_array" rendering.
|
||||
* - :class:`JaxToNumpy`
|
||||
- Wraps a Jax-based environment such that it can be interacted with NumPy arrays.
|
||||
* - :class:`JaxToTorch`
|
||||
- Wraps a Jax-based environment so that it can be interacted with PyTorch Tensors.
|
||||
* - :class:`MaxAndSkipObservation`
|
||||
- Skips the N-th frame (observation) and return the max values between the two last observations.
|
||||
* - :class:`NormalizeObservation`
|
||||
- Normalizes observations to be centered at the mean with unit variance.
|
||||
* - :class:`NormalizeReward`
|
||||
- Normalizes immediate rewards such that their exponential moving average has a fixed variance.
|
||||
* - :class:`NumpyToTorch`
|
||||
- Wraps a NumPy-based environment such that it can be interacted with PyTorch Tensors.
|
||||
* - :class:`OrderEnforcing`
|
||||
- Will produce an error if ``step`` or ``render`` is called before ``render``.
|
||||
* - :class:`PassiveEnvChecker`
|
||||
- A passive environment checker wrapper that surrounds the ``step``, ``reset`` and ``render`` functions to check they follows gymnasium's API.
|
||||
* - :class:`RecordEpisodeStatistics`
|
||||
- This wrapper will keep track of cumulative rewards and episode lengths.
|
||||
* - :class:`RecordVideo`
|
||||
- Records videos of environment episodes using the environment's render function.
|
||||
* - :class:`RenderCollection`
|
||||
- Collect rendered frames of an environment such ``render`` returns a ``list[RenderedFrame]``.
|
||||
* - :class:`RenderObservation`
|
||||
- Includes the rendered observations in the environment's observations.
|
||||
* - :class:`RescaleAction`
|
||||
- Affinely (linearly) rescales a ``Box`` action space of the environment to within the range of ``[min_action, max_action]``.
|
||||
* - :class:`RescaleObservation`
|
||||
- Affinely (linearly) rescales a ``Box`` observation space of the environment to within the range of ``[min_obs, max_obs]``.
|
||||
* - :class:`ReshapeObservation`
|
||||
- Reshapes Array based observations to a specified shape.
|
||||
* - :class:`ResizeObservation`
|
||||
- Resizes image observations using OpenCV to a specified shape.
|
||||
* - :class:`StickyAction`
|
||||
- Adds a probability that the action is repeated for the same ``step`` function.
|
||||
* - :class:`TimeAwareObservation`
|
||||
- Augment the observation with the number of time steps taken within an episode.
|
||||
* - :class:`TimeLimit`
|
||||
- Limits the number of steps for an environment through truncating the environment if a maximum number of timesteps is exceeded.
|
||||
* - :class:`TransformAction`
|
||||
- Applies a function to the ``action`` before passing the modified value to the environment ``step`` function.
|
||||
* - :class:`TransformObservation`
|
||||
- Applies a function to the ``observation`` received from the environment's ``reset`` and ``step`` that is passed back to the user.
|
||||
* - :class:`TransformReward`
|
||||
- Applies a function to the ``reward`` received from the environment's ``step``.
|
||||
|
||||
```
|
||||
|
||||
## Vector only Wrappers
|
||||
|
||||
```{eval-rst}
|
||||
.. py:currentmodule:: gymnasium.wrappers.vector
|
||||
|
||||
.. list-table::
|
||||
:header-rows: 1
|
||||
|
||||
* - Name
|
||||
- Description
|
||||
* - :class:`DictInfoToList`
|
||||
- Converts infos of vectorized environments from ``dict`` to ``List[dict]``.
|
||||
* - :class:`VectorizeTransformAction`
|
||||
- Vectorizes a single-agent transform action wrapper for vector environments.
|
||||
* - :class:`VectorizeTransformObservation`
|
||||
- Vectorizes a single-agent transform observation wrapper for vector environments.
|
||||
* - :class:`VectorizeTransformReward`
|
||||
- Vectorizes a single-agent transform reward wrapper for vector environments.
|
||||
```
|
19
docs/api/wrappers/vector_wrappers.md
Normal file
19
docs/api/wrappers/vector_wrappers.md
Normal file
@@ -0,0 +1,19 @@
|
||||
---
|
||||
title: Vector Wrappers
|
||||
---
|
||||
|
||||
# Vector wrappers
|
||||
|
||||
## Vector only wrappers
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: gymnasium.wrappers.vector.DictInfoToList
|
||||
```
|
||||
|
||||
## Vectorize Transform Wrappers to Vector Wrappers
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: gymnasium.wrappers.vector.VectorizeTransformObservation
|
||||
.. autoclass:: gymnasium.wrappers.vector.VectorizeTransformAction
|
||||
.. autoclass:: gymnasium.wrappers.vector.VectorizeTransformReward
|
||||
```
|
@@ -40,10 +40,10 @@ release = gymnasium.__version__
|
||||
# ones.
|
||||
extensions = [
|
||||
"sphinx.ext.napoleon",
|
||||
"sphinx.ext.doctest",
|
||||
"sphinx.ext.autodoc",
|
||||
"sphinx.ext.githubpages",
|
||||
"sphinx.ext.viewcode",
|
||||
"sphinx.ext.coverage",
|
||||
"myst_parser",
|
||||
"furo.gen_tutorials",
|
||||
"sphinx_gallery.gen_gallery",
|
||||
|
@@ -1,134 +0,0 @@
|
||||
---
|
||||
layout: "contents"
|
||||
title: Basic Usage
|
||||
firstpage:
|
||||
---
|
||||
|
||||
# Basic Usage
|
||||
|
||||
Gymnasium is a project that provides an API for all single agent reinforcement learning environments, and includes implementations of common environments: cartpole, pendulum, mountain-car, mujoco, atari, and more.
|
||||
|
||||
The API contains four key functions: ``make``, ``reset``, ``step`` and ``render``, that this basic usage will introduce you to. At the core of Gymnasium is ``Env``, a high-level python class representing a markov decision process (MDP) from reinforcement learning theory (this is not a perfect reconstruction, and is missing several components of MDPs). Within gymnasium, environments (MDPs) are implemented as ``Env`` classes, along with ``Wrappers``, which provide helpful utilities and can change the results passed to the user.
|
||||
|
||||
## Initializing Environments
|
||||
|
||||
Initializing environments is very easy in Gymnasium and can be done via the ``make`` function:
|
||||
|
||||
```python
|
||||
import gymnasium as gym
|
||||
env = gym.make('CartPole-v1')
|
||||
```
|
||||
|
||||
This will return an ``Env`` for users to interact with. To see all environments you can create, use ``gymnasium.envs.registry.keys()``.``make`` includes a number of additional parameters to adding wrappers, specifying keywords to the environment and more.
|
||||
|
||||
## Interacting with the Environment
|
||||
|
||||
The classic "agent-environment loop" pictured below is simplified representation of reinforcement learning that Gymnasium implements.
|
||||
|
||||
```{image} /_static/diagrams/AE_loop.png
|
||||
:width: 50%
|
||||
:align: center
|
||||
:class: only-light
|
||||
```
|
||||
|
||||
```{image} /_static/diagrams/AE_loop_dark.png
|
||||
:width: 50%
|
||||
:align: center
|
||||
:class: only-dark
|
||||
```
|
||||
|
||||
This loop is implemented using the following gymnasium code
|
||||
|
||||
```python
|
||||
import gymnasium as gym
|
||||
env = gym.make("LunarLander-v2", render_mode="human")
|
||||
observation, info = env.reset()
|
||||
|
||||
for _ in range(1000):
|
||||
action = env.action_space.sample() # agent policy that uses the observation and info
|
||||
observation, reward, terminated, truncated, info = env.step(action)
|
||||
|
||||
if terminated or truncated:
|
||||
observation, info = env.reset()
|
||||
|
||||
env.close()
|
||||
```
|
||||
|
||||
The output should look something like this:
|
||||
|
||||
```{figure} https://user-images.githubusercontent.com/15806078/153222406-af5ce6f0-4696-4a24-a683-46ad4939170c.gif
|
||||
:width: 50%
|
||||
:align: center
|
||||
```
|
||||
|
||||
### Explaining the code
|
||||
|
||||
First, an environment is created using ``make`` with an additional keyword `"render_mode"` that specifies how the environment should be visualised. See ``render`` for details on the default meaning of different render modes. In this example, we use the ``"LunarLander"`` environment where the agent controls a spaceship that needs to land safely.
|
||||
|
||||
After initializing the environment, we ``reset`` the environment to get the first observation of the environment. For initializing the environment with a particular random seed or options (see environment documentation for possible values) use the ``seed`` or ``options`` parameters with ``reset``.
|
||||
|
||||
Next, the agent performs an action in the environment, ``step``, this can be imagined as moving a robot or pressing a button on a games' controller that causes a change within the environment. As a result, the agent receives a new observation from the updated environment along with a reward for taking the action. This reward could be for instance positive for destroying an enemy or a negative reward for moving into lava. One such action-observation exchange is referred to as a *timestep*.
|
||||
|
||||
However, after some timesteps, the environment may end, this is called the terminal state. For instance, the robot may have crashed, or the agent have succeeded in completing a task, the environment will need to stop as the agent cannot continue. In gymnasium, if the environment has terminated, this is returned by ``step``. Similarly, we may also want the environment to end after a fixed number of timesteps, in this case, the environment issues a truncated signal. If either of ``terminated`` or ``truncated`` are `true` then ``reset`` should be called next to restart the environment.
|
||||
|
||||
## Action and observation spaces
|
||||
|
||||
Every environment specifies the format of valid actions and observations with the ``env.action_space`` and ``env.observation_space`` attributes. This is helpful for both knowing the expected input and output of the environment as all valid actions and observation should be contained with the respective space.
|
||||
|
||||
In the example, we sampled random actions via ``env.action_space.sample()`` instead of using an agent policy, mapping observations to actions which users will want to make. See one of the agent tutorials for an example of creating and training an agent policy.
|
||||
|
||||
Every environment should have the attributes ``action_space`` and ``observation_space``, both of which should be instances of classes that inherit from ``Space``. Gymnasium has support for a majority of possible spaces users might need:
|
||||
|
||||
- ``Box``: describes an n-dimensional continuous space. It's a bounded space where we can define the upper and lower
|
||||
limits which describe the valid values our observations can take.
|
||||
- ``Discrete``: describes a discrete space where {0, 1, ..., n-1} are the possible values our observation or action can take.
|
||||
Values can be shifted to {a, a+1, ..., a+n-1} using an optional argument.
|
||||
- ``Dict``: represents a dictionary of simple spaces.
|
||||
- ``Tuple``: represents a tuple of simple spaces.
|
||||
- ``MultiBinary``: creates an n-shape binary space. Argument n can be a number or a list of numbers.
|
||||
- ``MultiDiscrete``: consists of a series of ``Discrete`` action spaces with a different number of actions in each element.
|
||||
|
||||
For example usage of spaces, see their [documentation](/api/spaces) along with [utility functions](/api/spaces/utils). There are a couple of more niche spaces ``Graph``, ``Sequence`` and ``Text``.
|
||||
|
||||
## Modifying the environment
|
||||
|
||||
Wrappers are a convenient way to modify an existing environment without having to alter the underlying code directly. Using wrappers will allow you to avoid a lot of boilerplate code and make your environment more modular. Wrappers can also be chained to combine their effects. Most environments that are generated via ``gymnasium.make`` will already be wrapped by default using the ``TimeLimit``, ``OrderEnforcing`` and ``PassiveEnvChecker``.
|
||||
|
||||
In order to wrap an environment, you must first initialize a base environment. Then you can pass this environment along with (possibly optional) parameters to the wrapper's constructor:
|
||||
|
||||
```python
|
||||
>>> import gymnasium as gym
|
||||
>>> from gymnasium.wrappers import FlattenObservation
|
||||
>>> env = gym.make("CarRacing-v2")
|
||||
>>> env.observation_space.shape
|
||||
(96, 96, 3)
|
||||
>>> wrapped_env = FlattenObservation(env)
|
||||
>>> wrapped_env.observation_space.shape
|
||||
(27648,)
|
||||
|
||||
```
|
||||
|
||||
Gymnasium already provides many commonly used wrappers for you. Some examples:
|
||||
|
||||
- `TimeLimit`: Issue a truncated signal if a maximum number of timesteps has been exceeded (or the base environment has issued a truncated signal).
|
||||
- `ClipAction`: Clip the action such that it lies in the action space (of type `Box`).
|
||||
- `RescaleAction`: Rescale actions to lie in a specified interval
|
||||
- `TimeAwareObservation`: Add information about the index of timestep to observation. In some cases helpful to ensure that transitions are Markov.
|
||||
|
||||
For a full list of implemented wrappers in gymnasium, see [wrappers](/api/wrappers).
|
||||
|
||||
If you have a wrapped environment, and you want to get the unwrapped environment underneath all the layers of wrappers (so that you can manually call a function or change some underlying aspect of the environment), you can use the `.unwrapped` attribute. If the environment is already a base environment, the `.unwrapped` attribute will just return itself.
|
||||
|
||||
```python
|
||||
>>> wrapped_env
|
||||
<FlattenObservation<TimeLimit<OrderEnforcing<PassiveEnvChecker<CarRacing<CarRacing-v2>>>>>>
|
||||
>>> wrapped_env.unwrapped
|
||||
<gymnasium.envs.box2d.car_racing.CarRacing object at 0x7f04efcb8850>
|
||||
|
||||
```
|
||||
|
||||
## More information
|
||||
|
||||
* [Making a Custom environment using the Gymnasium API](/tutorials/gymnasium_basics/environment_creation/)
|
||||
* [Training an agent to play blackjack](/tutorials/training_agents/blackjack_tutorial)
|
||||
* [Compatibility with OpenAI Gym](/content/gym_compatibility)
|
@@ -1,122 +0,0 @@
|
||||
---
|
||||
layout: "contents"
|
||||
title: Migration Guide
|
||||
---
|
||||
|
||||
# v21 to v26 Migration Guide
|
||||
|
||||
```{eval-rst}
|
||||
.. py:currentmodule:: gymnasium.wrappers
|
||||
|
||||
Gymnasium is a fork of `OpenAI Gym v26 <https://github.com/openai/gym/releases/tag/0.26.2>`_, which introduced a large breaking change from `Gym v21 <https://github.com/openai/gym/releases/tag/v0.21.0>`_.
|
||||
In this guide, we briefly outline the API changes from Gym v21 - which a number of tutorials have been written for - to Gym v26.
|
||||
For environments still stuck in the v21 API, users can use the :class:`EnvCompatibility` wrapper to convert them to v26 compliant.
|
||||
For more information, see the `guide </content/gym_compatibility>`_
|
||||
```
|
||||
|
||||
### Example code for v21
|
||||
|
||||
```python
|
||||
import gym
|
||||
env = gym.make("LunarLander-v2", options={})
|
||||
env.seed(123)
|
||||
observation = env.reset()
|
||||
|
||||
done = False
|
||||
while not done:
|
||||
action = env.action_space.sample() # agent policy that uses the observation and info
|
||||
observation, reward, done, info = env.step(action)
|
||||
|
||||
env.render(mode="human")
|
||||
|
||||
env.close()
|
||||
```
|
||||
|
||||
### Example code for v26
|
||||
|
||||
```python
|
||||
import gym
|
||||
env = gym.make("LunarLander-v2", render_mode="human")
|
||||
observation, info = env.reset(seed=123, options={})
|
||||
|
||||
done = False
|
||||
while not done:
|
||||
action = env.action_space.sample() # agent policy that uses the observation and info
|
||||
observation, reward, terminated, truncated, info = env.step(action)
|
||||
|
||||
done = terminated or truncated
|
||||
|
||||
env.close()
|
||||
```
|
||||
|
||||
## Seed and random number generator
|
||||
|
||||
```{eval-rst}
|
||||
.. py:currentmodule:: gymnasium.Env
|
||||
|
||||
The ``Env.seed()`` has been removed from the Gym v26 environments in favour of ``Env.reset(seed=seed)``.
|
||||
This allows seeding to only be changed on environment reset.
|
||||
The decision to remove ``seed`` was because some environments use emulators that cannot change random number generators within an episode and must be done at the beginning of a new episode.
|
||||
We are aware of cases where controlling the random number generator is important, in these cases, if the environment uses the built-in random number generator, users can set the seed manually with the attribute :attr:`np_random`.
|
||||
|
||||
Gymnasium v26 changed to using ``numpy.random.Generator`` instead of a custom random number generator.
|
||||
This means that several functions such as ``randint`` were removed in favour of ``integers``.
|
||||
While some environments might use external random number generator, we recommend using the attribute :attr:`np_random` that wrappers and external users can access and utilise.
|
||||
```
|
||||
|
||||
## Environment Reset
|
||||
|
||||
```{eval-rst}
|
||||
In v26, :meth:`reset` takes two optional parameters and returns one value.
|
||||
This contrasts to v21 which takes no parameters and returns ``None``.
|
||||
The two parameters are ``seed`` for setting the random number generator and ``options`` which allows additional data to be passed to the environment on reset.
|
||||
For example, in classic control, the ``options`` parameter now allows users to modify the range of the state bound.
|
||||
See the original `PR <https://github.com/openai/gym/pull/2921>`_ for more details.
|
||||
|
||||
:meth:`reset` further returns ``info``, similar to the ``info`` returned by :meth:`step`.
|
||||
This is important because ``info`` can include metrics or valid action mask that is used or saved in the next step.
|
||||
|
||||
To update older environments, we highly recommend that ``super().reset(seed=seed)`` is called on the first line of :meth:`reset`.
|
||||
This will automatically update the :attr:`np_random` with the seed value.
|
||||
```
|
||||
|
||||
## Environment Step
|
||||
|
||||
```{eval-rst}
|
||||
In v21, the type definition of :meth:`step` is ``tuple[ObsType, SupportsFloat, bool, dict[str, Any]`` representing the next observation, the reward from the step, if the episode is done and additional info from the step.
|
||||
Due to reproducibility issues that will be expanded on in a blog post soon, we have changed the type definition to ``tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]`` adding an extra boolean value.
|
||||
This extra bool corresponds to the older `done` now changed to `terminated` and `truncated`.
|
||||
These changes were introduced in Gym `v26 <https://github.com/openai/gym/releases/tag/0.26.0>`_ (turned off by default in `v25 <https://github.com/openai/gym/releases/tag/0.25.0>`_).
|
||||
|
||||
For users wishing to update, in most cases, replacing ``done`` with ``terminated`` and ``truncated=False`` in :meth:`step` should address most issues.
|
||||
However, environments that have reasons for episode truncation rather than termination should read through the associated `PR <https://github.com/openai/gym/pull/2752>`_.
|
||||
For users looping through an environment, they should modify ``done = terminated or truncated`` as is show in the example code.
|
||||
For training libraries, the primary difference is to change ``done`` to ``terminated``, indicating whether bootstrapping should or shouldn't happen.
|
||||
```
|
||||
|
||||
## TimeLimit Wrapper
|
||||
```{eval-rst}
|
||||
In v21, the :class:`TimeLimit` wrapper added an extra key in the ``info`` dictionary ``TimeLimit.truncated`` whenever the agent reached the time limit without reaching a terminal state.
|
||||
|
||||
In v26, this information is instead communicated through the `truncated` return value described in the previous section, which is `True` if the agent reaches the time limit, whether or not it reaches a terminal state. The old dictionary entry is equivalent to ``truncated and not terminated``
|
||||
```
|
||||
|
||||
## Environment Render
|
||||
|
||||
```{eval-rst}
|
||||
In v26, a new render API was introduced such that the render mode is fixed at initialisation as some environments don't allow on-the-fly render mode changes. Therefore, users should now specify the :attr:`render_mode` within ``gym.make`` as shown in the v26 example code above.
|
||||
|
||||
For a more complete explanation of the changes, please refer to this `summary <https://younis.dev/blog/render-api/>`_.
|
||||
```
|
||||
|
||||
## Removed code
|
||||
|
||||
```{eval-rst}
|
||||
.. py:currentmodule:: gymnasium.wrappers
|
||||
|
||||
* GoalEnv - This was removed, users needing it should reimplement the environment or use Gymnasium Robotics which contains an implementation of this environment.
|
||||
* ``from gym.envs.classic_control import rendering`` - This was removed in favour of users implementing their own rendering systems. Gymnasium environments are coded using pygame.
|
||||
* Robotics environments - The robotics environments have been moved to the `Gymnasium Robotics <https://robotics.farama.org/>`_ project.
|
||||
* Monitor wrapper - This wrapper was replaced with two separate wrapper, :class:`RecordVideo` and :class:`RecordEpisodeStatistics`
|
||||
|
||||
```
|
@@ -17,19 +17,27 @@ An API standard for reinforcement learning with a diverse collection of referenc
|
||||
:width: 500
|
||||
```
|
||||
|
||||
**Gymnasium is a maintained fork of OpenAI’s Gym library.** The Gymnasium interface is simple, pythonic, and capable of representing general RL problems, and has a [compatibility wrapper](content/gym_compatibility) for old Gym environments:
|
||||
**Gymnasium is a maintained fork of OpenAI’s Gym library.** The Gymnasium interface is simple, pythonic, and capable of representing general RL problems, and has a [compatibility wrapper](introduction/gym_compatibility) for old Gym environments:
|
||||
|
||||
```{code-block} python
|
||||
|
||||
import gymnasium as gym
|
||||
|
||||
# Initialise the environment
|
||||
env = gym.make("LunarLander-v2", render_mode="human")
|
||||
|
||||
# Reset the environment to generate the first observation
|
||||
observation, info = env.reset(seed=42)
|
||||
for _ in range(1000):
|
||||
action = env.action_space.sample() # this is where you would insert your policy
|
||||
observation, reward, terminated, truncated, info = env.step(action)
|
||||
# this is where you would insert your policy
|
||||
action = env.action_space.sample()
|
||||
|
||||
if terminated or truncated:
|
||||
observation, info = env.reset()
|
||||
# step (transition) through the environment with the action
|
||||
# receiving the next observation, reward and if the episode has terminated or truncated
|
||||
observation, reward, terminated, truncated, info = env.step(action)
|
||||
|
||||
# If the episode has ended then we can reset to start a new episode
|
||||
if terminated or truncated:
|
||||
observation, info = env.reset()
|
||||
|
||||
env.close()
|
||||
```
|
||||
@@ -38,9 +46,9 @@ env.close()
|
||||
:hidden:
|
||||
:caption: Introduction
|
||||
|
||||
content/basic_usage
|
||||
content/gym_compatibility
|
||||
content/migration-guide
|
||||
introduction/basic_usage
|
||||
introduction/gym_compatibility
|
||||
introduction/migration-guide
|
||||
```
|
||||
|
||||
```{toctree}
|
||||
@@ -53,7 +61,7 @@ api/spaces
|
||||
api/wrappers
|
||||
api/vector
|
||||
api/utils
|
||||
api/experimental
|
||||
api/functional
|
||||
```
|
||||
|
||||
```{toctree}
|
||||
|
172
docs/introduction/basic_usage.md
Normal file
172
docs/introduction/basic_usage.md
Normal file
@@ -0,0 +1,172 @@
|
||||
---
|
||||
layout: "contents"
|
||||
title: Basic Usage
|
||||
firstpage:
|
||||
---
|
||||
|
||||
# Basic Usage
|
||||
|
||||
```{eval-rst}
|
||||
.. py:currentmodule:: gymnasium
|
||||
|
||||
Gymnasium is a project that provides an API for all single agent reinforcement learning environments, and includes implementations of common environments: cartpole, pendulum, mountain-car, mujoco, atari, and more.
|
||||
|
||||
The API contains four key functions: :meth:`make`, :meth:`Env.reset`, :meth:`Env.step` and :meth:`Env.render`, that this basic usage will introduce you to. At the core of Gymnasium is :class:`Env`, a high-level python class representing a markov decision process (MDP) from reinforcement learning theory (this is not a perfect reconstruction, and is missing several components of MDPs). Within gymnasium, environments (MDPs) are implemented as :class:`Env` classes, along with :class:`Wrapper`, provide helpful utilities to change actions passed to the environment and modified the observations, rewards, termination or truncations conditions passed back to the user.
|
||||
```
|
||||
|
||||
## Initializing Environments
|
||||
|
||||
```{eval-rst}
|
||||
.. py:currentmodule:: gymnasium
|
||||
|
||||
Initializing environments is very easy in Gymnasium and can be done via the :meth:`make` function:
|
||||
```
|
||||
|
||||
```python
|
||||
import gymnasium as gym
|
||||
env = gym.make('CartPole-v1')
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. py:currentmodule:: gymnasium
|
||||
|
||||
This will return an :class:`Env` for users to interact with. To see all environments you can create, use :meth:`pprint_registry`. Furthermore, :meth:`make` provides a number of additional arguments for specifying keywords to the environment, adding more or less wrappers, etc.
|
||||
```
|
||||
|
||||
## Interacting with the Environment
|
||||
|
||||
The classic "agent-environment loop" pictured below is simplified representation of reinforcement learning that Gymnasium implements.
|
||||
|
||||
```{image} /_static/diagrams/AE_loop.png
|
||||
:width: 50%
|
||||
:align: center
|
||||
:class: only-light
|
||||
```
|
||||
|
||||
```{image} /_static/diagrams/AE_loop_dark.png
|
||||
:width: 50%
|
||||
:align: center
|
||||
:class: only-dark
|
||||
```
|
||||
|
||||
This loop is implemented using the following gymnasium code
|
||||
|
||||
```python
|
||||
import gymnasium as gym
|
||||
env = gym.make("LunarLander-v2", render_mode="human")
|
||||
observation, info = env.reset()
|
||||
|
||||
for _ in range(1000):
|
||||
action = env.action_space.sample() # agent policy that uses the observation and info
|
||||
observation, reward, terminated, truncated, info = env.step(action)
|
||||
|
||||
if terminated or truncated:
|
||||
observation, info = env.reset()
|
||||
|
||||
env.close()
|
||||
```
|
||||
|
||||
The output should look something like this:
|
||||
|
||||
```{figure} https://user-images.githubusercontent.com/15806078/153222406-af5ce6f0-4696-4a24-a683-46ad4939170c.gif
|
||||
:width: 50%
|
||||
:align: center
|
||||
```
|
||||
|
||||
### Explaining the code
|
||||
|
||||
```{eval-rst}
|
||||
.. py:currentmodule:: gymnasium
|
||||
|
||||
First, an environment is created using :meth:`make` with an additional keyword ``"render_mode"`` that specifies how the environment should be visualised.
|
||||
|
||||
.. py:currentmodule:: gymnasium.Env
|
||||
|
||||
See :meth:`render` for details on the default meaning of different render modes. In this example, we use the ``"LunarLander"`` environment where the agent controls a spaceship that needs to land safely.
|
||||
|
||||
After initializing the environment, we :meth:`reset` the environment to get the first observation of the environment. For initializing the environment with a particular random seed or options (see environment documentation for possible values) use the ``seed`` or ``options`` parameters with :meth:`reset`.
|
||||
|
||||
Next, the agent performs an action in the environment, :meth:`step`, this can be imagined as moving a robot or pressing a button on a games' controller that causes a change within the environment. As a result, the agent receives a new observation from the updated environment along with a reward for taking the action. This reward could be for instance positive for destroying an enemy or a negative reward for moving into lava. One such action-observation exchange is referred to as a **timestep**.
|
||||
|
||||
However, after some timesteps, the environment may end, this is called the terminal state. For instance, the robot may have crashed, or the agent have succeeded in completing a task, the environment will need to stop as the agent cannot continue. In gymnasium, if the environment has terminated, this is returned by :meth:`step`. Similarly, we may also want the environment to end after a fixed number of timesteps, in this case, the environment issues a truncated signal. If either of ``terminated`` or ``truncated`` are ``True`` then :meth:`reset` should be called next to restart the environment.
|
||||
```
|
||||
|
||||
## Action and observation spaces
|
||||
|
||||
```{eval-rst}
|
||||
.. py:currentmodule:: gymnasium.Env
|
||||
|
||||
Every environment specifies the format of valid actions and observations with the :attr:`action_space` and :attr:`observation_space` attributes. This is helpful for both knowing the expected input and output of the environment as all valid actions and observation should be contained with the respective space.
|
||||
|
||||
In the example, we sampled random actions via ``env.action_space.sample()`` instead of using an agent policy, mapping observations to actions which users will want to make. See one of the agent tutorials for an example of creating and training an agent policy.
|
||||
|
||||
.. py:currentmodule:: gymnasium
|
||||
|
||||
Every environment should have the attributes :attr:`Env.action_space` and :attr:`Env.observation_space`, both of which should be instances of classes that inherit from :class:`spaces.Space`. Gymnasium has support for a majority of possible spaces users might need:
|
||||
|
||||
.. py:currentmodule:: gymnasium.spaces
|
||||
|
||||
- :class:`Box`: describes an n-dimensional continuous space. It's a bounded space where we can define the upper and lower
|
||||
limits which describe the valid values our observations can take.
|
||||
- :class:`Discrete`: describes a discrete space where ``{0, 1, ..., n-1}`` are the possible values our observation or action can take.
|
||||
Values can be shifted to ``{a, a+1, ..., a+n-1}`` using an optional argument.
|
||||
- :class:`Dict`: represents a dictionary of simple spaces.
|
||||
- :class:`Tuple`: represents a tuple of simple spaces.
|
||||
- :class:`MultiBinary`: creates an n-shape binary space. Argument n can be a number or a list of numbers.
|
||||
- :class:`MultiDiscrete`: consists of a series of :class:`Discrete` action spaces with a different number of actions in each element.
|
||||
|
||||
For example usage of spaces, see their `documentation </api/spaces>`_ along with `utility functions </api/spaces/utils>`_. There are a couple of more niche spaces :class:`Graph`, :class:`Sequence` and :class:`Text`.
|
||||
```
|
||||
|
||||
## Modifying the environment
|
||||
|
||||
```{eval-rst}
|
||||
.. py:currentmodule:: gymnasium.wrappers
|
||||
|
||||
Wrappers are a convenient way to modify an existing environment without having to alter the underlying code directly. Using wrappers will allow you to avoid a lot of boilerplate code and make your environment more modular. Wrappers can also be chained to combine their effects. Most environments that are generated via ``gymnasium.make`` will already be wrapped by default using the :class:`TimeLimitV0`, :class:`OrderEnforcingV0` and :class:`PassiveEnvCheckerV0`.
|
||||
|
||||
In order to wrap an environment, you must first initialize a base environment. Then you can pass this environment along with (possibly optional) parameters to the wrapper's constructor:
|
||||
```
|
||||
|
||||
```python
|
||||
>>> import gymnasium as gym
|
||||
>>> from gymnasium.wrappers import FlattenObservation
|
||||
>>> env = gym.make("CarRacing-v2")
|
||||
>>> env.observation_space.shape
|
||||
(96, 96, 3)
|
||||
>>> wrapped_env = FlattenObservation(env)
|
||||
>>> wrapped_env.observation_space.shape
|
||||
(27648,)
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. py:currentmodule:: gymnasium.wrappers
|
||||
|
||||
Gymnasium already provides many commonly used wrappers for you. Some examples:
|
||||
|
||||
- :class:`TimeLimitV0`: Issue a truncated signal if a maximum number of timesteps has been exceeded (or the base environment has issued a truncated signal).
|
||||
- :class:`ClipActionV0`: Clip the action such that it lies in the action space (of type `Box`).
|
||||
- :class:`RescaleActionV0`: Rescale actions to lie in a specified interval
|
||||
- :class:`TimeAwareObservationV0`: Add information about the index of timestep to observation. In some cases helpful to ensure that transitions are Markov.
|
||||
```
|
||||
|
||||
For a full list of implemented wrappers in gymnasium, see [wrappers](/api/wrappers).
|
||||
|
||||
```{eval-rst}
|
||||
.. py:currentmodule:: gymnasium.Env
|
||||
|
||||
If you have a wrapped environment, and you want to get the unwrapped environment underneath all the layers of wrappers (so that you can manually call a function or change some underlying aspect of the environment), you can use the :attr:`unwrapped` attribute. If the environment is already a base environment, the :attr:`unwrapped` attribute will just return itself.
|
||||
```
|
||||
|
||||
```python
|
||||
>>> wrapped_env
|
||||
<FlattenObservation<TimeLimit<OrderEnforcing<PassiveEnvChecker<CarRacing<CarRacing-v2>>>>>>
|
||||
>>> wrapped_env.unwrapped
|
||||
<gymnasium.envs.box2d.car_racing.CarRacing object at 0x7f04efcb8850>
|
||||
```
|
||||
|
||||
## More information
|
||||
|
||||
* [Making a Custom environment using the Gymnasium API](/tutorials/gymnasium_basics/environment_creation/)
|
||||
* [Training an agent to play blackjack](/tutorials/training_agents/blackjack_tutorial)
|
||||
* [Compatibility with OpenAI Gym](/introduction/gym_compatibility)
|
@@ -12,9 +12,7 @@ Gymnasium provides a number of compatibility methods for a range of Environment
|
||||
```{eval-rst}
|
||||
.. py:currentmodule:: gymnasium.wrappers
|
||||
|
||||
For environments that are registered solely in OpenAI Gym and not in Gymnasium, Gymnasium v0.26.3 and above allows importing them through either a special environment or a wrapper.
|
||||
The ``"GymV26Environment-v0"`` environment was introduced in Gymnasium v0.26.3, and allows importing of Gym environments through the ``env_name`` argument along with other relevant kwargs environment kwargs.
|
||||
To perform conversion through a wrapper, the environment itself can be passed to the wrapper :class:`EnvCompatibility` through the ``env`` kwarg.
|
||||
For environments that are registered solely in OpenAI Gym and not in Gymnasium, Gymnasium v0.26.3 and above allows importing them through either a special environment or a wrapper. The ``"GymV26Environment-v0"`` environment was introduced in Gymnasium v0.26.3, and allows importing of Gym environments through the ``env_name`` argument along with other relevant kwargs environment kwargs. To perform conversion through a wrapper, the environment itself can be passed to the wrapper :class:`EnvCompatibility` through the ``env`` kwarg.
|
||||
```
|
||||
|
||||
An example of this is atari 0.8.0 which does not have a gymnasium implementation.
|
||||
@@ -29,9 +27,7 @@ env = gym.make("GymV26Environment-v0", env_id="ALE/Pong-v5")
|
||||
```{eval-rst}
|
||||
.. py:currentmodule:: gymnasium
|
||||
|
||||
A number of environments have not updated to the recent Gym changes, in particular since v0.21.
|
||||
This update is significant for the introduction of ``termination`` and ``truncation`` signatures in favour of the previously used ``done``.
|
||||
To allow backward compatibility, Gym and Gymnasium v0.26+ include an ``apply_api_compatibility`` kwarg when calling :meth:`make` that automatically converts a v0.21 API compliant environment to one that is compatible with v0.26+.
|
||||
A number of environments have not updated to the recent Gym changes, in particular since v0.21. This update is significant for the introduction of ``termination`` and ``truncation`` signatures in favour of the previously used ``done``. To allow backward compatibility, Gym and Gymnasium v0.26+ include an ``apply_api_compatibility`` kwarg when calling :meth:`make` that automatically converts a v0.21 API compliant environment to one that is compatible with v0.26+.
|
||||
```
|
||||
|
||||
```python
|
103
docs/introduction/migration-guide.md
Normal file
103
docs/introduction/migration-guide.md
Normal file
@@ -0,0 +1,103 @@
|
||||
---
|
||||
layout: "contents"
|
||||
title: Migration Guide
|
||||
---
|
||||
|
||||
# v0.21 to v0.26 Migration Guide
|
||||
|
||||
```{eval-rst}
|
||||
.. py:currentmodule:: gymnasium.wrappers
|
||||
|
||||
Gymnasium is a fork of `OpenAI Gym v0.26 <https://github.com/openai/gym/releases/tag/0.26.2>`_, which introduced a large breaking change from `Gym v0.21 <https://github.com/openai/gym/releases/tag/v0.21.0>`_. In this guide, we briefly outline the API changes from Gym v0.21 - which a number of tutorials have been written for - to Gym v0.26. For environments still stuck in the v0.21 API, users can use the :class:`EnvCompatibility` wrapper to convert them to v0.26 compliant.
|
||||
For more information, see the `guide </content/gym_compatibility>`_
|
||||
```
|
||||
|
||||
## Example code for v0.21
|
||||
|
||||
```python
|
||||
import gym
|
||||
env = gym.make("LunarLander-v2", options={})
|
||||
env.seed(123)
|
||||
observation = env.reset()
|
||||
|
||||
done = False
|
||||
while not done:
|
||||
action = env.action_space.sample() # agent policy that uses the observation and info
|
||||
observation, reward, done, info = env.step(action)
|
||||
|
||||
env.render(mode="human")
|
||||
|
||||
env.close()
|
||||
```
|
||||
|
||||
## Example code for v0.26
|
||||
|
||||
```python
|
||||
import gym
|
||||
env = gym.make("LunarLander-v2", render_mode="human")
|
||||
observation, info = env.reset(seed=123, options={})
|
||||
|
||||
done = False
|
||||
while not done:
|
||||
action = env.action_space.sample() # agent policy that uses the observation and info
|
||||
observation, reward, terminated, truncated, info = env.step(action)
|
||||
|
||||
done = terminated or truncated
|
||||
|
||||
env.close()
|
||||
```
|
||||
|
||||
## Seed and random number generator
|
||||
|
||||
```{eval-rst}
|
||||
.. py:currentmodule:: gymnasium.Env
|
||||
|
||||
The ``Env.seed()`` has been removed from the Gym v0.26 environments in favour of ``Env.reset(seed=seed)``. This allows seeding to only be changed on environment reset. The decision to remove ``seed`` was because some environments use emulators that cannot change random number generators within an episode and must be done at the beginning of a new episode. We are aware of cases where controlling the random number generator is important, in these cases, if the environment uses the built-in random number generator, users can set the seed manually with the attribute :attr:`np_random`.
|
||||
|
||||
Gymnasium v0.26 changed to using ``numpy.random.Generator`` instead of a custom random number generator. This means that several functions such as ``randint`` were removed in favour of ``integers``. While some environments might use external random number generator, we recommend using the attribute :attr:`np_random` that wrappers and external users can access and utilise.
|
||||
```
|
||||
|
||||
## Environment Reset
|
||||
|
||||
```{eval-rst}
|
||||
In v0.26, :meth:`reset` takes two optional parameters and returns one value. This contrasts to v0.21 which takes no parameters and returns ``None``. The two parameters are ``seed`` for setting the random number generator and ``options`` which allows additional data to be passed to the environment on reset. For example, in classic control, the ``options`` parameter now allows users to modify the range of the state bound. See the original `PR <https://github.com/openai/gym/pull/2921>`_ for more details.
|
||||
|
||||
:meth:`reset` further returns ``info``, similar to the ``info`` returned by :meth:`step`. This is important because ``info`` can include metrics or valid action mask that is used or saved in the next step.
|
||||
|
||||
To update older environments, we highly recommend that ``super().reset(seed=seed)`` is called on the first line of :meth:`reset`. This will automatically update the :attr:`np_random` with the seed value.
|
||||
```
|
||||
|
||||
## Environment Step
|
||||
|
||||
```{eval-rst}
|
||||
In v0.21, the type definition of :meth:`step` is ``tuple[ObsType, SupportsFloat, bool, dict[str, Any]`` representing the next observation, the reward from the step, if the episode is done and additional info from the step. Due to reproducibility issues that will be expanded on in a blog post soon, we have changed the type definition to ``tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]`` adding an extra boolean value. This extra bool corresponds to the older `done` now changed to `terminated` and `truncated`. These changes were introduced in Gym `v0.26 <https://github.com/openai/gym/releases/tag/0.26.0>`_ (turned off by default in `v25 <https://github.com/openai/gym/releases/tag/0.25.0>`_).
|
||||
|
||||
For users wishing to update, in most cases, replacing ``done`` with ``terminated`` and ``truncated=False`` in :meth:`step` should address most issues. However, environments that have reasons for episode truncation rather than termination should read through the associated `PR <https://github.com/openai/gym/pull/2752>`_. For users looping through an environment, they should modify ``done = terminated or truncated`` as is show in the example code. For training libraries, the primary difference is to change ``done`` to ``terminated``, indicating whether bootstrapping should or shouldn't happen.
|
||||
```
|
||||
|
||||
## TimeLimit Wrapper
|
||||
```{eval-rst}
|
||||
In v0.21, the :class:`TimeLimit` wrapper added an extra key in the ``info`` dictionary ``TimeLimit.truncated`` whenever the agent reached the time limit without reaching a terminal state.
|
||||
|
||||
In v0.26, this information is instead communicated through the `truncated` return value described in the previous section, which is `True` if the agent reaches the time limit, whether or not it reaches a terminal state. The old dictionary entry is equivalent to ``truncated and not terminated``
|
||||
```
|
||||
|
||||
## Environment Render
|
||||
|
||||
```{eval-rst}
|
||||
In v0.26, a new render API was introduced such that the render mode is fixed at initialisation as some environments don't allow on-the-fly render mode changes. Therefore, users should now specify the :attr:`render_mode` within ``gym.make`` as shown in the v0.26 example code above.
|
||||
|
||||
For a more complete explanation of the changes, please refer to this `summary <https://younis.dev/blog/render-api/>`_.
|
||||
```
|
||||
|
||||
## Removed code
|
||||
|
||||
```{eval-rst}
|
||||
.. py:currentmodule:: gymnasium.wrappers
|
||||
|
||||
* GoalEnv - This was removed, users needing it should reimplement the environment or use Gymnasium Robotics which contains an implementation of this environment.
|
||||
* ``from gym.envs.classic_control import rendering`` - This was removed in favour of users implementing their own rendering systems. Gymnasium environments are coded using pygame.
|
||||
* Robotics environments - The robotics environments have been moved to the `Gymnasium Robotics <https://robotics.farama.org/>`_ project.
|
||||
* Monitor wrapper - This wrapper was replaced with two separate wrapper, :class:`RecordVideo` and :class:`RecordEpisodeStatistics`
|
||||
|
||||
```
|
@@ -21,8 +21,7 @@ from gymnasium.envs.registration import (
|
||||
|
||||
# necessary for `envs.__init__` which registers all gymnasium environments and loads plugins
|
||||
from gymnasium import envs
|
||||
from gymnasium import spaces, utils, vector, wrappers, error, logger
|
||||
from gymnasium import experimental
|
||||
from gymnasium import spaces, utils, vector, wrappers, error, logger, functional
|
||||
|
||||
|
||||
__all__ = [
|
||||
@@ -43,15 +42,15 @@ __all__ = [
|
||||
"register_envs",
|
||||
# module folders
|
||||
"envs",
|
||||
"experimental",
|
||||
"spaces",
|
||||
"utils",
|
||||
"vector",
|
||||
"wrappers",
|
||||
"error",
|
||||
"logger",
|
||||
"functional",
|
||||
]
|
||||
__version__ = "0.29.0"
|
||||
__version__ = "1.0.0rc1"
|
||||
|
||||
|
||||
# Initializing pygame initializes audio connections through SDL. SDL uses alsa by default on all Linux systems
|
||||
|
@@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Any, Generic, SupportsFloat, TypeVar
|
||||
|
||||
import numpy as np
|
||||
|
||||
from gymnasium import logger, spaces
|
||||
from gymnasium import spaces
|
||||
from gymnasium.utils import RecordConstructorArgs, seeding
|
||||
|
||||
|
||||
@@ -37,12 +37,10 @@ class Env(Generic[ObsType, ActType]):
|
||||
|
||||
- :attr:`action_space` - The Space object corresponding to valid actions, all valid actions should be contained within the space.
|
||||
- :attr:`observation_space` - The Space object corresponding to valid observations, all valid observations should be contained within the space.
|
||||
- :attr:`reward_range` - A tuple corresponding to the minimum and maximum possible rewards for an agent over an episode.
|
||||
The default reward range is set to :math:`(-\infty,+\infty)`.
|
||||
- :attr:`spec` - An environment spec that contains the information used to initialize the environment from :meth:`gymnasium.make`
|
||||
- :attr:`metadata` - The metadata of the environment, i.e. render modes, render fps
|
||||
- :attr:`np_random` - The random number generator for the environment. This is automatically assigned during
|
||||
``super().reset(seed=seed)`` and when assessing ``self.np_random``.
|
||||
``super().reset(seed=seed)`` and when assessing :attr:`np_random`.
|
||||
|
||||
.. seealso:: For modifying or extending environments use the :py:class:`gymnasium.Wrapper` class
|
||||
|
||||
@@ -54,7 +52,6 @@ class Env(Generic[ObsType, ActType]):
|
||||
metadata: dict[str, Any] = {"render_modes": []}
|
||||
# define render_mode if your environment supports rendering
|
||||
render_mode: str | None = None
|
||||
reward_range = (-float("inf"), float("inf"))
|
||||
spec: EnvSpec | None = None
|
||||
|
||||
# Set these in ALL subclasses
|
||||
@@ -238,6 +235,10 @@ class Env(Generic[ObsType, ActType]):
|
||||
"""Gets the attribute `name` from the environment."""
|
||||
return getattr(self, name)
|
||||
|
||||
def set_wrapper_attr(self, name: str, value: Any):
|
||||
"""Sets the attribute `name` on the environment with `value`."""
|
||||
setattr(self, name, value)
|
||||
|
||||
|
||||
WrapperObsType = TypeVar("WrapperObsType")
|
||||
WrapperActType = TypeVar("WrapperActType")
|
||||
@@ -268,56 +269,41 @@ class Wrapper(
|
||||
env: The environment to wrap
|
||||
"""
|
||||
self.env = env
|
||||
assert isinstance(env, Env)
|
||||
|
||||
self._action_space: spaces.Space[WrapperActType] | None = None
|
||||
self._observation_space: spaces.Space[WrapperObsType] | None = None
|
||||
self._reward_range: tuple[SupportsFloat, SupportsFloat] | None = None
|
||||
self._metadata: dict[str, Any] | None = None
|
||||
|
||||
self._cached_spec: EnvSpec | None = None
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
"""Returns an attribute with ``name``, unless ``name`` starts with an underscore.
|
||||
def step(
|
||||
self, action: WrapperActType
|
||||
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]:
|
||||
"""Uses the :meth:`step` of the :attr:`env` that can be overwritten to change the returned data."""
|
||||
return self.env.step(action)
|
||||
|
||||
Args:
|
||||
name: The variable name
|
||||
def reset(
|
||||
self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
||||
) -> tuple[WrapperObsType, dict[str, Any]]:
|
||||
"""Uses the :meth:`reset` of the :attr:`env` that can be overwritten to change the returned data."""
|
||||
return self.env.reset(seed=seed, options=options)
|
||||
|
||||
Returns:
|
||||
The value of the variable in the wrapper stack
|
||||
def render(self) -> RenderFrame | list[RenderFrame] | None:
|
||||
"""Uses the :meth:`render` of the :attr:`env` that can be overwritten to change the returned data."""
|
||||
return self.env.render()
|
||||
|
||||
Warnings:
|
||||
This feature is deprecated and removed in v1.0 and replaced with `env.get_attr(name})`
|
||||
def close(self):
|
||||
"""Closes the wrapper and :attr:`env`."""
|
||||
return self.env.close()
|
||||
|
||||
@property
|
||||
def unwrapped(self) -> Env[ObsType, ActType]:
|
||||
"""Returns the base environment of the wrapper.
|
||||
|
||||
This will be the bare :class:`gymnasium.Env` environment, underneath all layers of wrappers.
|
||||
"""
|
||||
if name == "_np_random":
|
||||
raise AttributeError(
|
||||
"Can't access `_np_random` of a wrapper, use `self.unwrapped._np_random` or `self.np_random`."
|
||||
)
|
||||
elif name.startswith("_"):
|
||||
raise AttributeError(f"accessing private attribute '{name}' is prohibited")
|
||||
logger.warn(
|
||||
f"env.{name} to get variables from other wrappers is deprecated and will be removed in v1.0, "
|
||||
f"to get this variable you can do `env.unwrapped.{name}` for environment variables or `env.get_attr('{name}')` that will search the remaining wrappers."
|
||||
)
|
||||
return getattr(self.env, name)
|
||||
|
||||
def get_wrapper_attr(self, name: str) -> Any:
|
||||
"""Gets an attribute from the wrapper and lower environments if `name` doesn't exist in this object.
|
||||
|
||||
Args:
|
||||
name: The variable name to get
|
||||
|
||||
Returns:
|
||||
The variable with name in wrapper or lower environments
|
||||
"""
|
||||
if hasattr(self, name):
|
||||
return getattr(self, name)
|
||||
else:
|
||||
try:
|
||||
return self.env.get_wrapper_attr(name)
|
||||
except AttributeError as e:
|
||||
raise AttributeError(
|
||||
f"wrapper {self.class_name()} has no attribute {name!r}"
|
||||
) from e
|
||||
return self.env.unwrapped
|
||||
|
||||
@property
|
||||
def spec(self) -> EnvSpec | None:
|
||||
@@ -362,6 +348,53 @@ class Wrapper(
|
||||
kwargs=kwargs,
|
||||
)
|
||||
|
||||
def get_wrapper_attr(self, name: str) -> Any:
|
||||
"""Gets an attribute from the wrapper and lower environments if `name` doesn't exist in this object.
|
||||
|
||||
Args:
|
||||
name: The variable name to get
|
||||
|
||||
Returns:
|
||||
The variable with name in wrapper or lower environments
|
||||
"""
|
||||
if hasattr(self, name):
|
||||
return getattr(self, name)
|
||||
else:
|
||||
try:
|
||||
return self.env.get_wrapper_attr(name)
|
||||
except AttributeError as e:
|
||||
raise AttributeError(
|
||||
f"wrapper {self.class_name()} has no attribute {name!r}"
|
||||
) from e
|
||||
|
||||
def set_wrapper_attr(self, name: str, value: Any):
|
||||
"""Sets an attribute on this wrapper or lower environment if `name` is already defined.
|
||||
|
||||
Args:
|
||||
name: The variable name
|
||||
value: The new variable value
|
||||
"""
|
||||
sub_env = self.env
|
||||
attr_set = False
|
||||
|
||||
while attr_set is False and isinstance(sub_env, Wrapper):
|
||||
if hasattr(sub_env, name):
|
||||
setattr(sub_env, name, value)
|
||||
attr_set = True
|
||||
else:
|
||||
sub_env = sub_env.env
|
||||
|
||||
if attr_set is False:
|
||||
setattr(sub_env, name, value)
|
||||
|
||||
def __str__(self):
|
||||
"""Returns the wrapper name and the :attr:`env` representation string."""
|
||||
return f"<{type(self).__name__}{self.env}>"
|
||||
|
||||
def __repr__(self):
|
||||
"""Returns the string representation of the wrapper."""
|
||||
return str(self)
|
||||
|
||||
@classmethod
|
||||
def class_name(cls) -> str:
|
||||
"""Returns the class name of the wrapper."""
|
||||
@@ -393,18 +426,6 @@ class Wrapper(
|
||||
def observation_space(self, space: spaces.Space[WrapperObsType]):
|
||||
self._observation_space = space
|
||||
|
||||
@property
|
||||
def reward_range(self) -> tuple[SupportsFloat, SupportsFloat]:
|
||||
"""Return the :attr:`Env` :attr:`reward_range` unless overwritten then the wrapper :attr:`reward_range` is used."""
|
||||
if self._reward_range is None:
|
||||
return self.env.reward_range
|
||||
logger.warn("The `reward_range` is deprecated and will be removed in v1.0")
|
||||
return self._reward_range
|
||||
|
||||
@reward_range.setter
|
||||
def reward_range(self, value: tuple[SupportsFloat, SupportsFloat]):
|
||||
self._reward_range = value
|
||||
|
||||
@property
|
||||
def metadata(self) -> dict[str, Any]:
|
||||
"""Returns the :attr:`Env` :attr:`metadata`."""
|
||||
@@ -440,54 +461,15 @@ class Wrapper(
|
||||
"Can't access `_np_random` of a wrapper, use `.unwrapped._np_random` or `.np_random`."
|
||||
)
|
||||
|
||||
def step(
|
||||
self, action: WrapperActType
|
||||
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]:
|
||||
"""Uses the :meth:`step` of the :attr:`env` that can be overwritten to change the returned data."""
|
||||
return self.env.step(action)
|
||||
|
||||
def reset(
|
||||
self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
||||
) -> tuple[WrapperObsType, dict[str, Any]]:
|
||||
"""Uses the :meth:`reset` of the :attr:`env` that can be overwritten to change the returned data."""
|
||||
return self.env.reset(seed=seed, options=options)
|
||||
|
||||
def render(self) -> RenderFrame | list[RenderFrame] | None:
|
||||
"""Uses the :meth:`render` of the :attr:`env` that can be overwritten to change the returned data."""
|
||||
return self.env.render()
|
||||
|
||||
def close(self):
|
||||
"""Closes the wrapper and :attr:`env`."""
|
||||
return self.env.close()
|
||||
|
||||
def __str__(self):
|
||||
"""Returns the wrapper name and the :attr:`env` representation string."""
|
||||
return f"<{type(self).__name__}{self.env}>"
|
||||
|
||||
def __repr__(self):
|
||||
"""Returns the string representation of the wrapper."""
|
||||
return str(self)
|
||||
|
||||
@property
|
||||
def unwrapped(self) -> Env[ObsType, ActType]:
|
||||
"""Returns the base environment of the wrapper.
|
||||
|
||||
This will be the bare :class:`gymnasium.Env` environment, underneath all layers of wrappers.
|
||||
"""
|
||||
return self.env.unwrapped
|
||||
|
||||
|
||||
class ObservationWrapper(Wrapper[WrapperObsType, ActType, ObsType, ActType]):
|
||||
"""Superclass of wrappers that can modify observations using :meth:`observation` for :meth:`reset` and :meth:`step`.
|
||||
"""Modify observations from :meth:`Env.reset` and :meth:`Env.step` using :meth:`observation` function.
|
||||
|
||||
If you would like to apply a function to only the observation before
|
||||
passing it to the learning code, you can simply inherit from :class:`ObservationWrapper` and overwrite the method
|
||||
:meth:`observation` to implement that transformation. The transformation defined in that method must be
|
||||
reflected by the :attr:`env` observation space. Otherwise, you need to specify the new observation space of the
|
||||
wrapper by setting :attr:`self.observation_space` in the :meth:`__init__` method of your wrapper.
|
||||
|
||||
Among others, Gymnasium provides the observation wrapper :class:`TimeAwareObservation`, which adds information about the
|
||||
index of the timestep to the observation.
|
||||
"""
|
||||
|
||||
def __init__(self, env: Env[ObsType, ActType]):
|
||||
|
@@ -1,14 +1,7 @@
|
||||
"""Registers the internal gym envs then loads the env plugins for module using the entry point."""
|
||||
from typing import Any
|
||||
|
||||
from gymnasium.envs.registration import (
|
||||
load_plugin_envs,
|
||||
make,
|
||||
pprint_registry,
|
||||
register,
|
||||
registry,
|
||||
spec,
|
||||
)
|
||||
from gymnasium.envs.registration import make, pprint_registry, register, registry, spec
|
||||
|
||||
|
||||
# Classic
|
||||
@@ -459,7 +452,3 @@ def _raise_shimmy_error(*args: Any, **kwargs: Any):
|
||||
# When installed, shimmy will re-register these environments with the correct entry_point
|
||||
register(id="GymV21Environment-v0", entry_point=_raise_shimmy_error)
|
||||
register(id="GymV26Environment-v0", entry_point=_raise_shimmy_error)
|
||||
|
||||
|
||||
# Hook to load plugins from entry points
|
||||
load_plugin_envs()
|
||||
|
@@ -840,7 +840,7 @@ def heuristic(env, s):
|
||||
-(s[3]) * 0.5
|
||||
) # override to reduce fall speed, that's all we need after contact
|
||||
|
||||
if env.continuous:
|
||||
if env.unwrapped.continuous:
|
||||
a = np.array([hover_todo * 20 - 1, -angle_todo * 20])
|
||||
a = np.clip(a, -1, +1)
|
||||
else:
|
||||
|
@@ -12,7 +12,7 @@ import gymnasium as gym
|
||||
from gymnasium import logger, spaces
|
||||
from gymnasium.envs.classic_control import utils
|
||||
from gymnasium.error import DependencyNotInstalled
|
||||
from gymnasium.experimental.vector import VectorEnv
|
||||
from gymnasium.vector import VectorEnv
|
||||
from gymnasium.vector.utils import batch_space
|
||||
|
||||
|
||||
@@ -74,13 +74,29 @@ class CartPoleEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]):
|
||||
|
||||
## Arguments
|
||||
|
||||
```python
|
||||
import gymnasium as gym
|
||||
gym.make('CartPole-v1')
|
||||
```
|
||||
Cartpole only has ``render_mode`` as a keyword for ``gymnasium.make``.
|
||||
On reset, the `options` parameter allows the user to change the bounds used to determine the new random state.
|
||||
|
||||
On reset, the `options` parameter allows the user to change the bounds used to determine
|
||||
the new random state.
|
||||
Examples:
|
||||
>>> import gymnasium as gym
|
||||
>>> env = gym.make("CartPole-v1", render_mode="rgb_array")
|
||||
>>> env
|
||||
<TimeLimit<OrderEnforcing<PassiveEnvChecker<CartPoleEnv<CartPole-v1>>>>>
|
||||
>>> env.reset(seed=123, options={"low": 0, "high": 1})
|
||||
(array([0.6823519 , 0.05382102, 0.22035988, 0.18437181], dtype=float32), {})
|
||||
|
||||
## Vectorized environment
|
||||
|
||||
To increase steps per seconds, users can use a custom vector environment or with an environment vectorizor.
|
||||
|
||||
Examples:
|
||||
>>> import gymnasium as gym
|
||||
>>> envs = gym.make_vec("CartPole-v1", num_envs=3, vectorization_mode="vector_entry_point")
|
||||
>>> envs
|
||||
CartPoleVectorEnv(CartPole-v1, num_envs=3)
|
||||
>>> envs = gym.make_vec("CartPole-v1", num_envs=3, vectorization_mode="sync")
|
||||
>>> envs
|
||||
SyncVectorEnv(CartPole-v1, num_envs=3)
|
||||
"""
|
||||
|
||||
metadata = {
|
||||
@@ -328,8 +344,10 @@ class CartPoleVectorEnv(VectorEnv):
|
||||
max_episode_steps: int = 500,
|
||||
render_mode: Optional[str] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_envs = num_envs
|
||||
self.max_episode_steps = max_episode_steps
|
||||
self.render_mode = render_mode
|
||||
|
||||
self.gravity = 9.8
|
||||
self.masscart = 1.0
|
||||
self.masspole = 0.1
|
||||
@@ -339,7 +357,6 @@ class CartPoleVectorEnv(VectorEnv):
|
||||
self.force_mag = 10.0
|
||||
self.tau = 0.02 # seconds between state updates
|
||||
self.kinematics_integrator = "euler"
|
||||
self.max_episode_steps = max_episode_steps
|
||||
|
||||
self.steps = np.zeros(num_envs, dtype=np.int32)
|
||||
|
||||
@@ -367,8 +384,6 @@ class CartPoleVectorEnv(VectorEnv):
|
||||
self.single_observation_space = spaces.Box(-high, high, dtype=np.float32)
|
||||
self.observation_space = batch_space(self.single_observation_space, num_envs)
|
||||
|
||||
self.render_mode = render_mode
|
||||
|
||||
self.screen_width = 600
|
||||
self.screen_height = 400
|
||||
self.screens = None
|
||||
@@ -464,6 +479,7 @@ class CartPoleVectorEnv(VectorEnv):
|
||||
|
||||
def render(self):
|
||||
if self.render_mode is None:
|
||||
assert self.spec is not None
|
||||
gym.logger.warn(
|
||||
"You are calling render method without specifying any render mode. "
|
||||
"You can specify the render_mode at initialization, "
|
||||
|
@@ -10,10 +10,10 @@ import numpy as np
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.envs.registration import EnvSpec
|
||||
from gymnasium.experimental.functional import ActType, FuncEnv, StateType
|
||||
from gymnasium.experimental.wrappers.jax_to_numpy import jax_to_numpy
|
||||
from gymnasium.functional import ActType, FuncEnv, StateType
|
||||
from gymnasium.utils import seeding
|
||||
from gymnasium.vector.utils import batch_space
|
||||
from gymnasium.wrappers.jax_to_numpy import jax_to_numpy
|
||||
|
||||
|
||||
class FunctionalJaxEnv(gym.Env):
|
||||
@@ -89,7 +89,7 @@ class FunctionalJaxEnv(gym.Env):
|
||||
observation = self.func_env.observation(next_state)
|
||||
reward = self.func_env.reward(self.state, action, next_state)
|
||||
terminated = self.func_env.terminal(next_state)
|
||||
info = self.func_env.step_info(self.state, action, next_state)
|
||||
info = self.func_env.transition_info(self.state, action, next_state)
|
||||
self.state = next_state
|
||||
|
||||
observation = jax_to_numpy(observation)
|
||||
@@ -113,7 +113,7 @@ class FunctionalJaxEnv(gym.Env):
|
||||
self.render_state = None
|
||||
|
||||
|
||||
class FunctionalJaxVectorEnv(gym.experimental.vector.VectorEnv):
|
||||
class FunctionalJaxVectorEnv(gym.vector.VectorEnv):
|
||||
"""A vector env implementation for functional Jax envs."""
|
||||
|
||||
state: StateType
|
||||
@@ -211,7 +211,7 @@ class FunctionalJaxVectorEnv(gym.experimental.vector.VectorEnv):
|
||||
else jnp.zeros_like(terminated)
|
||||
)
|
||||
|
||||
info = self.func_env.step_info(self.state, action, next_state)
|
||||
info = self.func_env.transition_info(self.state, action, next_state)
|
||||
|
||||
done = jnp.logical_or(terminated, truncated)
|
||||
if jnp.any(done):
|
@@ -9,12 +9,9 @@ import numpy as np
|
||||
from jax.random import PRNGKey
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.envs.functional_jax_env import FunctionalJaxEnv, FunctionalJaxVectorEnv
|
||||
from gymnasium.error import DependencyNotInstalled
|
||||
from gymnasium.experimental.functional import ActType, FuncEnv, StateType
|
||||
from gymnasium.experimental.functional_jax_env import (
|
||||
FunctionalJaxEnv,
|
||||
FunctionalJaxVectorEnv,
|
||||
)
|
||||
from gymnasium.functional import ActType, FuncEnv, StateType
|
||||
from gymnasium.utils import EzPickle
|
||||
|
||||
|
||||
|
@@ -10,12 +10,9 @@ import numpy as np
|
||||
from jax.random import PRNGKey
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.envs.functional_jax_env import FunctionalJaxEnv, FunctionalJaxVectorEnv
|
||||
from gymnasium.error import DependencyNotInstalled
|
||||
from gymnasium.experimental.functional import ActType, FuncEnv, StateType
|
||||
from gymnasium.experimental.functional_jax_env import (
|
||||
FunctionalJaxEnv,
|
||||
FunctionalJaxVectorEnv,
|
||||
)
|
||||
from gymnasium.functional import ActType, FuncEnv, StateType
|
||||
from gymnasium.utils import EzPickle
|
||||
|
||||
|
||||
|
@@ -10,7 +10,6 @@ import importlib.util
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
import traceback
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from types import ModuleType
|
||||
@@ -86,11 +85,13 @@ class EnvSpec:
|
||||
* **nondeterministic**: If the observation of an environment cannot be repeated with the same initial state, random number generator state and actions.
|
||||
* **max_episode_steps**: The max number of steps that the environment can take before truncation
|
||||
* **order_enforce**: If to enforce the order of :meth:`gymnasium.Env.reset` before :meth:`gymnasium.Env.step` and :meth:`gymnasium.Env.render` functions
|
||||
* **autoreset**: If to automatically reset the environment on episode end
|
||||
* **disable_env_checker**: If to disable the environment checker wrapper in :meth:`gymnasium.make`, by default False (runs the environment checker)
|
||||
* **kwargs**: Additional keyword arguments passed to the environment during initialisation
|
||||
* **additional_wrappers**: A tuple of additional wrappers applied to the environment (WrapperSpec)
|
||||
* **vector_entry_point**: The location of the vectorized environment to create from
|
||||
|
||||
Changelogs:
|
||||
v1.0.0 - Autoreset attribute removed
|
||||
"""
|
||||
|
||||
id: str
|
||||
@@ -103,9 +104,7 @@ class EnvSpec:
|
||||
# Wrappers
|
||||
max_episode_steps: int | None = field(default=None)
|
||||
order_enforce: bool = field(default=True)
|
||||
autoreset: bool = field(default=False)
|
||||
disable_env_checker: bool = field(default=False)
|
||||
apply_api_compatibility: bool = field(default=False)
|
||||
|
||||
# Environment arguments
|
||||
kwargs: dict = field(default_factory=dict)
|
||||
@@ -224,12 +223,8 @@ class EnvSpec:
|
||||
output += f"\nmax_episode_steps={self.max_episode_steps}"
|
||||
if print_all or self.order_enforce is not True:
|
||||
output += f"\norder_enforce={self.order_enforce}"
|
||||
if print_all or self.autoreset is not False:
|
||||
output += f"\nautoreset={self.autoreset}"
|
||||
if print_all or self.disable_env_checker is not False:
|
||||
output += f"\ndisable_env_checker={self.disable_env_checker}"
|
||||
if print_all or self.apply_api_compatibility is not False:
|
||||
output += f"\napplied_api_compatibility={self.apply_api_compatibility}"
|
||||
|
||||
if print_all or self.additional_wrappers:
|
||||
wrapper_output: list[str] = []
|
||||
@@ -547,55 +542,6 @@ def load_env_creator(name: str) -> EnvCreator | VectorEnvCreator:
|
||||
return fn
|
||||
|
||||
|
||||
def load_plugin_envs(entry_point: str = "gymnasium.envs"):
|
||||
"""Load modules (plugins) using the gymnasium entry points in order to register external module's environments on ``import gymnasium``.
|
||||
|
||||
Args:
|
||||
entry_point: The string for the entry point.
|
||||
"""
|
||||
# Load third-party environments
|
||||
for plugin in metadata.entry_points(group=entry_point):
|
||||
# Python 3.8 doesn't support plugin.module, plugin.attr
|
||||
# So we'll have to try and parse this ourselves
|
||||
module, attr = None, None
|
||||
try:
|
||||
module, attr = plugin.module, plugin.attr # type: ignore ## error: Cannot access member "attr" for type "EntryPoint"
|
||||
except AttributeError:
|
||||
if ":" in plugin.value:
|
||||
module, attr = plugin.value.split(":", maxsplit=1)
|
||||
else:
|
||||
module, attr = plugin.value, None
|
||||
except Exception as e:
|
||||
logger.warn(
|
||||
f"While trying to load plugin `{plugin}` from {entry_point}, an exception occurred: {e}"
|
||||
)
|
||||
module, attr = None, None
|
||||
finally:
|
||||
if attr is None:
|
||||
raise error.Error(
|
||||
f"Gymnasium environment plugin `{module}` must specify a function to execute, not a root module"
|
||||
)
|
||||
|
||||
context = namespace(plugin.name)
|
||||
if plugin.name.startswith("__") and plugin.name.endswith("__"):
|
||||
# `__internal__` is an artifact of the plugin system when the root namespace had an allow-list.
|
||||
# The allow-list is now removed and plugins can register environments in the root namespace with the `__root__` magic key.
|
||||
if plugin.name == "__root__" or plugin.name == "__internal__":
|
||||
context = contextlib.nullcontext()
|
||||
else:
|
||||
logger.warn(
|
||||
f"The environment namespace magic key `{plugin.name}` is unsupported. "
|
||||
"To register an environment at the root namespace you should specify the `__root__` namespace."
|
||||
)
|
||||
|
||||
with context:
|
||||
fn = plugin.load()
|
||||
try:
|
||||
fn()
|
||||
except Exception:
|
||||
logger.warn(f"plugin: {plugin.value} raised {traceback.format_exc()}")
|
||||
|
||||
|
||||
def register_envs(env_module: ModuleType):
|
||||
"""A No-op function such that it can appear to IDEs that a module is used."""
|
||||
pass
|
||||
@@ -618,9 +564,7 @@ def register(
|
||||
nondeterministic: bool = False,
|
||||
max_episode_steps: int | None = None,
|
||||
order_enforce: bool = True,
|
||||
autoreset: bool = False,
|
||||
disable_env_checker: bool = False,
|
||||
apply_api_compatibility: bool = False,
|
||||
additional_wrappers: tuple[WrapperSpec, ...] = (),
|
||||
vector_entry_point: VectorEnvCreator | str | None = None,
|
||||
**kwargs: Any,
|
||||
@@ -640,13 +584,13 @@ def register(
|
||||
max_episode_steps: The maximum number of episodes steps before truncation. Used by the :class:`gymnasium.wrappers.TimeLimit` wrapper if not ``None``.
|
||||
order_enforce: If to enable the order enforcer wrapper to ensure users run functions in the correct order.
|
||||
If ``True``, then the :class:`gymnasium.wrappers.OrderEnforcing` is applied to the environment.
|
||||
autoreset: If to add the :class:`gymnasium.wrappers.AutoResetWrapper` such that on ``(terminated or truncated) is True``, :meth:`gymnasium.Env.reset` is called.
|
||||
disable_env_checker: If to disable the :class:`gymnasium.wrappers.PassiveEnvChecker` to the environment.
|
||||
apply_api_compatibility: If to apply the :class:`gymnasium.wrappers.StepAPICompatibility` wrapper to the environment.
|
||||
Use if the environment is implemented in the gym v0.21 environment API.
|
||||
additional_wrappers: Additional wrappers to apply the environment.
|
||||
vector_entry_point: The entry point for creating the vector environment
|
||||
**kwargs: arbitrary keyword arguments which are passed to the environment constructor on initialisation.
|
||||
|
||||
Changelogs:
|
||||
v1.0.0 - `autoreset` and `apply_api_compatibility` parameter was removed
|
||||
"""
|
||||
assert (
|
||||
entry_point is not None or vector_entry_point is not None
|
||||
@@ -669,11 +613,6 @@ def register(
|
||||
ns_id = ns
|
||||
full_env_id = get_env_id(ns_id, name, version)
|
||||
|
||||
if autoreset is True:
|
||||
logger.warn(
|
||||
"`gymnasium.register(..., autoreset=True)` is deprecated and will be removed in v1.0. If users wish to use it then add the auto reset wrapper in the `addition_wrappers` argument."
|
||||
)
|
||||
|
||||
new_spec = EnvSpec(
|
||||
id=full_env_id,
|
||||
entry_point=entry_point,
|
||||
@@ -681,9 +620,7 @@ def register(
|
||||
nondeterministic=nondeterministic,
|
||||
max_episode_steps=max_episode_steps,
|
||||
order_enforce=order_enforce,
|
||||
autoreset=autoreset,
|
||||
disable_env_checker=disable_env_checker,
|
||||
apply_api_compatibility=apply_api_compatibility,
|
||||
**kwargs,
|
||||
additional_wrappers=additional_wrappers,
|
||||
vector_entry_point=vector_entry_point,
|
||||
@@ -698,8 +635,6 @@ def register(
|
||||
def make(
|
||||
id: str | EnvSpec,
|
||||
max_episode_steps: int | None = None,
|
||||
autoreset: bool | None = None,
|
||||
apply_api_compatibility: bool | None = None,
|
||||
disable_env_checker: bool | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Env:
|
||||
@@ -710,12 +645,9 @@ def make(
|
||||
Args:
|
||||
id: A string for the environment id or a :class:`EnvSpec`. Optionally if using a string, a module to import can be included, e.g. ``'module:Env-v0'``.
|
||||
This is equivalent to importing the module first to register the environment followed by making the environment.
|
||||
max_episode_steps: Maximum length of an episode, can override the registered :class:`EnvSpec` ``max_episode_steps``.
|
||||
The value is used by :class:`gymnasium.wrappers.TimeLimit`.
|
||||
autoreset: Whether to automatically reset the environment after each episode (:class:`gymnasium.wrappers.AutoResetWrapper`).
|
||||
apply_api_compatibility: Whether to wrap the environment with the :class:`gymnasium.wrappers.StepAPICompatibility` wrapper that
|
||||
converts the environment step from a done bool to return termination and truncation bools.
|
||||
By default, the argument is None in which the :class:`EnvSpec` ``apply_api_compatibility`` is used, otherwise this variable is used in favor.
|
||||
max_episode_steps: Maximum length of an episode, can override the registered :class:`EnvSpec` ``max_episode_steps``
|
||||
with the value being passed to :class:`gymnasium.wrappers.TimeLimit`.
|
||||
Using ``max_episode_steps=-1`` will not apply the wrapper to the environment.
|
||||
disable_env_checker: If to add :class:`gymnasium.wrappers.PassiveEnvChecker`, ``None`` will default to the
|
||||
:class:`EnvSpec` ``disable_env_checker`` value otherwise use this value will be used.
|
||||
kwargs: Additional arguments to pass to the environment constructor.
|
||||
@@ -725,6 +657,9 @@ def make(
|
||||
|
||||
Raises:
|
||||
Error: If the ``id`` doesn't exist in the :attr:`registry`
|
||||
|
||||
Changelogs:
|
||||
v1.0.0 - `autoreset` and `apply_api_compatibility` was removed
|
||||
"""
|
||||
if isinstance(id, EnvSpec):
|
||||
env_spec = id
|
||||
@@ -790,14 +725,6 @@ def make(
|
||||
f"that is not in the possible render_modes ({render_modes})."
|
||||
)
|
||||
|
||||
if apply_api_compatibility or (
|
||||
apply_api_compatibility is None and env_spec.apply_api_compatibility
|
||||
):
|
||||
# If we use the compatibility layer, we treat the render mode explicitly and don't pass it to the env creator
|
||||
render_mode = env_spec_kwargs.pop("render_mode", None)
|
||||
else:
|
||||
render_mode = None
|
||||
|
||||
try:
|
||||
env = env_creator(**env_spec_kwargs)
|
||||
except TypeError as e:
|
||||
@@ -823,9 +750,7 @@ def make(
|
||||
nondeterministic=env_spec.nondeterministic,
|
||||
max_episode_steps=None,
|
||||
order_enforce=False,
|
||||
autoreset=False,
|
||||
disable_env_checker=True,
|
||||
apply_api_compatibility=False,
|
||||
kwargs=env_spec_kwargs,
|
||||
additional_wrappers=(),
|
||||
vector_entry_point=env_spec.vector_entry_point,
|
||||
@@ -845,15 +770,6 @@ def make(
|
||||
f"The environment's wrapper spec {recreated_wrapper_spec} is different from the saved `EnvSpec` additional wrapper {env_spec_wrapper_spec}"
|
||||
)
|
||||
|
||||
# Add step API wrapper
|
||||
if apply_api_compatibility is True or (
|
||||
apply_api_compatibility is None and env_spec.apply_api_compatibility is True
|
||||
):
|
||||
logger.warn(
|
||||
"`gymnasium.make(..., apply_api_compatibility=True)` and `env_spec.apply_api_compatibility` is deprecated and will be removed in v1.0"
|
||||
)
|
||||
env = gym.wrappers.EnvCompatibility(env, render_mode)
|
||||
|
||||
# Run the environment checker as the lowest level wrapper
|
||||
if disable_env_checker is False or (
|
||||
disable_env_checker is None and env_spec.disable_env_checker is False
|
||||
@@ -865,18 +781,11 @@ def make(
|
||||
env = gym.wrappers.OrderEnforcing(env)
|
||||
|
||||
# Add the time limit wrapper
|
||||
if max_episode_steps is not None:
|
||||
env = gym.wrappers.TimeLimit(env, max_episode_steps)
|
||||
elif env_spec.max_episode_steps is not None:
|
||||
env = gym.wrappers.TimeLimit(env, env_spec.max_episode_steps)
|
||||
|
||||
# Add the auto-reset wrapper
|
||||
if autoreset is True or (autoreset is None and env_spec.autoreset is True):
|
||||
env = gym.wrappers.AutoResetWrapper(env)
|
||||
|
||||
logger.warn(
|
||||
"`gymnasium.make(..., autoreset=True)` is deprecated and will be removed in v1.0"
|
||||
)
|
||||
if max_episode_steps != -1:
|
||||
if max_episode_steps is not None:
|
||||
env = gym.wrappers.TimeLimit(env, max_episode_steps)
|
||||
elif env_spec.max_episode_steps is not None:
|
||||
env = gym.wrappers.TimeLimit(env, env_spec.max_episode_steps)
|
||||
|
||||
for wrapper_spec in env_spec.additional_wrappers[num_prior_wrappers:]:
|
||||
if wrapper_spec.kwargs is None:
|
||||
@@ -898,25 +807,25 @@ def make(
|
||||
def make_vec(
|
||||
id: str | EnvSpec,
|
||||
num_envs: int = 1,
|
||||
vectorization_mode: str = "async",
|
||||
vectorization_mode: str | None = None,
|
||||
vector_kwargs: dict[str, Any] | None = None,
|
||||
wrappers: Sequence[Callable[[Env], Wrapper]] | None = None,
|
||||
**kwargs,
|
||||
) -> gym.experimental.vector.VectorEnv:
|
||||
) -> gym.vector.VectorEnv:
|
||||
"""Create a vector environment according to the given ID.
|
||||
|
||||
Note:
|
||||
This feature is experimental, and is likely to change in future releases.
|
||||
|
||||
To find all available environments use `gymnasium.envs.registry.keys()` for all valid ids.
|
||||
To find all available environments use :func:`gymnasium.pprint_registry` or ``gymnasium.registry.keys()`` for all valid ids.
|
||||
We refer to the Vector environment as the vectorizor while the environment being vectorized is the base or vectorized environment (``vectorizor(vectorized env)``).
|
||||
|
||||
Args:
|
||||
id: Name of the environment. Optionally, a module to import can be included, eg. 'module:Env-v0'
|
||||
num_envs: Number of environments to create
|
||||
vectorization_mode: How to vectorize the environment. Can be either "async", "sync" or "custom"
|
||||
vector_kwargs: Additional arguments to pass to the vectorized environment constructor.
|
||||
wrappers: A sequence of wrapper functions to apply to the environment. Can only be used in "sync" or "async" mode.
|
||||
**kwargs: Additional arguments to pass to the environment constructor.
|
||||
vectorization_mode: The vectorization method used, defaults to ``None`` such that if a ``vector_entry_point`` exists,
|
||||
this is first used otherwise defaults to ``sync`` to use the :class:`gymnasium.vector.SyncVectorEnv`.
|
||||
Valid modes are ``"async"``, ``"sync"`` or ``"vector_entry_point"``.
|
||||
vector_kwargs: Additional arguments to pass to the vectorizor environment constructor, i.e., ``SyncVectorEnv(..., **vector_kwargs)``.
|
||||
wrappers: A sequence of wrapper functions to apply to the base environment. Can only be used in ``"sync"`` or ``"async"`` mode.
|
||||
**kwargs: Additional arguments passed to the base environment constructor.
|
||||
|
||||
Returns:
|
||||
An instance of the environment.
|
||||
@@ -926,87 +835,93 @@ def make_vec(
|
||||
"""
|
||||
if vector_kwargs is None:
|
||||
vector_kwargs = {}
|
||||
|
||||
if wrappers is None:
|
||||
wrappers = []
|
||||
|
||||
if isinstance(id, EnvSpec):
|
||||
spec_ = id
|
||||
id_env_spec = id
|
||||
env_spec_kwargs = id_env_spec.kwargs.copy()
|
||||
|
||||
num_envs = env_spec_kwargs.pop("num_envs", num_envs)
|
||||
vectorization_mode = env_spec_kwargs.pop(
|
||||
"vectorization_mode", vectorization_mode
|
||||
)
|
||||
vector_kwargs = env_spec_kwargs.pop("vector_kwargs", vector_kwargs)
|
||||
wrappers = env_spec_kwargs.pop("wrappers", wrappers)
|
||||
else:
|
||||
spec_ = _find_spec(id)
|
||||
id_env_spec = _find_spec(id)
|
||||
env_spec_kwargs = id_env_spec.kwargs.copy()
|
||||
|
||||
_kwargs = spec_.kwargs.copy()
|
||||
_kwargs.update(kwargs)
|
||||
env_spec_kwargs.update(kwargs)
|
||||
|
||||
# Check if we have the necessary entry point
|
||||
if vectorization_mode in ("sync", "async"):
|
||||
if spec_.entry_point is None:
|
||||
raise error.Error(
|
||||
f"Cannot create vectorized environment for {id} because it doesn't have an entry point defined."
|
||||
)
|
||||
entry_point = spec_.entry_point
|
||||
elif vectorization_mode in ("custom",):
|
||||
if spec_.vector_entry_point is None:
|
||||
raise error.Error(
|
||||
f"Cannot create vectorized environment for {id} because it doesn't have a vector entry point defined."
|
||||
)
|
||||
entry_point = spec_.vector_entry_point
|
||||
else:
|
||||
raise error.Error(f"Invalid vectorization mode: {vectorization_mode}")
|
||||
# Update the vectorization_mode if None
|
||||
if vectorization_mode is None:
|
||||
if id_env_spec.vector_entry_point is not None:
|
||||
vectorization_mode = "vector_entry_point"
|
||||
else:
|
||||
vectorization_mode = "sync"
|
||||
|
||||
if callable(entry_point):
|
||||
env_creator = entry_point
|
||||
else:
|
||||
# Assume it's a string
|
||||
env_creator = load_env_creator(entry_point)
|
||||
|
||||
def _create_env():
|
||||
# Env creator for use with sync and async modes
|
||||
_kwargs_copy = _kwargs.copy()
|
||||
|
||||
render_mode = _kwargs.get("render_mode", None)
|
||||
if render_mode is not None:
|
||||
inner_render_mode = (
|
||||
render_mode[: -len("_list")]
|
||||
if render_mode.endswith("_list")
|
||||
else render_mode
|
||||
)
|
||||
_kwargs_copy["render_mode"] = inner_render_mode
|
||||
|
||||
_env = env_creator(**_kwargs_copy)
|
||||
_env.spec = spec_
|
||||
if spec_.max_episode_steps is not None:
|
||||
_env = gym.wrappers.TimeLimit(_env, spec_.max_episode_steps)
|
||||
|
||||
if render_mode is not None and render_mode.endswith("_list"):
|
||||
_env = gym.wrappers.RenderCollection(_env)
|
||||
def create_single_env() -> Env:
|
||||
single_env = make(id_env_spec.id, **env_spec_kwargs.copy())
|
||||
|
||||
for wrapper in wrappers:
|
||||
_env = wrapper(_env)
|
||||
return _env
|
||||
single_env = wrapper(single_env)
|
||||
return single_env
|
||||
|
||||
if vectorization_mode == "sync":
|
||||
env = gym.experimental.vector.SyncVectorEnv(
|
||||
env_fns=[_create_env for _ in range(num_envs)],
|
||||
if id_env_spec.entry_point is None:
|
||||
raise error.Error(
|
||||
f"Cannot create vectorized environment for {id_env_spec.id} because it doesn't have an entry point defined."
|
||||
)
|
||||
|
||||
env = gym.vector.SyncVectorEnv(
|
||||
env_fns=(create_single_env for _ in range(num_envs)),
|
||||
**vector_kwargs,
|
||||
)
|
||||
elif vectorization_mode == "async":
|
||||
env = gym.experimental.vector.AsyncVectorEnv(
|
||||
env_fns=[_create_env for _ in range(num_envs)],
|
||||
if id_env_spec.entry_point is None:
|
||||
raise error.Error(
|
||||
f"Cannot create vectorized environment for {id_env_spec.id} because it doesn't have an entry point defined."
|
||||
)
|
||||
|
||||
env = gym.vector.AsyncVectorEnv(
|
||||
env_fns=[create_single_env for _ in range(num_envs)],
|
||||
**vector_kwargs,
|
||||
)
|
||||
elif vectorization_mode == "custom":
|
||||
elif vectorization_mode == "vector_entry_point":
|
||||
entry_point = id_env_spec.vector_entry_point
|
||||
if entry_point is None:
|
||||
raise error.Error(
|
||||
f"Cannot create vectorized environment for {id} because it doesn't have a vector entry point defined."
|
||||
)
|
||||
elif callable(entry_point):
|
||||
env_creator = entry_point
|
||||
else: # Assume it's a string
|
||||
env_creator = load_env_creator(entry_point)
|
||||
|
||||
if len(wrappers) > 0:
|
||||
raise error.Error("Cannot use custom vectorization mode with wrappers.")
|
||||
vector_kwargs["max_episode_steps"] = spec_.max_episode_steps
|
||||
raise error.Error(
|
||||
"Cannot use `vector_entry_point` vectorization mode with the wrappers argument."
|
||||
)
|
||||
if "max_episode_steps" not in vector_kwargs:
|
||||
vector_kwargs["max_episode_steps"] = id_env_spec.max_episode_steps
|
||||
|
||||
env = env_creator(num_envs=num_envs, **vector_kwargs)
|
||||
else:
|
||||
raise error.Error(f"Invalid vectorization mode: {vectorization_mode}")
|
||||
|
||||
# Copies the environment creation specification and kwargs to add to the environment specification details
|
||||
spec_ = copy.deepcopy(spec_)
|
||||
spec_.kwargs = _kwargs
|
||||
env.unwrapped.spec = spec_
|
||||
copied_id_spec = copy.deepcopy(id_env_spec)
|
||||
copied_id_spec.kwargs = env_spec_kwargs
|
||||
if num_envs != 1:
|
||||
copied_id_spec.kwargs["num_envs"] = num_envs
|
||||
if vectorization_mode != "async":
|
||||
copied_id_spec.kwargs["vectorization_mode"] = vectorization_mode
|
||||
if vector_kwargs is not None:
|
||||
copied_id_spec.kwargs["vector_kwargs"] = vector_kwargs
|
||||
if wrappers is not None:
|
||||
copied_id_spec.kwargs["wrappers"] = wrappers
|
||||
env.unwrapped.spec = copied_id_spec
|
||||
|
||||
return env
|
||||
|
||||
|
@@ -12,9 +12,9 @@ from jax import random
|
||||
from jax.random import PRNGKey
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.envs.functional_jax_env import FunctionalJaxEnv
|
||||
from gymnasium.error import DependencyNotInstalled
|
||||
from gymnasium.experimental.functional import ActType, FuncEnv, StateType
|
||||
from gymnasium.experimental.functional_jax_env import FunctionalJaxEnv
|
||||
from gymnasium.functional import ActType, FuncEnv, StateType
|
||||
from gymnasium.utils import EzPickle, seeding
|
||||
from gymnasium.wrappers import HumanRendering
|
||||
|
||||
|
@@ -12,9 +12,9 @@ import numpy as np
|
||||
from jax.random import PRNGKey
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.envs.functional_jax_env import FunctionalJaxEnv
|
||||
from gymnasium.error import DependencyNotInstalled
|
||||
from gymnasium.experimental.functional import ActType, FuncEnv, StateType
|
||||
from gymnasium.experimental.functional_jax_env import FunctionalJaxEnv
|
||||
from gymnasium.functional import ActType, FuncEnv, StateType
|
||||
from gymnasium.utils import EzPickle
|
||||
from gymnasium.wrappers import HumanRendering
|
||||
|
||||
|
@@ -1,25 +0,0 @@
|
||||
"""Root __init__ of the gym experimental wrappers."""
|
||||
|
||||
|
||||
from gymnasium.experimental import functional, vector, wrappers
|
||||
|
||||
|
||||
# from gymnasium.experimental.functional import FuncEnv
|
||||
# from gymnasium.experimental.vector.async_vector_env import AsyncVectorEnv
|
||||
# from gymnasium.experimental.vector.sync_vector_env import SyncVectorEnv
|
||||
# from gymnasium.experimental.vector.vector_env import VectorEnv, VectorWrapper
|
||||
|
||||
|
||||
__all__ = [
|
||||
# Functional
|
||||
# "FuncEnv",
|
||||
"functional",
|
||||
# Vector
|
||||
# "VectorEnv",
|
||||
# "VectorWrapper",
|
||||
# "SyncVectorEnv",
|
||||
# "AsyncVectorEnv",
|
||||
# wrappers
|
||||
"wrappers",
|
||||
"vector",
|
||||
]
|
@@ -1,23 +0,0 @@
|
||||
"""Experimental vector env API."""
|
||||
from gymnasium.experimental.vector import utils
|
||||
from gymnasium.experimental.vector.async_vector_env import AsyncVectorEnv
|
||||
from gymnasium.experimental.vector.sync_vector_env import SyncVectorEnv
|
||||
from gymnasium.experimental.vector.vector_env import (
|
||||
VectorActionWrapper,
|
||||
VectorEnv,
|
||||
VectorObservationWrapper,
|
||||
VectorRewardWrapper,
|
||||
VectorWrapper,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"VectorEnv",
|
||||
"VectorWrapper",
|
||||
"VectorObservationWrapper",
|
||||
"VectorActionWrapper",
|
||||
"VectorRewardWrapper",
|
||||
"SyncVectorEnv",
|
||||
"AsyncVectorEnv",
|
||||
"utils",
|
||||
]
|
@@ -1,685 +0,0 @@
|
||||
"""An async vector environment."""
|
||||
from __future__ import annotations
|
||||
|
||||
import multiprocessing
|
||||
import sys
|
||||
import time
|
||||
from copy import deepcopy
|
||||
from enum import Enum
|
||||
from multiprocessing import Queue
|
||||
from multiprocessing.connection import Connection
|
||||
from typing import Any, Callable, Sequence
|
||||
|
||||
import numpy as np
|
||||
|
||||
from gymnasium import logger
|
||||
from gymnasium.core import Env, ObsType
|
||||
from gymnasium.error import (
|
||||
AlreadyPendingCallError,
|
||||
ClosedEnvironmentError,
|
||||
CustomSpaceError,
|
||||
NoAsyncCallError,
|
||||
)
|
||||
from gymnasium.experimental.vector.utils import (
|
||||
CloudpickleWrapper,
|
||||
batch_space,
|
||||
clear_mpi_env_vars,
|
||||
concatenate,
|
||||
create_empty_array,
|
||||
create_shared_memory,
|
||||
iterate,
|
||||
read_from_shared_memory,
|
||||
write_to_shared_memory,
|
||||
)
|
||||
from gymnasium.experimental.vector.vector_env import VectorEnv
|
||||
|
||||
|
||||
__all__ = ["AsyncVectorEnv"]
|
||||
|
||||
|
||||
class AsyncState(Enum):
|
||||
DEFAULT = "default"
|
||||
WAITING_RESET = "reset"
|
||||
WAITING_STEP = "step"
|
||||
WAITING_CALL = "call"
|
||||
|
||||
|
||||
class AsyncVectorEnv(VectorEnv):
|
||||
"""Vectorized environment that runs multiple environments in parallel.
|
||||
|
||||
It uses ``multiprocessing`` processes, and pipes for communication.
|
||||
|
||||
Example:
|
||||
>>> import gymnasium as gym
|
||||
>>> env = gym.vector.AsyncVectorEnv([
|
||||
... lambda: gym.make("Pendulum-v1", g=9.81),
|
||||
... lambda: gym.make("Pendulum-v1", g=1.62)
|
||||
... ])
|
||||
>>> env.reset(seed=42)
|
||||
(array([[-0.14995256, 0.9886932 , -0.12224312],
|
||||
[ 0.5760367 , 0.8174238 , -0.91244936]], dtype=float32), {})
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env_fns: Sequence[Callable[[], Env]],
|
||||
shared_memory: bool = True,
|
||||
copy: bool = True,
|
||||
context: str | None = None,
|
||||
daemon: bool = True,
|
||||
worker: callable | None = None,
|
||||
):
|
||||
"""Vectorized environment that runs multiple environments in parallel.
|
||||
|
||||
Args:
|
||||
env_fns: Functions that create the environments.
|
||||
shared_memory: If ``True``, then the observations from the worker processes are communicated back through
|
||||
shared variables. This can improve the efficiency if the observations are large (e.g. images).
|
||||
copy: If ``True``, then the :meth:`~AsyncVectorEnv.reset` and :meth:`~AsyncVectorEnv.step` methods
|
||||
return a copy of the observations.
|
||||
context: Context for `multiprocessing`_. If ``None``, then the default context is used.
|
||||
daemon: If ``True``, then subprocesses have ``daemon`` flag turned on; that is, they will quit if
|
||||
the head process quits. However, ``daemon=True`` prevents subprocesses to spawn children,
|
||||
so for some environments you may want to have it set to ``False``.
|
||||
worker: If set, then use that worker in a subprocess instead of a default one.
|
||||
Can be useful to override some inner vector env logic, for instance, how resets on termination or truncation are handled.
|
||||
|
||||
Warnings:
|
||||
worker is an advanced mode option. It provides a high degree of flexibility and a high chance
|
||||
to shoot yourself in the foot; thus, if you are writing your own worker, it is recommended to start
|
||||
from the code for ``_worker`` (or ``_worker_shared_memory``) method, and add changes.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the observation space of some sub-environment does not match observation_space
|
||||
(or, by default, the observation space of the first sub-environment).
|
||||
ValueError: If observation_space is a custom space (i.e. not a default space in Gym,
|
||||
such as gymnasium.spaces.Box, gymnasium.spaces.Discrete, or gymnasium.spaces.Dict) and shared_memory is True.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
ctx = multiprocessing.get_context(context)
|
||||
self.env_fns = env_fns
|
||||
self.num_envs = len(env_fns)
|
||||
self.shared_memory = shared_memory
|
||||
self.copy = copy
|
||||
|
||||
# This would be nice to get rid of, but without it there's a deadlock between shared memory and pipes
|
||||
dummy_env = env_fns[0]()
|
||||
self.metadata = dummy_env.metadata
|
||||
|
||||
self.single_observation_space = dummy_env.observation_space
|
||||
self.single_action_space = dummy_env.action_space
|
||||
|
||||
self.observation_space = batch_space(
|
||||
self.single_observation_space, self.num_envs
|
||||
)
|
||||
self.action_space = batch_space(self.single_action_space, self.num_envs)
|
||||
|
||||
dummy_env.close()
|
||||
del dummy_env
|
||||
|
||||
if self.shared_memory:
|
||||
try:
|
||||
_obs_buffer = create_shared_memory(
|
||||
self.single_observation_space, n=self.num_envs, ctx=ctx
|
||||
)
|
||||
self.observations = read_from_shared_memory(
|
||||
self.single_observation_space, _obs_buffer, n=self.num_envs
|
||||
)
|
||||
except CustomSpaceError as e:
|
||||
raise ValueError(
|
||||
"Using `shared_memory=True` in `AsyncVectorEnv` "
|
||||
"is incompatible with non-standard Gymnasium observation spaces "
|
||||
"(i.e. custom spaces inheriting from `gymnasium.Space`), and is "
|
||||
"only compatible with default Gymnasium spaces (e.g. `Box`, "
|
||||
"`Tuple`, `Dict`) for batching. Set `shared_memory=False` "
|
||||
"if you use custom observation spaces."
|
||||
) from e
|
||||
else:
|
||||
_obs_buffer = None
|
||||
self.observations = create_empty_array(
|
||||
self.single_observation_space, n=self.num_envs, fn=np.zeros
|
||||
)
|
||||
|
||||
self.parent_pipes, self.processes = [], []
|
||||
self.error_queue = ctx.Queue()
|
||||
target = worker or _worker
|
||||
with clear_mpi_env_vars():
|
||||
for idx, env_fn in enumerate(self.env_fns):
|
||||
parent_pipe, child_pipe = ctx.Pipe()
|
||||
process = ctx.Process(
|
||||
target=target,
|
||||
name=f"Worker<{type(self).__name__}>-{idx}",
|
||||
args=(
|
||||
idx,
|
||||
CloudpickleWrapper(env_fn),
|
||||
child_pipe,
|
||||
parent_pipe,
|
||||
_obs_buffer,
|
||||
self.error_queue,
|
||||
),
|
||||
)
|
||||
|
||||
self.parent_pipes.append(parent_pipe)
|
||||
self.processes.append(process)
|
||||
|
||||
process.daemon = daemon
|
||||
process.start()
|
||||
child_pipe.close()
|
||||
|
||||
self._state = AsyncState.DEFAULT
|
||||
self._check_spaces()
|
||||
|
||||
def reset_async(
|
||||
self,
|
||||
seed: int | list[int] | None = None,
|
||||
options: dict | None = None,
|
||||
):
|
||||
"""Send calls to the :obj:`reset` methods of the sub-environments.
|
||||
|
||||
To get the results of these calls, you may invoke :meth:`reset_wait`.
|
||||
|
||||
Args:
|
||||
seed: List of seeds for each environment
|
||||
options: The reset option
|
||||
|
||||
Raises:
|
||||
ClosedEnvironmentError: If the environment was closed (if :meth:`close` was previously called).
|
||||
AlreadyPendingCallError: If the environment is already waiting for a pending call to another
|
||||
method (e.g. :meth:`step_async`). This can be caused by two consecutive
|
||||
calls to :meth:`reset_async`, with no call to :meth:`reset_wait` in between.
|
||||
"""
|
||||
self._assert_is_running()
|
||||
|
||||
if seed is None:
|
||||
seed = [None for _ in range(self.num_envs)]
|
||||
if isinstance(seed, int):
|
||||
seed = [seed + i for i in range(self.num_envs)]
|
||||
assert len(seed) == self.num_envs
|
||||
|
||||
if self._state != AsyncState.DEFAULT:
|
||||
raise AlreadyPendingCallError(
|
||||
f"Calling `reset_async` while waiting for a pending call to `{self._state.value}` to complete",
|
||||
str(self._state.value),
|
||||
)
|
||||
|
||||
for pipe, single_seed in zip(self.parent_pipes, seed):
|
||||
single_kwargs = {}
|
||||
if single_seed is not None:
|
||||
single_kwargs["seed"] = single_seed
|
||||
if options is not None:
|
||||
single_kwargs["options"] = options
|
||||
|
||||
pipe.send(("reset", single_kwargs))
|
||||
self._state = AsyncState.WAITING_RESET
|
||||
|
||||
def reset_wait(
|
||||
self,
|
||||
timeout: int | float | None = None,
|
||||
) -> tuple[ObsType, list[dict]]:
|
||||
"""Waits for the calls triggered by :meth:`reset_async` to finish and returns the results.
|
||||
|
||||
Args:
|
||||
timeout: Number of seconds before the call to `reset_wait` times out. If `None`, the call to `reset_wait` never times out.
|
||||
|
||||
Returns:
|
||||
A tuple of batched observations and list of dictionaries
|
||||
|
||||
Raises:
|
||||
ClosedEnvironmentError: If the environment was closed (if :meth:`close` was previously called).
|
||||
NoAsyncCallError: If :meth:`reset_wait` was called without any prior call to :meth:`reset_async`.
|
||||
TimeoutError: If :meth:`reset_wait` timed out.
|
||||
"""
|
||||
self._assert_is_running()
|
||||
if self._state != AsyncState.WAITING_RESET:
|
||||
raise NoAsyncCallError(
|
||||
"Calling `reset_wait` without any prior " "call to `reset_async`.",
|
||||
AsyncState.WAITING_RESET.value,
|
||||
)
|
||||
|
||||
if not self._poll(timeout):
|
||||
self._state = AsyncState.DEFAULT
|
||||
raise multiprocessing.TimeoutError(
|
||||
f"The call to `reset_wait` has timed out after {timeout} second(s)."
|
||||
)
|
||||
|
||||
results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
|
||||
self._raise_if_errors(successes)
|
||||
self._state = AsyncState.DEFAULT
|
||||
|
||||
infos = {}
|
||||
results, info_data = zip(*results)
|
||||
for i, info in enumerate(info_data):
|
||||
infos = self._add_info(infos, info, i)
|
||||
|
||||
if not self.shared_memory:
|
||||
self.observations = concatenate(
|
||||
self.single_observation_space, results, self.observations
|
||||
)
|
||||
|
||||
return (deepcopy(self.observations) if self.copy else self.observations), infos
|
||||
|
||||
def reset(
|
||||
self,
|
||||
*,
|
||||
seed: int | list[int] | None = None,
|
||||
options: dict | None = None,
|
||||
):
|
||||
"""Reset all parallel environments and return a batch of initial observations and info.
|
||||
|
||||
Args:
|
||||
seed: The environment reset seeds
|
||||
options: If to return the options
|
||||
|
||||
Returns:
|
||||
A batch of observations and info from the vectorized environment.
|
||||
"""
|
||||
self.reset_async(seed=seed, options=options)
|
||||
return self.reset_wait()
|
||||
|
||||
def step_async(self, actions: np.ndarray):
|
||||
"""Send the calls to :obj:`step` to each sub-environment.
|
||||
|
||||
Args:
|
||||
actions: Batch of actions. element of :attr:`~VectorEnv.action_space`
|
||||
|
||||
Raises:
|
||||
ClosedEnvironmentError: If the environment was closed (if :meth:`close` was previously called).
|
||||
AlreadyPendingCallError: If the environment is already waiting for a pending call to another
|
||||
method (e.g. :meth:`reset_async`). This can be caused by two consecutive
|
||||
calls to :meth:`step_async`, with no call to :meth:`step_wait` in
|
||||
between.
|
||||
"""
|
||||
self._assert_is_running()
|
||||
if self._state != AsyncState.DEFAULT:
|
||||
raise AlreadyPendingCallError(
|
||||
f"Calling `step_async` while waiting for a pending call to `{self._state.value}` to complete.",
|
||||
str(self._state.value),
|
||||
)
|
||||
|
||||
actions = iterate(self.action_space, actions)
|
||||
for pipe, action in zip(self.parent_pipes, actions):
|
||||
pipe.send(("step", action))
|
||||
self._state = AsyncState.WAITING_STEP
|
||||
|
||||
def step_wait(
|
||||
self, timeout: int | float | None = None
|
||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, dict]:
|
||||
"""Wait for the calls to :obj:`step` in each sub-environment to finish.
|
||||
|
||||
Args:
|
||||
timeout: Number of seconds before the call to :meth:`step_wait` times out. If ``None``, the call to :meth:`step_wait` never times out.
|
||||
|
||||
Returns:
|
||||
The batched environment step information, (obs, reward, terminated, truncated, info)
|
||||
|
||||
Raises:
|
||||
ClosedEnvironmentError: If the environment was closed (if :meth:`close` was previously called).
|
||||
NoAsyncCallError: If :meth:`step_wait` was called without any prior call to :meth:`step_async`.
|
||||
TimeoutError: If :meth:`step_wait` timed out.
|
||||
"""
|
||||
self._assert_is_running()
|
||||
if self._state != AsyncState.WAITING_STEP:
|
||||
raise NoAsyncCallError(
|
||||
"Calling `step_wait` without any prior call " "to `step_async`.",
|
||||
AsyncState.WAITING_STEP.value,
|
||||
)
|
||||
|
||||
if not self._poll(timeout):
|
||||
self._state = AsyncState.DEFAULT
|
||||
raise multiprocessing.TimeoutError(
|
||||
f"The call to `step_wait` has timed out after {timeout} second(s)."
|
||||
)
|
||||
|
||||
observations_list, rewards, terminateds, truncateds, infos = [], [], [], [], {}
|
||||
successes = []
|
||||
for i, pipe in enumerate(self.parent_pipes):
|
||||
result, success = pipe.recv()
|
||||
obs, rew, terminated, truncated, info = result
|
||||
|
||||
successes.append(success)
|
||||
if success:
|
||||
observations_list.append(obs)
|
||||
rewards.append(rew)
|
||||
terminateds.append(terminated)
|
||||
truncateds.append(truncated)
|
||||
infos = self._add_info(infos, info, i)
|
||||
|
||||
self._raise_if_errors(successes)
|
||||
self._state = AsyncState.DEFAULT
|
||||
|
||||
if not self.shared_memory:
|
||||
self.observations = concatenate(
|
||||
self.single_observation_space,
|
||||
observations_list,
|
||||
self.observations,
|
||||
)
|
||||
|
||||
return (
|
||||
deepcopy(self.observations) if self.copy else self.observations,
|
||||
np.array(rewards),
|
||||
np.array(terminateds, dtype=np.bool_),
|
||||
np.array(truncateds, dtype=np.bool_),
|
||||
infos,
|
||||
)
|
||||
|
||||
def step(self, actions):
|
||||
"""Take an action for each parallel environment.
|
||||
|
||||
Args:
|
||||
actions: element of :attr:`action_space` Batch of actions.
|
||||
|
||||
Returns:
|
||||
Batch of (observations, rewards, terminations, truncations, infos)
|
||||
"""
|
||||
self.step_async(actions)
|
||||
return self.step_wait()
|
||||
|
||||
def call_async(self, name: str, *args, **kwargs):
|
||||
"""Calls the method with name asynchronously and apply args and kwargs to the method.
|
||||
|
||||
Args:
|
||||
name: Name of the method or property to call.
|
||||
*args: Arguments to apply to the method call.
|
||||
**kwargs: Keyword arguments to apply to the method call.
|
||||
|
||||
Raises:
|
||||
ClosedEnvironmentError: If the environment was closed (if :meth:`close` was previously called).
|
||||
AlreadyPendingCallError: Calling `call_async` while waiting for a pending call to complete
|
||||
"""
|
||||
self._assert_is_running()
|
||||
if self._state != AsyncState.DEFAULT:
|
||||
raise AlreadyPendingCallError(
|
||||
"Calling `call_async` while waiting "
|
||||
f"for a pending call to `{self._state.value}` to complete.",
|
||||
str(self._state.value),
|
||||
)
|
||||
|
||||
for pipe in self.parent_pipes:
|
||||
pipe.send(("_call", (name, args, kwargs)))
|
||||
self._state = AsyncState.WAITING_CALL
|
||||
|
||||
def call_wait(self, timeout: int | float | None = None) -> list:
|
||||
"""Calls all parent pipes and waits for the results.
|
||||
|
||||
Args:
|
||||
timeout: Number of seconds before the call to `step_wait` times out.
|
||||
If `None` (default), the call to `step_wait` never times out.
|
||||
|
||||
Returns:
|
||||
List of the results of the individual calls to the method or property for each environment.
|
||||
|
||||
Raises:
|
||||
NoAsyncCallError: Calling `call_wait` without any prior call to `call_async`.
|
||||
TimeoutError: The call to `call_wait` has timed out after timeout second(s).
|
||||
"""
|
||||
self._assert_is_running()
|
||||
if self._state != AsyncState.WAITING_CALL:
|
||||
raise NoAsyncCallError(
|
||||
"Calling `call_wait` without any prior call to `call_async`.",
|
||||
AsyncState.WAITING_CALL.value,
|
||||
)
|
||||
|
||||
if not self._poll(timeout):
|
||||
self._state = AsyncState.DEFAULT
|
||||
raise multiprocessing.TimeoutError(
|
||||
f"The call to `call_wait` has timed out after {timeout} second(s)."
|
||||
)
|
||||
|
||||
results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
|
||||
self._raise_if_errors(successes)
|
||||
self._state = AsyncState.DEFAULT
|
||||
|
||||
return results
|
||||
|
||||
def call(self, name: str, *args, **kwargs) -> list[Any]:
|
||||
"""Call a method, or get a property, from each parallel environment.
|
||||
|
||||
Args:
|
||||
name (str): Name of the method or property to call.
|
||||
*args: Arguments to apply to the method call.
|
||||
**kwargs: Keyword arguments to apply to the method call.
|
||||
|
||||
Returns:
|
||||
List of the results of the individual calls to the method or property for each environment.
|
||||
"""
|
||||
self.call_async(name, *args, **kwargs)
|
||||
return self.call_wait()
|
||||
|
||||
def get_attr(self, name: str):
|
||||
"""Get a property from each parallel environment.
|
||||
|
||||
Args:
|
||||
name (str): Name of the property to be get from each individual environment.
|
||||
|
||||
Returns:
|
||||
The property with name
|
||||
"""
|
||||
return self.call(name)
|
||||
|
||||
def set_attr(self, name: str, values: list[Any] | tuple[Any] | object):
|
||||
"""Sets an attribute of the sub-environments.
|
||||
|
||||
Args:
|
||||
name: Name of the property to be set in each individual environment.
|
||||
values: Values of the property to be set to. If ``values`` is a list or
|
||||
tuple, then it corresponds to the values for each individual
|
||||
environment, otherwise a single value is set for all environments.
|
||||
|
||||
Raises:
|
||||
ValueError: Values must be a list or tuple with length equal to the number of environments.
|
||||
AlreadyPendingCallError: Calling `set_attr` while waiting for a pending call to complete.
|
||||
"""
|
||||
self._assert_is_running()
|
||||
if not isinstance(values, (list, tuple)):
|
||||
values = [values for _ in range(self.num_envs)]
|
||||
if len(values) != self.num_envs:
|
||||
raise ValueError(
|
||||
"Values must be a list or tuple with length equal to the "
|
||||
f"number of environments. Got `{len(values)}` values for "
|
||||
f"{self.num_envs} environments."
|
||||
)
|
||||
|
||||
if self._state != AsyncState.DEFAULT:
|
||||
raise AlreadyPendingCallError(
|
||||
"Calling `set_attr` while waiting "
|
||||
f"for a pending call to `{self._state.value}` to complete.",
|
||||
str(self._state.value),
|
||||
)
|
||||
|
||||
for pipe, value in zip(self.parent_pipes, values):
|
||||
pipe.send(("_setattr", (name, value)))
|
||||
_, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
|
||||
self._raise_if_errors(successes)
|
||||
|
||||
def close_extras(self, timeout: int | float | None = None, terminate: bool = False):
|
||||
"""Close the environments & clean up the extra resources (processes and pipes).
|
||||
|
||||
Args:
|
||||
timeout: Number of seconds before the call to :meth:`close` times out. If ``None``,
|
||||
the call to :meth:`close` never times out. If the call to :meth:`close`
|
||||
times out, then all processes are terminated.
|
||||
terminate: If ``True``, then the :meth:`close` operation is forced and all processes are terminated.
|
||||
|
||||
Raises:
|
||||
TimeoutError: If :meth:`close` timed out.
|
||||
"""
|
||||
timeout = 0 if terminate else timeout
|
||||
try:
|
||||
if self._state != AsyncState.DEFAULT:
|
||||
logger.warn(
|
||||
f"Calling `close` while waiting for a pending call to `{self._state.value}` to complete."
|
||||
)
|
||||
function = getattr(self, f"{self._state.value}_wait")
|
||||
function(timeout)
|
||||
except multiprocessing.TimeoutError:
|
||||
terminate = True
|
||||
|
||||
if terminate:
|
||||
for process in self.processes:
|
||||
if process.is_alive():
|
||||
process.terminate()
|
||||
else:
|
||||
for pipe in self.parent_pipes:
|
||||
if (pipe is not None) and (not pipe.closed):
|
||||
pipe.send(("close", None))
|
||||
for pipe in self.parent_pipes:
|
||||
if (pipe is not None) and (not pipe.closed):
|
||||
pipe.recv()
|
||||
|
||||
for pipe in self.parent_pipes:
|
||||
if pipe is not None:
|
||||
pipe.close()
|
||||
for process in self.processes:
|
||||
process.join()
|
||||
|
||||
def _poll(self, timeout=None):
|
||||
self._assert_is_running()
|
||||
if timeout is None:
|
||||
return True
|
||||
end_time = time.perf_counter() + timeout
|
||||
delta = None
|
||||
for pipe in self.parent_pipes:
|
||||
delta = max(end_time - time.perf_counter(), 0)
|
||||
if pipe is None:
|
||||
return False
|
||||
if pipe.closed or (not pipe.poll(delta)):
|
||||
return False
|
||||
return True
|
||||
|
||||
def _check_spaces(self):
|
||||
self._assert_is_running()
|
||||
spaces = (self.single_observation_space, self.single_action_space)
|
||||
for pipe in self.parent_pipes:
|
||||
pipe.send(("_check_spaces", spaces))
|
||||
results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
|
||||
self._raise_if_errors(successes)
|
||||
same_observation_spaces, same_action_spaces = zip(*results)
|
||||
if not all(same_observation_spaces):
|
||||
raise RuntimeError(
|
||||
f"Some environments have an observation space different from `{self.single_observation_space}`. "
|
||||
"In order to batch observations, the observation spaces from all environments must be equal."
|
||||
)
|
||||
if not all(same_action_spaces):
|
||||
raise RuntimeError(
|
||||
f"Some environments have an action space different from `{self.single_action_space}`. "
|
||||
"In order to batch actions, the action spaces from all environments must be equal."
|
||||
)
|
||||
|
||||
def _assert_is_running(self):
|
||||
if self.closed:
|
||||
raise ClosedEnvironmentError(
|
||||
f"Trying to operate on `{type(self).__name__}`, after a call to `close()`."
|
||||
)
|
||||
|
||||
def _raise_if_errors(self, successes: list[bool]):
|
||||
if all(successes):
|
||||
return
|
||||
|
||||
num_errors = self.num_envs - sum(successes)
|
||||
assert num_errors > 0
|
||||
for i in range(num_errors):
|
||||
index, exctype, value = self.error_queue.get()
|
||||
logger.error(
|
||||
f"Received the following error from Worker-{index}: {exctype.__name__}: {value}"
|
||||
)
|
||||
logger.error(f"Shutting down Worker-{index}.")
|
||||
self.parent_pipes[index].close()
|
||||
self.parent_pipes[index] = None
|
||||
|
||||
if i == num_errors - 1:
|
||||
logger.error("Raising the last exception back to the main process.")
|
||||
raise exctype(value)
|
||||
|
||||
def __del__(self):
|
||||
"""On deleting the object, checks that the vector environment is closed."""
|
||||
if not getattr(self, "closed", True) and hasattr(self, "_state"):
|
||||
self.close(terminate=True)
|
||||
|
||||
|
||||
def _worker(
|
||||
index: int,
|
||||
env_fn: callable,
|
||||
pipe: Connection,
|
||||
parent_pipe: Connection,
|
||||
shared_memory: bool,
|
||||
error_queue: Queue,
|
||||
):
|
||||
env = env_fn()
|
||||
observation_space = env.observation_space
|
||||
action_space = env.action_space
|
||||
parent_pipe.close()
|
||||
try:
|
||||
while True:
|
||||
command, data = pipe.recv()
|
||||
|
||||
if command == "reset":
|
||||
observation, info = env.reset(**data)
|
||||
if shared_memory:
|
||||
write_to_shared_memory(
|
||||
observation_space, index, observation, shared_memory
|
||||
)
|
||||
observation = None
|
||||
pipe.send(((observation, info), True))
|
||||
|
||||
elif command == "step":
|
||||
(
|
||||
observation,
|
||||
reward,
|
||||
terminated,
|
||||
truncated,
|
||||
info,
|
||||
) = env.step(data)
|
||||
if terminated or truncated:
|
||||
old_observation, old_info = observation, info
|
||||
observation, info = env.reset()
|
||||
info["final_observation"] = old_observation
|
||||
info["final_info"] = old_info
|
||||
|
||||
if shared_memory:
|
||||
write_to_shared_memory(
|
||||
observation_space, index, observation, shared_memory
|
||||
)
|
||||
observation = None
|
||||
|
||||
pipe.send(((observation, reward, terminated, truncated, info), True))
|
||||
elif command == "seed":
|
||||
env.seed(data)
|
||||
pipe.send((None, True))
|
||||
elif command == "close":
|
||||
pipe.send((None, True))
|
||||
break
|
||||
elif command == "_call":
|
||||
name, args, kwargs = data
|
||||
if name in ["reset", "step", "seed", "close"]:
|
||||
raise ValueError(
|
||||
f"Trying to call function `{name}` with "
|
||||
f"`_call`. Use `{name}` directly instead."
|
||||
)
|
||||
function = getattr(env, name)
|
||||
if callable(function):
|
||||
pipe.send((function(*args, **kwargs), True))
|
||||
else:
|
||||
pipe.send((function, True))
|
||||
elif command == "_setattr":
|
||||
name, value = data
|
||||
setattr(env, name, value)
|
||||
pipe.send((None, True))
|
||||
elif command == "_check_spaces":
|
||||
pipe.send(
|
||||
(
|
||||
(data[0] == observation_space, data[1] == action_space),
|
||||
True,
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Received unknown command `{command}`. Must "
|
||||
"be one of {`reset`, `step`, `seed`, `close`, `_call`, "
|
||||
"`_setattr`, `_check_spaces`}."
|
||||
)
|
||||
except (KeyboardInterrupt, Exception):
|
||||
error_queue.put((index,) + sys.exc_info()[:2])
|
||||
pipe.send((None, False))
|
||||
finally:
|
||||
env.close()
|
@@ -1,229 +0,0 @@
|
||||
"""A synchronous vector environment."""
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Any, Callable, Iterator
|
||||
|
||||
import numpy as np
|
||||
|
||||
from gymnasium import Env
|
||||
from gymnasium.experimental.vector.utils import (
|
||||
batch_space,
|
||||
concatenate,
|
||||
create_empty_array,
|
||||
iterate,
|
||||
)
|
||||
from gymnasium.experimental.vector.vector_env import VectorEnv
|
||||
|
||||
|
||||
__all__ = ["SyncVectorEnv"]
|
||||
|
||||
|
||||
class SyncVectorEnv(VectorEnv):
|
||||
"""Vectorized environment that serially runs multiple environments.
|
||||
|
||||
Example:
|
||||
>>> import gymnasium as gym
|
||||
>>> env = gym.vector.SyncVectorEnv([
|
||||
... lambda: gym.make("Pendulum-v1", g=9.81),
|
||||
... lambda: gym.make("Pendulum-v1", g=1.62)
|
||||
... ])
|
||||
>>> env.reset(seed=42)
|
||||
(array([[-0.14995256, 0.9886932 , -0.12224312],
|
||||
[ 0.5760367 , 0.8174238 , -0.91244936]], dtype=float32), {})
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env_fns: Iterator[Callable[[], Env]],
|
||||
copy: bool = True,
|
||||
):
|
||||
"""Vectorized environment that serially runs multiple environments.
|
||||
|
||||
Args:
|
||||
env_fns: iterable of callable functions that create the environments.
|
||||
copy: If ``True``, then the :meth:`reset` and :meth:`step` methods return a copy of the observations.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the observation space of some sub-environment does not match observation_space
|
||||
(or, by default, the observation space of the first sub-environment).
|
||||
"""
|
||||
super().__init__()
|
||||
self.env_fns = env_fns
|
||||
self.envs = [env_fn() for env_fn in env_fns]
|
||||
self.num_envs = len(self.envs)
|
||||
self.copy = copy
|
||||
self.metadata = self.envs[0].metadata
|
||||
|
||||
self.spec = self.envs[0].spec
|
||||
|
||||
self.single_observation_space = self.envs[0].observation_space
|
||||
self.single_action_space = self.envs[0].action_space
|
||||
|
||||
self.observation_space = batch_space(
|
||||
self.single_observation_space, self.num_envs
|
||||
)
|
||||
self.action_space = batch_space(self.single_action_space, self.num_envs)
|
||||
|
||||
self._check_spaces()
|
||||
self.observations = create_empty_array(
|
||||
self.single_observation_space, n=self.num_envs, fn=np.zeros
|
||||
)
|
||||
self._rewards = np.zeros((self.num_envs,), dtype=np.float64)
|
||||
self._terminateds = np.zeros((self.num_envs,), dtype=np.bool_)
|
||||
self._truncateds = np.zeros((self.num_envs,), dtype=np.bool_)
|
||||
|
||||
def reset(
|
||||
self,
|
||||
seed: int | list[int] | None = None,
|
||||
options: dict | None = None,
|
||||
):
|
||||
"""Waits for the calls triggered by :meth:`reset_async` to finish and returns the results.
|
||||
|
||||
Args:
|
||||
seed: The reset environment seed
|
||||
options: Option information for the environment reset
|
||||
|
||||
Returns:
|
||||
The reset observation of the environment and reset information
|
||||
"""
|
||||
if seed is None:
|
||||
seed = [None for _ in range(self.num_envs)]
|
||||
if isinstance(seed, int):
|
||||
seed = [seed + i for i in range(self.num_envs)]
|
||||
assert len(seed) == self.num_envs
|
||||
|
||||
self._terminateds[:] = False
|
||||
self._truncateds[:] = False
|
||||
observations = []
|
||||
infos = {}
|
||||
for i, (env, single_seed) in enumerate(zip(self.envs, seed)):
|
||||
kwargs = {}
|
||||
if single_seed is not None:
|
||||
kwargs["seed"] = single_seed
|
||||
if options is not None:
|
||||
kwargs["options"] = options
|
||||
|
||||
observation, info = env.reset(**kwargs)
|
||||
observations.append(observation)
|
||||
infos = self._add_info(infos, info, i)
|
||||
|
||||
self.observations = concatenate(
|
||||
self.single_observation_space, observations, self.observations
|
||||
)
|
||||
return (deepcopy(self.observations) if self.copy else self.observations), infos
|
||||
|
||||
def step(self, actions):
|
||||
"""Steps through each of the environments returning the batched results.
|
||||
|
||||
Returns:
|
||||
The batched environment step results
|
||||
"""
|
||||
actions = iterate(self.action_space, actions)
|
||||
|
||||
observations, infos = [], {}
|
||||
for i, (env, action) in enumerate(zip(self.envs, actions)):
|
||||
(
|
||||
observation,
|
||||
self._rewards[i],
|
||||
self._terminateds[i],
|
||||
self._truncateds[i],
|
||||
info,
|
||||
) = env.step(action)
|
||||
|
||||
if self._terminateds[i] or self._truncateds[i]:
|
||||
old_observation, old_info = observation, info
|
||||
observation, info = env.reset()
|
||||
info["final_observation"] = old_observation
|
||||
info["final_info"] = old_info
|
||||
observations.append(observation)
|
||||
infos = self._add_info(infos, info, i)
|
||||
self.observations = concatenate(
|
||||
self.single_observation_space, observations, self.observations
|
||||
)
|
||||
|
||||
return (
|
||||
deepcopy(self.observations) if self.copy else self.observations,
|
||||
np.copy(self._rewards),
|
||||
np.copy(self._terminateds),
|
||||
np.copy(self._truncateds),
|
||||
infos,
|
||||
)
|
||||
|
||||
def call(self, name, *args, **kwargs) -> tuple:
|
||||
"""Calls the method with name and applies args and kwargs.
|
||||
|
||||
Args:
|
||||
name: The method name
|
||||
*args: The method args
|
||||
**kwargs: The method kwargs
|
||||
|
||||
Returns:
|
||||
Tuple of results
|
||||
"""
|
||||
results = []
|
||||
for env in self.envs:
|
||||
function = getattr(env, name)
|
||||
if callable(function):
|
||||
results.append(function(*args, **kwargs))
|
||||
else:
|
||||
results.append(function)
|
||||
|
||||
return tuple(results)
|
||||
|
||||
def get_attr(self, name: str):
|
||||
"""Get a property from each parallel environment.
|
||||
|
||||
Args:
|
||||
name (str): Name of the property to be get from each individual environment.
|
||||
|
||||
Returns:
|
||||
The property with name
|
||||
"""
|
||||
return self.call(name)
|
||||
|
||||
def set_attr(self, name: str, values: list | tuple | Any):
|
||||
"""Sets an attribute of the sub-environments.
|
||||
|
||||
Args:
|
||||
name: The property name to change
|
||||
values: Values of the property to be set to. If ``values`` is a list or
|
||||
tuple, then it corresponds to the values for each individual
|
||||
environment, otherwise, a single value is set for all environments.
|
||||
|
||||
Raises:
|
||||
ValueError: Values must be a list or tuple with length equal to the number of environments.
|
||||
"""
|
||||
if not isinstance(values, (list, tuple)):
|
||||
values = [values for _ in range(self.num_envs)]
|
||||
if len(values) != self.num_envs:
|
||||
raise ValueError(
|
||||
"Values must be a list or tuple with length equal to the "
|
||||
f"number of environments. Got `{len(values)}` values for "
|
||||
f"{self.num_envs} environments."
|
||||
)
|
||||
|
||||
for env, value in zip(self.envs, values):
|
||||
setattr(env, name, value)
|
||||
|
||||
def close_extras(self, **kwargs):
|
||||
"""Close the environments."""
|
||||
[env.close() for env in self.envs]
|
||||
|
||||
def _check_spaces(self) -> bool:
|
||||
for env in self.envs:
|
||||
if not (env.observation_space == self.single_observation_space):
|
||||
raise RuntimeError(
|
||||
"Some environments have an observation space different from "
|
||||
f"`{self.single_observation_space}`. In order to batch observations, "
|
||||
"the observation spaces from all environments must be equal."
|
||||
)
|
||||
|
||||
if not (env.action_space == self.single_action_space):
|
||||
raise RuntimeError(
|
||||
"Some environments have an action space different from "
|
||||
f"`{self.single_action_space}`. In order to batch actions, the "
|
||||
"action spaces from all environments must be equal."
|
||||
)
|
||||
|
||||
return True
|
@@ -1,30 +0,0 @@
|
||||
"""Module for gymnasium experimental vector utility functions."""
|
||||
|
||||
from gymnasium.experimental.vector.utils.misc import (
|
||||
CloudpickleWrapper,
|
||||
clear_mpi_env_vars,
|
||||
)
|
||||
from gymnasium.experimental.vector.utils.shared_memory import (
|
||||
create_shared_memory,
|
||||
read_from_shared_memory,
|
||||
write_to_shared_memory,
|
||||
)
|
||||
from gymnasium.experimental.vector.utils.space_utils import (
|
||||
batch_space,
|
||||
concatenate,
|
||||
create_empty_array,
|
||||
iterate,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"batch_space",
|
||||
"iterate",
|
||||
"concatenate",
|
||||
"create_empty_array",
|
||||
"create_shared_memory",
|
||||
"read_from_shared_memory",
|
||||
"write_to_shared_memory",
|
||||
"CloudpickleWrapper",
|
||||
"clear_mpi_env_vars",
|
||||
]
|
@@ -1,61 +0,0 @@
|
||||
"""Miscellaneous utilities."""
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
|
||||
from gymnasium.core import Env
|
||||
|
||||
|
||||
__all__ = ["CloudpickleWrapper", "clear_mpi_env_vars"]
|
||||
|
||||
|
||||
class CloudpickleWrapper:
|
||||
"""Wrapper that uses cloudpickle to pickle and unpickle the result."""
|
||||
|
||||
def __init__(self, fn: Callable[[], Env]):
|
||||
"""Cloudpickle wrapper for a function."""
|
||||
self.fn = fn
|
||||
|
||||
def __getstate__(self):
|
||||
"""Get the state using `cloudpickle.dumps(self.fn)`."""
|
||||
import cloudpickle
|
||||
|
||||
return cloudpickle.dumps(self.fn)
|
||||
|
||||
def __setstate__(self, ob):
|
||||
"""Sets the state with obs."""
|
||||
import pickle
|
||||
|
||||
self.fn = pickle.loads(ob)
|
||||
|
||||
def __call__(self):
|
||||
"""Calls the function `self.fn` with no arguments."""
|
||||
return self.fn()
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def clear_mpi_env_vars():
|
||||
"""Clears the MPI of environment variables.
|
||||
|
||||
`from mpi4py import MPI` will call `MPI_Init` by default.
|
||||
If the child process has MPI environment variables, MPI will think that the child process
|
||||
is an MPI process just like the parent and do bad things such as hang.
|
||||
|
||||
This context manager is a hacky way to clear those environment variables
|
||||
temporarily such as when we are starting multiprocessing Processes.
|
||||
|
||||
Yields:
|
||||
Yields for the context manager
|
||||
"""
|
||||
removed_environment = {}
|
||||
for k, v in list(os.environ.items()):
|
||||
for prefix in ["OMPI_", "PMI_"]:
|
||||
if k.startswith(prefix):
|
||||
removed_environment[k] = v
|
||||
del os.environ[k]
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
os.environ.update(removed_environment)
|
@@ -1,255 +0,0 @@
|
||||
"""Utility functions for vector environments to share memory between processes."""
|
||||
from __future__ import annotations
|
||||
|
||||
import multiprocessing as mp
|
||||
from collections import OrderedDict
|
||||
from ctypes import c_bool
|
||||
from functools import singledispatch
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
from gymnasium.error import CustomSpaceError
|
||||
from gymnasium.spaces import (
|
||||
Box,
|
||||
Dict,
|
||||
Discrete,
|
||||
Graph,
|
||||
MultiBinary,
|
||||
MultiDiscrete,
|
||||
Sequence,
|
||||
Space,
|
||||
Text,
|
||||
Tuple,
|
||||
flatten,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["create_shared_memory", "read_from_shared_memory", "write_to_shared_memory"]
|
||||
|
||||
|
||||
@singledispatch
|
||||
def create_shared_memory(
|
||||
space: Space[Any], n: int = 1, ctx=mp
|
||||
) -> dict[str, Any] | tuple[Any, ...] | mp.Array:
|
||||
"""Create a shared memory object, to be shared across processes.
|
||||
|
||||
This eventually contains the observations from the vectorized environment.
|
||||
|
||||
Args:
|
||||
space: Observation space of a single environment in the vectorized environment.
|
||||
n: Number of environments in the vectorized environment (i.e. the number of processes).
|
||||
ctx: The multiprocess module
|
||||
|
||||
Returns:
|
||||
shared_memory for the shared object across processes.
|
||||
|
||||
Raises:
|
||||
CustomSpaceError: Space is not a valid :class:`gymnasium.Space` instance
|
||||
"""
|
||||
if isinstance(space, Space):
|
||||
raise CustomSpaceError(
|
||||
f"Space of type `{type(space)}` doesn't have an registered `create_shared_memory` function. Register `{type(space)}` for `create_shared_memory` to support it."
|
||||
)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"The space provided to `create_shared_memory` is not a gymnasium Space instance, type: {type(space)}, {space}"
|
||||
)
|
||||
|
||||
|
||||
@create_shared_memory.register(Box)
|
||||
@create_shared_memory.register(Discrete)
|
||||
@create_shared_memory.register(MultiDiscrete)
|
||||
@create_shared_memory.register(MultiBinary)
|
||||
def _create_base_shared_memory(
|
||||
space: Box | Discrete | MultiDiscrete | MultiBinary, n: int = 1, ctx=mp
|
||||
):
|
||||
assert space.dtype is not None
|
||||
dtype = space.dtype.char
|
||||
if dtype in "?":
|
||||
dtype = c_bool
|
||||
return ctx.Array(dtype, n * int(np.prod(space.shape)))
|
||||
|
||||
|
||||
@create_shared_memory.register(Tuple)
|
||||
def _create_tuple_shared_memory(space: Tuple, n: int = 1, ctx=mp):
|
||||
return tuple(
|
||||
create_shared_memory(subspace, n=n, ctx=ctx) for subspace in space.spaces
|
||||
)
|
||||
|
||||
|
||||
@create_shared_memory.register(Dict)
|
||||
def _create_dict_shared_memory(space: Dict, n: int = 1, ctx=mp):
|
||||
return OrderedDict(
|
||||
[
|
||||
(key, create_shared_memory(subspace, n=n, ctx=ctx))
|
||||
for (key, subspace) in space.spaces.items()
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@create_shared_memory.register(Text)
|
||||
def _create_text_shared_memory(space: Text, n: int = 1, ctx=mp):
|
||||
return ctx.Array(np.dtype(np.int32).char, n * space.max_length)
|
||||
|
||||
|
||||
@create_shared_memory.register(Graph)
|
||||
@create_shared_memory.register(Sequence)
|
||||
def _create_dynamic_shared_memory(space: Graph | Sequence, n: int = 1, ctx=mp):
|
||||
raise TypeError(
|
||||
f"As {space} has a dynamic shape then it is not possible to make a static shared memory."
|
||||
)
|
||||
|
||||
|
||||
@singledispatch
|
||||
def read_from_shared_memory(
|
||||
space: Space, shared_memory: dict | tuple | mp.Array, n: int = 1
|
||||
) -> dict[str, Any] | tuple[Any, ...] | np.ndarray:
|
||||
"""Read the batch of observations from shared memory as a numpy array.
|
||||
|
||||
..notes::
|
||||
The numpy array objects returned by `read_from_shared_memory` shares the
|
||||
memory of `shared_memory`. Any changes to `shared_memory` are forwarded
|
||||
to `observations`, and vice-versa. To avoid any side-effect, use `np.copy`.
|
||||
|
||||
Args:
|
||||
space: Observation space of a single environment in the vectorized environment.
|
||||
shared_memory: Shared object across processes. This contains the observations from the vectorized environment.
|
||||
This object is created with `create_shared_memory`.
|
||||
n: Number of environments in the vectorized environment (i.e. the number of processes).
|
||||
|
||||
Returns:
|
||||
Batch of observations as a (possibly nested) numpy array.
|
||||
|
||||
Raises:
|
||||
CustomSpaceError: Space is not a valid :class:`gymnasium.Space` instance
|
||||
"""
|
||||
if isinstance(space, Space):
|
||||
raise CustomSpaceError(
|
||||
f"Space of type `{type(space)}` doesn't have an registered `read_from_shared_memory` function. Register `{type(space)}` for `read_from_shared_memory` to support it."
|
||||
)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"The space provided to `read_from_shared_memory` is not a gymnasium Space instance, type: {type(space)}, {space}"
|
||||
)
|
||||
|
||||
|
||||
@read_from_shared_memory.register(Box)
|
||||
@read_from_shared_memory.register(Discrete)
|
||||
@read_from_shared_memory.register(MultiDiscrete)
|
||||
@read_from_shared_memory.register(MultiBinary)
|
||||
def _read_base_from_shared_memory(
|
||||
space: Box | Discrete | MultiDiscrete | MultiBinary, shared_memory, n: int = 1
|
||||
):
|
||||
return np.frombuffer(shared_memory.get_obj(), dtype=space.dtype).reshape(
|
||||
(n,) + space.shape
|
||||
)
|
||||
|
||||
|
||||
@read_from_shared_memory.register(Tuple)
|
||||
def _read_tuple_from_shared_memory(space: Tuple, shared_memory, n: int = 1):
|
||||
return tuple(
|
||||
read_from_shared_memory(subspace, memory, n=n)
|
||||
for (memory, subspace) in zip(shared_memory, space.spaces)
|
||||
)
|
||||
|
||||
|
||||
@read_from_shared_memory.register(Dict)
|
||||
def _read_dict_from_shared_memory(space: Dict, shared_memory, n: int = 1):
|
||||
return OrderedDict(
|
||||
[
|
||||
(key, read_from_shared_memory(subspace, shared_memory[key], n=n))
|
||||
for (key, subspace) in space.spaces.items()
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@read_from_shared_memory.register(Text)
|
||||
def _read_text_from_shared_memory(space: Text, shared_memory, n: int = 1) -> tuple[str]:
|
||||
data = np.frombuffer(shared_memory.get_obj(), dtype=np.int32).reshape(
|
||||
(n, space.max_length)
|
||||
)
|
||||
|
||||
return tuple(
|
||||
"".join(
|
||||
[
|
||||
space.character_list[val]
|
||||
for val in values
|
||||
if val < len(space.character_set)
|
||||
]
|
||||
)
|
||||
for values in data
|
||||
)
|
||||
|
||||
|
||||
@singledispatch
|
||||
def write_to_shared_memory(
|
||||
space: Space,
|
||||
index: int,
|
||||
value: np.ndarray,
|
||||
shared_memory: dict[str, Any] | tuple[Any, ...] | mp.Array,
|
||||
):
|
||||
"""Write the observation of a single environment into shared memory.
|
||||
|
||||
Args:
|
||||
space: Observation space of a single environment in the vectorized environment.
|
||||
index: Index of the environment (must be in `[0, num_envs)`).
|
||||
value: Observation of the single environment to write to shared memory.
|
||||
shared_memory: Shared object across processes. This contains the observations from the vectorized environment.
|
||||
This object is created with `create_shared_memory`.
|
||||
|
||||
Raises:
|
||||
CustomSpaceError: Space is not a valid :class:`gymnasium.Space` instance
|
||||
"""
|
||||
if isinstance(space, Space):
|
||||
raise CustomSpaceError(
|
||||
f"Space of type `{type(space)}` doesn't have an registered `write_to_shared_memory` function. Register `{type(space)}` for `write_to_shared_memory` to support it."
|
||||
)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"The space provided to `write_to_shared_memory` is not a gymnasium Space instance, type: {type(space)}, {space}"
|
||||
)
|
||||
|
||||
|
||||
@write_to_shared_memory.register(Box)
|
||||
@write_to_shared_memory.register(Discrete)
|
||||
@write_to_shared_memory.register(MultiDiscrete)
|
||||
@write_to_shared_memory.register(MultiBinary)
|
||||
def _write_base_to_shared_memory(
|
||||
space: Box | Discrete | MultiDiscrete | MultiBinary,
|
||||
index: int,
|
||||
value,
|
||||
shared_memory,
|
||||
):
|
||||
size = int(np.prod(space.shape))
|
||||
destination = np.frombuffer(shared_memory.get_obj(), dtype=space.dtype)
|
||||
np.copyto(
|
||||
destination[index * size : (index + 1) * size],
|
||||
np.asarray(value, dtype=space.dtype).flatten(),
|
||||
)
|
||||
|
||||
|
||||
@write_to_shared_memory.register(Tuple)
|
||||
def _write_tuple_to_shared_memory(
|
||||
space: Tuple, index: int, values: tuple[Any, ...], shared_memory
|
||||
):
|
||||
for value, memory, subspace in zip(values, shared_memory, space.spaces):
|
||||
write_to_shared_memory(subspace, index, value, memory)
|
||||
|
||||
|
||||
@write_to_shared_memory.register(Dict)
|
||||
def _write_dict_to_shared_memory(
|
||||
space: Dict, index: int, values: dict[str, Any], shared_memory
|
||||
):
|
||||
for key, subspace in space.spaces.items():
|
||||
write_to_shared_memory(subspace, index, values[key], shared_memory[key])
|
||||
|
||||
|
||||
@write_to_shared_memory.register(Text)
|
||||
def _write_text_to_shared_memory(space: Text, index: int, values: str, shared_memory):
|
||||
size = space.max_length
|
||||
destination = np.frombuffer(shared_memory.get_obj(), dtype=np.int32)
|
||||
np.copyto(
|
||||
destination[index * size : (index + 1) * size],
|
||||
flatten(space, values),
|
||||
)
|
@@ -1,486 +0,0 @@
|
||||
"""Base class for vectorized environments."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Generic, TypeVar
|
||||
|
||||
import numpy as np
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.core import ActType, ObsType
|
||||
from gymnasium.utils import seeding
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gymnasium.envs.registration import EnvSpec
|
||||
|
||||
ArrayType = TypeVar("ArrayType")
|
||||
|
||||
|
||||
__all__ = [
|
||||
"VectorEnv",
|
||||
"VectorWrapper",
|
||||
"VectorObservationWrapper",
|
||||
"VectorActionWrapper",
|
||||
"VectorRewardWrapper",
|
||||
"ArrayType",
|
||||
]
|
||||
|
||||
|
||||
class VectorEnv(Generic[ObsType, ActType, ArrayType]):
|
||||
"""Base class for vectorized environments to run multiple independent copies of the same environment in parallel.
|
||||
|
||||
Vector environments can provide a linear speed-up in the steps taken per second through sampling multiple
|
||||
sub-environments at the same time. To prevent terminated environments waiting until all sub-environments have
|
||||
terminated or truncated, the vector environments autoreset sub-environments after they terminate or truncated.
|
||||
As a result, the final step's observation and info are overwritten by the reset's observation and info.
|
||||
Therefore, the observation and info for the final step of a sub-environment is stored in the info parameter,
|
||||
using `"final_observation"` and `"final_info"` respectively. See :meth:`step` for more information.
|
||||
|
||||
The vector environments batch `observations`, `rewards`, `terminations`, `truncations` and `info` for each
|
||||
parallel environment. In addition, :meth:`step` expects to receive a batch of actions for each parallel environment.
|
||||
|
||||
Gymnasium contains two types of Vector environments: :class:`AsyncVectorEnv` and :class:`SyncVectorEnv`.
|
||||
|
||||
The Vector Environments have the additional attributes for users to understand the implementation
|
||||
|
||||
- :attr:`num_envs` - The number of sub-environment in the vector environment
|
||||
- :attr:`observation_space` - The batched observation space of the vector environment
|
||||
- :attr:`single_observation_space` - The observation space of a single sub-environment
|
||||
- :attr:`action_space` - The batched action space of the vector environment
|
||||
- :attr:`single_action_space` - The action space of a single sub-environment
|
||||
|
||||
Note:
|
||||
The info parameter of :meth:`reset` and :meth:`step` was originally implemented before OpenAI Gym v25 was a list
|
||||
of dictionary for each sub-environment. However, this was modified in OpenAI Gym v25+ and in Gymnasium to a
|
||||
dictionary with a NumPy array for each key. To use the old info style using the :class:`VectorListInfo`.
|
||||
|
||||
Note:
|
||||
To render the sub-environments, use :meth:`call` with "render" arguments. Remember to set the `render_modes`
|
||||
for all the sub-environments during initialization.
|
||||
|
||||
Note:
|
||||
All parallel environments should share the identical observation and action spaces.
|
||||
In other words, a vector of multiple different environments is not supported.
|
||||
"""
|
||||
|
||||
spec: EnvSpec
|
||||
|
||||
observation_space: gym.Space
|
||||
action_space: gym.Space
|
||||
single_observation_space: gym.Space
|
||||
single_action_space: gym.Space
|
||||
|
||||
num_envs: int
|
||||
|
||||
closed = False
|
||||
|
||||
_np_random: np.random.Generator | None = None
|
||||
|
||||
def reset(
|
||||
self,
|
||||
*,
|
||||
seed: int | list[int] | None = None,
|
||||
options: dict[str, Any] | None = None,
|
||||
) -> tuple[ObsType, dict[str, Any]]: # type: ignore
|
||||
"""Reset all parallel environments and return a batch of initial observations and info.
|
||||
|
||||
Args:
|
||||
seed: The environment reset seeds
|
||||
options: If to return the options
|
||||
|
||||
Returns:
|
||||
A batch of observations and info from the vectorized environment.
|
||||
|
||||
Example:
|
||||
>>> import gymnasium as gym
|
||||
>>> envs = gym.vector.make("CartPole-v1", num_envs=3)
|
||||
>>> envs.reset(seed=42)
|
||||
(array([[ 0.0273956 , -0.00611216, 0.03585979, 0.0197368 ],
|
||||
[ 0.01522993, -0.04562247, -0.04799704, 0.03392126],
|
||||
[-0.03774345, -0.02418869, -0.00942293, 0.0469184 ]],
|
||||
dtype=float32), {})
|
||||
"""
|
||||
if seed is not None:
|
||||
self._np_random, seed = seeding.np_random(seed)
|
||||
|
||||
def step(
|
||||
self, actions: ActType
|
||||
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]:
|
||||
"""Take an action for each parallel environment.
|
||||
|
||||
Args:
|
||||
actions: element of :attr:`action_space` Batch of actions.
|
||||
|
||||
Returns:
|
||||
Batch of (observations, rewards, terminations, truncations, infos)
|
||||
|
||||
Note:
|
||||
As the vector environments autoreset for a terminating and truncating sub-environments,
|
||||
the returned observation and info is not the final step's observation or info which is instead stored in
|
||||
info as `"final_observation"` and `"final_info"`.
|
||||
|
||||
Example:
|
||||
>>> import gymnasium as gym
|
||||
>>> import numpy as np
|
||||
>>> envs = gym.vector.make("CartPole-v1", num_envs=3)
|
||||
>>> _ = envs.reset(seed=42)
|
||||
>>> actions = np.array([1, 0, 1])
|
||||
>>> observations, rewards, termination, truncation, infos = envs.step(actions)
|
||||
>>> observations
|
||||
array([[ 0.02727336, 0.18847767, 0.03625453, -0.26141977],
|
||||
[ 0.01431748, -0.24002443, -0.04731862, 0.3110827 ],
|
||||
[-0.03822722, 0.1710671 , -0.00848456, -0.2487226 ]],
|
||||
dtype=float32)
|
||||
>>> rewards
|
||||
array([1., 1., 1.])
|
||||
>>> termination
|
||||
array([False, False, False])
|
||||
>>> termination
|
||||
array([False, False, False])
|
||||
>>> infos
|
||||
{}
|
||||
"""
|
||||
pass
|
||||
|
||||
def close_extras(self, **kwargs):
|
||||
"""Clean up the extra resources e.g. beyond what's in this base class."""
|
||||
pass
|
||||
|
||||
def close(self, **kwargs):
|
||||
"""Close all parallel environments and release resources.
|
||||
|
||||
It also closes all the existing image viewers, then calls :meth:`close_extras` and set
|
||||
:attr:`closed` as ``True``.
|
||||
|
||||
Warnings:
|
||||
This function itself does not close the environments, it should be handled
|
||||
in :meth:`close_extras`. This is generic for both synchronous and asynchronous
|
||||
vectorized environments.
|
||||
|
||||
Note:
|
||||
This will be automatically called when garbage collected or program exited.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments passed to :meth:`close_extras`
|
||||
"""
|
||||
if self.closed:
|
||||
return
|
||||
|
||||
self.close_extras(**kwargs)
|
||||
self.closed = True
|
||||
|
||||
@property
|
||||
def np_random(self) -> np.random.Generator:
|
||||
"""Returns the environment's internal :attr:`_np_random` that if not set will initialise with a random seed.
|
||||
|
||||
Returns:
|
||||
Instances of `np.random.Generator`
|
||||
"""
|
||||
if self._np_random is None:
|
||||
self._np_random, seed = seeding.np_random()
|
||||
return self._np_random
|
||||
|
||||
@np_random.setter
|
||||
def np_random(self, value: np.random.Generator):
|
||||
self._np_random = value
|
||||
|
||||
@property
|
||||
def unwrapped(self):
|
||||
"""Return the base environment."""
|
||||
return self
|
||||
|
||||
def _add_info(self, infos: dict, info: dict, env_num: int) -> dict:
|
||||
"""Add env info to the info dictionary of the vectorized environment.
|
||||
|
||||
Given the `info` of a single environment add it to the `infos` dictionary
|
||||
which represents all the infos of the vectorized environment.
|
||||
Every `key` of `info` is paired with a boolean mask `_key` representing
|
||||
whether or not the i-indexed environment has this `info`.
|
||||
|
||||
Args:
|
||||
infos (dict): the infos of the vectorized environment
|
||||
info (dict): the info coming from the single environment
|
||||
env_num (int): the index of the single environment
|
||||
|
||||
Returns:
|
||||
infos (dict): the (updated) infos of the vectorized environment
|
||||
|
||||
"""
|
||||
for k in info.keys():
|
||||
if k not in infos:
|
||||
info_array, array_mask = self._init_info_arrays(type(info[k]))
|
||||
else:
|
||||
info_array, array_mask = infos[k], infos[f"_{k}"]
|
||||
|
||||
info_array[env_num], array_mask[env_num] = info[k], True
|
||||
infos[k], infos[f"_{k}"] = info_array, array_mask
|
||||
return infos
|
||||
|
||||
def _init_info_arrays(self, dtype: type) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""Initialize the info array.
|
||||
|
||||
Initialize the info array. If the dtype is numeric
|
||||
the info array will have the same dtype, otherwise
|
||||
will be an array of `None`. Also, a boolean array
|
||||
of the same length is returned. It will be used for
|
||||
assessing which environment has info data.
|
||||
|
||||
Args:
|
||||
dtype (type): data type of the info coming from the env.
|
||||
|
||||
Returns:
|
||||
array (np.ndarray): the initialized info array.
|
||||
array_mask (np.ndarray): the initialized boolean array.
|
||||
|
||||
"""
|
||||
if dtype in [int, float, bool] or issubclass(dtype, np.number):
|
||||
array = np.zeros(self.num_envs, dtype=dtype)
|
||||
else:
|
||||
array = np.zeros(self.num_envs, dtype=object)
|
||||
array[:] = None
|
||||
array_mask = np.zeros(self.num_envs, dtype=bool)
|
||||
return array, array_mask
|
||||
|
||||
def __del__(self):
|
||||
"""Closes the vector environment."""
|
||||
if not getattr(self, "closed", True):
|
||||
self.close()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Returns a string representation of the vector environment.
|
||||
|
||||
Returns:
|
||||
A string containing the class name, number of environments and environment spec id
|
||||
"""
|
||||
if getattr(self, "spec", None) is None:
|
||||
return f"{self.__class__.__name__}({self.num_envs})"
|
||||
else:
|
||||
return f"{self.__class__.__name__}({self.spec.id}, {self.num_envs})"
|
||||
|
||||
|
||||
class VectorWrapper(VectorEnv):
|
||||
"""Wraps the vectorized environment to allow a modular transformation.
|
||||
|
||||
This class is the base class for all wrappers for vectorized environments. The subclass
|
||||
could override some methods to change the behavior of the original vectorized environment
|
||||
without touching the original code.
|
||||
|
||||
Note:
|
||||
Don't forget to call ``super().__init__(env)`` if the subclass overrides :meth:`__init__`.
|
||||
"""
|
||||
|
||||
_observation_space: gym.Space | None = None
|
||||
_action_space: gym.Space | None = None
|
||||
_single_observation_space: gym.Space | None = None
|
||||
_single_action_space: gym.Space | None = None
|
||||
|
||||
def __init__(self, env: VectorEnv):
|
||||
"""Initialize the vectorized environment wrapper."""
|
||||
super().__init__()
|
||||
|
||||
assert isinstance(env, VectorEnv)
|
||||
self.env = env
|
||||
|
||||
# explicitly forward the methods defined in VectorEnv
|
||||
# to self.env (instead of the base class)
|
||||
|
||||
def reset(
|
||||
self,
|
||||
*,
|
||||
seed: int | list[int] | None = None,
|
||||
options: dict[str, Any] | None = None,
|
||||
) -> tuple[ObsType, dict[str, Any]]:
|
||||
"""Reset all environment using seed and options."""
|
||||
return self.env.reset(seed=seed, options=options)
|
||||
|
||||
def step(
|
||||
self, actions: ActType
|
||||
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]:
|
||||
"""Step all environments."""
|
||||
return self.env.step(actions)
|
||||
|
||||
def close(self, **kwargs: Any):
|
||||
"""Close all environments."""
|
||||
return self.env.close(**kwargs)
|
||||
|
||||
def close_extras(self, **kwargs: Any):
|
||||
"""Close all extra resources."""
|
||||
return self.env.close_extras(**kwargs)
|
||||
|
||||
# implicitly forward all other methods and attributes to self.env
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
"""Forward all other attributes to the base environment."""
|
||||
if name.startswith("_"):
|
||||
raise AttributeError(f"attempted to get missing private attribute '{name}'")
|
||||
return getattr(self.env, name)
|
||||
|
||||
@property
|
||||
def unwrapped(self):
|
||||
"""Return the base non-wrapped environment."""
|
||||
return self.env.unwrapped
|
||||
|
||||
def __repr__(self):
|
||||
"""Return the string representation of the vectorized environment."""
|
||||
return f"<{self.__class__.__name__}, {self.env}>"
|
||||
|
||||
def __del__(self):
|
||||
"""Close the vectorized environment."""
|
||||
self.env.__del__()
|
||||
|
||||
@property
|
||||
def spec(self) -> EnvSpec | None:
|
||||
"""Gets the specification of the wrapped environment."""
|
||||
return self.env.spec
|
||||
|
||||
@property
|
||||
def observation_space(self) -> gym.Space:
|
||||
"""Gets the observation space of the vector environment."""
|
||||
if self._observation_space is None:
|
||||
return self.env.observation_space
|
||||
return self._observation_space
|
||||
|
||||
@observation_space.setter
|
||||
def observation_space(self, space: gym.Space):
|
||||
"""Sets the observation space of the vector environment."""
|
||||
self._observation_space = space
|
||||
|
||||
@property
|
||||
def action_space(self) -> gym.Space:
|
||||
"""Gets the action space of the vector environment."""
|
||||
if self._action_space is None:
|
||||
return self.env.action_space
|
||||
return self._action_space
|
||||
|
||||
@action_space.setter
|
||||
def action_space(self, space: gym.Space):
|
||||
"""Sets the action space of the vector environment."""
|
||||
self._action_space = space
|
||||
|
||||
@property
|
||||
def single_observation_space(self) -> gym.Space:
|
||||
"""Gets the single observation space of the vector environment."""
|
||||
if self._single_observation_space is None:
|
||||
return self.env.single_observation_space
|
||||
return self._single_observation_space
|
||||
|
||||
@single_observation_space.setter
|
||||
def single_observation_space(self, space: gym.Space):
|
||||
"""Sets the single observation space of the vector environment."""
|
||||
self._single_observation_space = space
|
||||
|
||||
@property
|
||||
def single_action_space(self) -> gym.Space:
|
||||
"""Gets the single action space of the vector environment."""
|
||||
if self._single_action_space is None:
|
||||
return self.env.single_action_space
|
||||
return self._single_action_space
|
||||
|
||||
@single_action_space.setter
|
||||
def single_action_space(self, space):
|
||||
"""Sets the single action space of the vector environment."""
|
||||
self._single_action_space = space
|
||||
|
||||
@property
|
||||
def num_envs(self) -> int:
|
||||
"""Gets the wrapped vector environment's num of the sub-environments."""
|
||||
return self.env.num_envs
|
||||
|
||||
|
||||
class VectorObservationWrapper(VectorWrapper):
|
||||
"""Wraps the vectorized environment to allow a modular transformation of the observation. Equivalent to :class:`gym.ObservationWrapper` for vectorized environments."""
|
||||
|
||||
def reset(
|
||||
self,
|
||||
*,
|
||||
seed: int | list[int] | None = None,
|
||||
options: dict[str, Any] | None = None,
|
||||
) -> tuple[ObsType, dict[str, Any]]:
|
||||
"""Modifies the observation returned from the environment ``reset`` using the :meth:`observation`."""
|
||||
obs, info = self.env.reset(seed=seed, options=options)
|
||||
return self.vector_observation(obs), info
|
||||
|
||||
def step(
|
||||
self, actions: ActType
|
||||
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]:
|
||||
"""Modifies the observation returned from the environment ``step`` using the :meth:`observation`."""
|
||||
observation, reward, termination, truncation, info = self.env.step(actions)
|
||||
return (
|
||||
self.vector_observation(observation),
|
||||
reward,
|
||||
termination,
|
||||
truncation,
|
||||
self.update_final_obs(info),
|
||||
)
|
||||
|
||||
def vector_observation(self, observation: ObsType) -> ObsType:
|
||||
"""Defines the vector observation transformation.
|
||||
|
||||
Args:
|
||||
observation: A vector observation from the environment
|
||||
|
||||
Returns:
|
||||
the transformed observation
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def single_observation(self, observation: ObsType) -> ObsType:
|
||||
"""Defines the single observation transformation.
|
||||
|
||||
Args:
|
||||
observation: A single observation from the environment
|
||||
|
||||
Returns:
|
||||
The transformed observation
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def update_final_obs(self, info: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Updates the `final_obs` in the info using `single_observation`."""
|
||||
if "final_observation" in info:
|
||||
for i, obs in enumerate(info["final_observation"]):
|
||||
if obs is not None:
|
||||
info["final_observation"][i] = self.single_observation(obs)
|
||||
return info
|
||||
|
||||
|
||||
class VectorActionWrapper(VectorWrapper):
|
||||
"""Wraps the vectorized environment to allow a modular transformation of the actions. Equivalent of :class:`~gym.ActionWrapper` for vectorized environments."""
|
||||
|
||||
def step(
|
||||
self, actions: ActType
|
||||
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]:
|
||||
"""Steps through the environment using a modified action by :meth:`action`."""
|
||||
return self.env.step(self.actions(actions))
|
||||
|
||||
def actions(self, actions: ActType) -> ActType:
|
||||
"""Transform the actions before sending them to the environment.
|
||||
|
||||
Args:
|
||||
actions (ActType): the actions to transform
|
||||
|
||||
Returns:
|
||||
ActType: the transformed actions
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class VectorRewardWrapper(VectorWrapper):
|
||||
"""Wraps the vectorized environment to allow a modular transformation of the reward. Equivalent of :class:`~gym.RewardWrapper` for vectorized environments."""
|
||||
|
||||
def step(
|
||||
self, actions: ActType
|
||||
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]:
|
||||
"""Steps through the environment returning a reward modified by :meth:`reward`."""
|
||||
observation, reward, termination, truncation, info = self.env.step(actions)
|
||||
return observation, self.reward(reward), termination, truncation, info
|
||||
|
||||
def reward(self, reward: ArrayType) -> ArrayType:
|
||||
"""Transform the reward before returning it.
|
||||
|
||||
Args:
|
||||
reward (array): the reward to transform
|
||||
|
||||
Returns:
|
||||
array: the transformed reward
|
||||
"""
|
||||
raise NotImplementedError
|
@@ -1,164 +0,0 @@
|
||||
"""`__init__` for experimental wrappers, to avoid loading the wrappers if unnecessary, we can hack python."""
|
||||
# pyright: reportUnsupportedDunderAll=false
|
||||
import importlib
|
||||
import re
|
||||
|
||||
from gymnasium.error import DeprecatedWrapper
|
||||
from gymnasium.experimental.wrappers import vector
|
||||
from gymnasium.experimental.wrappers.atari_preprocessing import AtariPreprocessingV0
|
||||
from gymnasium.experimental.wrappers.common import (
|
||||
AutoresetV0,
|
||||
OrderEnforcingV0,
|
||||
PassiveEnvCheckerV0,
|
||||
RecordEpisodeStatisticsV0,
|
||||
)
|
||||
from gymnasium.experimental.wrappers.lambda_action import (
|
||||
ClipActionV0,
|
||||
LambdaActionV0,
|
||||
RescaleActionV0,
|
||||
)
|
||||
from gymnasium.experimental.wrappers.lambda_observation import (
|
||||
DtypeObservationV0,
|
||||
FilterObservationV0,
|
||||
FlattenObservationV0,
|
||||
GrayscaleObservationV0,
|
||||
LambdaObservationV0,
|
||||
PixelObservationV0,
|
||||
RescaleObservationV0,
|
||||
ReshapeObservationV0,
|
||||
ResizeObservationV0,
|
||||
)
|
||||
from gymnasium.experimental.wrappers.lambda_reward import ClipRewardV0, LambdaRewardV0
|
||||
from gymnasium.experimental.wrappers.rendering import (
|
||||
HumanRenderingV0,
|
||||
RecordVideoV0,
|
||||
RenderCollectionV0,
|
||||
)
|
||||
from gymnasium.experimental.wrappers.stateful_action import StickyActionV0
|
||||
from gymnasium.experimental.wrappers.stateful_observation import (
|
||||
DelayObservationV0,
|
||||
FrameStackObservationV0,
|
||||
MaxAndSkipObservationV0,
|
||||
NormalizeObservationV0,
|
||||
TimeAwareObservationV0,
|
||||
)
|
||||
from gymnasium.experimental.wrappers.stateful_reward import NormalizeRewardV1
|
||||
|
||||
|
||||
# Todo - Add legacy wrapper to new wrapper error for users when merged into gymnasium.wrappers
|
||||
|
||||
|
||||
__all__ = [
|
||||
"vector",
|
||||
# --- Observation wrappers ---
|
||||
"AtariPreprocessingV0",
|
||||
"DelayObservationV0",
|
||||
"DtypeObservationV0",
|
||||
"FilterObservationV0",
|
||||
"FlattenObservationV0",
|
||||
"FrameStackObservationV0",
|
||||
"GrayscaleObservationV0",
|
||||
"LambdaObservationV0",
|
||||
"MaxAndSkipObservationV0",
|
||||
"NormalizeObservationV0",
|
||||
"PixelObservationV0",
|
||||
"ResizeObservationV0",
|
||||
"ReshapeObservationV0",
|
||||
"RescaleObservationV0",
|
||||
"TimeAwareObservationV0",
|
||||
# --- Action Wrappers ---
|
||||
"ClipActionV0",
|
||||
"LambdaActionV0",
|
||||
"RescaleActionV0",
|
||||
# "NanAction",
|
||||
"StickyActionV0",
|
||||
# --- Reward wrappers ---
|
||||
"ClipRewardV0",
|
||||
"LambdaRewardV0",
|
||||
"NormalizeRewardV1",
|
||||
# --- Common ---
|
||||
"AutoresetV0",
|
||||
"PassiveEnvCheckerV0",
|
||||
"OrderEnforcingV0",
|
||||
"RecordEpisodeStatisticsV0",
|
||||
# --- Rendering ---
|
||||
"RenderCollectionV0",
|
||||
"RecordVideoV0",
|
||||
"HumanRenderingV0",
|
||||
# --- Conversion ---
|
||||
"JaxToNumpyV0",
|
||||
"JaxToTorchV0",
|
||||
"NumpyToTorchV0",
|
||||
]
|
||||
|
||||
# As these wrappers requires `jax` or `torch`, they are loaded by runtime for users trying to access them
|
||||
# to avoid `import jax` or `import torch` on `import gymnasium`.
|
||||
_wrapper_to_class = {
|
||||
# data converters
|
||||
"JaxToNumpyV0": "jax_to_numpy",
|
||||
"JaxToTorchV0": "jax_to_torch",
|
||||
"NumpyToTorchV0": "numpy_to_torch",
|
||||
}
|
||||
|
||||
|
||||
def __getattr__(wrapper_name: str):
|
||||
"""Load a wrapper by name.
|
||||
|
||||
This optimizes the loading of gymnasium wrappers by only loading the wrapper if it is used.
|
||||
Errors will be raised if the wrapper does not exist or if the version is not the latest.
|
||||
|
||||
Args:
|
||||
wrapper_name: The name of a wrapper to load.
|
||||
|
||||
Returns:
|
||||
The specified wrapper.
|
||||
|
||||
Raises:
|
||||
AttributeError: If the wrapper does not exist.
|
||||
DeprecatedWrapper: If the version is not the latest.
|
||||
"""
|
||||
# Check if the requested wrapper is in the _wrapper_to_class dictionary
|
||||
if wrapper_name in _wrapper_to_class:
|
||||
import_stmt = (
|
||||
f"gymnasium.experimental.wrappers.{_wrapper_to_class[wrapper_name]}"
|
||||
)
|
||||
module = importlib.import_module(import_stmt)
|
||||
return getattr(module, wrapper_name)
|
||||
|
||||
# Define a regex pattern to match the integer suffix (version number) of the wrapper
|
||||
int_suffix_pattern = r"(\d+)$"
|
||||
version_match = re.search(int_suffix_pattern, wrapper_name)
|
||||
|
||||
# If a version number is found, extract it and the base wrapper name
|
||||
if version_match:
|
||||
version = int(version_match.group())
|
||||
base_name = wrapper_name[: -len(version_match.group())]
|
||||
else:
|
||||
version = float("inf")
|
||||
base_name = wrapper_name[:-2]
|
||||
|
||||
# Filter the list of all wrappers to include only those with the same base name
|
||||
matching_wrappers = [name for name in __all__ if name.startswith(base_name)]
|
||||
|
||||
# If no matching wrappers are found, raise an AttributeError
|
||||
if not matching_wrappers:
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {wrapper_name!r}")
|
||||
|
||||
# Find the latest version of the matching wrappers
|
||||
latest_wrapper = max(
|
||||
matching_wrappers, key=lambda s: int(re.findall(int_suffix_pattern, s)[0])
|
||||
)
|
||||
latest_version = int(re.findall(int_suffix_pattern, latest_wrapper)[0])
|
||||
|
||||
# If the requested wrapper is an older version, raise a DeprecatedWrapper exception
|
||||
if version < latest_version:
|
||||
raise DeprecatedWrapper(
|
||||
f"{wrapper_name!r} is now deprecated, use {latest_wrapper!r} instead.\n"
|
||||
f"To see the changes made, go to "
|
||||
f"https://gymnasium.farama.org/api/experimental/wrappers/#gymnasium.experimental.wrappers.{latest_wrapper}"
|
||||
)
|
||||
# If the requested version is invalid, raise an AttributeError
|
||||
else:
|
||||
raise AttributeError(
|
||||
f"module {__name__!r} has no attribute {wrapper_name!r}, did you mean {latest_wrapper!r}"
|
||||
)
|
@@ -1,206 +0,0 @@
|
||||
"""Implementation of Atari 2600 Preprocessing following the guidelines of Machado et al., 2018."""
|
||||
import numpy as np
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.spaces import Box
|
||||
|
||||
|
||||
__all__ = ["AtariPreprocessingV0"]
|
||||
|
||||
|
||||
class AtariPreprocessingV0(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
||||
"""Atari 2600 preprocessing wrapper.
|
||||
|
||||
This class follows the guidelines in Machado et al. (2018),
|
||||
"Revisiting the Arcade Learning Environment: Evaluation Protocols and Open Problems for General Agents".
|
||||
|
||||
Specifically, the following preprocess stages applies to the atari environment:
|
||||
|
||||
- Noop Reset: Obtains the initial state by taking a random number of no-ops on reset, default max 30 no-ops.
|
||||
- Frame skipping: The number of frames skipped between steps, 4 by default
|
||||
- Max-pooling: Pools over the most recent two observations from the frame skips
|
||||
- Termination signal when a life is lost: When the agent losses a life during the environment, then the environment is terminated.
|
||||
Turned off by default. Not recommended by Machado et al. (2018).
|
||||
- Resize to a square image: Resizes the atari environment original observation shape from 210x180 to 84x84 by default
|
||||
- Grayscale observation: If the observation is colour or greyscale, by default, greyscale.
|
||||
- Scale observation: If to scale the observation between [0, 1) or [0, 255), by default, not scaled.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env: gym.Env,
|
||||
noop_max: int = 30,
|
||||
frame_skip: int = 4,
|
||||
screen_size: int = 84,
|
||||
terminal_on_life_loss: bool = False,
|
||||
grayscale_obs: bool = True,
|
||||
grayscale_newaxis: bool = False,
|
||||
scale_obs: bool = False,
|
||||
):
|
||||
"""Wrapper for Atari 2600 preprocessing.
|
||||
|
||||
Args:
|
||||
env (Env): The environment to apply the preprocessing
|
||||
noop_max (int): For No-op reset, the max number no-ops actions are taken at reset, to turn off, set to 0.
|
||||
frame_skip (int): The number of frames between new observation the agents observations effecting the frequency at which the agent experiences the game.
|
||||
screen_size (int): resize Atari frame
|
||||
terminal_on_life_loss (bool): `if True`, then :meth:`step()` returns `terminated=True` whenever a
|
||||
life is lost.
|
||||
grayscale_obs (bool): if True, then gray scale observation is returned, otherwise, RGB observation
|
||||
is returned.
|
||||
grayscale_newaxis (bool): `if True and grayscale_obs=True`, then a channel axis is added to
|
||||
grayscale observations to make them 3-dimensional.
|
||||
scale_obs (bool): if True, then observation normalized in range [0,1) is returned. It also limits memory
|
||||
optimization benefits of FrameStack Wrapper.
|
||||
|
||||
Raises:
|
||||
DependencyNotInstalled: opencv-python package not installed
|
||||
ValueError: Disable frame-skipping in the original env
|
||||
"""
|
||||
gym.utils.RecordConstructorArgs.__init__(
|
||||
self,
|
||||
noop_max=noop_max,
|
||||
frame_skip=frame_skip,
|
||||
screen_size=screen_size,
|
||||
terminal_on_life_loss=terminal_on_life_loss,
|
||||
grayscale_obs=grayscale_obs,
|
||||
grayscale_newaxis=grayscale_newaxis,
|
||||
scale_obs=scale_obs,
|
||||
)
|
||||
gym.Wrapper.__init__(self, env)
|
||||
|
||||
try:
|
||||
import cv2 # noqa: F401
|
||||
except ImportError:
|
||||
raise gym.error.DependencyNotInstalled(
|
||||
"opencv-python package not installed, run `pip install gymnasium[other]` to get dependencies for atari"
|
||||
)
|
||||
|
||||
assert frame_skip > 0
|
||||
assert screen_size > 0
|
||||
assert noop_max >= 0
|
||||
if frame_skip > 1:
|
||||
if (
|
||||
env.spec is not None
|
||||
and "NoFrameskip" not in env.spec.id
|
||||
and getattr(env.unwrapped, "_frameskip", None) != 1
|
||||
):
|
||||
raise ValueError(
|
||||
"Disable frame-skipping in the original env. Otherwise, more than one "
|
||||
"frame-skip will happen as through this wrapper"
|
||||
)
|
||||
self.noop_max = noop_max
|
||||
assert env.unwrapped.get_action_meanings()[0] == "NOOP"
|
||||
|
||||
self.frame_skip = frame_skip
|
||||
self.screen_size = screen_size
|
||||
self.terminal_on_life_loss = terminal_on_life_loss
|
||||
self.grayscale_obs = grayscale_obs
|
||||
self.grayscale_newaxis = grayscale_newaxis
|
||||
self.scale_obs = scale_obs
|
||||
|
||||
# buffer of most recent two observations for max pooling
|
||||
assert isinstance(env.observation_space, Box)
|
||||
if grayscale_obs:
|
||||
self.obs_buffer = [
|
||||
np.empty(env.observation_space.shape[:2], dtype=np.uint8),
|
||||
np.empty(env.observation_space.shape[:2], dtype=np.uint8),
|
||||
]
|
||||
else:
|
||||
self.obs_buffer = [
|
||||
np.empty(env.observation_space.shape, dtype=np.uint8),
|
||||
np.empty(env.observation_space.shape, dtype=np.uint8),
|
||||
]
|
||||
|
||||
self.lives = 0
|
||||
self.game_over = False
|
||||
|
||||
_low, _high, _obs_dtype = (
|
||||
(0, 255, np.uint8) if not scale_obs else (0, 1, np.float32)
|
||||
)
|
||||
_shape = (screen_size, screen_size, 1 if grayscale_obs else 3)
|
||||
if grayscale_obs and not grayscale_newaxis:
|
||||
_shape = _shape[:-1] # Remove channel axis
|
||||
self.observation_space = Box(
|
||||
low=_low, high=_high, shape=_shape, dtype=_obs_dtype
|
||||
)
|
||||
|
||||
@property
|
||||
def ale(self):
|
||||
"""Make ale as a class property to avoid serialization error."""
|
||||
return self.env.unwrapped.ale
|
||||
|
||||
def step(self, action):
|
||||
"""Applies the preprocessing for an :meth:`env.step`."""
|
||||
total_reward, terminated, truncated, info = 0.0, False, False, {}
|
||||
|
||||
for t in range(self.frame_skip):
|
||||
_, reward, terminated, truncated, info = self.env.step(action)
|
||||
total_reward += reward
|
||||
self.game_over = terminated
|
||||
|
||||
if self.terminal_on_life_loss:
|
||||
new_lives = self.ale.lives()
|
||||
terminated = terminated or new_lives < self.lives
|
||||
self.game_over = terminated
|
||||
self.lives = new_lives
|
||||
|
||||
if terminated or truncated:
|
||||
break
|
||||
if t == self.frame_skip - 2:
|
||||
if self.grayscale_obs:
|
||||
self.ale.getScreenGrayscale(self.obs_buffer[1])
|
||||
else:
|
||||
self.ale.getScreenRGB(self.obs_buffer[1])
|
||||
elif t == self.frame_skip - 1:
|
||||
if self.grayscale_obs:
|
||||
self.ale.getScreenGrayscale(self.obs_buffer[0])
|
||||
else:
|
||||
self.ale.getScreenRGB(self.obs_buffer[0])
|
||||
return self._get_obs(), total_reward, terminated, truncated, info
|
||||
|
||||
def reset(self, **kwargs):
|
||||
"""Resets the environment using preprocessing."""
|
||||
# NoopReset
|
||||
_, reset_info = self.env.reset(**kwargs)
|
||||
|
||||
noops = (
|
||||
self.env.unwrapped.np_random.integers(1, self.noop_max + 1)
|
||||
if self.noop_max > 0
|
||||
else 0
|
||||
)
|
||||
for _ in range(noops):
|
||||
_, _, terminated, truncated, step_info = self.env.step(0)
|
||||
reset_info.update(step_info)
|
||||
if terminated or truncated:
|
||||
_, reset_info = self.env.reset(**kwargs)
|
||||
|
||||
self.lives = self.ale.lives()
|
||||
if self.grayscale_obs:
|
||||
self.ale.getScreenGrayscale(self.obs_buffer[0])
|
||||
else:
|
||||
self.ale.getScreenRGB(self.obs_buffer[0])
|
||||
self.obs_buffer[1].fill(0)
|
||||
|
||||
return self._get_obs(), reset_info
|
||||
|
||||
def _get_obs(self):
|
||||
if self.frame_skip > 1: # more efficient in-place pooling
|
||||
np.maximum(self.obs_buffer[0], self.obs_buffer[1], out=self.obs_buffer[0])
|
||||
|
||||
import cv2
|
||||
|
||||
obs = cv2.resize(
|
||||
self.obs_buffer[0],
|
||||
(self.screen_size, self.screen_size),
|
||||
interpolation=cv2.INTER_AREA,
|
||||
)
|
||||
|
||||
if self.scale_obs:
|
||||
obs = np.asarray(obs, dtype=np.float32) / 255.0
|
||||
else:
|
||||
obs = np.asarray(obs, dtype=np.uint8)
|
||||
|
||||
if self.grayscale_obs and self.grayscale_newaxis:
|
||||
obs = np.expand_dims(obs, axis=-1) # Add a channel axis
|
||||
return obs
|
@@ -1,315 +0,0 @@
|
||||
"""A collection of common wrappers.
|
||||
|
||||
* ``AutoresetV0`` - Auto-resets the environment
|
||||
* ``PassiveEnvCheckerV0`` - Passive environment checker that does not modify any environment data
|
||||
* ``OrderEnforcingV0`` - Enforces the order of function calls to environments
|
||||
* ``RecordEpisodeStatisticsV0`` - Records the episode statistics
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from collections import deque
|
||||
from typing import Any, SupportsFloat
|
||||
|
||||
import numpy as np
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.core import ActType, ObsType, RenderFrame
|
||||
from gymnasium.error import ResetNeeded
|
||||
from gymnasium.utils.passive_env_checker import (
|
||||
check_action_space,
|
||||
check_observation_space,
|
||||
env_render_passive_checker,
|
||||
env_reset_passive_checker,
|
||||
env_step_passive_checker,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AutoresetV0",
|
||||
"PassiveEnvCheckerV0",
|
||||
"OrderEnforcingV0",
|
||||
"RecordEpisodeStatisticsV0",
|
||||
]
|
||||
|
||||
|
||||
class AutoresetV0(
|
||||
gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
|
||||
):
|
||||
"""A class for providing an automatic reset functionality for gymnasium environments when calling :meth:`self.step`."""
|
||||
|
||||
def __init__(self, env: gym.Env[ObsType, ActType]):
|
||||
"""A class for providing an automatic reset functionality for gymnasium environments when calling :meth:`self.step`.
|
||||
|
||||
Args:
|
||||
env (gym.Env): The environment to apply the wrapper
|
||||
"""
|
||||
gym.utils.RecordConstructorArgs.__init__(self)
|
||||
gym.Wrapper.__init__(self, env)
|
||||
|
||||
self._episode_ended: bool = False
|
||||
self._reset_options: dict[str, Any] | None = None
|
||||
|
||||
def step(
|
||||
self, action: ActType
|
||||
) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
|
||||
"""Steps through the environment with action and resets the environment if a terminated or truncated signal is encountered in the previous step.
|
||||
|
||||
Args:
|
||||
action: The action to take
|
||||
|
||||
Returns:
|
||||
The autoreset environment :meth:`step`
|
||||
"""
|
||||
if self._episode_ended:
|
||||
obs, info = self.env.reset(options=self._reset_options)
|
||||
self._episode_ended = True
|
||||
return obs, 0, False, False, info
|
||||
else:
|
||||
obs, reward, terminated, truncated, info = super().step(action)
|
||||
self._episode_ended = terminated or truncated
|
||||
return obs, reward, terminated, truncated, info
|
||||
|
||||
def reset(
|
||||
self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
||||
) -> tuple[ObsType, dict[str, Any]]:
|
||||
"""Resets the environment, saving the options used."""
|
||||
self._episode_ended = False
|
||||
self._reset_options = options
|
||||
return super().reset(seed=seed, options=self._reset_options)
|
||||
|
||||
|
||||
class PassiveEnvCheckerV0(
|
||||
gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
|
||||
):
|
||||
"""A passive environment checker wrapper that surrounds the step, reset and render functions to check they follow the gymnasium API."""
|
||||
|
||||
def __init__(self, env: gym.Env[ObsType, ActType]):
|
||||
"""Initialises the wrapper with the environments, run the observation and action space tests."""
|
||||
gym.utils.RecordConstructorArgs.__init__(self)
|
||||
gym.Wrapper.__init__(self, env)
|
||||
|
||||
assert hasattr(
|
||||
env, "action_space"
|
||||
), "The environment must specify an action space. https://gymnasium.farama.org/tutorials/gymnasium_basics/environment_creation/"
|
||||
check_action_space(env.action_space)
|
||||
assert hasattr(
|
||||
env, "observation_space"
|
||||
), "The environment must specify an observation space. https://gymnasium.farama.org/tutorials/gymnasium_basics/environment_creation/"
|
||||
check_observation_space(env.observation_space)
|
||||
|
||||
self._checked_reset: bool = False
|
||||
self._checked_step: bool = False
|
||||
self._checked_render: bool = False
|
||||
|
||||
def step(
|
||||
self, action: ActType
|
||||
) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
|
||||
"""Steps through the environment that on the first call will run the `passive_env_step_check`."""
|
||||
if self._checked_step is False:
|
||||
self._checked_step = True
|
||||
return env_step_passive_checker(self.env, action)
|
||||
else:
|
||||
return self.env.step(action)
|
||||
|
||||
def reset(
|
||||
self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
||||
) -> tuple[ObsType, dict[str, Any]]:
|
||||
"""Resets the environment that on the first call will run the `passive_env_reset_check`."""
|
||||
if self._checked_reset is False:
|
||||
self._checked_reset = True
|
||||
return env_reset_passive_checker(self.env, seed=seed, options=options)
|
||||
else:
|
||||
return self.env.reset(seed=seed, options=options)
|
||||
|
||||
def render(self) -> RenderFrame | list[RenderFrame] | None:
|
||||
"""Renders the environment that on the first call will run the `passive_env_render_check`."""
|
||||
if self._checked_render is False:
|
||||
self._checked_render = True
|
||||
return env_render_passive_checker(self.env)
|
||||
else:
|
||||
return self.env.render()
|
||||
|
||||
|
||||
class OrderEnforcingV0(
|
||||
gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
|
||||
):
|
||||
"""A wrapper that will produce an error if :meth:`step` is called before an initial :meth:`reset`.
|
||||
|
||||
Example:
|
||||
>>> import gymnasium as gym
|
||||
>>> from gymnasium.experimental.wrappers import OrderEnforcingV0
|
||||
>>> env = gym.make("CartPole-v1", render_mode="human")
|
||||
>>> env = OrderEnforcingV0(env)
|
||||
>>> env.step(0)
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
gymnasium.error.ResetNeeded: Cannot call env.step() before calling env.reset()
|
||||
>>> env.render()
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
gymnasium.error.ResetNeeded: Cannot call `env.render()` before calling `env.reset()`, if this is a intended action, set `disable_render_order_enforcing=True` on the OrderEnforcer wrapper.
|
||||
>>> _ = env.reset()
|
||||
>>> env.render()
|
||||
>>> _ = env.step(0)
|
||||
>>> env.close()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env: gym.Env[ObsType, ActType],
|
||||
disable_render_order_enforcing: bool = False,
|
||||
):
|
||||
"""A wrapper that will produce an error if :meth:`step` is called before an initial :meth:`reset`.
|
||||
|
||||
Args:
|
||||
env: The environment to wrap
|
||||
disable_render_order_enforcing: If to disable render order enforcing
|
||||
"""
|
||||
gym.utils.RecordConstructorArgs.__init__(
|
||||
self, disable_render_order_enforcing=disable_render_order_enforcing
|
||||
)
|
||||
gym.Wrapper.__init__(self, env)
|
||||
|
||||
self._has_reset: bool = False
|
||||
self._disable_render_order_enforcing: bool = disable_render_order_enforcing
|
||||
|
||||
def step(self, action: ActType) -> tuple[ObsType, SupportsFloat, bool, bool, dict]:
|
||||
"""Steps through the environment."""
|
||||
if not self._has_reset:
|
||||
raise ResetNeeded("Cannot call env.step() before calling env.reset()")
|
||||
return super().step(action)
|
||||
|
||||
def reset(
|
||||
self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
||||
) -> tuple[ObsType, dict[str, Any]]:
|
||||
"""Resets the environment with `kwargs`."""
|
||||
self._has_reset = True
|
||||
return super().reset(seed=seed, options=options)
|
||||
|
||||
def render(self) -> RenderFrame | list[RenderFrame] | None:
|
||||
"""Renders the environment with `kwargs`."""
|
||||
if not self._disable_render_order_enforcing and not self._has_reset:
|
||||
raise ResetNeeded(
|
||||
"Cannot call `env.render()` before calling `env.reset()`, if this is a intended action, "
|
||||
"set `disable_render_order_enforcing=True` on the OrderEnforcer wrapper."
|
||||
)
|
||||
return super().render()
|
||||
|
||||
@property
|
||||
def has_reset(self):
|
||||
"""Returns if the environment has been reset before."""
|
||||
return self._has_reset
|
||||
|
||||
|
||||
class RecordEpisodeStatisticsV0(
|
||||
gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
|
||||
):
|
||||
"""This wrapper will keep track of cumulative rewards and episode lengths.
|
||||
|
||||
At the end of an episode, the statistics of the episode will be added to ``info``
|
||||
using the key ``episode``. If using a vectorized environment also the key
|
||||
``_episode`` is used which indicates whether the env at the respective index has
|
||||
the episode statistics.
|
||||
|
||||
After the completion of an episode, ``info`` will look like this::
|
||||
|
||||
>>> info = {
|
||||
... "episode": {
|
||||
... "r": "<cumulative reward>",
|
||||
... "l": "<episode length>",
|
||||
... "t": "<elapsed time since beginning of episode>"
|
||||
... },
|
||||
... }
|
||||
|
||||
For a vectorized environments the output will be in the form of::
|
||||
|
||||
>>> infos = {
|
||||
... "final_observation": "<array of length num-envs>",
|
||||
... "_final_observation": "<boolean array of length num-envs>",
|
||||
... "final_info": "<array of length num-envs>",
|
||||
... "_final_info": "<boolean array of length num-envs>",
|
||||
... "episode": {
|
||||
... "r": "<array of cumulative reward>",
|
||||
... "l": "<array of episode length>",
|
||||
... "t": "<array of elapsed time since beginning of episode>"
|
||||
... },
|
||||
... "_episode": "<boolean array of length num-envs>"
|
||||
... }
|
||||
|
||||
|
||||
Moreover, the most recent rewards and episode lengths are stored in buffers that can be accessed via
|
||||
:attr:`wrapped_env.return_queue` and :attr:`wrapped_env.length_queue` respectively.
|
||||
|
||||
Attributes:
|
||||
episode_reward_buffer: The cumulative rewards of the last ``deque_size``-many episodes
|
||||
episode_length_buffer: The lengths of the last ``deque_size``-many episodes
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env: gym.Env[ObsType, ActType],
|
||||
buffer_length: int | None = 100,
|
||||
stats_key: str = "episode",
|
||||
):
|
||||
"""This wrapper will keep track of cumulative rewards and episode lengths.
|
||||
|
||||
Args:
|
||||
env (Env): The environment to apply the wrapper
|
||||
buffer_length: The size of the buffers :attr:`return_queue` and :attr:`length_queue`
|
||||
stats_key: The info key for the episode statistics
|
||||
"""
|
||||
gym.utils.RecordConstructorArgs.__init__(self)
|
||||
gym.Wrapper.__init__(self, env)
|
||||
|
||||
self._stats_key = stats_key
|
||||
|
||||
self.episode_count = 0
|
||||
self.episode_start_time: float = -1
|
||||
self.episode_reward: float = -1
|
||||
self.episode_length: int = -1
|
||||
|
||||
self.episode_time_length_buffer: deque[int] = deque(maxlen=buffer_length)
|
||||
self.episode_reward_buffer: deque[float] = deque(maxlen=buffer_length)
|
||||
self.episode_length_buffer: deque[int] = deque(maxlen=buffer_length)
|
||||
|
||||
def step(
|
||||
self, action: ActType
|
||||
) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
|
||||
"""Steps through the environment, recording the episode statistics."""
|
||||
obs, reward, terminated, truncated, info = super().step(action)
|
||||
|
||||
self.episode_reward += reward
|
||||
self.episode_length += 1
|
||||
|
||||
if terminated or truncated:
|
||||
assert self._stats_key not in info
|
||||
|
||||
episode_time_length = np.round(
|
||||
time.perf_counter() - self.episode_start_time, 6
|
||||
)
|
||||
info[self._stats_key] = {
|
||||
"r": self.episode_reward,
|
||||
"l": self.episode_length,
|
||||
"t": episode_time_length,
|
||||
}
|
||||
|
||||
self.episode_time_length_buffer.append(episode_time_length)
|
||||
self.episode_reward_buffer.append(self.episode_reward)
|
||||
self.episode_length_buffer.append(self.episode_length)
|
||||
|
||||
self.episode_count += 1
|
||||
|
||||
return obs, reward, terminated, truncated, info
|
||||
|
||||
def reset(
|
||||
self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
||||
) -> tuple[ObsType, dict[str, Any]]:
|
||||
"""Resets the environment using seed and options and resets the episode rewards and lengths."""
|
||||
obs, info = super().reset(seed=seed, options=options)
|
||||
|
||||
self.episode_start_time = time.perf_counter()
|
||||
self.episode_reward = 0
|
||||
self.episode_length = 0
|
||||
|
||||
return obs, info
|
@@ -1,620 +0,0 @@
|
||||
"""A collection of observation wrappers using a lambda function.
|
||||
|
||||
* ``LambdaObservationV0`` - Transforms the observation with a function
|
||||
* ``FilterObservationV0`` - Filters a ``Tuple`` or ``Dict`` to only include certain keys
|
||||
* ``FlattenObservationV0`` - Flattens the observations
|
||||
* ``GrayscaleObservationV0`` - Converts a RGB observation to a grayscale observation
|
||||
* ``ResizeObservationV0`` - Resizes an array-based observation (normally a RGB observation)
|
||||
* ``ReshapeObservationV0`` - Reshapes an array-based observation
|
||||
* ``RescaleObservationV0`` - Rescales an observation to between a minimum and maximum value
|
||||
* ``DtypeObservationV0`` - Convert an observation to a dtype
|
||||
* ``PixelObservationV0`` - Allows the observation to the rendered frame
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Callable, Final, Sequence
|
||||
|
||||
import numpy as np
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ActType, ObsType, WrapperObsType
|
||||
from gymnasium.error import DependencyNotInstalled
|
||||
|
||||
|
||||
__all__ = [
|
||||
"LambdaObservationV0",
|
||||
"FilterObservationV0",
|
||||
"FlattenObservationV0",
|
||||
"GrayscaleObservationV0",
|
||||
"ResizeObservationV0",
|
||||
"ReshapeObservationV0",
|
||||
"RescaleObservationV0",
|
||||
"DtypeObservationV0",
|
||||
"PixelObservationV0",
|
||||
]
|
||||
|
||||
|
||||
class LambdaObservationV0(
|
||||
gym.ObservationWrapper[WrapperObsType, ActType, ObsType],
|
||||
gym.utils.RecordConstructorArgs,
|
||||
):
|
||||
"""Transforms an observation via a function provided to the wrapper.
|
||||
|
||||
The function :attr:`func` will be applied to all observations.
|
||||
If the observations from :attr:`func` are outside the bounds of the ``env``'s observation space, provide an :attr:`observation_space`.
|
||||
|
||||
Example:
|
||||
>>> import gymnasium as gym
|
||||
>>> from gymnasium.experimental.wrappers import LambdaObservationV0
|
||||
>>> import numpy as np
|
||||
>>> np.random.seed(0)
|
||||
>>> env = gym.make("CartPole-v1")
|
||||
>>> env = LambdaObservationV0(env, lambda obs: obs + 0.1 * np.random.random(obs.shape), env.observation_space)
|
||||
>>> env.reset(seed=42)
|
||||
(array([0.08227695, 0.06540678, 0.09613613, 0.07422512]), {})
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env: gym.Env[ObsType, ActType],
|
||||
func: Callable[[ObsType], Any],
|
||||
observation_space: gym.Space[WrapperObsType] | None,
|
||||
):
|
||||
"""Constructor for the lambda observation wrapper.
|
||||
|
||||
Args:
|
||||
env: The environment to wrap
|
||||
func: A function that will transform an observation. If this transformed observation is outside the observation space of ``env.observation_space`` then provide an `observation_space`.
|
||||
observation_space: The observation spaces of the wrapper, if None, then it is assumed the same as ``env.observation_space``.
|
||||
"""
|
||||
gym.utils.RecordConstructorArgs.__init__(
|
||||
self, func=func, observation_space=observation_space
|
||||
)
|
||||
gym.ObservationWrapper.__init__(self, env)
|
||||
|
||||
if observation_space is not None:
|
||||
self.observation_space = observation_space
|
||||
|
||||
self.func = func
|
||||
|
||||
def observation(self, observation: ObsType) -> Any:
|
||||
"""Apply function to the observation."""
|
||||
return self.func(observation)
|
||||
|
||||
|
||||
class FilterObservationV0(
|
||||
LambdaObservationV0[WrapperObsType, ActType, ObsType],
|
||||
gym.utils.RecordConstructorArgs,
|
||||
):
|
||||
"""Filters Dict or Tuple observation space by the keys or indexes.
|
||||
|
||||
Example:
|
||||
>>> import gymnasium as gym
|
||||
>>> from gymnasium.wrappers import TransformObservation
|
||||
>>> from gymnasium.experimental.wrappers import FilterObservationV0
|
||||
>>> env = gym.make("CartPole-v1")
|
||||
>>> env = gym.wrappers.TransformObservation(env, lambda obs: {'obs': obs, 'time': 0})
|
||||
>>> env.observation_space = gym.spaces.Dict(obs=env.observation_space, time=gym.spaces.Discrete(1))
|
||||
>>> env.reset(seed=42)
|
||||
({'obs': array([ 0.0273956 , -0.00611216, 0.03585979, 0.0197368 ], dtype=float32), 'time': 0}, {})
|
||||
>>> env = FilterObservationV0(env, filter_keys=['time'])
|
||||
>>> env.reset(seed=42)
|
||||
({'time': 0}, {})
|
||||
>>> env.step(0)
|
||||
({'time': 0}, 1.0, False, False, {})
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, env: gym.Env[ObsType, ActType], filter_keys: Sequence[str | int]
|
||||
):
|
||||
"""Constructor for the filter observation wrapper.
|
||||
|
||||
Args:
|
||||
env: The environment to wrap
|
||||
filter_keys: The subspaces to be included, use a list of strings or integers for ``Dict`` and ``Tuple`` spaces respectivesly
|
||||
"""
|
||||
assert isinstance(filter_keys, Sequence)
|
||||
gym.utils.RecordConstructorArgs.__init__(self, filter_keys=filter_keys)
|
||||
|
||||
# Filters for dictionary space
|
||||
if isinstance(env.observation_space, spaces.Dict):
|
||||
assert all(isinstance(key, str) for key in filter_keys)
|
||||
|
||||
if any(
|
||||
key not in env.observation_space.spaces.keys() for key in filter_keys
|
||||
):
|
||||
missing_keys = [
|
||||
key
|
||||
for key in filter_keys
|
||||
if key not in env.observation_space.spaces.keys()
|
||||
]
|
||||
raise ValueError(
|
||||
"All the `filter_keys` must be included in the observation space.\n"
|
||||
f"Filter keys: {filter_keys}\n"
|
||||
f"Observation keys: {list(env.observation_space.spaces.keys())}\n"
|
||||
f"Missing keys: {missing_keys}"
|
||||
)
|
||||
|
||||
new_observation_space = spaces.Dict(
|
||||
{key: env.observation_space[key] for key in filter_keys}
|
||||
)
|
||||
if len(new_observation_space) == 0:
|
||||
raise ValueError(
|
||||
"The observation space is empty due to filtering all keys."
|
||||
)
|
||||
|
||||
LambdaObservationV0.__init__(
|
||||
self,
|
||||
env=env,
|
||||
func=lambda obs: {key: obs[key] for key in filter_keys},
|
||||
observation_space=new_observation_space,
|
||||
)
|
||||
# Filter for tuple observation
|
||||
elif isinstance(env.observation_space, spaces.Tuple):
|
||||
assert all(isinstance(key, int) for key in filter_keys)
|
||||
assert len(set(filter_keys)) == len(
|
||||
filter_keys
|
||||
), f"Duplicate keys exist, filter_keys: {filter_keys}"
|
||||
|
||||
if any(
|
||||
0 < key and key >= len(env.observation_space) for key in filter_keys
|
||||
):
|
||||
missing_index = [
|
||||
key
|
||||
for key in filter_keys
|
||||
if 0 < key and key >= len(env.observation_space)
|
||||
]
|
||||
raise ValueError(
|
||||
"All the `filter_keys` must be included in the length of the observation space.\n"
|
||||
f"Filter keys: {filter_keys}, length of observation: {len(env.observation_space)}, "
|
||||
f"missing indexes: {missing_index}"
|
||||
)
|
||||
|
||||
new_observation_spaces = spaces.Tuple(
|
||||
env.observation_space[key] for key in filter_keys
|
||||
)
|
||||
if len(new_observation_spaces) == 0:
|
||||
raise ValueError(
|
||||
"The observation space is empty due to filtering all keys."
|
||||
)
|
||||
|
||||
LambdaObservationV0.__init__(
|
||||
self,
|
||||
env=env,
|
||||
func=lambda obs: tuple(obs[key] for key in filter_keys),
|
||||
observation_space=new_observation_spaces,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"FilterObservation wrapper is only usable with `Dict` and `Tuple` observations, actual type: {type(env.observation_space)}"
|
||||
)
|
||||
|
||||
self.filter_keys: Final[Sequence[str | int]] = filter_keys
|
||||
|
||||
|
||||
class FlattenObservationV0(
|
||||
LambdaObservationV0[WrapperObsType, ActType, ObsType],
|
||||
gym.utils.RecordConstructorArgs,
|
||||
):
|
||||
"""Observation wrapper that flattens the observation.
|
||||
|
||||
Example:
|
||||
>>> import gymnasium as gym
|
||||
>>> from gymnasium.experimental.wrappers import FlattenObservationV0
|
||||
>>> env = gym.make("CarRacing-v2")
|
||||
>>> env.observation_space.shape
|
||||
(96, 96, 3)
|
||||
>>> env = FlattenObservationV0(env)
|
||||
>>> env.observation_space.shape
|
||||
(27648,)
|
||||
>>> obs, _ = env.reset()
|
||||
>>> obs.shape
|
||||
(27648,)
|
||||
"""
|
||||
|
||||
def __init__(self, env: gym.Env[ObsType, ActType]):
|
||||
"""Constructor for any environment's observation space that implements ``spaces.utils.flatten_space`` and ``spaces.utils.flatten``.
|
||||
|
||||
Args:
|
||||
env: The environment to wrap
|
||||
"""
|
||||
gym.utils.RecordConstructorArgs.__init__(self)
|
||||
LambdaObservationV0.__init__(
|
||||
self,
|
||||
env=env,
|
||||
func=lambda obs: spaces.utils.flatten(env.observation_space, obs),
|
||||
observation_space=spaces.utils.flatten_space(env.observation_space),
|
||||
)
|
||||
|
||||
|
||||
class GrayscaleObservationV0(
|
||||
LambdaObservationV0[WrapperObsType, ActType, ObsType],
|
||||
gym.utils.RecordConstructorArgs,
|
||||
):
|
||||
"""Observation wrapper that converts an RGB image to grayscale.
|
||||
|
||||
The :attr:`keep_dim` will keep the channel dimension
|
||||
|
||||
Example:
|
||||
>>> import gymnasium as gym
|
||||
>>> from gymnasium.experimental.wrappers import GrayscaleObservationV0
|
||||
>>> env = gym.make("CarRacing-v2")
|
||||
>>> env.observation_space.shape
|
||||
(96, 96, 3)
|
||||
>>> grayscale_env = GrayscaleObservationV0(env)
|
||||
>>> grayscale_env.observation_space.shape
|
||||
(96, 96)
|
||||
>>> grayscale_env = GrayscaleObservationV0(env, keep_dim=True)
|
||||
>>> grayscale_env.observation_space.shape
|
||||
(96, 96, 1)
|
||||
"""
|
||||
|
||||
def __init__(self, env: gym.Env[ObsType, ActType], keep_dim: bool = False):
|
||||
"""Constructor for an RGB image based environments to make the image grayscale.
|
||||
|
||||
Args:
|
||||
env: The environment to wrap
|
||||
keep_dim: If to keep the channel in the observation, if ``True``, ``obs.shape == 3`` else ``obs.shape == 2``
|
||||
"""
|
||||
assert isinstance(env.observation_space, spaces.Box)
|
||||
assert (
|
||||
len(env.observation_space.shape) == 3
|
||||
and env.observation_space.shape[-1] == 3
|
||||
)
|
||||
assert (
|
||||
np.all(env.observation_space.low == 0)
|
||||
and np.all(env.observation_space.high == 255)
|
||||
and env.observation_space.dtype == np.uint8
|
||||
)
|
||||
gym.utils.RecordConstructorArgs.__init__(self, keep_dim=keep_dim)
|
||||
|
||||
self.keep_dim: Final[bool] = keep_dim
|
||||
if keep_dim:
|
||||
new_observation_space = spaces.Box(
|
||||
low=0,
|
||||
high=255,
|
||||
shape=env.observation_space.shape[:2] + (1,),
|
||||
dtype=np.uint8,
|
||||
)
|
||||
LambdaObservationV0.__init__(
|
||||
self,
|
||||
env=env,
|
||||
func=lambda obs: np.expand_dims(
|
||||
np.sum(
|
||||
np.multiply(obs, np.array([0.2125, 0.7154, 0.0721])), axis=-1
|
||||
).astype(np.uint8),
|
||||
axis=-1,
|
||||
),
|
||||
observation_space=new_observation_space,
|
||||
)
|
||||
else:
|
||||
new_observation_space = spaces.Box(
|
||||
low=0, high=255, shape=env.observation_space.shape[:2], dtype=np.uint8
|
||||
)
|
||||
LambdaObservationV0.__init__(
|
||||
self,
|
||||
env=env,
|
||||
func=lambda obs: np.sum(
|
||||
np.multiply(obs, np.array([0.2125, 0.7154, 0.0721])), axis=-1
|
||||
).astype(np.uint8),
|
||||
observation_space=new_observation_space,
|
||||
)
|
||||
|
||||
|
||||
class ResizeObservationV0(
|
||||
LambdaObservationV0[WrapperObsType, ActType, ObsType],
|
||||
gym.utils.RecordConstructorArgs,
|
||||
):
|
||||
"""Resizes image observations using OpenCV to shape.
|
||||
|
||||
Example:
|
||||
>>> import gymnasium as gym
|
||||
>>> from gymnasium.experimental.wrappers import ResizeObservationV0
|
||||
>>> env = gym.make("CarRacing-v2")
|
||||
>>> env.observation_space.shape
|
||||
(96, 96, 3)
|
||||
>>> resized_env = ResizeObservationV0(env, (32, 32))
|
||||
>>> resized_env.observation_space.shape
|
||||
(32, 32, 3)
|
||||
"""
|
||||
|
||||
def __init__(self, env: gym.Env[ObsType, ActType], shape: tuple[int, ...]):
|
||||
"""Constructor that requires an image environment observation space with a shape.
|
||||
|
||||
Args:
|
||||
env: The environment to wrap
|
||||
shape: The resized observation shape
|
||||
"""
|
||||
assert isinstance(env.observation_space, spaces.Box)
|
||||
assert len(env.observation_space.shape) in [2, 3]
|
||||
assert np.all(env.observation_space.low == 0) and np.all(
|
||||
env.observation_space.high == 255
|
||||
)
|
||||
assert env.observation_space.dtype == np.uint8
|
||||
|
||||
assert isinstance(shape, tuple)
|
||||
assert all(np.issubdtype(type(elem), np.integer) for elem in shape)
|
||||
assert all(x > 0 for x in shape)
|
||||
|
||||
try:
|
||||
import cv2
|
||||
except ImportError as e:
|
||||
raise DependencyNotInstalled(
|
||||
"opencv (cv2) is not installed, run `pip install gymnasium[other]`"
|
||||
) from e
|
||||
|
||||
self.shape: Final[tuple[int, ...]] = tuple(shape)
|
||||
|
||||
new_observation_space = spaces.Box(
|
||||
low=0,
|
||||
high=255,
|
||||
shape=self.shape + env.observation_space.shape[2:],
|
||||
dtype=np.uint8,
|
||||
)
|
||||
|
||||
gym.utils.RecordConstructorArgs.__init__(self, shape=shape)
|
||||
LambdaObservationV0.__init__(
|
||||
self,
|
||||
env=env,
|
||||
func=lambda obs: cv2.resize(obs, self.shape, interpolation=cv2.INTER_AREA),
|
||||
observation_space=new_observation_space,
|
||||
)
|
||||
|
||||
|
||||
class ReshapeObservationV0(
|
||||
LambdaObservationV0[WrapperObsType, ActType, ObsType],
|
||||
gym.utils.RecordConstructorArgs,
|
||||
):
|
||||
"""Reshapes array based observations to shapes.
|
||||
|
||||
Example:
|
||||
>>> import gymnasium as gym
|
||||
>>> from gymnasium.experimental.wrappers import ReshapeObservationV0
|
||||
>>> env = gym.make("CarRacing-v2")
|
||||
>>> env.observation_space.shape
|
||||
(96, 96, 3)
|
||||
>>> reshape_env = ReshapeObservationV0(env, (24, 4, 96, 1, 3))
|
||||
>>> reshape_env.observation_space.shape
|
||||
(24, 4, 96, 1, 3)
|
||||
"""
|
||||
|
||||
def __init__(self, env: gym.Env[ObsType, ActType], shape: int | tuple[int, ...]):
|
||||
"""Constructor for env with ``Box`` observation space that has a shape product equal to the new shape product.
|
||||
|
||||
Args:
|
||||
env: The environment to wrap
|
||||
shape: The reshaped observation space
|
||||
"""
|
||||
assert isinstance(env.observation_space, spaces.Box)
|
||||
assert np.product(shape) == np.product(env.observation_space.shape)
|
||||
|
||||
assert isinstance(shape, tuple)
|
||||
assert all(np.issubdtype(type(elem), np.integer) for elem in shape)
|
||||
assert all(x > 0 or x == -1 for x in shape)
|
||||
|
||||
new_observation_space = spaces.Box(
|
||||
low=np.reshape(np.ravel(env.observation_space.low), shape),
|
||||
high=np.reshape(np.ravel(env.observation_space.high), shape),
|
||||
shape=shape,
|
||||
dtype=env.observation_space.dtype,
|
||||
)
|
||||
self.shape = shape
|
||||
|
||||
gym.utils.RecordConstructorArgs.__init__(self, shape=shape)
|
||||
LambdaObservationV0.__init__(
|
||||
self,
|
||||
env=env,
|
||||
func=lambda obs: np.reshape(obs, shape),
|
||||
observation_space=new_observation_space,
|
||||
)
|
||||
|
||||
|
||||
class RescaleObservationV0(
|
||||
LambdaObservationV0[WrapperObsType, ActType, ObsType],
|
||||
gym.utils.RecordConstructorArgs,
|
||||
):
|
||||
"""Linearly rescales observation to between a minimum and maximum value.
|
||||
|
||||
Example:
|
||||
>>> import gymnasium as gym
|
||||
>>> from gymnasium.experimental.wrappers import RescaleObservationV0
|
||||
>>> env = gym.make("Pendulum-v1")
|
||||
>>> env.observation_space
|
||||
Box([-1. -1. -8.], [1. 1. 8.], (3,), float32)
|
||||
>>> env = RescaleObservationV0(env, np.array([-2, -1, -10], dtype=np.float32), np.array([1, 0, 1], dtype=np.float32))
|
||||
>>> env.observation_space
|
||||
Box([ -2. -1. -10.], [1. 0. 1.], (3,), float32)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env: gym.Env[ObsType, ActType],
|
||||
min_obs: np.floating | np.integer | np.ndarray,
|
||||
max_obs: np.floating | np.integer | np.ndarray,
|
||||
):
|
||||
"""Constructor that requires the env observation spaces to be a :class:`Box`.
|
||||
|
||||
Args:
|
||||
env: The environment to wrap
|
||||
min_obs: The new minimum observation bound
|
||||
max_obs: The new maximum observation bound
|
||||
"""
|
||||
assert isinstance(env.observation_space, spaces.Box)
|
||||
assert not np.any(env.observation_space.low == np.inf) and not np.any(
|
||||
env.observation_space.high == np.inf
|
||||
)
|
||||
|
||||
if not isinstance(min_obs, np.ndarray):
|
||||
assert np.issubdtype(type(min_obs), np.integer) or np.issubdtype(
|
||||
type(max_obs), np.floating
|
||||
)
|
||||
min_obs = np.full(env.observation_space.shape, min_obs)
|
||||
assert (
|
||||
min_obs.shape == env.observation_space.shape
|
||||
), f"{min_obs.shape}, {env.observation_space.shape}, {min_obs}, {env.observation_space.low}"
|
||||
assert not np.any(min_obs == np.inf)
|
||||
|
||||
if not isinstance(max_obs, np.ndarray):
|
||||
assert np.issubdtype(type(max_obs), np.integer) or np.issubdtype(
|
||||
type(max_obs), np.floating
|
||||
)
|
||||
max_obs = np.full(env.observation_space.shape, max_obs)
|
||||
assert max_obs.shape == env.observation_space.shape
|
||||
assert not np.any(max_obs == np.inf)
|
||||
|
||||
self.min_obs = min_obs
|
||||
self.max_obs = max_obs
|
||||
|
||||
# Imagine the x-axis between the old Box and the y-axis being the new Box
|
||||
gradient = (max_obs - min_obs) / (
|
||||
env.observation_space.high - env.observation_space.low
|
||||
)
|
||||
intercept = gradient * -env.observation_space.low + min_obs
|
||||
|
||||
gym.utils.RecordConstructorArgs.__init__(self, min_obs=min_obs, max_obs=max_obs)
|
||||
LambdaObservationV0.__init__(
|
||||
self,
|
||||
env=env,
|
||||
func=lambda obs: gradient * obs + intercept,
|
||||
observation_space=spaces.Box(
|
||||
low=min_obs,
|
||||
high=max_obs,
|
||||
shape=env.observation_space.shape,
|
||||
dtype=env.observation_space.dtype,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class DtypeObservationV0(
|
||||
LambdaObservationV0[WrapperObsType, ActType, ObsType],
|
||||
gym.utils.RecordConstructorArgs,
|
||||
):
|
||||
"""Observation wrapper for transforming the dtype of an observation.
|
||||
|
||||
Note:
|
||||
This is only compatible with :class:`Box`, :class:`Discrete`, :class:`MultiDiscrete` and :class:`MultiBinary` observation spaces
|
||||
"""
|
||||
|
||||
def __init__(self, env: gym.Env[ObsType, ActType], dtype: Any):
|
||||
"""Constructor for Dtype observation wrapper.
|
||||
|
||||
Args:
|
||||
env: The environment to wrap
|
||||
dtype: The new dtype of the observation
|
||||
"""
|
||||
assert isinstance(
|
||||
env.observation_space,
|
||||
(spaces.Box, spaces.Discrete, spaces.MultiDiscrete, spaces.MultiBinary),
|
||||
)
|
||||
|
||||
self.dtype = dtype
|
||||
if isinstance(env.observation_space, spaces.Box):
|
||||
new_observation_space = spaces.Box(
|
||||
low=env.observation_space.low,
|
||||
high=env.observation_space.high,
|
||||
shape=env.observation_space.shape,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
elif isinstance(env.observation_space, spaces.Discrete):
|
||||
new_observation_space = spaces.Box(
|
||||
low=env.observation_space.start,
|
||||
high=env.observation_space.start + env.observation_space.n,
|
||||
shape=(),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
elif isinstance(env.observation_space, spaces.MultiDiscrete):
|
||||
new_observation_space = spaces.MultiDiscrete(
|
||||
env.observation_space.nvec, dtype=dtype
|
||||
)
|
||||
elif isinstance(env.observation_space, spaces.MultiBinary):
|
||||
new_observation_space = spaces.Box(
|
||||
low=0,
|
||||
high=1,
|
||||
shape=env.observation_space.shape,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
else:
|
||||
raise TypeError(
|
||||
"DtypeObservation is only compatible with value / array-based observations."
|
||||
)
|
||||
|
||||
gym.utils.RecordConstructorArgs.__init__(self, dtype=dtype)
|
||||
LambdaObservationV0.__init__(
|
||||
self,
|
||||
env=env,
|
||||
func=lambda obs: dtype(obs),
|
||||
observation_space=new_observation_space,
|
||||
)
|
||||
|
||||
|
||||
class PixelObservationV0(
|
||||
LambdaObservationV0[WrapperObsType, ActType, ObsType],
|
||||
gym.utils.RecordConstructorArgs,
|
||||
):
|
||||
"""Includes the rendered observations to the environment's observations.
|
||||
|
||||
Observations of this wrapper will be dictionaries of images.
|
||||
You can also choose to add the observation of the base environment to this dictionary.
|
||||
In that case, if the base environment has an observation space of type :class:`Dict`, the dictionary
|
||||
of rendered images will be updated with the base environment's observation. If, however, the observation
|
||||
space is of type :class:`Box`, the base environment's observation (which will be an element of the :class:`Box`
|
||||
space) will be added to the dictionary under the key "state".
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env: gym.Env[ObsType, ActType],
|
||||
pixels_only: bool = True,
|
||||
pixels_key: str = "pixels",
|
||||
obs_key: str = "state",
|
||||
):
|
||||
"""Constructor of the pixel observation wrapper.
|
||||
|
||||
Args:
|
||||
env: The environment to wrap.
|
||||
pixels_only (bool): If ``True`` (default), the original observation returned
|
||||
by the wrapped environment will be discarded, and a dictionary
|
||||
observation will only include pixels. If ``False``, the
|
||||
observation dictionary will contain both the original
|
||||
observations and the pixel observations.
|
||||
pixels_key: Optional custom string specifying the pixel key. Defaults to "pixels"
|
||||
obs_key: Optional custom string specifying the obs key. Defaults to "state"
|
||||
"""
|
||||
gym.utils.RecordConstructorArgs.__init__(
|
||||
self, pixels_only=pixels_only, pixels_key=pixels_key, obs_key=obs_key
|
||||
)
|
||||
|
||||
assert env.render_mode is not None and env.render_mode != "human"
|
||||
env.reset()
|
||||
pixels = env.render()
|
||||
assert pixels is not None and isinstance(pixels, np.ndarray)
|
||||
pixel_space = spaces.Box(low=0, high=255, shape=pixels.shape, dtype=np.uint8)
|
||||
|
||||
if pixels_only:
|
||||
obs_space = pixel_space
|
||||
LambdaObservationV0.__init__(
|
||||
self, env=env, func=lambda _: self.render(), observation_space=obs_space
|
||||
)
|
||||
elif isinstance(env.observation_space, spaces.Dict):
|
||||
assert pixels_key not in env.observation_space.spaces.keys()
|
||||
|
||||
obs_space = spaces.Dict(
|
||||
{pixels_key: pixel_space, **env.observation_space.spaces}
|
||||
)
|
||||
LambdaObservationV0.__init__(
|
||||
self,
|
||||
env=env,
|
||||
func=lambda obs: {pixels_key: self.render(), **obs_space},
|
||||
observation_space=obs_space,
|
||||
)
|
||||
else:
|
||||
obs_space = spaces.Dict(
|
||||
{obs_key: env.observation_space, pixels_key: pixel_space}
|
||||
)
|
||||
LambdaObservationV0.__init__(
|
||||
self,
|
||||
env=env,
|
||||
func=lambda obs: {obs_key: obs, pixels_key: self.render()},
|
||||
observation_space=obs_space,
|
||||
)
|
@@ -1,102 +0,0 @@
|
||||
"""A collection of wrappers for modifying the reward.
|
||||
|
||||
* ``LambdaRewardV0`` - Transforms the reward by a function
|
||||
* ``ClipRewardV0`` - Clips the reward between a minimum and maximum value
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Callable, SupportsFloat
|
||||
|
||||
import numpy as np
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.core import ActType, ObsType
|
||||
from gymnasium.error import InvalidBound
|
||||
|
||||
|
||||
__all__ = ["LambdaRewardV0", "ClipRewardV0"]
|
||||
|
||||
|
||||
class LambdaRewardV0(
|
||||
gym.RewardWrapper[ObsType, ActType], gym.utils.RecordConstructorArgs
|
||||
):
|
||||
"""A reward wrapper that allows a custom function to modify the step reward.
|
||||
|
||||
Example:
|
||||
>>> import gymnasium as gym
|
||||
>>> from gymnasium.experimental.wrappers import LambdaRewardV0
|
||||
>>> env = gym.make("CartPole-v1")
|
||||
>>> env = LambdaRewardV0(env, lambda r: 2 * r + 1)
|
||||
>>> _ = env.reset()
|
||||
>>> _, rew, _, _, _ = env.step(0)
|
||||
>>> rew
|
||||
3.0
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env: gym.Env[ObsType, ActType],
|
||||
func: Callable[[SupportsFloat], SupportsFloat],
|
||||
):
|
||||
"""Initialize LambdaRewardV0 wrapper.
|
||||
|
||||
Args:
|
||||
env (Env): The environment to wrap
|
||||
func: (Callable): The function to apply to reward
|
||||
"""
|
||||
gym.utils.RecordConstructorArgs.__init__(self, func=func)
|
||||
gym.RewardWrapper.__init__(self, env)
|
||||
|
||||
self.func = func
|
||||
|
||||
def reward(self, reward: SupportsFloat) -> SupportsFloat:
|
||||
"""Apply function to reward.
|
||||
|
||||
Args:
|
||||
reward (Union[float, int, np.ndarray]): environment's reward
|
||||
"""
|
||||
return self.func(reward)
|
||||
|
||||
|
||||
class ClipRewardV0(LambdaRewardV0[ObsType, ActType], gym.utils.RecordConstructorArgs):
|
||||
"""A wrapper that clips the rewards for an environment between an upper and lower bound.
|
||||
|
||||
Example:
|
||||
>>> import gymnasium as gym
|
||||
>>> from gymnasium.experimental.wrappers import ClipRewardV0
|
||||
>>> env = gym.make("CartPole-v1")
|
||||
>>> env = ClipRewardV0(env, 0, 0.5)
|
||||
>>> _ = env.reset()
|
||||
>>> _, rew, _, _, _ = env.step(1)
|
||||
>>> rew
|
||||
0.5
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env: gym.Env[ObsType, ActType],
|
||||
min_reward: float | np.ndarray | None = None,
|
||||
max_reward: float | np.ndarray | None = None,
|
||||
):
|
||||
"""Initialize ClipRewardsV0 wrapper.
|
||||
|
||||
Args:
|
||||
env (Env): The environment to wrap
|
||||
min_reward (Union[float, np.ndarray]): lower bound to apply
|
||||
max_reward (Union[float, np.ndarray]): higher bound to apply
|
||||
"""
|
||||
if min_reward is None and max_reward is None:
|
||||
raise InvalidBound("Both `min_reward` and `max_reward` cannot be None")
|
||||
|
||||
elif max_reward is not None and min_reward is not None:
|
||||
if np.any(max_reward - min_reward < 0):
|
||||
raise InvalidBound(
|
||||
f"Min reward ({min_reward}) must be smaller than max reward ({max_reward})"
|
||||
)
|
||||
|
||||
gym.utils.RecordConstructorArgs.__init__(
|
||||
self, min_reward=min_reward, max_reward=max_reward
|
||||
)
|
||||
LambdaRewardV0.__init__(
|
||||
self, env=env, func=lambda x: np.clip(x, a_min=min_reward, a_max=max_reward)
|
||||
)
|
@@ -1,146 +0,0 @@
|
||||
"""Wrappers for vector environments."""
|
||||
# pyright: reportUnsupportedDunderAll=false
|
||||
import importlib
|
||||
import re
|
||||
|
||||
from gymnasium.error import DeprecatedWrapper
|
||||
from gymnasium.experimental.wrappers.vector.dict_info_to_list import DictInfoToListV0
|
||||
from gymnasium.experimental.wrappers.vector.record_episode_statistics import (
|
||||
RecordEpisodeStatisticsV0,
|
||||
)
|
||||
from gymnasium.experimental.wrappers.vector.vectorize_action import (
|
||||
ClipActionV0,
|
||||
LambdaActionV0,
|
||||
RescaleActionV0,
|
||||
VectorizeLambdaActionV0,
|
||||
)
|
||||
from gymnasium.experimental.wrappers.vector.vectorize_observation import (
|
||||
DtypeObservationV0,
|
||||
FilterObservationV0,
|
||||
FlattenObservationV0,
|
||||
GrayscaleObservationV0,
|
||||
LambdaObservationV0,
|
||||
RescaleObservationV0,
|
||||
ReshapeObservationV0,
|
||||
ResizeObservationV0,
|
||||
VectorizeLambdaObservationV0,
|
||||
)
|
||||
from gymnasium.experimental.wrappers.vector.vectorize_reward import (
|
||||
ClipRewardV0,
|
||||
LambdaRewardV0,
|
||||
VectorizeLambdaRewardV0,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
# --- Vector only wrappers
|
||||
"VectorizeLambdaObservationV0",
|
||||
"VectorizeLambdaActionV0",
|
||||
"VectorizeLambdaRewardV0",
|
||||
"DictInfoToListV0",
|
||||
# --- Observation wrappers ---
|
||||
"LambdaObservationV0",
|
||||
"FilterObservationV0",
|
||||
"FlattenObservationV0",
|
||||
"GrayscaleObservationV0",
|
||||
"ResizeObservationV0",
|
||||
"ReshapeObservationV0",
|
||||
"RescaleObservationV0",
|
||||
"DtypeObservationV0",
|
||||
# "PixelObservationV0",
|
||||
# "NormalizeObservationV0",
|
||||
# "TimeAwareObservationV0",
|
||||
# "FrameStackObservationV0",
|
||||
# "DelayObservationV0",
|
||||
# --- Action Wrappers ---
|
||||
"LambdaActionV0",
|
||||
"ClipActionV0",
|
||||
"RescaleActionV0",
|
||||
# --- Reward wrappers ---
|
||||
"LambdaRewardV0",
|
||||
"ClipRewardV0",
|
||||
# "NormalizeRewardV1",
|
||||
# --- Common ---
|
||||
"RecordEpisodeStatisticsV0",
|
||||
# --- Rendering ---
|
||||
# "RenderCollectionV0",
|
||||
# "RecordVideoV0",
|
||||
# "HumanRenderingV0",
|
||||
# --- Conversion ---
|
||||
"JaxToNumpyV0",
|
||||
"JaxToTorchV0",
|
||||
"NumpyToTorchV0",
|
||||
]
|
||||
|
||||
|
||||
# As these wrappers requires `jax` or `torch`, they are loaded by runtime on users trying to access them
|
||||
# to avoid `import jax` or `import torch` on `import gymnasium`.
|
||||
_wrapper_to_class = {
|
||||
# data converters
|
||||
"JaxToNumpyV0": "jax_to_numpy",
|
||||
"JaxToTorchV0": "jax_to_torch",
|
||||
"NumpyToTorchV0": "numpy_to_torch",
|
||||
}
|
||||
|
||||
|
||||
def __getattr__(wrapper_name: str):
|
||||
"""Load a wrapper by name.
|
||||
|
||||
This optimizes the loading of gymnasium wrappers by only loading the wrapper if it is used.
|
||||
Errors will be raised if the wrapper does not exist or if the version is not the latest.
|
||||
|
||||
Args:
|
||||
wrapper_name: The name of a wrapper to load.
|
||||
|
||||
Returns:
|
||||
The specified wrapper.
|
||||
|
||||
Raises:
|
||||
AttributeError: If the wrapper does not exist.
|
||||
DeprecatedWrapper: If the version is not the latest.
|
||||
"""
|
||||
# Check if the requested wrapper is in the _wrapper_to_class dictionary
|
||||
if wrapper_name in _wrapper_to_class:
|
||||
import_stmt = (
|
||||
f"gymnasium.experimental.wrappers.vector.{_wrapper_to_class[wrapper_name]}"
|
||||
)
|
||||
module = importlib.import_module(import_stmt)
|
||||
return getattr(module, wrapper_name)
|
||||
|
||||
# Define a regex pattern to match the integer suffix (version number) of the wrapper
|
||||
int_suffix_pattern = r"(\d+)$"
|
||||
version_match = re.search(int_suffix_pattern, wrapper_name)
|
||||
|
||||
# If a version number is found, extract it and the base wrapper name
|
||||
if version_match:
|
||||
version = int(version_match.group())
|
||||
base_name = wrapper_name[: -len(version_match.group())]
|
||||
else:
|
||||
version = float("inf")
|
||||
base_name = wrapper_name[:-2]
|
||||
|
||||
# Filter the list of all wrappers to include only those with the same base name
|
||||
matching_wrappers = [name for name in __all__ if name.startswith(base_name)]
|
||||
|
||||
# If no matching wrappers are found, raise an AttributeError
|
||||
if not matching_wrappers:
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {wrapper_name!r}")
|
||||
|
||||
# Find the latest version of the matching wrappers
|
||||
latest_wrapper = max(
|
||||
matching_wrappers, key=lambda s: int(re.findall(int_suffix_pattern, s)[0])
|
||||
)
|
||||
latest_version = int(re.findall(int_suffix_pattern, latest_wrapper)[0])
|
||||
|
||||
# If the requested wrapper is an older version, raise a DeprecatedWrapper exception
|
||||
if version < latest_version:
|
||||
raise DeprecatedWrapper(
|
||||
f"{wrapper_name!r} is now deprecated, use {latest_wrapper!r} instead.\n"
|
||||
f"To see the changes made, go to "
|
||||
f"https://gymnasium.farama.org/api/experimental/vector-wrappers/#gymnasium.experimental.wrappers.vector.{latest_wrapper}"
|
||||
)
|
||||
# If the requested version is invalid, raise an AttributeError
|
||||
else:
|
||||
raise AttributeError(
|
||||
f"module {__name__!r} has no attribute {wrapper_name!r}, did you mean {latest_wrapper!r}"
|
||||
)
|
@@ -1,86 +0,0 @@
|
||||
"""Wrapper that converts the info format for vec envs into the list format."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from gymnasium.core import ActType, ObsType
|
||||
from gymnasium.experimental.vector.vector_env import ArrayType, VectorEnv, VectorWrapper
|
||||
|
||||
|
||||
__all__ = ["DictInfoToListV0"]
|
||||
|
||||
|
||||
class DictInfoToListV0(VectorWrapper):
|
||||
"""Converts infos of vectorized environments from dict to List[dict].
|
||||
|
||||
This wrapper converts the info format of a
|
||||
vector environment from a dictionary to a list of dictionaries.
|
||||
This wrapper is intended to be used around vectorized
|
||||
environments. If using other wrappers that perform
|
||||
operation on info like `RecordEpisodeStatistics` this
|
||||
need to be the outermost wrapper.
|
||||
|
||||
i.e. ``DictInfoToListV0(RecordEpisodeStatisticsV0(vector_env))``
|
||||
|
||||
Example::
|
||||
|
||||
>>> import numpy as np
|
||||
>>> dict_info = {
|
||||
... "k": np.array([0., 0., 0.5, 0.3]),
|
||||
... "_k": np.array([False, False, True, True])
|
||||
... }
|
||||
>>> list_info = [{}, {}, {"k": 0.5}, {"k": 0.3}]
|
||||
"""
|
||||
|
||||
def __init__(self, env: VectorEnv):
|
||||
"""This wrapper will convert the info into the list format.
|
||||
|
||||
Args:
|
||||
env (Env): The environment to apply the wrapper
|
||||
"""
|
||||
super().__init__(env)
|
||||
|
||||
def step(
|
||||
self, actions: ActType
|
||||
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, list[dict[str, Any]]]:
|
||||
"""Steps through the environment, convert dict info to list."""
|
||||
observation, reward, terminated, truncated, infos = self.env.step(actions)
|
||||
list_info = self._convert_info_to_list(infos)
|
||||
|
||||
return observation, reward, terminated, truncated, list_info
|
||||
|
||||
def reset(
|
||||
self,
|
||||
*,
|
||||
seed: int | list[int] | None = None,
|
||||
options: dict[str, Any] | None = None,
|
||||
) -> tuple[ObsType, list[dict[str, Any]]]:
|
||||
"""Resets the environment using kwargs."""
|
||||
obs, infos = self.env.reset(seed=seed, options=options)
|
||||
list_info = self._convert_info_to_list(infos)
|
||||
|
||||
return obs, list_info
|
||||
|
||||
def _convert_info_to_list(self, infos: dict) -> list[dict[str, Any]]:
|
||||
"""Convert the dict info to list.
|
||||
|
||||
Convert the dict info of the vectorized environment
|
||||
into a list of dictionaries where the i-th dictionary
|
||||
has the info of the i-th environment.
|
||||
|
||||
Args:
|
||||
infos (dict): info dict coming from the env.
|
||||
|
||||
Returns:
|
||||
list_info (list): converted info.
|
||||
|
||||
"""
|
||||
list_info = [{} for _ in range(self.num_envs)]
|
||||
list_info = self._process_episode_statistics(infos, list_info)
|
||||
for k in infos:
|
||||
if k.startswith("_"):
|
||||
continue
|
||||
for i, has_info in enumerate(infos[f"_{k}"]):
|
||||
if has_info:
|
||||
list_info[i][k] = infos[k][i]
|
||||
return list_info
|
@@ -1,143 +0,0 @@
|
||||
"""Vectorizes action wrappers to work for `VectorEnv`."""
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Any, Callable
|
||||
|
||||
import numpy as np
|
||||
|
||||
from gymnasium import Space
|
||||
from gymnasium.core import ActType, Env
|
||||
from gymnasium.experimental.vector import VectorActionWrapper, VectorEnv
|
||||
from gymnasium.experimental.wrappers import lambda_action
|
||||
from gymnasium.vector.utils import batch_space, concatenate, create_empty_array, iterate
|
||||
|
||||
|
||||
class LambdaActionV0(VectorActionWrapper):
|
||||
"""Transforms an action via a function provided to the wrapper.
|
||||
|
||||
The function :attr:`func` will be applied to all vector actions.
|
||||
If the observations from :attr:`func` are outside the bounds of the ``env``'s action space, provide an :attr:`action_space`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env: VectorEnv,
|
||||
func: Callable[[ActType], Any],
|
||||
action_space: Space | None = None,
|
||||
):
|
||||
"""Constructor for the lambda action wrapper.
|
||||
|
||||
Args:
|
||||
env: The vector environment to wrap
|
||||
func: A function that will transform an action. If this transformed action is outside the action space of ``env.action_space`` then provide an ``action_space``.
|
||||
action_space: The action spaces of the wrapper, if None, then it is assumed the same as ``env.action_space``.
|
||||
"""
|
||||
super().__init__(env)
|
||||
|
||||
if action_space is not None:
|
||||
self.action_space = action_space
|
||||
|
||||
self.func = func
|
||||
|
||||
def actions(self, actions: ActType) -> ActType:
|
||||
"""Applies the :attr:`func` to the actions."""
|
||||
return self.func(actions)
|
||||
|
||||
|
||||
class VectorizeLambdaActionV0(VectorActionWrapper):
|
||||
"""Vectorizes a single-agent lambda action wrapper for vector environments."""
|
||||
|
||||
class VectorizedEnv(Env):
|
||||
"""Fake single-agent environment uses for the single-agent wrapper."""
|
||||
|
||||
def __init__(self, action_space: Space):
|
||||
"""Constructor for the fake environment."""
|
||||
self.action_space = action_space
|
||||
|
||||
def __init__(
|
||||
self, env: VectorEnv, wrapper: type[lambda_action.LambdaActionV0], **kwargs: Any
|
||||
):
|
||||
"""Constructor for the vectorized lambda action wrapper.
|
||||
|
||||
Args:
|
||||
env: The vector environment to wrap
|
||||
wrapper: The wrapper to vectorize
|
||||
**kwargs: Arguments for the LambdaActionV0 wrapper
|
||||
"""
|
||||
super().__init__(env)
|
||||
|
||||
self.wrapper = wrapper(
|
||||
self.VectorizedEnv(self.env.single_action_space), **kwargs
|
||||
)
|
||||
self.single_action_space = self.wrapper.action_space
|
||||
self.action_space = batch_space(self.single_action_space, self.num_envs)
|
||||
|
||||
self.same_out = self.action_space == self.env.action_space
|
||||
self.out = create_empty_array(self.single_action_space, self.num_envs)
|
||||
|
||||
def actions(self, actions: ActType) -> ActType:
|
||||
"""Applies the wrapper to each of the action.
|
||||
|
||||
Args:
|
||||
actions: The actions to apply the function to
|
||||
|
||||
Returns:
|
||||
The updated actions using the wrapper func
|
||||
"""
|
||||
if self.same_out:
|
||||
return concatenate(
|
||||
self.single_action_space,
|
||||
tuple(
|
||||
self.wrapper.func(action)
|
||||
for action in iterate(self.action_space, actions)
|
||||
),
|
||||
actions,
|
||||
)
|
||||
else:
|
||||
return deepcopy(
|
||||
concatenate(
|
||||
self.single_action_space,
|
||||
tuple(
|
||||
self.wrapper.func(action)
|
||||
for action in iterate(self.action_space, actions)
|
||||
),
|
||||
self.out,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class ClipActionV0(VectorizeLambdaActionV0):
|
||||
"""Clip the continuous action within the valid :class:`Box` observation space bound."""
|
||||
|
||||
def __init__(self, env: VectorEnv):
|
||||
"""Constructor for the Clip Action wrapper.
|
||||
|
||||
Args:
|
||||
env: The vector environment to wrap
|
||||
"""
|
||||
super().__init__(env, lambda_action.ClipActionV0)
|
||||
|
||||
|
||||
class RescaleActionV0(VectorizeLambdaActionV0):
|
||||
"""Affinely rescales the continuous action space of the environment to the range [min_action, max_action]."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env: VectorEnv,
|
||||
min_action: float | int | np.ndarray,
|
||||
max_action: float | int | np.ndarray,
|
||||
):
|
||||
"""Initializes the :class:`RescaleAction` wrapper.
|
||||
|
||||
Args:
|
||||
env (Env): The vector environment to wrap
|
||||
min_action (float, int or np.ndarray): The min values for each action. This may be a numpy array or a scalar.
|
||||
max_action (float, int or np.ndarray): The max values for each action. This may be a numpy array or a scalar.
|
||||
"""
|
||||
super().__init__(
|
||||
env,
|
||||
lambda_action.RescaleActionV0,
|
||||
min_action=min_action,
|
||||
max_action=max_action,
|
||||
)
|
@@ -1,222 +0,0 @@
|
||||
"""Vectorizes observation wrappers to works for `VectorEnv`."""
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Any, Callable, Sequence
|
||||
|
||||
import numpy as np
|
||||
|
||||
from gymnasium import Space
|
||||
from gymnasium.core import Env, ObsType
|
||||
from gymnasium.experimental.vector import VectorEnv, VectorObservationWrapper
|
||||
from gymnasium.experimental.vector.utils import batch_space, concatenate, iterate
|
||||
from gymnasium.experimental.wrappers import lambda_observation
|
||||
from gymnasium.vector.utils import create_empty_array
|
||||
|
||||
|
||||
class LambdaObservationV0(VectorObservationWrapper):
|
||||
"""Transforms an observation via a function provided to the wrapper.
|
||||
|
||||
The function :attr:`func` will be applied to all vector observations.
|
||||
If the observations from :attr:`func` are outside the bounds of the ``env``'s observation space, provide an :attr:`observation_space`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env: VectorEnv,
|
||||
vector_func: Callable[[ObsType], Any],
|
||||
single_func: Callable[[ObsType], Any],
|
||||
observation_space: Space | None = None,
|
||||
):
|
||||
"""Constructor for the lambda observation wrapper.
|
||||
|
||||
Args:
|
||||
env: The vector environment to wrap
|
||||
vector_func: A function that will transform the vector observation. If this transformed observation is outside the observation space of ``env.observation_space`` then provide an ``observation_space``.
|
||||
single_func: A function that will transform an individual observation.
|
||||
observation_space: The observation spaces of the wrapper, if None, then it is assumed the same as ``env.observation_space``.
|
||||
"""
|
||||
super().__init__(env)
|
||||
|
||||
if observation_space is not None:
|
||||
self.observation_space = observation_space
|
||||
|
||||
self.vector_func = vector_func
|
||||
self.single_func = single_func
|
||||
|
||||
def vector_observation(self, observation: ObsType) -> ObsType:
|
||||
"""Apply function to the vector observation."""
|
||||
return self.vector_func(observation)
|
||||
|
||||
def single_observation(self, observation: ObsType) -> ObsType:
|
||||
"""Apply function to the single observation."""
|
||||
return self.single_func(observation)
|
||||
|
||||
|
||||
class VectorizeLambdaObservationV0(VectorObservationWrapper):
|
||||
"""Vectori`es a single-agent lambda observation wrapper for vector environments."""
|
||||
|
||||
class VectorizedEnv(Env):
|
||||
"""Fake single-agent environment uses for the single-agent wrapper."""
|
||||
|
||||
def __init__(self, observation_space: Space):
|
||||
"""Constructor for the fake environment."""
|
||||
self.observation_space = observation_space
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env: VectorEnv,
|
||||
wrapper: type[lambda_observation.LambdaObservationV0],
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Constructor for the vectorized lambda observation wrapper.
|
||||
|
||||
Args:
|
||||
env: The vector environment to wrap.
|
||||
wrapper: The wrapper to vectorize
|
||||
**kwargs: Keyword argument for the wrapper
|
||||
"""
|
||||
super().__init__(env)
|
||||
|
||||
self.wrapper = wrapper(
|
||||
self.VectorizedEnv(self.env.single_observation_space), **kwargs
|
||||
)
|
||||
self.single_observation_space = self.wrapper.observation_space
|
||||
self.observation_space = batch_space(
|
||||
self.single_observation_space, self.num_envs
|
||||
)
|
||||
|
||||
self.same_out = self.observation_space == self.env.observation_space
|
||||
self.out = create_empty_array(self.single_observation_space, self.num_envs)
|
||||
|
||||
def vector_observation(self, observation: ObsType) -> ObsType:
|
||||
"""Iterates over the vector observations applying the single-agent wrapper ``observation`` then concatenates the observations together again."""
|
||||
if self.same_out:
|
||||
return concatenate(
|
||||
self.single_observation_space,
|
||||
tuple(
|
||||
self.wrapper.func(obs)
|
||||
for obs in iterate(self.observation_space, observation)
|
||||
),
|
||||
observation,
|
||||
)
|
||||
else:
|
||||
return deepcopy(
|
||||
concatenate(
|
||||
self.single_observation_space,
|
||||
tuple(
|
||||
self.wrapper.func(obs)
|
||||
for obs in iterate(self.observation_space, observation)
|
||||
),
|
||||
self.out,
|
||||
)
|
||||
)
|
||||
|
||||
def single_observation(self, observation: ObsType) -> ObsType:
|
||||
"""Transforms a single observation using the wrapper transformation function."""
|
||||
return self.wrapper.func(observation)
|
||||
|
||||
|
||||
class FilterObservationV0(VectorizeLambdaObservationV0):
|
||||
"""Vector wrapper for filtering dict or tuple observation spaces."""
|
||||
|
||||
def __init__(self, env: VectorEnv, filter_keys: Sequence[str | int]):
|
||||
"""Constructor for the filter observation wrapper.
|
||||
|
||||
Args:
|
||||
env: The vector environment to wrap
|
||||
filter_keys: The subspaces to be included, use a list of strings or integers for ``Dict`` and ``Tuple`` spaces respectivesly
|
||||
"""
|
||||
super().__init__(
|
||||
env, lambda_observation.FilterObservationV0, filter_keys=filter_keys
|
||||
)
|
||||
|
||||
|
||||
class FlattenObservationV0(VectorizeLambdaObservationV0):
|
||||
"""Observation wrapper that flattens the observation."""
|
||||
|
||||
def __init__(self, env: VectorEnv):
|
||||
"""Constructor for any environment's observation space that implements ``spaces.utils.flatten_space`` and ``spaces.utils.flatten``.
|
||||
|
||||
Args:
|
||||
env: The vector environment to wrap
|
||||
"""
|
||||
super().__init__(env, lambda_observation.FlattenObservationV0)
|
||||
|
||||
|
||||
class GrayscaleObservationV0(VectorizeLambdaObservationV0):
|
||||
"""Observation wrapper that converts an RGB image to grayscale."""
|
||||
|
||||
def __init__(self, env: VectorEnv, keep_dim: bool = False):
|
||||
"""Constructor for an RGB image based environments to make the image grayscale.
|
||||
|
||||
Args:
|
||||
env: The vector environment to wrap
|
||||
keep_dim: If to keep the channel in the observation, if ``True``, ``obs.shape == 3`` else ``obs.shape == 2``
|
||||
"""
|
||||
super().__init__(
|
||||
env, lambda_observation.GrayscaleObservationV0, keep_dim=keep_dim
|
||||
)
|
||||
|
||||
|
||||
class ResizeObservationV0(VectorizeLambdaObservationV0):
|
||||
"""Resizes image observations using OpenCV to shape."""
|
||||
|
||||
def __init__(self, env: VectorEnv, shape: tuple[int, ...]):
|
||||
"""Constructor that requires an image environment observation space with a shape.
|
||||
|
||||
Args:
|
||||
env: The vector environment to wrap
|
||||
shape: The resized observation shape
|
||||
"""
|
||||
super().__init__(env, lambda_observation.ResizeObservationV0, shape=shape)
|
||||
|
||||
|
||||
class ReshapeObservationV0(VectorizeLambdaObservationV0):
|
||||
"""Reshapes array based observations to shapes."""
|
||||
|
||||
def __init__(self, env: VectorEnv, shape: int | tuple[int, ...]):
|
||||
"""Constructor for env with Box observation space that has a shape product equal to the new shape product.
|
||||
|
||||
Args:
|
||||
env: The vector environment to wrap
|
||||
shape: The reshaped observation space
|
||||
"""
|
||||
super().__init__(env, lambda_observation.ReshapeObservationV0, shape=shape)
|
||||
|
||||
|
||||
class RescaleObservationV0(VectorizeLambdaObservationV0):
|
||||
"""Linearly rescales observation to between a minimum and maximum value."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env: VectorEnv,
|
||||
min_obs: np.floating | np.integer | np.ndarray,
|
||||
max_obs: np.floating | np.integer | np.ndarray,
|
||||
):
|
||||
"""Constructor that requires the env observation spaces to be a :class:`Box`.
|
||||
|
||||
Args:
|
||||
env: The vector environment to wrap
|
||||
min_obs: The new minimum observation bound
|
||||
max_obs: The new maximum observation bound
|
||||
"""
|
||||
super().__init__(
|
||||
env,
|
||||
lambda_observation.RescaleObservationV0,
|
||||
min_obs=min_obs,
|
||||
max_obs=max_obs,
|
||||
)
|
||||
|
||||
|
||||
class DtypeObservationV0(VectorizeLambdaObservationV0):
|
||||
"""Observation wrapper for transforming the dtype of an observation."""
|
||||
|
||||
def __init__(self, env: VectorEnv, dtype: Any):
|
||||
"""Constructor for Dtype observation wrapper.
|
||||
|
||||
Args:
|
||||
env: The vector environment to wrap
|
||||
dtype: The new dtype of the observation
|
||||
"""
|
||||
super().__init__(env, lambda_observation.DtypeObservationV0, dtype=dtype)
|
@@ -1,78 +0,0 @@
|
||||
"""Vectorizes reward function to work with `VectorEnv`."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Callable
|
||||
|
||||
import numpy as np
|
||||
|
||||
from gymnasium import Env
|
||||
from gymnasium.experimental.vector import VectorEnv, VectorRewardWrapper
|
||||
from gymnasium.experimental.vector.vector_env import ArrayType
|
||||
from gymnasium.experimental.wrappers import lambda_reward
|
||||
|
||||
|
||||
class LambdaRewardV0(VectorRewardWrapper):
|
||||
"""A reward wrapper that allows a custom function to modify the step reward."""
|
||||
|
||||
def __init__(self, env: VectorEnv, func: Callable[[ArrayType], ArrayType]):
|
||||
"""Initialize LambdaRewardV0 wrapper.
|
||||
|
||||
Args:
|
||||
env (Env): The vector environment to wrap
|
||||
func: (Callable): The function to apply to reward
|
||||
"""
|
||||
super().__init__(env)
|
||||
|
||||
self.func = func
|
||||
|
||||
def reward(self, reward: ArrayType) -> ArrayType:
|
||||
"""Apply function to reward."""
|
||||
return self.func(reward)
|
||||
|
||||
|
||||
class VectorizeLambdaRewardV0(VectorRewardWrapper):
|
||||
"""Vectorizes a single-agent lambda reward wrapper for vector environments."""
|
||||
|
||||
def __init__(
|
||||
self, env: VectorEnv, wrapper: type[lambda_reward.LambdaRewardV0], **kwargs: Any
|
||||
):
|
||||
"""Constructor for the vectorized lambda reward wrapper.
|
||||
|
||||
Args:
|
||||
env: The vector environment to wrap.
|
||||
wrapper: The wrapper to vectorize
|
||||
**kwargs: Keyword argument for the wrapper
|
||||
"""
|
||||
super().__init__(env)
|
||||
|
||||
self.wrapper = wrapper(Env(), **kwargs)
|
||||
|
||||
def reward(self, reward: ArrayType) -> ArrayType:
|
||||
"""Iterates over the reward updating each with the wrapper func."""
|
||||
for i, r in enumerate(reward):
|
||||
reward[i] = self.wrapper.func(r)
|
||||
return reward
|
||||
|
||||
|
||||
class ClipRewardV0(VectorizeLambdaRewardV0):
|
||||
"""A wrapper that clips the rewards for an environment between an upper and lower bound."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env: VectorEnv,
|
||||
min_reward: float | np.ndarray | None = None,
|
||||
max_reward: float | np.ndarray | None = None,
|
||||
):
|
||||
"""Constructor for ClipReward wrapper.
|
||||
|
||||
Args:
|
||||
env: The vector environment to wrap
|
||||
min_reward: The min reward for each step
|
||||
max_reward: the max reward for each step
|
||||
"""
|
||||
super().__init__(
|
||||
env,
|
||||
lambda_reward.ClipRewardV0,
|
||||
min_reward=min_reward,
|
||||
max_reward=max_reward,
|
||||
)
|
@@ -24,13 +24,14 @@ class FuncEnv(
|
||||
This API is meant to be used in a stateless manner, with the environment state being passed around explicitly.
|
||||
That being said, nothing here prevents users from using the environment statefully, it's just not recommended.
|
||||
A functional env consists of the following functions (in this case, instance methods):
|
||||
- initial: returns the initial state of the POMDP
|
||||
- observation: returns the observation in a given state
|
||||
- transition: returns the next state after taking an action in a given state
|
||||
- reward: returns the reward for a given (state, action, next_state) tuple
|
||||
- terminal: returns whether a given state is terminal
|
||||
- state_info: optional, returns a dict of info about a given state
|
||||
- step_info: optional, returns a dict of info about a given (state, action, next_state) tuple
|
||||
|
||||
* initial: returns the initial state of the POMDP
|
||||
* observation: returns the observation in a given state
|
||||
* transition: returns the next state after taking an action in a given state
|
||||
* reward: returns the reward for a given (state, action, next_state) tuple
|
||||
* terminal: returns whether a given state is terminal
|
||||
* state_info: optional, returns a dict of info about a given state
|
||||
* step_info: optional, returns a dict of info about a given (state, action, next_state) tuple
|
||||
|
||||
The class-based structure serves the purpose of allowing environment constants to be defined in the class,
|
||||
and then using them by name in the code itself.
|
||||
@@ -47,32 +48,32 @@ class FuncEnv(
|
||||
self.__dict__.update(options or {})
|
||||
|
||||
def initial(self, rng: Any) -> StateType:
|
||||
"""Initial state."""
|
||||
"""Generates the initial state of the environment with a random number generator."""
|
||||
raise NotImplementedError
|
||||
|
||||
def transition(self, state: StateType, action: ActType, rng: Any) -> StateType:
|
||||
"""Transition."""
|
||||
"""Updates (transitions) the state with an action and random number generator."""
|
||||
raise NotImplementedError
|
||||
|
||||
def observation(self, state: StateType) -> ObsType:
|
||||
"""Observation."""
|
||||
"""Generates an observation for a given state of an environment."""
|
||||
raise NotImplementedError
|
||||
|
||||
def reward(
|
||||
self, state: StateType, action: ActType, next_state: StateType
|
||||
) -> RewardType:
|
||||
"""Reward."""
|
||||
"""Computes the reward for a given transition between `state`, `action` to `next_state`."""
|
||||
raise NotImplementedError
|
||||
|
||||
def terminal(self, state: StateType) -> TerminalType:
|
||||
"""Terminal state."""
|
||||
"""Returns if the state is a final terminal state."""
|
||||
raise NotImplementedError
|
||||
|
||||
def state_info(self, state: StateType) -> dict:
|
||||
"""Info dict about a single state."""
|
||||
return {}
|
||||
|
||||
def step_info(
|
||||
def transition_info(
|
||||
self, state: StateType, action: ActType, next_state: StateType
|
||||
) -> dict:
|
||||
"""Info dict about a full transition."""
|
||||
@@ -82,11 +83,13 @@ class FuncEnv(
|
||||
"""Functional transformations."""
|
||||
self.initial = func(self.initial)
|
||||
self.transition = func(self.transition)
|
||||
|
||||
self.observation = func(self.observation)
|
||||
self.reward = func(self.reward)
|
||||
self.terminal = func(self.terminal)
|
||||
|
||||
self.state_info = func(self.state_info)
|
||||
self.step_info = func(self.step_info)
|
||||
self.transition_info = func(self.transition_info)
|
||||
|
||||
def render_image(
|
||||
self, state: StateType, render_state: RenderStateType
|
@@ -274,7 +274,7 @@ class Box(Space[NDArray[Any]]):
|
||||
return (
|
||||
isinstance(other, Box)
|
||||
and (self.shape == other.shape)
|
||||
# and (self.dtype == other.dtype)
|
||||
and (self.dtype == other.dtype)
|
||||
and np.allclose(self.low, other.low)
|
||||
and np.allclose(self.high, other.high)
|
||||
)
|
||||
|
@@ -45,7 +45,7 @@ class Dict(Space[typing.Dict[str, Any]], typing.Mapping[str, Space[Any]]):
|
||||
|
||||
It can be convenient to use :class:`Dict` spaces if you want to make complex observations or actions more human-readable.
|
||||
Usually, it will not be possible to use elements of this space directly in learning code. However, you can easily
|
||||
convert `Dict` observations to flat arrays by using a :class:`gymnasium.wrappers.FlattenObservation` wrapper.
|
||||
convert :class:`Dict` observations to flat arrays by using a :class:`gymnasium.wrappers.FlattenObservation` wrapper.
|
||||
Similar wrappers can be implemented to deal with :class:`Dict` actions.
|
||||
"""
|
||||
|
||||
|
@@ -62,8 +62,8 @@ class Discrete(Space[np.int64]):
|
||||
|
||||
Args:
|
||||
mask: An optional mask for if an action can be selected.
|
||||
Expected `np.ndarray` of shape `(n,)` and dtype `np.int8` where `1` represents valid actions and `0` invalid / infeasible actions.
|
||||
If there are no possible actions (i.e. `np.all(mask == 0)`) then `space.start` will be returned.
|
||||
Expected `np.ndarray` of shape ``(n,)`` and dtype ``np.int8`` where ``1`` represents valid actions and ``0`` invalid / infeasible actions.
|
||||
If there are no possible actions (i.e. ``np.all(mask == 0)``) then ``space.start`` will be returned.
|
||||
|
||||
Returns:
|
||||
A sampled integer from the space
|
||||
|
@@ -27,7 +27,7 @@ class GraphInstance(NamedTuple):
|
||||
|
||||
|
||||
class Graph(Space[GraphInstance]):
|
||||
r"""A space representing graph information as a series of `nodes` connected with `edges` according to an adjacency matrix represented as a series of `edge_links`.
|
||||
r"""A space representing graph information as a series of ``nodes`` connected with ``edges`` according to an adjacency matrix represented as a series of ``edge_links``.
|
||||
|
||||
Example:
|
||||
>>> from gymnasium.spaces import Graph, Box, Discrete
|
||||
@@ -122,14 +122,14 @@ class Graph(Space[GraphInstance]):
|
||||
num_nodes: int = 10,
|
||||
num_edges: int | None = None,
|
||||
) -> GraphInstance:
|
||||
"""Generates a single sample graph with num_nodes between 1 and 10 sampled from the Graph.
|
||||
"""Generates a single sample graph with num_nodes between ``1`` and ``10`` sampled from the Graph.
|
||||
|
||||
Args:
|
||||
mask: An optional tuple of optional node and edge mask that is only possible with Discrete spaces
|
||||
(Box spaces don't support sample masks).
|
||||
If no `num_edges` is provided then the `edge_mask` is multiplied by the number of edges
|
||||
num_nodes: The number of nodes that will be sampled, the default is 10 nodes
|
||||
num_edges: An optional number of edges, otherwise, a random number between 0 and `num_nodes` ^ 2
|
||||
If no ``num_edges`` is provided then the ``edge_mask`` is multiplied by the number of edges
|
||||
num_nodes: The number of nodes that will be sampled, the default is `10` nodes
|
||||
num_edges: An optional number of edges, otherwise, a random number between `0` and :math:`num_nodes^2`
|
||||
|
||||
Returns:
|
||||
A :class:`GraphInstance` with attributes `.nodes`, `.edges`, and `.edge_links`.
|
||||
@@ -212,7 +212,7 @@ class Graph(Space[GraphInstance]):
|
||||
def __repr__(self) -> str:
|
||||
"""A string representation of this space.
|
||||
|
||||
The representation will include node_space and edge_space
|
||||
The representation will include ``node_space`` and ``edge_space``
|
||||
|
||||
Returns:
|
||||
A representation of the space
|
||||
|
@@ -65,8 +65,8 @@ class MultiBinary(Space[NDArray[np.int8]]):
|
||||
|
||||
Args:
|
||||
mask: An optional np.ndarray to mask samples with expected shape of ``space.shape``.
|
||||
For mask == 0 then the samples will be 0 and mask == 1 then random samples will be generated.
|
||||
The expected mask shape is the space shape and mask dtype is `np.int8`.
|
||||
For ``mask == 0`` then the samples will be ``0`` and ``mask == 1` then random samples will be generated.
|
||||
The expected mask shape is the space shape and mask dtype is ``np.int8``.
|
||||
|
||||
Returns:
|
||||
Sampled values from space
|
||||
|
@@ -87,12 +87,12 @@ class MultiDiscrete(Space[NDArray[np.integer]]):
|
||||
"""Generates a single random sample this space.
|
||||
|
||||
Args:
|
||||
mask: An optional mask for multi-discrete, expects tuples with a `np.ndarray` mask in the position of each
|
||||
action with shape `(n,)` where `n` is the number of actions and `dtype=np.int8`.
|
||||
Only mask values == 1 are possible to sample unless all mask values for an action are 0 then the default action `self.start` (the smallest element) is sampled.
|
||||
mask: An optional mask for multi-discrete, expects tuples with a ``np.ndarray`` mask in the position of each
|
||||
action with shape ``(n,)`` where ``n`` is the number of actions and ``dtype=np.int8``.
|
||||
Only ``mask values == 1`` are possible to sample unless all mask values for an action are ``0`` then the default action ``self.start`` (the smallest element) is sampled.
|
||||
|
||||
Returns:
|
||||
An `np.ndarray` of shape `space.shape`
|
||||
An ``np.ndarray`` of :meth:`Space.shape`
|
||||
"""
|
||||
if mask is not None:
|
||||
|
||||
@@ -206,6 +206,7 @@ class MultiDiscrete(Space[NDArray[np.integer]]):
|
||||
"""Check whether ``other`` is equivalent to this instance."""
|
||||
return bool(
|
||||
isinstance(other, MultiDiscrete)
|
||||
and self.dtype == other.dtype
|
||||
and np.all(self.nvec == other.nvec)
|
||||
and np.all(self.start == other.start)
|
||||
)
|
||||
|
@@ -38,7 +38,7 @@ class Sequence(Space[Union[typing.Tuple[Any, ...], Any]]):
|
||||
Args:
|
||||
space: Elements in the sequences this space represent must belong to this space.
|
||||
seed: Optionally, you can use this argument to seed the RNG that is used to sample from the space.
|
||||
stack: If `True` then the resulting samples would be stacked.
|
||||
stack: If ``True`` then the resulting samples would be stacked.
|
||||
"""
|
||||
assert isinstance(
|
||||
space, Space
|
||||
@@ -78,14 +78,13 @@ class Sequence(Space[Union[typing.Tuple[Any, ...], Any]]):
|
||||
|
||||
Args:
|
||||
mask: An optional mask for (optionally) the length of the sequence and (optionally) the values in the sequence.
|
||||
If you specify `mask`, it is expected to be a tuple of the form `(length_mask, sample_mask)` where `length_mask`
|
||||
is
|
||||
If you specify ``mask``, it is expected to be a tuple of the form ``(length_mask, sample_mask)`` where ``length_mask`` is
|
||||
|
||||
* ``None`` The length will be randomly drawn from a geometric distribution
|
||||
* ``np.ndarray`` of integers, in which case the length of the sampled sequence is randomly drawn from this array.
|
||||
* ``int`` for a fixed length sample
|
||||
|
||||
The second element of the mask tuple `sample` mask specifies a mask that is applied when
|
||||
The second element of the mask tuple ``sample`` mask specifies a mask that is applied when
|
||||
sampling elements from the base space. The mask is applied for each feature space sample.
|
||||
|
||||
Returns:
|
||||
|
@@ -78,13 +78,13 @@ class Text(Space[str]):
|
||||
self,
|
||||
mask: None | (tuple[int | None, NDArray[np.int8] | None]) = None,
|
||||
) -> str:
|
||||
"""Generates a single random sample from this space with by default a random length between `min_length` and `max_length` and sampled from the `charset`.
|
||||
"""Generates a single random sample from this space with by default a random length between ``min_length`` and ``max_length`` and sampled from the ``charset``.
|
||||
|
||||
Args:
|
||||
mask: An optional tuples of length and mask for the text.
|
||||
The length is expected to be between the `min_length` and `max_length` otherwise a random integer between `min_length` and `max_length` is selected.
|
||||
For the mask, we expect a numpy array of length of the charset passed with `dtype == np.int8`.
|
||||
If the charlist mask is all zero then an empty string is returned no matter the `min_length`
|
||||
The length is expected to be between the ``min_length`` and ``max_length`` otherwise a random integer between ``min_length`` and ``max_length`` is selected.
|
||||
For the mask, we expect a numpy array of length of the charset passed with ``dtype == np.int8``.
|
||||
If the charlist mask is all zero then an empty string is returned no matter the ``min_length``
|
||||
|
||||
Returns:
|
||||
A sampled string from the space
|
||||
|
@@ -53,8 +53,8 @@ class Tuple(Space[typing.Tuple[Any, ...]], typing.Sequence[Any]):
|
||||
Depending on the type of seed, the subspaces will be seeded differently
|
||||
|
||||
* ``None`` - All the subspaces will use a random initial seed
|
||||
* ``Int`` - The integer is used to seed the `Tuple` space that is used to generate seed values for each of the subspaces. Warning, this does not guarantee unique seeds for all of the subspaces.
|
||||
* ``List`` - Values used to seed the subspaces. This allows the seeding of multiple composite subspaces (``List(42, 54, ...``).
|
||||
* ``Int`` - The integer is used to seed the :class:`Tuple` space that is used to generate seed values for each of the subspaces. Warning, this does not guarantee unique seeds for all the subspaces.
|
||||
* ``List`` - Values used to seed the subspaces. This allows the seeding of multiple composite subspaces ``[42, 54, ...]``.
|
||||
|
||||
Args:
|
||||
seed: An optional list of ints or int to seed the (sub-)spaces.
|
||||
|
@@ -428,9 +428,7 @@ def flatten_space(space: Space[Any]) -> Box | Dict | Sequence | Tuple | Graph:
|
||||
Raises:
|
||||
NotImplementedError: if the space is not defined in :mod:`gymnasium.spaces`.
|
||||
|
||||
Example:
|
||||
Flatten spaces.Box:
|
||||
|
||||
Example - Flatten spaces.Box:
|
||||
>>> from gymnasium.spaces import Box
|
||||
>>> box = Box(0.0, 1.0, shape=(3, 4, 5))
|
||||
>>> box
|
||||
@@ -440,8 +438,7 @@ def flatten_space(space: Space[Any]) -> Box | Dict | Sequence | Tuple | Graph:
|
||||
>>> flatten(box, box.sample()) in flatten_space(box)
|
||||
True
|
||||
|
||||
Flatten spaces.Discrete:
|
||||
|
||||
Example - Flatten spaces.Discrete:
|
||||
>>> from gymnasium.spaces import Discrete
|
||||
>>> discrete = Discrete(5)
|
||||
>>> flatten_space(discrete)
|
||||
@@ -449,8 +446,7 @@ def flatten_space(space: Space[Any]) -> Box | Dict | Sequence | Tuple | Graph:
|
||||
>>> flatten(discrete, discrete.sample()) in flatten_space(discrete)
|
||||
True
|
||||
|
||||
Flatten spaces.Dict:
|
||||
|
||||
Example - Flatten spaces.Dict:
|
||||
>>> from gymnasium.spaces import Dict, Discrete, Box
|
||||
>>> space = Dict({"position": Discrete(2), "velocity": Box(0, 1, shape=(2, 2))})
|
||||
>>> flatten_space(space)
|
||||
@@ -458,8 +454,7 @@ def flatten_space(space: Space[Any]) -> Box | Dict | Sequence | Tuple | Graph:
|
||||
>>> flatten(space, space.sample()) in flatten_space(space)
|
||||
True
|
||||
|
||||
Flatten spaces.Graph:
|
||||
|
||||
Example - Flatten spaces.Graph:
|
||||
>>> from gymnasium.spaces import Graph, Discrete, Box
|
||||
>>> space = Graph(node_space=Box(low=-100, high=100, shape=(3, 4)), edge_space=Discrete(5))
|
||||
>>> flatten_space(space)
|
||||
|
@@ -1,4 +1,4 @@
|
||||
"""A set of functions for checking an environment details.
|
||||
"""A set of functions for checking an environment implementation.
|
||||
|
||||
This file is originally from the Stable Baselines3 repository hosted on GitHub
|
||||
(https://github.com/DLR-RM/stable-baselines3/)
|
||||
@@ -63,7 +63,7 @@ def data_equivalence(data_1, data_2) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def check_reset_seed(env: gym.Env) -> None:
|
||||
def check_reset_seed(env: gym.Env):
|
||||
"""Check that the environment can be reset with a seed.
|
||||
|
||||
Args:
|
||||
@@ -132,7 +132,7 @@ def check_reset_seed(env: gym.Env) -> None:
|
||||
)
|
||||
|
||||
|
||||
def check_reset_options(env: gym.Env) -> None:
|
||||
def check_reset_options(env: gym.Env):
|
||||
"""Check that the environment can be reset with options.
|
||||
|
||||
Args:
|
||||
@@ -160,7 +160,7 @@ def check_reset_options(env: gym.Env) -> None:
|
||||
)
|
||||
|
||||
|
||||
def check_reset_return_info_deprecation(env: gym.Env) -> None:
|
||||
def check_reset_return_info_deprecation(env: gym.Env):
|
||||
"""Makes sure support for deprecated `return_info` argument is dropped.
|
||||
|
||||
Args:
|
||||
@@ -177,7 +177,7 @@ def check_reset_return_info_deprecation(env: gym.Env) -> None:
|
||||
)
|
||||
|
||||
|
||||
def check_seed_deprecation(env: gym.Env) -> None:
|
||||
def check_seed_deprecation(env: gym.Env):
|
||||
"""Makes sure support for deprecated function `seed` is dropped.
|
||||
|
||||
Args:
|
||||
@@ -193,7 +193,7 @@ def check_seed_deprecation(env: gym.Env) -> None:
|
||||
)
|
||||
|
||||
|
||||
def check_reset_return_type(env: gym.Env) -> None:
|
||||
def check_reset_return_type(env: gym.Env):
|
||||
"""Checks that :meth:`reset` correctly returns a tuple of the form `(obs , info)`.
|
||||
|
||||
Args:
|
||||
@@ -218,7 +218,7 @@ def check_reset_return_type(env: gym.Env) -> None:
|
||||
), f"The second element returned by `env.reset()` was not a dictionary, actual type: {type(info)}"
|
||||
|
||||
|
||||
def check_space_limit(space: spaces.Space, space_type: str) -> None:
|
||||
def check_space_limit(space, space_type: str):
|
||||
"""Check the space limit for only the Box space as a test that only runs as part of `check_env`."""
|
||||
if isinstance(space, spaces.Box):
|
||||
if np.any(np.equal(space.low, -np.inf)):
|
||||
@@ -256,18 +256,19 @@ def check_space_limit(space: spaces.Space, space_type: str) -> None:
|
||||
check_space_limit(subspace, space_type)
|
||||
|
||||
|
||||
def check_env(env: gym.Env, warn: bool = None, skip_render_check: bool = False) -> None:
|
||||
"""Check that an environment follows Gym API.
|
||||
def check_env(env: gym.Env, warn: bool = None, skip_render_check: bool = False):
|
||||
"""Check that an environment follows Gymnasium's API.
|
||||
|
||||
This is an invasive function that calls the environment's reset and step.
|
||||
.. py:currentmodule:: gymnasium.Env
|
||||
|
||||
This is particularly useful when using a custom environment.
|
||||
Please take a look at https://gymnasium.farama.org/tutorials/gymnasium_basics/environment_creation/
|
||||
for more information about the API.
|
||||
To ensure that an environment is implemented "correctly", ``check_env`` checks that the :attr:`observation_space` and :attr:`action_space` are correct.
|
||||
Furthermore, the function will call the :meth:`reset`, :meth:`step` and :meth:`render` functions with a variety of values.
|
||||
|
||||
We highly recommend users calling this function after an environment is constructed and within a projects continuous integration to keep an environment update with Gymnasium's API.
|
||||
|
||||
Args:
|
||||
env: The Gym environment that will be checked
|
||||
warn: Ignored
|
||||
warn: Ignored, previously silenced particular warnings
|
||||
skip_render_check: Whether to skip the checks for the render method. True by default (useful for the CI)
|
||||
"""
|
||||
if warn is not None:
|
||||
|
@@ -12,6 +12,8 @@ __all__ = [
|
||||
"env_render_passive_checker",
|
||||
"env_reset_passive_checker",
|
||||
"env_step_passive_checker",
|
||||
"check_action_space",
|
||||
"check_observation_space",
|
||||
]
|
||||
|
||||
|
||||
|
@@ -1,6 +1,8 @@
|
||||
"""Utilities of visualising an environment."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import deque
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Callable, List
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -40,8 +42,8 @@ class PlayableGame:
|
||||
def __init__(
|
||||
self,
|
||||
env: Env,
|
||||
keys_to_action: Optional[Dict[Tuple[int, ...], int]] = None,
|
||||
zoom: Optional[float] = None,
|
||||
keys_to_action: dict[tuple[int, ...], int] | None = None,
|
||||
zoom: float | None = None,
|
||||
):
|
||||
"""Wraps an environment with a dictionary of keyboard buttons to action and if to zoom in on the environment.
|
||||
|
||||
@@ -66,7 +68,7 @@ class PlayableGame:
|
||||
self.running = True
|
||||
|
||||
def _get_relevant_keys(
|
||||
self, keys_to_action: Optional[Dict[Tuple[int], int]] = None
|
||||
self, keys_to_action: dict[tuple[int], int] | None = None
|
||||
) -> set:
|
||||
if keys_to_action is None:
|
||||
if hasattr(self.env, "get_keys_to_action"):
|
||||
@@ -83,7 +85,7 @@ class PlayableGame:
|
||||
relevant_keys = set(sum((list(k) for k in keys_to_action.keys()), []))
|
||||
return relevant_keys
|
||||
|
||||
def _get_video_size(self, zoom: Optional[float] = None) -> Tuple[int, int]:
|
||||
def _get_video_size(self, zoom: float | None = None) -> tuple[int, int]:
|
||||
rendered = self.env.render()
|
||||
if isinstance(rendered, List):
|
||||
rendered = rendered[-1]
|
||||
@@ -123,7 +125,7 @@ class PlayableGame:
|
||||
|
||||
|
||||
def display_arr(
|
||||
screen: Surface, arr: np.ndarray, video_size: Tuple[int, int], transpose: bool
|
||||
screen: Surface, arr: np.ndarray, video_size: tuple[int, int], transpose: bool
|
||||
):
|
||||
"""Displays a numpy array on screen.
|
||||
|
||||
@@ -147,15 +149,15 @@ def display_arr(
|
||||
|
||||
def play(
|
||||
env: Env,
|
||||
transpose: Optional[bool] = True,
|
||||
fps: Optional[int] = None,
|
||||
zoom: Optional[float] = None,
|
||||
callback: Optional[Callable] = None,
|
||||
keys_to_action: Optional[Dict[Union[Tuple[Union[str, int]], str], ActType]] = None,
|
||||
seed: Optional[int] = None,
|
||||
transpose: bool | None = True,
|
||||
fps: int | None = None,
|
||||
zoom: float | None = None,
|
||||
callback: Callable | None = None,
|
||||
keys_to_action: dict[tuple[str | int] | str, ActType] | None = None,
|
||||
seed: int | None = None,
|
||||
noop: ActType = 0,
|
||||
):
|
||||
"""Allows one to play the game using keyboard.
|
||||
"""Allows the user to play the environment using a keyboard.
|
||||
|
||||
Args:
|
||||
env: Environment to use for playing.
|
||||
@@ -164,13 +166,14 @@ def play(
|
||||
``env.metadata["render_fps""]`` (or 30, if the environment does not specify "render_fps") is used.
|
||||
zoom: Zoom the observation in, ``zoom`` amount, should be positive float
|
||||
callback: If a callback is provided, it will be executed after every step. It takes the following input:
|
||||
obs_t: observation before performing action
|
||||
obs_tp1: observation after performing action
|
||||
action: action that was executed
|
||||
rew: reward that was received
|
||||
terminated: whether the environment is terminated or not
|
||||
truncated: whether the environment is truncated or not
|
||||
info: debug info
|
||||
|
||||
* obs_t: observation before performing action
|
||||
* obs_tp1: observation after performing action
|
||||
* action: action that was executed
|
||||
* rew: reward that was received
|
||||
* terminated: whether the environment is terminated or not
|
||||
* truncated: whether the environment is truncated or not
|
||||
* info: debug info
|
||||
keys_to_action: Mapping from keys pressed to action performed.
|
||||
Different formats are supported: Key combinations can either be expressed as a tuple of unicode code
|
||||
points of the keys, as a tuple of characters, or as a string where each character of the string represents
|
||||
@@ -205,28 +208,29 @@ def play(
|
||||
noop: The action used when no key input has been entered, or the entered key combination is unknown.
|
||||
|
||||
Example:
|
||||
>>> import gymnasium as gym
|
||||
>>> from gymnasium.utils.play import play
|
||||
>>> play(gym.make("CarRacing-v2", render_mode="rgb_array"), keys_to_action={ # doctest: +SKIP
|
||||
... "w": np.array([0, 0.7, 0]),
|
||||
... "a": np.array([-1, 0, 0]),
|
||||
... "s": np.array([0, 0, 1]),
|
||||
... "d": np.array([1, 0, 0]),
|
||||
... "wa": np.array([-1, 0.7, 0]),
|
||||
... "dw": np.array([1, 0.7, 0]),
|
||||
... "ds": np.array([1, 0, 1]),
|
||||
... "as": np.array([-1, 0, 1]),
|
||||
... }, noop=np.array([0,0,0]))
|
||||
>>> play(gym.make("CarRacing-v2", render_mode="rgb_array"), # doctest: +SKIP
|
||||
... keys_to_action={
|
||||
... "w": np.array([0, 0.7, 0]),
|
||||
... "a": np.array([-1, 0, 0]),
|
||||
... "s": np.array([0, 0, 1]),
|
||||
... "d": np.array([1, 0, 0]),
|
||||
... "wa": np.array([-1, 0.7, 0]),
|
||||
... "dw": np.array([1, 0.7, 0]),
|
||||
... "ds": np.array([1, 0, 1]),
|
||||
... "as": np.array([-1, 0, 1]),
|
||||
... },
|
||||
... noop=np.array([0, 0, 0])
|
||||
... )
|
||||
|
||||
Above code works also if the environment is wrapped, so it's particularly useful in
|
||||
verifying that the frame-level preprocessing does not render the game
|
||||
unplayable.
|
||||
|
||||
If you wish to plot real time statistics as you play, you can use
|
||||
:class:`gym.utils.play.PlayPlot`. Here's a sample code for plotting the reward
|
||||
:class:`PlayPlot`. Here's a sample code for plotting the reward
|
||||
for last 150 steps.
|
||||
|
||||
>>> import gymnasium as gym
|
||||
>>> from gymnasium.utils.play import PlayPlot, play
|
||||
>>> def callback(obs_t, obs_tp1, action, rew, terminated, truncated, info):
|
||||
... return [rew,]
|
||||
@@ -321,7 +325,7 @@ class PlayPlot:
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, callback: Callable, horizon_timesteps: int, plot_names: List[str]
|
||||
self, callback: Callable, horizon_timesteps: int, plot_names: list[str]
|
||||
):
|
||||
"""Constructor of :class:`PlayPlot`.
|
||||
|
||||
@@ -355,7 +359,7 @@ class PlayPlot:
|
||||
for axis, name in zip(self.ax, plot_names):
|
||||
axis.set_title(name)
|
||||
self.t = 0
|
||||
self.cur_plot: List[Optional[plt.Axes]] = [None for _ in range(num_plots)]
|
||||
self.cur_plot: list[plt.Axes | None] = [None for _ in range(num_plots)]
|
||||
self.data = [deque(maxlen=horizon_timesteps) for _ in range(num_plots)]
|
||||
|
||||
def callback(
|
||||
|
@@ -15,9 +15,9 @@ except ImportError as e:
|
||||
|
||||
|
||||
def capped_cubic_video_schedule(episode_id: int) -> bool:
|
||||
"""The default episode trigger.
|
||||
r"""The default episode trigger.
|
||||
|
||||
This function will trigger recordings at the episode indices 0, 1, 4, 8, 27, ..., :math:`k^3`, ..., 729, 1000, 2000, 3000, ...
|
||||
This function will trigger recordings at the episode indices :math:`\{0, 1, 4, 8, 27, ..., k^3, ..., 729, 1000, 2000, 3000, ...\}`
|
||||
|
||||
Args:
|
||||
episode_id: The episode number
|
||||
|
@@ -1,22 +1,29 @@
|
||||
"""Set of random number generator functions: seeding, generator, hashing seeds."""
|
||||
from typing import Any, Optional, Tuple
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
|
||||
from gymnasium import error
|
||||
|
||||
|
||||
def np_random(seed: Optional[int] = None) -> Tuple[np.random.Generator, Any]:
|
||||
"""Generates a random number generator from the seed and returns the Generator and seed.
|
||||
def np_random(seed: int | None = None) -> tuple[np.random.Generator, int]:
|
||||
"""Returns a NumPy random number generator (RNG) along with seed value from the inputted seed.
|
||||
|
||||
If ``seed`` is ``None`` then a **random** seed will be generated as the RNG's initial seed.
|
||||
This randomly selected seed is returned as the second value of the tuple.
|
||||
|
||||
.. py:currentmodule:: gymnasium.Env
|
||||
|
||||
This function is called in :meth:`reset` to reset an environment's initial RNG.
|
||||
|
||||
Args:
|
||||
seed: The seed used to create the generator
|
||||
|
||||
Returns:
|
||||
The generator and resulting seed
|
||||
A NumPy-based Random Number Generator and generator seed
|
||||
|
||||
Raises:
|
||||
Error: Seed must be a non-negative integer or omitted
|
||||
Error: Seed must be a non-negative integer
|
||||
"""
|
||||
if seed is not None and not (isinstance(seed, int) and 0 <= seed):
|
||||
if isinstance(seed, int) is False:
|
||||
|
@@ -1,4 +1,6 @@
|
||||
"""Contains methods for step compatibility, from old-to-new and new-to-old API."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import SupportsFloat, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -23,13 +25,15 @@ TerminatedTruncatedStepType = Tuple[
|
||||
|
||||
|
||||
def convert_to_terminated_truncated_step_api(
|
||||
step_returns: Union[DoneStepType, TerminatedTruncatedStepType], is_vector_env=False
|
||||
step_returns: DoneStepType | TerminatedTruncatedStepType, is_vector_env=False
|
||||
) -> TerminatedTruncatedStepType:
|
||||
"""Function to transform step returns to new step API irrespective of input API.
|
||||
|
||||
.. py:currentmodule:: gymnasium.Env
|
||||
|
||||
Args:
|
||||
step_returns (tuple): Items returned by step(). Can be (obs, rew, done, info) or (obs, rew, terminated, truncated, info)
|
||||
is_vector_env (bool): Whether the step_returns are from a vector environment
|
||||
step_returns (tuple): Items returned by :meth:`step`. Can be ``(obs, rew, done, info)`` or ``(obs, rew, terminated, truncated, info)``
|
||||
is_vector_env (bool): Whether the ``step_returns`` are from a vector environment
|
||||
"""
|
||||
if len(step_returns) == 5:
|
||||
return step_returns
|
||||
@@ -75,14 +79,16 @@ def convert_to_terminated_truncated_step_api(
|
||||
|
||||
|
||||
def convert_to_done_step_api(
|
||||
step_returns: Union[TerminatedTruncatedStepType, DoneStepType],
|
||||
step_returns: TerminatedTruncatedStepType | DoneStepType,
|
||||
is_vector_env: bool = False,
|
||||
) -> DoneStepType:
|
||||
"""Function to transform step returns to old step API irrespective of input API.
|
||||
|
||||
.. py:currentmodule:: gymnasium.Env
|
||||
|
||||
Args:
|
||||
step_returns (tuple): Items returned by step(). Can be (obs, rew, done, info) or (obs, rew, terminated, truncated, info)
|
||||
is_vector_env (bool): Whether the step_returns are from a vector environment
|
||||
step_returns (tuple): Items returned by :meth:`step`. Can be ``(obs, rew, done, info)`` or ``(obs, rew, terminated, truncated, info)``
|
||||
is_vector_env (bool): Whether the ``step_returns`` are from a vector environment
|
||||
"""
|
||||
if len(step_returns) == 4:
|
||||
return step_returns
|
||||
@@ -130,38 +136,41 @@ def convert_to_done_step_api(
|
||||
|
||||
|
||||
def step_api_compatibility(
|
||||
step_returns: Union[TerminatedTruncatedStepType, DoneStepType],
|
||||
step_returns: TerminatedTruncatedStepType | DoneStepType,
|
||||
output_truncation_bool: bool = True,
|
||||
is_vector_env: bool = False,
|
||||
) -> Union[TerminatedTruncatedStepType, DoneStepType]:
|
||||
"""Function to transform step returns to the API specified by `output_truncation_bool` bool.
|
||||
) -> TerminatedTruncatedStepType | DoneStepType:
|
||||
"""Function to transform step returns to the API specified by ``output_truncation_bool``.
|
||||
|
||||
Done (old) step API refers to step() method returning (observation, reward, done, info)
|
||||
Terminated Truncated (new) step API refers to step() method returning (observation, reward, terminated, truncated, info)
|
||||
.. py:currentmodule:: gymnasium.Env
|
||||
|
||||
Done (old) step API refers to :meth:`step` method returning ``(observation, reward, done, info)``
|
||||
Terminated Truncated (new) step API refers to :meth:`step` method returning ``(observation, reward, terminated, truncated, info)``
|
||||
(Refer to docs for details on the API change)
|
||||
|
||||
Args:
|
||||
step_returns (tuple): Items returned by step(). Can be (obs, rew, done, info) or (obs, rew, terminated, truncated, info)
|
||||
output_truncation_bool (bool): Whether the output should return two booleans (new API) or one (old) (True by default)
|
||||
is_vector_env (bool): Whether the step_returns are from a vector environment
|
||||
step_returns (tuple): Items returned by :meth:`step`. Can be ``(obs, rew, done, info)`` or ``(obs, rew, terminated, truncated, info)``
|
||||
output_truncation_bool (bool): Whether the output should return two booleans (new API) or one (old) (``True`` by default)
|
||||
is_vector_env (bool): Whether the ``step_returns`` are from a vector environment
|
||||
|
||||
Returns:
|
||||
step_returns (tuple): Depending on `output_truncation_bool` bool, it can return `(obs, rew, done, info)` or `(obs, rew, terminated, truncated, info)`
|
||||
step_returns (tuple): Depending on ``output_truncation_bool``, it can return ``(obs, rew, done, info)`` or ``(obs, rew, terminated, truncated, info)``
|
||||
|
||||
Example:
|
||||
This function can be used to ensure compatibility in step interfaces with conflicting API. Eg. if env is written in old API,
|
||||
wrapper is written in new API, and the final step output is desired to be in old API.
|
||||
This function can be used to ensure compatibility in step interfaces with conflicting API. E.g. if env is written in old API,
|
||||
wrapper is written in new API, and the final step output is desired to be in old API.
|
||||
|
||||
>>> import gymnasium as gym
|
||||
>>> env = gym.make("CartPole-v0")
|
||||
>>> _ = env.reset()
|
||||
>>> obs, rewards, done, info = step_api_compatibility(env.step(0), output_truncation_bool=False)
|
||||
>>> obs, rewards, terminated, truncated, info = step_api_compatibility(env.step(0), output_truncation_bool=True)
|
||||
>>> _, _ = env.reset()
|
||||
>>> obs, reward, done, info = step_api_compatibility(env.step(0), output_truncation_bool=False)
|
||||
>>> obs, reward, terminated, truncated, info = step_api_compatibility(env.step(0), output_truncation_bool=True)
|
||||
|
||||
>>> vec_env = gym.vector.make("CartPole-v0")
|
||||
>>> _ = vec_env.reset()
|
||||
>>> vec_env = gym.make_vec("CartPole-v0", vectorization_mode="sync")
|
||||
>>> _, _ = vec_env.reset()
|
||||
>>> obs, rewards, dones, infos = step_api_compatibility(vec_env.step([0]), is_vector_env=True, output_truncation_bool=False)
|
||||
>>> obs, rewards, terminated, truncated, info = step_api_compatibility(vec_env.step([0]), is_vector_env=True, output_truncation_bool=True)
|
||||
>>> obs, rewards, terminations, truncations, infos = step_api_compatibility(vec_env.step([0]), is_vector_env=True, output_truncation_bool=True)
|
||||
|
||||
"""
|
||||
if output_truncation_bool:
|
||||
return convert_to_terminated_truncated_step_api(step_returns, is_vector_env)
|
||||
|
@@ -1,85 +1,23 @@
|
||||
"""Module for vector environments."""
|
||||
from typing import Callable, Iterable, List, Optional, Union
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.core import Env
|
||||
"""Experimental vector env API."""
|
||||
from gymnasium.vector import utils
|
||||
from gymnasium.vector.async_vector_env import AsyncVectorEnv
|
||||
from gymnasium.vector.sync_vector_env import SyncVectorEnv
|
||||
from gymnasium.vector.vector_env import VectorEnv, VectorEnvWrapper
|
||||
from gymnasium.vector.vector_env import (
|
||||
VectorActionWrapper,
|
||||
VectorEnv,
|
||||
VectorObservationWrapper,
|
||||
VectorRewardWrapper,
|
||||
VectorWrapper,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AsyncVectorEnv",
|
||||
"SyncVectorEnv",
|
||||
"VectorEnv",
|
||||
"VectorEnvWrapper",
|
||||
"make",
|
||||
"VectorWrapper",
|
||||
"VectorObservationWrapper",
|
||||
"VectorActionWrapper",
|
||||
"VectorRewardWrapper",
|
||||
"SyncVectorEnv",
|
||||
"AsyncVectorEnv",
|
||||
"utils",
|
||||
]
|
||||
|
||||
|
||||
def make(
|
||||
id: str,
|
||||
num_envs: int = 1,
|
||||
asynchronous: bool = True,
|
||||
wrappers: Optional[Union[Callable[[Env], Env], List[Callable[[Env], Env]]]] = None,
|
||||
disable_env_checker: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> VectorEnv:
|
||||
"""Create a vectorized environment from multiple copies of an environment, from its id.
|
||||
|
||||
Args:
|
||||
id: The environment ID. This must be a valid ID from the registry.
|
||||
num_envs: Number of copies of the environment.
|
||||
asynchronous: If `True`, wraps the environments in an :class:`AsyncVectorEnv` (which uses `multiprocessing` to run the environments in parallel). If ``False``, wraps the environments in a :class:`SyncVectorEnv`.
|
||||
wrappers: If not ``None``, then apply the wrappers to each internal environment during creation.
|
||||
disable_env_checker: If to run the env checker for the first environment only. None will default to the environment spec `disable_env_checker` parameter
|
||||
(that is by default False), otherwise will run according to this argument (True = not run, False = run)
|
||||
**kwargs: Keywords arguments applied during `gym.make`
|
||||
|
||||
Returns:
|
||||
The vectorized environment.
|
||||
|
||||
Example:
|
||||
>>> import gymnasium as gym
|
||||
>>> env = gym.vector.make('CartPole-v1', num_envs=3)
|
||||
>>> env.reset(seed=42)
|
||||
(array([[ 0.0273956 , -0.00611216, 0.03585979, 0.0197368 ],
|
||||
[ 0.01522993, -0.04562247, -0.04799704, 0.03392126],
|
||||
[-0.03774345, -0.02418869, -0.00942293, 0.0469184 ]],
|
||||
dtype=float32), {})
|
||||
"""
|
||||
gym.logger.warn(
|
||||
"`gymnasium.vector.make(...)` is deprecated and will be replaced by `gymnasium.make_vec(...)` in v1.0"
|
||||
)
|
||||
|
||||
def create_env(env_num: int) -> Callable[[], Env]:
|
||||
"""Creates an environment that can enable or disable the environment checker."""
|
||||
# If the env_num > 0 then disable the environment checker otherwise use the parameter
|
||||
_disable_env_checker = True if env_num > 0 else disable_env_checker
|
||||
|
||||
def _make_env() -> Env:
|
||||
env = gym.envs.registration.make(
|
||||
id,
|
||||
disable_env_checker=_disable_env_checker,
|
||||
**kwargs,
|
||||
)
|
||||
if wrappers is not None:
|
||||
if callable(wrappers):
|
||||
env = wrappers(env)
|
||||
elif isinstance(wrappers, Iterable) and all(
|
||||
[callable(w) for w in wrappers]
|
||||
):
|
||||
for wrapper in wrappers:
|
||||
env = wrapper(env)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return env
|
||||
|
||||
return _make_env
|
||||
|
||||
env_fns = [
|
||||
create_env(disable_env_checker or env_num > 0) for env_num in range(num_envs)
|
||||
]
|
||||
return AsyncVectorEnv(env_fns) if asynchronous else SyncVectorEnv(env_fns)
|
||||
|
@@ -1,17 +1,19 @@
|
||||
"""An async vector environment."""
|
||||
import multiprocessing as mp
|
||||
from __future__ import annotations
|
||||
|
||||
import multiprocessing
|
||||
import sys
|
||||
import time
|
||||
from copy import deepcopy
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union
|
||||
from multiprocessing import Queue
|
||||
from multiprocessing.connection import Connection
|
||||
from typing import Any, Callable, Sequence
|
||||
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium import logger
|
||||
from gymnasium.core import Env, ObsType
|
||||
from gymnasium.core import ActType, Env, ObsType, RenderFrame
|
||||
from gymnasium.error import (
|
||||
AlreadyPendingCallError,
|
||||
ClosedEnvironmentError,
|
||||
@@ -20,6 +22,7 @@ from gymnasium.error import (
|
||||
)
|
||||
from gymnasium.vector.utils import (
|
||||
CloudpickleWrapper,
|
||||
batch_space,
|
||||
clear_mpi_env_vars,
|
||||
concatenate,
|
||||
create_empty_array,
|
||||
@@ -28,13 +31,15 @@ from gymnasium.vector.utils import (
|
||||
read_from_shared_memory,
|
||||
write_to_shared_memory,
|
||||
)
|
||||
from gymnasium.vector.vector_env import VectorEnv
|
||||
from gymnasium.vector.vector_env import ArrayType, VectorEnv
|
||||
|
||||
|
||||
__all__ = ["AsyncVectorEnv"]
|
||||
__all__ = ["AsyncVectorEnv", "AsyncState"]
|
||||
|
||||
|
||||
class AsyncState(Enum):
|
||||
"""The AsyncVectorEnv possible states given the different actions."""
|
||||
|
||||
DEFAULT = "default"
|
||||
WAITING_RESET = "reset"
|
||||
WAITING_STEP = "step"
|
||||
@@ -48,39 +53,57 @@ class AsyncVectorEnv(VectorEnv):
|
||||
|
||||
Example:
|
||||
>>> import gymnasium as gym
|
||||
>>> env = gym.vector.AsyncVectorEnv([
|
||||
>>> envs = gym.make_vec("Pendulum-v1", num_envs=2, vectorization_mode="async")
|
||||
>>> envs
|
||||
AsyncVectorEnv(Pendulum-v1, num_envs=2)
|
||||
>>> envs = gym.vector.AsyncVectorEnv([
|
||||
... lambda: gym.make("Pendulum-v1", g=9.81),
|
||||
... lambda: gym.make("Pendulum-v1", g=1.62)
|
||||
... ])
|
||||
>>> env.reset(seed=42)
|
||||
(array([[-0.14995256, 0.9886932 , -0.12224312],
|
||||
[ 0.5760367 , 0.8174238 , -0.91244936]], dtype=float32), {})
|
||||
>>> envs
|
||||
AsyncVectorEnv(num_envs=2)
|
||||
>>> observations, infos = envs.reset(seed=42)
|
||||
>>> observations
|
||||
array([[-0.14995256, 0.9886932 , -0.12224312],
|
||||
[ 0.5760367 , 0.8174238 , -0.91244936]], dtype=float32)
|
||||
>>> infos
|
||||
{}
|
||||
>>> _ = envs.action_space.seed(123)
|
||||
>>> observations, rewards, terminations, truncations, infos = envs.step(envs.action_space.sample())
|
||||
>>> observations
|
||||
array([[-0.1851753 , 0.98270553, 0.714599 ],
|
||||
[ 0.6193494 , 0.7851154 , -1.0808398 ]], dtype=float32)
|
||||
>>> rewards
|
||||
array([-2.96495728, -1.00214607])
|
||||
>>> terminations
|
||||
array([False, False])
|
||||
>>> truncations
|
||||
array([False, False])
|
||||
>>> infos
|
||||
{}
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env_fns: Sequence[Callable[[], Env]],
|
||||
observation_space: Optional[gym.Space] = None,
|
||||
action_space: Optional[gym.Space] = None,
|
||||
shared_memory: bool = True,
|
||||
copy: bool = True,
|
||||
context: Optional[str] = None,
|
||||
context: str | None = None,
|
||||
daemon: bool = True,
|
||||
worker: Optional[Callable] = None,
|
||||
worker: Callable[
|
||||
[int, Callable[[], Env], Connection, Connection, bool, Queue], None
|
||||
]
|
||||
| None = None,
|
||||
):
|
||||
"""Vectorized environment that runs multiple environments in parallel.
|
||||
|
||||
Args:
|
||||
env_fns: Functions that create the environments.
|
||||
observation_space: Observation space of a single environment. If ``None``,
|
||||
then the observation space of the first environment is taken.
|
||||
action_space: Action space of a single environment. If ``None``,
|
||||
then the action space of the first environment is taken.
|
||||
shared_memory: If ``True``, then the observations from the worker processes are communicated back through
|
||||
shared variables. This can improve the efficiency if the observations are large (e.g. images).
|
||||
copy: If ``True``, then the :meth:`~AsyncVectorEnv.reset` and :meth:`~AsyncVectorEnv.step` methods
|
||||
copy: If ``True``, then the :meth:`AsyncVectorEnv.reset` and :meth:`AsyncVectorEnv.step` methods
|
||||
return a copy of the observations.
|
||||
context: Context for `multiprocessing`_. If ``None``, then the default context is used.
|
||||
context: Context for `multiprocessing`. If ``None``, then the default context is used.
|
||||
daemon: If ``True``, then subprocesses have ``daemon`` flag turned on; that is, they will quit if
|
||||
the head process quits. However, ``daemon=True`` prevents subprocesses to spawn children,
|
||||
so for some environments you may want to have it set to ``False``.
|
||||
@@ -98,24 +121,33 @@ class AsyncVectorEnv(VectorEnv):
|
||||
ValueError: If observation_space is a custom space (i.e. not a default space in Gym,
|
||||
such as gymnasium.spaces.Box, gymnasium.spaces.Discrete, or gymnasium.spaces.Dict) and shared_memory is True.
|
||||
"""
|
||||
ctx = mp.get_context(context)
|
||||
self.env_fns = env_fns
|
||||
self.shared_memory = shared_memory
|
||||
self.copy = copy
|
||||
dummy_env = env_fns[0]()
|
||||
self.metadata = dummy_env.metadata
|
||||
|
||||
if (observation_space is None) or (action_space is None):
|
||||
observation_space = observation_space or dummy_env.observation_space
|
||||
action_space = action_space or dummy_env.action_space
|
||||
self.num_envs = len(env_fns)
|
||||
|
||||
# This would be nice to get rid of, but without it there's a deadlock between shared memory and pipes
|
||||
# Create a dummy environment to gather the metadata and observation / action space of the environment
|
||||
dummy_env = env_fns[0]()
|
||||
|
||||
# As we support `make_vec(spec)` then we can't include a `spec = dummy_env.spec` as this doesn't guarantee we can actual recreate the vector env.
|
||||
self.metadata = dummy_env.metadata
|
||||
self.render_mode = dummy_env.render_mode
|
||||
|
||||
self.single_observation_space = dummy_env.observation_space
|
||||
self.single_action_space = dummy_env.action_space
|
||||
|
||||
self.observation_space = batch_space(
|
||||
self.single_observation_space, self.num_envs
|
||||
)
|
||||
self.action_space = batch_space(self.single_action_space, self.num_envs)
|
||||
|
||||
dummy_env.close()
|
||||
del dummy_env
|
||||
super().__init__(
|
||||
num_envs=len(env_fns),
|
||||
observation_space=observation_space,
|
||||
action_space=action_space,
|
||||
)
|
||||
|
||||
# Generate the multiprocessing context for the observation buffer
|
||||
ctx = multiprocessing.get_context(context)
|
||||
if self.shared_memory:
|
||||
try:
|
||||
_obs_buffer = create_shared_memory(
|
||||
@@ -126,12 +158,9 @@ class AsyncVectorEnv(VectorEnv):
|
||||
)
|
||||
except CustomSpaceError as e:
|
||||
raise ValueError(
|
||||
"Using `shared_memory=True` in `AsyncVectorEnv` "
|
||||
"is incompatible with non-standard Gymnasium observation spaces "
|
||||
"(i.e. custom spaces inheriting from `gymnasium.Space`), and is "
|
||||
"only compatible with default Gymnasium spaces (e.g. `Box`, "
|
||||
"`Tuple`, `Dict`) for batching. Set `shared_memory=False` "
|
||||
"if you use custom observation spaces."
|
||||
"Using `shared_memory=True` in `AsyncVectorEnv` is incompatible with non-standard Gymnasium observation spaces (i.e. custom spaces inheriting from `gymnasium.Space`), "
|
||||
"and is only compatible with default Gymnasium spaces (e.g. `Box`, `Tuple`, `Dict`) for batching. "
|
||||
"Set `shared_memory=False` if you use custom observation spaces."
|
||||
) from e
|
||||
else:
|
||||
_obs_buffer = None
|
||||
@@ -141,8 +170,7 @@ class AsyncVectorEnv(VectorEnv):
|
||||
|
||||
self.parent_pipes, self.processes = [], []
|
||||
self.error_queue = ctx.Queue()
|
||||
target = _worker_shared_memory if self.shared_memory else _worker
|
||||
target = worker or target
|
||||
target = worker or _async_worker
|
||||
with clear_mpi_env_vars():
|
||||
for idx, env_fn in enumerate(self.env_fns):
|
||||
parent_pipe, child_pipe = ctx.Pipe()
|
||||
@@ -169,10 +197,28 @@ class AsyncVectorEnv(VectorEnv):
|
||||
self._state = AsyncState.DEFAULT
|
||||
self._check_spaces()
|
||||
|
||||
def reset(
|
||||
self,
|
||||
*,
|
||||
seed: int | list[int] | None = None,
|
||||
options: dict[str, Any] | None = None,
|
||||
) -> tuple[ObsType, dict[str, Any]]:
|
||||
"""Resets all sub-environments in parallel and return a batch of concatenated observations and info.
|
||||
|
||||
Args:
|
||||
seed: The environment reset seeds
|
||||
options: If to return the options
|
||||
|
||||
Returns:
|
||||
A batch of observations and info from the vectorized environment.
|
||||
"""
|
||||
self.reset_async(seed=seed, options=options)
|
||||
return self.reset_wait()
|
||||
|
||||
def reset_async(
|
||||
self,
|
||||
seed: Optional[Union[int, List[int]]] = None,
|
||||
options: Optional[dict] = None,
|
||||
seed: int | list[int] | None = None,
|
||||
options: dict | None = None,
|
||||
):
|
||||
"""Send calls to the :obj:`reset` methods of the sub-environments.
|
||||
|
||||
@@ -192,38 +238,29 @@ class AsyncVectorEnv(VectorEnv):
|
||||
|
||||
if seed is None:
|
||||
seed = [None for _ in range(self.num_envs)]
|
||||
if isinstance(seed, int):
|
||||
elif isinstance(seed, int):
|
||||
seed = [seed + i for i in range(self.num_envs)]
|
||||
assert len(seed) == self.num_envs
|
||||
|
||||
if self._state != AsyncState.DEFAULT:
|
||||
raise AlreadyPendingCallError(
|
||||
f"Calling `reset_async` while waiting for a pending call to `{self._state.value}` to complete",
|
||||
self._state.value,
|
||||
str(self._state.value),
|
||||
)
|
||||
|
||||
for pipe, single_seed in zip(self.parent_pipes, seed):
|
||||
single_kwargs = {}
|
||||
if single_seed is not None:
|
||||
single_kwargs["seed"] = single_seed
|
||||
if options is not None:
|
||||
single_kwargs["options"] = options
|
||||
|
||||
pipe.send(("reset", single_kwargs))
|
||||
for pipe, env_seed in zip(self.parent_pipes, seed):
|
||||
env_kwargs = {"seed": env_seed, "options": options}
|
||||
pipe.send(("reset", env_kwargs))
|
||||
self._state = AsyncState.WAITING_RESET
|
||||
|
||||
def reset_wait(
|
||||
self,
|
||||
timeout: Optional[Union[int, float]] = None,
|
||||
seed: Optional[int] = None,
|
||||
options: Optional[dict] = None,
|
||||
) -> Union[ObsType, Tuple[ObsType, dict]]:
|
||||
timeout: int | float | None = None,
|
||||
) -> tuple[ObsType, dict[str, Any]]:
|
||||
"""Waits for the calls triggered by :meth:`reset_async` to finish and returns the results.
|
||||
|
||||
Args:
|
||||
timeout: Number of seconds before the call to `reset_wait` times out. If `None`, the call to `reset_wait` never times out.
|
||||
seed: ignored
|
||||
options: ignored
|
||||
timeout: Number of seconds before the call to ``reset_wait`` times out. If `None`, the call to ``reset_wait`` never times out.
|
||||
|
||||
Returns:
|
||||
A tuple of batched observations and list of dictionaries
|
||||
@@ -240,15 +277,14 @@ class AsyncVectorEnv(VectorEnv):
|
||||
AsyncState.WAITING_RESET.value,
|
||||
)
|
||||
|
||||
if not self._poll(timeout):
|
||||
if not self._poll_pipe_envs(timeout):
|
||||
self._state = AsyncState.DEFAULT
|
||||
raise mp.TimeoutError(
|
||||
raise multiprocessing.TimeoutError(
|
||||
f"The call to `reset_wait` has timed out after {timeout} second(s)."
|
||||
)
|
||||
|
||||
results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
|
||||
self._raise_if_errors(successes)
|
||||
self._state = AsyncState.DEFAULT
|
||||
|
||||
infos = {}
|
||||
results, info_data = zip(*results)
|
||||
@@ -260,13 +296,28 @@ class AsyncVectorEnv(VectorEnv):
|
||||
self.single_observation_space, results, self.observations
|
||||
)
|
||||
|
||||
self._state = AsyncState.DEFAULT
|
||||
return (deepcopy(self.observations) if self.copy else self.observations), infos
|
||||
|
||||
def step_async(self, actions: np.ndarray):
|
||||
"""Send the calls to :obj:`step` to each sub-environment.
|
||||
def step(
|
||||
self, actions: ActType
|
||||
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict[str, Any]]:
|
||||
"""Take an action for each parallel environment.
|
||||
|
||||
Args:
|
||||
actions: Batch of actions. element of :attr:`~VectorEnv.action_space`
|
||||
actions: element of :attr:`action_space` batch of actions.
|
||||
|
||||
Returns:
|
||||
Batch of (observations, rewards, terminations, truncations, infos)
|
||||
"""
|
||||
self.step_async(actions)
|
||||
return self.step_wait()
|
||||
|
||||
def step_async(self, actions: np.ndarray):
|
||||
"""Send the calls to :meth:`Env.step` to each sub-environment.
|
||||
|
||||
Args:
|
||||
actions: Batch of actions. element of :attr:`VectorEnv.action_space`
|
||||
|
||||
Raises:
|
||||
ClosedEnvironmentError: If the environment was closed (if :meth:`close` was previously called).
|
||||
@@ -279,17 +330,17 @@ class AsyncVectorEnv(VectorEnv):
|
||||
if self._state != AsyncState.DEFAULT:
|
||||
raise AlreadyPendingCallError(
|
||||
f"Calling `step_async` while waiting for a pending call to `{self._state.value}` to complete.",
|
||||
self._state.value,
|
||||
str(self._state.value),
|
||||
)
|
||||
|
||||
actions = iterate(self.action_space, actions)
|
||||
for pipe, action in zip(self.parent_pipes, actions):
|
||||
iter_actions = iterate(self.action_space, actions)
|
||||
for pipe, action in zip(self.parent_pipes, iter_actions):
|
||||
pipe.send(("step", action))
|
||||
self._state = AsyncState.WAITING_STEP
|
||||
|
||||
def step_wait(
|
||||
self, timeout: Optional[Union[int, float]] = None
|
||||
) -> Tuple[Any, NDArray[Any], NDArray[Any], NDArray[Any], dict]:
|
||||
self, timeout: int | float | None = None
|
||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, dict]:
|
||||
"""Wait for the calls to :obj:`step` in each sub-environment to finish.
|
||||
|
||||
Args:
|
||||
@@ -310,44 +361,61 @@ class AsyncVectorEnv(VectorEnv):
|
||||
AsyncState.WAITING_STEP.value,
|
||||
)
|
||||
|
||||
if not self._poll(timeout):
|
||||
if not self._poll_pipe_envs(timeout):
|
||||
self._state = AsyncState.DEFAULT
|
||||
raise mp.TimeoutError(
|
||||
raise multiprocessing.TimeoutError(
|
||||
f"The call to `step_wait` has timed out after {timeout} second(s)."
|
||||
)
|
||||
|
||||
observations_list, rewards, terminateds, truncateds, infos = [], [], [], [], {}
|
||||
observations, rewards, terminations, truncations, infos = [], [], [], [], {}
|
||||
successes = []
|
||||
for i, pipe in enumerate(self.parent_pipes):
|
||||
result, success = pipe.recv()
|
||||
for env_idx, pipe in enumerate(self.parent_pipes):
|
||||
env_step_return, success = pipe.recv()
|
||||
|
||||
successes.append(success)
|
||||
if success:
|
||||
obs, rew, terminated, truncated, info = result
|
||||
|
||||
observations_list.append(obs)
|
||||
rewards.append(rew)
|
||||
terminateds.append(terminated)
|
||||
truncateds.append(truncated)
|
||||
infos = self._add_info(infos, info, i)
|
||||
observations.append(env_step_return[0])
|
||||
rewards.append(env_step_return[1])
|
||||
terminations.append(env_step_return[2])
|
||||
truncations.append(env_step_return[3])
|
||||
infos = self._add_info(infos, env_step_return[4], env_idx)
|
||||
|
||||
self._raise_if_errors(successes)
|
||||
self._state = AsyncState.DEFAULT
|
||||
|
||||
if not self.shared_memory:
|
||||
self.observations = concatenate(
|
||||
self.single_observation_space,
|
||||
observations_list,
|
||||
observations,
|
||||
self.observations,
|
||||
)
|
||||
|
||||
self._state = AsyncState.DEFAULT
|
||||
return (
|
||||
deepcopy(self.observations) if self.copy else self.observations,
|
||||
np.array(rewards),
|
||||
np.array(terminateds, dtype=np.bool_),
|
||||
np.array(truncateds, dtype=np.bool_),
|
||||
np.array(rewards, dtype=np.float64),
|
||||
np.array(terminations, dtype=np.bool_),
|
||||
np.array(truncations, dtype=np.bool_),
|
||||
infos,
|
||||
)
|
||||
|
||||
def call(self, name: str, *args: Any, **kwargs: Any) -> tuple[Any, ...]:
|
||||
"""Call a method from each parallel environment with args and kwargs.
|
||||
|
||||
Args:
|
||||
name (str): Name of the method or property to call.
|
||||
*args: Position arguments to apply to the method call.
|
||||
**kwargs: Keyword arguments to apply to the method call.
|
||||
|
||||
Returns:
|
||||
List of the results of the individual calls to the method or property for each environment.
|
||||
"""
|
||||
self.call_async(name, *args, **kwargs)
|
||||
return self.call_wait()
|
||||
|
||||
def render(self) -> tuple[RenderFrame, ...] | None:
|
||||
"""Returns a list of rendered frames from the environments."""
|
||||
return self.call("render")
|
||||
|
||||
def call_async(self, name: str, *args, **kwargs):
|
||||
"""Calls the method with name asynchronously and apply args and kwargs to the method.
|
||||
|
||||
@@ -363,28 +431,27 @@ class AsyncVectorEnv(VectorEnv):
|
||||
self._assert_is_running()
|
||||
if self._state != AsyncState.DEFAULT:
|
||||
raise AlreadyPendingCallError(
|
||||
"Calling `call_async` while waiting "
|
||||
f"for a pending call to `{self._state.value}` to complete.",
|
||||
self._state.value,
|
||||
f"Calling `call_async` while waiting for a pending call to `{self._state.value}` to complete.",
|
||||
str(self._state.value),
|
||||
)
|
||||
|
||||
for pipe in self.parent_pipes:
|
||||
pipe.send(("_call", (name, args, kwargs)))
|
||||
self._state = AsyncState.WAITING_CALL
|
||||
|
||||
def call_wait(self, timeout: Optional[Union[int, float]] = None) -> list:
|
||||
def call_wait(self, timeout: int | float | None = None) -> tuple[Any, ...]:
|
||||
"""Calls all parent pipes and waits for the results.
|
||||
|
||||
Args:
|
||||
timeout: Number of seconds before the call to `step_wait` times out.
|
||||
If `None` (default), the call to `step_wait` never times out.
|
||||
timeout: Number of seconds before the call to :meth:`step_wait` times out.
|
||||
If ``None`` (default), the call to :meth:`step_wait` never times out.
|
||||
|
||||
Returns:
|
||||
List of the results of the individual calls to the method or property for each environment.
|
||||
|
||||
Raises:
|
||||
NoAsyncCallError: Calling `call_wait` without any prior call to `call_async`.
|
||||
TimeoutError: The call to `call_wait` has timed out after timeout second(s).
|
||||
NoAsyncCallError: Calling :meth:`call_wait` without any prior call to :meth:`call_async`.
|
||||
TimeoutError: The call to :meth:`call_wait` has timed out after timeout second(s).
|
||||
"""
|
||||
self._assert_is_running()
|
||||
if self._state != AsyncState.WAITING_CALL:
|
||||
@@ -393,9 +460,9 @@ class AsyncVectorEnv(VectorEnv):
|
||||
AsyncState.WAITING_CALL.value,
|
||||
)
|
||||
|
||||
if not self._poll(timeout):
|
||||
if not self._poll_pipe_envs(timeout):
|
||||
self._state = AsyncState.DEFAULT
|
||||
raise mp.TimeoutError(
|
||||
raise multiprocessing.TimeoutError(
|
||||
f"The call to `call_wait` has timed out after {timeout} second(s)."
|
||||
)
|
||||
|
||||
@@ -405,7 +472,18 @@ class AsyncVectorEnv(VectorEnv):
|
||||
|
||||
return results
|
||||
|
||||
def set_attr(self, name: str, values: Union[list, tuple, object]):
|
||||
def get_attr(self, name: str):
|
||||
"""Get a property from each parallel environment.
|
||||
|
||||
Args:
|
||||
name (str): Name of the property to be get from each individual environment.
|
||||
|
||||
Returns:
|
||||
The property with name
|
||||
"""
|
||||
return self.call(name)
|
||||
|
||||
def set_attr(self, name: str, values: list[Any] | tuple[Any] | object):
|
||||
"""Sets an attribute of the sub-environments.
|
||||
|
||||
Args:
|
||||
@@ -416,23 +494,21 @@ class AsyncVectorEnv(VectorEnv):
|
||||
|
||||
Raises:
|
||||
ValueError: Values must be a list or tuple with length equal to the number of environments.
|
||||
AlreadyPendingCallError: Calling `set_attr` while waiting for a pending call to complete.
|
||||
AlreadyPendingCallError: Calling :meth:`set_attr` while waiting for a pending call to complete.
|
||||
"""
|
||||
self._assert_is_running()
|
||||
if not isinstance(values, (list, tuple)):
|
||||
values = [values for _ in range(self.num_envs)]
|
||||
if len(values) != self.num_envs:
|
||||
raise ValueError(
|
||||
"Values must be a list or tuple with length equal to the "
|
||||
f"number of environments. Got `{len(values)}` values for "
|
||||
f"{self.num_envs} environments."
|
||||
"Values must be a list or tuple with length equal to the number of environments. "
|
||||
f"Got `{len(values)}` values for {self.num_envs} environments."
|
||||
)
|
||||
|
||||
if self._state != AsyncState.DEFAULT:
|
||||
raise AlreadyPendingCallError(
|
||||
"Calling `set_attr` while waiting "
|
||||
f"for a pending call to `{self._state.value}` to complete.",
|
||||
self._state.value,
|
||||
f"Calling `set_attr` while waiting for a pending call to `{self._state.value}` to complete.",
|
||||
str(self._state.value),
|
||||
)
|
||||
|
||||
for pipe, value in zip(self.parent_pipes, values):
|
||||
@@ -440,9 +516,7 @@ class AsyncVectorEnv(VectorEnv):
|
||||
_, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
|
||||
self._raise_if_errors(successes)
|
||||
|
||||
def close_extras(
|
||||
self, timeout: Optional[Union[int, float]] = None, terminate: bool = False
|
||||
):
|
||||
def close_extras(self, timeout: int | float | None = None, terminate: bool = False):
|
||||
"""Close the environments & clean up the extra resources (processes and pipes).
|
||||
|
||||
Args:
|
||||
@@ -462,7 +536,7 @@ class AsyncVectorEnv(VectorEnv):
|
||||
)
|
||||
function = getattr(self, f"{self._state.value}_wait")
|
||||
function(timeout)
|
||||
except mp.TimeoutError:
|
||||
except multiprocessing.TimeoutError:
|
||||
terminate = True
|
||||
|
||||
if terminate:
|
||||
@@ -483,14 +557,16 @@ class AsyncVectorEnv(VectorEnv):
|
||||
for process in self.processes:
|
||||
process.join()
|
||||
|
||||
def _poll(self, timeout=None):
|
||||
def _poll_pipe_envs(self, timeout: int | None = None):
|
||||
self._assert_is_running()
|
||||
|
||||
if timeout is None:
|
||||
return True
|
||||
|
||||
end_time = time.perf_counter() + timeout
|
||||
delta = None
|
||||
for pipe in self.parent_pipes:
|
||||
delta = max(end_time - time.perf_counter(), 0)
|
||||
|
||||
if pipe is None:
|
||||
return False
|
||||
if pipe.closed or (not pipe.poll(delta)):
|
||||
@@ -500,22 +576,23 @@ class AsyncVectorEnv(VectorEnv):
|
||||
def _check_spaces(self):
|
||||
self._assert_is_running()
|
||||
spaces = (self.single_observation_space, self.single_action_space)
|
||||
|
||||
for pipe in self.parent_pipes:
|
||||
pipe.send(("_check_spaces", spaces))
|
||||
|
||||
results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
|
||||
self._raise_if_errors(successes)
|
||||
same_observation_spaces, same_action_spaces = zip(*results)
|
||||
|
||||
if not all(same_observation_spaces):
|
||||
raise RuntimeError(
|
||||
"Some environments have an observation space different from "
|
||||
f"`{self.single_observation_space}`. In order to batch observations, "
|
||||
"the observation spaces from all environments must be equal."
|
||||
f"Some environments have an observation space different from `{self.single_observation_space}`. "
|
||||
"In order to batch observations, the observation spaces from all environments must be equal."
|
||||
)
|
||||
if not all(same_action_spaces):
|
||||
raise RuntimeError(
|
||||
"Some environments have an action space different from "
|
||||
f"`{self.single_action_space}`. In order to batch actions, the "
|
||||
"action spaces from all environments must be equal."
|
||||
f"Some environments have an action space different from `{self.single_action_space}`. "
|
||||
"In order to batch actions, the action spaces from all environments must be equal."
|
||||
)
|
||||
|
||||
def _assert_is_running(self):
|
||||
@@ -524,7 +601,7 @@ class AsyncVectorEnv(VectorEnv):
|
||||
f"Trying to operate on `{type(self).__name__}`, after a call to `close()`."
|
||||
)
|
||||
|
||||
def _raise_if_errors(self, successes):
|
||||
def _raise_if_errors(self, successes: list[bool]):
|
||||
if all(successes):
|
||||
return
|
||||
|
||||
@@ -532,10 +609,12 @@ class AsyncVectorEnv(VectorEnv):
|
||||
assert num_errors > 0
|
||||
for i in range(num_errors):
|
||||
index, exctype, value = self.error_queue.get()
|
||||
|
||||
logger.error(
|
||||
f"Received the following error from Worker-{index}: {exctype.__name__}: {value}"
|
||||
)
|
||||
logger.error(f"Shutting down Worker-{index}.")
|
||||
|
||||
self.parent_pipes[index].close()
|
||||
self.parent_pipes[index] = None
|
||||
|
||||
@@ -549,17 +628,32 @@ class AsyncVectorEnv(VectorEnv):
|
||||
self.close(terminate=True)
|
||||
|
||||
|
||||
def _worker(index, env_fn, pipe, parent_pipe, shared_memory, error_queue):
|
||||
assert shared_memory is None
|
||||
def _async_worker(
|
||||
index: int,
|
||||
env_fn: callable,
|
||||
pipe: Connection,
|
||||
parent_pipe: Connection,
|
||||
shared_memory: bool,
|
||||
error_queue: Queue,
|
||||
):
|
||||
env = env_fn()
|
||||
observation_space = env.observation_space
|
||||
action_space = env.action_space
|
||||
|
||||
parent_pipe.close()
|
||||
|
||||
try:
|
||||
while True:
|
||||
command, data = pipe.recv()
|
||||
|
||||
if command == "reset":
|
||||
observation, info = env.reset(**data)
|
||||
if shared_memory:
|
||||
write_to_shared_memory(
|
||||
observation_space, index, observation, shared_memory
|
||||
)
|
||||
observation = None
|
||||
pipe.send(((observation, info), True))
|
||||
|
||||
elif command == "step":
|
||||
(
|
||||
observation,
|
||||
@@ -573,112 +667,43 @@ def _worker(index, env_fn, pipe, parent_pipe, shared_memory, error_queue):
|
||||
observation, info = env.reset()
|
||||
info["final_observation"] = old_observation
|
||||
info["final_info"] = old_info
|
||||
|
||||
if shared_memory:
|
||||
write_to_shared_memory(
|
||||
observation_space, index, observation, shared_memory
|
||||
)
|
||||
observation = None
|
||||
|
||||
pipe.send(((observation, reward, terminated, truncated, info), True))
|
||||
elif command == "seed":
|
||||
env.seed(data)
|
||||
pipe.send((None, True))
|
||||
elif command == "close":
|
||||
pipe.send((None, True))
|
||||
break
|
||||
elif command == "_call":
|
||||
name, args, kwargs = data
|
||||
if name in ["reset", "step", "seed", "close"]:
|
||||
if name in ["reset", "step", "close", "set_wrapper_attr"]:
|
||||
raise ValueError(
|
||||
f"Trying to call function `{name}` with "
|
||||
f"`_call`. Use `{name}` directly instead."
|
||||
f"Trying to call function `{name}` with `call`, use `{name}` directly instead."
|
||||
)
|
||||
function = getattr(env, name)
|
||||
if callable(function):
|
||||
pipe.send((function(*args, **kwargs), True))
|
||||
|
||||
attr = env.get_wrapper_attr(name)
|
||||
if callable(attr):
|
||||
pipe.send((attr(*args, **kwargs), True))
|
||||
else:
|
||||
pipe.send((function, True))
|
||||
pipe.send((attr, True))
|
||||
elif command == "_setattr":
|
||||
name, value = data
|
||||
setattr(env, name, value)
|
||||
env.set_wrapper_attr(name, value)
|
||||
pipe.send((None, True))
|
||||
elif command == "_check_spaces":
|
||||
pipe.send(
|
||||
(
|
||||
(data[0] == env.observation_space, data[1] == env.action_space),
|
||||
(data[0] == observation_space, data[1] == action_space),
|
||||
True,
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Received unknown command `{command}`. Must "
|
||||
"be one of {`reset`, `step`, `seed`, `close`, `_call`, "
|
||||
"`_setattr`, `_check_spaces`}."
|
||||
)
|
||||
except (KeyboardInterrupt, Exception):
|
||||
error_queue.put((index,) + sys.exc_info()[:2])
|
||||
pipe.send((None, False))
|
||||
finally:
|
||||
env.close()
|
||||
|
||||
|
||||
def _worker_shared_memory(index, env_fn, pipe, parent_pipe, shared_memory, error_queue):
|
||||
assert shared_memory is not None
|
||||
env = env_fn()
|
||||
observation_space = env.observation_space
|
||||
parent_pipe.close()
|
||||
try:
|
||||
while True:
|
||||
command, data = pipe.recv()
|
||||
if command == "reset":
|
||||
observation, info = env.reset(**data)
|
||||
write_to_shared_memory(
|
||||
observation_space, index, observation, shared_memory
|
||||
)
|
||||
pipe.send(((None, info), True))
|
||||
|
||||
elif command == "step":
|
||||
(
|
||||
observation,
|
||||
reward,
|
||||
terminated,
|
||||
truncated,
|
||||
info,
|
||||
) = env.step(data)
|
||||
if terminated or truncated:
|
||||
old_observation, old_info = observation, info
|
||||
observation, info = env.reset()
|
||||
info["final_observation"] = old_observation
|
||||
info["final_info"] = old_info
|
||||
write_to_shared_memory(
|
||||
observation_space, index, observation, shared_memory
|
||||
)
|
||||
pipe.send(((None, reward, terminated, truncated, info), True))
|
||||
elif command == "seed":
|
||||
env.seed(data)
|
||||
pipe.send((None, True))
|
||||
elif command == "close":
|
||||
pipe.send((None, True))
|
||||
break
|
||||
elif command == "_call":
|
||||
name, args, kwargs = data
|
||||
if name in ["reset", "step", "seed", "close"]:
|
||||
raise ValueError(
|
||||
f"Trying to call function `{name}` with "
|
||||
f"`_call`. Use `{name}` directly instead."
|
||||
)
|
||||
function = getattr(env, name)
|
||||
if callable(function):
|
||||
pipe.send((function(*args, **kwargs), True))
|
||||
else:
|
||||
pipe.send((function, True))
|
||||
elif command == "_setattr":
|
||||
name, value = data
|
||||
setattr(env, name, value)
|
||||
pipe.send((None, True))
|
||||
elif command == "_check_spaces":
|
||||
pipe.send(
|
||||
((data[0] == observation_space, data[1] == env.action_space), True)
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Received unknown command `{command}`. Must "
|
||||
"be one of {`reset`, `step`, `seed`, `close`, `_call`, "
|
||||
"`_setattr`, `_check_spaces`}."
|
||||
f"Received unknown command `{command}`. Must be one of [`reset`, `step`, `close`, `_call`, `_setattr`, `_check_spaces`]."
|
||||
)
|
||||
except (KeyboardInterrupt, Exception):
|
||||
error_queue.put((index,) + sys.exc_info()[:2])
|
||||
|
@@ -1,14 +1,15 @@
|
||||
"""A synchronous vector environment."""
|
||||
"""Implementation of a synchronous (for loop) vectorization method of any environment."""
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Any, Callable, Iterable, List, Optional, Sequence, Tuple, Union
|
||||
from typing import Any, Callable, Iterator, Sequence
|
||||
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from gymnasium import Env
|
||||
from gymnasium.spaces import Space
|
||||
from gymnasium.vector.utils import concatenate, create_empty_array, iterate
|
||||
from gymnasium.vector.vector_env import VectorEnv
|
||||
from gymnasium.core import ActType, ObsType, RenderFrame
|
||||
from gymnasium.vector.utils import batch_space, concatenate, create_empty_array, iterate
|
||||
from gymnasium.vector.vector_env import ArrayType, VectorEnv
|
||||
|
||||
|
||||
__all__ = ["SyncVectorEnv"]
|
||||
@@ -19,156 +20,175 @@ class SyncVectorEnv(VectorEnv):
|
||||
|
||||
Example:
|
||||
>>> import gymnasium as gym
|
||||
>>> env = gym.vector.SyncVectorEnv([
|
||||
>>> envs = gym.make_vec("Pendulum-v1", num_envs=2, vectorization_mode="sync")
|
||||
>>> envs
|
||||
SyncVectorEnv(Pendulum-v1, num_envs=2)
|
||||
>>> envs = gym.vector.SyncVectorEnv([
|
||||
... lambda: gym.make("Pendulum-v1", g=9.81),
|
||||
... lambda: gym.make("Pendulum-v1", g=1.62)
|
||||
... ])
|
||||
>>> env.reset(seed=42)
|
||||
(array([[-0.14995256, 0.9886932 , -0.12224312],
|
||||
[ 0.5760367 , 0.8174238 , -0.91244936]], dtype=float32), {})
|
||||
>>> envs
|
||||
SyncVectorEnv(num_envs=2)
|
||||
>>> obs, infos = envs.reset(seed=42)
|
||||
>>> obs
|
||||
array([[-0.14995256, 0.9886932 , -0.12224312],
|
||||
[ 0.5760367 , 0.8174238 , -0.91244936]], dtype=float32)
|
||||
>>> infos
|
||||
{}
|
||||
>>> _ = envs.action_space.seed(42)
|
||||
>>> actions = envs.action_space.sample()
|
||||
>>> obs, rewards, terminates, truncates, infos = envs.step(actions)
|
||||
>>> obs
|
||||
array([[-0.1878752 , 0.98219293, 0.7695615 ],
|
||||
[ 0.6102389 , 0.79221743, -0.8498053 ]], dtype=float32)
|
||||
>>> rewards
|
||||
array([-2.96562607, -0.99902063])
|
||||
>>> terminates
|
||||
array([False, False])
|
||||
>>> truncates
|
||||
array([False, False])
|
||||
>>> infos
|
||||
{}
|
||||
>>> envs.close()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env_fns: Iterable[Callable[[], Env]],
|
||||
observation_space: Space = None,
|
||||
action_space: Space = None,
|
||||
env_fns: Iterator[Callable[[], Env]] | Sequence[Callable[[], Env]],
|
||||
copy: bool = True,
|
||||
):
|
||||
"""Vectorized environment that serially runs multiple environments.
|
||||
|
||||
Args:
|
||||
env_fns: iterable of callable functions that create the environments.
|
||||
observation_space: Observation space of a single environment. If ``None``,
|
||||
then the observation space of the first environment is taken.
|
||||
action_space: Action space of a single environment. If ``None``,
|
||||
then the action space of the first environment is taken.
|
||||
copy: If ``True``, then the :meth:`reset` and :meth:`step` methods return a copy of the observations.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the observation space of some sub-environment does not match observation_space
|
||||
(or, by default, the observation space of the first sub-environment).
|
||||
"""
|
||||
self.env_fns = env_fns
|
||||
self.envs = [env_fn() for env_fn in env_fns]
|
||||
self.copy = copy
|
||||
self.env_fns = env_fns
|
||||
|
||||
# Initialise all sub-environments
|
||||
self.envs = [env_fn() for env_fn in env_fns]
|
||||
|
||||
# Define core attributes using the sub-environments
|
||||
# As we support `make_vec(spec)` then we can't include a `spec = self.envs[0].spec` as this doesn't guarantee we can actual recreate the vector env.
|
||||
self.num_envs = len(self.envs)
|
||||
self.metadata = self.envs[0].metadata
|
||||
self.render_mode = self.envs[0].render_mode
|
||||
|
||||
if (observation_space is None) or (action_space is None):
|
||||
observation_space = observation_space or self.envs[0].observation_space
|
||||
action_space = action_space or self.envs[0].action_space
|
||||
super().__init__(
|
||||
num_envs=len(self.envs),
|
||||
observation_space=observation_space,
|
||||
action_space=action_space,
|
||||
)
|
||||
|
||||
# Initialises the single spaces from the sub-environments
|
||||
self.single_observation_space = self.envs[0].observation_space
|
||||
self.single_action_space = self.envs[0].action_space
|
||||
self._check_spaces()
|
||||
self.observations = create_empty_array(
|
||||
|
||||
# Initialise the obs and action space based on the single versions and num of sub-environments
|
||||
self.observation_space = batch_space(
|
||||
self.single_observation_space, self.num_envs
|
||||
)
|
||||
self.action_space = batch_space(self.single_action_space, self.num_envs)
|
||||
|
||||
# Initialise attributes used in `step` and `reset`
|
||||
self._observations = create_empty_array(
|
||||
self.single_observation_space, n=self.num_envs, fn=np.zeros
|
||||
)
|
||||
self._rewards = np.zeros((self.num_envs,), dtype=np.float64)
|
||||
self._terminateds = np.zeros((self.num_envs,), dtype=np.bool_)
|
||||
self._truncateds = np.zeros((self.num_envs,), dtype=np.bool_)
|
||||
self._actions = None
|
||||
self._terminations = np.zeros((self.num_envs,), dtype=np.bool_)
|
||||
self._truncations = np.zeros((self.num_envs,), dtype=np.bool_)
|
||||
|
||||
def seed(self, seed: Optional[Union[int, Sequence[int]]] = None):
|
||||
"""Sets the seed in all sub-environments.
|
||||
|
||||
Args:
|
||||
seed: The seed
|
||||
"""
|
||||
super().seed(seed=seed)
|
||||
if seed is None:
|
||||
seed = [None for _ in range(self.num_envs)]
|
||||
if isinstance(seed, int):
|
||||
seed = [seed + i for i in range(self.num_envs)]
|
||||
assert len(seed) == self.num_envs
|
||||
|
||||
for env, single_seed in zip(self.envs, seed):
|
||||
env.seed(single_seed)
|
||||
|
||||
def reset_wait(
|
||||
def reset(
|
||||
self,
|
||||
seed: Optional[Union[int, List[int]]] = None,
|
||||
options: Optional[dict] = None,
|
||||
):
|
||||
"""Waits for the calls triggered by :meth:`reset_async` to finish and returns the results.
|
||||
*,
|
||||
seed: int | list[int] | None = None,
|
||||
options: dict[str, Any] | None = None,
|
||||
) -> tuple[ObsType, dict[str, Any]]:
|
||||
"""Resets each of the sub-environments and concatenate the results together.
|
||||
|
||||
Args:
|
||||
seed: The reset environment seed
|
||||
options: Option information for the environment reset
|
||||
seed: Seeds used to reset the sub-environments, either
|
||||
* ``None`` - random seeds for all environment
|
||||
* ``int`` - ``[seed, seed+1, ..., seed+n]``
|
||||
* List of ints - ``[1, 2, 3, ..., n]``
|
||||
options: Option information used for each sub-environment
|
||||
|
||||
Returns:
|
||||
The reset observation of the environment and reset information
|
||||
Concatenated observations and info from each sub-environment
|
||||
"""
|
||||
if seed is None:
|
||||
seed = [None for _ in range(self.num_envs)]
|
||||
if isinstance(seed, int):
|
||||
elif isinstance(seed, int):
|
||||
seed = [seed + i for i in range(self.num_envs)]
|
||||
assert len(seed) == self.num_envs
|
||||
|
||||
self._terminateds[:] = False
|
||||
self._truncateds[:] = False
|
||||
observations = []
|
||||
infos = {}
|
||||
self._terminations = np.zeros((self.num_envs,), dtype=np.bool_)
|
||||
self._truncations = np.zeros((self.num_envs,), dtype=np.bool_)
|
||||
|
||||
observations, infos = [], {}
|
||||
for i, (env, single_seed) in enumerate(zip(self.envs, seed)):
|
||||
kwargs = {}
|
||||
if single_seed is not None:
|
||||
kwargs["seed"] = single_seed
|
||||
if options is not None:
|
||||
kwargs["options"] = options
|
||||
env_obs, env_info = env.reset(seed=single_seed, options=options)
|
||||
|
||||
observation, info = env.reset(**kwargs)
|
||||
observations.append(observation)
|
||||
infos = self._add_info(infos, info, i)
|
||||
observations.append(env_obs)
|
||||
infos = self._add_info(infos, env_info, i)
|
||||
|
||||
self.observations = concatenate(
|
||||
self.single_observation_space, observations, self.observations
|
||||
# Concatenate the observations
|
||||
self._observations = concatenate(
|
||||
self.single_observation_space, observations, self._observations
|
||||
)
|
||||
return (deepcopy(self.observations) if self.copy else self.observations), infos
|
||||
|
||||
def step_async(self, actions):
|
||||
"""Sets :attr:`_actions` for use by the :meth:`step_wait` by converting the ``actions`` to an iterable version."""
|
||||
self._actions = iterate(self.action_space, actions)
|
||||
return deepcopy(self._observations) if self.copy else self._observations, infos
|
||||
|
||||
def step_wait(self) -> Tuple[Any, NDArray[Any], NDArray[Any], NDArray[Any], dict]:
|
||||
def step(
|
||||
self, actions: ActType
|
||||
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict[str, Any]]:
|
||||
"""Steps through each of the environments returning the batched results.
|
||||
|
||||
Returns:
|
||||
The batched environment step results
|
||||
"""
|
||||
actions = iterate(self.action_space, actions)
|
||||
|
||||
observations, infos = [], {}
|
||||
for i, (env, action) in enumerate(zip(self.envs, self._actions)):
|
||||
for i, (env, action) in enumerate(zip(self.envs, actions)):
|
||||
(
|
||||
observation,
|
||||
env_obs,
|
||||
self._rewards[i],
|
||||
self._terminateds[i],
|
||||
self._truncateds[i],
|
||||
info,
|
||||
self._terminations[i],
|
||||
self._truncations[i],
|
||||
env_info,
|
||||
) = env.step(action)
|
||||
|
||||
if self._terminateds[i] or self._truncateds[i]:
|
||||
old_observation, old_info = observation, info
|
||||
observation, info = env.reset()
|
||||
info["final_observation"] = old_observation
|
||||
info["final_info"] = old_info
|
||||
observations.append(observation)
|
||||
infos = self._add_info(infos, info, i)
|
||||
self.observations = concatenate(
|
||||
self.single_observation_space, observations, self.observations
|
||||
# If sub-environments terminates or truncates then save the obs and info to the batched info
|
||||
if self._terminations[i] or self._truncations[i]:
|
||||
old_observation, old_info = env_obs, env_info
|
||||
env_obs, env_info = env.reset()
|
||||
|
||||
env_info["final_observation"] = old_observation
|
||||
env_info["final_info"] = old_info
|
||||
|
||||
observations.append(env_obs)
|
||||
infos = self._add_info(infos, env_info, i)
|
||||
|
||||
# Concatenate the observations
|
||||
self._observations = concatenate(
|
||||
self.single_observation_space, observations, self._observations
|
||||
)
|
||||
|
||||
return (
|
||||
deepcopy(self.observations) if self.copy else self.observations,
|
||||
deepcopy(self._observations) if self.copy else self._observations,
|
||||
np.copy(self._rewards),
|
||||
np.copy(self._terminateds),
|
||||
np.copy(self._truncateds),
|
||||
np.copy(self._terminations),
|
||||
np.copy(self._truncations),
|
||||
infos,
|
||||
)
|
||||
|
||||
def call(self, name, *args, **kwargs) -> tuple:
|
||||
"""Calls the method with name and applies args and kwargs.
|
||||
def render(self) -> tuple[RenderFrame, ...] | None:
|
||||
"""Returns the rendered frames from the environments."""
|
||||
return tuple(env.render() for env in self.envs)
|
||||
|
||||
def call(self, name: str, *args: Any, **kwargs: Any) -> tuple[Any, ...]:
|
||||
"""Calls a sub-environment method with name and applies args and kwargs.
|
||||
|
||||
Args:
|
||||
name: The method name
|
||||
@@ -180,7 +200,8 @@ class SyncVectorEnv(VectorEnv):
|
||||
"""
|
||||
results = []
|
||||
for env in self.envs:
|
||||
function = getattr(env, name)
|
||||
function = env.get_wrapper_attr(name)
|
||||
|
||||
if callable(function):
|
||||
results.append(function(*args, **kwargs))
|
||||
else:
|
||||
@@ -188,7 +209,18 @@ class SyncVectorEnv(VectorEnv):
|
||||
|
||||
return tuple(results)
|
||||
|
||||
def set_attr(self, name: str, values: Union[list, tuple, Any]):
|
||||
def get_attr(self, name: str) -> Any:
|
||||
"""Get a property from each parallel environment.
|
||||
|
||||
Args:
|
||||
name (str): Name of the property to get from each individual environment.
|
||||
|
||||
Returns:
|
||||
The property with name
|
||||
"""
|
||||
return self.call(name)
|
||||
|
||||
def set_attr(self, name: str, values: list[Any] | tuple[Any, ...] | Any):
|
||||
"""Sets an attribute of the sub-environments.
|
||||
|
||||
Args:
|
||||
@@ -202,34 +234,33 @@ class SyncVectorEnv(VectorEnv):
|
||||
"""
|
||||
if not isinstance(values, (list, tuple)):
|
||||
values = [values for _ in range(self.num_envs)]
|
||||
|
||||
if len(values) != self.num_envs:
|
||||
raise ValueError(
|
||||
"Values must be a list or tuple with length equal to the "
|
||||
f"number of environments. Got `{len(values)}` values for "
|
||||
f"{self.num_envs} environments."
|
||||
"Values must be a list or tuple with length equal to the number of environments. "
|
||||
f"Got `{len(values)}` values for {self.num_envs} environments."
|
||||
)
|
||||
|
||||
for env, value in zip(self.envs, values):
|
||||
setattr(env, name, value)
|
||||
env.set_wrapper_attr(name, value)
|
||||
|
||||
def close_extras(self, **kwargs):
|
||||
def close_extras(self, **kwargs: Any):
|
||||
"""Close the environments."""
|
||||
[env.close() for env in self.envs]
|
||||
|
||||
def _check_spaces(self) -> bool:
|
||||
"""Check that each of the environments obs and action spaces are equivalent to the single obs and action space."""
|
||||
for env in self.envs:
|
||||
if not (env.observation_space == self.single_observation_space):
|
||||
raise RuntimeError(
|
||||
"Some environments have an observation space different from "
|
||||
f"`{self.single_observation_space}`. In order to batch observations, "
|
||||
"the observation spaces from all environments must be equal."
|
||||
f"Some environments have an observation space different from `{self.single_observation_space}`. "
|
||||
"In order to batch observations, the observation spaces from all environments must be equal."
|
||||
)
|
||||
|
||||
if not (env.action_space == self.single_action_space):
|
||||
raise RuntimeError(
|
||||
"Some environments have an action space different from "
|
||||
f"`{self.single_action_space}`. In order to batch actions, the "
|
||||
"action spaces from all environments must be equal."
|
||||
f"Some environments have an action space different from `{self.single_action_space}`. "
|
||||
"In order to batch actions, the action spaces from all environments must be equal."
|
||||
)
|
||||
|
||||
return True
|
||||
|
@@ -1,26 +1,27 @@
|
||||
"""Module for gymnasium vector utils."""
|
||||
"""Module for gymnasium experimental vector utility functions."""
|
||||
|
||||
from gymnasium.vector.utils.misc import CloudpickleWrapper, clear_mpi_env_vars
|
||||
from gymnasium.vector.utils.numpy_utils import concatenate, create_empty_array
|
||||
from gymnasium.vector.utils.shared_memory import (
|
||||
create_shared_memory,
|
||||
read_from_shared_memory,
|
||||
write_to_shared_memory,
|
||||
)
|
||||
from gymnasium.vector.utils.spaces import (
|
||||
_BaseGymSpaces, # pyright: ignore[reportPrivateUsage]
|
||||
from gymnasium.vector.utils.space_utils import (
|
||||
batch_space,
|
||||
concatenate,
|
||||
create_empty_array,
|
||||
iterate,
|
||||
)
|
||||
from gymnasium.vector.utils.spaces import BaseGymSpaces, batch_space, iterate
|
||||
|
||||
|
||||
__all__ = [
|
||||
"CloudpickleWrapper",
|
||||
"clear_mpi_env_vars",
|
||||
"batch_space",
|
||||
"iterate",
|
||||
"concatenate",
|
||||
"create_empty_array",
|
||||
"create_shared_memory",
|
||||
"read_from_shared_memory",
|
||||
"write_to_shared_memory",
|
||||
"BaseGymSpaces",
|
||||
"batch_space",
|
||||
"iterate",
|
||||
"CloudpickleWrapper",
|
||||
"clear_mpi_env_vars",
|
||||
]
|
||||
|
@@ -39,7 +39,7 @@ class CloudpickleWrapper:
|
||||
def clear_mpi_env_vars():
|
||||
"""Clears the MPI of environment variables.
|
||||
|
||||
`from mpi4py import MPI` will call `MPI_Init` by default.
|
||||
``from mpi4py import MPI`` will call ``MPI_Init`` by default.
|
||||
If the child process has MPI environment variables, MPI will think that the child process
|
||||
is an MPI process just like the parent and do bad things such as hang.
|
||||
|
||||
|
@@ -1,146 +0,0 @@
|
||||
"""Numpy utility functions: concatenate space samples and create empty array."""
|
||||
from collections import OrderedDict
|
||||
from functools import singledispatch
|
||||
from typing import Callable, Iterable, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from gymnasium.spaces import (
|
||||
Box,
|
||||
Dict,
|
||||
Discrete,
|
||||
MultiBinary,
|
||||
MultiDiscrete,
|
||||
Space,
|
||||
Tuple,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["concatenate", "create_empty_array"]
|
||||
|
||||
|
||||
@singledispatch
|
||||
def concatenate(
|
||||
space: Space, items: Iterable, out: Union[tuple, dict, np.ndarray]
|
||||
) -> Union[tuple, dict, np.ndarray]:
|
||||
"""Concatenate multiple samples from space into a single object.
|
||||
|
||||
Args:
|
||||
space: Observation space of a single environment in the vectorized environment.
|
||||
items: Samples to be concatenated.
|
||||
out: The output object. This object is a (possibly nested) numpy array.
|
||||
|
||||
Returns:
|
||||
The output object. This object is a (possibly nested) numpy array.
|
||||
|
||||
Raises:
|
||||
ValueError: Space is not a valid :class:`gym.Space` instance
|
||||
|
||||
Example:
|
||||
>>> from gymnasium.spaces import Box
|
||||
>>> import numpy as np
|
||||
>>> space = Box(low=0, high=1, shape=(3,), seed=42, dtype=np.float32)
|
||||
>>> out = np.zeros((2, 3), dtype=np.float32)
|
||||
>>> items = [space.sample() for _ in range(2)]
|
||||
>>> concatenate(space, items, out)
|
||||
array([[0.77395606, 0.43887845, 0.85859793],
|
||||
[0.697368 , 0.09417735, 0.97562236]], dtype=float32)
|
||||
"""
|
||||
raise ValueError(
|
||||
f"Space of type `{type(space)}` is not a valid `gymnasium.Space` instance."
|
||||
)
|
||||
|
||||
|
||||
@concatenate.register(Box)
|
||||
@concatenate.register(Discrete)
|
||||
@concatenate.register(MultiDiscrete)
|
||||
@concatenate.register(MultiBinary)
|
||||
def _concatenate_base(space, items, out):
|
||||
return np.stack(items, axis=0, out=out)
|
||||
|
||||
|
||||
@concatenate.register(Tuple)
|
||||
def _concatenate_tuple(space, items, out):
|
||||
return tuple(
|
||||
concatenate(subspace, [item[i] for item in items], out[i])
|
||||
for (i, subspace) in enumerate(space.spaces)
|
||||
)
|
||||
|
||||
|
||||
@concatenate.register(Dict)
|
||||
def _concatenate_dict(space, items, out):
|
||||
return OrderedDict(
|
||||
[
|
||||
(key, concatenate(subspace, [item[key] for item in items], out[key]))
|
||||
for (key, subspace) in space.spaces.items()
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@concatenate.register(Space)
|
||||
def _concatenate_custom(space, items, out):
|
||||
return tuple(items)
|
||||
|
||||
|
||||
@singledispatch
|
||||
def create_empty_array(
|
||||
space: Space, n: int = 1, fn: Callable[..., np.ndarray] = np.zeros
|
||||
) -> Union[tuple, dict, np.ndarray]:
|
||||
"""Create an empty (possibly nested) numpy array.
|
||||
|
||||
Args:
|
||||
space: Observation space of a single environment in the vectorized environment.
|
||||
n: Number of environments in the vectorized environment. If `None`, creates an empty sample from `space`.
|
||||
fn: Function to apply when creating the empty numpy array. Examples of such functions are `np.empty` or `np.zeros`.
|
||||
|
||||
Returns:
|
||||
The output object. This object is a (possibly nested) numpy array.
|
||||
|
||||
Raises:
|
||||
ValueError: Space is not a valid :class:`gym.Space` instance
|
||||
|
||||
Example:
|
||||
>>> from gymnasium.spaces import Box, Dict
|
||||
>>> import numpy as np
|
||||
>>> space = Dict({
|
||||
... 'position': Box(low=0, high=1, shape=(3,), dtype=np.float32),
|
||||
... 'velocity': Box(low=0, high=1, shape=(2,), dtype=np.float32)})
|
||||
>>> create_empty_array(space, n=2, fn=np.zeros)
|
||||
OrderedDict([('position', array([[0., 0., 0.],
|
||||
[0., 0., 0.]], dtype=float32)), ('velocity', array([[0., 0.],
|
||||
[0., 0.]], dtype=float32))])
|
||||
"""
|
||||
raise ValueError(
|
||||
f"Space of type `{type(space)}` is not a valid `gymnasium.Space` instance."
|
||||
)
|
||||
|
||||
|
||||
# It is possible for the some of the Box low to be greater than 0, then array is not in space
|
||||
@create_empty_array.register(Box)
|
||||
# If the Discrete start > 0 or start + length < 0 then array is not in space
|
||||
@create_empty_array.register(Discrete)
|
||||
@create_empty_array.register(MultiDiscrete)
|
||||
@create_empty_array.register(MultiBinary)
|
||||
def _create_empty_array_base(space, n=1, fn=np.zeros):
|
||||
shape = space.shape if (n is None) else (n,) + space.shape
|
||||
return fn(shape, dtype=space.dtype)
|
||||
|
||||
|
||||
@create_empty_array.register(Tuple)
|
||||
def _create_empty_array_tuple(space, n=1, fn=np.zeros):
|
||||
return tuple(create_empty_array(subspace, n=n, fn=fn) for subspace in space.spaces)
|
||||
|
||||
|
||||
@create_empty_array.register(Dict)
|
||||
def _create_empty_array_dict(space, n=1, fn=np.zeros):
|
||||
return OrderedDict(
|
||||
[
|
||||
(key, create_empty_array(subspace, n=n, fn=fn))
|
||||
for (key, subspace) in space.spaces.items()
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@create_empty_array.register(Space)
|
||||
def _create_empty_array_custom(space, n=1, fn=np.zeros):
|
||||
return None
|
@@ -1,9 +1,11 @@
|
||||
"""Utility functions for vector environments to share memory between processes."""
|
||||
from __future__ import annotations
|
||||
|
||||
import multiprocessing as mp
|
||||
from collections import OrderedDict
|
||||
from ctypes import c_bool
|
||||
from functools import singledispatch
|
||||
from typing import Union
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -12,10 +14,14 @@ from gymnasium.spaces import (
|
||||
Box,
|
||||
Dict,
|
||||
Discrete,
|
||||
Graph,
|
||||
MultiBinary,
|
||||
MultiDiscrete,
|
||||
Sequence,
|
||||
Space,
|
||||
Text,
|
||||
Tuple,
|
||||
flatten,
|
||||
)
|
||||
|
||||
|
||||
@@ -24,8 +30,8 @@ __all__ = ["create_shared_memory", "read_from_shared_memory", "write_to_shared_m
|
||||
|
||||
@singledispatch
|
||||
def create_shared_memory(
|
||||
space: Space, n: int = 1, ctx=mp
|
||||
) -> Union[dict, tuple, mp.Array]:
|
||||
space: Space[Any], n: int = 1, ctx=mp
|
||||
) -> dict[str, Any] | tuple[Any, ...] | mp.Array:
|
||||
"""Create a shared memory object, to be shared across processes.
|
||||
|
||||
This eventually contains the observations from the vectorized environment.
|
||||
@@ -41,20 +47,24 @@ def create_shared_memory(
|
||||
Raises:
|
||||
CustomSpaceError: Space is not a valid :class:`gymnasium.Space` instance
|
||||
"""
|
||||
raise CustomSpaceError(
|
||||
"Cannot create a shared memory for space with "
|
||||
f"type `{type(space)}`. Shared memory only supports "
|
||||
"default Gymnasium spaces (e.g. `Box`, `Tuple`, "
|
||||
"`Dict`, etc...), and does not support custom "
|
||||
"Gymnasium spaces."
|
||||
)
|
||||
if isinstance(space, Space):
|
||||
raise CustomSpaceError(
|
||||
f"Space of type `{type(space)}` doesn't have an registered `create_shared_memory` function. Register `{type(space)}` for `create_shared_memory` to support it."
|
||||
)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"The space provided to `create_shared_memory` is not a gymnasium Space instance, type: {type(space)}, {space}"
|
||||
)
|
||||
|
||||
|
||||
@create_shared_memory.register(Box)
|
||||
@create_shared_memory.register(Discrete)
|
||||
@create_shared_memory.register(MultiDiscrete)
|
||||
@create_shared_memory.register(MultiBinary)
|
||||
def _create_base_shared_memory(space, n: int = 1, ctx=mp):
|
||||
def _create_base_shared_memory(
|
||||
space: Box | Discrete | MultiDiscrete | MultiBinary, n: int = 1, ctx=mp
|
||||
):
|
||||
assert space.dtype is not None
|
||||
dtype = space.dtype.char
|
||||
if dtype in "?":
|
||||
dtype = c_bool
|
||||
@@ -62,14 +72,14 @@ def _create_base_shared_memory(space, n: int = 1, ctx=mp):
|
||||
|
||||
|
||||
@create_shared_memory.register(Tuple)
|
||||
def _create_tuple_shared_memory(space, n: int = 1, ctx=mp):
|
||||
def _create_tuple_shared_memory(space: Tuple, n: int = 1, ctx=mp):
|
||||
return tuple(
|
||||
create_shared_memory(subspace, n=n, ctx=ctx) for subspace in space.spaces
|
||||
)
|
||||
|
||||
|
||||
@create_shared_memory.register(Dict)
|
||||
def _create_dict_shared_memory(space, n=1, ctx=mp):
|
||||
def _create_dict_shared_memory(space: Dict, n: int = 1, ctx=mp):
|
||||
return OrderedDict(
|
||||
[
|
||||
(key, create_shared_memory(subspace, n=n, ctx=ctx))
|
||||
@@ -78,10 +88,23 @@ def _create_dict_shared_memory(space, n=1, ctx=mp):
|
||||
)
|
||||
|
||||
|
||||
@create_shared_memory.register(Text)
|
||||
def _create_text_shared_memory(space: Text, n: int = 1, ctx=mp):
|
||||
return ctx.Array(np.dtype(np.int32).char, n * space.max_length)
|
||||
|
||||
|
||||
@create_shared_memory.register(Graph)
|
||||
@create_shared_memory.register(Sequence)
|
||||
def _create_dynamic_shared_memory(space: Graph | Sequence, n: int = 1, ctx=mp):
|
||||
raise TypeError(
|
||||
f"As {space} has a dynamic shape then it is not possible to make a static shared memory."
|
||||
)
|
||||
|
||||
|
||||
@singledispatch
|
||||
def read_from_shared_memory(
|
||||
space: Space, shared_memory: Union[dict, tuple, mp.Array], n: int = 1
|
||||
) -> Union[dict, tuple, np.ndarray]:
|
||||
space: Space, shared_memory: dict | tuple | mp.Array, n: int = 1
|
||||
) -> dict[str, Any] | tuple[Any, ...] | np.ndarray:
|
||||
"""Read the batch of observations from shared memory as a numpy array.
|
||||
|
||||
..notes::
|
||||
@@ -101,27 +124,30 @@ def read_from_shared_memory(
|
||||
Raises:
|
||||
CustomSpaceError: Space is not a valid :class:`gymnasium.Space` instance
|
||||
"""
|
||||
raise CustomSpaceError(
|
||||
"Cannot read from a shared memory for space with "
|
||||
f"type `{type(space)}`. Shared memory only supports "
|
||||
"default Gymnasium spaces (e.g. `Box`, `Tuple`, "
|
||||
"`Dict`, etc...), and does not support custom "
|
||||
"Gymnasium spaces."
|
||||
)
|
||||
if isinstance(space, Space):
|
||||
raise CustomSpaceError(
|
||||
f"Space of type `{type(space)}` doesn't have an registered `read_from_shared_memory` function. Register `{type(space)}` for `read_from_shared_memory` to support it."
|
||||
)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"The space provided to `read_from_shared_memory` is not a gymnasium Space instance, type: {type(space)}, {space}"
|
||||
)
|
||||
|
||||
|
||||
@read_from_shared_memory.register(Box)
|
||||
@read_from_shared_memory.register(Discrete)
|
||||
@read_from_shared_memory.register(MultiDiscrete)
|
||||
@read_from_shared_memory.register(MultiBinary)
|
||||
def _read_base_from_shared_memory(space, shared_memory, n: int = 1):
|
||||
def _read_base_from_shared_memory(
|
||||
space: Box | Discrete | MultiDiscrete | MultiBinary, shared_memory, n: int = 1
|
||||
):
|
||||
return np.frombuffer(shared_memory.get_obj(), dtype=space.dtype).reshape(
|
||||
(n,) + space.shape
|
||||
)
|
||||
|
||||
|
||||
@read_from_shared_memory.register(Tuple)
|
||||
def _read_tuple_from_shared_memory(space, shared_memory, n: int = 1):
|
||||
def _read_tuple_from_shared_memory(space: Tuple, shared_memory, n: int = 1):
|
||||
return tuple(
|
||||
read_from_shared_memory(subspace, memory, n=n)
|
||||
for (memory, subspace) in zip(shared_memory, space.spaces)
|
||||
@@ -129,7 +155,7 @@ def _read_tuple_from_shared_memory(space, shared_memory, n: int = 1):
|
||||
|
||||
|
||||
@read_from_shared_memory.register(Dict)
|
||||
def _read_dict_from_shared_memory(space, shared_memory, n: int = 1):
|
||||
def _read_dict_from_shared_memory(space: Dict, shared_memory, n: int = 1):
|
||||
return OrderedDict(
|
||||
[
|
||||
(key, read_from_shared_memory(subspace, shared_memory[key], n=n))
|
||||
@@ -138,12 +164,30 @@ def _read_dict_from_shared_memory(space, shared_memory, n: int = 1):
|
||||
)
|
||||
|
||||
|
||||
@read_from_shared_memory.register(Text)
|
||||
def _read_text_from_shared_memory(space: Text, shared_memory, n: int = 1) -> tuple[str]:
|
||||
data = np.frombuffer(shared_memory.get_obj(), dtype=np.int32).reshape(
|
||||
(n, space.max_length)
|
||||
)
|
||||
|
||||
return tuple(
|
||||
"".join(
|
||||
[
|
||||
space.character_list[val]
|
||||
for val in values
|
||||
if val < len(space.character_set)
|
||||
]
|
||||
)
|
||||
for values in data
|
||||
)
|
||||
|
||||
|
||||
@singledispatch
|
||||
def write_to_shared_memory(
|
||||
space: Space,
|
||||
index: int,
|
||||
value: np.ndarray,
|
||||
shared_memory: Union[dict, tuple, mp.Array],
|
||||
shared_memory: dict[str, Any] | tuple[Any, ...] | mp.Array,
|
||||
):
|
||||
"""Write the observation of a single environment into shared memory.
|
||||
|
||||
@@ -157,20 +201,26 @@ def write_to_shared_memory(
|
||||
Raises:
|
||||
CustomSpaceError: Space is not a valid :class:`gymnasium.Space` instance
|
||||
"""
|
||||
raise CustomSpaceError(
|
||||
"Cannot write to a shared memory for space with "
|
||||
f"type `{type(space)}`. Shared memory only supports "
|
||||
"default Gymnasium spaces (e.g. `Box`, `Tuple`, "
|
||||
"`Dict`, etc...), and does not support custom "
|
||||
"Gymnasium spaces."
|
||||
)
|
||||
if isinstance(space, Space):
|
||||
raise CustomSpaceError(
|
||||
f"Space of type `{type(space)}` doesn't have an registered `write_to_shared_memory` function. Register `{type(space)}` for `write_to_shared_memory` to support it."
|
||||
)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"The space provided to `write_to_shared_memory` is not a gymnasium Space instance, type: {type(space)}, {space}"
|
||||
)
|
||||
|
||||
|
||||
@write_to_shared_memory.register(Box)
|
||||
@write_to_shared_memory.register(Discrete)
|
||||
@write_to_shared_memory.register(MultiDiscrete)
|
||||
@write_to_shared_memory.register(MultiBinary)
|
||||
def _write_base_to_shared_memory(space, index, value, shared_memory):
|
||||
def _write_base_to_shared_memory(
|
||||
space: Box | Discrete | MultiDiscrete | MultiBinary,
|
||||
index: int,
|
||||
value,
|
||||
shared_memory,
|
||||
):
|
||||
size = int(np.prod(space.shape))
|
||||
destination = np.frombuffer(shared_memory.get_obj(), dtype=space.dtype)
|
||||
np.copyto(
|
||||
@@ -180,12 +230,26 @@ def _write_base_to_shared_memory(space, index, value, shared_memory):
|
||||
|
||||
|
||||
@write_to_shared_memory.register(Tuple)
|
||||
def _write_tuple_to_shared_memory(space, index, values, shared_memory):
|
||||
def _write_tuple_to_shared_memory(
|
||||
space: Tuple, index: int, values: tuple[Any, ...], shared_memory
|
||||
):
|
||||
for value, memory, subspace in zip(values, shared_memory, space.spaces):
|
||||
write_to_shared_memory(subspace, index, value, memory)
|
||||
|
||||
|
||||
@write_to_shared_memory.register(Dict)
|
||||
def _write_dict_to_shared_memory(space, index, values, shared_memory):
|
||||
def _write_dict_to_shared_memory(
|
||||
space: Dict, index: int, values: dict[str, Any], shared_memory
|
||||
):
|
||||
for key, subspace in space.spaces.items():
|
||||
write_to_shared_memory(subspace, index, values[key], shared_memory[key])
|
||||
|
||||
|
||||
@write_to_shared_memory.register(Text)
|
||||
def _write_text_to_shared_memory(space: Text, index: int, values: str, shared_memory):
|
||||
size = space.max_length
|
||||
destination = np.frombuffer(shared_memory.get_obj(), dtype=np.int32)
|
||||
np.copyto(
|
||||
destination[index * size : (index + 1) * size],
|
||||
flatten(space, values),
|
||||
)
|
||||
|
@@ -150,7 +150,7 @@ def iterate(space: Space[T_cov], items: Iterable[T_cov]) -> Iterator:
|
||||
The output object. This object is a (possibly nested) numpy array.
|
||||
|
||||
Raises:
|
||||
ValueError: Space is not an instance of :class:`gym.Space`
|
||||
ValueError: Space is not an instance of :class:`gymnasium.Space`
|
||||
|
||||
Example:
|
||||
>>> from gymnasium.spaces import Box, Dict
|
||||
@@ -311,14 +311,14 @@ def create_empty_array(
|
||||
|
||||
Args:
|
||||
space: Observation space of a single environment in the vectorized environment.
|
||||
n: Number of environments in the vectorized environment. If `None`, creates an empty sample from `space`.
|
||||
fn: Function to apply when creating the empty numpy array. Examples of such functions are `np.empty` or `np.zeros`.
|
||||
n: Number of environments in the vectorized environment. If ``None``, creates an empty sample from ``space``.
|
||||
fn: Function to apply when creating the empty numpy array. Examples of such functions are ``np.empty`` or ``np.zeros``.
|
||||
|
||||
Returns:
|
||||
The output object. This object is a (possibly nested) numpy array.
|
||||
|
||||
Raises:
|
||||
ValueError: Space is not a valid :class:`gym.Space` instance
|
||||
ValueError: Space is not a valid :class:`gymnasium.Space` instance
|
||||
|
||||
Example:
|
||||
>>> from gymnasium.spaces import Box, Dict
|
@@ -1,215 +0,0 @@
|
||||
"""Utility functions for gymnasium spaces: batch space and iterator."""
|
||||
from collections import OrderedDict
|
||||
from copy import deepcopy
|
||||
from functools import singledispatch
|
||||
from typing import Iterator
|
||||
|
||||
import numpy as np
|
||||
|
||||
from gymnasium.error import CustomSpaceError
|
||||
from gymnasium.spaces import (
|
||||
Box,
|
||||
Dict,
|
||||
Discrete,
|
||||
MultiBinary,
|
||||
MultiDiscrete,
|
||||
Space,
|
||||
Tuple,
|
||||
)
|
||||
|
||||
|
||||
BaseGymSpaces = (Box, Discrete, MultiDiscrete, MultiBinary)
|
||||
_BaseGymSpaces = BaseGymSpaces
|
||||
__all__ = ["BaseGymSpaces", "_BaseGymSpaces", "batch_space", "iterate"]
|
||||
|
||||
|
||||
@singledispatch
|
||||
def batch_space(space: Space, n: int = 1) -> Space:
|
||||
"""Create a (batched) space, containing multiple copies of a single space.
|
||||
|
||||
Args:
|
||||
space: Space (e.g. the observation space) for a single environment in the vectorized environment.
|
||||
n: Number of environments in the vectorized environment.
|
||||
|
||||
Returns:
|
||||
Space (e.g. the observation space) for a batch of environments in the vectorized environment.
|
||||
|
||||
Raises:
|
||||
ValueError: Cannot batch space that is not a valid :class:`gym.Space` instance
|
||||
|
||||
Example:
|
||||
>>> from gymnasium.spaces import Box, Dict
|
||||
>>> import numpy as np
|
||||
>>> space = Dict({
|
||||
... 'position': Box(low=0, high=1, shape=(3,), dtype=np.float32),
|
||||
... 'velocity': Box(low=0, high=1, shape=(2,), dtype=np.float32)
|
||||
... })
|
||||
>>> batch_space(space, n=5)
|
||||
Dict('position': Box(0.0, 1.0, (5, 3), float32), 'velocity': Box(0.0, 1.0, (5, 2), float32))
|
||||
"""
|
||||
raise ValueError(
|
||||
f"Cannot batch space with type `{type(space)}`. The space must be a valid `gymnasium.Space` instance."
|
||||
)
|
||||
|
||||
|
||||
@batch_space.register(Box)
|
||||
def _batch_space_box(space, n=1):
|
||||
repeats = tuple([n] + [1] * space.low.ndim)
|
||||
low, high = np.tile(space.low, repeats), np.tile(space.high, repeats)
|
||||
return Box(low=low, high=high, dtype=space.dtype, seed=deepcopy(space.np_random))
|
||||
|
||||
|
||||
@batch_space.register(Discrete)
|
||||
def _batch_space_discrete(space, n=1):
|
||||
return MultiDiscrete(
|
||||
np.full((n,), space.n, dtype=space.dtype),
|
||||
dtype=space.dtype,
|
||||
seed=deepcopy(space.np_random),
|
||||
start=np.full((n,), space.start, dtype=space.dtype),
|
||||
)
|
||||
|
||||
|
||||
@batch_space.register(MultiDiscrete)
|
||||
def _batch_space_multidiscrete(space, n=1):
|
||||
repeats = tuple([n] + [1] * space.nvec.ndim)
|
||||
low = np.tile(space.start, repeats)
|
||||
high = low + np.tile(space.nvec, repeats) - 1
|
||||
return Box(
|
||||
low=low,
|
||||
high=high,
|
||||
dtype=space.dtype,
|
||||
seed=deepcopy(space.np_random),
|
||||
)
|
||||
|
||||
|
||||
@batch_space.register(MultiBinary)
|
||||
def _batch_space_multibinary(space, n=1):
|
||||
return Box(
|
||||
low=0,
|
||||
high=1,
|
||||
shape=(n,) + space.shape,
|
||||
dtype=space.dtype,
|
||||
seed=deepcopy(space.np_random),
|
||||
)
|
||||
|
||||
|
||||
@batch_space.register(Tuple)
|
||||
def _batch_space_tuple(space, n=1):
|
||||
return Tuple(
|
||||
tuple(batch_space(subspace, n=n) for subspace in space.spaces),
|
||||
seed=deepcopy(space.np_random),
|
||||
)
|
||||
|
||||
|
||||
@batch_space.register(Dict)
|
||||
def _batch_space_dict(space, n=1):
|
||||
return Dict(
|
||||
OrderedDict(
|
||||
[
|
||||
(key, batch_space(subspace, n=n))
|
||||
for (key, subspace) in space.spaces.items()
|
||||
]
|
||||
),
|
||||
seed=deepcopy(space.np_random),
|
||||
)
|
||||
|
||||
|
||||
@batch_space.register(Space)
|
||||
def _batch_space_custom(space, n=1):
|
||||
# Without deepcopy, then the space.np_random is batched_space.spaces[0].np_random
|
||||
# Which is an issue if you are sampling actions of both the original space and the batched space
|
||||
batched_space = Tuple(
|
||||
tuple(deepcopy(space) for _ in range(n)), seed=deepcopy(space.np_random)
|
||||
)
|
||||
new_seeds = list(map(int, batched_space.np_random.integers(0, 1e8, n)))
|
||||
batched_space.seed(new_seeds)
|
||||
return batched_space
|
||||
|
||||
|
||||
@singledispatch
|
||||
def iterate(space: Space, items) -> Iterator:
|
||||
"""Iterate over the elements of a (batched) space.
|
||||
|
||||
Args:
|
||||
space: Space to which `items` belong to.
|
||||
items: Items to be iterated over.
|
||||
|
||||
Returns:
|
||||
Iterator over the elements in `items`.
|
||||
|
||||
Raises:
|
||||
ValueError: Space is not an instance of :class:`gym.Space`
|
||||
|
||||
Example:
|
||||
>>> from gymnasium.spaces import Box, Dict
|
||||
>>> import numpy as np
|
||||
>>> space = Dict({
|
||||
... 'position': Box(low=0, high=1, shape=(2, 3), seed=42, dtype=np.float32),
|
||||
... 'velocity': Box(low=0, high=1, shape=(2, 2), seed=42, dtype=np.float32)})
|
||||
>>> items = space.sample()
|
||||
>>> it = iterate(space, items)
|
||||
>>> next(it)
|
||||
OrderedDict([('position', array([0.77395606, 0.43887845, 0.85859793], dtype=float32)), ('velocity', array([0.77395606, 0.43887845], dtype=float32))])
|
||||
>>> next(it)
|
||||
OrderedDict([('position', array([0.697368 , 0.09417735, 0.97562236], dtype=float32)), ('velocity', array([0.85859793, 0.697368 ], dtype=float32))])
|
||||
>>> next(it)
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
StopIteration
|
||||
"""
|
||||
raise ValueError(
|
||||
f"Space of type `{type(space)}` is not a valid `gymnasium.Space` instance."
|
||||
)
|
||||
|
||||
|
||||
@iterate.register(Discrete)
|
||||
def _iterate_discrete(space, items):
|
||||
raise TypeError("Unable to iterate over a space of type `Discrete`.")
|
||||
|
||||
|
||||
@iterate.register(Box)
|
||||
@iterate.register(MultiDiscrete)
|
||||
@iterate.register(MultiBinary)
|
||||
def _iterate_base(space, items):
|
||||
try:
|
||||
return iter(items)
|
||||
except TypeError as e:
|
||||
raise TypeError(
|
||||
f"Unable to iterate over the following elements: {items}"
|
||||
) from e
|
||||
|
||||
|
||||
@iterate.register(Tuple)
|
||||
def _iterate_tuple(space, items):
|
||||
# If this is a tuple of custom subspaces only, then simply iterate over items
|
||||
if all(
|
||||
isinstance(subspace, Space)
|
||||
and (not isinstance(subspace, BaseGymSpaces + (Tuple, Dict)))
|
||||
for subspace in space.spaces
|
||||
):
|
||||
return iter(items)
|
||||
|
||||
return zip(
|
||||
*[iterate(subspace, items[i]) for i, subspace in enumerate(space.spaces)]
|
||||
)
|
||||
|
||||
|
||||
@iterate.register(Dict)
|
||||
def _iterate_dict(space, items):
|
||||
keys, values = zip(
|
||||
*[
|
||||
(key, iterate(subspace, items[key]))
|
||||
for key, subspace in space.spaces.items()
|
||||
]
|
||||
)
|
||||
for item in zip(*values):
|
||||
yield OrderedDict([(key, value) for (key, value) in zip(keys, item)])
|
||||
|
||||
|
||||
@iterate.register(Space)
|
||||
def _iterate_custom(space, items):
|
||||
raise CustomSpaceError(
|
||||
f"Unable to iterate over {items}, since {space} "
|
||||
"is a custom `gymnasium.Space` instance (i.e. not one of "
|
||||
"`Box`, `Dict`, etc...)."
|
||||
)
|
@@ -1,30 +1,46 @@
|
||||
"""Base class for vectorized environments."""
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Generic, TypeVar
|
||||
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.vector.utils.spaces import batch_space
|
||||
from gymnasium.core import ActType, ObsType, RenderFrame
|
||||
from gymnasium.utils import seeding
|
||||
|
||||
|
||||
__all__ = ["VectorEnv"]
|
||||
if TYPE_CHECKING:
|
||||
from gymnasium.envs.registration import EnvSpec
|
||||
|
||||
ArrayType = TypeVar("ArrayType")
|
||||
|
||||
|
||||
class VectorEnv(gym.Env):
|
||||
__all__ = [
|
||||
"VectorEnv",
|
||||
"VectorWrapper",
|
||||
"VectorObservationWrapper",
|
||||
"VectorActionWrapper",
|
||||
"VectorRewardWrapper",
|
||||
"ArrayType",
|
||||
]
|
||||
|
||||
|
||||
class VectorEnv(Generic[ObsType, ActType, ArrayType]):
|
||||
"""Base class for vectorized environments to run multiple independent copies of the same environment in parallel.
|
||||
|
||||
Vector environments can provide a linear speed-up in the steps taken per second through sampling multiple
|
||||
sub-environments at the same time. To prevent terminated environments waiting until all sub-environments have
|
||||
terminated or truncated, the vector environments autoreset sub-environments after they terminate or truncated.
|
||||
As a result, the final step's observation and info are overwritten by the reset's observation and info.
|
||||
Therefore, the observation and info for the final step of a sub-environment is stored in the info parameter,
|
||||
terminated or truncated, the vector environments automatically reset sub-environments after they terminate or truncated (within the same step call).
|
||||
As a result, the step's observation and info are overwritten by the reset's observation and info.
|
||||
To preserve this data, the observation and info for the final step of a sub-environment is stored in the info parameter,
|
||||
using `"final_observation"` and `"final_info"` respectively. See :meth:`step` for more information.
|
||||
|
||||
The vector environments batch `observations`, `rewards`, `terminations`, `truncations` and `info` for each
|
||||
parallel environment. In addition, :meth:`step` expects to receive a batch of actions for each parallel environment.
|
||||
The vector environments batches `observations`, `rewards`, `terminations`, `truncations` and `info` for each
|
||||
sub-environment. In addition, :meth:`step` expects to receive a batch of actions for each parallel environment.
|
||||
|
||||
Gymnasium contains two types of Vector environments: :class:`AsyncVectorEnv` and :class:`SyncVectorEnv`.
|
||||
Gymnasium contains two generalised Vector environments: :class:`AsyncVectorEnv` and :class:`SyncVectorEnv` along with
|
||||
several custom vector environment implementations.
|
||||
|
||||
The Vector Environments have the additional attributes for users to understand the implementation
|
||||
|
||||
@@ -34,89 +50,67 @@ class VectorEnv(gym.Env):
|
||||
- :attr:`action_space` - The batched action space of the vector environment
|
||||
- :attr:`single_action_space` - The action space of a single sub-environment
|
||||
|
||||
Note:
|
||||
The info parameter of :meth:`reset` and :meth:`step` was originally implemented before OpenAI Gym v25 was a list
|
||||
of dictionary for each sub-environment. However, this was modified in OpenAI Gym v25+ and in Gymnasium to a
|
||||
dictionary with a NumPy array for each key. To use the old info style using the :class:`VectorListInfo`.
|
||||
Examples:
|
||||
>>> import gymnasium as gym
|
||||
>>> envs = gym.make_vec("CartPole-v1", num_envs=3, vectorization_mode="sync", wrappers=(gym.wrappers.TimeAwareObservation,))
|
||||
>>> envs = gym.wrappers.vector.ClipReward(envs, min_reward=0.2, max_reward=0.8)
|
||||
>>> envs
|
||||
<ClipReward, SyncVectorEnv(CartPole-v1, num_envs=3)>
|
||||
>>> observations, infos = envs.reset(seed=123)
|
||||
>>> observations
|
||||
array([[ 0.01823519, -0.0446179 , -0.02796401, -0.03156282, 0. ],
|
||||
[ 0.02852531, 0.02858594, 0.0469136 , 0.02480598, 0. ],
|
||||
[ 0.03517495, -0.000635 , -0.01098382, -0.03203924, 0. ]])
|
||||
>>> infos
|
||||
{}
|
||||
>>> _ = envs.action_space.seed(123)
|
||||
>>> observations, rewards, terminations, truncations, infos = envs.step(envs.action_space.sample())
|
||||
>>> observations
|
||||
array([[ 0.01734283, 0.15089367, -0.02859527, -0.33293587, 1. ],
|
||||
[ 0.02909703, -0.16717631, 0.04740972, 0.3319138 , 1. ],
|
||||
[ 0.03516225, -0.19559774, -0.01162461, 0.25715804, 1. ]])
|
||||
>>> rewards
|
||||
array([0.8, 0.8, 0.8])
|
||||
>>> terminations
|
||||
array([False, False, False])
|
||||
>>> truncations
|
||||
array([False, False, False])
|
||||
>>> infos
|
||||
{}
|
||||
>>> envs.close()
|
||||
|
||||
Note:
|
||||
To render the sub-environments, use :meth:`call` with "render" arguments. Remember to set the `render_modes`
|
||||
for all the sub-environments during initialization.
|
||||
The info parameter of :meth:`reset` and :meth:`step` was originally implemented before v0.25 as a list
|
||||
of dictionary for each sub-environment. However, this was modified in v0.25+ to be a
|
||||
dictionary with a NumPy array for each key. To use the old info style, utilise the :class:`DictInfoToList` wrapper.
|
||||
|
||||
Note:
|
||||
All parallel environments should share the identical observation and action spaces.
|
||||
In other words, a vector of multiple different environments is not supported.
|
||||
|
||||
Note:
|
||||
:func:`make_vec` is the equivalent function to :func:`make` for vector environments.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_envs: int,
|
||||
observation_space: gym.Space,
|
||||
action_space: gym.Space,
|
||||
):
|
||||
"""Base class for vectorized environments.
|
||||
spec: EnvSpec | None = None
|
||||
render_mode: str | None = None
|
||||
closed: bool = False
|
||||
|
||||
Args:
|
||||
num_envs: Number of environments in the vectorized environment.
|
||||
observation_space: Observation space of a single environment.
|
||||
action_space: Action space of a single environment.
|
||||
"""
|
||||
self.num_envs = num_envs
|
||||
self.is_vector_env = True
|
||||
self.observation_space = batch_space(observation_space, n=num_envs)
|
||||
self.action_space = batch_space(action_space, n=num_envs)
|
||||
observation_space: gym.Space
|
||||
action_space: gym.Space
|
||||
single_observation_space: gym.Space
|
||||
single_action_space: gym.Space
|
||||
|
||||
self.closed = False
|
||||
self.viewer = None
|
||||
num_envs: int
|
||||
|
||||
# The observation and action spaces of a single environment are
|
||||
# kept in separate properties
|
||||
self.single_observation_space = observation_space
|
||||
self.single_action_space = action_space
|
||||
|
||||
def reset_async(
|
||||
self,
|
||||
seed: Optional[Union[int, List[int]]] = None,
|
||||
options: Optional[dict] = None,
|
||||
):
|
||||
"""Reset the sub-environments asynchronously.
|
||||
|
||||
This method will return ``None``. A call to :meth:`reset_async` should be followed
|
||||
by a call to :meth:`reset_wait` to retrieve the results.
|
||||
|
||||
Args:
|
||||
seed: The reset seed
|
||||
options: Reset options
|
||||
"""
|
||||
pass
|
||||
|
||||
def reset_wait(
|
||||
self,
|
||||
seed: Optional[Union[int, List[int]]] = None,
|
||||
options: Optional[dict] = None,
|
||||
):
|
||||
"""Retrieves the results of a :meth:`reset_async` call.
|
||||
|
||||
A call to this method must always be preceded by a call to :meth:`reset_async`.
|
||||
|
||||
Args:
|
||||
seed: The reset seed
|
||||
options: Reset options
|
||||
|
||||
Returns:
|
||||
The results from :meth:`reset_async`
|
||||
|
||||
Raises:
|
||||
NotImplementedError: VectorEnv does not implement function
|
||||
"""
|
||||
raise NotImplementedError("VectorEnv does not implement function")
|
||||
_np_random: np.random.Generator | None = None
|
||||
|
||||
def reset(
|
||||
self,
|
||||
*,
|
||||
seed: Optional[Union[int, List[int]]] = None,
|
||||
options: Optional[dict] = None,
|
||||
):
|
||||
seed: int | list[int] | None = None,
|
||||
options: dict[str, Any] | None = None,
|
||||
) -> tuple[ObsType, dict[str, Any]]: # type: ignore
|
||||
"""Reset all parallel environments and return a batch of initial observations and info.
|
||||
|
||||
Args:
|
||||
@@ -128,47 +122,26 @@ class VectorEnv(gym.Env):
|
||||
|
||||
Example:
|
||||
>>> import gymnasium as gym
|
||||
>>> envs = gym.vector.make("CartPole-v1", num_envs=3)
|
||||
>>> envs.reset(seed=42)
|
||||
(array([[ 0.0273956 , -0.00611216, 0.03585979, 0.0197368 ],
|
||||
>>> envs = gym.make_vec("CartPole-v1", num_envs=3, vectorization_mode="sync")
|
||||
>>> observations, infos = envs.reset(seed=42)
|
||||
>>> observations
|
||||
array([[ 0.0273956 , -0.00611216, 0.03585979, 0.0197368 ],
|
||||
[ 0.01522993, -0.04562247, -0.04799704, 0.03392126],
|
||||
[-0.03774345, -0.02418869, -0.00942293, 0.0469184 ]],
|
||||
dtype=float32), {})
|
||||
dtype=float32)
|
||||
>>> infos
|
||||
{}
|
||||
"""
|
||||
self.reset_async(seed=seed, options=options)
|
||||
return self.reset_wait(seed=seed, options=options)
|
||||
|
||||
def step_async(self, actions):
|
||||
"""Asynchronously performs steps in the sub-environments.
|
||||
|
||||
The results can be retrieved via a call to :meth:`step_wait`.
|
||||
|
||||
Args:
|
||||
actions: The actions to take asynchronously
|
||||
"""
|
||||
|
||||
def step_wait(
|
||||
self, **kwargs
|
||||
) -> Tuple[Any, NDArray[Any], NDArray[Any], NDArray[Any], dict]:
|
||||
"""Retrieves the results of a :meth:`step_async` call.
|
||||
|
||||
A call to this method must always be preceded by a call to :meth:`step_async`.
|
||||
|
||||
Args:
|
||||
**kwargs: Additional keywords for vector implementation
|
||||
|
||||
Returns:
|
||||
The results from the :meth:`step_async` call
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
if seed is not None:
|
||||
self._np_random, seed = seeding.np_random(seed)
|
||||
|
||||
def step(
|
||||
self, actions
|
||||
) -> Tuple[Any, NDArray[Any], NDArray[Any], NDArray[Any], dict]:
|
||||
self, actions: ActType
|
||||
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict[str, Any]]:
|
||||
"""Take an action for each parallel environment.
|
||||
|
||||
Args:
|
||||
actions: element of :attr:`action_space` Batch of actions.
|
||||
actions: Batch of actions with the :attr:`action_space` shape.
|
||||
|
||||
Returns:
|
||||
Batch of (observations, rewards, terminations, truncations, infos)
|
||||
@@ -181,10 +154,10 @@ class VectorEnv(gym.Env):
|
||||
Example:
|
||||
>>> import gymnasium as gym
|
||||
>>> import numpy as np
|
||||
>>> envs = gym.vector.make("CartPole-v1", num_envs=3)
|
||||
>>> envs = gym.make_vec("CartPole-v1", num_envs=3, vectorization_mode="sync")
|
||||
>>> _ = envs.reset(seed=42)
|
||||
>>> actions = np.array([1, 0, 1])
|
||||
>>> observations, rewards, termination, truncation, infos = envs.step(actions)
|
||||
>>> actions = np.array([1, 0, 1], dtype=np.int32)
|
||||
>>> observations, rewards, terminations, truncations, infos = envs.step(actions)
|
||||
>>> observations
|
||||
array([[ 0.02727336, 0.18847767, 0.03625453, -0.26141977],
|
||||
[ 0.01431748, -0.24002443, -0.04731862, 0.3110827 ],
|
||||
@@ -192,62 +165,25 @@ class VectorEnv(gym.Env):
|
||||
dtype=float32)
|
||||
>>> rewards
|
||||
array([1., 1., 1.])
|
||||
>>> termination
|
||||
>>> terminations
|
||||
array([False, False, False])
|
||||
>>> truncation
|
||||
>>> terminations
|
||||
array([False, False, False])
|
||||
>>> infos
|
||||
{}
|
||||
"""
|
||||
self.step_async(actions)
|
||||
return self.step_wait()
|
||||
|
||||
def call_async(self, name, *args, **kwargs):
|
||||
"""Calls a method name for each parallel environment asynchronously."""
|
||||
|
||||
def call_wait(self, **kwargs) -> List[Any]: # type: ignore
|
||||
"""After calling a method in :meth:`call_async`, this function collects the results."""
|
||||
|
||||
def call(self, name: str, *args, **kwargs) -> List[Any]:
|
||||
"""Call a method, or get a property, from each parallel environment.
|
||||
|
||||
Args:
|
||||
name (str): Name of the method or property to call.
|
||||
*args: Arguments to apply to the method call.
|
||||
**kwargs: Keyword arguments to apply to the method call.
|
||||
def render(self) -> tuple[RenderFrame, ...] | None:
|
||||
"""Returns the rendered frames from the parallel environments.
|
||||
|
||||
Returns:
|
||||
List of the results of the individual calls to the method or property for each environment.
|
||||
A tuple of rendered frames from the parallel environments
|
||||
"""
|
||||
self.call_async(name, *args, **kwargs)
|
||||
return self.call_wait()
|
||||
raise NotImplementedError(
|
||||
f"{self.__str__()} render function is not implemented."
|
||||
)
|
||||
|
||||
def get_attr(self, name: str):
|
||||
"""Get a property from each parallel environment.
|
||||
|
||||
Args:
|
||||
name (str): Name of the property to be get from each individual environment.
|
||||
|
||||
Returns:
|
||||
The property with name
|
||||
"""
|
||||
return self.call(name)
|
||||
|
||||
def set_attr(self, name: str, values: Union[list, tuple, object]):
|
||||
"""Set a property in each sub-environment.
|
||||
|
||||
Args:
|
||||
name (str): Name of the property to be set in each individual environment.
|
||||
values (list, tuple, or object): Values of the property to be set to. If `values` is a list or
|
||||
tuple, then it corresponds to the values for each individual environment, otherwise a single value
|
||||
is set for all environments.
|
||||
"""
|
||||
|
||||
def close_extras(self, **kwargs):
|
||||
"""Clean up the extra resources e.g. beyond what's in this base class."""
|
||||
pass
|
||||
|
||||
def close(self, **kwargs):
|
||||
def close(self, **kwargs: Any):
|
||||
"""Close all parallel environments and release resources.
|
||||
|
||||
It also closes all the existing image viewers, then calls :meth:`close_extras` and set
|
||||
@@ -266,12 +202,37 @@ class VectorEnv(gym.Env):
|
||||
"""
|
||||
if self.closed:
|
||||
return
|
||||
if self.viewer is not None:
|
||||
self.viewer.close()
|
||||
|
||||
self.close_extras(**kwargs)
|
||||
self.closed = True
|
||||
|
||||
def _add_info(self, infos: dict, info: dict, env_num: int) -> dict:
|
||||
def close_extras(self, **kwargs: Any):
|
||||
"""Clean up the extra resources e.g. beyond what's in this base class."""
|
||||
pass
|
||||
|
||||
@property
|
||||
def np_random(self) -> np.random.Generator:
|
||||
"""Returns the environment's internal :attr:`_np_random` that if not set will initialise with a random seed.
|
||||
|
||||
Returns:
|
||||
Instances of `np.random.Generator`
|
||||
"""
|
||||
if self._np_random is None:
|
||||
self._np_random, seed = seeding.np_random()
|
||||
return self._np_random
|
||||
|
||||
@np_random.setter
|
||||
def np_random(self, value: np.random.Generator):
|
||||
self._np_random = value
|
||||
|
||||
@property
|
||||
def unwrapped(self):
|
||||
"""Return the base environment."""
|
||||
return self
|
||||
|
||||
def _add_info(
|
||||
self, infos: dict[str, Any], info: dict[str, Any], env_num: int
|
||||
) -> dict[str, Any]:
|
||||
"""Add env info to the info dictionary of the vectorized environment.
|
||||
|
||||
Given the `info` of a single environment add it to the `infos` dictionary
|
||||
@@ -298,7 +259,7 @@ class VectorEnv(gym.Env):
|
||||
infos[k], infos[f"_{k}"] = info_array, array_mask
|
||||
return infos
|
||||
|
||||
def _init_info_arrays(self, dtype: type) -> Tuple[np.ndarray, np.ndarray]:
|
||||
def _init_info_arrays(self, dtype: type) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""Initialize the info array.
|
||||
|
||||
Initialize the info array. If the dtype is numeric
|
||||
@@ -335,12 +296,14 @@ class VectorEnv(gym.Env):
|
||||
A string containing the class name, number of environments and environment spec id
|
||||
"""
|
||||
if self.spec is None:
|
||||
return f"{self.__class__.__name__}({self.num_envs})"
|
||||
return f"{self.__class__.__name__}(num_envs={self.num_envs})"
|
||||
else:
|
||||
return f"{self.__class__.__name__}({self.spec.id}, {self.num_envs})"
|
||||
return (
|
||||
f"{self.__class__.__name__}({self.spec.id}, num_envs={self.num_envs})"
|
||||
)
|
||||
|
||||
|
||||
class VectorEnvWrapper(VectorEnv):
|
||||
class VectorWrapper(VectorEnv):
|
||||
"""Wraps the vectorized environment to allow a modular transformation.
|
||||
|
||||
This class is the base class for all wrappers for vectorized environments. The subclass
|
||||
@@ -352,47 +315,223 @@ class VectorEnvWrapper(VectorEnv):
|
||||
"""
|
||||
|
||||
def __init__(self, env: VectorEnv):
|
||||
assert isinstance(env, VectorEnv)
|
||||
"""Initialize the vectorized environment wrapper.
|
||||
|
||||
Args:
|
||||
env: The environment to wrap
|
||||
"""
|
||||
self.env = env
|
||||
assert isinstance(env, VectorEnv)
|
||||
|
||||
# explicitly forward the methods defined in VectorEnv
|
||||
# to self.env (instead of the base class)
|
||||
def reset_async(self, **kwargs):
|
||||
return self.env.reset_async(**kwargs)
|
||||
self._observation_space: gym.Space | None = None
|
||||
self._action_space: gym.Space | None = None
|
||||
self._single_observation_space: gym.Space | None = None
|
||||
self._single_action_space: gym.Space | None = None
|
||||
|
||||
def reset_wait(self, **kwargs):
|
||||
return self.env.reset_wait(**kwargs)
|
||||
def reset(
|
||||
self,
|
||||
*,
|
||||
seed: int | list[int] | None = None,
|
||||
options: dict[str, Any] | None = None,
|
||||
) -> tuple[ObsType, dict[str, Any]]:
|
||||
"""Reset all environment using seed and options."""
|
||||
return self.env.reset(seed=seed, options=options)
|
||||
|
||||
def step_async(self, actions):
|
||||
return self.env.step_async(actions)
|
||||
def step(
|
||||
self, actions: ActType
|
||||
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict[str, Any]]:
|
||||
"""Step through all environments using the actions returning the batched data."""
|
||||
return self.env.step(actions)
|
||||
|
||||
def step_wait(self):
|
||||
return self.env.step_wait()
|
||||
def render(self) -> tuple[RenderFrame, ...] | None:
|
||||
"""Returns the render mode from the base vector environment."""
|
||||
return self.env.render()
|
||||
|
||||
def close(self, **kwargs):
|
||||
def close(self, **kwargs: Any):
|
||||
"""Close all environments."""
|
||||
return self.env.close(**kwargs)
|
||||
|
||||
def close_extras(self, **kwargs):
|
||||
def close_extras(self, **kwargs: Any):
|
||||
"""Close all extra resources."""
|
||||
return self.env.close_extras(**kwargs)
|
||||
|
||||
def call(self, name, *args, **kwargs):
|
||||
return self.env.call(name, *args, **kwargs)
|
||||
|
||||
def set_attr(self, name, values):
|
||||
return self.env.set_attr(name, values)
|
||||
|
||||
# implicitly forward all other methods and attributes to self.env
|
||||
def __getattr__(self, name):
|
||||
if name.startswith("_"):
|
||||
raise AttributeError(f"attempted to get missing private attribute '{name}'")
|
||||
return getattr(self.env, name)
|
||||
|
||||
@property
|
||||
def unwrapped(self):
|
||||
"""Return the base non-wrapped environment."""
|
||||
return self.env.unwrapped
|
||||
|
||||
def __repr__(self):
|
||||
"""Return the string representation of the vectorized environment."""
|
||||
return f"<{self.__class__.__name__}, {self.env}>"
|
||||
|
||||
def __del__(self):
|
||||
self.env.__del__()
|
||||
@property
|
||||
def spec(self) -> EnvSpec | None:
|
||||
"""Gets the specification of the wrapped environment."""
|
||||
return self.env.spec
|
||||
|
||||
@property
|
||||
def observation_space(self) -> gym.Space:
|
||||
"""Gets the observation space of the vector environment."""
|
||||
if self._observation_space is None:
|
||||
return self.env.observation_space
|
||||
return self._observation_space
|
||||
|
||||
@observation_space.setter
|
||||
def observation_space(self, space: gym.Space):
|
||||
"""Sets the observation space of the vector environment."""
|
||||
self._observation_space = space
|
||||
|
||||
@property
|
||||
def action_space(self) -> gym.Space:
|
||||
"""Gets the action space of the vector environment."""
|
||||
if self._action_space is None:
|
||||
return self.env.action_space
|
||||
return self._action_space
|
||||
|
||||
@action_space.setter
|
||||
def action_space(self, space: gym.Space):
|
||||
"""Sets the action space of the vector environment."""
|
||||
self._action_space = space
|
||||
|
||||
@property
|
||||
def single_observation_space(self) -> gym.Space:
|
||||
"""Gets the single observation space of the vector environment."""
|
||||
if self._single_observation_space is None:
|
||||
return self.env.single_observation_space
|
||||
return self._single_observation_space
|
||||
|
||||
@single_observation_space.setter
|
||||
def single_observation_space(self, space: gym.Space):
|
||||
"""Sets the single observation space of the vector environment."""
|
||||
self._single_observation_space = space
|
||||
|
||||
@property
|
||||
def single_action_space(self) -> gym.Space:
|
||||
"""Gets the single action space of the vector environment."""
|
||||
if self._single_action_space is None:
|
||||
return self.env.single_action_space
|
||||
return self._single_action_space
|
||||
|
||||
@single_action_space.setter
|
||||
def single_action_space(self, space):
|
||||
"""Sets the single action space of the vector environment."""
|
||||
self._single_action_space = space
|
||||
|
||||
@property
|
||||
def num_envs(self) -> int:
|
||||
"""Gets the wrapped vector environment's num of the sub-environments."""
|
||||
return self.env.num_envs
|
||||
|
||||
@property
|
||||
def render_mode(self) -> tuple[RenderFrame, ...] | None:
|
||||
"""Returns the `render_mode` from the base environment."""
|
||||
return self.env.render_mode
|
||||
|
||||
|
||||
class VectorObservationWrapper(VectorWrapper):
|
||||
"""Wraps the vectorized environment to allow a modular transformation of the observation.
|
||||
|
||||
Equivalent to :class:`gymnasium.ObservationWrapper` for vectorized environments.
|
||||
"""
|
||||
|
||||
def reset(
|
||||
self,
|
||||
*,
|
||||
seed: int | list[int] | None = None,
|
||||
options: dict[str, Any] | None = None,
|
||||
) -> tuple[ObsType, dict[str, Any]]:
|
||||
"""Modifies the observation returned from the environment ``reset`` using the :meth:`observation`."""
|
||||
obs, info = self.env.reset(seed=seed, options=options)
|
||||
return self.vector_observation(obs), info
|
||||
|
||||
def step(
|
||||
self, actions: ActType
|
||||
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict[str, Any]]:
|
||||
"""Modifies the observation returned from the environment ``step`` using the :meth:`observation`."""
|
||||
observation, reward, termination, truncation, info = self.env.step(actions)
|
||||
return (
|
||||
self.vector_observation(observation),
|
||||
reward,
|
||||
termination,
|
||||
truncation,
|
||||
self.update_final_obs(info),
|
||||
)
|
||||
|
||||
def vector_observation(self, observation: ObsType) -> ObsType:
|
||||
"""Defines the vector observation transformation.
|
||||
|
||||
Args:
|
||||
observation: A vector observation from the environment
|
||||
|
||||
Returns:
|
||||
the transformed observation
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def single_observation(self, observation: ObsType) -> ObsType:
|
||||
"""Defines the single observation transformation.
|
||||
|
||||
Args:
|
||||
observation: A single observation from the environment
|
||||
|
||||
Returns:
|
||||
The transformed observation
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def update_final_obs(self, info: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Updates the `final_obs` in the info using `single_observation`."""
|
||||
if "final_observation" in info:
|
||||
for i, obs in enumerate(info["final_observation"]):
|
||||
if obs is not None:
|
||||
info["final_observation"][i] = self.single_observation(obs)
|
||||
return info
|
||||
|
||||
|
||||
class VectorActionWrapper(VectorWrapper):
|
||||
"""Wraps the vectorized environment to allow a modular transformation of the actions.
|
||||
|
||||
Equivalent of :class:`gymnasium.ActionWrapper` for vectorized environments.
|
||||
"""
|
||||
|
||||
def step(
|
||||
self, actions: ActType
|
||||
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict[str, Any]]:
|
||||
"""Steps through the environment using a modified action by :meth:`action`."""
|
||||
return self.env.step(self.actions(actions))
|
||||
|
||||
def actions(self, actions: ActType) -> ActType:
|
||||
"""Transform the actions before sending them to the environment.
|
||||
|
||||
Args:
|
||||
actions (ActType): the actions to transform
|
||||
|
||||
Returns:
|
||||
ActType: the transformed actions
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class VectorRewardWrapper(VectorWrapper):
|
||||
"""Wraps the vectorized environment to allow a modular transformation of the reward.
|
||||
|
||||
Equivalent of :class:`gymnasium.RewardWrapper` for vectorized environments.
|
||||
"""
|
||||
|
||||
def step(
|
||||
self, actions: ActType
|
||||
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict[str, Any]]:
|
||||
"""Steps through the environment returning a reward modified by :meth:`reward`."""
|
||||
observation, reward, termination, truncation, info = self.env.step(actions)
|
||||
return observation, self.rewards(reward), termination, truncation, info
|
||||
|
||||
def rewards(self, reward: ArrayType) -> ArrayType:
|
||||
"""Transform the reward before returning it.
|
||||
|
||||
Args:
|
||||
reward (array): the reward to transform
|
||||
|
||||
Returns:
|
||||
array: the transformed reward
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
@@ -1,18 +0,0 @@
|
||||
# Wrappers
|
||||
|
||||
Wrappers are used to transform an environment in a modular way:
|
||||
|
||||
```python
|
||||
import gymnasium as gym
|
||||
env = gym.make('CartPole-v1')
|
||||
env = MyWrapper(env)
|
||||
```
|
||||
|
||||
## Quick tips for writing your own wrapper
|
||||
|
||||
- Don't forget to call `super(class_name, self).__init__(env)` if you override the wrapper's `__init__` function
|
||||
- You can access the inner environment with `self.unwrapped`
|
||||
- You can access the previous wrapper using `self.env`
|
||||
- The variables `metadata`, `action_space`, `observation_space`, `reward_range`, and `spec` are copied to `self` from the previous layer
|
||||
- Create a wrapped function for at least one of the following: `__init__(self, env)`, `step`, `reset`, `render`, `close`, or `seed`
|
||||
- Your layered function should take its input from the previous layer (`self.env`) and/or the inner layer (`self.unwrapped`)
|
@@ -1,9 +1,8 @@
|
||||
"""Module of wrapper classes.
|
||||
"""Wrappers are a convenient way to modify an existing environment without having to alter the underlying code directly.
|
||||
|
||||
Wrappers are a convenient way to modify an existing environment without having to alter the underlying code directly.
|
||||
Using wrappers will allow you to avoid a lot of boilerplate code and make your environment more modular. Wrappers can
|
||||
also be chained to combine their effects.
|
||||
Most environments that are generated via :meth:`gymnasium.make` will already be wrapped by default.
|
||||
Using wrappers will allow you to avoid a lot of boilerplate code and make your environment more modular.
|
||||
Importantly wrappers can be chained to combine their effects and most environments that are generated via
|
||||
:meth:`gymnasium.make` will already be wrapped by default.
|
||||
|
||||
In order to wrap an environment, you must first initialize a base environment. Then you can pass this environment along
|
||||
with (possibly optional) parameters to the wrapper's constructor.
|
||||
@@ -46,27 +45,135 @@ If you need a wrapper to do more complicated tasks, you can inherit from the :cl
|
||||
|
||||
If you'd like to implement your own custom wrapper, check out `the corresponding tutorial <../../tutorials/implementing_custom_wrappers>`_.
|
||||
"""
|
||||
# pyright: reportUnsupportedDunderAll=false
|
||||
import importlib
|
||||
import re
|
||||
|
||||
from gymnasium.error import DeprecatedWrapper
|
||||
from gymnasium.wrappers import vector
|
||||
from gymnasium.wrappers.atari_preprocessing import AtariPreprocessing
|
||||
from gymnasium.wrappers.autoreset import AutoResetWrapper
|
||||
from gymnasium.wrappers.clip_action import ClipAction
|
||||
from gymnasium.wrappers.compatibility import EnvCompatibility
|
||||
from gymnasium.wrappers.env_checker import PassiveEnvChecker
|
||||
from gymnasium.wrappers.filter_observation import FilterObservation
|
||||
from gymnasium.wrappers.flatten_observation import FlattenObservation
|
||||
from gymnasium.wrappers.frame_stack import FrameStack, LazyFrames
|
||||
from gymnasium.wrappers.gray_scale_observation import GrayScaleObservation
|
||||
from gymnasium.wrappers.human_rendering import HumanRendering
|
||||
from gymnasium.wrappers.normalize import NormalizeObservation, NormalizeReward
|
||||
from gymnasium.wrappers.order_enforcing import OrderEnforcing
|
||||
from gymnasium.wrappers.pixel_observation import PixelObservationWrapper
|
||||
from gymnasium.wrappers.record_episode_statistics import RecordEpisodeStatistics
|
||||
from gymnasium.wrappers.record_video import RecordVideo, capped_cubic_video_schedule
|
||||
from gymnasium.wrappers.render_collection import RenderCollection
|
||||
from gymnasium.wrappers.rescale_action import RescaleAction
|
||||
from gymnasium.wrappers.resize_observation import ResizeObservation
|
||||
from gymnasium.wrappers.step_api_compatibility import StepAPICompatibility
|
||||
from gymnasium.wrappers.time_aware_observation import TimeAwareObservation
|
||||
from gymnasium.wrappers.time_limit import TimeLimit
|
||||
from gymnasium.wrappers.transform_observation import TransformObservation
|
||||
from gymnasium.wrappers.transform_reward import TransformReward
|
||||
from gymnasium.wrappers.vector_list_info import VectorListInfo
|
||||
from gymnasium.wrappers.common import (
|
||||
Autoreset,
|
||||
OrderEnforcing,
|
||||
PassiveEnvChecker,
|
||||
RecordEpisodeStatistics,
|
||||
TimeLimit,
|
||||
)
|
||||
from gymnasium.wrappers.rendering import HumanRendering, RecordVideo, RenderCollection
|
||||
from gymnasium.wrappers.stateful_action import StickyAction
|
||||
from gymnasium.wrappers.stateful_observation import (
|
||||
DelayObservation,
|
||||
FrameStackObservation,
|
||||
MaxAndSkipObservation,
|
||||
NormalizeObservation,
|
||||
TimeAwareObservation,
|
||||
)
|
||||
from gymnasium.wrappers.stateful_reward import NormalizeReward
|
||||
from gymnasium.wrappers.transform_action import (
|
||||
ClipAction,
|
||||
RescaleAction,
|
||||
TransformAction,
|
||||
)
|
||||
from gymnasium.wrappers.transform_observation import (
|
||||
DtypeObservation,
|
||||
FilterObservation,
|
||||
FlattenObservation,
|
||||
GrayscaleObservation,
|
||||
RenderObservation,
|
||||
RescaleObservation,
|
||||
ReshapeObservation,
|
||||
ResizeObservation,
|
||||
TransformObservation,
|
||||
)
|
||||
from gymnasium.wrappers.transform_reward import ClipReward, TransformReward
|
||||
|
||||
|
||||
__all__ = [
|
||||
"vector",
|
||||
# --- Observation wrappers ---
|
||||
"AtariPreprocessing",
|
||||
"DelayObservation",
|
||||
"DtypeObservation",
|
||||
"FilterObservation",
|
||||
"FlattenObservation",
|
||||
"FrameStackObservation",
|
||||
"GrayscaleObservation",
|
||||
"TransformObservation",
|
||||
"MaxAndSkipObservation",
|
||||
"NormalizeObservation",
|
||||
"RenderObservation",
|
||||
"ResizeObservation",
|
||||
"ReshapeObservation",
|
||||
"RescaleObservation",
|
||||
"TimeAwareObservation",
|
||||
# --- Action Wrappers ---
|
||||
"ClipAction",
|
||||
"TransformAction",
|
||||
"RescaleAction",
|
||||
# "NanAction",
|
||||
"StickyAction",
|
||||
# --- Reward wrappers ---
|
||||
"ClipReward",
|
||||
"TransformReward",
|
||||
"NormalizeReward",
|
||||
# --- Common ---
|
||||
"TimeLimit",
|
||||
"Autoreset",
|
||||
"PassiveEnvChecker",
|
||||
"OrderEnforcing",
|
||||
"RecordEpisodeStatistics",
|
||||
# --- Rendering ---
|
||||
"RenderCollection",
|
||||
"RecordVideo",
|
||||
"HumanRendering",
|
||||
# --- Conversion ---
|
||||
"JaxToNumpy",
|
||||
"JaxToTorch",
|
||||
"NumpyToTorch",
|
||||
]
|
||||
|
||||
# As these wrappers requires `jax` or `torch`, they are loaded by runtime for users trying to access them
|
||||
# to avoid `import jax` or `import torch` on `import gymnasium`.
|
||||
_wrapper_to_class = {
|
||||
# data converters
|
||||
"JaxToNumpy": "jax_to_numpy",
|
||||
"JaxToTorch": "jax_to_torch",
|
||||
"NumpyToTorch": "numpy_to_torch",
|
||||
}
|
||||
|
||||
_renamed_wrapper = {
|
||||
"AutoResetWrapper": "Autoreset",
|
||||
"FrameStack": "FrameStackObservation",
|
||||
"PixelObservationWrapper": "RenderObservation",
|
||||
"VectorListInfo": "vector.DictInfoToList",
|
||||
}
|
||||
|
||||
|
||||
def __getattr__(wrapper_name: str):
|
||||
"""Load a wrapper by name.
|
||||
|
||||
This optimizes the loading of gymnasium wrappers by only loading the wrapper if it is used.
|
||||
Errors will be raised if the wrapper does not exist or if the version is not the latest.
|
||||
|
||||
Args:
|
||||
wrapper_name: The name of a wrapper to load.
|
||||
|
||||
Returns:
|
||||
The specified wrapper.
|
||||
|
||||
Raises:
|
||||
AttributeError: If the wrapper does not exist.
|
||||
DeprecatedWrapper: If the version is not the latest.
|
||||
"""
|
||||
# Check if the requested wrapper is in the _wrapper_to_class dictionary
|
||||
if wrapper_name in _wrapper_to_class:
|
||||
import_stmt = f"gymnasium.wrappers.{_wrapper_to_class[wrapper_name]}"
|
||||
module = importlib.import_module(import_stmt)
|
||||
return getattr(module, wrapper_name)
|
||||
|
||||
elif wrapper_name in _renamed_wrapper:
|
||||
raise AttributeError(
|
||||
f"{wrapper_name!r} has been renamed with `wrappers.{_renamed_wrapper[wrapper_name]}`"
|
||||
)
|
||||
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {wrapper_name!r}")
|
||||
|
@@ -5,14 +5,14 @@ import gymnasium as gym
|
||||
from gymnasium.spaces import Box
|
||||
|
||||
|
||||
try:
|
||||
import cv2
|
||||
except ImportError:
|
||||
cv2 = None
|
||||
__all__ = ["AtariPreprocessing"]
|
||||
|
||||
|
||||
class AtariPreprocessing(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
||||
"""Atari 2600 preprocessing wrapper.
|
||||
"""Implements the common preprocessing techniques for Atari environments (excluding frame stacking).
|
||||
|
||||
For frame stacking use :class:`gymnasium.wrappers.FrameStackObservation`.
|
||||
No vector version of the wrapper exists
|
||||
|
||||
This class follows the guidelines in Machado et al. (2018),
|
||||
"Revisiting the Arcade Learning Environment: Evaluation Protocols and Open Problems for General Agents".
|
||||
@@ -20,13 +20,22 @@ class AtariPreprocessing(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
||||
Specifically, the following preprocess stages applies to the atari environment:
|
||||
|
||||
- Noop Reset: Obtains the initial state by taking a random number of no-ops on reset, default max 30 no-ops.
|
||||
- Frame skipping: The number of frames skipped between steps, 4 by default
|
||||
- Max-pooling: Pools over the most recent two observations from the frame skips
|
||||
- Frame skipping: The number of frames skipped between steps, 4 by default.
|
||||
- Max-pooling: Pools over the most recent two observations from the frame skips.
|
||||
- Termination signal when a life is lost: When the agent losses a life during the environment, then the environment is terminated.
|
||||
Turned off by default. Not recommended by Machado et al. (2018).
|
||||
- Resize to a square image: Resizes the atari environment original observation shape from 210x180 to 84x84 by default
|
||||
- Grayscale observation: If the observation is colour or greyscale, by default, greyscale.
|
||||
- Scale observation: If to scale the observation between [0, 1) or [0, 255), by default, not scaled.
|
||||
- Resize to a square image: Resizes the atari environment original observation shape from 210x180 to 84x84 by default.
|
||||
- Grayscale observation: Makes the observation greyscale, enabled by default.
|
||||
- Grayscale new axis: Extends the last channel of the observation such that the image is 3-dimensional, not enabled by default.
|
||||
- Scale observation: Whether to scale the observation between [0, 1) or [0, 255), not scaled by default.
|
||||
|
||||
Example:
|
||||
>>> import gymnasium as gym # doctest: +SKIP
|
||||
>>> env = gym.make("ALE/Adventure-v5") # doctest: +SKIP
|
||||
>>> env = AtariPreprocessing(env, noop_max=10, frame_skip=0, screen_size=84, terminal_on_life_loss=True, grayscale_obs=False, grayscale_newaxis=False) # doctest: +SKIP
|
||||
|
||||
Change logs:
|
||||
* Added in gym v0.12.2 (gym #1455)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -46,7 +55,7 @@ class AtariPreprocessing(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
||||
env (Env): The environment to apply the preprocessing
|
||||
noop_max (int): For No-op reset, the max number no-ops actions are taken at reset, to turn off, set to 0.
|
||||
frame_skip (int): The number of frames between new observation the agents observations effecting the frequency at which the agent experiences the game.
|
||||
screen_size (int): resize Atari frame
|
||||
screen_size (int): resize Atari frame.
|
||||
terminal_on_life_loss (bool): `if True`, then :meth:`step()` returns `terminated=True` whenever a
|
||||
life is lost.
|
||||
grayscale_obs (bool): if True, then gray scale observation is returned, otherwise, RGB observation
|
||||
@@ -72,10 +81,13 @@ class AtariPreprocessing(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
||||
)
|
||||
gym.Wrapper.__init__(self, env)
|
||||
|
||||
if cv2 is None:
|
||||
try:
|
||||
import cv2 # noqa: F401
|
||||
except ImportError:
|
||||
raise gym.error.DependencyNotInstalled(
|
||||
"opencv-python package not installed, run `pip install gymnasium[other]` to get dependencies for atari"
|
||||
)
|
||||
|
||||
assert frame_skip > 0
|
||||
assert screen_size > 0
|
||||
assert noop_max >= 0
|
||||
@@ -187,7 +199,9 @@ class AtariPreprocessing(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
||||
def _get_obs(self):
|
||||
if self.frame_skip > 1: # more efficient in-place pooling
|
||||
np.maximum(self.obs_buffer[0], self.obs_buffer[1], out=self.obs_buffer[0])
|
||||
assert cv2 is not None
|
||||
|
||||
import cv2
|
||||
|
||||
obs = cv2.resize(
|
||||
self.obs_buffer[0],
|
||||
(self.screen_size, self.screen_size),
|
||||
|
@@ -1,86 +0,0 @@
|
||||
"""Wrapper that autoreset environments when `terminated=True` or `truncated=True`."""
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import gymnasium as gym
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gymnasium.envs.registration import EnvSpec
|
||||
|
||||
|
||||
class AutoResetWrapper(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
||||
"""A class for providing an automatic reset functionality for gymnasium environments when calling :meth:`self.step`.
|
||||
|
||||
When calling step causes :meth:`Env.step` to return `terminated=True` or `truncated=True`, :meth:`Env.reset` is called,
|
||||
and the return format of :meth:`self.step` is as follows: ``(new_obs, final_reward, final_terminated, final_truncated, info)``
|
||||
with new step API and ``(new_obs, final_reward, final_done, info)`` with the old step API.
|
||||
|
||||
- ``new_obs`` is the first observation after calling :meth:`self.env.reset`
|
||||
- ``final_reward`` is the reward after calling :meth:`self.env.step`, prior to calling :meth:`self.env.reset`.
|
||||
- ``final_terminated`` is the terminated value before calling :meth:`self.env.reset`.
|
||||
- ``final_truncated`` is the truncated value before calling :meth:`self.env.reset`. Both `final_terminated` and `final_truncated` cannot be False.
|
||||
- ``info`` is a dict containing all the keys from the info dict returned by the call to :meth:`self.env.reset`,
|
||||
with an additional key "final_observation" containing the observation returned by the last call to :meth:`self.env.step`
|
||||
and "final_info" containing the info dict returned by the last call to :meth:`self.env.step`.
|
||||
|
||||
Warning:
|
||||
When using this wrapper to collect rollouts, note that when :meth:`Env.step` returns `terminated` or `truncated`, a
|
||||
new observation from after calling :meth:`Env.reset` is returned by :meth:`Env.step` alongside the
|
||||
final reward, terminated and truncated state from the previous episode.
|
||||
If you need the final state from the previous episode, you need to retrieve it via the
|
||||
"final_observation" key in the info dict.
|
||||
Make sure you know what you're doing if you use this wrapper!
|
||||
"""
|
||||
|
||||
def __init__(self, env: gym.Env):
|
||||
"""A class for providing an automatic reset functionality for gymnasium environments when calling :meth:`self.step`.
|
||||
|
||||
Args:
|
||||
env (gym.Env): The environment to apply the wrapper
|
||||
"""
|
||||
gym.utils.RecordConstructorArgs.__init__(self)
|
||||
gym.Wrapper.__init__(self, env)
|
||||
|
||||
def step(self, action):
|
||||
"""Steps through the environment with action and resets the environment if a terminated or truncated signal is encountered.
|
||||
|
||||
Args:
|
||||
action: The action to take
|
||||
|
||||
Returns:
|
||||
The autoreset environment :meth:`step`
|
||||
"""
|
||||
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||
if terminated or truncated:
|
||||
new_obs, new_info = self.env.reset()
|
||||
assert (
|
||||
"final_observation" not in new_info
|
||||
), 'info dict cannot contain key "final_observation" '
|
||||
assert (
|
||||
"final_info" not in new_info
|
||||
), 'info dict cannot contain key "final_info" '
|
||||
|
||||
new_info["final_observation"] = obs
|
||||
new_info["final_info"] = info
|
||||
|
||||
obs = new_obs
|
||||
info = new_info
|
||||
|
||||
return obs, reward, terminated, truncated, info
|
||||
|
||||
@property
|
||||
def spec(self) -> EnvSpec | None:
|
||||
"""Modifies the environment spec to specify the `autoreset=True`."""
|
||||
if self._cached_spec is not None:
|
||||
return self._cached_spec
|
||||
|
||||
env_spec = self.env.spec
|
||||
if env_spec is not None:
|
||||
env_spec = deepcopy(env_spec)
|
||||
env_spec.autoreset = True
|
||||
|
||||
self._cached_spec = env_spec
|
||||
return env_spec
|
@@ -1,43 +0,0 @@
|
||||
"""Wrapper for clipping actions within a valid bound."""
|
||||
import numpy as np
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.spaces import Box
|
||||
|
||||
|
||||
class ClipAction(gym.ActionWrapper, gym.utils.RecordConstructorArgs):
|
||||
"""Clip the continuous action within the valid :class:`Box` observation space bound.
|
||||
|
||||
Example:
|
||||
>>> import gymnasium as gym
|
||||
>>> from gymnasium.wrappers import ClipAction
|
||||
>>> env = gym.make("Hopper-v4")
|
||||
>>> env = ClipAction(env)
|
||||
>>> env.action_space
|
||||
Box(-1.0, 1.0, (3,), float32)
|
||||
>>> _ = env.reset(seed=42)
|
||||
>>> _ = env.step(np.array([5.0, -2.0, 0.0]))
|
||||
... # Executes the action np.array([1.0, -1.0, 0]) in the base environment
|
||||
"""
|
||||
|
||||
def __init__(self, env: gym.Env):
|
||||
"""A wrapper for clipping continuous actions within the valid bound.
|
||||
|
||||
Args:
|
||||
env: The environment to apply the wrapper
|
||||
"""
|
||||
assert isinstance(env.action_space, Box)
|
||||
|
||||
gym.utils.RecordConstructorArgs.__init__(self)
|
||||
gym.ActionWrapper.__init__(self, env)
|
||||
|
||||
def action(self, action):
|
||||
"""Clips the action within the valid bounds.
|
||||
|
||||
Args:
|
||||
action: The action to clip
|
||||
|
||||
Returns:
|
||||
The clipped action
|
||||
"""
|
||||
return np.clip(action, self.action_space.low, self.action_space.high)
|
536
gymnasium/wrappers/common.py
Normal file
536
gymnasium/wrappers/common.py
Normal file
@@ -0,0 +1,536 @@
|
||||
"""A collection of common wrappers.
|
||||
|
||||
* ``TimeLimit`` - Provides a time limit on the number of steps for an environment before it truncates
|
||||
* ``Autoreset`` - Auto-resets the environment
|
||||
* ``PassiveEnvChecker`` - Passive environment checker that does not modify any environment data
|
||||
* ``OrderEnforcing`` - Enforces the order of function calls to environments
|
||||
* ``RecordEpisodeStatistics`` - Records the episode statistics
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from collections import deque
|
||||
from copy import deepcopy
|
||||
from typing import TYPE_CHECKING, Any, SupportsFloat
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium import logger
|
||||
from gymnasium.core import ActType, ObsType, RenderFrame
|
||||
from gymnasium.error import ResetNeeded
|
||||
from gymnasium.utils.passive_env_checker import (
|
||||
check_action_space,
|
||||
check_observation_space,
|
||||
env_render_passive_checker,
|
||||
env_reset_passive_checker,
|
||||
env_step_passive_checker,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gymnasium.envs.registration import EnvSpec
|
||||
|
||||
|
||||
__all__ = [
|
||||
"TimeLimit",
|
||||
"Autoreset",
|
||||
"PassiveEnvChecker",
|
||||
"OrderEnforcing",
|
||||
"RecordEpisodeStatistics",
|
||||
]
|
||||
|
||||
|
||||
class TimeLimit(
|
||||
gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
|
||||
):
|
||||
"""Limits the number of steps for an environment through truncating the environment if a maximum number of timesteps is exceeded.
|
||||
|
||||
If a truncation is not defined inside the environment itself, this is the only place that the truncation signal is issued.
|
||||
Critically, this is different from the `terminated` signal that originates from the underlying environment as part of the MDP.
|
||||
No vector wrapper exists.
|
||||
|
||||
Example using the TimeLimit wrapper:
|
||||
>>> from gymnasium.wrappers import TimeLimit
|
||||
>>> from gymnasium.envs.classic_control import CartPoleEnv
|
||||
|
||||
>>> spec = gym.spec("CartPole-v1")
|
||||
>>> spec.max_episode_steps
|
||||
500
|
||||
>>> env = gym.make("CartPole-v1")
|
||||
>>> env # TimeLimit is included within the environment stack
|
||||
<TimeLimit<OrderEnforcing<PassiveEnvChecker<CartPoleEnv<CartPole-v1>>>>>
|
||||
>>> env.spec # doctest: +ELLIPSIS
|
||||
EnvSpec(id='CartPole-v1', ..., max_episode_steps=500, ...)
|
||||
>>> env = gym.make("CartPole-v1", max_episode_steps=3)
|
||||
>>> env.spec # doctest: +ELLIPSIS
|
||||
EnvSpec(id='CartPole-v1', ..., max_episode_steps=3, ...)
|
||||
>>> env = TimeLimit(CartPoleEnv(), max_episode_steps=10)
|
||||
>>> env
|
||||
<TimeLimit<CartPoleEnv instance>>
|
||||
|
||||
Example of `TimeLimit` determining the episode step
|
||||
>>> env = gym.make("CartPole-v1", max_episode_steps=3)
|
||||
>>> _ = env.reset(seed=123)
|
||||
>>> _ = env.action_space.seed(123)
|
||||
>>> _, _, terminated, truncated, _ = env.step(env.action_space.sample())
|
||||
>>> terminated, truncated
|
||||
(False, False)
|
||||
>>> _, _, terminated, truncated, _ = env.step(env.action_space.sample())
|
||||
>>> terminated, truncated
|
||||
(False, False)
|
||||
>>> _, _, terminated, truncated, _ = env.step(env.action_space.sample())
|
||||
>>> terminated, truncated
|
||||
(False, True)
|
||||
|
||||
Change logs:
|
||||
* v0.10.6 - Initially added
|
||||
* v0.25.0 - With the step API update, the termination and truncation signal is returned separately.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env: gym.Env,
|
||||
max_episode_steps: int,
|
||||
):
|
||||
"""Initializes the :class:`TimeLimit` wrapper with an environment and the number of steps after which truncation will occur.
|
||||
|
||||
Args:
|
||||
env: The environment to apply the wrapper
|
||||
max_episode_steps: An optional max episode steps (if ``None``, ``env.spec.max_episode_steps`` is used)
|
||||
"""
|
||||
gym.utils.RecordConstructorArgs.__init__(
|
||||
self, max_episode_steps=max_episode_steps
|
||||
)
|
||||
gym.Wrapper.__init__(self, env)
|
||||
|
||||
self._max_episode_steps = max_episode_steps
|
||||
self._elapsed_steps = None
|
||||
|
||||
def step(
|
||||
self, action: ActType
|
||||
) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
|
||||
"""Steps through the environment and if the number of steps elapsed exceeds ``max_episode_steps`` then truncate.
|
||||
|
||||
Args:
|
||||
action: The environment step action
|
||||
|
||||
Returns:
|
||||
The environment step ``(observation, reward, terminated, truncated, info)`` with `truncated=True`
|
||||
if the number of steps elapsed >= max episode steps
|
||||
|
||||
"""
|
||||
observation, reward, terminated, truncated, info = self.env.step(action)
|
||||
self._elapsed_steps += 1
|
||||
|
||||
if self._elapsed_steps >= self._max_episode_steps:
|
||||
truncated = True
|
||||
|
||||
return observation, reward, terminated, truncated, info
|
||||
|
||||
def reset(
|
||||
self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
||||
) -> tuple[ObsType, dict[str, Any]]:
|
||||
"""Resets the environment with :param:`**kwargs` and sets the number of steps elapsed to zero.
|
||||
|
||||
Args:
|
||||
seed: Seed for the environment
|
||||
options: Options for the environment
|
||||
|
||||
Returns:
|
||||
The reset environment
|
||||
"""
|
||||
self._elapsed_steps = 0
|
||||
return super().reset(seed=seed, options=options)
|
||||
|
||||
@property
|
||||
def spec(self) -> EnvSpec | None:
|
||||
"""Modifies the environment spec to include the `max_episode_steps=self._max_episode_steps`."""
|
||||
if self._cached_spec is not None:
|
||||
return self._cached_spec
|
||||
|
||||
env_spec = self.env.spec
|
||||
if env_spec is not None:
|
||||
env_spec = deepcopy(env_spec)
|
||||
env_spec.max_episode_steps = self._max_episode_steps
|
||||
|
||||
self._cached_spec = env_spec
|
||||
return env_spec
|
||||
|
||||
|
||||
class Autoreset(
|
||||
gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
|
||||
):
|
||||
"""The wrapped environment is automatically reset when an terminated or truncated state is reached.
|
||||
|
||||
When calling step causes :meth:`Env.step` to return `terminated=True` or `truncated=True`, :meth:`Env.reset` is called,
|
||||
and the return format of :meth:`self.step` is as follows: ``(new_obs, final_reward, final_terminated, final_truncated, info)``
|
||||
with new step API and ``(new_obs, final_reward, final_done, info)`` with the old step API.
|
||||
No vector version of the wrapper exists.
|
||||
|
||||
- ``obs`` is the first observation after calling :meth:`self.env.reset`
|
||||
- ``final_reward`` is the reward after calling :meth:`self.env.step`, prior to calling :meth:`self.env.reset`.
|
||||
- ``final_terminated`` is the terminated value before calling :meth:`self.env.reset`.
|
||||
- ``final_truncated`` is the truncated value before calling :meth:`self.env.reset`. Both `final_terminated` and `final_truncated` cannot be False.
|
||||
- ``info`` is a dict containing all the keys from the info dict returned by the call to :meth:`self.env.reset`,
|
||||
with an additional key "final_observation" containing the observation returned by the last call to :meth:`self.env.step`
|
||||
and "final_info" containing the info dict returned by the last call to :meth:`self.env.step`.
|
||||
|
||||
Warning:
|
||||
When using this wrapper to collect rollouts, note that when :meth:`Env.step` returns `terminated` or `truncated`, a
|
||||
new observation from after calling :meth:`Env.reset` is returned by :meth:`Env.step` alongside the
|
||||
final reward, terminated and truncated state from the previous episode.
|
||||
If you need the final state from the previous episode, you need to retrieve it via the
|
||||
"final_observation" key in the info dict.
|
||||
Make sure you know what you're doing if you use this wrapper!
|
||||
|
||||
Change logs:
|
||||
* v0.24.0 - Initially added as `AutoResetWrapper`
|
||||
* v1.0.0 - renamed to `Autoreset` and autoreset order was changed to reset on the step after the environment terminates or truncates. As a result, `"final_observation"` and `"final_info"` is removed.
|
||||
"""
|
||||
|
||||
def __init__(self, env: gym.Env):
|
||||
"""A class for providing an automatic reset functionality for gymnasium environments when calling :meth:`self.step`.
|
||||
|
||||
Args:
|
||||
env (gym.Env): The environment to apply the wrapper
|
||||
"""
|
||||
gym.utils.RecordConstructorArgs.__init__(self)
|
||||
gym.Wrapper.__init__(self, env)
|
||||
|
||||
def step(
|
||||
self, action: ActType
|
||||
) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
|
||||
"""Steps through the environment with action and resets the environment if a terminated or truncated signal is encountered.
|
||||
|
||||
Args:
|
||||
action: The action to take
|
||||
|
||||
Returns:
|
||||
The autoreset environment :meth:`step`
|
||||
"""
|
||||
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||
|
||||
if terminated or truncated:
|
||||
new_obs, new_info = self.env.reset()
|
||||
|
||||
assert (
|
||||
"final_observation" not in new_info
|
||||
), f'new info dict already contains "final_observation", info keys: {new_info.keys()}'
|
||||
assert (
|
||||
"final_info" not in new_info
|
||||
), f'new info dict already contains "final_observation", info keys: {new_info.keys()}'
|
||||
|
||||
new_info["final_observation"] = obs
|
||||
new_info["final_info"] = info
|
||||
|
||||
obs = new_obs
|
||||
info = new_info
|
||||
|
||||
return obs, reward, terminated, truncated, info
|
||||
|
||||
|
||||
class PassiveEnvChecker(
|
||||
gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
|
||||
):
|
||||
"""A passive wrapper that surrounds the ``step``, ``reset`` and ``render`` functions to check they follow Gymnasium's API.
|
||||
|
||||
This wrapper is automatically applied during make and can be disabled with `disable_env_checker`.
|
||||
No vector version of the wrapper exists.
|
||||
|
||||
Example:
|
||||
>>> import gymnasium as gym
|
||||
>>> env = gym.make("CartPole-v1")
|
||||
>>> env
|
||||
<TimeLimit<OrderEnforcing<PassiveEnvChecker<CartPoleEnv<CartPole-v1>>>>>
|
||||
>>> env = gym.make("CartPole-v1", disable_env_checker=True)
|
||||
>>> env
|
||||
<TimeLimit<OrderEnforcing<CartPoleEnv<CartPole-v1>>>>
|
||||
|
||||
Change logs:
|
||||
* v0.24.1 - Initially added however broken in several ways
|
||||
* v0.25.0 - Bugs was all fixed
|
||||
* v0.29.0 - Removed warnings for infinite bounds for Box observation and action spaces and inregular bound shapes
|
||||
"""
|
||||
|
||||
def __init__(self, env: gym.Env[ObsType, ActType]):
|
||||
"""Initialises the wrapper with the environments, run the observation and action space tests."""
|
||||
gym.utils.RecordConstructorArgs.__init__(self)
|
||||
gym.Wrapper.__init__(self, env)
|
||||
|
||||
assert hasattr(
|
||||
env, "action_space"
|
||||
), "The environment must specify an action space. https://gymnasium.farama.org/tutorials/gymnasium_basics/environment_creation/"
|
||||
check_action_space(env.action_space)
|
||||
assert hasattr(
|
||||
env, "observation_space"
|
||||
), "The environment must specify an observation space. https://gymnasium.farama.org/tutorials/gymnasium_basics/environment_creation/"
|
||||
check_observation_space(env.observation_space)
|
||||
|
||||
self.checked_reset: bool = False
|
||||
self.checked_step: bool = False
|
||||
self.checked_render: bool = False
|
||||
self.close_called: bool = False
|
||||
|
||||
def step(
|
||||
self, action: ActType
|
||||
) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
|
||||
"""Steps through the environment that on the first call will run the `passive_env_step_check`."""
|
||||
if self.checked_step is False:
|
||||
self.checked_step = True
|
||||
return env_step_passive_checker(self.env, action)
|
||||
else:
|
||||
return self.env.step(action)
|
||||
|
||||
def reset(
|
||||
self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
||||
) -> tuple[ObsType, dict[str, Any]]:
|
||||
"""Resets the environment that on the first call will run the `passive_env_reset_check`."""
|
||||
if self.checked_reset is False:
|
||||
self.checked_reset = True
|
||||
return env_reset_passive_checker(self.env, seed=seed, options=options)
|
||||
else:
|
||||
return self.env.reset(seed=seed, options=options)
|
||||
|
||||
def render(self) -> RenderFrame | list[RenderFrame] | None:
|
||||
"""Renders the environment that on the first call will run the `passive_env_render_check`."""
|
||||
if self.checked_render is False:
|
||||
self.checked_render = True
|
||||
return env_render_passive_checker(self.env)
|
||||
else:
|
||||
return self.env.render()
|
||||
|
||||
@property
|
||||
def spec(self) -> EnvSpec | None:
|
||||
"""Modifies the environment spec to such that `disable_env_checker=False`."""
|
||||
if self._cached_spec is not None:
|
||||
return self._cached_spec
|
||||
|
||||
env_spec = self.env.spec
|
||||
if env_spec is not None:
|
||||
env_spec = deepcopy(env_spec)
|
||||
env_spec.disable_env_checker = False
|
||||
|
||||
self._cached_spec = env_spec
|
||||
return env_spec
|
||||
|
||||
def close(self):
|
||||
"""Warns if calling close on a closed environment fails."""
|
||||
if not self.close_called:
|
||||
self.close_called = True
|
||||
return self.env.close()
|
||||
else:
|
||||
try:
|
||||
return self.env.close()
|
||||
except Exception as e:
|
||||
logger.warn(
|
||||
"Calling `env.close()` on the closed environment should be allowed, but it raised the following exception."
|
||||
)
|
||||
raise e
|
||||
|
||||
|
||||
class OrderEnforcing(
|
||||
gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
|
||||
):
|
||||
"""Will produce an error if ``step`` or ``render`` is called before ``reset``.
|
||||
|
||||
No vector version of the wrapper exists.
|
||||
|
||||
Example:
|
||||
>>> import gymnasium as gym
|
||||
>>> from gymnasium.wrappers import OrderEnforcing
|
||||
>>> env = gym.make("CartPole-v1", render_mode="human")
|
||||
>>> env = OrderEnforcing(env)
|
||||
>>> env.step(0)
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
gymnasium.error.ResetNeeded: Cannot call env.step() before calling env.reset()
|
||||
>>> env.render()
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
gymnasium.error.ResetNeeded: Cannot call `env.render()` before calling `env.reset()`, if this is an intended action, set `disable_render_order_enforcing=True` on the OrderEnforcer wrapper.
|
||||
>>> _ = env.reset()
|
||||
>>> env.render()
|
||||
>>> _ = env.step(0)
|
||||
>>> env.close()
|
||||
|
||||
Change logs:
|
||||
* v0.22.0 - Initially added
|
||||
* v0.24.0 - Added order enforcing for the render function
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env: gym.Env[ObsType, ActType],
|
||||
disable_render_order_enforcing: bool = False,
|
||||
):
|
||||
"""A wrapper that will produce an error if :meth:`step` is called before an initial :meth:`reset`.
|
||||
|
||||
Args:
|
||||
env: The environment to wrap
|
||||
disable_render_order_enforcing: If to disable render order enforcing
|
||||
"""
|
||||
gym.utils.RecordConstructorArgs.__init__(
|
||||
self, disable_render_order_enforcing=disable_render_order_enforcing
|
||||
)
|
||||
gym.Wrapper.__init__(self, env)
|
||||
|
||||
self._has_reset: bool = False
|
||||
self._disable_render_order_enforcing: bool = disable_render_order_enforcing
|
||||
|
||||
def step(self, action: ActType) -> tuple[ObsType, SupportsFloat, bool, bool, dict]:
|
||||
"""Steps through the environment."""
|
||||
if not self._has_reset:
|
||||
raise ResetNeeded("Cannot call env.step() before calling env.reset()")
|
||||
return super().step(action)
|
||||
|
||||
def reset(
|
||||
self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
||||
) -> tuple[ObsType, dict[str, Any]]:
|
||||
"""Resets the environment with `kwargs`."""
|
||||
self._has_reset = True
|
||||
return super().reset(seed=seed, options=options)
|
||||
|
||||
def render(self) -> RenderFrame | list[RenderFrame] | None:
|
||||
"""Renders the environment with `kwargs`."""
|
||||
if not self._disable_render_order_enforcing and not self._has_reset:
|
||||
raise ResetNeeded(
|
||||
"Cannot call `env.render()` before calling `env.reset()`, if this is an intended action, "
|
||||
"set `disable_render_order_enforcing=True` on the OrderEnforcer wrapper."
|
||||
)
|
||||
return super().render()
|
||||
|
||||
@property
|
||||
def has_reset(self):
|
||||
"""Returns if the environment has been reset before."""
|
||||
return self._has_reset
|
||||
|
||||
@property
|
||||
def spec(self) -> EnvSpec | None:
|
||||
"""Modifies the environment spec to add the `order_enforce=True`."""
|
||||
if self._cached_spec is not None:
|
||||
return self._cached_spec
|
||||
|
||||
env_spec = self.env.spec
|
||||
if env_spec is not None:
|
||||
env_spec = deepcopy(env_spec)
|
||||
env_spec.order_enforce = True
|
||||
|
||||
self._cached_spec = env_spec
|
||||
return env_spec
|
||||
|
||||
|
||||
class RecordEpisodeStatistics(
|
||||
gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
|
||||
):
|
||||
"""This wrapper will keep track of cumulative rewards and episode lengths.
|
||||
|
||||
At the end of an episode, the statistics of the episode will be added to ``info``
|
||||
using the key ``episode``. If using a vectorized environment also the key
|
||||
``_episode`` is used which indicates whether the env at the respective index has
|
||||
the episode statistics.
|
||||
A vector version of the wrapper exists, :class:`gymnasium.wrappers.vector.RecordEpisodeStatistics`.
|
||||
|
||||
After the completion of an episode, ``info`` will look like this::
|
||||
|
||||
>>> info = {
|
||||
... "episode": {
|
||||
... "r": "<cumulative reward>",
|
||||
... "l": "<episode length>",
|
||||
... "t": "<elapsed time since beginning of episode>"
|
||||
... },
|
||||
... }
|
||||
|
||||
For a vectorized environments the output will be in the form of::
|
||||
|
||||
>>> infos = {
|
||||
... "final_observation": "<array of length num-envs>",
|
||||
... "_final_observation": "<boolean array of length num-envs>",
|
||||
... "final_info": "<array of length num-envs>",
|
||||
... "_final_info": "<boolean array of length num-envs>",
|
||||
... "episode": {
|
||||
... "r": "<array of cumulative reward>",
|
||||
... "l": "<array of episode length>",
|
||||
... "t": "<array of elapsed time since beginning of episode>"
|
||||
... },
|
||||
... "_episode": "<boolean array of length num-envs>"
|
||||
... }
|
||||
|
||||
Moreover, the most recent rewards and episode lengths are stored in buffers that can be accessed via
|
||||
:attr:`wrapped_env.return_queue` and :attr:`wrapped_env.length_queue` respectively.
|
||||
|
||||
Attributes:
|
||||
* time_queue: The time length of the last ``deque_size``-many episodes
|
||||
* return_queue: The cumulative rewards of the last ``deque_size``-many episodes
|
||||
* length_queue: The lengths of the last ``deque_size``-many episodes
|
||||
|
||||
Change logs:
|
||||
* v0.15.4 - Initially added
|
||||
* v1.0.0 - Removed vector environment support for `wrappers.vector.RecordEpisodeStatistics` and add attribute ``time_queue``
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env: gym.Env[ObsType, ActType],
|
||||
buffer_length: int | None = 100,
|
||||
stats_key: str = "episode",
|
||||
):
|
||||
"""This wrapper will keep track of cumulative rewards and episode lengths.
|
||||
|
||||
Args:
|
||||
env (Env): The environment to apply the wrapper
|
||||
buffer_length: The size of the buffers :attr:`return_queue` and :attr:`length_queue`
|
||||
stats_key: The info key for the episode statistics
|
||||
"""
|
||||
gym.utils.RecordConstructorArgs.__init__(self)
|
||||
gym.Wrapper.__init__(self, env)
|
||||
|
||||
self._stats_key = stats_key
|
||||
|
||||
self.episode_count = 0
|
||||
self.episode_start_time: float = -1
|
||||
self.episode_returns: float = 0.0
|
||||
self.episode_lengths: int = 0
|
||||
|
||||
self.time_queue: deque[float] = deque(maxlen=buffer_length)
|
||||
self.return_queue: deque[float] = deque(maxlen=buffer_length)
|
||||
self.length_queue: deque[int] = deque(maxlen=buffer_length)
|
||||
|
||||
def step(
|
||||
self, action: ActType
|
||||
) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
|
||||
"""Steps through the environment, recording the episode statistics."""
|
||||
obs, reward, terminated, truncated, info = super().step(action)
|
||||
|
||||
self.episode_returns += reward
|
||||
self.episode_lengths += 1
|
||||
|
||||
if terminated or truncated:
|
||||
assert self._stats_key not in info
|
||||
|
||||
episode_time_length = round(
|
||||
time.perf_counter() - self.episode_start_time, 6
|
||||
)
|
||||
info[self._stats_key] = {
|
||||
"r": self.episode_returns,
|
||||
"l": self.episode_lengths,
|
||||
"t": episode_time_length,
|
||||
}
|
||||
|
||||
self.time_queue.append(episode_time_length)
|
||||
self.return_queue.append(self.episode_returns)
|
||||
self.length_queue.append(self.episode_lengths)
|
||||
|
||||
self.episode_count += 1
|
||||
|
||||
return obs, reward, terminated, truncated, info
|
||||
|
||||
def reset(
|
||||
self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
||||
) -> tuple[ObsType, dict[str, Any]]:
|
||||
"""Resets the environment using seed and options and resets the episode rewards and lengths."""
|
||||
obs, info = super().reset(seed=seed, options=options)
|
||||
|
||||
self.episode_start_time = time.perf_counter()
|
||||
self.episode_returns = 0.0
|
||||
self.episode_lengths = 0
|
||||
|
||||
return obs, info
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user