mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-07-30 13:28:50 +00:00
Add generic conversion wrapper between Array API compatible frameworks (#1333)
This commit is contained in:
2
.github/workflows/docs-build-dev.yml
vendored
2
.github/workflows/docs-build-dev.yml
vendored
@@ -19,7 +19,7 @@ jobs:
|
||||
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.9'
|
||||
python-version: '3.10'
|
||||
|
||||
- name: Install dependencies
|
||||
run: pip install -r docs/requirements.txt
|
||||
|
2
.github/workflows/docs-build-release.yml
vendored
2
.github/workflows/docs-build-release.yml
vendored
@@ -19,7 +19,7 @@ jobs:
|
||||
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.9'
|
||||
python-version: '3.10'
|
||||
|
||||
- name: Install dependencies
|
||||
run: pip install -r docs/requirements.txt
|
||||
|
2
.github/workflows/docs-manual-build.yml
vendored
2
.github/workflows/docs-manual-build.yml
vendored
@@ -33,7 +33,7 @@ jobs:
|
||||
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.9'
|
||||
python-version: '3.10'
|
||||
|
||||
- name: Install dependencies
|
||||
run: pip install -r docs/requirements.txt
|
||||
|
5
.github/workflows/run-pytest.yml
vendored
5
.github/workflows/run-pytest.yml
vendored
@@ -10,8 +10,8 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: true
|
||||
matrix:
|
||||
python-version: ['3.9', '3.10', '3.11', '3.12']
|
||||
numpy-version: ['>=1.21,<2.0', '>=2.0']
|
||||
python-version: ['3.10', '3.11', '3.12']
|
||||
numpy-version: ['>=1.21,<2.0', '>=2.1']
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- run: |
|
||||
@@ -22,7 +22,6 @@ jobs:
|
||||
- name: Run tests
|
||||
run: docker run gymnasium-all-docker pytest tests/*
|
||||
- name: Run doctests
|
||||
if: ${{ matrix.python-version != '3.8' }}
|
||||
run: docker run gymnasium-all-docker pytest --doctest-modules gymnasium/
|
||||
|
||||
build-necessary:
|
||||
|
2
.github/workflows/run-tutorial.yml
vendored
2
.github/workflows/run-tutorial.yml
vendored
@@ -20,7 +20,7 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false # This ensures all matrix combinations run even if one fails
|
||||
matrix:
|
||||
python-version: ["3.9"]
|
||||
python-version: ["3.10"]
|
||||
tutorial-group:
|
||||
- gymnasium_basics
|
||||
- training_agents
|
||||
|
@@ -32,7 +32,7 @@ To install the base Gymnasium library, use `pip install gymnasium`
|
||||
|
||||
This does not include dependencies for all families of environments (there's a massive number, and some can be problematic to install on certain systems). You can install these dependencies for one family like `pip install "gymnasium[atari]"` or use `pip install "gymnasium[all]"` to install all dependencies.
|
||||
|
||||
We support and test for Python 3.8, 3.9, 3.10, 3.11 and 3.12 on Linux and macOS. We will accept PRs related to Windows, but do not officially support it.
|
||||
We support and test for Python 3.10, 3.11 and 3.12 on Linux and macOS. We will accept PRs related to Windows, but do not officially support it.
|
||||
|
||||
## API
|
||||
|
||||
|
@@ -27,6 +27,7 @@ title: Misc Wrappers
|
||||
## Data Conversion Wrappers
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: gymnasium.wrappers.ArrayConversion
|
||||
.. autoclass:: gymnasium.wrappers.JaxToNumpy
|
||||
.. autoclass:: gymnasium.wrappers.JaxToTorch
|
||||
.. autoclass:: gymnasium.wrappers.NumpyToTorch
|
||||
|
@@ -34,6 +34,8 @@ wrapper in the page on the wrapper type
|
||||
- Converts an image observation computed by ``reset`` and ``step`` from RGB to Grayscale.
|
||||
* - :class:`HumanRendering`
|
||||
- Allows human like rendering for environments that support "rgb_array" rendering.
|
||||
* - :class:`ArrayConversion`
|
||||
- Wraps an environment based on any Array API compatible framework, e.g. torch, jax, numpy, such that it can be interacted with any other Array API compatible framework.
|
||||
* - :class:`JaxToNumpy`
|
||||
- Wraps a Jax-based environment such that it can be interacted with NumPy arrays.
|
||||
* - :class:`JaxToTorch`
|
||||
|
261
gymnasium/wrappers/array_conversion.py
Normal file
261
gymnasium/wrappers/array_conversion.py
Normal file
@@ -0,0 +1,261 @@
|
||||
# This wrapper will convert array inputs from an Array API compatible framework A for the actions
|
||||
# to any other Array API compatible framework B for an underlying environment that is implemented
|
||||
# in framework B, then convert the return observations from framework B back to framework A.
|
||||
#
|
||||
# More precisely, the wrapper will work for any two frameworks that can be made compatible with the
|
||||
# `array-api-compat` package.
|
||||
#
|
||||
# See https://data-apis.org/array-api/latest/ for more information on the Array API standard, and
|
||||
# https://data-apis.org/array-api-compat/ for more information on the Array API compatibility layer.
|
||||
#
|
||||
# General structure for converting between types originally copied from
|
||||
# https://github.com/google/brax/blob/9d6b7ced2a13da0d074b5e9fbd3aad8311e26997/brax/io/torch.py
|
||||
# Under the Apache 2.0 license. Copyright is held by the authors
|
||||
|
||||
"""Helper functions and wrapper class for converting between arbitrary Array API compatible frameworks and a target framework."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import importlib
|
||||
import numbers
|
||||
from collections import abc
|
||||
from types import ModuleType, NoneType
|
||||
from typing import Any, Iterable, Mapping, SupportsFloat
|
||||
|
||||
import numpy as np
|
||||
from packaging.version import Version
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.core import RenderFrame, WrapperActType, WrapperObsType
|
||||
from gymnasium.error import DependencyNotInstalled
|
||||
|
||||
|
||||
try:
|
||||
from array_api_compat import array_namespace, is_array_api_obj, to_device
|
||||
|
||||
except ImportError:
|
||||
raise DependencyNotInstalled(
|
||||
'Array API packages are not installed therefore cannot call `array_conversion`, run `pip install "gymnasium[array-api]"`'
|
||||
)
|
||||
|
||||
|
||||
if Version(np.__version__) < Version("2.1.0"):
|
||||
raise DependencyNotInstalled("Array API functionality requires numpy >= 2.1.0")
|
||||
|
||||
|
||||
__all__ = ["ArrayConversion", "array_conversion"]
|
||||
|
||||
Array = Any # TODO: Switch to ArrayAPI type once https://github.com/data-apis/array-api/pull/589 is merged
|
||||
Device = Any # TODO: Switch to ArrayAPI type if available
|
||||
|
||||
|
||||
def module_namespace(xp: ModuleType) -> ModuleType:
|
||||
"""Determine the Array API compatible namespace of the given module.
|
||||
|
||||
This function is closely linked to the `array_api_compat.array_namespace` function. It returns
|
||||
the compatible namespace for a module directly instead of from an array object of that module.
|
||||
|
||||
See https://data-apis.org/array-api-compat/helper-functions.html#array_api_compat.array_namespace
|
||||
"""
|
||||
try:
|
||||
return array_namespace(xp.empty(0))
|
||||
except AttributeError as e:
|
||||
raise ValueError(f"Module {xp} is not an Array API compatible module.") from e
|
||||
|
||||
|
||||
def module_name_to_namespace(name: str) -> ModuleType:
|
||||
return module_namespace(importlib.import_module(name))
|
||||
|
||||
|
||||
@functools.singledispatch
|
||||
def array_conversion(value: Any, xp: ModuleType, device: Device | None = None) -> Any:
|
||||
"""Convert a value into the specified xp module array type."""
|
||||
raise Exception(
|
||||
f"No known conversion for ({type(value)}) to xp module ({xp}) registered. Report as issue on github."
|
||||
)
|
||||
|
||||
|
||||
@array_conversion.register(numbers.Number)
|
||||
def _number_array_conversion(
|
||||
value: numbers.Number, xp: ModuleType, device: Device | None = None
|
||||
) -> Array:
|
||||
"""Convert a python number (int, float, complex) to an Array API framework array."""
|
||||
return xp.asarray(value, device=device)
|
||||
|
||||
|
||||
@array_conversion.register(abc.Mapping)
|
||||
def _mapping_array_conversion(
|
||||
value: Mapping[str, Any], xp: ModuleType, device: Device | None = None
|
||||
) -> Mapping[str, Any]:
|
||||
"""Convert a mapping of Arrays into a Dictionary of the specified xp module array type."""
|
||||
return type(value)(**{k: array_conversion(v, xp, device) for k, v in value.items()})
|
||||
|
||||
|
||||
@array_conversion.register(abc.Iterable)
|
||||
def _iterable_array_conversion(
|
||||
value: Iterable[Any], xp: ModuleType, device: Device | None = None
|
||||
) -> Iterable[Any]:
|
||||
"""Convert an Iterable from Arrays to an iterable of the specified xp module array type."""
|
||||
# There is currently no type for ArrayAPI compatible objects, so they fall through to this
|
||||
# function registered for any Iterable. If they are arrays, we can convert them directly.
|
||||
# We currently cannot pass the device to the from_dlpack function, since it is not supported
|
||||
# for some frameworks (see e.g. https://github.com/data-apis/array-api-compat/issues/204)
|
||||
if is_array_api_obj(value):
|
||||
return _array_api_array_conversion(value, xp, device)
|
||||
if hasattr(value, "_make"):
|
||||
# namedtuple - underline used to prevent potential name conflicts
|
||||
# noinspection PyProtectedMember
|
||||
return type(value)._make(array_conversion(v, xp, device) for v in value)
|
||||
return type(value)(array_conversion(v, xp, device) for v in value)
|
||||
|
||||
|
||||
def _array_api_array_conversion(
|
||||
value: Array, xp: ModuleType, device: Device | None = None
|
||||
) -> Array:
|
||||
"""Convert an Array API compatible array to the specified xp module array type."""
|
||||
try:
|
||||
x = xp.from_dlpack(value)
|
||||
return to_device(x, device) if device is not None else x
|
||||
except (RuntimeError, BufferError):
|
||||
# If dlpack fails (e.g. because the array is read-only for frameworks that do not
|
||||
# support it), we create a copy of the array that we own and then convert it.
|
||||
# TODO: The correct treatment of read-only arrays is currently not fully clear in the
|
||||
# Array API. Once ongoing discussions are resolved, we should update this code to remove
|
||||
# any fallbacks.
|
||||
value_namespace = array_namespace(value)
|
||||
value_copy = value_namespace.asarray(value, copy=True)
|
||||
return xp.asarray(value_copy, device=device)
|
||||
|
||||
|
||||
@array_conversion.register(NoneType)
|
||||
def _none_array_conversion(
|
||||
value: None, xp: ModuleType, device: Device | None = None
|
||||
) -> None:
|
||||
"""Passes through None values."""
|
||||
return value
|
||||
|
||||
|
||||
class ArrayConversion(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
||||
"""Wraps an Array API compatible environment so that it can be interacted with a specific Array API framework.
|
||||
|
||||
Actions must be provided as Array API compatible arrays and observations will be returned as Arrays of the specified xp module.
|
||||
A vector version of the wrapper exists, :class:`gymnasium.wrappers.vector.ArrayConversion`.
|
||||
|
||||
Example:
|
||||
>>> import torch # doctest: +SKIP
|
||||
>>> import jax.numpy as jnp # doctest: +SKIP
|
||||
>>> import gymnasium as gym # doctest: +SKIP
|
||||
>>> env = gym.make("JaxEnv-vx") # doctest: +SKIP
|
||||
>>> env = ArrayConversion(env, env_xp=jnp, target_xp=torch) # doctest: +SKIP
|
||||
>>> obs, _ = env.reset(seed=123) # doctest: +SKIP
|
||||
>>> type(obs) # doctest: +SKIP
|
||||
<class 'torch.Tensor'>
|
||||
>>> action = torch.tensor(env.action_space.sample()) # doctest: +SKIP
|
||||
>>> obs, reward, terminated, truncated, info = env.step(action) # doctest: +SKIP
|
||||
>>> type(obs) # doctest: +SKIP
|
||||
<class 'torch.Tensor'>
|
||||
>>> type(reward) # doctest: +SKIP
|
||||
<class 'float'>
|
||||
>>> type(terminated) # doctest: +SKIP
|
||||
<class 'bool'>
|
||||
>>> type(truncated) # doctest: +SKIP
|
||||
<class 'bool'>
|
||||
|
||||
Change logs:
|
||||
* v1.2.0 - Initially added
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env: gym.Env,
|
||||
env_xp: ModuleType,
|
||||
target_xp: ModuleType,
|
||||
env_device: Device | None = None,
|
||||
target_device: Device | None = None,
|
||||
):
|
||||
"""Wrapper class to change inputs and outputs of environment to any Array API framework.
|
||||
|
||||
Args:
|
||||
env: The Array API compatible environment to wrap
|
||||
env_xp: The Array API framework the environment is on
|
||||
target_xp: The Array API framework to convert to
|
||||
env_device: The device the environment is on
|
||||
target_device: The device on which Arrays should be returned
|
||||
"""
|
||||
gym.utils.RecordConstructorArgs.__init__(self)
|
||||
gym.Wrapper.__init__(self, env)
|
||||
|
||||
self._env_xp = module_namespace(env_xp)
|
||||
self._target_xp = module_namespace(target_xp)
|
||||
self._env_device: Device | None = env_device
|
||||
self._target_device: Device | None = target_device
|
||||
|
||||
def step(
|
||||
self, action: WrapperActType
|
||||
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]:
|
||||
"""Performs the given action within the environment.
|
||||
|
||||
Args:
|
||||
action: The action to perform as any Array API compatible array
|
||||
|
||||
Returns:
|
||||
The next observation, reward, termination, truncation, and extra info
|
||||
"""
|
||||
action = array_conversion(action, xp=self._env_xp, device=self._env_device)
|
||||
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||
|
||||
return (
|
||||
array_conversion(obs, xp=self._target_xp, device=self._target_device),
|
||||
float(reward),
|
||||
bool(terminated),
|
||||
bool(truncated),
|
||||
array_conversion(info, xp=self._target_xp, device=self._target_device),
|
||||
)
|
||||
|
||||
def reset(
|
||||
self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
||||
) -> tuple[WrapperObsType, dict[str, Any]]:
|
||||
"""Resets the environment returning observation and info as Array from any Array API compatible framework.
|
||||
|
||||
Args:
|
||||
seed: The seed for resetting the environment
|
||||
options: The options for resetting the environment, these are converted to jax arrays.
|
||||
|
||||
Returns:
|
||||
xp-based observations and info
|
||||
"""
|
||||
if options:
|
||||
options = array_conversion(options, self._env_xp, self._env_device)
|
||||
|
||||
return array_conversion(
|
||||
self.env.reset(seed=seed, options=options),
|
||||
self._target_xp,
|
||||
self._target_device,
|
||||
)
|
||||
|
||||
def render(self) -> RenderFrame | list[RenderFrame] | None:
|
||||
"""Returns the rendered frames as an xp Array."""
|
||||
return array_conversion(self.env.render(), self._target_xp, self._target_device)
|
||||
|
||||
def __getstate__(self):
|
||||
"""Returns the object pickle state with args and kwargs."""
|
||||
env_xp_name = self._env_xp.__name__.replace("array_api_compat.", "")
|
||||
target_xp_name = self._target_xp.__name__.replace("array_api_compat.", "")
|
||||
env_device = self._env_device
|
||||
target_device = self._target_device
|
||||
return {
|
||||
"env_xp_name": env_xp_name,
|
||||
"target_xp_name": target_xp_name,
|
||||
"env_device": env_device,
|
||||
"target_device": target_device,
|
||||
"env": self.env,
|
||||
}
|
||||
|
||||
def __setstate__(self, d):
|
||||
"""Sets the object pickle state using d."""
|
||||
self.env = d["env"]
|
||||
self._env_xp = module_name_to_namespace(d["env_xp_name"])
|
||||
self._target_xp = module_name_to_namespace(d["target_xp_name"])
|
||||
self._env_device = d["env_device"]
|
||||
self._target_device = d["target_device"]
|
@@ -3,19 +3,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import numbers
|
||||
from collections import abc
|
||||
from typing import Any, Iterable, Mapping, SupportsFloat
|
||||
|
||||
import numpy as np
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.core import ActType, ObsType, RenderFrame, WrapperActType, WrapperObsType
|
||||
from gymnasium.core import ActType, ObsType
|
||||
from gymnasium.error import DependencyNotInstalled
|
||||
from gymnasium.wrappers.array_conversion import (
|
||||
ArrayConversion,
|
||||
array_conversion,
|
||||
module_namespace,
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
except ImportError:
|
||||
raise DependencyNotInstalled(
|
||||
@@ -24,110 +25,13 @@ except ImportError:
|
||||
|
||||
__all__ = ["JaxToNumpy", "jax_to_numpy", "numpy_to_jax"]
|
||||
|
||||
# The NoneType is not defined in Python 3.9. Remove when the minimal version is bumped to >=3.10
|
||||
_NoneType = type(None)
|
||||
|
||||
jax_to_numpy = functools.partial(array_conversion, xp=module_namespace(np))
|
||||
|
||||
numpy_to_jax = functools.partial(array_conversion, xp=module_namespace(jnp))
|
||||
|
||||
|
||||
@functools.singledispatch
|
||||
def numpy_to_jax(value: Any) -> Any:
|
||||
"""Converts a value to a Jax Array."""
|
||||
raise Exception(
|
||||
f"No known conversion for Numpy type ({type(value)}) to Jax registered. Report as issue on github."
|
||||
)
|
||||
|
||||
|
||||
@numpy_to_jax.register(numbers.Number)
|
||||
def _number_to_jax(
|
||||
value: numbers.Number,
|
||||
) -> jax.Array:
|
||||
"""Converts a number (int, float, etc.) to a Jax Array."""
|
||||
assert jnp is not None
|
||||
return jnp.array(value)
|
||||
|
||||
|
||||
@numpy_to_jax.register(np.ndarray)
|
||||
def _numpy_array_to_jax(value: np.ndarray) -> jax.Array:
|
||||
"""Converts a NumPy Array to a Jax Array with the same dtype (excluding float64 without being enabled)."""
|
||||
assert jnp is not None
|
||||
return jnp.array(value, dtype=value.dtype)
|
||||
|
||||
|
||||
@numpy_to_jax.register(abc.Mapping)
|
||||
def _mapping_numpy_to_jax(value: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
"""Converts a dictionary of numpy arrays to a mapping of Jax Array."""
|
||||
return type(value)(**{k: numpy_to_jax(v) for k, v in value.items()})
|
||||
|
||||
|
||||
@numpy_to_jax.register(abc.Iterable)
|
||||
def _iterable_numpy_to_jax(
|
||||
value: Iterable[np.ndarray | Any],
|
||||
) -> Iterable[jax.Array | Any]:
|
||||
"""Converts an Iterable from Numpy Arrays to an iterable of Jax Array."""
|
||||
if hasattr(value, "_make"):
|
||||
# namedtuple - underline used to prevent potential name conflicts
|
||||
# noinspection PyProtectedMember
|
||||
return type(value)._make(numpy_to_jax(v) for v in value)
|
||||
else:
|
||||
return type(value)(numpy_to_jax(v) for v in value)
|
||||
|
||||
|
||||
@numpy_to_jax.register(_NoneType)
|
||||
def _none_numpy_to_jax(value: None) -> None:
|
||||
"""Passes through None values."""
|
||||
return value
|
||||
|
||||
|
||||
@functools.singledispatch
|
||||
def jax_to_numpy(value: Any) -> Any:
|
||||
"""Converts a value to a numpy array."""
|
||||
raise Exception(
|
||||
f"No known conversion for Jax type ({type(value)}) to NumPy registered. Report as issue on github."
|
||||
)
|
||||
|
||||
|
||||
@jax_to_numpy.register(jax.Array)
|
||||
def _devicearray_jax_to_numpy(value: jax.Array) -> np.ndarray:
|
||||
"""Converts a Jax Array to a numpy array."""
|
||||
return np.array(value)
|
||||
|
||||
|
||||
@jax_to_numpy.register(abc.Mapping)
|
||||
def _mapping_jax_to_numpy(
|
||||
value: Mapping[str, jax.Array | Any],
|
||||
) -> Mapping[str, np.ndarray | Any]:
|
||||
"""Converts a dictionary of Jax Array to a mapping of numpy arrays."""
|
||||
return type(value)(**{k: jax_to_numpy(v) for k, v in value.items()})
|
||||
|
||||
|
||||
@jax_to_numpy.register(abc.Iterable)
|
||||
def _iterable_jax_to_numpy(
|
||||
value: Iterable[np.ndarray | Any],
|
||||
) -> Iterable[jax.Array | Any]:
|
||||
"""Converts an Iterable from Numpy arrays to an iterable of Jax Array."""
|
||||
if isinstance(value, jax.Array):
|
||||
# Since the update to jax 0.6.0, calling jax_to_numpy with a <class 'jaxlib.xla_extension.ArrayImpl'>
|
||||
# argument wrongly dispatches to _iterable_jax_to_numpy which fails with:
|
||||
# TypeError: (): incompatible function arguments.
|
||||
# See: https://github.com/Farama-Foundation/Gymnasium/issues/1360
|
||||
return _devicearray_jax_to_numpy(value)
|
||||
elif hasattr(value, "_make"):
|
||||
# namedtuple - underline used to prevent potential name conflicts
|
||||
# noinspection PyProtectedMember
|
||||
return type(value)._make(jax_to_numpy(v) for v in value)
|
||||
else:
|
||||
return type(value)(jax_to_numpy(v) for v in value)
|
||||
|
||||
|
||||
@jax_to_numpy.register(_NoneType)
|
||||
def _none_jax_to_numpy(value: None) -> None:
|
||||
"""Passes through None values."""
|
||||
return value
|
||||
|
||||
|
||||
class JaxToNumpy(
|
||||
gym.Wrapper[WrapperObsType, WrapperActType, ObsType, ActType],
|
||||
gym.utils.RecordConstructorArgs,
|
||||
):
|
||||
class JaxToNumpy(ArrayConversion):
|
||||
"""Wraps a Jax-based environment such that it can be interacted with NumPy arrays.
|
||||
|
||||
Actions must be provided as numpy arrays and observations will be returned as numpy arrays.
|
||||
@@ -169,48 +73,4 @@ class JaxToNumpy(
|
||||
raise DependencyNotInstalled(
|
||||
'Jax is not installed, run `pip install "gymnasium[jax]"`'
|
||||
)
|
||||
gym.utils.RecordConstructorArgs.__init__(self)
|
||||
gym.Wrapper.__init__(self, env)
|
||||
|
||||
def step(
|
||||
self, action: WrapperActType
|
||||
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]:
|
||||
"""Transforms the action to a jax array .
|
||||
|
||||
Args:
|
||||
action: the action to perform as a numpy array
|
||||
|
||||
Returns:
|
||||
A tuple containing numpy versions of the next observation, reward, termination, truncation, and extra info.
|
||||
"""
|
||||
jax_action = numpy_to_jax(action)
|
||||
obs, reward, terminated, truncated, info = self.env.step(jax_action)
|
||||
|
||||
return (
|
||||
jax_to_numpy(obs),
|
||||
float(reward),
|
||||
bool(terminated),
|
||||
bool(truncated),
|
||||
jax_to_numpy(info),
|
||||
)
|
||||
|
||||
def reset(
|
||||
self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
||||
) -> tuple[WrapperObsType, dict[str, Any]]:
|
||||
"""Resets the environment returning numpy-based observation and info.
|
||||
|
||||
Args:
|
||||
seed: The seed for resetting the environment
|
||||
options: The options for resetting the environment, these are converted to jax arrays.
|
||||
|
||||
Returns:
|
||||
Numpy-based observations and info
|
||||
"""
|
||||
if options:
|
||||
options = numpy_to_jax(options)
|
||||
|
||||
return jax_to_numpy(self.env.reset(seed=seed, options=options))
|
||||
|
||||
def render(self) -> RenderFrame | list[RenderFrame] | None:
|
||||
"""Returns the rendered frames as a numpy array."""
|
||||
return jax_to_numpy(self.env.render())
|
||||
super().__init__(env=env, env_xp=jnp, target_xp=np)
|
||||
|
@@ -7,22 +7,23 @@
|
||||
# Under the Apache 2.0 license. Copyright is held by the authors
|
||||
|
||||
"""Helper functions and wrapper class for converting between PyTorch and Jax."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import numbers
|
||||
from collections import abc
|
||||
from typing import Any, Iterable, Mapping, SupportsFloat, Union
|
||||
from typing import Union
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.core import RenderFrame, WrapperActType, WrapperObsType
|
||||
from gymnasium.error import DependencyNotInstalled
|
||||
from gymnasium.wrappers.array_conversion import (
|
||||
ArrayConversion,
|
||||
array_conversion,
|
||||
module_namespace,
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import dlpack as jax_dlpack
|
||||
|
||||
except ImportError:
|
||||
raise DependencyNotInstalled(
|
||||
@@ -31,7 +32,6 @@ except ImportError:
|
||||
|
||||
try:
|
||||
import torch
|
||||
from torch.utils import dlpack as torch_dlpack
|
||||
|
||||
Device = Union[str, torch.device]
|
||||
except ImportError:
|
||||
@@ -42,109 +42,13 @@ except ImportError:
|
||||
|
||||
__all__ = ["JaxToTorch", "jax_to_torch", "torch_to_jax", "Device"]
|
||||
|
||||
# The NoneType is not defined in Python 3.9. Remove when the minimal version is bumped to >=3.10
|
||||
_NoneType = type(None)
|
||||
|
||||
torch_to_jax = functools.partial(array_conversion, xp=module_namespace(jnp))
|
||||
|
||||
jax_to_torch = functools.partial(array_conversion, xp=module_namespace(torch))
|
||||
|
||||
|
||||
@functools.singledispatch
|
||||
def torch_to_jax(value: Any) -> Any:
|
||||
"""Converts a PyTorch Tensor into a Jax Array."""
|
||||
raise Exception(
|
||||
f"No known conversion for Torch type ({type(value)}) to Jax registered. Report as issue on github."
|
||||
)
|
||||
|
||||
|
||||
@torch_to_jax.register(numbers.Number)
|
||||
def _number_torch_to_jax(value: numbers.Number) -> Any:
|
||||
"""Convert a python number (int, float, complex) to a jax array."""
|
||||
return jnp.array(value)
|
||||
|
||||
|
||||
@torch_to_jax.register(torch.Tensor)
|
||||
def _tensor_torch_to_jax(value: torch.Tensor) -> jax.Array:
|
||||
"""Converts a PyTorch Tensor into a Jax Array."""
|
||||
return jax_dlpack.from_dlpack(value) # pyright: ignore[reportPrivateImportUsage]
|
||||
|
||||
|
||||
@torch_to_jax.register(abc.Mapping)
|
||||
def _mapping_torch_to_jax(value: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
"""Converts a mapping of PyTorch Tensors into a Dictionary of Jax Array."""
|
||||
return type(value)(**{k: torch_to_jax(v) for k, v in value.items()})
|
||||
|
||||
|
||||
@torch_to_jax.register(abc.Iterable)
|
||||
def _iterable_torch_to_jax(value: Iterable[Any]) -> Iterable[Any]:
|
||||
"""Converts an Iterable from PyTorch Tensors to an iterable of Jax Array."""
|
||||
if hasattr(value, "_make"):
|
||||
# namedtuple - underline used to prevent potential name conflicts
|
||||
# noinspection PyProtectedMember
|
||||
return type(value)._make(torch_to_jax(v) for v in value)
|
||||
else:
|
||||
return type(value)(torch_to_jax(v) for v in value)
|
||||
|
||||
|
||||
@torch_to_jax.register(_NoneType)
|
||||
def _none_torch_to_jax(value: None) -> None:
|
||||
"""Passes through None values."""
|
||||
return value
|
||||
|
||||
|
||||
@functools.singledispatch
|
||||
def jax_to_torch(value: Any, device: Device | None = None) -> Any:
|
||||
"""Converts a Jax Array into a PyTorch Tensor."""
|
||||
raise Exception(
|
||||
f"No known conversion for Jax type ({type(value)}) to PyTorch registered. Report as issue on github."
|
||||
)
|
||||
|
||||
|
||||
@jax_to_torch.register(jax.Array)
|
||||
def _devicearray_jax_to_torch(
|
||||
value: jax.Array, device: Device | None = None
|
||||
) -> torch.Tensor:
|
||||
"""Converts a Jax Array into a PyTorch Tensor."""
|
||||
assert jax_dlpack is not None and torch_dlpack is not None
|
||||
tensor = torch_dlpack.from_dlpack(
|
||||
value
|
||||
) # pyright: ignore[reportPrivateImportUsage]
|
||||
if device:
|
||||
return tensor.to(device=device)
|
||||
return tensor
|
||||
|
||||
|
||||
@jax_to_torch.register(abc.Mapping)
|
||||
def _jax_mapping_to_torch(
|
||||
value: Mapping[str, Any], device: Device | None = None
|
||||
) -> Mapping[str, Any]:
|
||||
"""Converts a mapping of Jax Array into a Dictionary of PyTorch Tensors."""
|
||||
return type(value)(**{k: jax_to_torch(v, device) for k, v in value.items()})
|
||||
|
||||
|
||||
@jax_to_torch.register(abc.Iterable)
|
||||
def _jax_iterable_to_torch(
|
||||
value: Iterable[Any], device: Device | None = None
|
||||
) -> Iterable[Any]:
|
||||
"""Converts an Iterable from Jax Array to an iterable of PyTorch Tensors."""
|
||||
if isinstance(value, jax.Array):
|
||||
# Since the update to jax 0.6.0, calling jax_to_torch with a <class 'jaxlib.xla_extension.ArrayImpl'>
|
||||
# argument wrongly dispatches to _iterable_jax_to_torch which fails with:
|
||||
# TypeError: (): incompatible function arguments.
|
||||
# See: https://github.com/Farama-Foundation/Gymnasium/issues/1360
|
||||
return _devicearray_jax_to_torch(value)
|
||||
elif hasattr(value, "_make"):
|
||||
# namedtuple - underline used to prevent potential name conflicts
|
||||
# noinspection PyProtectedMember
|
||||
return type(value)._make(jax_to_torch(v, device) for v in value)
|
||||
else:
|
||||
return type(value)(jax_to_torch(v, device) for v in value)
|
||||
|
||||
|
||||
@jax_to_torch.register(_NoneType)
|
||||
def _none_jax_to_torch(value: None, device: Device | None = None) -> None:
|
||||
"""Passes through None values."""
|
||||
return value
|
||||
|
||||
|
||||
class JaxToTorch(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
||||
class JaxToTorch(ArrayConversion):
|
||||
"""Wraps a Jax-based environment so that it can be interacted with PyTorch Tensors.
|
||||
|
||||
Actions must be provided as PyTorch Tensors and observations will be returned as PyTorch Tensors.
|
||||
@@ -183,50 +87,8 @@ class JaxToTorch(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
||||
env: The Jax-based environment to wrap
|
||||
device: The device the torch Tensors should be moved to
|
||||
"""
|
||||
gym.utils.RecordConstructorArgs.__init__(self, device=device)
|
||||
gym.Wrapper.__init__(self, env)
|
||||
super().__init__(env=env, env_xp=jnp, target_xp=torch, target_device=device)
|
||||
|
||||
# TODO: Device was part of the public API, but should be removed in favor of _env_device and
|
||||
# _target_device.
|
||||
self.device: Device | None = device
|
||||
|
||||
def step(
|
||||
self, action: WrapperActType
|
||||
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]:
|
||||
"""Performs the given action within the environment.
|
||||
|
||||
Args:
|
||||
action: The action to perform as a PyTorch Tensor
|
||||
|
||||
Returns:
|
||||
The next observation, reward, termination, truncation, and extra info
|
||||
"""
|
||||
jax_action = torch_to_jax(action)
|
||||
obs, reward, terminated, truncated, info = self.env.step(jax_action)
|
||||
|
||||
return (
|
||||
jax_to_torch(obs, self.device),
|
||||
float(reward),
|
||||
bool(terminated),
|
||||
bool(truncated),
|
||||
jax_to_torch(info, self.device),
|
||||
)
|
||||
|
||||
def reset(
|
||||
self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
||||
) -> tuple[WrapperObsType, dict[str, Any]]:
|
||||
"""Resets the environment returning PyTorch-based observation and info.
|
||||
|
||||
Args:
|
||||
seed: The seed for resetting the environment
|
||||
options: The options for resetting the environment, these are converted to jax arrays.
|
||||
|
||||
Returns:
|
||||
PyTorch-based observations and info
|
||||
"""
|
||||
if options:
|
||||
options = torch_to_jax(options)
|
||||
|
||||
return jax_to_torch(self.env.reset(seed=seed, options=options), self.device)
|
||||
|
||||
def render(self) -> RenderFrame | list[RenderFrame] | None:
|
||||
"""Returns the rendered frames as a torch tensor."""
|
||||
return jax_to_torch(self.env.render())
|
||||
|
@@ -3,15 +3,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import numbers
|
||||
from collections import abc
|
||||
from typing import Any, Iterable, Mapping, SupportsFloat, Union
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.core import RenderFrame, WrapperActType, WrapperObsType
|
||||
from gymnasium.error import DependencyNotInstalled
|
||||
from gymnasium.wrappers.array_conversion import (
|
||||
ArrayConversion,
|
||||
array_conversion,
|
||||
module_namespace,
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
@@ -26,100 +28,13 @@ except ImportError:
|
||||
|
||||
__all__ = ["NumpyToTorch", "torch_to_numpy", "numpy_to_torch", "Device"]
|
||||
|
||||
# The NoneType is not defined in Python 3.9. Remove when the minimal version is bumped to >=3.10
|
||||
_NoneType = type(None)
|
||||
|
||||
torch_to_numpy = functools.partial(array_conversion, xp=module_namespace(np))
|
||||
|
||||
numpy_to_torch = functools.partial(array_conversion, xp=module_namespace(torch))
|
||||
|
||||
|
||||
@functools.singledispatch
|
||||
def torch_to_numpy(value: Any) -> Any:
|
||||
"""Converts a PyTorch Tensor into a NumPy Array."""
|
||||
raise Exception(
|
||||
f"No known conversion for Torch type ({type(value)}) to NumPy registered. Report as issue on github."
|
||||
)
|
||||
|
||||
|
||||
@torch_to_numpy.register(numbers.Number)
|
||||
def _number_to_numpy(value: numbers.Number) -> Any:
|
||||
"""Convert a python number (int, float, complex) to a NumPy array."""
|
||||
return np.array(value)
|
||||
|
||||
|
||||
@torch_to_numpy.register(torch.Tensor)
|
||||
def _torch_to_numpy(value: torch.Tensor) -> Any:
|
||||
"""Convert a torch.Tensor to a NumPy array."""
|
||||
return value.numpy(force=True)
|
||||
|
||||
|
||||
@torch_to_numpy.register(abc.Mapping)
|
||||
def _mapping_torch_to_numpy(value: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
"""Converts a mapping of PyTorch Tensors into a Dictionary of NumPy Array."""
|
||||
return type(value)(**{k: torch_to_numpy(v) for k, v in value.items()})
|
||||
|
||||
|
||||
@torch_to_numpy.register(abc.Iterable)
|
||||
def _iterable_torch_to_numpy(value: Iterable[Any]) -> Iterable[Any]:
|
||||
"""Converts an Iterable from PyTorch Tensors to an iterable of NumPy Array."""
|
||||
if hasattr(value, "_make"):
|
||||
# namedtuple - underline used to prevent potential name conflicts
|
||||
# noinspection PyProtectedMember
|
||||
return type(value)._make(torch_to_numpy(v) for v in value)
|
||||
else:
|
||||
return type(value)(torch_to_numpy(v) for v in value)
|
||||
|
||||
|
||||
@torch_to_numpy.register(_NoneType)
|
||||
def _none_torch_to_numpy(value: None) -> None:
|
||||
"""Passes through None values."""
|
||||
return value
|
||||
|
||||
|
||||
@functools.singledispatch
|
||||
def numpy_to_torch(value: Any, device: Device | None = None) -> Any:
|
||||
"""Converts a NumPy Array into a PyTorch Tensor."""
|
||||
raise Exception(
|
||||
f"No known conversion for NumPy type ({type(value)}) to PyTorch registered. Report as issue on github."
|
||||
)
|
||||
|
||||
|
||||
@numpy_to_torch.register(numbers.Number)
|
||||
@numpy_to_torch.register(np.ndarray)
|
||||
def _numpy_to_torch(value: np.ndarray, device: Device | None = None) -> torch.Tensor:
|
||||
"""Converts a NumPy Array into a PyTorch Tensor."""
|
||||
assert torch is not None
|
||||
tensor = torch.tensor(value)
|
||||
if device:
|
||||
return tensor.to(device=device)
|
||||
return tensor
|
||||
|
||||
|
||||
@numpy_to_torch.register(abc.Mapping)
|
||||
def _numpy_mapping_to_torch(
|
||||
value: Mapping[str, Any], device: Device | None = None
|
||||
) -> Mapping[str, Any]:
|
||||
"""Converts a mapping of NumPy Array into a Dictionary of PyTorch Tensors."""
|
||||
return type(value)(**{k: numpy_to_torch(v, device) for k, v in value.items()})
|
||||
|
||||
|
||||
@numpy_to_torch.register(abc.Iterable)
|
||||
def _numpy_iterable_to_torch(
|
||||
value: Iterable[Any], device: Device | None = None
|
||||
) -> Iterable[Any]:
|
||||
"""Converts an Iterable from NumPy Array to an iterable of PyTorch Tensors."""
|
||||
if hasattr(value, "_make"):
|
||||
# namedtuple - underline used to prevent potential name conflicts
|
||||
# noinspection PyProtectedMember
|
||||
return type(value)._make(numpy_to_torch(v, device) for v in value)
|
||||
else:
|
||||
return type(value)(numpy_to_torch(v, device) for v in value)
|
||||
|
||||
|
||||
@numpy_to_torch.register(_NoneType)
|
||||
def _none_numpy_to_torch(value: None) -> None:
|
||||
"""Passes through None values."""
|
||||
return value
|
||||
|
||||
|
||||
class NumpyToTorch(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
||||
class NumpyToTorch(ArrayConversion):
|
||||
"""Wraps a NumPy-based environment such that it can be interacted with PyTorch Tensors.
|
||||
|
||||
Actions must be provided as PyTorch Tensors and observations will be returned as PyTorch Tensors.
|
||||
@@ -158,50 +73,6 @@ class NumpyToTorch(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
||||
env: The NumPy-based environment to wrap
|
||||
device: The device the torch Tensors should be moved to
|
||||
"""
|
||||
gym.utils.RecordConstructorArgs.__init__(self, device=device)
|
||||
gym.Wrapper.__init__(self, env)
|
||||
super().__init__(env=env, env_xp=np, target_xp=torch, target_device=device)
|
||||
|
||||
self.device: Device | None = device
|
||||
|
||||
def step(
|
||||
self, action: WrapperActType
|
||||
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]:
|
||||
"""Using a PyTorch based action that is converted to NumPy to be used by the environment.
|
||||
|
||||
Args:
|
||||
action: A PyTorch-based action
|
||||
|
||||
Returns:
|
||||
The PyTorch-based Tensor next observation, reward, termination, truncation, and extra info
|
||||
"""
|
||||
jax_action = torch_to_numpy(action)
|
||||
obs, reward, terminated, truncated, info = self.env.step(jax_action)
|
||||
|
||||
return (
|
||||
numpy_to_torch(obs, self.device),
|
||||
float(reward),
|
||||
bool(terminated),
|
||||
bool(truncated),
|
||||
numpy_to_torch(info, self.device),
|
||||
)
|
||||
|
||||
def reset(
|
||||
self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
||||
) -> tuple[WrapperObsType, dict[str, Any]]:
|
||||
"""Resets the environment returning PyTorch-based observation and info.
|
||||
|
||||
Args:
|
||||
seed: The seed for resetting the environment
|
||||
options: The options for resetting the environment, these are converted to jax arrays.
|
||||
|
||||
Returns:
|
||||
PyTorch-based observations and info
|
||||
"""
|
||||
if options:
|
||||
options = torch_to_numpy(options)
|
||||
|
||||
return numpy_to_torch(self.env.reset(seed=seed, options=options), self.device)
|
||||
|
||||
def render(self) -> RenderFrame | list[RenderFrame] | None:
|
||||
"""Returns the rendered frames as a torch tensor."""
|
||||
return numpy_to_torch(self.env.render())
|
||||
|
131
gymnasium/wrappers/vector/array_conversion.py
Normal file
131
gymnasium/wrappers/vector/array_conversion.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""Vector wrapper for converting between Array API compatible frameworks."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import ModuleType
|
||||
from typing import Any
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.core import ActType, ObsType
|
||||
from gymnasium.vector import VectorEnv, VectorWrapper
|
||||
from gymnasium.vector.vector_env import ArrayType
|
||||
from gymnasium.wrappers.array_conversion import (
|
||||
Device,
|
||||
array_conversion,
|
||||
module_name_to_namespace,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["ArrayConversion"]
|
||||
|
||||
|
||||
class ArrayConversion(VectorWrapper, gym.utils.RecordConstructorArgs):
|
||||
"""Wraps a vector environment returning Array API compatible arrays so that it can be interacted with through a specific framework.
|
||||
|
||||
Notes:
|
||||
A vectorized version of ``gymnasium.wrappers.ArrayConversion``
|
||||
|
||||
Actions must be provided as Array API compatible arrays and observations, rewards, terminations and truncations will be returned in the desired framework.
|
||||
xp here is a module that is compatible with the Array API standard, e.g. ``numpy``, ``jax`` etc.
|
||||
|
||||
Example:
|
||||
>>> import gymnasium as gym # doctest: +SKIP
|
||||
>>> envs = gym.make_vec("JaxEnv-vx", 3) # doctest: +SKIP
|
||||
>>> envs = ArrayConversion(envs, xp=np) # doctest: +SKIP
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env: VectorEnv,
|
||||
env_xp: ModuleType | str,
|
||||
target_xp: ModuleType | str,
|
||||
env_device: Device | None = None,
|
||||
target_device: Device | None = None,
|
||||
):
|
||||
"""Wrapper class to change inputs and outputs of environment to any Array API framework.
|
||||
|
||||
Args:
|
||||
env: The Array API compatible environment to wrap
|
||||
env_xp: The Array API framework the environment is on
|
||||
target_xp: The Array API framework to convert to
|
||||
env_device: The device the environment is on
|
||||
target_device: The device on which Arrays should be returned
|
||||
"""
|
||||
gym.utils.RecordConstructorArgs.__init__(self)
|
||||
VectorWrapper.__init__(self, env)
|
||||
self._env_xp = env_xp
|
||||
self._target_xp = target_xp
|
||||
self._env_device = env_device
|
||||
self._target_device = target_device
|
||||
|
||||
def step(
|
||||
self, actions: ActType
|
||||
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]:
|
||||
"""Transforms the action to the specified xp module array type.
|
||||
|
||||
Args:
|
||||
actions: The action to perform
|
||||
|
||||
Returns:
|
||||
A tuple containing xp versions of the next observation, reward, termination, truncation, and extra info.
|
||||
"""
|
||||
actions = array_conversion(actions, xp=self._env_xp, device=self._env_device)
|
||||
obs, reward, terminated, truncated, info = self.env.step(actions)
|
||||
|
||||
return (
|
||||
array_conversion(obs, xp=self._target_xp, device=self._target_device),
|
||||
array_conversion(reward, xp=self._target_xp, device=self._target_device),
|
||||
array_conversion(
|
||||
terminated, xp=self._target_xp, device=self._target_device
|
||||
),
|
||||
array_conversion(truncated, xp=self._target_xp, device=self._target_device),
|
||||
array_conversion(info, xp=self._target_xp, device=self._target_device),
|
||||
)
|
||||
|
||||
def reset(
|
||||
self,
|
||||
*,
|
||||
seed: int | list[int] | None = None,
|
||||
options: dict[str, Any] | None = None,
|
||||
) -> tuple[ObsType, dict[str, Any]]:
|
||||
"""Resets the environment returning xp-based observation and info.
|
||||
|
||||
Args:
|
||||
seed: The seed for resetting the environment
|
||||
options: The options for resetting the environment, these are converted to xp arrays.
|
||||
|
||||
Returns:
|
||||
xp-based observations and info
|
||||
"""
|
||||
if options:
|
||||
options = array_conversion(
|
||||
options, xp=self._env_xp, device=self._env_device
|
||||
)
|
||||
|
||||
return array_conversion(
|
||||
self.env.reset(seed=seed, options=options),
|
||||
xp=self._target_xp,
|
||||
device=self._target_device,
|
||||
)
|
||||
|
||||
def __getstate__(self):
|
||||
"""Returns the object pickle state with args and kwargs."""
|
||||
env_xp_name = self._env_xp.__name__.replace("array_api_compat.", "")
|
||||
target_xp_name = self._target_xp.__name__.replace("array_api_compat.", "")
|
||||
env_device = self._env_device
|
||||
target_device = self._target_device
|
||||
return {
|
||||
"env_xp_name": env_xp_name,
|
||||
"target_xp_name": target_xp_name,
|
||||
"env_device": env_device,
|
||||
"target_device": target_device,
|
||||
"env": self.env,
|
||||
}
|
||||
|
||||
def __setstate__(self, d):
|
||||
"""Sets the object pickle state using d."""
|
||||
self.env = d["env"]
|
||||
self._env_xp = module_name_to_namespace(d["env_xp_name"])
|
||||
self._target_xp = module_name_to_namespace(d["target_xp_name"])
|
||||
self._env_device = d["env_device"]
|
||||
self._target_device = d["target_device"]
|
@@ -2,21 +2,18 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
from gymnasium.core import ActType, ObsType
|
||||
from gymnasium.error import DependencyNotInstalled
|
||||
from gymnasium.vector import VectorEnv, VectorWrapper
|
||||
from gymnasium.vector.vector_env import ArrayType
|
||||
from gymnasium.wrappers.jax_to_numpy import jax_to_numpy, numpy_to_jax
|
||||
from gymnasium.vector import VectorEnv
|
||||
from gymnasium.wrappers.vector.array_conversion import ArrayConversion
|
||||
|
||||
|
||||
__all__ = ["JaxToNumpy"]
|
||||
|
||||
|
||||
class JaxToNumpy(VectorWrapper):
|
||||
class JaxToNumpy(ArrayConversion):
|
||||
"""Wraps a jax vector environment so that it can be interacted with through numpy arrays.
|
||||
|
||||
Notes:
|
||||
@@ -40,46 +37,4 @@ class JaxToNumpy(VectorWrapper):
|
||||
raise DependencyNotInstalled(
|
||||
'Jax is not installed, run `pip install "gymnasium[jax]"`'
|
||||
)
|
||||
super().__init__(env)
|
||||
|
||||
def step(
|
||||
self, actions: ActType
|
||||
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]:
|
||||
"""Transforms the action to a jax array .
|
||||
|
||||
Args:
|
||||
actions: the action to perform as a numpy array
|
||||
|
||||
Returns:
|
||||
A tuple containing numpy versions of the next observation, reward, termination, truncation, and extra info.
|
||||
"""
|
||||
jax_actions = numpy_to_jax(actions)
|
||||
obs, reward, terminated, truncated, info = self.env.step(jax_actions)
|
||||
|
||||
return (
|
||||
jax_to_numpy(obs),
|
||||
jax_to_numpy(reward),
|
||||
jax_to_numpy(terminated),
|
||||
jax_to_numpy(truncated),
|
||||
jax_to_numpy(info),
|
||||
)
|
||||
|
||||
def reset(
|
||||
self,
|
||||
*,
|
||||
seed: int | list[int] | None = None,
|
||||
options: dict[str, Any] | None = None,
|
||||
) -> tuple[ObsType, dict[str, Any]]:
|
||||
"""Resets the environment returning numpy-based observation and info.
|
||||
|
||||
Args:
|
||||
seed: The seed for resetting the environment
|
||||
options: The options for resetting the environment, these are converted to jax arrays.
|
||||
|
||||
Returns:
|
||||
Numpy-based observations and info
|
||||
"""
|
||||
if options:
|
||||
options = numpy_to_jax(options)
|
||||
|
||||
return jax_to_numpy(self.env.reset(seed=seed, options=options))
|
||||
super().__init__(env, env_xp=jnp, target_xp=np)
|
||||
|
@@ -2,18 +2,18 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
import jax.numpy as jnp
|
||||
import torch
|
||||
|
||||
from gymnasium.core import ActType, ObsType
|
||||
from gymnasium.vector import VectorEnv, VectorWrapper
|
||||
from gymnasium.vector.vector_env import ArrayType
|
||||
from gymnasium.wrappers.jax_to_torch import Device, jax_to_torch, torch_to_jax
|
||||
from gymnasium.vector import VectorEnv
|
||||
from gymnasium.wrappers.jax_to_torch import Device
|
||||
from gymnasium.wrappers.vector.array_conversion import ArrayConversion
|
||||
|
||||
|
||||
__all__ = ["JaxToTorch"]
|
||||
|
||||
|
||||
class JaxToTorch(VectorWrapper):
|
||||
class JaxToTorch(ArrayConversion):
|
||||
"""Wraps a Jax-based vector environment so that it can be interacted with through PyTorch Tensors.
|
||||
|
||||
Actions must be provided as PyTorch Tensors and observations, rewards, terminations and truncations will be returned as PyTorch Tensors.
|
||||
@@ -31,48 +31,6 @@ class JaxToTorch(VectorWrapper):
|
||||
env: The Jax-based vector environment to wrap
|
||||
device: The device the torch Tensors should be moved to
|
||||
"""
|
||||
super().__init__(env)
|
||||
super().__init__(env, env_xp=jnp, target_xp=torch, target_device=device)
|
||||
|
||||
self.device: Device | None = device
|
||||
|
||||
def step(
|
||||
self, actions: ActType
|
||||
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]:
|
||||
"""Performs the given action within the environment.
|
||||
|
||||
Args:
|
||||
actions: The action to perform as a PyTorch Tensor
|
||||
|
||||
Returns:
|
||||
Torch-based Tensors of the next observation, reward, termination, truncation, and extra info
|
||||
"""
|
||||
jax_action = torch_to_jax(actions)
|
||||
obs, reward, terminated, truncated, info = self.env.step(jax_action)
|
||||
|
||||
return (
|
||||
jax_to_torch(obs, self.device),
|
||||
jax_to_torch(reward, self.device),
|
||||
jax_to_torch(terminated, self.device),
|
||||
jax_to_torch(truncated, self.device),
|
||||
jax_to_torch(info, self.device),
|
||||
)
|
||||
|
||||
def reset(
|
||||
self,
|
||||
*,
|
||||
seed: int | list[int] | None = None,
|
||||
options: dict[str, Any] | None = None,
|
||||
) -> tuple[ObsType, dict[str, Any]]:
|
||||
"""Resets the environment returning PyTorch-based observation and info.
|
||||
|
||||
Args:
|
||||
seed: The seed for resetting the environment
|
||||
options: The options for resetting the environment, these are converted to jax arrays.
|
||||
|
||||
Returns:
|
||||
PyTorch-based observations and info
|
||||
"""
|
||||
if options:
|
||||
options = torch_to_jax(options)
|
||||
|
||||
return jax_to_torch(self.env.reset(seed=seed, options=options), self.device)
|
||||
|
@@ -2,18 +2,18 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from gymnasium.core import ActType, ObsType
|
||||
from gymnasium.vector import VectorEnv, VectorWrapper
|
||||
from gymnasium.vector.vector_env import ArrayType
|
||||
from gymnasium.wrappers.numpy_to_torch import Device, numpy_to_torch, torch_to_numpy
|
||||
from gymnasium.vector import VectorEnv
|
||||
from gymnasium.wrappers.numpy_to_torch import Device
|
||||
from gymnasium.wrappers.vector.array_conversion import ArrayConversion
|
||||
|
||||
|
||||
__all__ = ["NumpyToTorch"]
|
||||
|
||||
|
||||
class NumpyToTorch(VectorWrapper):
|
||||
class NumpyToTorch(ArrayConversion):
|
||||
"""Wraps a numpy-based environment so that it can be interacted with through PyTorch Tensors.
|
||||
|
||||
Example:
|
||||
@@ -45,48 +45,6 @@ class NumpyToTorch(VectorWrapper):
|
||||
env: The NumPy-based vector environment to wrap
|
||||
device: The device the torch Tensors should be moved to
|
||||
"""
|
||||
super().__init__(env)
|
||||
super().__init__(env, env_xp=np, target_xp=torch, target_device=device)
|
||||
|
||||
self.device: Device | None = device
|
||||
|
||||
def step(
|
||||
self, actions: ActType
|
||||
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]:
|
||||
"""Using a PyTorch based action that is converted to NumPy to be used by the environment.
|
||||
|
||||
Args:
|
||||
action: A PyTorch-based action
|
||||
|
||||
Returns:
|
||||
The PyTorch-based Tensor next observation, reward, termination, truncation, and extra info
|
||||
"""
|
||||
numpy_action = torch_to_numpy(actions)
|
||||
obs, reward, terminated, truncated, info = self.env.step(numpy_action)
|
||||
|
||||
return (
|
||||
numpy_to_torch(obs, self.device),
|
||||
numpy_to_torch(reward, self.device),
|
||||
numpy_to_torch(terminated, self.device),
|
||||
numpy_to_torch(truncated, self.device),
|
||||
numpy_to_torch(info, self.device),
|
||||
)
|
||||
|
||||
def reset(
|
||||
self,
|
||||
*,
|
||||
seed: int | list[int] | None = None,
|
||||
options: dict[str, Any] | None = None,
|
||||
) -> tuple[ObsType, dict[str, Any]]:
|
||||
"""Resets the environment returning PyTorch-based observation and info.
|
||||
|
||||
Args:
|
||||
seed: The seed for resetting the environment
|
||||
options: The options for resetting the environment, these are converted to NumPy arrays.
|
||||
|
||||
Returns:
|
||||
PyTorch-based observations and info
|
||||
"""
|
||||
if options:
|
||||
options = torch_to_numpy(options)
|
||||
|
||||
return numpy_to_torch(self.env.reset(seed=seed, options=options), self.device)
|
||||
|
@@ -8,7 +8,7 @@ build-backend = "setuptools.build_meta"
|
||||
name = "gymnasium"
|
||||
description = "A standard API for reinforcement learning and a diverse set of reference environments (formerly Gym)."
|
||||
readme = "README.md"
|
||||
requires-python = ">= 3.8"
|
||||
requires-python = ">= 3.10"
|
||||
authors = [{ name = "Farama Foundation", email = "contact@farama.org" }]
|
||||
license = { text = "MIT License" }
|
||||
keywords = ["Reinforcement Learning", "game", "RL", "AI", "gymnasium"]
|
||||
@@ -16,8 +16,6 @@ classifiers = [
|
||||
"Development Status :: 5 - Production/Stable",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
@@ -27,7 +25,6 @@ classifiers = [
|
||||
dependencies = [
|
||||
"numpy >=1.21.0",
|
||||
"cloudpickle >=1.2.0",
|
||||
"importlib-metadata >=4.8.0; python_version < '3.10'",
|
||||
"typing-extensions >=4.3.0",
|
||||
"farama-notifications >=0.0.1",
|
||||
]
|
||||
@@ -38,15 +35,30 @@ dynamic = ["version"]
|
||||
atari = ["ale_py >=0.9"]
|
||||
box2d = ["box2d-py ==2.3.5", "pygame >=2.1.3", "swig ==4.*"]
|
||||
classic-control = ["pygame >=2.1.3"]
|
||||
classic_control = ["pygame >=2.1.3"] # kept for backward compatibility
|
||||
classic_control = ["pygame >=2.1.3"] # kept for backward compatibility
|
||||
mujoco-py = ["mujoco-py >=2.1,<2.2", "cython<3"]
|
||||
mujoco_py = ["mujoco-py >=2.1,<2.2", "cython<3"] # kept for backward compatibility
|
||||
mujoco_py = [
|
||||
"mujoco-py >=2.1,<2.2",
|
||||
"cython<3",
|
||||
] # kept for backward compatibility
|
||||
mujoco = ["mujoco >=2.1.5", "imageio >=2.14.1", "packaging >=23.0"]
|
||||
toy-text = ["pygame >=2.1.3"]
|
||||
toy_text = ["pygame >=2.1.3"] # kept for backward compatibility
|
||||
jax = ["jax >=0.4.16", "jaxlib >=0.4.16", "flax >=0.5.0"]
|
||||
torch = ["torch >=1.13.0"]
|
||||
other = ["moviepy >=1.0.0", "matplotlib >=3.0", "opencv-python >=3.0", "seaborn >= 0.13"]
|
||||
toy_text = ["pygame >=2.1.3"] # kept for backward compatibility
|
||||
jax = [
|
||||
"jax >=0.4.16",
|
||||
"jaxlib >=0.4.16",
|
||||
"flax >=0.5.0",
|
||||
"array-api-compat >=1.11.0",
|
||||
"numpy>=2.1",
|
||||
]
|
||||
torch = ["torch >=1.13.0", "array-api-compat >=1.11.0", "numpy>=2.1"]
|
||||
array-api = ["array-api-compat >=1.11.0", "numpy>=2.1"]
|
||||
other = [
|
||||
"moviepy >=1.0.0",
|
||||
"matplotlib >=3.0",
|
||||
"opencv-python >=3.0",
|
||||
"seaborn >= 0.13",
|
||||
]
|
||||
all = [
|
||||
# All dependencies above except accept-rom-license
|
||||
# NOTE: No need to manually remove the duplicates, setuptools automatically does that.
|
||||
@@ -71,17 +83,26 @@ all = [
|
||||
"jax >=0.4.16",
|
||||
"jaxlib >=0.4.16",
|
||||
"flax >= 0.5.0",
|
||||
"array-api-compat >=1.11.0",
|
||||
"numpy>=2.1",
|
||||
# torch
|
||||
"torch >=1.13.0",
|
||||
"array-api-compat >=1.11.0",
|
||||
"numpy>=2.1",
|
||||
# array-api
|
||||
"array-api-compat >=1.11.0",
|
||||
"numpy>=2.1",
|
||||
# other
|
||||
"opencv-python >=3.0",
|
||||
"matplotlib >=3.0",
|
||||
"moviepy >=1.0.0",
|
||||
]
|
||||
|
||||
testing = [
|
||||
"pytest >=7.1.3",
|
||||
"scipy >=1.7.3",
|
||||
"dill >=0.3.7",
|
||||
"array_api_extra >=0.7.0",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
@@ -125,7 +146,7 @@ exclude = ["tests/**", "**/node_modules", "**/__pycache__"]
|
||||
strict = []
|
||||
|
||||
typeCheckingMode = "basic"
|
||||
pythonVersion = "3.8"
|
||||
pythonVersion = "3.10"
|
||||
pythonPlatform = "All"
|
||||
typeshedPath = "typeshed"
|
||||
enableTypeIgnoreComments = true
|
||||
@@ -138,19 +159,19 @@ reportMissingTypeStubs = false
|
||||
# For warning and error, will raise an error when
|
||||
reportInvalidTypeVarUse = "none"
|
||||
|
||||
reportGeneralTypeIssues = "none" # -> commented out raises 489 errors
|
||||
reportAttributeAccessIssue = "none" # pyright provides false positives
|
||||
reportArgumentType = "none" # pyright provides false positives
|
||||
reportGeneralTypeIssues = "none" # -> commented out raises 489 errors
|
||||
reportAttributeAccessIssue = "none" # pyright provides false positives
|
||||
reportArgumentType = "none" # pyright provides false positives
|
||||
|
||||
reportPrivateUsage = "warning"
|
||||
|
||||
reportIndexIssue = "none" # TODO fix one by one
|
||||
reportReturnType = "none" # TODO fix one by one
|
||||
reportCallIssue = "none" # TODO fix one by one
|
||||
reportOperatorIssue = "none" # TODO fix one by one
|
||||
reportInvalidTypeForm = "none" # TODO fix one by one
|
||||
reportOptionalMemberAccess = "none" # TODO fix one by one
|
||||
reportAssignmentType = "none" # TODO fix one by one
|
||||
reportIndexIssue = "none" # TODO fix one by one
|
||||
reportReturnType = "none" # TODO fix one by one
|
||||
reportCallIssue = "none" # TODO fix one by one
|
||||
reportOperatorIssue = "none" # TODO fix one by one
|
||||
reportInvalidTypeForm = "none" # TODO fix one by one
|
||||
reportOptionalMemberAccess = "none" # TODO fix one by one
|
||||
reportAssignmentType = "none" # TODO fix one by one
|
||||
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
|
@@ -6,10 +6,15 @@ import types
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ActType, ObsType
|
||||
from gymnasium.envs.registration import EnvSpec
|
||||
from gymnasium.vector import VectorEnv
|
||||
from gymnasium.vector.utils import batch_space
|
||||
from gymnasium.vector.vector_env import AutoresetMode
|
||||
|
||||
|
||||
def basic_reset_func(
|
||||
@@ -106,3 +111,112 @@ class GenericTestEnv(gym.Env):
|
||||
def render(self):
|
||||
"""Renders the environment."""
|
||||
raise NotImplementedError("testingEnv render_fn is not set.")
|
||||
|
||||
|
||||
def basic_vector_reset_func(
|
||||
self,
|
||||
*,
|
||||
seed: int | None = None,
|
||||
options: dict | None = None,
|
||||
) -> tuple[ObsType, dict]:
|
||||
"""A basic reset function that will pass the environment check using random actions from the observation space."""
|
||||
super(GenericTestVectorEnv, self).reset(seed=seed)
|
||||
self.observation_space.seed(self.np_random_seed)
|
||||
return self.observation_space.sample(), {"options": options}
|
||||
|
||||
|
||||
def basic_vector_step_func(
|
||||
self, action: ActType
|
||||
) -> tuple[ObsType, np.ndarray, np.ndarray, np.ndarray, dict]:
|
||||
"""A step function that follows the basic step api that will pass the environment check using random actions from the observation space."""
|
||||
obs = self.observation_space.sample()
|
||||
rewards = np.zeros(self.num_envs, dtype=np.float64)
|
||||
terminations = np.zeros(self.num_envs, dtype=np.bool_)
|
||||
truncations = np.zeros(self.num_envs, dtype=np.bool_)
|
||||
return obs, rewards, terminations, truncations, {}
|
||||
|
||||
|
||||
def basic_vector_render_func(self):
|
||||
"""Basic render fn that does nothing."""
|
||||
pass
|
||||
|
||||
|
||||
class GenericTestVectorEnv(VectorEnv):
|
||||
"""A generic testing vector environment similar to GenericTestEnv.
|
||||
|
||||
Some tests cannot use SyncVectorEnv, e.g. when returning non-numpy arrays in the observations.
|
||||
In these cases, GenericTestVectorEnv can be used to simulate a vector environment.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_envs: int = 1,
|
||||
action_space: spaces.Space = spaces.Box(0, 1, (1,)),
|
||||
observation_space: spaces.Space = spaces.Box(0, 1, (1,)),
|
||||
reset_func: Callable = basic_vector_reset_func,
|
||||
step_func: Callable = basic_vector_step_func,
|
||||
render_func: Callable = basic_vector_render_func,
|
||||
metadata: dict[str, Any] = {
|
||||
"render_modes": [],
|
||||
"autoreset_mode": AutoresetMode.NEXT_STEP,
|
||||
},
|
||||
render_mode: str | None = None,
|
||||
spec: EnvSpec = EnvSpec(
|
||||
"TestingVectorEnv-v0",
|
||||
"tests.testing_env:GenericTestVectorEnv",
|
||||
max_episode_steps=100,
|
||||
),
|
||||
):
|
||||
"""Generic testing vector environment constructor.
|
||||
|
||||
Args:
|
||||
num_envs: The number of environments to create
|
||||
action_space: The environment action space
|
||||
observation_space: The environment observation space
|
||||
reset_func: The environment reset function
|
||||
step_func: The environment step function
|
||||
render_func: The environment render function
|
||||
metadata: The environment metadata
|
||||
render_mode: The render mode of the environment
|
||||
spec: The environment spec
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.num_envs = num_envs
|
||||
self.metadata = metadata
|
||||
self.render_mode = render_mode
|
||||
self.spec = spec
|
||||
|
||||
# Set the single spaces and create batched spaces
|
||||
self.single_observation_space = observation_space
|
||||
self.single_action_space = action_space
|
||||
self.observation_space = batch_space(observation_space, num_envs)
|
||||
self.action_space = batch_space(action_space, num_envs)
|
||||
|
||||
# Bind the functions to the instance
|
||||
if reset_func is not None:
|
||||
self.reset = types.MethodType(reset_func, self)
|
||||
if step_func is not None:
|
||||
self.step = types.MethodType(step_func, self)
|
||||
if render_func is not None:
|
||||
self.render = types.MethodType(render_func, self)
|
||||
|
||||
def reset(
|
||||
self,
|
||||
*,
|
||||
seed: int | None = None,
|
||||
options: dict | None = None,
|
||||
) -> tuple[ObsType, dict]:
|
||||
"""Resets the environment."""
|
||||
# If you need a default working reset function, use `basic_vector_reset_fn` above
|
||||
raise NotImplementedError("TestingVectorEnv reset_fn is not set.")
|
||||
|
||||
def step(
|
||||
self, action: ActType
|
||||
) -> tuple[ObsType, np.ndarray, np.ndarray, np.ndarray, dict]:
|
||||
"""Steps through the environment."""
|
||||
raise NotImplementedError("TestingVectorEnv step_fn is not set.")
|
||||
|
||||
def render(self) -> tuple[Any, ...] | None:
|
||||
"""Renders the environment."""
|
||||
raise NotImplementedError("TestingVectorEnv render_fn is not set.")
|
||||
|
250
tests/wrappers/test_array_conversion.py
Normal file
250
tests/wrappers/test_array_conversion.py
Normal file
@@ -0,0 +1,250 @@
|
||||
"""Test suite for ArrayConversion wrapper."""
|
||||
|
||||
import importlib
|
||||
import itertools
|
||||
import pickle
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
import pytest
|
||||
|
||||
import gymnasium
|
||||
|
||||
|
||||
array_api_compat = pytest.importorskip("array_api_compat")
|
||||
array_api_extra = pytest.importorskip("array_api_extra")
|
||||
|
||||
from array_api_compat import array_namespace, is_array_api_obj # noqa: E402
|
||||
|
||||
from gymnasium.wrappers.array_conversion import ( # noqa: E402
|
||||
ArrayConversion,
|
||||
array_conversion,
|
||||
module_namespace,
|
||||
)
|
||||
from tests.testing_env import GenericTestEnv # noqa: E402
|
||||
|
||||
|
||||
# Define available modules
|
||||
installed_modules = []
|
||||
array_api_modules = [
|
||||
"numpy",
|
||||
"jax.numpy",
|
||||
"torch",
|
||||
"cupy",
|
||||
"dask.array",
|
||||
"sparse",
|
||||
"array_api_strict",
|
||||
]
|
||||
for module in array_api_modules:
|
||||
try:
|
||||
installed_modules.append(importlib.import_module(module))
|
||||
except ImportError:
|
||||
pass # Modules that are not installed are skipped
|
||||
|
||||
installed_modules_combinations = list(itertools.permutations(installed_modules, 2))
|
||||
|
||||
|
||||
def xp_data_equivalence(data_1, data_2) -> bool:
|
||||
"""Return if two variables are equivalent that might contain ``torch.Tensor``."""
|
||||
if type(data_1) is type(data_2):
|
||||
if isinstance(data_1, dict):
|
||||
return data_1.keys() == data_2.keys() and all(
|
||||
xp_data_equivalence(data_1[k], data_2[k]) for k in data_1.keys()
|
||||
)
|
||||
elif isinstance(data_1, (tuple, list)):
|
||||
return len(data_1) == len(data_2) and all(
|
||||
xp_data_equivalence(o_1, o_2) for o_1, o_2 in zip(data_1, data_2)
|
||||
)
|
||||
elif is_array_api_obj(data_1):
|
||||
return array_api_extra.isclose(data_1, data_2, atol=0.00001).all()
|
||||
else:
|
||||
return data_1 == data_2
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
class ExampleNamedTuple(NamedTuple):
|
||||
a: Any # Array API compatible object. Does not have proper typing support yet.
|
||||
b: Any # Same as a
|
||||
|
||||
|
||||
def _supports_higher_precision(xp, low_type, high_type):
|
||||
"""Check if an array module supports higher precision type."""
|
||||
return xp.result_type(low_type, high_type) == high_type
|
||||
|
||||
|
||||
# When converting between array modules (source → target → source), we need to ensure that the
|
||||
# precision used is supported by both modules. If either module only supports 32-bit types, we must
|
||||
# use the lower precision to account for the conversion during the roundtrip.
|
||||
def atleast_float32(source_xp, target_xp):
|
||||
"""Return source_xp.float64 if both modules support it, otherwise source_xp.float32."""
|
||||
source_supports_64 = _supports_higher_precision(
|
||||
source_xp, source_xp.float32, source_xp.float64
|
||||
)
|
||||
target_supports_64 = _supports_higher_precision(
|
||||
target_xp, target_xp.float32, target_xp.float64
|
||||
)
|
||||
return (
|
||||
source_xp.float64
|
||||
if (source_supports_64 and target_supports_64)
|
||||
else source_xp.float32
|
||||
)
|
||||
|
||||
|
||||
def atleast_int32(source_xp, target_xp):
|
||||
"""Return source_xp.int64 if both modules support it, otherwise source_xp.int32."""
|
||||
source_supports_64 = _supports_higher_precision(
|
||||
source_xp, source_xp.int32, source_xp.int64
|
||||
)
|
||||
target_supports_64 = _supports_higher_precision(
|
||||
target_xp, target_xp.int32, target_xp.int64
|
||||
)
|
||||
return (
|
||||
source_xp.int64
|
||||
if (source_supports_64 and target_supports_64)
|
||||
else source_xp.int32
|
||||
)
|
||||
|
||||
|
||||
def value_parametrization():
|
||||
for source_xp, target_xp in installed_modules_combinations:
|
||||
xp = module_namespace(source_xp)
|
||||
source_xp = module_namespace(source_xp)
|
||||
target_xp = module_namespace(target_xp)
|
||||
for value, expected_value in [
|
||||
(2, xp.asarray(2, dtype=atleast_int32(source_xp, target_xp))),
|
||||
(
|
||||
(3.0, 4),
|
||||
(
|
||||
xp.asarray(3.0, dtype=atleast_float32(source_xp, target_xp)),
|
||||
xp.asarray(4, dtype=atleast_int32(source_xp, target_xp)),
|
||||
),
|
||||
),
|
||||
(
|
||||
[3.0, 4],
|
||||
[
|
||||
xp.asarray(3.0, dtype=atleast_float32(source_xp, target_xp)),
|
||||
xp.asarray(4, dtype=atleast_int32(source_xp, target_xp)),
|
||||
],
|
||||
),
|
||||
(
|
||||
{
|
||||
"a": 6.0,
|
||||
"b": 7,
|
||||
},
|
||||
{
|
||||
"a": xp.asarray(6.0, dtype=atleast_float32(source_xp, target_xp)),
|
||||
"b": xp.asarray(7, dtype=atleast_int32(source_xp, target_xp)),
|
||||
},
|
||||
),
|
||||
(xp.asarray(1.0, dtype=xp.float32), xp.asarray(1.0, dtype=xp.float32)),
|
||||
(xp.asarray(1.0, dtype=xp.uint8), xp.asarray(1.0, dtype=xp.uint8)),
|
||||
(xp.asarray([1, 2], dtype=xp.int32), xp.asarray([1, 2], dtype=xp.int32)),
|
||||
(
|
||||
xp.asarray([[1.0], [2.0]], dtype=xp.int32),
|
||||
xp.asarray([[1.0], [2.0]], dtype=xp.int32),
|
||||
),
|
||||
(
|
||||
{
|
||||
"a": (
|
||||
1,
|
||||
xp.asarray(2.0, dtype=xp.float32),
|
||||
xp.asarray([3, 4], dtype=xp.int32),
|
||||
),
|
||||
"b": {"c": 5},
|
||||
},
|
||||
{
|
||||
"a": (
|
||||
xp.asarray(1, dtype=atleast_int32(source_xp, target_xp)),
|
||||
xp.asarray(2.0, dtype=xp.float32),
|
||||
xp.asarray([3, 4], dtype=xp.int32),
|
||||
),
|
||||
"b": {
|
||||
"c": xp.asarray(5, dtype=atleast_int32(source_xp, target_xp))
|
||||
},
|
||||
},
|
||||
),
|
||||
(
|
||||
ExampleNamedTuple(
|
||||
a=xp.asarray([1, 2], dtype=xp.int32),
|
||||
b=xp.asarray([1.0, 2.0], dtype=xp.float32),
|
||||
),
|
||||
ExampleNamedTuple(
|
||||
a=xp.asarray([1, 2], dtype=xp.int32),
|
||||
b=xp.asarray([1.0, 2.0], dtype=xp.float32),
|
||||
),
|
||||
),
|
||||
(None, None),
|
||||
]:
|
||||
yield (source_xp, target_xp, value, expected_value)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"source_xp,target_xp,value,expected_value", value_parametrization()
|
||||
)
|
||||
def test_roundtripping(source_xp, target_xp, value, expected_value):
|
||||
"""Test roundtripping between different Array API compatible frameworks."""
|
||||
roundtripped_value = array_conversion(
|
||||
array_conversion(value, xp=target_xp), xp=source_xp
|
||||
)
|
||||
assert xp_data_equivalence(roundtripped_value, expected_value)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("env_xp, target_xp", installed_modules_combinations)
|
||||
def test_array_conversion_wrapper(env_xp, target_xp):
|
||||
# Define reset and step functions without partial to avoid pickling issues
|
||||
|
||||
def reset_func(self, seed=None, options=None):
|
||||
"""A generic array API reset function."""
|
||||
return env_xp.asarray([1.0, 2.0, 3.0]), {"data": env_xp.asarray([1, 2, 3])}
|
||||
|
||||
def step_func(self, action):
|
||||
"""A generic array API step function."""
|
||||
assert isinstance(action, type(env_xp.zeros(1)))
|
||||
return (
|
||||
env_xp.asarray([1, 2, 3]),
|
||||
env_xp.asarray(5.0),
|
||||
env_xp.asarray(True),
|
||||
env_xp.asarray(False),
|
||||
{"data": env_xp.asarray([1.0, 2.0])},
|
||||
)
|
||||
|
||||
env = GenericTestEnv(reset_func=reset_func, step_func=step_func)
|
||||
|
||||
# Check that the reset and step for env_xp environment are as expected
|
||||
obs, info = env.reset()
|
||||
# env_xp is automatically converted to the compatible namespace by array_namespace, so we need
|
||||
# to check against the compatible namespace of env_xp in array_api_compat
|
||||
env_xp_compat = module_namespace(env_xp)
|
||||
assert array_namespace(obs) is env_xp_compat
|
||||
assert isinstance(info, dict) and array_namespace(info["data"]) is env_xp_compat
|
||||
|
||||
obs, reward, terminated, truncated, info = env.step(env_xp_compat.asarray([1, 2]))
|
||||
assert array_namespace(obs) is env_xp_compat
|
||||
assert array_namespace(reward) is env_xp_compat
|
||||
assert array_namespace(terminated) is env_xp_compat
|
||||
assert array_namespace(truncated) is env_xp_compat
|
||||
assert isinstance(info, dict) and array_namespace(info["data"]) is env_xp_compat
|
||||
|
||||
# Check that the wrapped version is correct.
|
||||
target_xp_compat = module_namespace(target_xp)
|
||||
wrapped_env = ArrayConversion(env, env_xp=env_xp, target_xp=target_xp)
|
||||
obs, info = wrapped_env.reset()
|
||||
assert array_namespace(obs) is target_xp_compat
|
||||
assert isinstance(info, dict) and array_namespace(info["data"]) is target_xp_compat
|
||||
|
||||
action = target_xp.asarray([1, 2], dtype=target_xp.int32)
|
||||
obs, reward, terminated, truncated, info = wrapped_env.step(action)
|
||||
assert array_namespace(obs) is target_xp_compat
|
||||
assert isinstance(reward, float)
|
||||
assert isinstance(terminated, bool) and isinstance(truncated, bool)
|
||||
assert isinstance(info, dict) and array_namespace(info["data"]) is target_xp_compat
|
||||
|
||||
# Check that the wrapped environment can render. This implicitly returns None and requires a
|
||||
# None -> None conversion
|
||||
wrapped_env.render()
|
||||
|
||||
# Test that the wrapped environment can be pickled
|
||||
env = gymnasium.make("CartPole-v1", disable_env_checker=True)
|
||||
wrapped_env = ArrayConversion(env, env_xp=env_xp, target_xp=target_xp)
|
||||
pkl = pickle.dumps(wrapped_env)
|
||||
pickle.loads(pkl)
|
@@ -1,10 +1,13 @@
|
||||
"""Test suite for JaxToNumpy wrapper."""
|
||||
|
||||
import pickle
|
||||
from typing import NamedTuple
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import gymnasium
|
||||
|
||||
|
||||
jax = pytest.importorskip("jax")
|
||||
jnp = pytest.importorskip("jax.numpy")
|
||||
@@ -133,3 +136,9 @@ def test_jax_to_numpy_wrapper():
|
||||
# Check that the wrapped environment can render. This implicitly returns None and requires a
|
||||
# None -> None conversion
|
||||
numpy_env.render()
|
||||
|
||||
# Test that the wrapped environment can be pickled
|
||||
env = gymnasium.make("CartPole-v1", disable_env_checker=True)
|
||||
wrapped_env = JaxToNumpy(env)
|
||||
pkl = pickle.dumps(wrapped_env)
|
||||
pickle.loads(pkl)
|
||||
|
@@ -1,9 +1,12 @@
|
||||
"""Test suite for TorchToJax wrapper."""
|
||||
|
||||
import pickle
|
||||
from typing import NamedTuple
|
||||
|
||||
import pytest
|
||||
|
||||
import gymnasium
|
||||
|
||||
|
||||
jax = pytest.importorskip("jax")
|
||||
jnp = pytest.importorskip("jax.numpy")
|
||||
@@ -148,3 +151,9 @@ def test_jax_to_torch_wrapper():
|
||||
# Check that the wrapped environment can render. This implicitly returns None and requires a
|
||||
# None -> None conversion
|
||||
wrapped_env.render()
|
||||
|
||||
# Test that the wrapped environment can be pickled
|
||||
env = gymnasium.make("CartPole-v1", disable_env_checker=True)
|
||||
wrapped_env = JaxToTorch(env)
|
||||
pkl = pickle.dumps(wrapped_env)
|
||||
pickle.loads(pkl)
|
||||
|
@@ -1,10 +1,13 @@
|
||||
"""Test suite for NumPyToTorch wrapper."""
|
||||
|
||||
import pickle
|
||||
from typing import NamedTuple
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import gymnasium
|
||||
|
||||
|
||||
torch = pytest.importorskip("torch")
|
||||
|
||||
@@ -128,3 +131,9 @@ def test_numpy_to_torch():
|
||||
# Check that the wrapped environment can render. This implicitly returns None and requires a
|
||||
# None -> None conversion
|
||||
torch_env.render()
|
||||
|
||||
# Test that the wrapped environment can be pickled
|
||||
env = gymnasium.make("CartPole-v1", disable_env_checker=True)
|
||||
wrapped_env = NumpyToTorch(env)
|
||||
pkl = pickle.dumps(wrapped_env)
|
||||
pickle.loads(pkl)
|
||||
|
146
tests/wrappers/vector/test_array_conversion.py
Normal file
146
tests/wrappers/vector/test_array_conversion.py
Normal file
@@ -0,0 +1,146 @@
|
||||
"""Test suite for vector ArrayConversion wrapper."""
|
||||
|
||||
import importlib
|
||||
import itertools
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from tests.testing_env import GenericTestVectorEnv
|
||||
|
||||
|
||||
array_api_compat = pytest.importorskip("array_api_compat")
|
||||
from array_api_compat import array_namespace # noqa: E402
|
||||
|
||||
from gymnasium.wrappers.array_conversion import module_namespace # noqa: E402
|
||||
from gymnasium.wrappers.vector.array_conversion import ArrayConversion # noqa: E402
|
||||
from gymnasium.wrappers.vector.jax_to_numpy import JaxToNumpy # noqa: E402
|
||||
from gymnasium.wrappers.vector.jax_to_torch import JaxToTorch # noqa: E402
|
||||
from gymnasium.wrappers.vector.numpy_to_torch import NumpyToTorch # noqa: E402
|
||||
|
||||
|
||||
# Define available modules
|
||||
installed_modules = []
|
||||
array_api_modules = [
|
||||
"numpy",
|
||||
"jax.numpy",
|
||||
"torch",
|
||||
"cupy",
|
||||
"dask.array",
|
||||
"sparse",
|
||||
"array_api_strict",
|
||||
]
|
||||
for module in array_api_modules:
|
||||
try:
|
||||
installed_modules.append(importlib.import_module(module))
|
||||
except ImportError:
|
||||
pass # Modules that are not installed are skipped
|
||||
|
||||
installed_modules_combinations = list(itertools.permutations(installed_modules, 2))
|
||||
|
||||
|
||||
def create_vector_env(env_xp):
|
||||
_reset_func = partial(reset_func, num_envs=3, xp=env_xp)
|
||||
_step_func = partial(step_func, num_envs=3, xp=env_xp)
|
||||
return GenericTestVectorEnv(reset_func=_reset_func, step_func=_step_func)
|
||||
|
||||
|
||||
def reset_func(self, seed=None, options=None, num_envs: int = 1, xp=np):
|
||||
return xp.asarray([[1.0, 2.0, 3.0] * num_envs]), {
|
||||
"data": xp.asarray([[1, 2, 3] * num_envs])
|
||||
}
|
||||
|
||||
|
||||
def step_func(self, action, num_envs: int = 1, xp=np):
|
||||
assert isinstance(action, type(xp.zeros(1)))
|
||||
return (
|
||||
xp.asarray([[1, 2, 3] * num_envs]),
|
||||
xp.asarray([5.0] * num_envs),
|
||||
xp.asarray([False] * num_envs),
|
||||
xp.asarray([False] * num_envs),
|
||||
{"data": xp.asarray([[1.0, 2.0] * num_envs])},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("env_xp, target_xp", installed_modules_combinations)
|
||||
def test_array_conversion_wrapper(env_xp, target_xp):
|
||||
env_xp_compat = module_namespace(env_xp)
|
||||
env = create_vector_env(env_xp_compat)
|
||||
|
||||
# Check that the reset and step for env_xp environment are as expected
|
||||
obs, info = env.reset()
|
||||
# env_xp is automatically converted to the compatible namespace by array_namespace, so we need
|
||||
# to check against the compatible namespace of env_xp in array_api_compat
|
||||
assert array_namespace(obs) is env_xp_compat
|
||||
assert isinstance(info, dict) and array_namespace(info["data"]) is env_xp_compat
|
||||
|
||||
obs, reward, terminated, truncated, info = env.step(env_xp_compat.asarray([1, 2]))
|
||||
assert array_namespace(obs) is env_xp_compat
|
||||
assert array_namespace(reward) is env_xp_compat
|
||||
assert array_namespace(terminated) is env_xp_compat
|
||||
assert array_namespace(truncated) is env_xp_compat
|
||||
assert isinstance(info, dict) and array_namespace(info["data"]) is env_xp_compat
|
||||
|
||||
# Check that the wrapped version is correct.
|
||||
target_xp_compat = module_namespace(target_xp)
|
||||
wrapped_env = ArrayConversion(env, env_xp=env_xp, target_xp=target_xp)
|
||||
obs, info = wrapped_env.reset()
|
||||
assert array_namespace(obs) is target_xp_compat
|
||||
assert isinstance(info, dict) and array_namespace(info["data"]) is target_xp_compat
|
||||
|
||||
action = target_xp.asarray([1, 2], dtype=target_xp.int32)
|
||||
obs, reward, terminated, truncated, info = wrapped_env.step(action)
|
||||
assert array_namespace(obs) is target_xp_compat
|
||||
assert array_namespace(reward) is target_xp_compat
|
||||
assert array_namespace(terminated) is target_xp_compat
|
||||
assert terminated.dtype == target_xp.bool
|
||||
assert array_namespace(truncated) is target_xp_compat
|
||||
assert truncated.dtype == target_xp.bool
|
||||
assert isinstance(info, dict) and array_namespace(info["data"]) is target_xp_compat
|
||||
|
||||
# Check that the wrapped environment can render. This implicitly returns None and requires a
|
||||
# None -> None conversion
|
||||
wrapped_env.render()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("wrapper", [JaxToNumpy, JaxToTorch, NumpyToTorch])
|
||||
def test_specialized_wrappers(wrapper: type[JaxToNumpy | JaxToTorch | NumpyToTorch]):
|
||||
if wrapper is JaxToNumpy:
|
||||
jax = pytest.importorskip("jax")
|
||||
env_xp, target_xp = jax.numpy, np
|
||||
elif wrapper is JaxToTorch:
|
||||
jax = pytest.importorskip("jax")
|
||||
torch = pytest.importorskip("torch")
|
||||
env_xp, target_xp = jax.numpy, torch
|
||||
elif wrapper is NumpyToTorch:
|
||||
torch = pytest.importorskip("torch")
|
||||
env_xp, target_xp = np, torch
|
||||
else:
|
||||
raise TypeError(f"Unknown specialized conversion wrapper {type(wrapper)}")
|
||||
env_xp_compat = module_namespace(env_xp)
|
||||
target_xp_compat = module_namespace(target_xp)
|
||||
|
||||
# The unwrapped test env sanity check is already covered by test_array_conversion_wrapper for
|
||||
# all known frameworks, including the specialized ones.
|
||||
env = create_vector_env(env_xp_compat)
|
||||
|
||||
# Check that the wrapped version is correct.
|
||||
wrapped_env = wrapper(env)
|
||||
obs, info = wrapped_env.reset()
|
||||
assert array_namespace(obs) is target_xp_compat
|
||||
assert isinstance(info, dict) and array_namespace(info["data"]) is target_xp_compat
|
||||
|
||||
action = target_xp.asarray([1, 2], dtype=target_xp.int32)
|
||||
obs, reward, terminated, truncated, info = wrapped_env.step(action)
|
||||
assert array_namespace(obs) is target_xp_compat
|
||||
assert array_namespace(reward) is target_xp_compat
|
||||
assert array_namespace(terminated) is target_xp_compat
|
||||
assert terminated.dtype == target_xp.bool
|
||||
assert array_namespace(truncated) is target_xp_compat
|
||||
assert truncated.dtype == target_xp.bool
|
||||
assert isinstance(info, dict) and array_namespace(info["data"]) is target_xp_compat
|
||||
|
||||
# Check that the wrapped environment can render. This implicitly returns None and requires a
|
||||
# None -> None conversion
|
||||
wrapped_env.render()
|
Reference in New Issue
Block a user