diff --git a/docs/api/vector/wrappers.md b/docs/api/vector/wrappers.md index 183cd7e86..cf7bc2b6f 100644 --- a/docs/api/vector/wrappers.md +++ b/docs/api/vector/wrappers.md @@ -74,6 +74,7 @@ title: Vector Wrappers ## Implemented Data Conversion wrappers ```{eval-rst} +.. autoclass:: gymnasium.wrappers.vector.ArrayConversion .. autoclass:: gymnasium.wrappers.vector.JaxToNumpy .. autoclass:: gymnasium.wrappers.vector.JaxToTorch .. autoclass:: gymnasium.wrappers.vector.NumpyToTorch diff --git a/docs/api/wrappers/table.md b/docs/api/wrappers/table.md index 89cb9acb9..49404a4da 100644 --- a/docs/api/wrappers/table.md +++ b/docs/api/wrappers/table.md @@ -12,6 +12,8 @@ wrapper in the page on the wrapper type * - Name - Description + * - :class:`ArrayConversion` + - Wraps an environment based on any Array API compatible framework, e.g. ``numpy``, ``torch``, ``jax.numpy``, such that it can be interacted with any other Array API compatible framework. * - :class:`AtariPreprocessing` - Implements the common preprocessing techniques for Atari environments (excluding frame stacking). * - :class:`Autoreset` @@ -34,8 +36,6 @@ wrapper in the page on the wrapper type - Converts an image observation computed by ``reset`` and ``step`` from RGB to Grayscale. * - :class:`HumanRendering` - Allows human like rendering for environments that support "rgb_array" rendering. - * - :class:`ArrayConversion` - - Wraps an environment based on any Array API compatible framework, e.g. torch, jax, numpy, such that it can be interacted with any other Array API compatible framework. * - :class:`JaxToNumpy` - Wraps a Jax-based environment such that it can be interacted with NumPy arrays. * - :class:`JaxToTorch` diff --git a/gymnasium/wrappers/__init__.py b/gymnasium/wrappers/__init__.py index 9876a3f73..1244e183e 100644 --- a/gymnasium/wrappers/__init__.py +++ b/gymnasium/wrappers/__init__.py @@ -134,6 +134,7 @@ __all__ = [ "RecordVideo", "HumanRendering", # --- Conversion --- + "ArrayConversion", "JaxToNumpy", "JaxToTorch", "NumpyToTorch", @@ -143,6 +144,7 @@ __all__ = [ # to avoid `import jax` or `import torch` on `import gymnasium`. _wrapper_to_class = { # data converters + "ArrayConversion": "array_conversion", "JaxToNumpy": "jax_to_numpy", "JaxToTorch": "jax_to_torch", "NumpyToTorch": "numpy_to_torch", diff --git a/gymnasium/wrappers/array_conversion.py b/gymnasium/wrappers/array_conversion.py index f24842d8e..b974423aa 100644 --- a/gymnasium/wrappers/array_conversion.py +++ b/gymnasium/wrappers/array_conversion.py @@ -137,9 +137,12 @@ def _none_array_conversion( class ArrayConversion(gym.Wrapper, gym.utils.RecordConstructorArgs): - """Wraps an Array API compatible environment so that it can be interacted with a specific Array API framework. + """Wraps an Array API compatible environment so that it can be interacted with with another Array API framework. + + Popular Array API frameworks include ``numpy``, ``torch``, ``jax.numpy``, ``cupy`` etc. With this wrapper, you can convert outputs from your environment to + any of these frameworks. Conversely, actions are automatically mapped back to the environment framework, if possible without moving the + data or device transfers. - Actions must be provided as Array API compatible arrays and observations will be returned as Arrays of the specified xp module. A vector version of the wrapper exists, :class:`gymnasium.wrappers.vector.ArrayConversion`. Example: diff --git a/gymnasium/wrappers/vector/__init__.py b/gymnasium/wrappers/vector/__init__.py index 9ae3ab24f..da51dcfcc 100644 --- a/gymnasium/wrappers/vector/__init__.py +++ b/gymnasium/wrappers/vector/__init__.py @@ -67,6 +67,7 @@ __all__ = [ # "RecordVideo", "HumanRendering", # --- Conversion --- + "ArrayConversion", "JaxToNumpy", "JaxToTorch", "NumpyToTorch", @@ -77,6 +78,7 @@ __all__ = [ # to avoid `import jax` or `import torch` on `import gymnasium`. _wrapper_to_class = { # data converters + "ArrayConversion": "array_conversion", "JaxToNumpy": "jax_to_numpy", "JaxToTorch": "jax_to_torch", "NumpyToTorch": "numpy_to_torch", diff --git a/gymnasium/wrappers/vector/array_conversion.py b/gymnasium/wrappers/vector/array_conversion.py index a48c8a107..f03515897 100644 --- a/gymnasium/wrappers/vector/array_conversion.py +++ b/gymnasium/wrappers/vector/array_conversion.py @@ -22,11 +22,12 @@ __all__ = ["ArrayConversion"] class ArrayConversion(VectorWrapper, gym.utils.RecordConstructorArgs): """Wraps a vector environment returning Array API compatible arrays so that it can be interacted with through a specific framework. - Notes: - A vectorized version of ``gymnasium.wrappers.ArrayConversion`` + Popular Array API frameworks include ``numpy``, ``torch``, ``jax.numpy``, ``cupy`` etc. With this wrapper, you can convert outputs from your environment to + any of these frameworks. Conversely, actions are automatically mapped back to the environment framework, if possible without moving the + data or device transfers. - Actions must be provided as Array API compatible arrays and observations, rewards, terminations and truncations will be returned in the desired framework. - xp here is a module that is compatible with the Array API standard, e.g. ``numpy``, ``jax`` etc. + Notes: + A vectorized version of :class:`gymnasium.wrappers.ArrayConversion` Example: >>> import gymnasium as gym # doctest: +SKIP diff --git a/gymnasium/wrappers/vector/jax_to_numpy.py b/gymnasium/wrappers/vector/jax_to_numpy.py index 4f1bc92d7..fc3a5f38e 100644 --- a/gymnasium/wrappers/vector/jax_to_numpy.py +++ b/gymnasium/wrappers/vector/jax_to_numpy.py @@ -17,7 +17,7 @@ class JaxToNumpy(ArrayConversion): """Wraps a jax vector environment so that it can be interacted with through numpy arrays. Notes: - A vectorized version of ``gymnasium.wrappers.JaxToNumpy`` + A vectorized version of :class:`gymnasium.wrappers.JaxToNumpy` Actions must be provided as numpy arrays and observations, rewards, terminations and truncations will be returned as numpy arrays. diff --git a/tests/wrappers/test_array_conversion.py b/tests/wrappers/test_array_conversion.py index a2091de49..7af347c2f 100644 --- a/tests/wrappers/test_array_conversion.py +++ b/tests/wrappers/test_array_conversion.py @@ -15,8 +15,8 @@ array_api_extra = pytest.importorskip("array_api_extra") from array_api_compat import array_namespace, is_array_api_obj # noqa: E402 +from gymnasium.wrappers import ArrayConversion # noqa: E402 from gymnasium.wrappers.array_conversion import ( # noqa: E402 - ArrayConversion, array_conversion, module_namespace, ) diff --git a/tests/wrappers/vector/test_array_conversion.py b/tests/wrappers/vector/test_array_conversion.py index 6288946cd..a23faa2d1 100644 --- a/tests/wrappers/vector/test_array_conversion.py +++ b/tests/wrappers/vector/test_array_conversion.py @@ -14,10 +14,10 @@ array_api_compat = pytest.importorskip("array_api_compat") from array_api_compat import array_namespace # noqa: E402 from gymnasium.wrappers.array_conversion import module_namespace # noqa: E402 -from gymnasium.wrappers.vector.array_conversion import ArrayConversion # noqa: E402 -from gymnasium.wrappers.vector.jax_to_numpy import JaxToNumpy # noqa: E402 -from gymnasium.wrappers.vector.jax_to_torch import JaxToTorch # noqa: E402 -from gymnasium.wrappers.vector.numpy_to_torch import NumpyToTorch # noqa: E402 +from gymnasium.wrappers.vector import ArrayConversion # noqa: E402 +from gymnasium.wrappers.vector import JaxToNumpy # noqa: E402 +from gymnasium.wrappers.vector import JaxToTorch # noqa: E402 +from gymnasium.wrappers.vector import NumpyToTorch # noqa: E402 # Define available modules