diff --git a/.github/workflows/docs-build-dev.yml b/.github/workflows/docs-build-dev.yml index e574547da..43621d991 100644 --- a/.github/workflows/docs-build-dev.yml +++ b/.github/workflows/docs-build-dev.yml @@ -19,7 +19,7 @@ jobs: - uses: actions/setup-python@v5 with: - python-version: '3.9' + python-version: '3.10' - name: Install dependencies run: pip install -r docs/requirements.txt diff --git a/.github/workflows/docs-build-release.yml b/.github/workflows/docs-build-release.yml index aff91478d..6d674da67 100644 --- a/.github/workflows/docs-build-release.yml +++ b/.github/workflows/docs-build-release.yml @@ -19,7 +19,7 @@ jobs: - uses: actions/setup-python@v5 with: - python-version: '3.9' + python-version: '3.10' - name: Install dependencies run: pip install -r docs/requirements.txt diff --git a/.github/workflows/docs-manual-build.yml b/.github/workflows/docs-manual-build.yml index a02b71077..923d85979 100644 --- a/.github/workflows/docs-manual-build.yml +++ b/.github/workflows/docs-manual-build.yml @@ -33,7 +33,7 @@ jobs: - uses: actions/setup-python@v5 with: - python-version: '3.9' + python-version: '3.10' - name: Install dependencies run: pip install -r docs/requirements.txt diff --git a/.github/workflows/run-pytest.yml b/.github/workflows/run-pytest.yml index 8d228776c..8d3535b75 100644 --- a/.github/workflows/run-pytest.yml +++ b/.github/workflows/run-pytest.yml @@ -10,8 +10,8 @@ jobs: strategy: fail-fast: true matrix: - python-version: ['3.9', '3.10', '3.11', '3.12'] - numpy-version: ['>=1.21,<2.0', '>=2.0'] + python-version: ['3.10', '3.11', '3.12'] + numpy-version: ['>=1.21,<2.0', '>=2.1'] steps: - uses: actions/checkout@v4 - run: | @@ -22,7 +22,6 @@ jobs: - name: Run tests run: docker run gymnasium-all-docker pytest tests/* - name: Run doctests - if: ${{ matrix.python-version != '3.8' }} run: docker run gymnasium-all-docker pytest --doctest-modules gymnasium/ build-necessary: diff --git a/.github/workflows/run-tutorial.yml b/.github/workflows/run-tutorial.yml index 953f2d104..4f5c1712c 100644 --- a/.github/workflows/run-tutorial.yml +++ b/.github/workflows/run-tutorial.yml @@ -20,7 +20,7 @@ jobs: strategy: fail-fast: false # This ensures all matrix combinations run even if one fails matrix: - python-version: ["3.9"] + python-version: ["3.10"] tutorial-group: - gymnasium_basics - training_agents diff --git a/README.md b/README.md index 1d914c1c5..89a4090db 100644 --- a/README.md +++ b/README.md @@ -32,7 +32,7 @@ To install the base Gymnasium library, use `pip install gymnasium` This does not include dependencies for all families of environments (there's a massive number, and some can be problematic to install on certain systems). You can install these dependencies for one family like `pip install "gymnasium[atari]"` or use `pip install "gymnasium[all]"` to install all dependencies. -We support and test for Python 3.8, 3.9, 3.10, 3.11 and 3.12 on Linux and macOS. We will accept PRs related to Windows, but do not officially support it. +We support and test for Python 3.10, 3.11 and 3.12 on Linux and macOS. We will accept PRs related to Windows, but do not officially support it. ## API diff --git a/docs/api/wrappers/misc_wrappers.md b/docs/api/wrappers/misc_wrappers.md index 25deaafda..88ac581bd 100644 --- a/docs/api/wrappers/misc_wrappers.md +++ b/docs/api/wrappers/misc_wrappers.md @@ -27,6 +27,7 @@ title: Misc Wrappers ## Data Conversion Wrappers ```{eval-rst} +.. autoclass:: gymnasium.wrappers.ArrayConversion .. autoclass:: gymnasium.wrappers.JaxToNumpy .. autoclass:: gymnasium.wrappers.JaxToTorch .. autoclass:: gymnasium.wrappers.NumpyToTorch diff --git a/docs/api/wrappers/table.md b/docs/api/wrappers/table.md index f78a6f9ea..89cb9acb9 100644 --- a/docs/api/wrappers/table.md +++ b/docs/api/wrappers/table.md @@ -34,6 +34,8 @@ wrapper in the page on the wrapper type - Converts an image observation computed by ``reset`` and ``step`` from RGB to Grayscale. * - :class:`HumanRendering` - Allows human like rendering for environments that support "rgb_array" rendering. + * - :class:`ArrayConversion` + - Wraps an environment based on any Array API compatible framework, e.g. torch, jax, numpy, such that it can be interacted with any other Array API compatible framework. * - :class:`JaxToNumpy` - Wraps a Jax-based environment such that it can be interacted with NumPy arrays. * - :class:`JaxToTorch` diff --git a/gymnasium/wrappers/array_conversion.py b/gymnasium/wrappers/array_conversion.py new file mode 100644 index 000000000..f24842d8e --- /dev/null +++ b/gymnasium/wrappers/array_conversion.py @@ -0,0 +1,261 @@ +# This wrapper will convert array inputs from an Array API compatible framework A for the actions +# to any other Array API compatible framework B for an underlying environment that is implemented +# in framework B, then convert the return observations from framework B back to framework A. +# +# More precisely, the wrapper will work for any two frameworks that can be made compatible with the +# `array-api-compat` package. +# +# See https://data-apis.org/array-api/latest/ for more information on the Array API standard, and +# https://data-apis.org/array-api-compat/ for more information on the Array API compatibility layer. +# +# General structure for converting between types originally copied from +# https://github.com/google/brax/blob/9d6b7ced2a13da0d074b5e9fbd3aad8311e26997/brax/io/torch.py +# Under the Apache 2.0 license. Copyright is held by the authors + +"""Helper functions and wrapper class for converting between arbitrary Array API compatible frameworks and a target framework.""" + +from __future__ import annotations + +import functools +import importlib +import numbers +from collections import abc +from types import ModuleType, NoneType +from typing import Any, Iterable, Mapping, SupportsFloat + +import numpy as np +from packaging.version import Version + +import gymnasium as gym +from gymnasium.core import RenderFrame, WrapperActType, WrapperObsType +from gymnasium.error import DependencyNotInstalled + + +try: + from array_api_compat import array_namespace, is_array_api_obj, to_device + +except ImportError: + raise DependencyNotInstalled( + 'Array API packages are not installed therefore cannot call `array_conversion`, run `pip install "gymnasium[array-api]"`' + ) + + +if Version(np.__version__) < Version("2.1.0"): + raise DependencyNotInstalled("Array API functionality requires numpy >= 2.1.0") + + +__all__ = ["ArrayConversion", "array_conversion"] + +Array = Any # TODO: Switch to ArrayAPI type once https://github.com/data-apis/array-api/pull/589 is merged +Device = Any # TODO: Switch to ArrayAPI type if available + + +def module_namespace(xp: ModuleType) -> ModuleType: + """Determine the Array API compatible namespace of the given module. + + This function is closely linked to the `array_api_compat.array_namespace` function. It returns + the compatible namespace for a module directly instead of from an array object of that module. + + See https://data-apis.org/array-api-compat/helper-functions.html#array_api_compat.array_namespace + """ + try: + return array_namespace(xp.empty(0)) + except AttributeError as e: + raise ValueError(f"Module {xp} is not an Array API compatible module.") from e + + +def module_name_to_namespace(name: str) -> ModuleType: + return module_namespace(importlib.import_module(name)) + + +@functools.singledispatch +def array_conversion(value: Any, xp: ModuleType, device: Device | None = None) -> Any: + """Convert a value into the specified xp module array type.""" + raise Exception( + f"No known conversion for ({type(value)}) to xp module ({xp}) registered. Report as issue on github." + ) + + +@array_conversion.register(numbers.Number) +def _number_array_conversion( + value: numbers.Number, xp: ModuleType, device: Device | None = None +) -> Array: + """Convert a python number (int, float, complex) to an Array API framework array.""" + return xp.asarray(value, device=device) + + +@array_conversion.register(abc.Mapping) +def _mapping_array_conversion( + value: Mapping[str, Any], xp: ModuleType, device: Device | None = None +) -> Mapping[str, Any]: + """Convert a mapping of Arrays into a Dictionary of the specified xp module array type.""" + return type(value)(**{k: array_conversion(v, xp, device) for k, v in value.items()}) + + +@array_conversion.register(abc.Iterable) +def _iterable_array_conversion( + value: Iterable[Any], xp: ModuleType, device: Device | None = None +) -> Iterable[Any]: + """Convert an Iterable from Arrays to an iterable of the specified xp module array type.""" + # There is currently no type for ArrayAPI compatible objects, so they fall through to this + # function registered for any Iterable. If they are arrays, we can convert them directly. + # We currently cannot pass the device to the from_dlpack function, since it is not supported + # for some frameworks (see e.g. https://github.com/data-apis/array-api-compat/issues/204) + if is_array_api_obj(value): + return _array_api_array_conversion(value, xp, device) + if hasattr(value, "_make"): + # namedtuple - underline used to prevent potential name conflicts + # noinspection PyProtectedMember + return type(value)._make(array_conversion(v, xp, device) for v in value) + return type(value)(array_conversion(v, xp, device) for v in value) + + +def _array_api_array_conversion( + value: Array, xp: ModuleType, device: Device | None = None +) -> Array: + """Convert an Array API compatible array to the specified xp module array type.""" + try: + x = xp.from_dlpack(value) + return to_device(x, device) if device is not None else x + except (RuntimeError, BufferError): + # If dlpack fails (e.g. because the array is read-only for frameworks that do not + # support it), we create a copy of the array that we own and then convert it. + # TODO: The correct treatment of read-only arrays is currently not fully clear in the + # Array API. Once ongoing discussions are resolved, we should update this code to remove + # any fallbacks. + value_namespace = array_namespace(value) + value_copy = value_namespace.asarray(value, copy=True) + return xp.asarray(value_copy, device=device) + + +@array_conversion.register(NoneType) +def _none_array_conversion( + value: None, xp: ModuleType, device: Device | None = None +) -> None: + """Passes through None values.""" + return value + + +class ArrayConversion(gym.Wrapper, gym.utils.RecordConstructorArgs): + """Wraps an Array API compatible environment so that it can be interacted with a specific Array API framework. + + Actions must be provided as Array API compatible arrays and observations will be returned as Arrays of the specified xp module. + A vector version of the wrapper exists, :class:`gymnasium.wrappers.vector.ArrayConversion`. + + Example: + >>> import torch # doctest: +SKIP + >>> import jax.numpy as jnp # doctest: +SKIP + >>> import gymnasium as gym # doctest: +SKIP + >>> env = gym.make("JaxEnv-vx") # doctest: +SKIP + >>> env = ArrayConversion(env, env_xp=jnp, target_xp=torch) # doctest: +SKIP + >>> obs, _ = env.reset(seed=123) # doctest: +SKIP + >>> type(obs) # doctest: +SKIP + + >>> action = torch.tensor(env.action_space.sample()) # doctest: +SKIP + >>> obs, reward, terminated, truncated, info = env.step(action) # doctest: +SKIP + >>> type(obs) # doctest: +SKIP + + >>> type(reward) # doctest: +SKIP + + >>> type(terminated) # doctest: +SKIP + + >>> type(truncated) # doctest: +SKIP + + + Change logs: + * v1.2.0 - Initially added + """ + + def __init__( + self, + env: gym.Env, + env_xp: ModuleType, + target_xp: ModuleType, + env_device: Device | None = None, + target_device: Device | None = None, + ): + """Wrapper class to change inputs and outputs of environment to any Array API framework. + + Args: + env: The Array API compatible environment to wrap + env_xp: The Array API framework the environment is on + target_xp: The Array API framework to convert to + env_device: The device the environment is on + target_device: The device on which Arrays should be returned + """ + gym.utils.RecordConstructorArgs.__init__(self) + gym.Wrapper.__init__(self, env) + + self._env_xp = module_namespace(env_xp) + self._target_xp = module_namespace(target_xp) + self._env_device: Device | None = env_device + self._target_device: Device | None = target_device + + def step( + self, action: WrapperActType + ) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]: + """Performs the given action within the environment. + + Args: + action: The action to perform as any Array API compatible array + + Returns: + The next observation, reward, termination, truncation, and extra info + """ + action = array_conversion(action, xp=self._env_xp, device=self._env_device) + obs, reward, terminated, truncated, info = self.env.step(action) + + return ( + array_conversion(obs, xp=self._target_xp, device=self._target_device), + float(reward), + bool(terminated), + bool(truncated), + array_conversion(info, xp=self._target_xp, device=self._target_device), + ) + + def reset( + self, *, seed: int | None = None, options: dict[str, Any] | None = None + ) -> tuple[WrapperObsType, dict[str, Any]]: + """Resets the environment returning observation and info as Array from any Array API compatible framework. + + Args: + seed: The seed for resetting the environment + options: The options for resetting the environment, these are converted to jax arrays. + + Returns: + xp-based observations and info + """ + if options: + options = array_conversion(options, self._env_xp, self._env_device) + + return array_conversion( + self.env.reset(seed=seed, options=options), + self._target_xp, + self._target_device, + ) + + def render(self) -> RenderFrame | list[RenderFrame] | None: + """Returns the rendered frames as an xp Array.""" + return array_conversion(self.env.render(), self._target_xp, self._target_device) + + def __getstate__(self): + """Returns the object pickle state with args and kwargs.""" + env_xp_name = self._env_xp.__name__.replace("array_api_compat.", "") + target_xp_name = self._target_xp.__name__.replace("array_api_compat.", "") + env_device = self._env_device + target_device = self._target_device + return { + "env_xp_name": env_xp_name, + "target_xp_name": target_xp_name, + "env_device": env_device, + "target_device": target_device, + "env": self.env, + } + + def __setstate__(self, d): + """Sets the object pickle state using d.""" + self.env = d["env"] + self._env_xp = module_name_to_namespace(d["env_xp_name"]) + self._target_xp = module_name_to_namespace(d["target_xp_name"]) + self._env_device = d["env_device"] + self._target_device = d["target_device"] diff --git a/gymnasium/wrappers/jax_to_numpy.py b/gymnasium/wrappers/jax_to_numpy.py index a331498b6..aa942f6e8 100644 --- a/gymnasium/wrappers/jax_to_numpy.py +++ b/gymnasium/wrappers/jax_to_numpy.py @@ -3,19 +3,20 @@ from __future__ import annotations import functools -import numbers -from collections import abc -from typing import Any, Iterable, Mapping, SupportsFloat import numpy as np import gymnasium as gym -from gymnasium.core import ActType, ObsType, RenderFrame, WrapperActType, WrapperObsType +from gymnasium.core import ActType, ObsType from gymnasium.error import DependencyNotInstalled +from gymnasium.wrappers.array_conversion import ( + ArrayConversion, + array_conversion, + module_namespace, +) try: - import jax import jax.numpy as jnp except ImportError: raise DependencyNotInstalled( @@ -24,110 +25,13 @@ except ImportError: __all__ = ["JaxToNumpy", "jax_to_numpy", "numpy_to_jax"] -# The NoneType is not defined in Python 3.9. Remove when the minimal version is bumped to >=3.10 -_NoneType = type(None) + +jax_to_numpy = functools.partial(array_conversion, xp=module_namespace(np)) + +numpy_to_jax = functools.partial(array_conversion, xp=module_namespace(jnp)) -@functools.singledispatch -def numpy_to_jax(value: Any) -> Any: - """Converts a value to a Jax Array.""" - raise Exception( - f"No known conversion for Numpy type ({type(value)}) to Jax registered. Report as issue on github." - ) - - -@numpy_to_jax.register(numbers.Number) -def _number_to_jax( - value: numbers.Number, -) -> jax.Array: - """Converts a number (int, float, etc.) to a Jax Array.""" - assert jnp is not None - return jnp.array(value) - - -@numpy_to_jax.register(np.ndarray) -def _numpy_array_to_jax(value: np.ndarray) -> jax.Array: - """Converts a NumPy Array to a Jax Array with the same dtype (excluding float64 without being enabled).""" - assert jnp is not None - return jnp.array(value, dtype=value.dtype) - - -@numpy_to_jax.register(abc.Mapping) -def _mapping_numpy_to_jax(value: Mapping[str, Any]) -> Mapping[str, Any]: - """Converts a dictionary of numpy arrays to a mapping of Jax Array.""" - return type(value)(**{k: numpy_to_jax(v) for k, v in value.items()}) - - -@numpy_to_jax.register(abc.Iterable) -def _iterable_numpy_to_jax( - value: Iterable[np.ndarray | Any], -) -> Iterable[jax.Array | Any]: - """Converts an Iterable from Numpy Arrays to an iterable of Jax Array.""" - if hasattr(value, "_make"): - # namedtuple - underline used to prevent potential name conflicts - # noinspection PyProtectedMember - return type(value)._make(numpy_to_jax(v) for v in value) - else: - return type(value)(numpy_to_jax(v) for v in value) - - -@numpy_to_jax.register(_NoneType) -def _none_numpy_to_jax(value: None) -> None: - """Passes through None values.""" - return value - - -@functools.singledispatch -def jax_to_numpy(value: Any) -> Any: - """Converts a value to a numpy array.""" - raise Exception( - f"No known conversion for Jax type ({type(value)}) to NumPy registered. Report as issue on github." - ) - - -@jax_to_numpy.register(jax.Array) -def _devicearray_jax_to_numpy(value: jax.Array) -> np.ndarray: - """Converts a Jax Array to a numpy array.""" - return np.array(value) - - -@jax_to_numpy.register(abc.Mapping) -def _mapping_jax_to_numpy( - value: Mapping[str, jax.Array | Any], -) -> Mapping[str, np.ndarray | Any]: - """Converts a dictionary of Jax Array to a mapping of numpy arrays.""" - return type(value)(**{k: jax_to_numpy(v) for k, v in value.items()}) - - -@jax_to_numpy.register(abc.Iterable) -def _iterable_jax_to_numpy( - value: Iterable[np.ndarray | Any], -) -> Iterable[jax.Array | Any]: - """Converts an Iterable from Numpy arrays to an iterable of Jax Array.""" - if isinstance(value, jax.Array): - # Since the update to jax 0.6.0, calling jax_to_numpy with a - # argument wrongly dispatches to _iterable_jax_to_numpy which fails with: - # TypeError: (): incompatible function arguments. - # See: https://github.com/Farama-Foundation/Gymnasium/issues/1360 - return _devicearray_jax_to_numpy(value) - elif hasattr(value, "_make"): - # namedtuple - underline used to prevent potential name conflicts - # noinspection PyProtectedMember - return type(value)._make(jax_to_numpy(v) for v in value) - else: - return type(value)(jax_to_numpy(v) for v in value) - - -@jax_to_numpy.register(_NoneType) -def _none_jax_to_numpy(value: None) -> None: - """Passes through None values.""" - return value - - -class JaxToNumpy( - gym.Wrapper[WrapperObsType, WrapperActType, ObsType, ActType], - gym.utils.RecordConstructorArgs, -): +class JaxToNumpy(ArrayConversion): """Wraps a Jax-based environment such that it can be interacted with NumPy arrays. Actions must be provided as numpy arrays and observations will be returned as numpy arrays. @@ -169,48 +73,4 @@ class JaxToNumpy( raise DependencyNotInstalled( 'Jax is not installed, run `pip install "gymnasium[jax]"`' ) - gym.utils.RecordConstructorArgs.__init__(self) - gym.Wrapper.__init__(self, env) - - def step( - self, action: WrapperActType - ) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]: - """Transforms the action to a jax array . - - Args: - action: the action to perform as a numpy array - - Returns: - A tuple containing numpy versions of the next observation, reward, termination, truncation, and extra info. - """ - jax_action = numpy_to_jax(action) - obs, reward, terminated, truncated, info = self.env.step(jax_action) - - return ( - jax_to_numpy(obs), - float(reward), - bool(terminated), - bool(truncated), - jax_to_numpy(info), - ) - - def reset( - self, *, seed: int | None = None, options: dict[str, Any] | None = None - ) -> tuple[WrapperObsType, dict[str, Any]]: - """Resets the environment returning numpy-based observation and info. - - Args: - seed: The seed for resetting the environment - options: The options for resetting the environment, these are converted to jax arrays. - - Returns: - Numpy-based observations and info - """ - if options: - options = numpy_to_jax(options) - - return jax_to_numpy(self.env.reset(seed=seed, options=options)) - - def render(self) -> RenderFrame | list[RenderFrame] | None: - """Returns the rendered frames as a numpy array.""" - return jax_to_numpy(self.env.render()) + super().__init__(env=env, env_xp=jnp, target_xp=np) diff --git a/gymnasium/wrappers/jax_to_torch.py b/gymnasium/wrappers/jax_to_torch.py index 63e906689..45aa33673 100644 --- a/gymnasium/wrappers/jax_to_torch.py +++ b/gymnasium/wrappers/jax_to_torch.py @@ -7,22 +7,23 @@ # Under the Apache 2.0 license. Copyright is held by the authors """Helper functions and wrapper class for converting between PyTorch and Jax.""" + from __future__ import annotations import functools -import numbers -from collections import abc -from typing import Any, Iterable, Mapping, SupportsFloat, Union +from typing import Union import gymnasium as gym -from gymnasium.core import RenderFrame, WrapperActType, WrapperObsType from gymnasium.error import DependencyNotInstalled +from gymnasium.wrappers.array_conversion import ( + ArrayConversion, + array_conversion, + module_namespace, +) try: - import jax import jax.numpy as jnp - from jax import dlpack as jax_dlpack except ImportError: raise DependencyNotInstalled( @@ -31,7 +32,6 @@ except ImportError: try: import torch - from torch.utils import dlpack as torch_dlpack Device = Union[str, torch.device] except ImportError: @@ -42,109 +42,13 @@ except ImportError: __all__ = ["JaxToTorch", "jax_to_torch", "torch_to_jax", "Device"] -# The NoneType is not defined in Python 3.9. Remove when the minimal version is bumped to >=3.10 -_NoneType = type(None) + +torch_to_jax = functools.partial(array_conversion, xp=module_namespace(jnp)) + +jax_to_torch = functools.partial(array_conversion, xp=module_namespace(torch)) -@functools.singledispatch -def torch_to_jax(value: Any) -> Any: - """Converts a PyTorch Tensor into a Jax Array.""" - raise Exception( - f"No known conversion for Torch type ({type(value)}) to Jax registered. Report as issue on github." - ) - - -@torch_to_jax.register(numbers.Number) -def _number_torch_to_jax(value: numbers.Number) -> Any: - """Convert a python number (int, float, complex) to a jax array.""" - return jnp.array(value) - - -@torch_to_jax.register(torch.Tensor) -def _tensor_torch_to_jax(value: torch.Tensor) -> jax.Array: - """Converts a PyTorch Tensor into a Jax Array.""" - return jax_dlpack.from_dlpack(value) # pyright: ignore[reportPrivateImportUsage] - - -@torch_to_jax.register(abc.Mapping) -def _mapping_torch_to_jax(value: Mapping[str, Any]) -> Mapping[str, Any]: - """Converts a mapping of PyTorch Tensors into a Dictionary of Jax Array.""" - return type(value)(**{k: torch_to_jax(v) for k, v in value.items()}) - - -@torch_to_jax.register(abc.Iterable) -def _iterable_torch_to_jax(value: Iterable[Any]) -> Iterable[Any]: - """Converts an Iterable from PyTorch Tensors to an iterable of Jax Array.""" - if hasattr(value, "_make"): - # namedtuple - underline used to prevent potential name conflicts - # noinspection PyProtectedMember - return type(value)._make(torch_to_jax(v) for v in value) - else: - return type(value)(torch_to_jax(v) for v in value) - - -@torch_to_jax.register(_NoneType) -def _none_torch_to_jax(value: None) -> None: - """Passes through None values.""" - return value - - -@functools.singledispatch -def jax_to_torch(value: Any, device: Device | None = None) -> Any: - """Converts a Jax Array into a PyTorch Tensor.""" - raise Exception( - f"No known conversion for Jax type ({type(value)}) to PyTorch registered. Report as issue on github." - ) - - -@jax_to_torch.register(jax.Array) -def _devicearray_jax_to_torch( - value: jax.Array, device: Device | None = None -) -> torch.Tensor: - """Converts a Jax Array into a PyTorch Tensor.""" - assert jax_dlpack is not None and torch_dlpack is not None - tensor = torch_dlpack.from_dlpack( - value - ) # pyright: ignore[reportPrivateImportUsage] - if device: - return tensor.to(device=device) - return tensor - - -@jax_to_torch.register(abc.Mapping) -def _jax_mapping_to_torch( - value: Mapping[str, Any], device: Device | None = None -) -> Mapping[str, Any]: - """Converts a mapping of Jax Array into a Dictionary of PyTorch Tensors.""" - return type(value)(**{k: jax_to_torch(v, device) for k, v in value.items()}) - - -@jax_to_torch.register(abc.Iterable) -def _jax_iterable_to_torch( - value: Iterable[Any], device: Device | None = None -) -> Iterable[Any]: - """Converts an Iterable from Jax Array to an iterable of PyTorch Tensors.""" - if isinstance(value, jax.Array): - # Since the update to jax 0.6.0, calling jax_to_torch with a - # argument wrongly dispatches to _iterable_jax_to_torch which fails with: - # TypeError: (): incompatible function arguments. - # See: https://github.com/Farama-Foundation/Gymnasium/issues/1360 - return _devicearray_jax_to_torch(value) - elif hasattr(value, "_make"): - # namedtuple - underline used to prevent potential name conflicts - # noinspection PyProtectedMember - return type(value)._make(jax_to_torch(v, device) for v in value) - else: - return type(value)(jax_to_torch(v, device) for v in value) - - -@jax_to_torch.register(_NoneType) -def _none_jax_to_torch(value: None, device: Device | None = None) -> None: - """Passes through None values.""" - return value - - -class JaxToTorch(gym.Wrapper, gym.utils.RecordConstructorArgs): +class JaxToTorch(ArrayConversion): """Wraps a Jax-based environment so that it can be interacted with PyTorch Tensors. Actions must be provided as PyTorch Tensors and observations will be returned as PyTorch Tensors. @@ -183,50 +87,8 @@ class JaxToTorch(gym.Wrapper, gym.utils.RecordConstructorArgs): env: The Jax-based environment to wrap device: The device the torch Tensors should be moved to """ - gym.utils.RecordConstructorArgs.__init__(self, device=device) - gym.Wrapper.__init__(self, env) + super().__init__(env=env, env_xp=jnp, target_xp=torch, target_device=device) + # TODO: Device was part of the public API, but should be removed in favor of _env_device and + # _target_device. self.device: Device | None = device - - def step( - self, action: WrapperActType - ) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]: - """Performs the given action within the environment. - - Args: - action: The action to perform as a PyTorch Tensor - - Returns: - The next observation, reward, termination, truncation, and extra info - """ - jax_action = torch_to_jax(action) - obs, reward, terminated, truncated, info = self.env.step(jax_action) - - return ( - jax_to_torch(obs, self.device), - float(reward), - bool(terminated), - bool(truncated), - jax_to_torch(info, self.device), - ) - - def reset( - self, *, seed: int | None = None, options: dict[str, Any] | None = None - ) -> tuple[WrapperObsType, dict[str, Any]]: - """Resets the environment returning PyTorch-based observation and info. - - Args: - seed: The seed for resetting the environment - options: The options for resetting the environment, these are converted to jax arrays. - - Returns: - PyTorch-based observations and info - """ - if options: - options = torch_to_jax(options) - - return jax_to_torch(self.env.reset(seed=seed, options=options), self.device) - - def render(self) -> RenderFrame | list[RenderFrame] | None: - """Returns the rendered frames as a torch tensor.""" - return jax_to_torch(self.env.render()) diff --git a/gymnasium/wrappers/numpy_to_torch.py b/gymnasium/wrappers/numpy_to_torch.py index db3870611..4fb95037e 100644 --- a/gymnasium/wrappers/numpy_to_torch.py +++ b/gymnasium/wrappers/numpy_to_torch.py @@ -3,15 +3,17 @@ from __future__ import annotations import functools -import numbers -from collections import abc -from typing import Any, Iterable, Mapping, SupportsFloat, Union +from typing import Union import numpy as np import gymnasium as gym -from gymnasium.core import RenderFrame, WrapperActType, WrapperObsType from gymnasium.error import DependencyNotInstalled +from gymnasium.wrappers.array_conversion import ( + ArrayConversion, + array_conversion, + module_namespace, +) try: @@ -26,100 +28,13 @@ except ImportError: __all__ = ["NumpyToTorch", "torch_to_numpy", "numpy_to_torch", "Device"] -# The NoneType is not defined in Python 3.9. Remove when the minimal version is bumped to >=3.10 -_NoneType = type(None) + +torch_to_numpy = functools.partial(array_conversion, xp=module_namespace(np)) + +numpy_to_torch = functools.partial(array_conversion, xp=module_namespace(torch)) -@functools.singledispatch -def torch_to_numpy(value: Any) -> Any: - """Converts a PyTorch Tensor into a NumPy Array.""" - raise Exception( - f"No known conversion for Torch type ({type(value)}) to NumPy registered. Report as issue on github." - ) - - -@torch_to_numpy.register(numbers.Number) -def _number_to_numpy(value: numbers.Number) -> Any: - """Convert a python number (int, float, complex) to a NumPy array.""" - return np.array(value) - - -@torch_to_numpy.register(torch.Tensor) -def _torch_to_numpy(value: torch.Tensor) -> Any: - """Convert a torch.Tensor to a NumPy array.""" - return value.numpy(force=True) - - -@torch_to_numpy.register(abc.Mapping) -def _mapping_torch_to_numpy(value: Mapping[str, Any]) -> Mapping[str, Any]: - """Converts a mapping of PyTorch Tensors into a Dictionary of NumPy Array.""" - return type(value)(**{k: torch_to_numpy(v) for k, v in value.items()}) - - -@torch_to_numpy.register(abc.Iterable) -def _iterable_torch_to_numpy(value: Iterable[Any]) -> Iterable[Any]: - """Converts an Iterable from PyTorch Tensors to an iterable of NumPy Array.""" - if hasattr(value, "_make"): - # namedtuple - underline used to prevent potential name conflicts - # noinspection PyProtectedMember - return type(value)._make(torch_to_numpy(v) for v in value) - else: - return type(value)(torch_to_numpy(v) for v in value) - - -@torch_to_numpy.register(_NoneType) -def _none_torch_to_numpy(value: None) -> None: - """Passes through None values.""" - return value - - -@functools.singledispatch -def numpy_to_torch(value: Any, device: Device | None = None) -> Any: - """Converts a NumPy Array into a PyTorch Tensor.""" - raise Exception( - f"No known conversion for NumPy type ({type(value)}) to PyTorch registered. Report as issue on github." - ) - - -@numpy_to_torch.register(numbers.Number) -@numpy_to_torch.register(np.ndarray) -def _numpy_to_torch(value: np.ndarray, device: Device | None = None) -> torch.Tensor: - """Converts a NumPy Array into a PyTorch Tensor.""" - assert torch is not None - tensor = torch.tensor(value) - if device: - return tensor.to(device=device) - return tensor - - -@numpy_to_torch.register(abc.Mapping) -def _numpy_mapping_to_torch( - value: Mapping[str, Any], device: Device | None = None -) -> Mapping[str, Any]: - """Converts a mapping of NumPy Array into a Dictionary of PyTorch Tensors.""" - return type(value)(**{k: numpy_to_torch(v, device) for k, v in value.items()}) - - -@numpy_to_torch.register(abc.Iterable) -def _numpy_iterable_to_torch( - value: Iterable[Any], device: Device | None = None -) -> Iterable[Any]: - """Converts an Iterable from NumPy Array to an iterable of PyTorch Tensors.""" - if hasattr(value, "_make"): - # namedtuple - underline used to prevent potential name conflicts - # noinspection PyProtectedMember - return type(value)._make(numpy_to_torch(v, device) for v in value) - else: - return type(value)(numpy_to_torch(v, device) for v in value) - - -@numpy_to_torch.register(_NoneType) -def _none_numpy_to_torch(value: None) -> None: - """Passes through None values.""" - return value - - -class NumpyToTorch(gym.Wrapper, gym.utils.RecordConstructorArgs): +class NumpyToTorch(ArrayConversion): """Wraps a NumPy-based environment such that it can be interacted with PyTorch Tensors. Actions must be provided as PyTorch Tensors and observations will be returned as PyTorch Tensors. @@ -158,50 +73,6 @@ class NumpyToTorch(gym.Wrapper, gym.utils.RecordConstructorArgs): env: The NumPy-based environment to wrap device: The device the torch Tensors should be moved to """ - gym.utils.RecordConstructorArgs.__init__(self, device=device) - gym.Wrapper.__init__(self, env) + super().__init__(env=env, env_xp=np, target_xp=torch, target_device=device) self.device: Device | None = device - - def step( - self, action: WrapperActType - ) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]: - """Using a PyTorch based action that is converted to NumPy to be used by the environment. - - Args: - action: A PyTorch-based action - - Returns: - The PyTorch-based Tensor next observation, reward, termination, truncation, and extra info - """ - jax_action = torch_to_numpy(action) - obs, reward, terminated, truncated, info = self.env.step(jax_action) - - return ( - numpy_to_torch(obs, self.device), - float(reward), - bool(terminated), - bool(truncated), - numpy_to_torch(info, self.device), - ) - - def reset( - self, *, seed: int | None = None, options: dict[str, Any] | None = None - ) -> tuple[WrapperObsType, dict[str, Any]]: - """Resets the environment returning PyTorch-based observation and info. - - Args: - seed: The seed for resetting the environment - options: The options for resetting the environment, these are converted to jax arrays. - - Returns: - PyTorch-based observations and info - """ - if options: - options = torch_to_numpy(options) - - return numpy_to_torch(self.env.reset(seed=seed, options=options), self.device) - - def render(self) -> RenderFrame | list[RenderFrame] | None: - """Returns the rendered frames as a torch tensor.""" - return numpy_to_torch(self.env.render()) diff --git a/gymnasium/wrappers/vector/array_conversion.py b/gymnasium/wrappers/vector/array_conversion.py new file mode 100644 index 000000000..a48c8a107 --- /dev/null +++ b/gymnasium/wrappers/vector/array_conversion.py @@ -0,0 +1,131 @@ +"""Vector wrapper for converting between Array API compatible frameworks.""" + +from __future__ import annotations + +from types import ModuleType +from typing import Any + +import gymnasium as gym +from gymnasium.core import ActType, ObsType +from gymnasium.vector import VectorEnv, VectorWrapper +from gymnasium.vector.vector_env import ArrayType +from gymnasium.wrappers.array_conversion import ( + Device, + array_conversion, + module_name_to_namespace, +) + + +__all__ = ["ArrayConversion"] + + +class ArrayConversion(VectorWrapper, gym.utils.RecordConstructorArgs): + """Wraps a vector environment returning Array API compatible arrays so that it can be interacted with through a specific framework. + + Notes: + A vectorized version of ``gymnasium.wrappers.ArrayConversion`` + + Actions must be provided as Array API compatible arrays and observations, rewards, terminations and truncations will be returned in the desired framework. + xp here is a module that is compatible with the Array API standard, e.g. ``numpy``, ``jax`` etc. + + Example: + >>> import gymnasium as gym # doctest: +SKIP + >>> envs = gym.make_vec("JaxEnv-vx", 3) # doctest: +SKIP + >>> envs = ArrayConversion(envs, xp=np) # doctest: +SKIP + """ + + def __init__( + self, + env: VectorEnv, + env_xp: ModuleType | str, + target_xp: ModuleType | str, + env_device: Device | None = None, + target_device: Device | None = None, + ): + """Wrapper class to change inputs and outputs of environment to any Array API framework. + + Args: + env: The Array API compatible environment to wrap + env_xp: The Array API framework the environment is on + target_xp: The Array API framework to convert to + env_device: The device the environment is on + target_device: The device on which Arrays should be returned + """ + gym.utils.RecordConstructorArgs.__init__(self) + VectorWrapper.__init__(self, env) + self._env_xp = env_xp + self._target_xp = target_xp + self._env_device = env_device + self._target_device = target_device + + def step( + self, actions: ActType + ) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]: + """Transforms the action to the specified xp module array type. + + Args: + actions: The action to perform + + Returns: + A tuple containing xp versions of the next observation, reward, termination, truncation, and extra info. + """ + actions = array_conversion(actions, xp=self._env_xp, device=self._env_device) + obs, reward, terminated, truncated, info = self.env.step(actions) + + return ( + array_conversion(obs, xp=self._target_xp, device=self._target_device), + array_conversion(reward, xp=self._target_xp, device=self._target_device), + array_conversion( + terminated, xp=self._target_xp, device=self._target_device + ), + array_conversion(truncated, xp=self._target_xp, device=self._target_device), + array_conversion(info, xp=self._target_xp, device=self._target_device), + ) + + def reset( + self, + *, + seed: int | list[int] | None = None, + options: dict[str, Any] | None = None, + ) -> tuple[ObsType, dict[str, Any]]: + """Resets the environment returning xp-based observation and info. + + Args: + seed: The seed for resetting the environment + options: The options for resetting the environment, these are converted to xp arrays. + + Returns: + xp-based observations and info + """ + if options: + options = array_conversion( + options, xp=self._env_xp, device=self._env_device + ) + + return array_conversion( + self.env.reset(seed=seed, options=options), + xp=self._target_xp, + device=self._target_device, + ) + + def __getstate__(self): + """Returns the object pickle state with args and kwargs.""" + env_xp_name = self._env_xp.__name__.replace("array_api_compat.", "") + target_xp_name = self._target_xp.__name__.replace("array_api_compat.", "") + env_device = self._env_device + target_device = self._target_device + return { + "env_xp_name": env_xp_name, + "target_xp_name": target_xp_name, + "env_device": env_device, + "target_device": target_device, + "env": self.env, + } + + def __setstate__(self, d): + """Sets the object pickle state using d.""" + self.env = d["env"] + self._env_xp = module_name_to_namespace(d["env_xp_name"]) + self._target_xp = module_name_to_namespace(d["target_xp_name"]) + self._env_device = d["env_device"] + self._target_device = d["target_device"] diff --git a/gymnasium/wrappers/vector/jax_to_numpy.py b/gymnasium/wrappers/vector/jax_to_numpy.py index 94432a789..4f1bc92d7 100644 --- a/gymnasium/wrappers/vector/jax_to_numpy.py +++ b/gymnasium/wrappers/vector/jax_to_numpy.py @@ -2,21 +2,18 @@ from __future__ import annotations -from typing import Any - import jax.numpy as jnp +import numpy as np -from gymnasium.core import ActType, ObsType from gymnasium.error import DependencyNotInstalled -from gymnasium.vector import VectorEnv, VectorWrapper -from gymnasium.vector.vector_env import ArrayType -from gymnasium.wrappers.jax_to_numpy import jax_to_numpy, numpy_to_jax +from gymnasium.vector import VectorEnv +from gymnasium.wrappers.vector.array_conversion import ArrayConversion __all__ = ["JaxToNumpy"] -class JaxToNumpy(VectorWrapper): +class JaxToNumpy(ArrayConversion): """Wraps a jax vector environment so that it can be interacted with through numpy arrays. Notes: @@ -40,46 +37,4 @@ class JaxToNumpy(VectorWrapper): raise DependencyNotInstalled( 'Jax is not installed, run `pip install "gymnasium[jax]"`' ) - super().__init__(env) - - def step( - self, actions: ActType - ) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]: - """Transforms the action to a jax array . - - Args: - actions: the action to perform as a numpy array - - Returns: - A tuple containing numpy versions of the next observation, reward, termination, truncation, and extra info. - """ - jax_actions = numpy_to_jax(actions) - obs, reward, terminated, truncated, info = self.env.step(jax_actions) - - return ( - jax_to_numpy(obs), - jax_to_numpy(reward), - jax_to_numpy(terminated), - jax_to_numpy(truncated), - jax_to_numpy(info), - ) - - def reset( - self, - *, - seed: int | list[int] | None = None, - options: dict[str, Any] | None = None, - ) -> tuple[ObsType, dict[str, Any]]: - """Resets the environment returning numpy-based observation and info. - - Args: - seed: The seed for resetting the environment - options: The options for resetting the environment, these are converted to jax arrays. - - Returns: - Numpy-based observations and info - """ - if options: - options = numpy_to_jax(options) - - return jax_to_numpy(self.env.reset(seed=seed, options=options)) + super().__init__(env, env_xp=jnp, target_xp=np) diff --git a/gymnasium/wrappers/vector/jax_to_torch.py b/gymnasium/wrappers/vector/jax_to_torch.py index 32a468124..8cdac3ac0 100644 --- a/gymnasium/wrappers/vector/jax_to_torch.py +++ b/gymnasium/wrappers/vector/jax_to_torch.py @@ -2,18 +2,18 @@ from __future__ import annotations -from typing import Any +import jax.numpy as jnp +import torch -from gymnasium.core import ActType, ObsType -from gymnasium.vector import VectorEnv, VectorWrapper -from gymnasium.vector.vector_env import ArrayType -from gymnasium.wrappers.jax_to_torch import Device, jax_to_torch, torch_to_jax +from gymnasium.vector import VectorEnv +from gymnasium.wrappers.jax_to_torch import Device +from gymnasium.wrappers.vector.array_conversion import ArrayConversion __all__ = ["JaxToTorch"] -class JaxToTorch(VectorWrapper): +class JaxToTorch(ArrayConversion): """Wraps a Jax-based vector environment so that it can be interacted with through PyTorch Tensors. Actions must be provided as PyTorch Tensors and observations, rewards, terminations and truncations will be returned as PyTorch Tensors. @@ -31,48 +31,6 @@ class JaxToTorch(VectorWrapper): env: The Jax-based vector environment to wrap device: The device the torch Tensors should be moved to """ - super().__init__(env) + super().__init__(env, env_xp=jnp, target_xp=torch, target_device=device) self.device: Device | None = device - - def step( - self, actions: ActType - ) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]: - """Performs the given action within the environment. - - Args: - actions: The action to perform as a PyTorch Tensor - - Returns: - Torch-based Tensors of the next observation, reward, termination, truncation, and extra info - """ - jax_action = torch_to_jax(actions) - obs, reward, terminated, truncated, info = self.env.step(jax_action) - - return ( - jax_to_torch(obs, self.device), - jax_to_torch(reward, self.device), - jax_to_torch(terminated, self.device), - jax_to_torch(truncated, self.device), - jax_to_torch(info, self.device), - ) - - def reset( - self, - *, - seed: int | list[int] | None = None, - options: dict[str, Any] | None = None, - ) -> tuple[ObsType, dict[str, Any]]: - """Resets the environment returning PyTorch-based observation and info. - - Args: - seed: The seed for resetting the environment - options: The options for resetting the environment, these are converted to jax arrays. - - Returns: - PyTorch-based observations and info - """ - if options: - options = torch_to_jax(options) - - return jax_to_torch(self.env.reset(seed=seed, options=options), self.device) diff --git a/gymnasium/wrappers/vector/numpy_to_torch.py b/gymnasium/wrappers/vector/numpy_to_torch.py index 8ae7e2728..ba230a304 100644 --- a/gymnasium/wrappers/vector/numpy_to_torch.py +++ b/gymnasium/wrappers/vector/numpy_to_torch.py @@ -2,18 +2,18 @@ from __future__ import annotations -from typing import Any +import numpy as np +import torch -from gymnasium.core import ActType, ObsType -from gymnasium.vector import VectorEnv, VectorWrapper -from gymnasium.vector.vector_env import ArrayType -from gymnasium.wrappers.numpy_to_torch import Device, numpy_to_torch, torch_to_numpy +from gymnasium.vector import VectorEnv +from gymnasium.wrappers.numpy_to_torch import Device +from gymnasium.wrappers.vector.array_conversion import ArrayConversion __all__ = ["NumpyToTorch"] -class NumpyToTorch(VectorWrapper): +class NumpyToTorch(ArrayConversion): """Wraps a numpy-based environment so that it can be interacted with through PyTorch Tensors. Example: @@ -45,48 +45,6 @@ class NumpyToTorch(VectorWrapper): env: The NumPy-based vector environment to wrap device: The device the torch Tensors should be moved to """ - super().__init__(env) + super().__init__(env, env_xp=np, target_xp=torch, target_device=device) self.device: Device | None = device - - def step( - self, actions: ActType - ) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]: - """Using a PyTorch based action that is converted to NumPy to be used by the environment. - - Args: - action: A PyTorch-based action - - Returns: - The PyTorch-based Tensor next observation, reward, termination, truncation, and extra info - """ - numpy_action = torch_to_numpy(actions) - obs, reward, terminated, truncated, info = self.env.step(numpy_action) - - return ( - numpy_to_torch(obs, self.device), - numpy_to_torch(reward, self.device), - numpy_to_torch(terminated, self.device), - numpy_to_torch(truncated, self.device), - numpy_to_torch(info, self.device), - ) - - def reset( - self, - *, - seed: int | list[int] | None = None, - options: dict[str, Any] | None = None, - ) -> tuple[ObsType, dict[str, Any]]: - """Resets the environment returning PyTorch-based observation and info. - - Args: - seed: The seed for resetting the environment - options: The options for resetting the environment, these are converted to NumPy arrays. - - Returns: - PyTorch-based observations and info - """ - if options: - options = torch_to_numpy(options) - - return numpy_to_torch(self.env.reset(seed=seed, options=options), self.device) diff --git a/pyproject.toml b/pyproject.toml index 6fa14fef7..1ed572b6d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ build-backend = "setuptools.build_meta" name = "gymnasium" description = "A standard API for reinforcement learning and a diverse set of reference environments (formerly Gym)." readme = "README.md" -requires-python = ">= 3.8" +requires-python = ">= 3.10" authors = [{ name = "Farama Foundation", email = "contact@farama.org" }] license = { text = "MIT License" } keywords = ["Reinforcement Learning", "game", "RL", "AI", "gymnasium"] @@ -16,8 +16,6 @@ classifiers = [ "Development Status :: 5 - Production/Stable", "License :: OSI Approved :: MIT License", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", @@ -27,7 +25,6 @@ classifiers = [ dependencies = [ "numpy >=1.21.0", "cloudpickle >=1.2.0", - "importlib-metadata >=4.8.0; python_version < '3.10'", "typing-extensions >=4.3.0", "farama-notifications >=0.0.1", ] @@ -38,15 +35,30 @@ dynamic = ["version"] atari = ["ale_py >=0.9"] box2d = ["box2d-py ==2.3.5", "pygame >=2.1.3", "swig ==4.*"] classic-control = ["pygame >=2.1.3"] -classic_control = ["pygame >=2.1.3"] # kept for backward compatibility +classic_control = ["pygame >=2.1.3"] # kept for backward compatibility mujoco-py = ["mujoco-py >=2.1,<2.2", "cython<3"] -mujoco_py = ["mujoco-py >=2.1,<2.2", "cython<3"] # kept for backward compatibility +mujoco_py = [ + "mujoco-py >=2.1,<2.2", + "cython<3", +] # kept for backward compatibility mujoco = ["mujoco >=2.1.5", "imageio >=2.14.1", "packaging >=23.0"] toy-text = ["pygame >=2.1.3"] -toy_text = ["pygame >=2.1.3"] # kept for backward compatibility -jax = ["jax >=0.4.16", "jaxlib >=0.4.16", "flax >=0.5.0"] -torch = ["torch >=1.13.0"] -other = ["moviepy >=1.0.0", "matplotlib >=3.0", "opencv-python >=3.0", "seaborn >= 0.13"] +toy_text = ["pygame >=2.1.3"] # kept for backward compatibility +jax = [ + "jax >=0.4.16", + "jaxlib >=0.4.16", + "flax >=0.5.0", + "array-api-compat >=1.11.0", + "numpy>=2.1", +] +torch = ["torch >=1.13.0", "array-api-compat >=1.11.0", "numpy>=2.1"] +array-api = ["array-api-compat >=1.11.0", "numpy>=2.1"] +other = [ + "moviepy >=1.0.0", + "matplotlib >=3.0", + "opencv-python >=3.0", + "seaborn >= 0.13", +] all = [ # All dependencies above except accept-rom-license # NOTE: No need to manually remove the duplicates, setuptools automatically does that. @@ -71,17 +83,26 @@ all = [ "jax >=0.4.16", "jaxlib >=0.4.16", "flax >= 0.5.0", + "array-api-compat >=1.11.0", + "numpy>=2.1", # torch "torch >=1.13.0", + "array-api-compat >=1.11.0", + "numpy>=2.1", + # array-api + "array-api-compat >=1.11.0", + "numpy>=2.1", # other "opencv-python >=3.0", "matplotlib >=3.0", "moviepy >=1.0.0", ] + testing = [ "pytest >=7.1.3", "scipy >=1.7.3", "dill >=0.3.7", + "array_api_extra >=0.7.0", ] [project.urls] @@ -125,7 +146,7 @@ exclude = ["tests/**", "**/node_modules", "**/__pycache__"] strict = [] typeCheckingMode = "basic" -pythonVersion = "3.8" +pythonVersion = "3.10" pythonPlatform = "All" typeshedPath = "typeshed" enableTypeIgnoreComments = true @@ -138,19 +159,19 @@ reportMissingTypeStubs = false # For warning and error, will raise an error when reportInvalidTypeVarUse = "none" -reportGeneralTypeIssues = "none" # -> commented out raises 489 errors -reportAttributeAccessIssue = "none" # pyright provides false positives -reportArgumentType = "none" # pyright provides false positives +reportGeneralTypeIssues = "none" # -> commented out raises 489 errors +reportAttributeAccessIssue = "none" # pyright provides false positives +reportArgumentType = "none" # pyright provides false positives reportPrivateUsage = "warning" -reportIndexIssue = "none" # TODO fix one by one -reportReturnType = "none" # TODO fix one by one -reportCallIssue = "none" # TODO fix one by one -reportOperatorIssue = "none" # TODO fix one by one -reportInvalidTypeForm = "none" # TODO fix one by one -reportOptionalMemberAccess = "none" # TODO fix one by one -reportAssignmentType = "none" # TODO fix one by one +reportIndexIssue = "none" # TODO fix one by one +reportReturnType = "none" # TODO fix one by one +reportCallIssue = "none" # TODO fix one by one +reportOperatorIssue = "none" # TODO fix one by one +reportInvalidTypeForm = "none" # TODO fix one by one +reportOptionalMemberAccess = "none" # TODO fix one by one +reportAssignmentType = "none" # TODO fix one by one [tool.pytest.ini_options] diff --git a/tests/testing_env.py b/tests/testing_env.py index 41572a62d..82d23f06c 100644 --- a/tests/testing_env.py +++ b/tests/testing_env.py @@ -6,10 +6,15 @@ import types from collections.abc import Callable from typing import Any +import numpy as np + import gymnasium as gym from gymnasium import spaces from gymnasium.core import ActType, ObsType from gymnasium.envs.registration import EnvSpec +from gymnasium.vector import VectorEnv +from gymnasium.vector.utils import batch_space +from gymnasium.vector.vector_env import AutoresetMode def basic_reset_func( @@ -106,3 +111,112 @@ class GenericTestEnv(gym.Env): def render(self): """Renders the environment.""" raise NotImplementedError("testingEnv render_fn is not set.") + + +def basic_vector_reset_func( + self, + *, + seed: int | None = None, + options: dict | None = None, +) -> tuple[ObsType, dict]: + """A basic reset function that will pass the environment check using random actions from the observation space.""" + super(GenericTestVectorEnv, self).reset(seed=seed) + self.observation_space.seed(self.np_random_seed) + return self.observation_space.sample(), {"options": options} + + +def basic_vector_step_func( + self, action: ActType +) -> tuple[ObsType, np.ndarray, np.ndarray, np.ndarray, dict]: + """A step function that follows the basic step api that will pass the environment check using random actions from the observation space.""" + obs = self.observation_space.sample() + rewards = np.zeros(self.num_envs, dtype=np.float64) + terminations = np.zeros(self.num_envs, dtype=np.bool_) + truncations = np.zeros(self.num_envs, dtype=np.bool_) + return obs, rewards, terminations, truncations, {} + + +def basic_vector_render_func(self): + """Basic render fn that does nothing.""" + pass + + +class GenericTestVectorEnv(VectorEnv): + """A generic testing vector environment similar to GenericTestEnv. + + Some tests cannot use SyncVectorEnv, e.g. when returning non-numpy arrays in the observations. + In these cases, GenericTestVectorEnv can be used to simulate a vector environment. + """ + + def __init__( + self, + num_envs: int = 1, + action_space: spaces.Space = spaces.Box(0, 1, (1,)), + observation_space: spaces.Space = spaces.Box(0, 1, (1,)), + reset_func: Callable = basic_vector_reset_func, + step_func: Callable = basic_vector_step_func, + render_func: Callable = basic_vector_render_func, + metadata: dict[str, Any] = { + "render_modes": [], + "autoreset_mode": AutoresetMode.NEXT_STEP, + }, + render_mode: str | None = None, + spec: EnvSpec = EnvSpec( + "TestingVectorEnv-v0", + "tests.testing_env:GenericTestVectorEnv", + max_episode_steps=100, + ), + ): + """Generic testing vector environment constructor. + + Args: + num_envs: The number of environments to create + action_space: The environment action space + observation_space: The environment observation space + reset_func: The environment reset function + step_func: The environment step function + render_func: The environment render function + metadata: The environment metadata + render_mode: The render mode of the environment + spec: The environment spec + """ + super().__init__() + + self.num_envs = num_envs + self.metadata = metadata + self.render_mode = render_mode + self.spec = spec + + # Set the single spaces and create batched spaces + self.single_observation_space = observation_space + self.single_action_space = action_space + self.observation_space = batch_space(observation_space, num_envs) + self.action_space = batch_space(action_space, num_envs) + + # Bind the functions to the instance + if reset_func is not None: + self.reset = types.MethodType(reset_func, self) + if step_func is not None: + self.step = types.MethodType(step_func, self) + if render_func is not None: + self.render = types.MethodType(render_func, self) + + def reset( + self, + *, + seed: int | None = None, + options: dict | None = None, + ) -> tuple[ObsType, dict]: + """Resets the environment.""" + # If you need a default working reset function, use `basic_vector_reset_fn` above + raise NotImplementedError("TestingVectorEnv reset_fn is not set.") + + def step( + self, action: ActType + ) -> tuple[ObsType, np.ndarray, np.ndarray, np.ndarray, dict]: + """Steps through the environment.""" + raise NotImplementedError("TestingVectorEnv step_fn is not set.") + + def render(self) -> tuple[Any, ...] | None: + """Renders the environment.""" + raise NotImplementedError("TestingVectorEnv render_fn is not set.") diff --git a/tests/wrappers/test_array_conversion.py b/tests/wrappers/test_array_conversion.py new file mode 100644 index 000000000..a2091de49 --- /dev/null +++ b/tests/wrappers/test_array_conversion.py @@ -0,0 +1,250 @@ +"""Test suite for ArrayConversion wrapper.""" + +import importlib +import itertools +import pickle +from typing import Any, NamedTuple + +import pytest + +import gymnasium + + +array_api_compat = pytest.importorskip("array_api_compat") +array_api_extra = pytest.importorskip("array_api_extra") + +from array_api_compat import array_namespace, is_array_api_obj # noqa: E402 + +from gymnasium.wrappers.array_conversion import ( # noqa: E402 + ArrayConversion, + array_conversion, + module_namespace, +) +from tests.testing_env import GenericTestEnv # noqa: E402 + + +# Define available modules +installed_modules = [] +array_api_modules = [ + "numpy", + "jax.numpy", + "torch", + "cupy", + "dask.array", + "sparse", + "array_api_strict", +] +for module in array_api_modules: + try: + installed_modules.append(importlib.import_module(module)) + except ImportError: + pass # Modules that are not installed are skipped + +installed_modules_combinations = list(itertools.permutations(installed_modules, 2)) + + +def xp_data_equivalence(data_1, data_2) -> bool: + """Return if two variables are equivalent that might contain ``torch.Tensor``.""" + if type(data_1) is type(data_2): + if isinstance(data_1, dict): + return data_1.keys() == data_2.keys() and all( + xp_data_equivalence(data_1[k], data_2[k]) for k in data_1.keys() + ) + elif isinstance(data_1, (tuple, list)): + return len(data_1) == len(data_2) and all( + xp_data_equivalence(o_1, o_2) for o_1, o_2 in zip(data_1, data_2) + ) + elif is_array_api_obj(data_1): + return array_api_extra.isclose(data_1, data_2, atol=0.00001).all() + else: + return data_1 == data_2 + else: + return False + + +class ExampleNamedTuple(NamedTuple): + a: Any # Array API compatible object. Does not have proper typing support yet. + b: Any # Same as a + + +def _supports_higher_precision(xp, low_type, high_type): + """Check if an array module supports higher precision type.""" + return xp.result_type(low_type, high_type) == high_type + + +# When converting between array modules (source → target → source), we need to ensure that the +# precision used is supported by both modules. If either module only supports 32-bit types, we must +# use the lower precision to account for the conversion during the roundtrip. +def atleast_float32(source_xp, target_xp): + """Return source_xp.float64 if both modules support it, otherwise source_xp.float32.""" + source_supports_64 = _supports_higher_precision( + source_xp, source_xp.float32, source_xp.float64 + ) + target_supports_64 = _supports_higher_precision( + target_xp, target_xp.float32, target_xp.float64 + ) + return ( + source_xp.float64 + if (source_supports_64 and target_supports_64) + else source_xp.float32 + ) + + +def atleast_int32(source_xp, target_xp): + """Return source_xp.int64 if both modules support it, otherwise source_xp.int32.""" + source_supports_64 = _supports_higher_precision( + source_xp, source_xp.int32, source_xp.int64 + ) + target_supports_64 = _supports_higher_precision( + target_xp, target_xp.int32, target_xp.int64 + ) + return ( + source_xp.int64 + if (source_supports_64 and target_supports_64) + else source_xp.int32 + ) + + +def value_parametrization(): + for source_xp, target_xp in installed_modules_combinations: + xp = module_namespace(source_xp) + source_xp = module_namespace(source_xp) + target_xp = module_namespace(target_xp) + for value, expected_value in [ + (2, xp.asarray(2, dtype=atleast_int32(source_xp, target_xp))), + ( + (3.0, 4), + ( + xp.asarray(3.0, dtype=atleast_float32(source_xp, target_xp)), + xp.asarray(4, dtype=atleast_int32(source_xp, target_xp)), + ), + ), + ( + [3.0, 4], + [ + xp.asarray(3.0, dtype=atleast_float32(source_xp, target_xp)), + xp.asarray(4, dtype=atleast_int32(source_xp, target_xp)), + ], + ), + ( + { + "a": 6.0, + "b": 7, + }, + { + "a": xp.asarray(6.0, dtype=atleast_float32(source_xp, target_xp)), + "b": xp.asarray(7, dtype=atleast_int32(source_xp, target_xp)), + }, + ), + (xp.asarray(1.0, dtype=xp.float32), xp.asarray(1.0, dtype=xp.float32)), + (xp.asarray(1.0, dtype=xp.uint8), xp.asarray(1.0, dtype=xp.uint8)), + (xp.asarray([1, 2], dtype=xp.int32), xp.asarray([1, 2], dtype=xp.int32)), + ( + xp.asarray([[1.0], [2.0]], dtype=xp.int32), + xp.asarray([[1.0], [2.0]], dtype=xp.int32), + ), + ( + { + "a": ( + 1, + xp.asarray(2.0, dtype=xp.float32), + xp.asarray([3, 4], dtype=xp.int32), + ), + "b": {"c": 5}, + }, + { + "a": ( + xp.asarray(1, dtype=atleast_int32(source_xp, target_xp)), + xp.asarray(2.0, dtype=xp.float32), + xp.asarray([3, 4], dtype=xp.int32), + ), + "b": { + "c": xp.asarray(5, dtype=atleast_int32(source_xp, target_xp)) + }, + }, + ), + ( + ExampleNamedTuple( + a=xp.asarray([1, 2], dtype=xp.int32), + b=xp.asarray([1.0, 2.0], dtype=xp.float32), + ), + ExampleNamedTuple( + a=xp.asarray([1, 2], dtype=xp.int32), + b=xp.asarray([1.0, 2.0], dtype=xp.float32), + ), + ), + (None, None), + ]: + yield (source_xp, target_xp, value, expected_value) + + +@pytest.mark.parametrize( + "source_xp,target_xp,value,expected_value", value_parametrization() +) +def test_roundtripping(source_xp, target_xp, value, expected_value): + """Test roundtripping between different Array API compatible frameworks.""" + roundtripped_value = array_conversion( + array_conversion(value, xp=target_xp), xp=source_xp + ) + assert xp_data_equivalence(roundtripped_value, expected_value) + + +@pytest.mark.parametrize("env_xp, target_xp", installed_modules_combinations) +def test_array_conversion_wrapper(env_xp, target_xp): + # Define reset and step functions without partial to avoid pickling issues + + def reset_func(self, seed=None, options=None): + """A generic array API reset function.""" + return env_xp.asarray([1.0, 2.0, 3.0]), {"data": env_xp.asarray([1, 2, 3])} + + def step_func(self, action): + """A generic array API step function.""" + assert isinstance(action, type(env_xp.zeros(1))) + return ( + env_xp.asarray([1, 2, 3]), + env_xp.asarray(5.0), + env_xp.asarray(True), + env_xp.asarray(False), + {"data": env_xp.asarray([1.0, 2.0])}, + ) + + env = GenericTestEnv(reset_func=reset_func, step_func=step_func) + + # Check that the reset and step for env_xp environment are as expected + obs, info = env.reset() + # env_xp is automatically converted to the compatible namespace by array_namespace, so we need + # to check against the compatible namespace of env_xp in array_api_compat + env_xp_compat = module_namespace(env_xp) + assert array_namespace(obs) is env_xp_compat + assert isinstance(info, dict) and array_namespace(info["data"]) is env_xp_compat + + obs, reward, terminated, truncated, info = env.step(env_xp_compat.asarray([1, 2])) + assert array_namespace(obs) is env_xp_compat + assert array_namespace(reward) is env_xp_compat + assert array_namespace(terminated) is env_xp_compat + assert array_namespace(truncated) is env_xp_compat + assert isinstance(info, dict) and array_namespace(info["data"]) is env_xp_compat + + # Check that the wrapped version is correct. + target_xp_compat = module_namespace(target_xp) + wrapped_env = ArrayConversion(env, env_xp=env_xp, target_xp=target_xp) + obs, info = wrapped_env.reset() + assert array_namespace(obs) is target_xp_compat + assert isinstance(info, dict) and array_namespace(info["data"]) is target_xp_compat + + action = target_xp.asarray([1, 2], dtype=target_xp.int32) + obs, reward, terminated, truncated, info = wrapped_env.step(action) + assert array_namespace(obs) is target_xp_compat + assert isinstance(reward, float) + assert isinstance(terminated, bool) and isinstance(truncated, bool) + assert isinstance(info, dict) and array_namespace(info["data"]) is target_xp_compat + + # Check that the wrapped environment can render. This implicitly returns None and requires a + # None -> None conversion + wrapped_env.render() + + # Test that the wrapped environment can be pickled + env = gymnasium.make("CartPole-v1", disable_env_checker=True) + wrapped_env = ArrayConversion(env, env_xp=env_xp, target_xp=target_xp) + pkl = pickle.dumps(wrapped_env) + pickle.loads(pkl) diff --git a/tests/wrappers/test_jax_to_numpy.py b/tests/wrappers/test_jax_to_numpy.py index cbfd036e3..235ac25ca 100644 --- a/tests/wrappers/test_jax_to_numpy.py +++ b/tests/wrappers/test_jax_to_numpy.py @@ -1,10 +1,13 @@ """Test suite for JaxToNumpy wrapper.""" +import pickle from typing import NamedTuple import numpy as np import pytest +import gymnasium + jax = pytest.importorskip("jax") jnp = pytest.importorskip("jax.numpy") @@ -133,3 +136,9 @@ def test_jax_to_numpy_wrapper(): # Check that the wrapped environment can render. This implicitly returns None and requires a # None -> None conversion numpy_env.render() + + # Test that the wrapped environment can be pickled + env = gymnasium.make("CartPole-v1", disable_env_checker=True) + wrapped_env = JaxToNumpy(env) + pkl = pickle.dumps(wrapped_env) + pickle.loads(pkl) diff --git a/tests/wrappers/test_jax_to_torch.py b/tests/wrappers/test_jax_to_torch.py index 11b9ae113..e68d3070b 100644 --- a/tests/wrappers/test_jax_to_torch.py +++ b/tests/wrappers/test_jax_to_torch.py @@ -1,9 +1,12 @@ """Test suite for TorchToJax wrapper.""" +import pickle from typing import NamedTuple import pytest +import gymnasium + jax = pytest.importorskip("jax") jnp = pytest.importorskip("jax.numpy") @@ -148,3 +151,9 @@ def test_jax_to_torch_wrapper(): # Check that the wrapped environment can render. This implicitly returns None and requires a # None -> None conversion wrapped_env.render() + + # Test that the wrapped environment can be pickled + env = gymnasium.make("CartPole-v1", disable_env_checker=True) + wrapped_env = JaxToTorch(env) + pkl = pickle.dumps(wrapped_env) + pickle.loads(pkl) diff --git a/tests/wrappers/test_numpy_to_torch.py b/tests/wrappers/test_numpy_to_torch.py index 323fc83e2..29173c8b6 100644 --- a/tests/wrappers/test_numpy_to_torch.py +++ b/tests/wrappers/test_numpy_to_torch.py @@ -1,10 +1,13 @@ """Test suite for NumPyToTorch wrapper.""" +import pickle from typing import NamedTuple import numpy as np import pytest +import gymnasium + torch = pytest.importorskip("torch") @@ -128,3 +131,9 @@ def test_numpy_to_torch(): # Check that the wrapped environment can render. This implicitly returns None and requires a # None -> None conversion torch_env.render() + + # Test that the wrapped environment can be pickled + env = gymnasium.make("CartPole-v1", disable_env_checker=True) + wrapped_env = NumpyToTorch(env) + pkl = pickle.dumps(wrapped_env) + pickle.loads(pkl) diff --git a/tests/wrappers/vector/test_array_conversion.py b/tests/wrappers/vector/test_array_conversion.py new file mode 100644 index 000000000..6288946cd --- /dev/null +++ b/tests/wrappers/vector/test_array_conversion.py @@ -0,0 +1,146 @@ +"""Test suite for vector ArrayConversion wrapper.""" + +import importlib +import itertools +from functools import partial + +import numpy as np +import pytest + +from tests.testing_env import GenericTestVectorEnv + + +array_api_compat = pytest.importorskip("array_api_compat") +from array_api_compat import array_namespace # noqa: E402 + +from gymnasium.wrappers.array_conversion import module_namespace # noqa: E402 +from gymnasium.wrappers.vector.array_conversion import ArrayConversion # noqa: E402 +from gymnasium.wrappers.vector.jax_to_numpy import JaxToNumpy # noqa: E402 +from gymnasium.wrappers.vector.jax_to_torch import JaxToTorch # noqa: E402 +from gymnasium.wrappers.vector.numpy_to_torch import NumpyToTorch # noqa: E402 + + +# Define available modules +installed_modules = [] +array_api_modules = [ + "numpy", + "jax.numpy", + "torch", + "cupy", + "dask.array", + "sparse", + "array_api_strict", +] +for module in array_api_modules: + try: + installed_modules.append(importlib.import_module(module)) + except ImportError: + pass # Modules that are not installed are skipped + +installed_modules_combinations = list(itertools.permutations(installed_modules, 2)) + + +def create_vector_env(env_xp): + _reset_func = partial(reset_func, num_envs=3, xp=env_xp) + _step_func = partial(step_func, num_envs=3, xp=env_xp) + return GenericTestVectorEnv(reset_func=_reset_func, step_func=_step_func) + + +def reset_func(self, seed=None, options=None, num_envs: int = 1, xp=np): + return xp.asarray([[1.0, 2.0, 3.0] * num_envs]), { + "data": xp.asarray([[1, 2, 3] * num_envs]) + } + + +def step_func(self, action, num_envs: int = 1, xp=np): + assert isinstance(action, type(xp.zeros(1))) + return ( + xp.asarray([[1, 2, 3] * num_envs]), + xp.asarray([5.0] * num_envs), + xp.asarray([False] * num_envs), + xp.asarray([False] * num_envs), + {"data": xp.asarray([[1.0, 2.0] * num_envs])}, + ) + + +@pytest.mark.parametrize("env_xp, target_xp", installed_modules_combinations) +def test_array_conversion_wrapper(env_xp, target_xp): + env_xp_compat = module_namespace(env_xp) + env = create_vector_env(env_xp_compat) + + # Check that the reset and step for env_xp environment are as expected + obs, info = env.reset() + # env_xp is automatically converted to the compatible namespace by array_namespace, so we need + # to check against the compatible namespace of env_xp in array_api_compat + assert array_namespace(obs) is env_xp_compat + assert isinstance(info, dict) and array_namespace(info["data"]) is env_xp_compat + + obs, reward, terminated, truncated, info = env.step(env_xp_compat.asarray([1, 2])) + assert array_namespace(obs) is env_xp_compat + assert array_namespace(reward) is env_xp_compat + assert array_namespace(terminated) is env_xp_compat + assert array_namespace(truncated) is env_xp_compat + assert isinstance(info, dict) and array_namespace(info["data"]) is env_xp_compat + + # Check that the wrapped version is correct. + target_xp_compat = module_namespace(target_xp) + wrapped_env = ArrayConversion(env, env_xp=env_xp, target_xp=target_xp) + obs, info = wrapped_env.reset() + assert array_namespace(obs) is target_xp_compat + assert isinstance(info, dict) and array_namespace(info["data"]) is target_xp_compat + + action = target_xp.asarray([1, 2], dtype=target_xp.int32) + obs, reward, terminated, truncated, info = wrapped_env.step(action) + assert array_namespace(obs) is target_xp_compat + assert array_namespace(reward) is target_xp_compat + assert array_namespace(terminated) is target_xp_compat + assert terminated.dtype == target_xp.bool + assert array_namespace(truncated) is target_xp_compat + assert truncated.dtype == target_xp.bool + assert isinstance(info, dict) and array_namespace(info["data"]) is target_xp_compat + + # Check that the wrapped environment can render. This implicitly returns None and requires a + # None -> None conversion + wrapped_env.render() + + +@pytest.mark.parametrize("wrapper", [JaxToNumpy, JaxToTorch, NumpyToTorch]) +def test_specialized_wrappers(wrapper: type[JaxToNumpy | JaxToTorch | NumpyToTorch]): + if wrapper is JaxToNumpy: + jax = pytest.importorskip("jax") + env_xp, target_xp = jax.numpy, np + elif wrapper is JaxToTorch: + jax = pytest.importorskip("jax") + torch = pytest.importorskip("torch") + env_xp, target_xp = jax.numpy, torch + elif wrapper is NumpyToTorch: + torch = pytest.importorskip("torch") + env_xp, target_xp = np, torch + else: + raise TypeError(f"Unknown specialized conversion wrapper {type(wrapper)}") + env_xp_compat = module_namespace(env_xp) + target_xp_compat = module_namespace(target_xp) + + # The unwrapped test env sanity check is already covered by test_array_conversion_wrapper for + # all known frameworks, including the specialized ones. + env = create_vector_env(env_xp_compat) + + # Check that the wrapped version is correct. + wrapped_env = wrapper(env) + obs, info = wrapped_env.reset() + assert array_namespace(obs) is target_xp_compat + assert isinstance(info, dict) and array_namespace(info["data"]) is target_xp_compat + + action = target_xp.asarray([1, 2], dtype=target_xp.int32) + obs, reward, terminated, truncated, info = wrapped_env.step(action) + assert array_namespace(obs) is target_xp_compat + assert array_namespace(reward) is target_xp_compat + assert array_namespace(terminated) is target_xp_compat + assert terminated.dtype == target_xp.bool + assert array_namespace(truncated) is target_xp_compat + assert truncated.dtype == target_xp.bool + assert isinstance(info, dict) and array_namespace(info["data"]) is target_xp_compat + + # Check that the wrapped environment can render. This implicitly returns None and requires a + # None -> None conversion + wrapped_env.render()