mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-01 14:10:30 +00:00
Fix conversion wrapper module registration and docs (#1380)
This commit is contained in:
@@ -74,6 +74,7 @@ title: Vector Wrappers
|
|||||||
## Implemented Data Conversion wrappers
|
## Implemented Data Conversion wrappers
|
||||||
|
|
||||||
```{eval-rst}
|
```{eval-rst}
|
||||||
|
.. autoclass:: gymnasium.wrappers.vector.ArrayConversion
|
||||||
.. autoclass:: gymnasium.wrappers.vector.JaxToNumpy
|
.. autoclass:: gymnasium.wrappers.vector.JaxToNumpy
|
||||||
.. autoclass:: gymnasium.wrappers.vector.JaxToTorch
|
.. autoclass:: gymnasium.wrappers.vector.JaxToTorch
|
||||||
.. autoclass:: gymnasium.wrappers.vector.NumpyToTorch
|
.. autoclass:: gymnasium.wrappers.vector.NumpyToTorch
|
||||||
|
@@ -12,6 +12,8 @@ wrapper in the page on the wrapper type
|
|||||||
|
|
||||||
* - Name
|
* - Name
|
||||||
- Description
|
- Description
|
||||||
|
* - :class:`ArrayConversion`
|
||||||
|
- Wraps an environment based on any Array API compatible framework, e.g. ``numpy``, ``torch``, ``jax.numpy``, such that it can be interacted with any other Array API compatible framework.
|
||||||
* - :class:`AtariPreprocessing`
|
* - :class:`AtariPreprocessing`
|
||||||
- Implements the common preprocessing techniques for Atari environments (excluding frame stacking).
|
- Implements the common preprocessing techniques for Atari environments (excluding frame stacking).
|
||||||
* - :class:`Autoreset`
|
* - :class:`Autoreset`
|
||||||
@@ -34,8 +36,6 @@ wrapper in the page on the wrapper type
|
|||||||
- Converts an image observation computed by ``reset`` and ``step`` from RGB to Grayscale.
|
- Converts an image observation computed by ``reset`` and ``step`` from RGB to Grayscale.
|
||||||
* - :class:`HumanRendering`
|
* - :class:`HumanRendering`
|
||||||
- Allows human like rendering for environments that support "rgb_array" rendering.
|
- Allows human like rendering for environments that support "rgb_array" rendering.
|
||||||
* - :class:`ArrayConversion`
|
|
||||||
- Wraps an environment based on any Array API compatible framework, e.g. torch, jax, numpy, such that it can be interacted with any other Array API compatible framework.
|
|
||||||
* - :class:`JaxToNumpy`
|
* - :class:`JaxToNumpy`
|
||||||
- Wraps a Jax-based environment such that it can be interacted with NumPy arrays.
|
- Wraps a Jax-based environment such that it can be interacted with NumPy arrays.
|
||||||
* - :class:`JaxToTorch`
|
* - :class:`JaxToTorch`
|
||||||
|
@@ -134,6 +134,7 @@ __all__ = [
|
|||||||
"RecordVideo",
|
"RecordVideo",
|
||||||
"HumanRendering",
|
"HumanRendering",
|
||||||
# --- Conversion ---
|
# --- Conversion ---
|
||||||
|
"ArrayConversion",
|
||||||
"JaxToNumpy",
|
"JaxToNumpy",
|
||||||
"JaxToTorch",
|
"JaxToTorch",
|
||||||
"NumpyToTorch",
|
"NumpyToTorch",
|
||||||
@@ -143,6 +144,7 @@ __all__ = [
|
|||||||
# to avoid `import jax` or `import torch` on `import gymnasium`.
|
# to avoid `import jax` or `import torch` on `import gymnasium`.
|
||||||
_wrapper_to_class = {
|
_wrapper_to_class = {
|
||||||
# data converters
|
# data converters
|
||||||
|
"ArrayConversion": "array_conversion",
|
||||||
"JaxToNumpy": "jax_to_numpy",
|
"JaxToNumpy": "jax_to_numpy",
|
||||||
"JaxToTorch": "jax_to_torch",
|
"JaxToTorch": "jax_to_torch",
|
||||||
"NumpyToTorch": "numpy_to_torch",
|
"NumpyToTorch": "numpy_to_torch",
|
||||||
|
@@ -137,9 +137,12 @@ def _none_array_conversion(
|
|||||||
|
|
||||||
|
|
||||||
class ArrayConversion(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
class ArrayConversion(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
||||||
"""Wraps an Array API compatible environment so that it can be interacted with a specific Array API framework.
|
"""Wraps an Array API compatible environment so that it can be interacted with with another Array API framework.
|
||||||
|
|
||||||
|
Popular Array API frameworks include ``numpy``, ``torch``, ``jax.numpy``, ``cupy`` etc. With this wrapper, you can convert outputs from your environment to
|
||||||
|
any of these frameworks. Conversely, actions are automatically mapped back to the environment framework, if possible without moving the
|
||||||
|
data or device transfers.
|
||||||
|
|
||||||
Actions must be provided as Array API compatible arrays and observations will be returned as Arrays of the specified xp module.
|
|
||||||
A vector version of the wrapper exists, :class:`gymnasium.wrappers.vector.ArrayConversion`.
|
A vector version of the wrapper exists, :class:`gymnasium.wrappers.vector.ArrayConversion`.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
@@ -67,6 +67,7 @@ __all__ = [
|
|||||||
# "RecordVideo",
|
# "RecordVideo",
|
||||||
"HumanRendering",
|
"HumanRendering",
|
||||||
# --- Conversion ---
|
# --- Conversion ---
|
||||||
|
"ArrayConversion",
|
||||||
"JaxToNumpy",
|
"JaxToNumpy",
|
||||||
"JaxToTorch",
|
"JaxToTorch",
|
||||||
"NumpyToTorch",
|
"NumpyToTorch",
|
||||||
@@ -77,6 +78,7 @@ __all__ = [
|
|||||||
# to avoid `import jax` or `import torch` on `import gymnasium`.
|
# to avoid `import jax` or `import torch` on `import gymnasium`.
|
||||||
_wrapper_to_class = {
|
_wrapper_to_class = {
|
||||||
# data converters
|
# data converters
|
||||||
|
"ArrayConversion": "array_conversion",
|
||||||
"JaxToNumpy": "jax_to_numpy",
|
"JaxToNumpy": "jax_to_numpy",
|
||||||
"JaxToTorch": "jax_to_torch",
|
"JaxToTorch": "jax_to_torch",
|
||||||
"NumpyToTorch": "numpy_to_torch",
|
"NumpyToTorch": "numpy_to_torch",
|
||||||
|
@@ -22,11 +22,12 @@ __all__ = ["ArrayConversion"]
|
|||||||
class ArrayConversion(VectorWrapper, gym.utils.RecordConstructorArgs):
|
class ArrayConversion(VectorWrapper, gym.utils.RecordConstructorArgs):
|
||||||
"""Wraps a vector environment returning Array API compatible arrays so that it can be interacted with through a specific framework.
|
"""Wraps a vector environment returning Array API compatible arrays so that it can be interacted with through a specific framework.
|
||||||
|
|
||||||
Notes:
|
Popular Array API frameworks include ``numpy``, ``torch``, ``jax.numpy``, ``cupy`` etc. With this wrapper, you can convert outputs from your environment to
|
||||||
A vectorized version of ``gymnasium.wrappers.ArrayConversion``
|
any of these frameworks. Conversely, actions are automatically mapped back to the environment framework, if possible without moving the
|
||||||
|
data or device transfers.
|
||||||
|
|
||||||
Actions must be provided as Array API compatible arrays and observations, rewards, terminations and truncations will be returned in the desired framework.
|
Notes:
|
||||||
xp here is a module that is compatible with the Array API standard, e.g. ``numpy``, ``jax`` etc.
|
A vectorized version of :class:`gymnasium.wrappers.ArrayConversion`
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
>>> import gymnasium as gym # doctest: +SKIP
|
>>> import gymnasium as gym # doctest: +SKIP
|
||||||
|
@@ -17,7 +17,7 @@ class JaxToNumpy(ArrayConversion):
|
|||||||
"""Wraps a jax vector environment so that it can be interacted with through numpy arrays.
|
"""Wraps a jax vector environment so that it can be interacted with through numpy arrays.
|
||||||
|
|
||||||
Notes:
|
Notes:
|
||||||
A vectorized version of ``gymnasium.wrappers.JaxToNumpy``
|
A vectorized version of :class:`gymnasium.wrappers.JaxToNumpy`
|
||||||
|
|
||||||
Actions must be provided as numpy arrays and observations, rewards, terminations and truncations will be returned as numpy arrays.
|
Actions must be provided as numpy arrays and observations, rewards, terminations and truncations will be returned as numpy arrays.
|
||||||
|
|
||||||
|
@@ -15,8 +15,8 @@ array_api_extra = pytest.importorskip("array_api_extra")
|
|||||||
|
|
||||||
from array_api_compat import array_namespace, is_array_api_obj # noqa: E402
|
from array_api_compat import array_namespace, is_array_api_obj # noqa: E402
|
||||||
|
|
||||||
|
from gymnasium.wrappers import ArrayConversion # noqa: E402
|
||||||
from gymnasium.wrappers.array_conversion import ( # noqa: E402
|
from gymnasium.wrappers.array_conversion import ( # noqa: E402
|
||||||
ArrayConversion,
|
|
||||||
array_conversion,
|
array_conversion,
|
||||||
module_namespace,
|
module_namespace,
|
||||||
)
|
)
|
||||||
|
@@ -14,10 +14,10 @@ array_api_compat = pytest.importorskip("array_api_compat")
|
|||||||
from array_api_compat import array_namespace # noqa: E402
|
from array_api_compat import array_namespace # noqa: E402
|
||||||
|
|
||||||
from gymnasium.wrappers.array_conversion import module_namespace # noqa: E402
|
from gymnasium.wrappers.array_conversion import module_namespace # noqa: E402
|
||||||
from gymnasium.wrappers.vector.array_conversion import ArrayConversion # noqa: E402
|
from gymnasium.wrappers.vector import ArrayConversion # noqa: E402
|
||||||
from gymnasium.wrappers.vector.jax_to_numpy import JaxToNumpy # noqa: E402
|
from gymnasium.wrappers.vector import JaxToNumpy # noqa: E402
|
||||||
from gymnasium.wrappers.vector.jax_to_torch import JaxToTorch # noqa: E402
|
from gymnasium.wrappers.vector import JaxToTorch # noqa: E402
|
||||||
from gymnasium.wrappers.vector.numpy_to_torch import NumpyToTorch # noqa: E402
|
from gymnasium.wrappers.vector import NumpyToTorch # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
# Define available modules
|
# Define available modules
|
||||||
|
Reference in New Issue
Block a user