Fix conversion wrapper module registration and docs (#1380)

This commit is contained in:
Martin Schuck
2025-05-12 20:23:44 +02:00
committed by GitHub
parent 4ce718b615
commit f5ea97f863
9 changed files with 23 additions and 14 deletions

View File

@@ -74,6 +74,7 @@ title: Vector Wrappers
## Implemented Data Conversion wrappers
```{eval-rst}
.. autoclass:: gymnasium.wrappers.vector.ArrayConversion
.. autoclass:: gymnasium.wrappers.vector.JaxToNumpy
.. autoclass:: gymnasium.wrappers.vector.JaxToTorch
.. autoclass:: gymnasium.wrappers.vector.NumpyToTorch

View File

@@ -12,6 +12,8 @@ wrapper in the page on the wrapper type
* - Name
- Description
* - :class:`ArrayConversion`
- Wraps an environment based on any Array API compatible framework, e.g. ``numpy``, ``torch``, ``jax.numpy``, such that it can be interacted with any other Array API compatible framework.
* - :class:`AtariPreprocessing`
- Implements the common preprocessing techniques for Atari environments (excluding frame stacking).
* - :class:`Autoreset`
@@ -34,8 +36,6 @@ wrapper in the page on the wrapper type
- Converts an image observation computed by ``reset`` and ``step`` from RGB to Grayscale.
* - :class:`HumanRendering`
- Allows human like rendering for environments that support "rgb_array" rendering.
* - :class:`ArrayConversion`
- Wraps an environment based on any Array API compatible framework, e.g. torch, jax, numpy, such that it can be interacted with any other Array API compatible framework.
* - :class:`JaxToNumpy`
- Wraps a Jax-based environment such that it can be interacted with NumPy arrays.
* - :class:`JaxToTorch`

View File

@@ -134,6 +134,7 @@ __all__ = [
"RecordVideo",
"HumanRendering",
# --- Conversion ---
"ArrayConversion",
"JaxToNumpy",
"JaxToTorch",
"NumpyToTorch",
@@ -143,6 +144,7 @@ __all__ = [
# to avoid `import jax` or `import torch` on `import gymnasium`.
_wrapper_to_class = {
# data converters
"ArrayConversion": "array_conversion",
"JaxToNumpy": "jax_to_numpy",
"JaxToTorch": "jax_to_torch",
"NumpyToTorch": "numpy_to_torch",

View File

@@ -137,9 +137,12 @@ def _none_array_conversion(
class ArrayConversion(gym.Wrapper, gym.utils.RecordConstructorArgs):
"""Wraps an Array API compatible environment so that it can be interacted with a specific Array API framework.
"""Wraps an Array API compatible environment so that it can be interacted with with another Array API framework.
Popular Array API frameworks include ``numpy``, ``torch``, ``jax.numpy``, ``cupy`` etc. With this wrapper, you can convert outputs from your environment to
any of these frameworks. Conversely, actions are automatically mapped back to the environment framework, if possible without moving the
data or device transfers.
Actions must be provided as Array API compatible arrays and observations will be returned as Arrays of the specified xp module.
A vector version of the wrapper exists, :class:`gymnasium.wrappers.vector.ArrayConversion`.
Example:

View File

@@ -67,6 +67,7 @@ __all__ = [
# "RecordVideo",
"HumanRendering",
# --- Conversion ---
"ArrayConversion",
"JaxToNumpy",
"JaxToTorch",
"NumpyToTorch",
@@ -77,6 +78,7 @@ __all__ = [
# to avoid `import jax` or `import torch` on `import gymnasium`.
_wrapper_to_class = {
# data converters
"ArrayConversion": "array_conversion",
"JaxToNumpy": "jax_to_numpy",
"JaxToTorch": "jax_to_torch",
"NumpyToTorch": "numpy_to_torch",

View File

@@ -22,11 +22,12 @@ __all__ = ["ArrayConversion"]
class ArrayConversion(VectorWrapper, gym.utils.RecordConstructorArgs):
"""Wraps a vector environment returning Array API compatible arrays so that it can be interacted with through a specific framework.
Notes:
A vectorized version of ``gymnasium.wrappers.ArrayConversion``
Popular Array API frameworks include ``numpy``, ``torch``, ``jax.numpy``, ``cupy`` etc. With this wrapper, you can convert outputs from your environment to
any of these frameworks. Conversely, actions are automatically mapped back to the environment framework, if possible without moving the
data or device transfers.
Actions must be provided as Array API compatible arrays and observations, rewards, terminations and truncations will be returned in the desired framework.
xp here is a module that is compatible with the Array API standard, e.g. ``numpy``, ``jax`` etc.
Notes:
A vectorized version of :class:`gymnasium.wrappers.ArrayConversion`
Example:
>>> import gymnasium as gym # doctest: +SKIP

View File

@@ -17,7 +17,7 @@ class JaxToNumpy(ArrayConversion):
"""Wraps a jax vector environment so that it can be interacted with through numpy arrays.
Notes:
A vectorized version of ``gymnasium.wrappers.JaxToNumpy``
A vectorized version of :class:`gymnasium.wrappers.JaxToNumpy`
Actions must be provided as numpy arrays and observations, rewards, terminations and truncations will be returned as numpy arrays.

View File

@@ -15,8 +15,8 @@ array_api_extra = pytest.importorskip("array_api_extra")
from array_api_compat import array_namespace, is_array_api_obj # noqa: E402
from gymnasium.wrappers import ArrayConversion # noqa: E402
from gymnasium.wrappers.array_conversion import ( # noqa: E402
ArrayConversion,
array_conversion,
module_namespace,
)

View File

@@ -14,10 +14,10 @@ array_api_compat = pytest.importorskip("array_api_compat")
from array_api_compat import array_namespace # noqa: E402
from gymnasium.wrappers.array_conversion import module_namespace # noqa: E402
from gymnasium.wrappers.vector.array_conversion import ArrayConversion # noqa: E402
from gymnasium.wrappers.vector.jax_to_numpy import JaxToNumpy # noqa: E402
from gymnasium.wrappers.vector.jax_to_torch import JaxToTorch # noqa: E402
from gymnasium.wrappers.vector.numpy_to_torch import NumpyToTorch # noqa: E402
from gymnasium.wrappers.vector import ArrayConversion # noqa: E402
from gymnasium.wrappers.vector import JaxToNumpy # noqa: E402
from gymnasium.wrappers.vector import JaxToTorch # noqa: E402
from gymnasium.wrappers.vector import NumpyToTorch # noqa: E402
# Define available modules