Add generic conversion wrapper between Array API compatible frameworks (#1333)

This commit is contained in:
Martin Schuck
2025-05-12 00:10:06 +02:00
committed by GitHub
parent 95637ebc7f
commit 5dde9a79be
23 changed files with 1039 additions and 623 deletions

View File

@@ -19,7 +19,7 @@ jobs:
- uses: actions/setup-python@v5
with:
python-version: '3.9'
python-version: '3.10'
- name: Install dependencies
run: pip install -r docs/requirements.txt

View File

@@ -19,7 +19,7 @@ jobs:
- uses: actions/setup-python@v5
with:
python-version: '3.9'
python-version: '3.10'
- name: Install dependencies
run: pip install -r docs/requirements.txt

View File

@@ -33,7 +33,7 @@ jobs:
- uses: actions/setup-python@v5
with:
python-version: '3.9'
python-version: '3.10'
- name: Install dependencies
run: pip install -r docs/requirements.txt

View File

@@ -10,8 +10,8 @@ jobs:
strategy:
fail-fast: true
matrix:
python-version: ['3.9', '3.10', '3.11', '3.12']
numpy-version: ['>=1.21,<2.0', '>=2.0']
python-version: ['3.10', '3.11', '3.12']
numpy-version: ['>=1.21,<2.0', '>=2.1']
steps:
- uses: actions/checkout@v4
- run: |
@@ -22,7 +22,6 @@ jobs:
- name: Run tests
run: docker run gymnasium-all-docker pytest tests/*
- name: Run doctests
if: ${{ matrix.python-version != '3.8' }}
run: docker run gymnasium-all-docker pytest --doctest-modules gymnasium/
build-necessary:

View File

@@ -20,7 +20,7 @@ jobs:
strategy:
fail-fast: false # This ensures all matrix combinations run even if one fails
matrix:
python-version: ["3.9"]
python-version: ["3.10"]
tutorial-group:
- gymnasium_basics
- training_agents

View File

@@ -32,7 +32,7 @@ To install the base Gymnasium library, use `pip install gymnasium`
This does not include dependencies for all families of environments (there's a massive number, and some can be problematic to install on certain systems). You can install these dependencies for one family like `pip install "gymnasium[atari]"` or use `pip install "gymnasium[all]"` to install all dependencies.
We support and test for Python 3.8, 3.9, 3.10, 3.11 and 3.12 on Linux and macOS. We will accept PRs related to Windows, but do not officially support it.
We support and test for Python 3.10, 3.11 and 3.12 on Linux and macOS. We will accept PRs related to Windows, but do not officially support it.
## API

View File

@@ -27,6 +27,7 @@ title: Misc Wrappers
## Data Conversion Wrappers
```{eval-rst}
.. autoclass:: gymnasium.wrappers.ArrayConversion
.. autoclass:: gymnasium.wrappers.JaxToNumpy
.. autoclass:: gymnasium.wrappers.JaxToTorch
.. autoclass:: gymnasium.wrappers.NumpyToTorch

View File

@@ -34,6 +34,8 @@ wrapper in the page on the wrapper type
- Converts an image observation computed by ``reset`` and ``step`` from RGB to Grayscale.
* - :class:`HumanRendering`
- Allows human like rendering for environments that support "rgb_array" rendering.
* - :class:`ArrayConversion`
- Wraps an environment based on any Array API compatible framework, e.g. torch, jax, numpy, such that it can be interacted with any other Array API compatible framework.
* - :class:`JaxToNumpy`
- Wraps a Jax-based environment such that it can be interacted with NumPy arrays.
* - :class:`JaxToTorch`

View File

@@ -0,0 +1,261 @@
# This wrapper will convert array inputs from an Array API compatible framework A for the actions
# to any other Array API compatible framework B for an underlying environment that is implemented
# in framework B, then convert the return observations from framework B back to framework A.
#
# More precisely, the wrapper will work for any two frameworks that can be made compatible with the
# `array-api-compat` package.
#
# See https://data-apis.org/array-api/latest/ for more information on the Array API standard, and
# https://data-apis.org/array-api-compat/ for more information on the Array API compatibility layer.
#
# General structure for converting between types originally copied from
# https://github.com/google/brax/blob/9d6b7ced2a13da0d074b5e9fbd3aad8311e26997/brax/io/torch.py
# Under the Apache 2.0 license. Copyright is held by the authors
"""Helper functions and wrapper class for converting between arbitrary Array API compatible frameworks and a target framework."""
from __future__ import annotations
import functools
import importlib
import numbers
from collections import abc
from types import ModuleType, NoneType
from typing import Any, Iterable, Mapping, SupportsFloat
import numpy as np
from packaging.version import Version
import gymnasium as gym
from gymnasium.core import RenderFrame, WrapperActType, WrapperObsType
from gymnasium.error import DependencyNotInstalled
try:
from array_api_compat import array_namespace, is_array_api_obj, to_device
except ImportError:
raise DependencyNotInstalled(
'Array API packages are not installed therefore cannot call `array_conversion`, run `pip install "gymnasium[array-api]"`'
)
if Version(np.__version__) < Version("2.1.0"):
raise DependencyNotInstalled("Array API functionality requires numpy >= 2.1.0")
__all__ = ["ArrayConversion", "array_conversion"]
Array = Any # TODO: Switch to ArrayAPI type once https://github.com/data-apis/array-api/pull/589 is merged
Device = Any # TODO: Switch to ArrayAPI type if available
def module_namespace(xp: ModuleType) -> ModuleType:
"""Determine the Array API compatible namespace of the given module.
This function is closely linked to the `array_api_compat.array_namespace` function. It returns
the compatible namespace for a module directly instead of from an array object of that module.
See https://data-apis.org/array-api-compat/helper-functions.html#array_api_compat.array_namespace
"""
try:
return array_namespace(xp.empty(0))
except AttributeError as e:
raise ValueError(f"Module {xp} is not an Array API compatible module.") from e
def module_name_to_namespace(name: str) -> ModuleType:
return module_namespace(importlib.import_module(name))
@functools.singledispatch
def array_conversion(value: Any, xp: ModuleType, device: Device | None = None) -> Any:
"""Convert a value into the specified xp module array type."""
raise Exception(
f"No known conversion for ({type(value)}) to xp module ({xp}) registered. Report as issue on github."
)
@array_conversion.register(numbers.Number)
def _number_array_conversion(
value: numbers.Number, xp: ModuleType, device: Device | None = None
) -> Array:
"""Convert a python number (int, float, complex) to an Array API framework array."""
return xp.asarray(value, device=device)
@array_conversion.register(abc.Mapping)
def _mapping_array_conversion(
value: Mapping[str, Any], xp: ModuleType, device: Device | None = None
) -> Mapping[str, Any]:
"""Convert a mapping of Arrays into a Dictionary of the specified xp module array type."""
return type(value)(**{k: array_conversion(v, xp, device) for k, v in value.items()})
@array_conversion.register(abc.Iterable)
def _iterable_array_conversion(
value: Iterable[Any], xp: ModuleType, device: Device | None = None
) -> Iterable[Any]:
"""Convert an Iterable from Arrays to an iterable of the specified xp module array type."""
# There is currently no type for ArrayAPI compatible objects, so they fall through to this
# function registered for any Iterable. If they are arrays, we can convert them directly.
# We currently cannot pass the device to the from_dlpack function, since it is not supported
# for some frameworks (see e.g. https://github.com/data-apis/array-api-compat/issues/204)
if is_array_api_obj(value):
return _array_api_array_conversion(value, xp, device)
if hasattr(value, "_make"):
# namedtuple - underline used to prevent potential name conflicts
# noinspection PyProtectedMember
return type(value)._make(array_conversion(v, xp, device) for v in value)
return type(value)(array_conversion(v, xp, device) for v in value)
def _array_api_array_conversion(
value: Array, xp: ModuleType, device: Device | None = None
) -> Array:
"""Convert an Array API compatible array to the specified xp module array type."""
try:
x = xp.from_dlpack(value)
return to_device(x, device) if device is not None else x
except (RuntimeError, BufferError):
# If dlpack fails (e.g. because the array is read-only for frameworks that do not
# support it), we create a copy of the array that we own and then convert it.
# TODO: The correct treatment of read-only arrays is currently not fully clear in the
# Array API. Once ongoing discussions are resolved, we should update this code to remove
# any fallbacks.
value_namespace = array_namespace(value)
value_copy = value_namespace.asarray(value, copy=True)
return xp.asarray(value_copy, device=device)
@array_conversion.register(NoneType)
def _none_array_conversion(
value: None, xp: ModuleType, device: Device | None = None
) -> None:
"""Passes through None values."""
return value
class ArrayConversion(gym.Wrapper, gym.utils.RecordConstructorArgs):
"""Wraps an Array API compatible environment so that it can be interacted with a specific Array API framework.
Actions must be provided as Array API compatible arrays and observations will be returned as Arrays of the specified xp module.
A vector version of the wrapper exists, :class:`gymnasium.wrappers.vector.ArrayConversion`.
Example:
>>> import torch # doctest: +SKIP
>>> import jax.numpy as jnp # doctest: +SKIP
>>> import gymnasium as gym # doctest: +SKIP
>>> env = gym.make("JaxEnv-vx") # doctest: +SKIP
>>> env = ArrayConversion(env, env_xp=jnp, target_xp=torch) # doctest: +SKIP
>>> obs, _ = env.reset(seed=123) # doctest: +SKIP
>>> type(obs) # doctest: +SKIP
<class 'torch.Tensor'>
>>> action = torch.tensor(env.action_space.sample()) # doctest: +SKIP
>>> obs, reward, terminated, truncated, info = env.step(action) # doctest: +SKIP
>>> type(obs) # doctest: +SKIP
<class 'torch.Tensor'>
>>> type(reward) # doctest: +SKIP
<class 'float'>
>>> type(terminated) # doctest: +SKIP
<class 'bool'>
>>> type(truncated) # doctest: +SKIP
<class 'bool'>
Change logs:
* v1.2.0 - Initially added
"""
def __init__(
self,
env: gym.Env,
env_xp: ModuleType,
target_xp: ModuleType,
env_device: Device | None = None,
target_device: Device | None = None,
):
"""Wrapper class to change inputs and outputs of environment to any Array API framework.
Args:
env: The Array API compatible environment to wrap
env_xp: The Array API framework the environment is on
target_xp: The Array API framework to convert to
env_device: The device the environment is on
target_device: The device on which Arrays should be returned
"""
gym.utils.RecordConstructorArgs.__init__(self)
gym.Wrapper.__init__(self, env)
self._env_xp = module_namespace(env_xp)
self._target_xp = module_namespace(target_xp)
self._env_device: Device | None = env_device
self._target_device: Device | None = target_device
def step(
self, action: WrapperActType
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]:
"""Performs the given action within the environment.
Args:
action: The action to perform as any Array API compatible array
Returns:
The next observation, reward, termination, truncation, and extra info
"""
action = array_conversion(action, xp=self._env_xp, device=self._env_device)
obs, reward, terminated, truncated, info = self.env.step(action)
return (
array_conversion(obs, xp=self._target_xp, device=self._target_device),
float(reward),
bool(terminated),
bool(truncated),
array_conversion(info, xp=self._target_xp, device=self._target_device),
)
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[WrapperObsType, dict[str, Any]]:
"""Resets the environment returning observation and info as Array from any Array API compatible framework.
Args:
seed: The seed for resetting the environment
options: The options for resetting the environment, these are converted to jax arrays.
Returns:
xp-based observations and info
"""
if options:
options = array_conversion(options, self._env_xp, self._env_device)
return array_conversion(
self.env.reset(seed=seed, options=options),
self._target_xp,
self._target_device,
)
def render(self) -> RenderFrame | list[RenderFrame] | None:
"""Returns the rendered frames as an xp Array."""
return array_conversion(self.env.render(), self._target_xp, self._target_device)
def __getstate__(self):
"""Returns the object pickle state with args and kwargs."""
env_xp_name = self._env_xp.__name__.replace("array_api_compat.", "")
target_xp_name = self._target_xp.__name__.replace("array_api_compat.", "")
env_device = self._env_device
target_device = self._target_device
return {
"env_xp_name": env_xp_name,
"target_xp_name": target_xp_name,
"env_device": env_device,
"target_device": target_device,
"env": self.env,
}
def __setstate__(self, d):
"""Sets the object pickle state using d."""
self.env = d["env"]
self._env_xp = module_name_to_namespace(d["env_xp_name"])
self._target_xp = module_name_to_namespace(d["target_xp_name"])
self._env_device = d["env_device"]
self._target_device = d["target_device"]

View File

@@ -3,19 +3,20 @@
from __future__ import annotations
import functools
import numbers
from collections import abc
from typing import Any, Iterable, Mapping, SupportsFloat
import numpy as np
import gymnasium as gym
from gymnasium.core import ActType, ObsType, RenderFrame, WrapperActType, WrapperObsType
from gymnasium.core import ActType, ObsType
from gymnasium.error import DependencyNotInstalled
from gymnasium.wrappers.array_conversion import (
ArrayConversion,
array_conversion,
module_namespace,
)
try:
import jax
import jax.numpy as jnp
except ImportError:
raise DependencyNotInstalled(
@@ -24,110 +25,13 @@ except ImportError:
__all__ = ["JaxToNumpy", "jax_to_numpy", "numpy_to_jax"]
# The NoneType is not defined in Python 3.9. Remove when the minimal version is bumped to >=3.10
_NoneType = type(None)
jax_to_numpy = functools.partial(array_conversion, xp=module_namespace(np))
numpy_to_jax = functools.partial(array_conversion, xp=module_namespace(jnp))
@functools.singledispatch
def numpy_to_jax(value: Any) -> Any:
"""Converts a value to a Jax Array."""
raise Exception(
f"No known conversion for Numpy type ({type(value)}) to Jax registered. Report as issue on github."
)
@numpy_to_jax.register(numbers.Number)
def _number_to_jax(
value: numbers.Number,
) -> jax.Array:
"""Converts a number (int, float, etc.) to a Jax Array."""
assert jnp is not None
return jnp.array(value)
@numpy_to_jax.register(np.ndarray)
def _numpy_array_to_jax(value: np.ndarray) -> jax.Array:
"""Converts a NumPy Array to a Jax Array with the same dtype (excluding float64 without being enabled)."""
assert jnp is not None
return jnp.array(value, dtype=value.dtype)
@numpy_to_jax.register(abc.Mapping)
def _mapping_numpy_to_jax(value: Mapping[str, Any]) -> Mapping[str, Any]:
"""Converts a dictionary of numpy arrays to a mapping of Jax Array."""
return type(value)(**{k: numpy_to_jax(v) for k, v in value.items()})
@numpy_to_jax.register(abc.Iterable)
def _iterable_numpy_to_jax(
value: Iterable[np.ndarray | Any],
) -> Iterable[jax.Array | Any]:
"""Converts an Iterable from Numpy Arrays to an iterable of Jax Array."""
if hasattr(value, "_make"):
# namedtuple - underline used to prevent potential name conflicts
# noinspection PyProtectedMember
return type(value)._make(numpy_to_jax(v) for v in value)
else:
return type(value)(numpy_to_jax(v) for v in value)
@numpy_to_jax.register(_NoneType)
def _none_numpy_to_jax(value: None) -> None:
"""Passes through None values."""
return value
@functools.singledispatch
def jax_to_numpy(value: Any) -> Any:
"""Converts a value to a numpy array."""
raise Exception(
f"No known conversion for Jax type ({type(value)}) to NumPy registered. Report as issue on github."
)
@jax_to_numpy.register(jax.Array)
def _devicearray_jax_to_numpy(value: jax.Array) -> np.ndarray:
"""Converts a Jax Array to a numpy array."""
return np.array(value)
@jax_to_numpy.register(abc.Mapping)
def _mapping_jax_to_numpy(
value: Mapping[str, jax.Array | Any],
) -> Mapping[str, np.ndarray | Any]:
"""Converts a dictionary of Jax Array to a mapping of numpy arrays."""
return type(value)(**{k: jax_to_numpy(v) for k, v in value.items()})
@jax_to_numpy.register(abc.Iterable)
def _iterable_jax_to_numpy(
value: Iterable[np.ndarray | Any],
) -> Iterable[jax.Array | Any]:
"""Converts an Iterable from Numpy arrays to an iterable of Jax Array."""
if isinstance(value, jax.Array):
# Since the update to jax 0.6.0, calling jax_to_numpy with a <class 'jaxlib.xla_extension.ArrayImpl'>
# argument wrongly dispatches to _iterable_jax_to_numpy which fails with:
# TypeError: (): incompatible function arguments.
# See: https://github.com/Farama-Foundation/Gymnasium/issues/1360
return _devicearray_jax_to_numpy(value)
elif hasattr(value, "_make"):
# namedtuple - underline used to prevent potential name conflicts
# noinspection PyProtectedMember
return type(value)._make(jax_to_numpy(v) for v in value)
else:
return type(value)(jax_to_numpy(v) for v in value)
@jax_to_numpy.register(_NoneType)
def _none_jax_to_numpy(value: None) -> None:
"""Passes through None values."""
return value
class JaxToNumpy(
gym.Wrapper[WrapperObsType, WrapperActType, ObsType, ActType],
gym.utils.RecordConstructorArgs,
):
class JaxToNumpy(ArrayConversion):
"""Wraps a Jax-based environment such that it can be interacted with NumPy arrays.
Actions must be provided as numpy arrays and observations will be returned as numpy arrays.
@@ -169,48 +73,4 @@ class JaxToNumpy(
raise DependencyNotInstalled(
'Jax is not installed, run `pip install "gymnasium[jax]"`'
)
gym.utils.RecordConstructorArgs.__init__(self)
gym.Wrapper.__init__(self, env)
def step(
self, action: WrapperActType
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]:
"""Transforms the action to a jax array .
Args:
action: the action to perform as a numpy array
Returns:
A tuple containing numpy versions of the next observation, reward, termination, truncation, and extra info.
"""
jax_action = numpy_to_jax(action)
obs, reward, terminated, truncated, info = self.env.step(jax_action)
return (
jax_to_numpy(obs),
float(reward),
bool(terminated),
bool(truncated),
jax_to_numpy(info),
)
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[WrapperObsType, dict[str, Any]]:
"""Resets the environment returning numpy-based observation and info.
Args:
seed: The seed for resetting the environment
options: The options for resetting the environment, these are converted to jax arrays.
Returns:
Numpy-based observations and info
"""
if options:
options = numpy_to_jax(options)
return jax_to_numpy(self.env.reset(seed=seed, options=options))
def render(self) -> RenderFrame | list[RenderFrame] | None:
"""Returns the rendered frames as a numpy array."""
return jax_to_numpy(self.env.render())
super().__init__(env=env, env_xp=jnp, target_xp=np)

View File

@@ -7,22 +7,23 @@
# Under the Apache 2.0 license. Copyright is held by the authors
"""Helper functions and wrapper class for converting between PyTorch and Jax."""
from __future__ import annotations
import functools
import numbers
from collections import abc
from typing import Any, Iterable, Mapping, SupportsFloat, Union
from typing import Union
import gymnasium as gym
from gymnasium.core import RenderFrame, WrapperActType, WrapperObsType
from gymnasium.error import DependencyNotInstalled
from gymnasium.wrappers.array_conversion import (
ArrayConversion,
array_conversion,
module_namespace,
)
try:
import jax
import jax.numpy as jnp
from jax import dlpack as jax_dlpack
except ImportError:
raise DependencyNotInstalled(
@@ -31,7 +32,6 @@ except ImportError:
try:
import torch
from torch.utils import dlpack as torch_dlpack
Device = Union[str, torch.device]
except ImportError:
@@ -42,109 +42,13 @@ except ImportError:
__all__ = ["JaxToTorch", "jax_to_torch", "torch_to_jax", "Device"]
# The NoneType is not defined in Python 3.9. Remove when the minimal version is bumped to >=3.10
_NoneType = type(None)
torch_to_jax = functools.partial(array_conversion, xp=module_namespace(jnp))
jax_to_torch = functools.partial(array_conversion, xp=module_namespace(torch))
@functools.singledispatch
def torch_to_jax(value: Any) -> Any:
"""Converts a PyTorch Tensor into a Jax Array."""
raise Exception(
f"No known conversion for Torch type ({type(value)}) to Jax registered. Report as issue on github."
)
@torch_to_jax.register(numbers.Number)
def _number_torch_to_jax(value: numbers.Number) -> Any:
"""Convert a python number (int, float, complex) to a jax array."""
return jnp.array(value)
@torch_to_jax.register(torch.Tensor)
def _tensor_torch_to_jax(value: torch.Tensor) -> jax.Array:
"""Converts a PyTorch Tensor into a Jax Array."""
return jax_dlpack.from_dlpack(value) # pyright: ignore[reportPrivateImportUsage]
@torch_to_jax.register(abc.Mapping)
def _mapping_torch_to_jax(value: Mapping[str, Any]) -> Mapping[str, Any]:
"""Converts a mapping of PyTorch Tensors into a Dictionary of Jax Array."""
return type(value)(**{k: torch_to_jax(v) for k, v in value.items()})
@torch_to_jax.register(abc.Iterable)
def _iterable_torch_to_jax(value: Iterable[Any]) -> Iterable[Any]:
"""Converts an Iterable from PyTorch Tensors to an iterable of Jax Array."""
if hasattr(value, "_make"):
# namedtuple - underline used to prevent potential name conflicts
# noinspection PyProtectedMember
return type(value)._make(torch_to_jax(v) for v in value)
else:
return type(value)(torch_to_jax(v) for v in value)
@torch_to_jax.register(_NoneType)
def _none_torch_to_jax(value: None) -> None:
"""Passes through None values."""
return value
@functools.singledispatch
def jax_to_torch(value: Any, device: Device | None = None) -> Any:
"""Converts a Jax Array into a PyTorch Tensor."""
raise Exception(
f"No known conversion for Jax type ({type(value)}) to PyTorch registered. Report as issue on github."
)
@jax_to_torch.register(jax.Array)
def _devicearray_jax_to_torch(
value: jax.Array, device: Device | None = None
) -> torch.Tensor:
"""Converts a Jax Array into a PyTorch Tensor."""
assert jax_dlpack is not None and torch_dlpack is not None
tensor = torch_dlpack.from_dlpack(
value
) # pyright: ignore[reportPrivateImportUsage]
if device:
return tensor.to(device=device)
return tensor
@jax_to_torch.register(abc.Mapping)
def _jax_mapping_to_torch(
value: Mapping[str, Any], device: Device | None = None
) -> Mapping[str, Any]:
"""Converts a mapping of Jax Array into a Dictionary of PyTorch Tensors."""
return type(value)(**{k: jax_to_torch(v, device) for k, v in value.items()})
@jax_to_torch.register(abc.Iterable)
def _jax_iterable_to_torch(
value: Iterable[Any], device: Device | None = None
) -> Iterable[Any]:
"""Converts an Iterable from Jax Array to an iterable of PyTorch Tensors."""
if isinstance(value, jax.Array):
# Since the update to jax 0.6.0, calling jax_to_torch with a <class 'jaxlib.xla_extension.ArrayImpl'>
# argument wrongly dispatches to _iterable_jax_to_torch which fails with:
# TypeError: (): incompatible function arguments.
# See: https://github.com/Farama-Foundation/Gymnasium/issues/1360
return _devicearray_jax_to_torch(value)
elif hasattr(value, "_make"):
# namedtuple - underline used to prevent potential name conflicts
# noinspection PyProtectedMember
return type(value)._make(jax_to_torch(v, device) for v in value)
else:
return type(value)(jax_to_torch(v, device) for v in value)
@jax_to_torch.register(_NoneType)
def _none_jax_to_torch(value: None, device: Device | None = None) -> None:
"""Passes through None values."""
return value
class JaxToTorch(gym.Wrapper, gym.utils.RecordConstructorArgs):
class JaxToTorch(ArrayConversion):
"""Wraps a Jax-based environment so that it can be interacted with PyTorch Tensors.
Actions must be provided as PyTorch Tensors and observations will be returned as PyTorch Tensors.
@@ -183,50 +87,8 @@ class JaxToTorch(gym.Wrapper, gym.utils.RecordConstructorArgs):
env: The Jax-based environment to wrap
device: The device the torch Tensors should be moved to
"""
gym.utils.RecordConstructorArgs.__init__(self, device=device)
gym.Wrapper.__init__(self, env)
super().__init__(env=env, env_xp=jnp, target_xp=torch, target_device=device)
# TODO: Device was part of the public API, but should be removed in favor of _env_device and
# _target_device.
self.device: Device | None = device
def step(
self, action: WrapperActType
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]:
"""Performs the given action within the environment.
Args:
action: The action to perform as a PyTorch Tensor
Returns:
The next observation, reward, termination, truncation, and extra info
"""
jax_action = torch_to_jax(action)
obs, reward, terminated, truncated, info = self.env.step(jax_action)
return (
jax_to_torch(obs, self.device),
float(reward),
bool(terminated),
bool(truncated),
jax_to_torch(info, self.device),
)
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[WrapperObsType, dict[str, Any]]:
"""Resets the environment returning PyTorch-based observation and info.
Args:
seed: The seed for resetting the environment
options: The options for resetting the environment, these are converted to jax arrays.
Returns:
PyTorch-based observations and info
"""
if options:
options = torch_to_jax(options)
return jax_to_torch(self.env.reset(seed=seed, options=options), self.device)
def render(self) -> RenderFrame | list[RenderFrame] | None:
"""Returns the rendered frames as a torch tensor."""
return jax_to_torch(self.env.render())

View File

@@ -3,15 +3,17 @@
from __future__ import annotations
import functools
import numbers
from collections import abc
from typing import Any, Iterable, Mapping, SupportsFloat, Union
from typing import Union
import numpy as np
import gymnasium as gym
from gymnasium.core import RenderFrame, WrapperActType, WrapperObsType
from gymnasium.error import DependencyNotInstalled
from gymnasium.wrappers.array_conversion import (
ArrayConversion,
array_conversion,
module_namespace,
)
try:
@@ -26,100 +28,13 @@ except ImportError:
__all__ = ["NumpyToTorch", "torch_to_numpy", "numpy_to_torch", "Device"]
# The NoneType is not defined in Python 3.9. Remove when the minimal version is bumped to >=3.10
_NoneType = type(None)
torch_to_numpy = functools.partial(array_conversion, xp=module_namespace(np))
numpy_to_torch = functools.partial(array_conversion, xp=module_namespace(torch))
@functools.singledispatch
def torch_to_numpy(value: Any) -> Any:
"""Converts a PyTorch Tensor into a NumPy Array."""
raise Exception(
f"No known conversion for Torch type ({type(value)}) to NumPy registered. Report as issue on github."
)
@torch_to_numpy.register(numbers.Number)
def _number_to_numpy(value: numbers.Number) -> Any:
"""Convert a python number (int, float, complex) to a NumPy array."""
return np.array(value)
@torch_to_numpy.register(torch.Tensor)
def _torch_to_numpy(value: torch.Tensor) -> Any:
"""Convert a torch.Tensor to a NumPy array."""
return value.numpy(force=True)
@torch_to_numpy.register(abc.Mapping)
def _mapping_torch_to_numpy(value: Mapping[str, Any]) -> Mapping[str, Any]:
"""Converts a mapping of PyTorch Tensors into a Dictionary of NumPy Array."""
return type(value)(**{k: torch_to_numpy(v) for k, v in value.items()})
@torch_to_numpy.register(abc.Iterable)
def _iterable_torch_to_numpy(value: Iterable[Any]) -> Iterable[Any]:
"""Converts an Iterable from PyTorch Tensors to an iterable of NumPy Array."""
if hasattr(value, "_make"):
# namedtuple - underline used to prevent potential name conflicts
# noinspection PyProtectedMember
return type(value)._make(torch_to_numpy(v) for v in value)
else:
return type(value)(torch_to_numpy(v) for v in value)
@torch_to_numpy.register(_NoneType)
def _none_torch_to_numpy(value: None) -> None:
"""Passes through None values."""
return value
@functools.singledispatch
def numpy_to_torch(value: Any, device: Device | None = None) -> Any:
"""Converts a NumPy Array into a PyTorch Tensor."""
raise Exception(
f"No known conversion for NumPy type ({type(value)}) to PyTorch registered. Report as issue on github."
)
@numpy_to_torch.register(numbers.Number)
@numpy_to_torch.register(np.ndarray)
def _numpy_to_torch(value: np.ndarray, device: Device | None = None) -> torch.Tensor:
"""Converts a NumPy Array into a PyTorch Tensor."""
assert torch is not None
tensor = torch.tensor(value)
if device:
return tensor.to(device=device)
return tensor
@numpy_to_torch.register(abc.Mapping)
def _numpy_mapping_to_torch(
value: Mapping[str, Any], device: Device | None = None
) -> Mapping[str, Any]:
"""Converts a mapping of NumPy Array into a Dictionary of PyTorch Tensors."""
return type(value)(**{k: numpy_to_torch(v, device) for k, v in value.items()})
@numpy_to_torch.register(abc.Iterable)
def _numpy_iterable_to_torch(
value: Iterable[Any], device: Device | None = None
) -> Iterable[Any]:
"""Converts an Iterable from NumPy Array to an iterable of PyTorch Tensors."""
if hasattr(value, "_make"):
# namedtuple - underline used to prevent potential name conflicts
# noinspection PyProtectedMember
return type(value)._make(numpy_to_torch(v, device) for v in value)
else:
return type(value)(numpy_to_torch(v, device) for v in value)
@numpy_to_torch.register(_NoneType)
def _none_numpy_to_torch(value: None) -> None:
"""Passes through None values."""
return value
class NumpyToTorch(gym.Wrapper, gym.utils.RecordConstructorArgs):
class NumpyToTorch(ArrayConversion):
"""Wraps a NumPy-based environment such that it can be interacted with PyTorch Tensors.
Actions must be provided as PyTorch Tensors and observations will be returned as PyTorch Tensors.
@@ -158,50 +73,6 @@ class NumpyToTorch(gym.Wrapper, gym.utils.RecordConstructorArgs):
env: The NumPy-based environment to wrap
device: The device the torch Tensors should be moved to
"""
gym.utils.RecordConstructorArgs.__init__(self, device=device)
gym.Wrapper.__init__(self, env)
super().__init__(env=env, env_xp=np, target_xp=torch, target_device=device)
self.device: Device | None = device
def step(
self, action: WrapperActType
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]:
"""Using a PyTorch based action that is converted to NumPy to be used by the environment.
Args:
action: A PyTorch-based action
Returns:
The PyTorch-based Tensor next observation, reward, termination, truncation, and extra info
"""
jax_action = torch_to_numpy(action)
obs, reward, terminated, truncated, info = self.env.step(jax_action)
return (
numpy_to_torch(obs, self.device),
float(reward),
bool(terminated),
bool(truncated),
numpy_to_torch(info, self.device),
)
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[WrapperObsType, dict[str, Any]]:
"""Resets the environment returning PyTorch-based observation and info.
Args:
seed: The seed for resetting the environment
options: The options for resetting the environment, these are converted to jax arrays.
Returns:
PyTorch-based observations and info
"""
if options:
options = torch_to_numpy(options)
return numpy_to_torch(self.env.reset(seed=seed, options=options), self.device)
def render(self) -> RenderFrame | list[RenderFrame] | None:
"""Returns the rendered frames as a torch tensor."""
return numpy_to_torch(self.env.render())

View File

@@ -0,0 +1,131 @@
"""Vector wrapper for converting between Array API compatible frameworks."""
from __future__ import annotations
from types import ModuleType
from typing import Any
import gymnasium as gym
from gymnasium.core import ActType, ObsType
from gymnasium.vector import VectorEnv, VectorWrapper
from gymnasium.vector.vector_env import ArrayType
from gymnasium.wrappers.array_conversion import (
Device,
array_conversion,
module_name_to_namespace,
)
__all__ = ["ArrayConversion"]
class ArrayConversion(VectorWrapper, gym.utils.RecordConstructorArgs):
"""Wraps a vector environment returning Array API compatible arrays so that it can be interacted with through a specific framework.
Notes:
A vectorized version of ``gymnasium.wrappers.ArrayConversion``
Actions must be provided as Array API compatible arrays and observations, rewards, terminations and truncations will be returned in the desired framework.
xp here is a module that is compatible with the Array API standard, e.g. ``numpy``, ``jax`` etc.
Example:
>>> import gymnasium as gym # doctest: +SKIP
>>> envs = gym.make_vec("JaxEnv-vx", 3) # doctest: +SKIP
>>> envs = ArrayConversion(envs, xp=np) # doctest: +SKIP
"""
def __init__(
self,
env: VectorEnv,
env_xp: ModuleType | str,
target_xp: ModuleType | str,
env_device: Device | None = None,
target_device: Device | None = None,
):
"""Wrapper class to change inputs and outputs of environment to any Array API framework.
Args:
env: The Array API compatible environment to wrap
env_xp: The Array API framework the environment is on
target_xp: The Array API framework to convert to
env_device: The device the environment is on
target_device: The device on which Arrays should be returned
"""
gym.utils.RecordConstructorArgs.__init__(self)
VectorWrapper.__init__(self, env)
self._env_xp = env_xp
self._target_xp = target_xp
self._env_device = env_device
self._target_device = target_device
def step(
self, actions: ActType
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]:
"""Transforms the action to the specified xp module array type.
Args:
actions: The action to perform
Returns:
A tuple containing xp versions of the next observation, reward, termination, truncation, and extra info.
"""
actions = array_conversion(actions, xp=self._env_xp, device=self._env_device)
obs, reward, terminated, truncated, info = self.env.step(actions)
return (
array_conversion(obs, xp=self._target_xp, device=self._target_device),
array_conversion(reward, xp=self._target_xp, device=self._target_device),
array_conversion(
terminated, xp=self._target_xp, device=self._target_device
),
array_conversion(truncated, xp=self._target_xp, device=self._target_device),
array_conversion(info, xp=self._target_xp, device=self._target_device),
)
def reset(
self,
*,
seed: int | list[int] | None = None,
options: dict[str, Any] | None = None,
) -> tuple[ObsType, dict[str, Any]]:
"""Resets the environment returning xp-based observation and info.
Args:
seed: The seed for resetting the environment
options: The options for resetting the environment, these are converted to xp arrays.
Returns:
xp-based observations and info
"""
if options:
options = array_conversion(
options, xp=self._env_xp, device=self._env_device
)
return array_conversion(
self.env.reset(seed=seed, options=options),
xp=self._target_xp,
device=self._target_device,
)
def __getstate__(self):
"""Returns the object pickle state with args and kwargs."""
env_xp_name = self._env_xp.__name__.replace("array_api_compat.", "")
target_xp_name = self._target_xp.__name__.replace("array_api_compat.", "")
env_device = self._env_device
target_device = self._target_device
return {
"env_xp_name": env_xp_name,
"target_xp_name": target_xp_name,
"env_device": env_device,
"target_device": target_device,
"env": self.env,
}
def __setstate__(self, d):
"""Sets the object pickle state using d."""
self.env = d["env"]
self._env_xp = module_name_to_namespace(d["env_xp_name"])
self._target_xp = module_name_to_namespace(d["target_xp_name"])
self._env_device = d["env_device"]
self._target_device = d["target_device"]

View File

@@ -2,21 +2,18 @@
from __future__ import annotations
from typing import Any
import jax.numpy as jnp
import numpy as np
from gymnasium.core import ActType, ObsType
from gymnasium.error import DependencyNotInstalled
from gymnasium.vector import VectorEnv, VectorWrapper
from gymnasium.vector.vector_env import ArrayType
from gymnasium.wrappers.jax_to_numpy import jax_to_numpy, numpy_to_jax
from gymnasium.vector import VectorEnv
from gymnasium.wrappers.vector.array_conversion import ArrayConversion
__all__ = ["JaxToNumpy"]
class JaxToNumpy(VectorWrapper):
class JaxToNumpy(ArrayConversion):
"""Wraps a jax vector environment so that it can be interacted with through numpy arrays.
Notes:
@@ -40,46 +37,4 @@ class JaxToNumpy(VectorWrapper):
raise DependencyNotInstalled(
'Jax is not installed, run `pip install "gymnasium[jax]"`'
)
super().__init__(env)
def step(
self, actions: ActType
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]:
"""Transforms the action to a jax array .
Args:
actions: the action to perform as a numpy array
Returns:
A tuple containing numpy versions of the next observation, reward, termination, truncation, and extra info.
"""
jax_actions = numpy_to_jax(actions)
obs, reward, terminated, truncated, info = self.env.step(jax_actions)
return (
jax_to_numpy(obs),
jax_to_numpy(reward),
jax_to_numpy(terminated),
jax_to_numpy(truncated),
jax_to_numpy(info),
)
def reset(
self,
*,
seed: int | list[int] | None = None,
options: dict[str, Any] | None = None,
) -> tuple[ObsType, dict[str, Any]]:
"""Resets the environment returning numpy-based observation and info.
Args:
seed: The seed for resetting the environment
options: The options for resetting the environment, these are converted to jax arrays.
Returns:
Numpy-based observations and info
"""
if options:
options = numpy_to_jax(options)
return jax_to_numpy(self.env.reset(seed=seed, options=options))
super().__init__(env, env_xp=jnp, target_xp=np)

View File

@@ -2,18 +2,18 @@
from __future__ import annotations
from typing import Any
import jax.numpy as jnp
import torch
from gymnasium.core import ActType, ObsType
from gymnasium.vector import VectorEnv, VectorWrapper
from gymnasium.vector.vector_env import ArrayType
from gymnasium.wrappers.jax_to_torch import Device, jax_to_torch, torch_to_jax
from gymnasium.vector import VectorEnv
from gymnasium.wrappers.jax_to_torch import Device
from gymnasium.wrappers.vector.array_conversion import ArrayConversion
__all__ = ["JaxToTorch"]
class JaxToTorch(VectorWrapper):
class JaxToTorch(ArrayConversion):
"""Wraps a Jax-based vector environment so that it can be interacted with through PyTorch Tensors.
Actions must be provided as PyTorch Tensors and observations, rewards, terminations and truncations will be returned as PyTorch Tensors.
@@ -31,48 +31,6 @@ class JaxToTorch(VectorWrapper):
env: The Jax-based vector environment to wrap
device: The device the torch Tensors should be moved to
"""
super().__init__(env)
super().__init__(env, env_xp=jnp, target_xp=torch, target_device=device)
self.device: Device | None = device
def step(
self, actions: ActType
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]:
"""Performs the given action within the environment.
Args:
actions: The action to perform as a PyTorch Tensor
Returns:
Torch-based Tensors of the next observation, reward, termination, truncation, and extra info
"""
jax_action = torch_to_jax(actions)
obs, reward, terminated, truncated, info = self.env.step(jax_action)
return (
jax_to_torch(obs, self.device),
jax_to_torch(reward, self.device),
jax_to_torch(terminated, self.device),
jax_to_torch(truncated, self.device),
jax_to_torch(info, self.device),
)
def reset(
self,
*,
seed: int | list[int] | None = None,
options: dict[str, Any] | None = None,
) -> tuple[ObsType, dict[str, Any]]:
"""Resets the environment returning PyTorch-based observation and info.
Args:
seed: The seed for resetting the environment
options: The options for resetting the environment, these are converted to jax arrays.
Returns:
PyTorch-based observations and info
"""
if options:
options = torch_to_jax(options)
return jax_to_torch(self.env.reset(seed=seed, options=options), self.device)

View File

@@ -2,18 +2,18 @@
from __future__ import annotations
from typing import Any
import numpy as np
import torch
from gymnasium.core import ActType, ObsType
from gymnasium.vector import VectorEnv, VectorWrapper
from gymnasium.vector.vector_env import ArrayType
from gymnasium.wrappers.numpy_to_torch import Device, numpy_to_torch, torch_to_numpy
from gymnasium.vector import VectorEnv
from gymnasium.wrappers.numpy_to_torch import Device
from gymnasium.wrappers.vector.array_conversion import ArrayConversion
__all__ = ["NumpyToTorch"]
class NumpyToTorch(VectorWrapper):
class NumpyToTorch(ArrayConversion):
"""Wraps a numpy-based environment so that it can be interacted with through PyTorch Tensors.
Example:
@@ -45,48 +45,6 @@ class NumpyToTorch(VectorWrapper):
env: The NumPy-based vector environment to wrap
device: The device the torch Tensors should be moved to
"""
super().__init__(env)
super().__init__(env, env_xp=np, target_xp=torch, target_device=device)
self.device: Device | None = device
def step(
self, actions: ActType
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]:
"""Using a PyTorch based action that is converted to NumPy to be used by the environment.
Args:
action: A PyTorch-based action
Returns:
The PyTorch-based Tensor next observation, reward, termination, truncation, and extra info
"""
numpy_action = torch_to_numpy(actions)
obs, reward, terminated, truncated, info = self.env.step(numpy_action)
return (
numpy_to_torch(obs, self.device),
numpy_to_torch(reward, self.device),
numpy_to_torch(terminated, self.device),
numpy_to_torch(truncated, self.device),
numpy_to_torch(info, self.device),
)
def reset(
self,
*,
seed: int | list[int] | None = None,
options: dict[str, Any] | None = None,
) -> tuple[ObsType, dict[str, Any]]:
"""Resets the environment returning PyTorch-based observation and info.
Args:
seed: The seed for resetting the environment
options: The options for resetting the environment, these are converted to NumPy arrays.
Returns:
PyTorch-based observations and info
"""
if options:
options = torch_to_numpy(options)
return numpy_to_torch(self.env.reset(seed=seed, options=options), self.device)

View File

@@ -8,7 +8,7 @@ build-backend = "setuptools.build_meta"
name = "gymnasium"
description = "A standard API for reinforcement learning and a diverse set of reference environments (formerly Gym)."
readme = "README.md"
requires-python = ">= 3.8"
requires-python = ">= 3.10"
authors = [{ name = "Farama Foundation", email = "contact@farama.org" }]
license = { text = "MIT License" }
keywords = ["Reinforcement Learning", "game", "RL", "AI", "gymnasium"]
@@ -16,8 +16,6 @@ classifiers = [
"Development Status :: 5 - Production/Stable",
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
@@ -27,7 +25,6 @@ classifiers = [
dependencies = [
"numpy >=1.21.0",
"cloudpickle >=1.2.0",
"importlib-metadata >=4.8.0; python_version < '3.10'",
"typing-extensions >=4.3.0",
"farama-notifications >=0.0.1",
]
@@ -38,15 +35,30 @@ dynamic = ["version"]
atari = ["ale_py >=0.9"]
box2d = ["box2d-py ==2.3.5", "pygame >=2.1.3", "swig ==4.*"]
classic-control = ["pygame >=2.1.3"]
classic_control = ["pygame >=2.1.3"] # kept for backward compatibility
classic_control = ["pygame >=2.1.3"] # kept for backward compatibility
mujoco-py = ["mujoco-py >=2.1,<2.2", "cython<3"]
mujoco_py = ["mujoco-py >=2.1,<2.2", "cython<3"] # kept for backward compatibility
mujoco_py = [
"mujoco-py >=2.1,<2.2",
"cython<3",
] # kept for backward compatibility
mujoco = ["mujoco >=2.1.5", "imageio >=2.14.1", "packaging >=23.0"]
toy-text = ["pygame >=2.1.3"]
toy_text = ["pygame >=2.1.3"] # kept for backward compatibility
jax = ["jax >=0.4.16", "jaxlib >=0.4.16", "flax >=0.5.0"]
torch = ["torch >=1.13.0"]
other = ["moviepy >=1.0.0", "matplotlib >=3.0", "opencv-python >=3.0", "seaborn >= 0.13"]
toy_text = ["pygame >=2.1.3"] # kept for backward compatibility
jax = [
"jax >=0.4.16",
"jaxlib >=0.4.16",
"flax >=0.5.0",
"array-api-compat >=1.11.0",
"numpy>=2.1",
]
torch = ["torch >=1.13.0", "array-api-compat >=1.11.0", "numpy>=2.1"]
array-api = ["array-api-compat >=1.11.0", "numpy>=2.1"]
other = [
"moviepy >=1.0.0",
"matplotlib >=3.0",
"opencv-python >=3.0",
"seaborn >= 0.13",
]
all = [
# All dependencies above except accept-rom-license
# NOTE: No need to manually remove the duplicates, setuptools automatically does that.
@@ -71,17 +83,26 @@ all = [
"jax >=0.4.16",
"jaxlib >=0.4.16",
"flax >= 0.5.0",
"array-api-compat >=1.11.0",
"numpy>=2.1",
# torch
"torch >=1.13.0",
"array-api-compat >=1.11.0",
"numpy>=2.1",
# array-api
"array-api-compat >=1.11.0",
"numpy>=2.1",
# other
"opencv-python >=3.0",
"matplotlib >=3.0",
"moviepy >=1.0.0",
]
testing = [
"pytest >=7.1.3",
"scipy >=1.7.3",
"dill >=0.3.7",
"array_api_extra >=0.7.0",
]
[project.urls]
@@ -125,7 +146,7 @@ exclude = ["tests/**", "**/node_modules", "**/__pycache__"]
strict = []
typeCheckingMode = "basic"
pythonVersion = "3.8"
pythonVersion = "3.10"
pythonPlatform = "All"
typeshedPath = "typeshed"
enableTypeIgnoreComments = true
@@ -138,19 +159,19 @@ reportMissingTypeStubs = false
# For warning and error, will raise an error when
reportInvalidTypeVarUse = "none"
reportGeneralTypeIssues = "none" # -> commented out raises 489 errors
reportAttributeAccessIssue = "none" # pyright provides false positives
reportArgumentType = "none" # pyright provides false positives
reportGeneralTypeIssues = "none" # -> commented out raises 489 errors
reportAttributeAccessIssue = "none" # pyright provides false positives
reportArgumentType = "none" # pyright provides false positives
reportPrivateUsage = "warning"
reportIndexIssue = "none" # TODO fix one by one
reportReturnType = "none" # TODO fix one by one
reportCallIssue = "none" # TODO fix one by one
reportOperatorIssue = "none" # TODO fix one by one
reportInvalidTypeForm = "none" # TODO fix one by one
reportOptionalMemberAccess = "none" # TODO fix one by one
reportAssignmentType = "none" # TODO fix one by one
reportIndexIssue = "none" # TODO fix one by one
reportReturnType = "none" # TODO fix one by one
reportCallIssue = "none" # TODO fix one by one
reportOperatorIssue = "none" # TODO fix one by one
reportInvalidTypeForm = "none" # TODO fix one by one
reportOptionalMemberAccess = "none" # TODO fix one by one
reportAssignmentType = "none" # TODO fix one by one
[tool.pytest.ini_options]

View File

@@ -6,10 +6,15 @@ import types
from collections.abc import Callable
from typing import Any
import numpy as np
import gymnasium as gym
from gymnasium import spaces
from gymnasium.core import ActType, ObsType
from gymnasium.envs.registration import EnvSpec
from gymnasium.vector import VectorEnv
from gymnasium.vector.utils import batch_space
from gymnasium.vector.vector_env import AutoresetMode
def basic_reset_func(
@@ -106,3 +111,112 @@ class GenericTestEnv(gym.Env):
def render(self):
"""Renders the environment."""
raise NotImplementedError("testingEnv render_fn is not set.")
def basic_vector_reset_func(
self,
*,
seed: int | None = None,
options: dict | None = None,
) -> tuple[ObsType, dict]:
"""A basic reset function that will pass the environment check using random actions from the observation space."""
super(GenericTestVectorEnv, self).reset(seed=seed)
self.observation_space.seed(self.np_random_seed)
return self.observation_space.sample(), {"options": options}
def basic_vector_step_func(
self, action: ActType
) -> tuple[ObsType, np.ndarray, np.ndarray, np.ndarray, dict]:
"""A step function that follows the basic step api that will pass the environment check using random actions from the observation space."""
obs = self.observation_space.sample()
rewards = np.zeros(self.num_envs, dtype=np.float64)
terminations = np.zeros(self.num_envs, dtype=np.bool_)
truncations = np.zeros(self.num_envs, dtype=np.bool_)
return obs, rewards, terminations, truncations, {}
def basic_vector_render_func(self):
"""Basic render fn that does nothing."""
pass
class GenericTestVectorEnv(VectorEnv):
"""A generic testing vector environment similar to GenericTestEnv.
Some tests cannot use SyncVectorEnv, e.g. when returning non-numpy arrays in the observations.
In these cases, GenericTestVectorEnv can be used to simulate a vector environment.
"""
def __init__(
self,
num_envs: int = 1,
action_space: spaces.Space = spaces.Box(0, 1, (1,)),
observation_space: spaces.Space = spaces.Box(0, 1, (1,)),
reset_func: Callable = basic_vector_reset_func,
step_func: Callable = basic_vector_step_func,
render_func: Callable = basic_vector_render_func,
metadata: dict[str, Any] = {
"render_modes": [],
"autoreset_mode": AutoresetMode.NEXT_STEP,
},
render_mode: str | None = None,
spec: EnvSpec = EnvSpec(
"TestingVectorEnv-v0",
"tests.testing_env:GenericTestVectorEnv",
max_episode_steps=100,
),
):
"""Generic testing vector environment constructor.
Args:
num_envs: The number of environments to create
action_space: The environment action space
observation_space: The environment observation space
reset_func: The environment reset function
step_func: The environment step function
render_func: The environment render function
metadata: The environment metadata
render_mode: The render mode of the environment
spec: The environment spec
"""
super().__init__()
self.num_envs = num_envs
self.metadata = metadata
self.render_mode = render_mode
self.spec = spec
# Set the single spaces and create batched spaces
self.single_observation_space = observation_space
self.single_action_space = action_space
self.observation_space = batch_space(observation_space, num_envs)
self.action_space = batch_space(action_space, num_envs)
# Bind the functions to the instance
if reset_func is not None:
self.reset = types.MethodType(reset_func, self)
if step_func is not None:
self.step = types.MethodType(step_func, self)
if render_func is not None:
self.render = types.MethodType(render_func, self)
def reset(
self,
*,
seed: int | None = None,
options: dict | None = None,
) -> tuple[ObsType, dict]:
"""Resets the environment."""
# If you need a default working reset function, use `basic_vector_reset_fn` above
raise NotImplementedError("TestingVectorEnv reset_fn is not set.")
def step(
self, action: ActType
) -> tuple[ObsType, np.ndarray, np.ndarray, np.ndarray, dict]:
"""Steps through the environment."""
raise NotImplementedError("TestingVectorEnv step_fn is not set.")
def render(self) -> tuple[Any, ...] | None:
"""Renders the environment."""
raise NotImplementedError("TestingVectorEnv render_fn is not set.")

View File

@@ -0,0 +1,250 @@
"""Test suite for ArrayConversion wrapper."""
import importlib
import itertools
import pickle
from typing import Any, NamedTuple
import pytest
import gymnasium
array_api_compat = pytest.importorskip("array_api_compat")
array_api_extra = pytest.importorskip("array_api_extra")
from array_api_compat import array_namespace, is_array_api_obj # noqa: E402
from gymnasium.wrappers.array_conversion import ( # noqa: E402
ArrayConversion,
array_conversion,
module_namespace,
)
from tests.testing_env import GenericTestEnv # noqa: E402
# Define available modules
installed_modules = []
array_api_modules = [
"numpy",
"jax.numpy",
"torch",
"cupy",
"dask.array",
"sparse",
"array_api_strict",
]
for module in array_api_modules:
try:
installed_modules.append(importlib.import_module(module))
except ImportError:
pass # Modules that are not installed are skipped
installed_modules_combinations = list(itertools.permutations(installed_modules, 2))
def xp_data_equivalence(data_1, data_2) -> bool:
"""Return if two variables are equivalent that might contain ``torch.Tensor``."""
if type(data_1) is type(data_2):
if isinstance(data_1, dict):
return data_1.keys() == data_2.keys() and all(
xp_data_equivalence(data_1[k], data_2[k]) for k in data_1.keys()
)
elif isinstance(data_1, (tuple, list)):
return len(data_1) == len(data_2) and all(
xp_data_equivalence(o_1, o_2) for o_1, o_2 in zip(data_1, data_2)
)
elif is_array_api_obj(data_1):
return array_api_extra.isclose(data_1, data_2, atol=0.00001).all()
else:
return data_1 == data_2
else:
return False
class ExampleNamedTuple(NamedTuple):
a: Any # Array API compatible object. Does not have proper typing support yet.
b: Any # Same as a
def _supports_higher_precision(xp, low_type, high_type):
"""Check if an array module supports higher precision type."""
return xp.result_type(low_type, high_type) == high_type
# When converting between array modules (source → target → source), we need to ensure that the
# precision used is supported by both modules. If either module only supports 32-bit types, we must
# use the lower precision to account for the conversion during the roundtrip.
def atleast_float32(source_xp, target_xp):
"""Return source_xp.float64 if both modules support it, otherwise source_xp.float32."""
source_supports_64 = _supports_higher_precision(
source_xp, source_xp.float32, source_xp.float64
)
target_supports_64 = _supports_higher_precision(
target_xp, target_xp.float32, target_xp.float64
)
return (
source_xp.float64
if (source_supports_64 and target_supports_64)
else source_xp.float32
)
def atleast_int32(source_xp, target_xp):
"""Return source_xp.int64 if both modules support it, otherwise source_xp.int32."""
source_supports_64 = _supports_higher_precision(
source_xp, source_xp.int32, source_xp.int64
)
target_supports_64 = _supports_higher_precision(
target_xp, target_xp.int32, target_xp.int64
)
return (
source_xp.int64
if (source_supports_64 and target_supports_64)
else source_xp.int32
)
def value_parametrization():
for source_xp, target_xp in installed_modules_combinations:
xp = module_namespace(source_xp)
source_xp = module_namespace(source_xp)
target_xp = module_namespace(target_xp)
for value, expected_value in [
(2, xp.asarray(2, dtype=atleast_int32(source_xp, target_xp))),
(
(3.0, 4),
(
xp.asarray(3.0, dtype=atleast_float32(source_xp, target_xp)),
xp.asarray(4, dtype=atleast_int32(source_xp, target_xp)),
),
),
(
[3.0, 4],
[
xp.asarray(3.0, dtype=atleast_float32(source_xp, target_xp)),
xp.asarray(4, dtype=atleast_int32(source_xp, target_xp)),
],
),
(
{
"a": 6.0,
"b": 7,
},
{
"a": xp.asarray(6.0, dtype=atleast_float32(source_xp, target_xp)),
"b": xp.asarray(7, dtype=atleast_int32(source_xp, target_xp)),
},
),
(xp.asarray(1.0, dtype=xp.float32), xp.asarray(1.0, dtype=xp.float32)),
(xp.asarray(1.0, dtype=xp.uint8), xp.asarray(1.0, dtype=xp.uint8)),
(xp.asarray([1, 2], dtype=xp.int32), xp.asarray([1, 2], dtype=xp.int32)),
(
xp.asarray([[1.0], [2.0]], dtype=xp.int32),
xp.asarray([[1.0], [2.0]], dtype=xp.int32),
),
(
{
"a": (
1,
xp.asarray(2.0, dtype=xp.float32),
xp.asarray([3, 4], dtype=xp.int32),
),
"b": {"c": 5},
},
{
"a": (
xp.asarray(1, dtype=atleast_int32(source_xp, target_xp)),
xp.asarray(2.0, dtype=xp.float32),
xp.asarray([3, 4], dtype=xp.int32),
),
"b": {
"c": xp.asarray(5, dtype=atleast_int32(source_xp, target_xp))
},
},
),
(
ExampleNamedTuple(
a=xp.asarray([1, 2], dtype=xp.int32),
b=xp.asarray([1.0, 2.0], dtype=xp.float32),
),
ExampleNamedTuple(
a=xp.asarray([1, 2], dtype=xp.int32),
b=xp.asarray([1.0, 2.0], dtype=xp.float32),
),
),
(None, None),
]:
yield (source_xp, target_xp, value, expected_value)
@pytest.mark.parametrize(
"source_xp,target_xp,value,expected_value", value_parametrization()
)
def test_roundtripping(source_xp, target_xp, value, expected_value):
"""Test roundtripping between different Array API compatible frameworks."""
roundtripped_value = array_conversion(
array_conversion(value, xp=target_xp), xp=source_xp
)
assert xp_data_equivalence(roundtripped_value, expected_value)
@pytest.mark.parametrize("env_xp, target_xp", installed_modules_combinations)
def test_array_conversion_wrapper(env_xp, target_xp):
# Define reset and step functions without partial to avoid pickling issues
def reset_func(self, seed=None, options=None):
"""A generic array API reset function."""
return env_xp.asarray([1.0, 2.0, 3.0]), {"data": env_xp.asarray([1, 2, 3])}
def step_func(self, action):
"""A generic array API step function."""
assert isinstance(action, type(env_xp.zeros(1)))
return (
env_xp.asarray([1, 2, 3]),
env_xp.asarray(5.0),
env_xp.asarray(True),
env_xp.asarray(False),
{"data": env_xp.asarray([1.0, 2.0])},
)
env = GenericTestEnv(reset_func=reset_func, step_func=step_func)
# Check that the reset and step for env_xp environment are as expected
obs, info = env.reset()
# env_xp is automatically converted to the compatible namespace by array_namespace, so we need
# to check against the compatible namespace of env_xp in array_api_compat
env_xp_compat = module_namespace(env_xp)
assert array_namespace(obs) is env_xp_compat
assert isinstance(info, dict) and array_namespace(info["data"]) is env_xp_compat
obs, reward, terminated, truncated, info = env.step(env_xp_compat.asarray([1, 2]))
assert array_namespace(obs) is env_xp_compat
assert array_namespace(reward) is env_xp_compat
assert array_namespace(terminated) is env_xp_compat
assert array_namespace(truncated) is env_xp_compat
assert isinstance(info, dict) and array_namespace(info["data"]) is env_xp_compat
# Check that the wrapped version is correct.
target_xp_compat = module_namespace(target_xp)
wrapped_env = ArrayConversion(env, env_xp=env_xp, target_xp=target_xp)
obs, info = wrapped_env.reset()
assert array_namespace(obs) is target_xp_compat
assert isinstance(info, dict) and array_namespace(info["data"]) is target_xp_compat
action = target_xp.asarray([1, 2], dtype=target_xp.int32)
obs, reward, terminated, truncated, info = wrapped_env.step(action)
assert array_namespace(obs) is target_xp_compat
assert isinstance(reward, float)
assert isinstance(terminated, bool) and isinstance(truncated, bool)
assert isinstance(info, dict) and array_namespace(info["data"]) is target_xp_compat
# Check that the wrapped environment can render. This implicitly returns None and requires a
# None -> None conversion
wrapped_env.render()
# Test that the wrapped environment can be pickled
env = gymnasium.make("CartPole-v1", disable_env_checker=True)
wrapped_env = ArrayConversion(env, env_xp=env_xp, target_xp=target_xp)
pkl = pickle.dumps(wrapped_env)
pickle.loads(pkl)

View File

@@ -1,10 +1,13 @@
"""Test suite for JaxToNumpy wrapper."""
import pickle
from typing import NamedTuple
import numpy as np
import pytest
import gymnasium
jax = pytest.importorskip("jax")
jnp = pytest.importorskip("jax.numpy")
@@ -133,3 +136,9 @@ def test_jax_to_numpy_wrapper():
# Check that the wrapped environment can render. This implicitly returns None and requires a
# None -> None conversion
numpy_env.render()
# Test that the wrapped environment can be pickled
env = gymnasium.make("CartPole-v1", disable_env_checker=True)
wrapped_env = JaxToNumpy(env)
pkl = pickle.dumps(wrapped_env)
pickle.loads(pkl)

View File

@@ -1,9 +1,12 @@
"""Test suite for TorchToJax wrapper."""
import pickle
from typing import NamedTuple
import pytest
import gymnasium
jax = pytest.importorskip("jax")
jnp = pytest.importorskip("jax.numpy")
@@ -148,3 +151,9 @@ def test_jax_to_torch_wrapper():
# Check that the wrapped environment can render. This implicitly returns None and requires a
# None -> None conversion
wrapped_env.render()
# Test that the wrapped environment can be pickled
env = gymnasium.make("CartPole-v1", disable_env_checker=True)
wrapped_env = JaxToTorch(env)
pkl = pickle.dumps(wrapped_env)
pickle.loads(pkl)

View File

@@ -1,10 +1,13 @@
"""Test suite for NumPyToTorch wrapper."""
import pickle
from typing import NamedTuple
import numpy as np
import pytest
import gymnasium
torch = pytest.importorskip("torch")
@@ -128,3 +131,9 @@ def test_numpy_to_torch():
# Check that the wrapped environment can render. This implicitly returns None and requires a
# None -> None conversion
torch_env.render()
# Test that the wrapped environment can be pickled
env = gymnasium.make("CartPole-v1", disable_env_checker=True)
wrapped_env = NumpyToTorch(env)
pkl = pickle.dumps(wrapped_env)
pickle.loads(pkl)

View File

@@ -0,0 +1,146 @@
"""Test suite for vector ArrayConversion wrapper."""
import importlib
import itertools
from functools import partial
import numpy as np
import pytest
from tests.testing_env import GenericTestVectorEnv
array_api_compat = pytest.importorskip("array_api_compat")
from array_api_compat import array_namespace # noqa: E402
from gymnasium.wrappers.array_conversion import module_namespace # noqa: E402
from gymnasium.wrappers.vector.array_conversion import ArrayConversion # noqa: E402
from gymnasium.wrappers.vector.jax_to_numpy import JaxToNumpy # noqa: E402
from gymnasium.wrappers.vector.jax_to_torch import JaxToTorch # noqa: E402
from gymnasium.wrappers.vector.numpy_to_torch import NumpyToTorch # noqa: E402
# Define available modules
installed_modules = []
array_api_modules = [
"numpy",
"jax.numpy",
"torch",
"cupy",
"dask.array",
"sparse",
"array_api_strict",
]
for module in array_api_modules:
try:
installed_modules.append(importlib.import_module(module))
except ImportError:
pass # Modules that are not installed are skipped
installed_modules_combinations = list(itertools.permutations(installed_modules, 2))
def create_vector_env(env_xp):
_reset_func = partial(reset_func, num_envs=3, xp=env_xp)
_step_func = partial(step_func, num_envs=3, xp=env_xp)
return GenericTestVectorEnv(reset_func=_reset_func, step_func=_step_func)
def reset_func(self, seed=None, options=None, num_envs: int = 1, xp=np):
return xp.asarray([[1.0, 2.0, 3.0] * num_envs]), {
"data": xp.asarray([[1, 2, 3] * num_envs])
}
def step_func(self, action, num_envs: int = 1, xp=np):
assert isinstance(action, type(xp.zeros(1)))
return (
xp.asarray([[1, 2, 3] * num_envs]),
xp.asarray([5.0] * num_envs),
xp.asarray([False] * num_envs),
xp.asarray([False] * num_envs),
{"data": xp.asarray([[1.0, 2.0] * num_envs])},
)
@pytest.mark.parametrize("env_xp, target_xp", installed_modules_combinations)
def test_array_conversion_wrapper(env_xp, target_xp):
env_xp_compat = module_namespace(env_xp)
env = create_vector_env(env_xp_compat)
# Check that the reset and step for env_xp environment are as expected
obs, info = env.reset()
# env_xp is automatically converted to the compatible namespace by array_namespace, so we need
# to check against the compatible namespace of env_xp in array_api_compat
assert array_namespace(obs) is env_xp_compat
assert isinstance(info, dict) and array_namespace(info["data"]) is env_xp_compat
obs, reward, terminated, truncated, info = env.step(env_xp_compat.asarray([1, 2]))
assert array_namespace(obs) is env_xp_compat
assert array_namespace(reward) is env_xp_compat
assert array_namespace(terminated) is env_xp_compat
assert array_namespace(truncated) is env_xp_compat
assert isinstance(info, dict) and array_namespace(info["data"]) is env_xp_compat
# Check that the wrapped version is correct.
target_xp_compat = module_namespace(target_xp)
wrapped_env = ArrayConversion(env, env_xp=env_xp, target_xp=target_xp)
obs, info = wrapped_env.reset()
assert array_namespace(obs) is target_xp_compat
assert isinstance(info, dict) and array_namespace(info["data"]) is target_xp_compat
action = target_xp.asarray([1, 2], dtype=target_xp.int32)
obs, reward, terminated, truncated, info = wrapped_env.step(action)
assert array_namespace(obs) is target_xp_compat
assert array_namespace(reward) is target_xp_compat
assert array_namespace(terminated) is target_xp_compat
assert terminated.dtype == target_xp.bool
assert array_namespace(truncated) is target_xp_compat
assert truncated.dtype == target_xp.bool
assert isinstance(info, dict) and array_namespace(info["data"]) is target_xp_compat
# Check that the wrapped environment can render. This implicitly returns None and requires a
# None -> None conversion
wrapped_env.render()
@pytest.mark.parametrize("wrapper", [JaxToNumpy, JaxToTorch, NumpyToTorch])
def test_specialized_wrappers(wrapper: type[JaxToNumpy | JaxToTorch | NumpyToTorch]):
if wrapper is JaxToNumpy:
jax = pytest.importorskip("jax")
env_xp, target_xp = jax.numpy, np
elif wrapper is JaxToTorch:
jax = pytest.importorskip("jax")
torch = pytest.importorskip("torch")
env_xp, target_xp = jax.numpy, torch
elif wrapper is NumpyToTorch:
torch = pytest.importorskip("torch")
env_xp, target_xp = np, torch
else:
raise TypeError(f"Unknown specialized conversion wrapper {type(wrapper)}")
env_xp_compat = module_namespace(env_xp)
target_xp_compat = module_namespace(target_xp)
# The unwrapped test env sanity check is already covered by test_array_conversion_wrapper for
# all known frameworks, including the specialized ones.
env = create_vector_env(env_xp_compat)
# Check that the wrapped version is correct.
wrapped_env = wrapper(env)
obs, info = wrapped_env.reset()
assert array_namespace(obs) is target_xp_compat
assert isinstance(info, dict) and array_namespace(info["data"]) is target_xp_compat
action = target_xp.asarray([1, 2], dtype=target_xp.int32)
obs, reward, terminated, truncated, info = wrapped_env.step(action)
assert array_namespace(obs) is target_xp_compat
assert array_namespace(reward) is target_xp_compat
assert array_namespace(terminated) is target_xp_compat
assert terminated.dtype == target_xp.bool
assert array_namespace(truncated) is target_xp_compat
assert truncated.dtype == target_xp.bool
assert isinstance(info, dict) and array_namespace(info["data"]) is target_xp_compat
# Check that the wrapped environment can render. This implicitly returns None and requires a
# None -> None conversion
wrapped_env.render()