mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-07-30 21:34:30 +00:00
Fix conversion wrapper module registration and docs (#1380)
This commit is contained in:
@@ -74,6 +74,7 @@ title: Vector Wrappers
|
||||
## Implemented Data Conversion wrappers
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: gymnasium.wrappers.vector.ArrayConversion
|
||||
.. autoclass:: gymnasium.wrappers.vector.JaxToNumpy
|
||||
.. autoclass:: gymnasium.wrappers.vector.JaxToTorch
|
||||
.. autoclass:: gymnasium.wrappers.vector.NumpyToTorch
|
||||
|
@@ -12,6 +12,8 @@ wrapper in the page on the wrapper type
|
||||
|
||||
* - Name
|
||||
- Description
|
||||
* - :class:`ArrayConversion`
|
||||
- Wraps an environment based on any Array API compatible framework, e.g. ``numpy``, ``torch``, ``jax.numpy``, such that it can be interacted with any other Array API compatible framework.
|
||||
* - :class:`AtariPreprocessing`
|
||||
- Implements the common preprocessing techniques for Atari environments (excluding frame stacking).
|
||||
* - :class:`Autoreset`
|
||||
@@ -34,8 +36,6 @@ wrapper in the page on the wrapper type
|
||||
- Converts an image observation computed by ``reset`` and ``step`` from RGB to Grayscale.
|
||||
* - :class:`HumanRendering`
|
||||
- Allows human like rendering for environments that support "rgb_array" rendering.
|
||||
* - :class:`ArrayConversion`
|
||||
- Wraps an environment based on any Array API compatible framework, e.g. torch, jax, numpy, such that it can be interacted with any other Array API compatible framework.
|
||||
* - :class:`JaxToNumpy`
|
||||
- Wraps a Jax-based environment such that it can be interacted with NumPy arrays.
|
||||
* - :class:`JaxToTorch`
|
||||
|
@@ -134,6 +134,7 @@ __all__ = [
|
||||
"RecordVideo",
|
||||
"HumanRendering",
|
||||
# --- Conversion ---
|
||||
"ArrayConversion",
|
||||
"JaxToNumpy",
|
||||
"JaxToTorch",
|
||||
"NumpyToTorch",
|
||||
@@ -143,6 +144,7 @@ __all__ = [
|
||||
# to avoid `import jax` or `import torch` on `import gymnasium`.
|
||||
_wrapper_to_class = {
|
||||
# data converters
|
||||
"ArrayConversion": "array_conversion",
|
||||
"JaxToNumpy": "jax_to_numpy",
|
||||
"JaxToTorch": "jax_to_torch",
|
||||
"NumpyToTorch": "numpy_to_torch",
|
||||
|
@@ -137,9 +137,12 @@ def _none_array_conversion(
|
||||
|
||||
|
||||
class ArrayConversion(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
||||
"""Wraps an Array API compatible environment so that it can be interacted with a specific Array API framework.
|
||||
"""Wraps an Array API compatible environment so that it can be interacted with with another Array API framework.
|
||||
|
||||
Popular Array API frameworks include ``numpy``, ``torch``, ``jax.numpy``, ``cupy`` etc. With this wrapper, you can convert outputs from your environment to
|
||||
any of these frameworks. Conversely, actions are automatically mapped back to the environment framework, if possible without moving the
|
||||
data or device transfers.
|
||||
|
||||
Actions must be provided as Array API compatible arrays and observations will be returned as Arrays of the specified xp module.
|
||||
A vector version of the wrapper exists, :class:`gymnasium.wrappers.vector.ArrayConversion`.
|
||||
|
||||
Example:
|
||||
|
@@ -67,6 +67,7 @@ __all__ = [
|
||||
# "RecordVideo",
|
||||
"HumanRendering",
|
||||
# --- Conversion ---
|
||||
"ArrayConversion",
|
||||
"JaxToNumpy",
|
||||
"JaxToTorch",
|
||||
"NumpyToTorch",
|
||||
@@ -77,6 +78,7 @@ __all__ = [
|
||||
# to avoid `import jax` or `import torch` on `import gymnasium`.
|
||||
_wrapper_to_class = {
|
||||
# data converters
|
||||
"ArrayConversion": "array_conversion",
|
||||
"JaxToNumpy": "jax_to_numpy",
|
||||
"JaxToTorch": "jax_to_torch",
|
||||
"NumpyToTorch": "numpy_to_torch",
|
||||
|
@@ -22,11 +22,12 @@ __all__ = ["ArrayConversion"]
|
||||
class ArrayConversion(VectorWrapper, gym.utils.RecordConstructorArgs):
|
||||
"""Wraps a vector environment returning Array API compatible arrays so that it can be interacted with through a specific framework.
|
||||
|
||||
Notes:
|
||||
A vectorized version of ``gymnasium.wrappers.ArrayConversion``
|
||||
Popular Array API frameworks include ``numpy``, ``torch``, ``jax.numpy``, ``cupy`` etc. With this wrapper, you can convert outputs from your environment to
|
||||
any of these frameworks. Conversely, actions are automatically mapped back to the environment framework, if possible without moving the
|
||||
data or device transfers.
|
||||
|
||||
Actions must be provided as Array API compatible arrays and observations, rewards, terminations and truncations will be returned in the desired framework.
|
||||
xp here is a module that is compatible with the Array API standard, e.g. ``numpy``, ``jax`` etc.
|
||||
Notes:
|
||||
A vectorized version of :class:`gymnasium.wrappers.ArrayConversion`
|
||||
|
||||
Example:
|
||||
>>> import gymnasium as gym # doctest: +SKIP
|
||||
|
@@ -17,7 +17,7 @@ class JaxToNumpy(ArrayConversion):
|
||||
"""Wraps a jax vector environment so that it can be interacted with through numpy arrays.
|
||||
|
||||
Notes:
|
||||
A vectorized version of ``gymnasium.wrappers.JaxToNumpy``
|
||||
A vectorized version of :class:`gymnasium.wrappers.JaxToNumpy`
|
||||
|
||||
Actions must be provided as numpy arrays and observations, rewards, terminations and truncations will be returned as numpy arrays.
|
||||
|
||||
|
@@ -15,8 +15,8 @@ array_api_extra = pytest.importorskip("array_api_extra")
|
||||
|
||||
from array_api_compat import array_namespace, is_array_api_obj # noqa: E402
|
||||
|
||||
from gymnasium.wrappers import ArrayConversion # noqa: E402
|
||||
from gymnasium.wrappers.array_conversion import ( # noqa: E402
|
||||
ArrayConversion,
|
||||
array_conversion,
|
||||
module_namespace,
|
||||
)
|
||||
|
@@ -14,10 +14,10 @@ array_api_compat = pytest.importorskip("array_api_compat")
|
||||
from array_api_compat import array_namespace # noqa: E402
|
||||
|
||||
from gymnasium.wrappers.array_conversion import module_namespace # noqa: E402
|
||||
from gymnasium.wrappers.vector.array_conversion import ArrayConversion # noqa: E402
|
||||
from gymnasium.wrappers.vector.jax_to_numpy import JaxToNumpy # noqa: E402
|
||||
from gymnasium.wrappers.vector.jax_to_torch import JaxToTorch # noqa: E402
|
||||
from gymnasium.wrappers.vector.numpy_to_torch import NumpyToTorch # noqa: E402
|
||||
from gymnasium.wrappers.vector import ArrayConversion # noqa: E402
|
||||
from gymnasium.wrappers.vector import JaxToNumpy # noqa: E402
|
||||
from gymnasium.wrappers.vector import JaxToTorch # noqa: E402
|
||||
from gymnasium.wrappers.vector import NumpyToTorch # noqa: E402
|
||||
|
||||
|
||||
# Define available modules
|
||||
|
Reference in New Issue
Block a user