mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-07-31 05:44:31 +00:00
Add generic conversion wrapper between Array API compatible frameworks (#1333)
This commit is contained in:
2
.github/workflows/docs-build-dev.yml
vendored
2
.github/workflows/docs-build-dev.yml
vendored
@@ -19,7 +19,7 @@ jobs:
|
|||||||
|
|
||||||
- uses: actions/setup-python@v5
|
- uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: '3.9'
|
python-version: '3.10'
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: pip install -r docs/requirements.txt
|
run: pip install -r docs/requirements.txt
|
||||||
|
2
.github/workflows/docs-build-release.yml
vendored
2
.github/workflows/docs-build-release.yml
vendored
@@ -19,7 +19,7 @@ jobs:
|
|||||||
|
|
||||||
- uses: actions/setup-python@v5
|
- uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: '3.9'
|
python-version: '3.10'
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: pip install -r docs/requirements.txt
|
run: pip install -r docs/requirements.txt
|
||||||
|
2
.github/workflows/docs-manual-build.yml
vendored
2
.github/workflows/docs-manual-build.yml
vendored
@@ -33,7 +33,7 @@ jobs:
|
|||||||
|
|
||||||
- uses: actions/setup-python@v5
|
- uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: '3.9'
|
python-version: '3.10'
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: pip install -r docs/requirements.txt
|
run: pip install -r docs/requirements.txt
|
||||||
|
5
.github/workflows/run-pytest.yml
vendored
5
.github/workflows/run-pytest.yml
vendored
@@ -10,8 +10,8 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
fail-fast: true
|
fail-fast: true
|
||||||
matrix:
|
matrix:
|
||||||
python-version: ['3.9', '3.10', '3.11', '3.12']
|
python-version: ['3.10', '3.11', '3.12']
|
||||||
numpy-version: ['>=1.21,<2.0', '>=2.0']
|
numpy-version: ['>=1.21,<2.0', '>=2.1']
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
- run: |
|
- run: |
|
||||||
@@ -22,7 +22,6 @@ jobs:
|
|||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: docker run gymnasium-all-docker pytest tests/*
|
run: docker run gymnasium-all-docker pytest tests/*
|
||||||
- name: Run doctests
|
- name: Run doctests
|
||||||
if: ${{ matrix.python-version != '3.8' }}
|
|
||||||
run: docker run gymnasium-all-docker pytest --doctest-modules gymnasium/
|
run: docker run gymnasium-all-docker pytest --doctest-modules gymnasium/
|
||||||
|
|
||||||
build-necessary:
|
build-necessary:
|
||||||
|
2
.github/workflows/run-tutorial.yml
vendored
2
.github/workflows/run-tutorial.yml
vendored
@@ -20,7 +20,7 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
fail-fast: false # This ensures all matrix combinations run even if one fails
|
fail-fast: false # This ensures all matrix combinations run even if one fails
|
||||||
matrix:
|
matrix:
|
||||||
python-version: ["3.9"]
|
python-version: ["3.10"]
|
||||||
tutorial-group:
|
tutorial-group:
|
||||||
- gymnasium_basics
|
- gymnasium_basics
|
||||||
- training_agents
|
- training_agents
|
||||||
|
@@ -32,7 +32,7 @@ To install the base Gymnasium library, use `pip install gymnasium`
|
|||||||
|
|
||||||
This does not include dependencies for all families of environments (there's a massive number, and some can be problematic to install on certain systems). You can install these dependencies for one family like `pip install "gymnasium[atari]"` or use `pip install "gymnasium[all]"` to install all dependencies.
|
This does not include dependencies for all families of environments (there's a massive number, and some can be problematic to install on certain systems). You can install these dependencies for one family like `pip install "gymnasium[atari]"` or use `pip install "gymnasium[all]"` to install all dependencies.
|
||||||
|
|
||||||
We support and test for Python 3.8, 3.9, 3.10, 3.11 and 3.12 on Linux and macOS. We will accept PRs related to Windows, but do not officially support it.
|
We support and test for Python 3.10, 3.11 and 3.12 on Linux and macOS. We will accept PRs related to Windows, but do not officially support it.
|
||||||
|
|
||||||
## API
|
## API
|
||||||
|
|
||||||
|
@@ -27,6 +27,7 @@ title: Misc Wrappers
|
|||||||
## Data Conversion Wrappers
|
## Data Conversion Wrappers
|
||||||
|
|
||||||
```{eval-rst}
|
```{eval-rst}
|
||||||
|
.. autoclass:: gymnasium.wrappers.ArrayConversion
|
||||||
.. autoclass:: gymnasium.wrappers.JaxToNumpy
|
.. autoclass:: gymnasium.wrappers.JaxToNumpy
|
||||||
.. autoclass:: gymnasium.wrappers.JaxToTorch
|
.. autoclass:: gymnasium.wrappers.JaxToTorch
|
||||||
.. autoclass:: gymnasium.wrappers.NumpyToTorch
|
.. autoclass:: gymnasium.wrappers.NumpyToTorch
|
||||||
|
@@ -34,6 +34,8 @@ wrapper in the page on the wrapper type
|
|||||||
- Converts an image observation computed by ``reset`` and ``step`` from RGB to Grayscale.
|
- Converts an image observation computed by ``reset`` and ``step`` from RGB to Grayscale.
|
||||||
* - :class:`HumanRendering`
|
* - :class:`HumanRendering`
|
||||||
- Allows human like rendering for environments that support "rgb_array" rendering.
|
- Allows human like rendering for environments that support "rgb_array" rendering.
|
||||||
|
* - :class:`ArrayConversion`
|
||||||
|
- Wraps an environment based on any Array API compatible framework, e.g. torch, jax, numpy, such that it can be interacted with any other Array API compatible framework.
|
||||||
* - :class:`JaxToNumpy`
|
* - :class:`JaxToNumpy`
|
||||||
- Wraps a Jax-based environment such that it can be interacted with NumPy arrays.
|
- Wraps a Jax-based environment such that it can be interacted with NumPy arrays.
|
||||||
* - :class:`JaxToTorch`
|
* - :class:`JaxToTorch`
|
||||||
|
261
gymnasium/wrappers/array_conversion.py
Normal file
261
gymnasium/wrappers/array_conversion.py
Normal file
@@ -0,0 +1,261 @@
|
|||||||
|
# This wrapper will convert array inputs from an Array API compatible framework A for the actions
|
||||||
|
# to any other Array API compatible framework B for an underlying environment that is implemented
|
||||||
|
# in framework B, then convert the return observations from framework B back to framework A.
|
||||||
|
#
|
||||||
|
# More precisely, the wrapper will work for any two frameworks that can be made compatible with the
|
||||||
|
# `array-api-compat` package.
|
||||||
|
#
|
||||||
|
# See https://data-apis.org/array-api/latest/ for more information on the Array API standard, and
|
||||||
|
# https://data-apis.org/array-api-compat/ for more information on the Array API compatibility layer.
|
||||||
|
#
|
||||||
|
# General structure for converting between types originally copied from
|
||||||
|
# https://github.com/google/brax/blob/9d6b7ced2a13da0d074b5e9fbd3aad8311e26997/brax/io/torch.py
|
||||||
|
# Under the Apache 2.0 license. Copyright is held by the authors
|
||||||
|
|
||||||
|
"""Helper functions and wrapper class for converting between arbitrary Array API compatible frameworks and a target framework."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import functools
|
||||||
|
import importlib
|
||||||
|
import numbers
|
||||||
|
from collections import abc
|
||||||
|
from types import ModuleType, NoneType
|
||||||
|
from typing import Any, Iterable, Mapping, SupportsFloat
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from packaging.version import Version
|
||||||
|
|
||||||
|
import gymnasium as gym
|
||||||
|
from gymnasium.core import RenderFrame, WrapperActType, WrapperObsType
|
||||||
|
from gymnasium.error import DependencyNotInstalled
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
from array_api_compat import array_namespace, is_array_api_obj, to_device
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
raise DependencyNotInstalled(
|
||||||
|
'Array API packages are not installed therefore cannot call `array_conversion`, run `pip install "gymnasium[array-api]"`'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if Version(np.__version__) < Version("2.1.0"):
|
||||||
|
raise DependencyNotInstalled("Array API functionality requires numpy >= 2.1.0")
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["ArrayConversion", "array_conversion"]
|
||||||
|
|
||||||
|
Array = Any # TODO: Switch to ArrayAPI type once https://github.com/data-apis/array-api/pull/589 is merged
|
||||||
|
Device = Any # TODO: Switch to ArrayAPI type if available
|
||||||
|
|
||||||
|
|
||||||
|
def module_namespace(xp: ModuleType) -> ModuleType:
|
||||||
|
"""Determine the Array API compatible namespace of the given module.
|
||||||
|
|
||||||
|
This function is closely linked to the `array_api_compat.array_namespace` function. It returns
|
||||||
|
the compatible namespace for a module directly instead of from an array object of that module.
|
||||||
|
|
||||||
|
See https://data-apis.org/array-api-compat/helper-functions.html#array_api_compat.array_namespace
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return array_namespace(xp.empty(0))
|
||||||
|
except AttributeError as e:
|
||||||
|
raise ValueError(f"Module {xp} is not an Array API compatible module.") from e
|
||||||
|
|
||||||
|
|
||||||
|
def module_name_to_namespace(name: str) -> ModuleType:
|
||||||
|
return module_namespace(importlib.import_module(name))
|
||||||
|
|
||||||
|
|
||||||
|
@functools.singledispatch
|
||||||
|
def array_conversion(value: Any, xp: ModuleType, device: Device | None = None) -> Any:
|
||||||
|
"""Convert a value into the specified xp module array type."""
|
||||||
|
raise Exception(
|
||||||
|
f"No known conversion for ({type(value)}) to xp module ({xp}) registered. Report as issue on github."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@array_conversion.register(numbers.Number)
|
||||||
|
def _number_array_conversion(
|
||||||
|
value: numbers.Number, xp: ModuleType, device: Device | None = None
|
||||||
|
) -> Array:
|
||||||
|
"""Convert a python number (int, float, complex) to an Array API framework array."""
|
||||||
|
return xp.asarray(value, device=device)
|
||||||
|
|
||||||
|
|
||||||
|
@array_conversion.register(abc.Mapping)
|
||||||
|
def _mapping_array_conversion(
|
||||||
|
value: Mapping[str, Any], xp: ModuleType, device: Device | None = None
|
||||||
|
) -> Mapping[str, Any]:
|
||||||
|
"""Convert a mapping of Arrays into a Dictionary of the specified xp module array type."""
|
||||||
|
return type(value)(**{k: array_conversion(v, xp, device) for k, v in value.items()})
|
||||||
|
|
||||||
|
|
||||||
|
@array_conversion.register(abc.Iterable)
|
||||||
|
def _iterable_array_conversion(
|
||||||
|
value: Iterable[Any], xp: ModuleType, device: Device | None = None
|
||||||
|
) -> Iterable[Any]:
|
||||||
|
"""Convert an Iterable from Arrays to an iterable of the specified xp module array type."""
|
||||||
|
# There is currently no type for ArrayAPI compatible objects, so they fall through to this
|
||||||
|
# function registered for any Iterable. If they are arrays, we can convert them directly.
|
||||||
|
# We currently cannot pass the device to the from_dlpack function, since it is not supported
|
||||||
|
# for some frameworks (see e.g. https://github.com/data-apis/array-api-compat/issues/204)
|
||||||
|
if is_array_api_obj(value):
|
||||||
|
return _array_api_array_conversion(value, xp, device)
|
||||||
|
if hasattr(value, "_make"):
|
||||||
|
# namedtuple - underline used to prevent potential name conflicts
|
||||||
|
# noinspection PyProtectedMember
|
||||||
|
return type(value)._make(array_conversion(v, xp, device) for v in value)
|
||||||
|
return type(value)(array_conversion(v, xp, device) for v in value)
|
||||||
|
|
||||||
|
|
||||||
|
def _array_api_array_conversion(
|
||||||
|
value: Array, xp: ModuleType, device: Device | None = None
|
||||||
|
) -> Array:
|
||||||
|
"""Convert an Array API compatible array to the specified xp module array type."""
|
||||||
|
try:
|
||||||
|
x = xp.from_dlpack(value)
|
||||||
|
return to_device(x, device) if device is not None else x
|
||||||
|
except (RuntimeError, BufferError):
|
||||||
|
# If dlpack fails (e.g. because the array is read-only for frameworks that do not
|
||||||
|
# support it), we create a copy of the array that we own and then convert it.
|
||||||
|
# TODO: The correct treatment of read-only arrays is currently not fully clear in the
|
||||||
|
# Array API. Once ongoing discussions are resolved, we should update this code to remove
|
||||||
|
# any fallbacks.
|
||||||
|
value_namespace = array_namespace(value)
|
||||||
|
value_copy = value_namespace.asarray(value, copy=True)
|
||||||
|
return xp.asarray(value_copy, device=device)
|
||||||
|
|
||||||
|
|
||||||
|
@array_conversion.register(NoneType)
|
||||||
|
def _none_array_conversion(
|
||||||
|
value: None, xp: ModuleType, device: Device | None = None
|
||||||
|
) -> None:
|
||||||
|
"""Passes through None values."""
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
class ArrayConversion(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
||||||
|
"""Wraps an Array API compatible environment so that it can be interacted with a specific Array API framework.
|
||||||
|
|
||||||
|
Actions must be provided as Array API compatible arrays and observations will be returned as Arrays of the specified xp module.
|
||||||
|
A vector version of the wrapper exists, :class:`gymnasium.wrappers.vector.ArrayConversion`.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> import torch # doctest: +SKIP
|
||||||
|
>>> import jax.numpy as jnp # doctest: +SKIP
|
||||||
|
>>> import gymnasium as gym # doctest: +SKIP
|
||||||
|
>>> env = gym.make("JaxEnv-vx") # doctest: +SKIP
|
||||||
|
>>> env = ArrayConversion(env, env_xp=jnp, target_xp=torch) # doctest: +SKIP
|
||||||
|
>>> obs, _ = env.reset(seed=123) # doctest: +SKIP
|
||||||
|
>>> type(obs) # doctest: +SKIP
|
||||||
|
<class 'torch.Tensor'>
|
||||||
|
>>> action = torch.tensor(env.action_space.sample()) # doctest: +SKIP
|
||||||
|
>>> obs, reward, terminated, truncated, info = env.step(action) # doctest: +SKIP
|
||||||
|
>>> type(obs) # doctest: +SKIP
|
||||||
|
<class 'torch.Tensor'>
|
||||||
|
>>> type(reward) # doctest: +SKIP
|
||||||
|
<class 'float'>
|
||||||
|
>>> type(terminated) # doctest: +SKIP
|
||||||
|
<class 'bool'>
|
||||||
|
>>> type(truncated) # doctest: +SKIP
|
||||||
|
<class 'bool'>
|
||||||
|
|
||||||
|
Change logs:
|
||||||
|
* v1.2.0 - Initially added
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
env: gym.Env,
|
||||||
|
env_xp: ModuleType,
|
||||||
|
target_xp: ModuleType,
|
||||||
|
env_device: Device | None = None,
|
||||||
|
target_device: Device | None = None,
|
||||||
|
):
|
||||||
|
"""Wrapper class to change inputs and outputs of environment to any Array API framework.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env: The Array API compatible environment to wrap
|
||||||
|
env_xp: The Array API framework the environment is on
|
||||||
|
target_xp: The Array API framework to convert to
|
||||||
|
env_device: The device the environment is on
|
||||||
|
target_device: The device on which Arrays should be returned
|
||||||
|
"""
|
||||||
|
gym.utils.RecordConstructorArgs.__init__(self)
|
||||||
|
gym.Wrapper.__init__(self, env)
|
||||||
|
|
||||||
|
self._env_xp = module_namespace(env_xp)
|
||||||
|
self._target_xp = module_namespace(target_xp)
|
||||||
|
self._env_device: Device | None = env_device
|
||||||
|
self._target_device: Device | None = target_device
|
||||||
|
|
||||||
|
def step(
|
||||||
|
self, action: WrapperActType
|
||||||
|
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]:
|
||||||
|
"""Performs the given action within the environment.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
action: The action to perform as any Array API compatible array
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The next observation, reward, termination, truncation, and extra info
|
||||||
|
"""
|
||||||
|
action = array_conversion(action, xp=self._env_xp, device=self._env_device)
|
||||||
|
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||||
|
|
||||||
|
return (
|
||||||
|
array_conversion(obs, xp=self._target_xp, device=self._target_device),
|
||||||
|
float(reward),
|
||||||
|
bool(terminated),
|
||||||
|
bool(truncated),
|
||||||
|
array_conversion(info, xp=self._target_xp, device=self._target_device),
|
||||||
|
)
|
||||||
|
|
||||||
|
def reset(
|
||||||
|
self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
||||||
|
) -> tuple[WrapperObsType, dict[str, Any]]:
|
||||||
|
"""Resets the environment returning observation and info as Array from any Array API compatible framework.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seed: The seed for resetting the environment
|
||||||
|
options: The options for resetting the environment, these are converted to jax arrays.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
xp-based observations and info
|
||||||
|
"""
|
||||||
|
if options:
|
||||||
|
options = array_conversion(options, self._env_xp, self._env_device)
|
||||||
|
|
||||||
|
return array_conversion(
|
||||||
|
self.env.reset(seed=seed, options=options),
|
||||||
|
self._target_xp,
|
||||||
|
self._target_device,
|
||||||
|
)
|
||||||
|
|
||||||
|
def render(self) -> RenderFrame | list[RenderFrame] | None:
|
||||||
|
"""Returns the rendered frames as an xp Array."""
|
||||||
|
return array_conversion(self.env.render(), self._target_xp, self._target_device)
|
||||||
|
|
||||||
|
def __getstate__(self):
|
||||||
|
"""Returns the object pickle state with args and kwargs."""
|
||||||
|
env_xp_name = self._env_xp.__name__.replace("array_api_compat.", "")
|
||||||
|
target_xp_name = self._target_xp.__name__.replace("array_api_compat.", "")
|
||||||
|
env_device = self._env_device
|
||||||
|
target_device = self._target_device
|
||||||
|
return {
|
||||||
|
"env_xp_name": env_xp_name,
|
||||||
|
"target_xp_name": target_xp_name,
|
||||||
|
"env_device": env_device,
|
||||||
|
"target_device": target_device,
|
||||||
|
"env": self.env,
|
||||||
|
}
|
||||||
|
|
||||||
|
def __setstate__(self, d):
|
||||||
|
"""Sets the object pickle state using d."""
|
||||||
|
self.env = d["env"]
|
||||||
|
self._env_xp = module_name_to_namespace(d["env_xp_name"])
|
||||||
|
self._target_xp = module_name_to_namespace(d["target_xp_name"])
|
||||||
|
self._env_device = d["env_device"]
|
||||||
|
self._target_device = d["target_device"]
|
@@ -3,19 +3,20 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
import numbers
|
|
||||||
from collections import abc
|
|
||||||
from typing import Any, Iterable, Mapping, SupportsFloat
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
from gymnasium.core import ActType, ObsType, RenderFrame, WrapperActType, WrapperObsType
|
from gymnasium.core import ActType, ObsType
|
||||||
from gymnasium.error import DependencyNotInstalled
|
from gymnasium.error import DependencyNotInstalled
|
||||||
|
from gymnasium.wrappers.array_conversion import (
|
||||||
|
ArrayConversion,
|
||||||
|
array_conversion,
|
||||||
|
module_namespace,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import jax
|
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise DependencyNotInstalled(
|
raise DependencyNotInstalled(
|
||||||
@@ -24,110 +25,13 @@ except ImportError:
|
|||||||
|
|
||||||
__all__ = ["JaxToNumpy", "jax_to_numpy", "numpy_to_jax"]
|
__all__ = ["JaxToNumpy", "jax_to_numpy", "numpy_to_jax"]
|
||||||
|
|
||||||
# The NoneType is not defined in Python 3.9. Remove when the minimal version is bumped to >=3.10
|
|
||||||
_NoneType = type(None)
|
jax_to_numpy = functools.partial(array_conversion, xp=module_namespace(np))
|
||||||
|
|
||||||
|
numpy_to_jax = functools.partial(array_conversion, xp=module_namespace(jnp))
|
||||||
|
|
||||||
|
|
||||||
@functools.singledispatch
|
class JaxToNumpy(ArrayConversion):
|
||||||
def numpy_to_jax(value: Any) -> Any:
|
|
||||||
"""Converts a value to a Jax Array."""
|
|
||||||
raise Exception(
|
|
||||||
f"No known conversion for Numpy type ({type(value)}) to Jax registered. Report as issue on github."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@numpy_to_jax.register(numbers.Number)
|
|
||||||
def _number_to_jax(
|
|
||||||
value: numbers.Number,
|
|
||||||
) -> jax.Array:
|
|
||||||
"""Converts a number (int, float, etc.) to a Jax Array."""
|
|
||||||
assert jnp is not None
|
|
||||||
return jnp.array(value)
|
|
||||||
|
|
||||||
|
|
||||||
@numpy_to_jax.register(np.ndarray)
|
|
||||||
def _numpy_array_to_jax(value: np.ndarray) -> jax.Array:
|
|
||||||
"""Converts a NumPy Array to a Jax Array with the same dtype (excluding float64 without being enabled)."""
|
|
||||||
assert jnp is not None
|
|
||||||
return jnp.array(value, dtype=value.dtype)
|
|
||||||
|
|
||||||
|
|
||||||
@numpy_to_jax.register(abc.Mapping)
|
|
||||||
def _mapping_numpy_to_jax(value: Mapping[str, Any]) -> Mapping[str, Any]:
|
|
||||||
"""Converts a dictionary of numpy arrays to a mapping of Jax Array."""
|
|
||||||
return type(value)(**{k: numpy_to_jax(v) for k, v in value.items()})
|
|
||||||
|
|
||||||
|
|
||||||
@numpy_to_jax.register(abc.Iterable)
|
|
||||||
def _iterable_numpy_to_jax(
|
|
||||||
value: Iterable[np.ndarray | Any],
|
|
||||||
) -> Iterable[jax.Array | Any]:
|
|
||||||
"""Converts an Iterable from Numpy Arrays to an iterable of Jax Array."""
|
|
||||||
if hasattr(value, "_make"):
|
|
||||||
# namedtuple - underline used to prevent potential name conflicts
|
|
||||||
# noinspection PyProtectedMember
|
|
||||||
return type(value)._make(numpy_to_jax(v) for v in value)
|
|
||||||
else:
|
|
||||||
return type(value)(numpy_to_jax(v) for v in value)
|
|
||||||
|
|
||||||
|
|
||||||
@numpy_to_jax.register(_NoneType)
|
|
||||||
def _none_numpy_to_jax(value: None) -> None:
|
|
||||||
"""Passes through None values."""
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
@functools.singledispatch
|
|
||||||
def jax_to_numpy(value: Any) -> Any:
|
|
||||||
"""Converts a value to a numpy array."""
|
|
||||||
raise Exception(
|
|
||||||
f"No known conversion for Jax type ({type(value)}) to NumPy registered. Report as issue on github."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@jax_to_numpy.register(jax.Array)
|
|
||||||
def _devicearray_jax_to_numpy(value: jax.Array) -> np.ndarray:
|
|
||||||
"""Converts a Jax Array to a numpy array."""
|
|
||||||
return np.array(value)
|
|
||||||
|
|
||||||
|
|
||||||
@jax_to_numpy.register(abc.Mapping)
|
|
||||||
def _mapping_jax_to_numpy(
|
|
||||||
value: Mapping[str, jax.Array | Any],
|
|
||||||
) -> Mapping[str, np.ndarray | Any]:
|
|
||||||
"""Converts a dictionary of Jax Array to a mapping of numpy arrays."""
|
|
||||||
return type(value)(**{k: jax_to_numpy(v) for k, v in value.items()})
|
|
||||||
|
|
||||||
|
|
||||||
@jax_to_numpy.register(abc.Iterable)
|
|
||||||
def _iterable_jax_to_numpy(
|
|
||||||
value: Iterable[np.ndarray | Any],
|
|
||||||
) -> Iterable[jax.Array | Any]:
|
|
||||||
"""Converts an Iterable from Numpy arrays to an iterable of Jax Array."""
|
|
||||||
if isinstance(value, jax.Array):
|
|
||||||
# Since the update to jax 0.6.0, calling jax_to_numpy with a <class 'jaxlib.xla_extension.ArrayImpl'>
|
|
||||||
# argument wrongly dispatches to _iterable_jax_to_numpy which fails with:
|
|
||||||
# TypeError: (): incompatible function arguments.
|
|
||||||
# See: https://github.com/Farama-Foundation/Gymnasium/issues/1360
|
|
||||||
return _devicearray_jax_to_numpy(value)
|
|
||||||
elif hasattr(value, "_make"):
|
|
||||||
# namedtuple - underline used to prevent potential name conflicts
|
|
||||||
# noinspection PyProtectedMember
|
|
||||||
return type(value)._make(jax_to_numpy(v) for v in value)
|
|
||||||
else:
|
|
||||||
return type(value)(jax_to_numpy(v) for v in value)
|
|
||||||
|
|
||||||
|
|
||||||
@jax_to_numpy.register(_NoneType)
|
|
||||||
def _none_jax_to_numpy(value: None) -> None:
|
|
||||||
"""Passes through None values."""
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
class JaxToNumpy(
|
|
||||||
gym.Wrapper[WrapperObsType, WrapperActType, ObsType, ActType],
|
|
||||||
gym.utils.RecordConstructorArgs,
|
|
||||||
):
|
|
||||||
"""Wraps a Jax-based environment such that it can be interacted with NumPy arrays.
|
"""Wraps a Jax-based environment such that it can be interacted with NumPy arrays.
|
||||||
|
|
||||||
Actions must be provided as numpy arrays and observations will be returned as numpy arrays.
|
Actions must be provided as numpy arrays and observations will be returned as numpy arrays.
|
||||||
@@ -169,48 +73,4 @@ class JaxToNumpy(
|
|||||||
raise DependencyNotInstalled(
|
raise DependencyNotInstalled(
|
||||||
'Jax is not installed, run `pip install "gymnasium[jax]"`'
|
'Jax is not installed, run `pip install "gymnasium[jax]"`'
|
||||||
)
|
)
|
||||||
gym.utils.RecordConstructorArgs.__init__(self)
|
super().__init__(env=env, env_xp=jnp, target_xp=np)
|
||||||
gym.Wrapper.__init__(self, env)
|
|
||||||
|
|
||||||
def step(
|
|
||||||
self, action: WrapperActType
|
|
||||||
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]:
|
|
||||||
"""Transforms the action to a jax array .
|
|
||||||
|
|
||||||
Args:
|
|
||||||
action: the action to perform as a numpy array
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A tuple containing numpy versions of the next observation, reward, termination, truncation, and extra info.
|
|
||||||
"""
|
|
||||||
jax_action = numpy_to_jax(action)
|
|
||||||
obs, reward, terminated, truncated, info = self.env.step(jax_action)
|
|
||||||
|
|
||||||
return (
|
|
||||||
jax_to_numpy(obs),
|
|
||||||
float(reward),
|
|
||||||
bool(terminated),
|
|
||||||
bool(truncated),
|
|
||||||
jax_to_numpy(info),
|
|
||||||
)
|
|
||||||
|
|
||||||
def reset(
|
|
||||||
self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
|
||||||
) -> tuple[WrapperObsType, dict[str, Any]]:
|
|
||||||
"""Resets the environment returning numpy-based observation and info.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
seed: The seed for resetting the environment
|
|
||||||
options: The options for resetting the environment, these are converted to jax arrays.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Numpy-based observations and info
|
|
||||||
"""
|
|
||||||
if options:
|
|
||||||
options = numpy_to_jax(options)
|
|
||||||
|
|
||||||
return jax_to_numpy(self.env.reset(seed=seed, options=options))
|
|
||||||
|
|
||||||
def render(self) -> RenderFrame | list[RenderFrame] | None:
|
|
||||||
"""Returns the rendered frames as a numpy array."""
|
|
||||||
return jax_to_numpy(self.env.render())
|
|
||||||
|
@@ -7,22 +7,23 @@
|
|||||||
# Under the Apache 2.0 license. Copyright is held by the authors
|
# Under the Apache 2.0 license. Copyright is held by the authors
|
||||||
|
|
||||||
"""Helper functions and wrapper class for converting between PyTorch and Jax."""
|
"""Helper functions and wrapper class for converting between PyTorch and Jax."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
import numbers
|
from typing import Union
|
||||||
from collections import abc
|
|
||||||
from typing import Any, Iterable, Mapping, SupportsFloat, Union
|
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
from gymnasium.core import RenderFrame, WrapperActType, WrapperObsType
|
|
||||||
from gymnasium.error import DependencyNotInstalled
|
from gymnasium.error import DependencyNotInstalled
|
||||||
|
from gymnasium.wrappers.array_conversion import (
|
||||||
|
ArrayConversion,
|
||||||
|
array_conversion,
|
||||||
|
module_namespace,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import jax
|
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
from jax import dlpack as jax_dlpack
|
|
||||||
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise DependencyNotInstalled(
|
raise DependencyNotInstalled(
|
||||||
@@ -31,7 +32,6 @@ except ImportError:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
import torch
|
import torch
|
||||||
from torch.utils import dlpack as torch_dlpack
|
|
||||||
|
|
||||||
Device = Union[str, torch.device]
|
Device = Union[str, torch.device]
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@@ -42,109 +42,13 @@ except ImportError:
|
|||||||
|
|
||||||
__all__ = ["JaxToTorch", "jax_to_torch", "torch_to_jax", "Device"]
|
__all__ = ["JaxToTorch", "jax_to_torch", "torch_to_jax", "Device"]
|
||||||
|
|
||||||
# The NoneType is not defined in Python 3.9. Remove when the minimal version is bumped to >=3.10
|
|
||||||
_NoneType = type(None)
|
torch_to_jax = functools.partial(array_conversion, xp=module_namespace(jnp))
|
||||||
|
|
||||||
|
jax_to_torch = functools.partial(array_conversion, xp=module_namespace(torch))
|
||||||
|
|
||||||
|
|
||||||
@functools.singledispatch
|
class JaxToTorch(ArrayConversion):
|
||||||
def torch_to_jax(value: Any) -> Any:
|
|
||||||
"""Converts a PyTorch Tensor into a Jax Array."""
|
|
||||||
raise Exception(
|
|
||||||
f"No known conversion for Torch type ({type(value)}) to Jax registered. Report as issue on github."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@torch_to_jax.register(numbers.Number)
|
|
||||||
def _number_torch_to_jax(value: numbers.Number) -> Any:
|
|
||||||
"""Convert a python number (int, float, complex) to a jax array."""
|
|
||||||
return jnp.array(value)
|
|
||||||
|
|
||||||
|
|
||||||
@torch_to_jax.register(torch.Tensor)
|
|
||||||
def _tensor_torch_to_jax(value: torch.Tensor) -> jax.Array:
|
|
||||||
"""Converts a PyTorch Tensor into a Jax Array."""
|
|
||||||
return jax_dlpack.from_dlpack(value) # pyright: ignore[reportPrivateImportUsage]
|
|
||||||
|
|
||||||
|
|
||||||
@torch_to_jax.register(abc.Mapping)
|
|
||||||
def _mapping_torch_to_jax(value: Mapping[str, Any]) -> Mapping[str, Any]:
|
|
||||||
"""Converts a mapping of PyTorch Tensors into a Dictionary of Jax Array."""
|
|
||||||
return type(value)(**{k: torch_to_jax(v) for k, v in value.items()})
|
|
||||||
|
|
||||||
|
|
||||||
@torch_to_jax.register(abc.Iterable)
|
|
||||||
def _iterable_torch_to_jax(value: Iterable[Any]) -> Iterable[Any]:
|
|
||||||
"""Converts an Iterable from PyTorch Tensors to an iterable of Jax Array."""
|
|
||||||
if hasattr(value, "_make"):
|
|
||||||
# namedtuple - underline used to prevent potential name conflicts
|
|
||||||
# noinspection PyProtectedMember
|
|
||||||
return type(value)._make(torch_to_jax(v) for v in value)
|
|
||||||
else:
|
|
||||||
return type(value)(torch_to_jax(v) for v in value)
|
|
||||||
|
|
||||||
|
|
||||||
@torch_to_jax.register(_NoneType)
|
|
||||||
def _none_torch_to_jax(value: None) -> None:
|
|
||||||
"""Passes through None values."""
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
@functools.singledispatch
|
|
||||||
def jax_to_torch(value: Any, device: Device | None = None) -> Any:
|
|
||||||
"""Converts a Jax Array into a PyTorch Tensor."""
|
|
||||||
raise Exception(
|
|
||||||
f"No known conversion for Jax type ({type(value)}) to PyTorch registered. Report as issue on github."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@jax_to_torch.register(jax.Array)
|
|
||||||
def _devicearray_jax_to_torch(
|
|
||||||
value: jax.Array, device: Device | None = None
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""Converts a Jax Array into a PyTorch Tensor."""
|
|
||||||
assert jax_dlpack is not None and torch_dlpack is not None
|
|
||||||
tensor = torch_dlpack.from_dlpack(
|
|
||||||
value
|
|
||||||
) # pyright: ignore[reportPrivateImportUsage]
|
|
||||||
if device:
|
|
||||||
return tensor.to(device=device)
|
|
||||||
return tensor
|
|
||||||
|
|
||||||
|
|
||||||
@jax_to_torch.register(abc.Mapping)
|
|
||||||
def _jax_mapping_to_torch(
|
|
||||||
value: Mapping[str, Any], device: Device | None = None
|
|
||||||
) -> Mapping[str, Any]:
|
|
||||||
"""Converts a mapping of Jax Array into a Dictionary of PyTorch Tensors."""
|
|
||||||
return type(value)(**{k: jax_to_torch(v, device) for k, v in value.items()})
|
|
||||||
|
|
||||||
|
|
||||||
@jax_to_torch.register(abc.Iterable)
|
|
||||||
def _jax_iterable_to_torch(
|
|
||||||
value: Iterable[Any], device: Device | None = None
|
|
||||||
) -> Iterable[Any]:
|
|
||||||
"""Converts an Iterable from Jax Array to an iterable of PyTorch Tensors."""
|
|
||||||
if isinstance(value, jax.Array):
|
|
||||||
# Since the update to jax 0.6.0, calling jax_to_torch with a <class 'jaxlib.xla_extension.ArrayImpl'>
|
|
||||||
# argument wrongly dispatches to _iterable_jax_to_torch which fails with:
|
|
||||||
# TypeError: (): incompatible function arguments.
|
|
||||||
# See: https://github.com/Farama-Foundation/Gymnasium/issues/1360
|
|
||||||
return _devicearray_jax_to_torch(value)
|
|
||||||
elif hasattr(value, "_make"):
|
|
||||||
# namedtuple - underline used to prevent potential name conflicts
|
|
||||||
# noinspection PyProtectedMember
|
|
||||||
return type(value)._make(jax_to_torch(v, device) for v in value)
|
|
||||||
else:
|
|
||||||
return type(value)(jax_to_torch(v, device) for v in value)
|
|
||||||
|
|
||||||
|
|
||||||
@jax_to_torch.register(_NoneType)
|
|
||||||
def _none_jax_to_torch(value: None, device: Device | None = None) -> None:
|
|
||||||
"""Passes through None values."""
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
class JaxToTorch(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
|
||||||
"""Wraps a Jax-based environment so that it can be interacted with PyTorch Tensors.
|
"""Wraps a Jax-based environment so that it can be interacted with PyTorch Tensors.
|
||||||
|
|
||||||
Actions must be provided as PyTorch Tensors and observations will be returned as PyTorch Tensors.
|
Actions must be provided as PyTorch Tensors and observations will be returned as PyTorch Tensors.
|
||||||
@@ -183,50 +87,8 @@ class JaxToTorch(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
|||||||
env: The Jax-based environment to wrap
|
env: The Jax-based environment to wrap
|
||||||
device: The device the torch Tensors should be moved to
|
device: The device the torch Tensors should be moved to
|
||||||
"""
|
"""
|
||||||
gym.utils.RecordConstructorArgs.__init__(self, device=device)
|
super().__init__(env=env, env_xp=jnp, target_xp=torch, target_device=device)
|
||||||
gym.Wrapper.__init__(self, env)
|
|
||||||
|
|
||||||
|
# TODO: Device was part of the public API, but should be removed in favor of _env_device and
|
||||||
|
# _target_device.
|
||||||
self.device: Device | None = device
|
self.device: Device | None = device
|
||||||
|
|
||||||
def step(
|
|
||||||
self, action: WrapperActType
|
|
||||||
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]:
|
|
||||||
"""Performs the given action within the environment.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
action: The action to perform as a PyTorch Tensor
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The next observation, reward, termination, truncation, and extra info
|
|
||||||
"""
|
|
||||||
jax_action = torch_to_jax(action)
|
|
||||||
obs, reward, terminated, truncated, info = self.env.step(jax_action)
|
|
||||||
|
|
||||||
return (
|
|
||||||
jax_to_torch(obs, self.device),
|
|
||||||
float(reward),
|
|
||||||
bool(terminated),
|
|
||||||
bool(truncated),
|
|
||||||
jax_to_torch(info, self.device),
|
|
||||||
)
|
|
||||||
|
|
||||||
def reset(
|
|
||||||
self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
|
||||||
) -> tuple[WrapperObsType, dict[str, Any]]:
|
|
||||||
"""Resets the environment returning PyTorch-based observation and info.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
seed: The seed for resetting the environment
|
|
||||||
options: The options for resetting the environment, these are converted to jax arrays.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
PyTorch-based observations and info
|
|
||||||
"""
|
|
||||||
if options:
|
|
||||||
options = torch_to_jax(options)
|
|
||||||
|
|
||||||
return jax_to_torch(self.env.reset(seed=seed, options=options), self.device)
|
|
||||||
|
|
||||||
def render(self) -> RenderFrame | list[RenderFrame] | None:
|
|
||||||
"""Returns the rendered frames as a torch tensor."""
|
|
||||||
return jax_to_torch(self.env.render())
|
|
||||||
|
@@ -3,15 +3,17 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
import numbers
|
from typing import Union
|
||||||
from collections import abc
|
|
||||||
from typing import Any, Iterable, Mapping, SupportsFloat, Union
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
from gymnasium.core import RenderFrame, WrapperActType, WrapperObsType
|
|
||||||
from gymnasium.error import DependencyNotInstalled
|
from gymnasium.error import DependencyNotInstalled
|
||||||
|
from gymnasium.wrappers.array_conversion import (
|
||||||
|
ArrayConversion,
|
||||||
|
array_conversion,
|
||||||
|
module_namespace,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -26,100 +28,13 @@ except ImportError:
|
|||||||
|
|
||||||
__all__ = ["NumpyToTorch", "torch_to_numpy", "numpy_to_torch", "Device"]
|
__all__ = ["NumpyToTorch", "torch_to_numpy", "numpy_to_torch", "Device"]
|
||||||
|
|
||||||
# The NoneType is not defined in Python 3.9. Remove when the minimal version is bumped to >=3.10
|
|
||||||
_NoneType = type(None)
|
torch_to_numpy = functools.partial(array_conversion, xp=module_namespace(np))
|
||||||
|
|
||||||
|
numpy_to_torch = functools.partial(array_conversion, xp=module_namespace(torch))
|
||||||
|
|
||||||
|
|
||||||
@functools.singledispatch
|
class NumpyToTorch(ArrayConversion):
|
||||||
def torch_to_numpy(value: Any) -> Any:
|
|
||||||
"""Converts a PyTorch Tensor into a NumPy Array."""
|
|
||||||
raise Exception(
|
|
||||||
f"No known conversion for Torch type ({type(value)}) to NumPy registered. Report as issue on github."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@torch_to_numpy.register(numbers.Number)
|
|
||||||
def _number_to_numpy(value: numbers.Number) -> Any:
|
|
||||||
"""Convert a python number (int, float, complex) to a NumPy array."""
|
|
||||||
return np.array(value)
|
|
||||||
|
|
||||||
|
|
||||||
@torch_to_numpy.register(torch.Tensor)
|
|
||||||
def _torch_to_numpy(value: torch.Tensor) -> Any:
|
|
||||||
"""Convert a torch.Tensor to a NumPy array."""
|
|
||||||
return value.numpy(force=True)
|
|
||||||
|
|
||||||
|
|
||||||
@torch_to_numpy.register(abc.Mapping)
|
|
||||||
def _mapping_torch_to_numpy(value: Mapping[str, Any]) -> Mapping[str, Any]:
|
|
||||||
"""Converts a mapping of PyTorch Tensors into a Dictionary of NumPy Array."""
|
|
||||||
return type(value)(**{k: torch_to_numpy(v) for k, v in value.items()})
|
|
||||||
|
|
||||||
|
|
||||||
@torch_to_numpy.register(abc.Iterable)
|
|
||||||
def _iterable_torch_to_numpy(value: Iterable[Any]) -> Iterable[Any]:
|
|
||||||
"""Converts an Iterable from PyTorch Tensors to an iterable of NumPy Array."""
|
|
||||||
if hasattr(value, "_make"):
|
|
||||||
# namedtuple - underline used to prevent potential name conflicts
|
|
||||||
# noinspection PyProtectedMember
|
|
||||||
return type(value)._make(torch_to_numpy(v) for v in value)
|
|
||||||
else:
|
|
||||||
return type(value)(torch_to_numpy(v) for v in value)
|
|
||||||
|
|
||||||
|
|
||||||
@torch_to_numpy.register(_NoneType)
|
|
||||||
def _none_torch_to_numpy(value: None) -> None:
|
|
||||||
"""Passes through None values."""
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
@functools.singledispatch
|
|
||||||
def numpy_to_torch(value: Any, device: Device | None = None) -> Any:
|
|
||||||
"""Converts a NumPy Array into a PyTorch Tensor."""
|
|
||||||
raise Exception(
|
|
||||||
f"No known conversion for NumPy type ({type(value)}) to PyTorch registered. Report as issue on github."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@numpy_to_torch.register(numbers.Number)
|
|
||||||
@numpy_to_torch.register(np.ndarray)
|
|
||||||
def _numpy_to_torch(value: np.ndarray, device: Device | None = None) -> torch.Tensor:
|
|
||||||
"""Converts a NumPy Array into a PyTorch Tensor."""
|
|
||||||
assert torch is not None
|
|
||||||
tensor = torch.tensor(value)
|
|
||||||
if device:
|
|
||||||
return tensor.to(device=device)
|
|
||||||
return tensor
|
|
||||||
|
|
||||||
|
|
||||||
@numpy_to_torch.register(abc.Mapping)
|
|
||||||
def _numpy_mapping_to_torch(
|
|
||||||
value: Mapping[str, Any], device: Device | None = None
|
|
||||||
) -> Mapping[str, Any]:
|
|
||||||
"""Converts a mapping of NumPy Array into a Dictionary of PyTorch Tensors."""
|
|
||||||
return type(value)(**{k: numpy_to_torch(v, device) for k, v in value.items()})
|
|
||||||
|
|
||||||
|
|
||||||
@numpy_to_torch.register(abc.Iterable)
|
|
||||||
def _numpy_iterable_to_torch(
|
|
||||||
value: Iterable[Any], device: Device | None = None
|
|
||||||
) -> Iterable[Any]:
|
|
||||||
"""Converts an Iterable from NumPy Array to an iterable of PyTorch Tensors."""
|
|
||||||
if hasattr(value, "_make"):
|
|
||||||
# namedtuple - underline used to prevent potential name conflicts
|
|
||||||
# noinspection PyProtectedMember
|
|
||||||
return type(value)._make(numpy_to_torch(v, device) for v in value)
|
|
||||||
else:
|
|
||||||
return type(value)(numpy_to_torch(v, device) for v in value)
|
|
||||||
|
|
||||||
|
|
||||||
@numpy_to_torch.register(_NoneType)
|
|
||||||
def _none_numpy_to_torch(value: None) -> None:
|
|
||||||
"""Passes through None values."""
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
class NumpyToTorch(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
|
||||||
"""Wraps a NumPy-based environment such that it can be interacted with PyTorch Tensors.
|
"""Wraps a NumPy-based environment such that it can be interacted with PyTorch Tensors.
|
||||||
|
|
||||||
Actions must be provided as PyTorch Tensors and observations will be returned as PyTorch Tensors.
|
Actions must be provided as PyTorch Tensors and observations will be returned as PyTorch Tensors.
|
||||||
@@ -158,50 +73,6 @@ class NumpyToTorch(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
|||||||
env: The NumPy-based environment to wrap
|
env: The NumPy-based environment to wrap
|
||||||
device: The device the torch Tensors should be moved to
|
device: The device the torch Tensors should be moved to
|
||||||
"""
|
"""
|
||||||
gym.utils.RecordConstructorArgs.__init__(self, device=device)
|
super().__init__(env=env, env_xp=np, target_xp=torch, target_device=device)
|
||||||
gym.Wrapper.__init__(self, env)
|
|
||||||
|
|
||||||
self.device: Device | None = device
|
self.device: Device | None = device
|
||||||
|
|
||||||
def step(
|
|
||||||
self, action: WrapperActType
|
|
||||||
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]:
|
|
||||||
"""Using a PyTorch based action that is converted to NumPy to be used by the environment.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
action: A PyTorch-based action
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The PyTorch-based Tensor next observation, reward, termination, truncation, and extra info
|
|
||||||
"""
|
|
||||||
jax_action = torch_to_numpy(action)
|
|
||||||
obs, reward, terminated, truncated, info = self.env.step(jax_action)
|
|
||||||
|
|
||||||
return (
|
|
||||||
numpy_to_torch(obs, self.device),
|
|
||||||
float(reward),
|
|
||||||
bool(terminated),
|
|
||||||
bool(truncated),
|
|
||||||
numpy_to_torch(info, self.device),
|
|
||||||
)
|
|
||||||
|
|
||||||
def reset(
|
|
||||||
self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
|
||||||
) -> tuple[WrapperObsType, dict[str, Any]]:
|
|
||||||
"""Resets the environment returning PyTorch-based observation and info.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
seed: The seed for resetting the environment
|
|
||||||
options: The options for resetting the environment, these are converted to jax arrays.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
PyTorch-based observations and info
|
|
||||||
"""
|
|
||||||
if options:
|
|
||||||
options = torch_to_numpy(options)
|
|
||||||
|
|
||||||
return numpy_to_torch(self.env.reset(seed=seed, options=options), self.device)
|
|
||||||
|
|
||||||
def render(self) -> RenderFrame | list[RenderFrame] | None:
|
|
||||||
"""Returns the rendered frames as a torch tensor."""
|
|
||||||
return numpy_to_torch(self.env.render())
|
|
||||||
|
131
gymnasium/wrappers/vector/array_conversion.py
Normal file
131
gymnasium/wrappers/vector/array_conversion.py
Normal file
@@ -0,0 +1,131 @@
|
|||||||
|
"""Vector wrapper for converting between Array API compatible frameworks."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from types import ModuleType
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import gymnasium as gym
|
||||||
|
from gymnasium.core import ActType, ObsType
|
||||||
|
from gymnasium.vector import VectorEnv, VectorWrapper
|
||||||
|
from gymnasium.vector.vector_env import ArrayType
|
||||||
|
from gymnasium.wrappers.array_conversion import (
|
||||||
|
Device,
|
||||||
|
array_conversion,
|
||||||
|
module_name_to_namespace,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["ArrayConversion"]
|
||||||
|
|
||||||
|
|
||||||
|
class ArrayConversion(VectorWrapper, gym.utils.RecordConstructorArgs):
|
||||||
|
"""Wraps a vector environment returning Array API compatible arrays so that it can be interacted with through a specific framework.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
A vectorized version of ``gymnasium.wrappers.ArrayConversion``
|
||||||
|
|
||||||
|
Actions must be provided as Array API compatible arrays and observations, rewards, terminations and truncations will be returned in the desired framework.
|
||||||
|
xp here is a module that is compatible with the Array API standard, e.g. ``numpy``, ``jax`` etc.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> import gymnasium as gym # doctest: +SKIP
|
||||||
|
>>> envs = gym.make_vec("JaxEnv-vx", 3) # doctest: +SKIP
|
||||||
|
>>> envs = ArrayConversion(envs, xp=np) # doctest: +SKIP
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
env: VectorEnv,
|
||||||
|
env_xp: ModuleType | str,
|
||||||
|
target_xp: ModuleType | str,
|
||||||
|
env_device: Device | None = None,
|
||||||
|
target_device: Device | None = None,
|
||||||
|
):
|
||||||
|
"""Wrapper class to change inputs and outputs of environment to any Array API framework.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env: The Array API compatible environment to wrap
|
||||||
|
env_xp: The Array API framework the environment is on
|
||||||
|
target_xp: The Array API framework to convert to
|
||||||
|
env_device: The device the environment is on
|
||||||
|
target_device: The device on which Arrays should be returned
|
||||||
|
"""
|
||||||
|
gym.utils.RecordConstructorArgs.__init__(self)
|
||||||
|
VectorWrapper.__init__(self, env)
|
||||||
|
self._env_xp = env_xp
|
||||||
|
self._target_xp = target_xp
|
||||||
|
self._env_device = env_device
|
||||||
|
self._target_device = target_device
|
||||||
|
|
||||||
|
def step(
|
||||||
|
self, actions: ActType
|
||||||
|
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]:
|
||||||
|
"""Transforms the action to the specified xp module array type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
actions: The action to perform
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple containing xp versions of the next observation, reward, termination, truncation, and extra info.
|
||||||
|
"""
|
||||||
|
actions = array_conversion(actions, xp=self._env_xp, device=self._env_device)
|
||||||
|
obs, reward, terminated, truncated, info = self.env.step(actions)
|
||||||
|
|
||||||
|
return (
|
||||||
|
array_conversion(obs, xp=self._target_xp, device=self._target_device),
|
||||||
|
array_conversion(reward, xp=self._target_xp, device=self._target_device),
|
||||||
|
array_conversion(
|
||||||
|
terminated, xp=self._target_xp, device=self._target_device
|
||||||
|
),
|
||||||
|
array_conversion(truncated, xp=self._target_xp, device=self._target_device),
|
||||||
|
array_conversion(info, xp=self._target_xp, device=self._target_device),
|
||||||
|
)
|
||||||
|
|
||||||
|
def reset(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
seed: int | list[int] | None = None,
|
||||||
|
options: dict[str, Any] | None = None,
|
||||||
|
) -> tuple[ObsType, dict[str, Any]]:
|
||||||
|
"""Resets the environment returning xp-based observation and info.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seed: The seed for resetting the environment
|
||||||
|
options: The options for resetting the environment, these are converted to xp arrays.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
xp-based observations and info
|
||||||
|
"""
|
||||||
|
if options:
|
||||||
|
options = array_conversion(
|
||||||
|
options, xp=self._env_xp, device=self._env_device
|
||||||
|
)
|
||||||
|
|
||||||
|
return array_conversion(
|
||||||
|
self.env.reset(seed=seed, options=options),
|
||||||
|
xp=self._target_xp,
|
||||||
|
device=self._target_device,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __getstate__(self):
|
||||||
|
"""Returns the object pickle state with args and kwargs."""
|
||||||
|
env_xp_name = self._env_xp.__name__.replace("array_api_compat.", "")
|
||||||
|
target_xp_name = self._target_xp.__name__.replace("array_api_compat.", "")
|
||||||
|
env_device = self._env_device
|
||||||
|
target_device = self._target_device
|
||||||
|
return {
|
||||||
|
"env_xp_name": env_xp_name,
|
||||||
|
"target_xp_name": target_xp_name,
|
||||||
|
"env_device": env_device,
|
||||||
|
"target_device": target_device,
|
||||||
|
"env": self.env,
|
||||||
|
}
|
||||||
|
|
||||||
|
def __setstate__(self, d):
|
||||||
|
"""Sets the object pickle state using d."""
|
||||||
|
self.env = d["env"]
|
||||||
|
self._env_xp = module_name_to_namespace(d["env_xp_name"])
|
||||||
|
self._target_xp = module_name_to_namespace(d["target_xp_name"])
|
||||||
|
self._env_device = d["env_device"]
|
||||||
|
self._target_device = d["target_device"]
|
@@ -2,21 +2,18 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from gymnasium.core import ActType, ObsType
|
|
||||||
from gymnasium.error import DependencyNotInstalled
|
from gymnasium.error import DependencyNotInstalled
|
||||||
from gymnasium.vector import VectorEnv, VectorWrapper
|
from gymnasium.vector import VectorEnv
|
||||||
from gymnasium.vector.vector_env import ArrayType
|
from gymnasium.wrappers.vector.array_conversion import ArrayConversion
|
||||||
from gymnasium.wrappers.jax_to_numpy import jax_to_numpy, numpy_to_jax
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["JaxToNumpy"]
|
__all__ = ["JaxToNumpy"]
|
||||||
|
|
||||||
|
|
||||||
class JaxToNumpy(VectorWrapper):
|
class JaxToNumpy(ArrayConversion):
|
||||||
"""Wraps a jax vector environment so that it can be interacted with through numpy arrays.
|
"""Wraps a jax vector environment so that it can be interacted with through numpy arrays.
|
||||||
|
|
||||||
Notes:
|
Notes:
|
||||||
@@ -40,46 +37,4 @@ class JaxToNumpy(VectorWrapper):
|
|||||||
raise DependencyNotInstalled(
|
raise DependencyNotInstalled(
|
||||||
'Jax is not installed, run `pip install "gymnasium[jax]"`'
|
'Jax is not installed, run `pip install "gymnasium[jax]"`'
|
||||||
)
|
)
|
||||||
super().__init__(env)
|
super().__init__(env, env_xp=jnp, target_xp=np)
|
||||||
|
|
||||||
def step(
|
|
||||||
self, actions: ActType
|
|
||||||
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]:
|
|
||||||
"""Transforms the action to a jax array .
|
|
||||||
|
|
||||||
Args:
|
|
||||||
actions: the action to perform as a numpy array
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A tuple containing numpy versions of the next observation, reward, termination, truncation, and extra info.
|
|
||||||
"""
|
|
||||||
jax_actions = numpy_to_jax(actions)
|
|
||||||
obs, reward, terminated, truncated, info = self.env.step(jax_actions)
|
|
||||||
|
|
||||||
return (
|
|
||||||
jax_to_numpy(obs),
|
|
||||||
jax_to_numpy(reward),
|
|
||||||
jax_to_numpy(terminated),
|
|
||||||
jax_to_numpy(truncated),
|
|
||||||
jax_to_numpy(info),
|
|
||||||
)
|
|
||||||
|
|
||||||
def reset(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
seed: int | list[int] | None = None,
|
|
||||||
options: dict[str, Any] | None = None,
|
|
||||||
) -> tuple[ObsType, dict[str, Any]]:
|
|
||||||
"""Resets the environment returning numpy-based observation and info.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
seed: The seed for resetting the environment
|
|
||||||
options: The options for resetting the environment, these are converted to jax arrays.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Numpy-based observations and info
|
|
||||||
"""
|
|
||||||
if options:
|
|
||||||
options = numpy_to_jax(options)
|
|
||||||
|
|
||||||
return jax_to_numpy(self.env.reset(seed=seed, options=options))
|
|
||||||
|
@@ -2,18 +2,18 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any
|
import jax.numpy as jnp
|
||||||
|
import torch
|
||||||
|
|
||||||
from gymnasium.core import ActType, ObsType
|
from gymnasium.vector import VectorEnv
|
||||||
from gymnasium.vector import VectorEnv, VectorWrapper
|
from gymnasium.wrappers.jax_to_torch import Device
|
||||||
from gymnasium.vector.vector_env import ArrayType
|
from gymnasium.wrappers.vector.array_conversion import ArrayConversion
|
||||||
from gymnasium.wrappers.jax_to_torch import Device, jax_to_torch, torch_to_jax
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["JaxToTorch"]
|
__all__ = ["JaxToTorch"]
|
||||||
|
|
||||||
|
|
||||||
class JaxToTorch(VectorWrapper):
|
class JaxToTorch(ArrayConversion):
|
||||||
"""Wraps a Jax-based vector environment so that it can be interacted with through PyTorch Tensors.
|
"""Wraps a Jax-based vector environment so that it can be interacted with through PyTorch Tensors.
|
||||||
|
|
||||||
Actions must be provided as PyTorch Tensors and observations, rewards, terminations and truncations will be returned as PyTorch Tensors.
|
Actions must be provided as PyTorch Tensors and observations, rewards, terminations and truncations will be returned as PyTorch Tensors.
|
||||||
@@ -31,48 +31,6 @@ class JaxToTorch(VectorWrapper):
|
|||||||
env: The Jax-based vector environment to wrap
|
env: The Jax-based vector environment to wrap
|
||||||
device: The device the torch Tensors should be moved to
|
device: The device the torch Tensors should be moved to
|
||||||
"""
|
"""
|
||||||
super().__init__(env)
|
super().__init__(env, env_xp=jnp, target_xp=torch, target_device=device)
|
||||||
|
|
||||||
self.device: Device | None = device
|
self.device: Device | None = device
|
||||||
|
|
||||||
def step(
|
|
||||||
self, actions: ActType
|
|
||||||
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]:
|
|
||||||
"""Performs the given action within the environment.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
actions: The action to perform as a PyTorch Tensor
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Torch-based Tensors of the next observation, reward, termination, truncation, and extra info
|
|
||||||
"""
|
|
||||||
jax_action = torch_to_jax(actions)
|
|
||||||
obs, reward, terminated, truncated, info = self.env.step(jax_action)
|
|
||||||
|
|
||||||
return (
|
|
||||||
jax_to_torch(obs, self.device),
|
|
||||||
jax_to_torch(reward, self.device),
|
|
||||||
jax_to_torch(terminated, self.device),
|
|
||||||
jax_to_torch(truncated, self.device),
|
|
||||||
jax_to_torch(info, self.device),
|
|
||||||
)
|
|
||||||
|
|
||||||
def reset(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
seed: int | list[int] | None = None,
|
|
||||||
options: dict[str, Any] | None = None,
|
|
||||||
) -> tuple[ObsType, dict[str, Any]]:
|
|
||||||
"""Resets the environment returning PyTorch-based observation and info.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
seed: The seed for resetting the environment
|
|
||||||
options: The options for resetting the environment, these are converted to jax arrays.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
PyTorch-based observations and info
|
|
||||||
"""
|
|
||||||
if options:
|
|
||||||
options = torch_to_jax(options)
|
|
||||||
|
|
||||||
return jax_to_torch(self.env.reset(seed=seed, options=options), self.device)
|
|
||||||
|
@@ -2,18 +2,18 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
from gymnasium.core import ActType, ObsType
|
from gymnasium.vector import VectorEnv
|
||||||
from gymnasium.vector import VectorEnv, VectorWrapper
|
from gymnasium.wrappers.numpy_to_torch import Device
|
||||||
from gymnasium.vector.vector_env import ArrayType
|
from gymnasium.wrappers.vector.array_conversion import ArrayConversion
|
||||||
from gymnasium.wrappers.numpy_to_torch import Device, numpy_to_torch, torch_to_numpy
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["NumpyToTorch"]
|
__all__ = ["NumpyToTorch"]
|
||||||
|
|
||||||
|
|
||||||
class NumpyToTorch(VectorWrapper):
|
class NumpyToTorch(ArrayConversion):
|
||||||
"""Wraps a numpy-based environment so that it can be interacted with through PyTorch Tensors.
|
"""Wraps a numpy-based environment so that it can be interacted with through PyTorch Tensors.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
@@ -45,48 +45,6 @@ class NumpyToTorch(VectorWrapper):
|
|||||||
env: The NumPy-based vector environment to wrap
|
env: The NumPy-based vector environment to wrap
|
||||||
device: The device the torch Tensors should be moved to
|
device: The device the torch Tensors should be moved to
|
||||||
"""
|
"""
|
||||||
super().__init__(env)
|
super().__init__(env, env_xp=np, target_xp=torch, target_device=device)
|
||||||
|
|
||||||
self.device: Device | None = device
|
self.device: Device | None = device
|
||||||
|
|
||||||
def step(
|
|
||||||
self, actions: ActType
|
|
||||||
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]:
|
|
||||||
"""Using a PyTorch based action that is converted to NumPy to be used by the environment.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
action: A PyTorch-based action
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The PyTorch-based Tensor next observation, reward, termination, truncation, and extra info
|
|
||||||
"""
|
|
||||||
numpy_action = torch_to_numpy(actions)
|
|
||||||
obs, reward, terminated, truncated, info = self.env.step(numpy_action)
|
|
||||||
|
|
||||||
return (
|
|
||||||
numpy_to_torch(obs, self.device),
|
|
||||||
numpy_to_torch(reward, self.device),
|
|
||||||
numpy_to_torch(terminated, self.device),
|
|
||||||
numpy_to_torch(truncated, self.device),
|
|
||||||
numpy_to_torch(info, self.device),
|
|
||||||
)
|
|
||||||
|
|
||||||
def reset(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
seed: int | list[int] | None = None,
|
|
||||||
options: dict[str, Any] | None = None,
|
|
||||||
) -> tuple[ObsType, dict[str, Any]]:
|
|
||||||
"""Resets the environment returning PyTorch-based observation and info.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
seed: The seed for resetting the environment
|
|
||||||
options: The options for resetting the environment, these are converted to NumPy arrays.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
PyTorch-based observations and info
|
|
||||||
"""
|
|
||||||
if options:
|
|
||||||
options = torch_to_numpy(options)
|
|
||||||
|
|
||||||
return numpy_to_torch(self.env.reset(seed=seed, options=options), self.device)
|
|
||||||
|
@@ -8,7 +8,7 @@ build-backend = "setuptools.build_meta"
|
|||||||
name = "gymnasium"
|
name = "gymnasium"
|
||||||
description = "A standard API for reinforcement learning and a diverse set of reference environments (formerly Gym)."
|
description = "A standard API for reinforcement learning and a diverse set of reference environments (formerly Gym)."
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">= 3.8"
|
requires-python = ">= 3.10"
|
||||||
authors = [{ name = "Farama Foundation", email = "contact@farama.org" }]
|
authors = [{ name = "Farama Foundation", email = "contact@farama.org" }]
|
||||||
license = { text = "MIT License" }
|
license = { text = "MIT License" }
|
||||||
keywords = ["Reinforcement Learning", "game", "RL", "AI", "gymnasium"]
|
keywords = ["Reinforcement Learning", "game", "RL", "AI", "gymnasium"]
|
||||||
@@ -16,8 +16,6 @@ classifiers = [
|
|||||||
"Development Status :: 5 - Production/Stable",
|
"Development Status :: 5 - Production/Stable",
|
||||||
"License :: OSI Approved :: MIT License",
|
"License :: OSI Approved :: MIT License",
|
||||||
"Programming Language :: Python :: 3",
|
"Programming Language :: Python :: 3",
|
||||||
"Programming Language :: Python :: 3.8",
|
|
||||||
"Programming Language :: Python :: 3.9",
|
|
||||||
"Programming Language :: Python :: 3.10",
|
"Programming Language :: Python :: 3.10",
|
||||||
"Programming Language :: Python :: 3.11",
|
"Programming Language :: Python :: 3.11",
|
||||||
"Programming Language :: Python :: 3.12",
|
"Programming Language :: Python :: 3.12",
|
||||||
@@ -27,7 +25,6 @@ classifiers = [
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"numpy >=1.21.0",
|
"numpy >=1.21.0",
|
||||||
"cloudpickle >=1.2.0",
|
"cloudpickle >=1.2.0",
|
||||||
"importlib-metadata >=4.8.0; python_version < '3.10'",
|
|
||||||
"typing-extensions >=4.3.0",
|
"typing-extensions >=4.3.0",
|
||||||
"farama-notifications >=0.0.1",
|
"farama-notifications >=0.0.1",
|
||||||
]
|
]
|
||||||
@@ -38,15 +35,30 @@ dynamic = ["version"]
|
|||||||
atari = ["ale_py >=0.9"]
|
atari = ["ale_py >=0.9"]
|
||||||
box2d = ["box2d-py ==2.3.5", "pygame >=2.1.3", "swig ==4.*"]
|
box2d = ["box2d-py ==2.3.5", "pygame >=2.1.3", "swig ==4.*"]
|
||||||
classic-control = ["pygame >=2.1.3"]
|
classic-control = ["pygame >=2.1.3"]
|
||||||
classic_control = ["pygame >=2.1.3"] # kept for backward compatibility
|
classic_control = ["pygame >=2.1.3"] # kept for backward compatibility
|
||||||
mujoco-py = ["mujoco-py >=2.1,<2.2", "cython<3"]
|
mujoco-py = ["mujoco-py >=2.1,<2.2", "cython<3"]
|
||||||
mujoco_py = ["mujoco-py >=2.1,<2.2", "cython<3"] # kept for backward compatibility
|
mujoco_py = [
|
||||||
|
"mujoco-py >=2.1,<2.2",
|
||||||
|
"cython<3",
|
||||||
|
] # kept for backward compatibility
|
||||||
mujoco = ["mujoco >=2.1.5", "imageio >=2.14.1", "packaging >=23.0"]
|
mujoco = ["mujoco >=2.1.5", "imageio >=2.14.1", "packaging >=23.0"]
|
||||||
toy-text = ["pygame >=2.1.3"]
|
toy-text = ["pygame >=2.1.3"]
|
||||||
toy_text = ["pygame >=2.1.3"] # kept for backward compatibility
|
toy_text = ["pygame >=2.1.3"] # kept for backward compatibility
|
||||||
jax = ["jax >=0.4.16", "jaxlib >=0.4.16", "flax >=0.5.0"]
|
jax = [
|
||||||
torch = ["torch >=1.13.0"]
|
"jax >=0.4.16",
|
||||||
other = ["moviepy >=1.0.0", "matplotlib >=3.0", "opencv-python >=3.0", "seaborn >= 0.13"]
|
"jaxlib >=0.4.16",
|
||||||
|
"flax >=0.5.0",
|
||||||
|
"array-api-compat >=1.11.0",
|
||||||
|
"numpy>=2.1",
|
||||||
|
]
|
||||||
|
torch = ["torch >=1.13.0", "array-api-compat >=1.11.0", "numpy>=2.1"]
|
||||||
|
array-api = ["array-api-compat >=1.11.0", "numpy>=2.1"]
|
||||||
|
other = [
|
||||||
|
"moviepy >=1.0.0",
|
||||||
|
"matplotlib >=3.0",
|
||||||
|
"opencv-python >=3.0",
|
||||||
|
"seaborn >= 0.13",
|
||||||
|
]
|
||||||
all = [
|
all = [
|
||||||
# All dependencies above except accept-rom-license
|
# All dependencies above except accept-rom-license
|
||||||
# NOTE: No need to manually remove the duplicates, setuptools automatically does that.
|
# NOTE: No need to manually remove the duplicates, setuptools automatically does that.
|
||||||
@@ -71,17 +83,26 @@ all = [
|
|||||||
"jax >=0.4.16",
|
"jax >=0.4.16",
|
||||||
"jaxlib >=0.4.16",
|
"jaxlib >=0.4.16",
|
||||||
"flax >= 0.5.0",
|
"flax >= 0.5.0",
|
||||||
|
"array-api-compat >=1.11.0",
|
||||||
|
"numpy>=2.1",
|
||||||
# torch
|
# torch
|
||||||
"torch >=1.13.0",
|
"torch >=1.13.0",
|
||||||
|
"array-api-compat >=1.11.0",
|
||||||
|
"numpy>=2.1",
|
||||||
|
# array-api
|
||||||
|
"array-api-compat >=1.11.0",
|
||||||
|
"numpy>=2.1",
|
||||||
# other
|
# other
|
||||||
"opencv-python >=3.0",
|
"opencv-python >=3.0",
|
||||||
"matplotlib >=3.0",
|
"matplotlib >=3.0",
|
||||||
"moviepy >=1.0.0",
|
"moviepy >=1.0.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
testing = [
|
testing = [
|
||||||
"pytest >=7.1.3",
|
"pytest >=7.1.3",
|
||||||
"scipy >=1.7.3",
|
"scipy >=1.7.3",
|
||||||
"dill >=0.3.7",
|
"dill >=0.3.7",
|
||||||
|
"array_api_extra >=0.7.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.urls]
|
[project.urls]
|
||||||
@@ -125,7 +146,7 @@ exclude = ["tests/**", "**/node_modules", "**/__pycache__"]
|
|||||||
strict = []
|
strict = []
|
||||||
|
|
||||||
typeCheckingMode = "basic"
|
typeCheckingMode = "basic"
|
||||||
pythonVersion = "3.8"
|
pythonVersion = "3.10"
|
||||||
pythonPlatform = "All"
|
pythonPlatform = "All"
|
||||||
typeshedPath = "typeshed"
|
typeshedPath = "typeshed"
|
||||||
enableTypeIgnoreComments = true
|
enableTypeIgnoreComments = true
|
||||||
@@ -138,19 +159,19 @@ reportMissingTypeStubs = false
|
|||||||
# For warning and error, will raise an error when
|
# For warning and error, will raise an error when
|
||||||
reportInvalidTypeVarUse = "none"
|
reportInvalidTypeVarUse = "none"
|
||||||
|
|
||||||
reportGeneralTypeIssues = "none" # -> commented out raises 489 errors
|
reportGeneralTypeIssues = "none" # -> commented out raises 489 errors
|
||||||
reportAttributeAccessIssue = "none" # pyright provides false positives
|
reportAttributeAccessIssue = "none" # pyright provides false positives
|
||||||
reportArgumentType = "none" # pyright provides false positives
|
reportArgumentType = "none" # pyright provides false positives
|
||||||
|
|
||||||
reportPrivateUsage = "warning"
|
reportPrivateUsage = "warning"
|
||||||
|
|
||||||
reportIndexIssue = "none" # TODO fix one by one
|
reportIndexIssue = "none" # TODO fix one by one
|
||||||
reportReturnType = "none" # TODO fix one by one
|
reportReturnType = "none" # TODO fix one by one
|
||||||
reportCallIssue = "none" # TODO fix one by one
|
reportCallIssue = "none" # TODO fix one by one
|
||||||
reportOperatorIssue = "none" # TODO fix one by one
|
reportOperatorIssue = "none" # TODO fix one by one
|
||||||
reportInvalidTypeForm = "none" # TODO fix one by one
|
reportInvalidTypeForm = "none" # TODO fix one by one
|
||||||
reportOptionalMemberAccess = "none" # TODO fix one by one
|
reportOptionalMemberAccess = "none" # TODO fix one by one
|
||||||
reportAssignmentType = "none" # TODO fix one by one
|
reportAssignmentType = "none" # TODO fix one by one
|
||||||
|
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
|
@@ -6,10 +6,15 @@ import types
|
|||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
from gymnasium import spaces
|
from gymnasium import spaces
|
||||||
from gymnasium.core import ActType, ObsType
|
from gymnasium.core import ActType, ObsType
|
||||||
from gymnasium.envs.registration import EnvSpec
|
from gymnasium.envs.registration import EnvSpec
|
||||||
|
from gymnasium.vector import VectorEnv
|
||||||
|
from gymnasium.vector.utils import batch_space
|
||||||
|
from gymnasium.vector.vector_env import AutoresetMode
|
||||||
|
|
||||||
|
|
||||||
def basic_reset_func(
|
def basic_reset_func(
|
||||||
@@ -106,3 +111,112 @@ class GenericTestEnv(gym.Env):
|
|||||||
def render(self):
|
def render(self):
|
||||||
"""Renders the environment."""
|
"""Renders the environment."""
|
||||||
raise NotImplementedError("testingEnv render_fn is not set.")
|
raise NotImplementedError("testingEnv render_fn is not set.")
|
||||||
|
|
||||||
|
|
||||||
|
def basic_vector_reset_func(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
seed: int | None = None,
|
||||||
|
options: dict | None = None,
|
||||||
|
) -> tuple[ObsType, dict]:
|
||||||
|
"""A basic reset function that will pass the environment check using random actions from the observation space."""
|
||||||
|
super(GenericTestVectorEnv, self).reset(seed=seed)
|
||||||
|
self.observation_space.seed(self.np_random_seed)
|
||||||
|
return self.observation_space.sample(), {"options": options}
|
||||||
|
|
||||||
|
|
||||||
|
def basic_vector_step_func(
|
||||||
|
self, action: ActType
|
||||||
|
) -> tuple[ObsType, np.ndarray, np.ndarray, np.ndarray, dict]:
|
||||||
|
"""A step function that follows the basic step api that will pass the environment check using random actions from the observation space."""
|
||||||
|
obs = self.observation_space.sample()
|
||||||
|
rewards = np.zeros(self.num_envs, dtype=np.float64)
|
||||||
|
terminations = np.zeros(self.num_envs, dtype=np.bool_)
|
||||||
|
truncations = np.zeros(self.num_envs, dtype=np.bool_)
|
||||||
|
return obs, rewards, terminations, truncations, {}
|
||||||
|
|
||||||
|
|
||||||
|
def basic_vector_render_func(self):
|
||||||
|
"""Basic render fn that does nothing."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class GenericTestVectorEnv(VectorEnv):
|
||||||
|
"""A generic testing vector environment similar to GenericTestEnv.
|
||||||
|
|
||||||
|
Some tests cannot use SyncVectorEnv, e.g. when returning non-numpy arrays in the observations.
|
||||||
|
In these cases, GenericTestVectorEnv can be used to simulate a vector environment.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_envs: int = 1,
|
||||||
|
action_space: spaces.Space = spaces.Box(0, 1, (1,)),
|
||||||
|
observation_space: spaces.Space = spaces.Box(0, 1, (1,)),
|
||||||
|
reset_func: Callable = basic_vector_reset_func,
|
||||||
|
step_func: Callable = basic_vector_step_func,
|
||||||
|
render_func: Callable = basic_vector_render_func,
|
||||||
|
metadata: dict[str, Any] = {
|
||||||
|
"render_modes": [],
|
||||||
|
"autoreset_mode": AutoresetMode.NEXT_STEP,
|
||||||
|
},
|
||||||
|
render_mode: str | None = None,
|
||||||
|
spec: EnvSpec = EnvSpec(
|
||||||
|
"TestingVectorEnv-v0",
|
||||||
|
"tests.testing_env:GenericTestVectorEnv",
|
||||||
|
max_episode_steps=100,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
"""Generic testing vector environment constructor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_envs: The number of environments to create
|
||||||
|
action_space: The environment action space
|
||||||
|
observation_space: The environment observation space
|
||||||
|
reset_func: The environment reset function
|
||||||
|
step_func: The environment step function
|
||||||
|
render_func: The environment render function
|
||||||
|
metadata: The environment metadata
|
||||||
|
render_mode: The render mode of the environment
|
||||||
|
spec: The environment spec
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.num_envs = num_envs
|
||||||
|
self.metadata = metadata
|
||||||
|
self.render_mode = render_mode
|
||||||
|
self.spec = spec
|
||||||
|
|
||||||
|
# Set the single spaces and create batched spaces
|
||||||
|
self.single_observation_space = observation_space
|
||||||
|
self.single_action_space = action_space
|
||||||
|
self.observation_space = batch_space(observation_space, num_envs)
|
||||||
|
self.action_space = batch_space(action_space, num_envs)
|
||||||
|
|
||||||
|
# Bind the functions to the instance
|
||||||
|
if reset_func is not None:
|
||||||
|
self.reset = types.MethodType(reset_func, self)
|
||||||
|
if step_func is not None:
|
||||||
|
self.step = types.MethodType(step_func, self)
|
||||||
|
if render_func is not None:
|
||||||
|
self.render = types.MethodType(render_func, self)
|
||||||
|
|
||||||
|
def reset(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
seed: int | None = None,
|
||||||
|
options: dict | None = None,
|
||||||
|
) -> tuple[ObsType, dict]:
|
||||||
|
"""Resets the environment."""
|
||||||
|
# If you need a default working reset function, use `basic_vector_reset_fn` above
|
||||||
|
raise NotImplementedError("TestingVectorEnv reset_fn is not set.")
|
||||||
|
|
||||||
|
def step(
|
||||||
|
self, action: ActType
|
||||||
|
) -> tuple[ObsType, np.ndarray, np.ndarray, np.ndarray, dict]:
|
||||||
|
"""Steps through the environment."""
|
||||||
|
raise NotImplementedError("TestingVectorEnv step_fn is not set.")
|
||||||
|
|
||||||
|
def render(self) -> tuple[Any, ...] | None:
|
||||||
|
"""Renders the environment."""
|
||||||
|
raise NotImplementedError("TestingVectorEnv render_fn is not set.")
|
||||||
|
250
tests/wrappers/test_array_conversion.py
Normal file
250
tests/wrappers/test_array_conversion.py
Normal file
@@ -0,0 +1,250 @@
|
|||||||
|
"""Test suite for ArrayConversion wrapper."""
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
import itertools
|
||||||
|
import pickle
|
||||||
|
from typing import Any, NamedTuple
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import gymnasium
|
||||||
|
|
||||||
|
|
||||||
|
array_api_compat = pytest.importorskip("array_api_compat")
|
||||||
|
array_api_extra = pytest.importorskip("array_api_extra")
|
||||||
|
|
||||||
|
from array_api_compat import array_namespace, is_array_api_obj # noqa: E402
|
||||||
|
|
||||||
|
from gymnasium.wrappers.array_conversion import ( # noqa: E402
|
||||||
|
ArrayConversion,
|
||||||
|
array_conversion,
|
||||||
|
module_namespace,
|
||||||
|
)
|
||||||
|
from tests.testing_env import GenericTestEnv # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
|
# Define available modules
|
||||||
|
installed_modules = []
|
||||||
|
array_api_modules = [
|
||||||
|
"numpy",
|
||||||
|
"jax.numpy",
|
||||||
|
"torch",
|
||||||
|
"cupy",
|
||||||
|
"dask.array",
|
||||||
|
"sparse",
|
||||||
|
"array_api_strict",
|
||||||
|
]
|
||||||
|
for module in array_api_modules:
|
||||||
|
try:
|
||||||
|
installed_modules.append(importlib.import_module(module))
|
||||||
|
except ImportError:
|
||||||
|
pass # Modules that are not installed are skipped
|
||||||
|
|
||||||
|
installed_modules_combinations = list(itertools.permutations(installed_modules, 2))
|
||||||
|
|
||||||
|
|
||||||
|
def xp_data_equivalence(data_1, data_2) -> bool:
|
||||||
|
"""Return if two variables are equivalent that might contain ``torch.Tensor``."""
|
||||||
|
if type(data_1) is type(data_2):
|
||||||
|
if isinstance(data_1, dict):
|
||||||
|
return data_1.keys() == data_2.keys() and all(
|
||||||
|
xp_data_equivalence(data_1[k], data_2[k]) for k in data_1.keys()
|
||||||
|
)
|
||||||
|
elif isinstance(data_1, (tuple, list)):
|
||||||
|
return len(data_1) == len(data_2) and all(
|
||||||
|
xp_data_equivalence(o_1, o_2) for o_1, o_2 in zip(data_1, data_2)
|
||||||
|
)
|
||||||
|
elif is_array_api_obj(data_1):
|
||||||
|
return array_api_extra.isclose(data_1, data_2, atol=0.00001).all()
|
||||||
|
else:
|
||||||
|
return data_1 == data_2
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class ExampleNamedTuple(NamedTuple):
|
||||||
|
a: Any # Array API compatible object. Does not have proper typing support yet.
|
||||||
|
b: Any # Same as a
|
||||||
|
|
||||||
|
|
||||||
|
def _supports_higher_precision(xp, low_type, high_type):
|
||||||
|
"""Check if an array module supports higher precision type."""
|
||||||
|
return xp.result_type(low_type, high_type) == high_type
|
||||||
|
|
||||||
|
|
||||||
|
# When converting between array modules (source → target → source), we need to ensure that the
|
||||||
|
# precision used is supported by both modules. If either module only supports 32-bit types, we must
|
||||||
|
# use the lower precision to account for the conversion during the roundtrip.
|
||||||
|
def atleast_float32(source_xp, target_xp):
|
||||||
|
"""Return source_xp.float64 if both modules support it, otherwise source_xp.float32."""
|
||||||
|
source_supports_64 = _supports_higher_precision(
|
||||||
|
source_xp, source_xp.float32, source_xp.float64
|
||||||
|
)
|
||||||
|
target_supports_64 = _supports_higher_precision(
|
||||||
|
target_xp, target_xp.float32, target_xp.float64
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
source_xp.float64
|
||||||
|
if (source_supports_64 and target_supports_64)
|
||||||
|
else source_xp.float32
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def atleast_int32(source_xp, target_xp):
|
||||||
|
"""Return source_xp.int64 if both modules support it, otherwise source_xp.int32."""
|
||||||
|
source_supports_64 = _supports_higher_precision(
|
||||||
|
source_xp, source_xp.int32, source_xp.int64
|
||||||
|
)
|
||||||
|
target_supports_64 = _supports_higher_precision(
|
||||||
|
target_xp, target_xp.int32, target_xp.int64
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
source_xp.int64
|
||||||
|
if (source_supports_64 and target_supports_64)
|
||||||
|
else source_xp.int32
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def value_parametrization():
|
||||||
|
for source_xp, target_xp in installed_modules_combinations:
|
||||||
|
xp = module_namespace(source_xp)
|
||||||
|
source_xp = module_namespace(source_xp)
|
||||||
|
target_xp = module_namespace(target_xp)
|
||||||
|
for value, expected_value in [
|
||||||
|
(2, xp.asarray(2, dtype=atleast_int32(source_xp, target_xp))),
|
||||||
|
(
|
||||||
|
(3.0, 4),
|
||||||
|
(
|
||||||
|
xp.asarray(3.0, dtype=atleast_float32(source_xp, target_xp)),
|
||||||
|
xp.asarray(4, dtype=atleast_int32(source_xp, target_xp)),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
[3.0, 4],
|
||||||
|
[
|
||||||
|
xp.asarray(3.0, dtype=atleast_float32(source_xp, target_xp)),
|
||||||
|
xp.asarray(4, dtype=atleast_int32(source_xp, target_xp)),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
{
|
||||||
|
"a": 6.0,
|
||||||
|
"b": 7,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"a": xp.asarray(6.0, dtype=atleast_float32(source_xp, target_xp)),
|
||||||
|
"b": xp.asarray(7, dtype=atleast_int32(source_xp, target_xp)),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
(xp.asarray(1.0, dtype=xp.float32), xp.asarray(1.0, dtype=xp.float32)),
|
||||||
|
(xp.asarray(1.0, dtype=xp.uint8), xp.asarray(1.0, dtype=xp.uint8)),
|
||||||
|
(xp.asarray([1, 2], dtype=xp.int32), xp.asarray([1, 2], dtype=xp.int32)),
|
||||||
|
(
|
||||||
|
xp.asarray([[1.0], [2.0]], dtype=xp.int32),
|
||||||
|
xp.asarray([[1.0], [2.0]], dtype=xp.int32),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
{
|
||||||
|
"a": (
|
||||||
|
1,
|
||||||
|
xp.asarray(2.0, dtype=xp.float32),
|
||||||
|
xp.asarray([3, 4], dtype=xp.int32),
|
||||||
|
),
|
||||||
|
"b": {"c": 5},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"a": (
|
||||||
|
xp.asarray(1, dtype=atleast_int32(source_xp, target_xp)),
|
||||||
|
xp.asarray(2.0, dtype=xp.float32),
|
||||||
|
xp.asarray([3, 4], dtype=xp.int32),
|
||||||
|
),
|
||||||
|
"b": {
|
||||||
|
"c": xp.asarray(5, dtype=atleast_int32(source_xp, target_xp))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
ExampleNamedTuple(
|
||||||
|
a=xp.asarray([1, 2], dtype=xp.int32),
|
||||||
|
b=xp.asarray([1.0, 2.0], dtype=xp.float32),
|
||||||
|
),
|
||||||
|
ExampleNamedTuple(
|
||||||
|
a=xp.asarray([1, 2], dtype=xp.int32),
|
||||||
|
b=xp.asarray([1.0, 2.0], dtype=xp.float32),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
(None, None),
|
||||||
|
]:
|
||||||
|
yield (source_xp, target_xp, value, expected_value)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"source_xp,target_xp,value,expected_value", value_parametrization()
|
||||||
|
)
|
||||||
|
def test_roundtripping(source_xp, target_xp, value, expected_value):
|
||||||
|
"""Test roundtripping between different Array API compatible frameworks."""
|
||||||
|
roundtripped_value = array_conversion(
|
||||||
|
array_conversion(value, xp=target_xp), xp=source_xp
|
||||||
|
)
|
||||||
|
assert xp_data_equivalence(roundtripped_value, expected_value)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("env_xp, target_xp", installed_modules_combinations)
|
||||||
|
def test_array_conversion_wrapper(env_xp, target_xp):
|
||||||
|
# Define reset and step functions without partial to avoid pickling issues
|
||||||
|
|
||||||
|
def reset_func(self, seed=None, options=None):
|
||||||
|
"""A generic array API reset function."""
|
||||||
|
return env_xp.asarray([1.0, 2.0, 3.0]), {"data": env_xp.asarray([1, 2, 3])}
|
||||||
|
|
||||||
|
def step_func(self, action):
|
||||||
|
"""A generic array API step function."""
|
||||||
|
assert isinstance(action, type(env_xp.zeros(1)))
|
||||||
|
return (
|
||||||
|
env_xp.asarray([1, 2, 3]),
|
||||||
|
env_xp.asarray(5.0),
|
||||||
|
env_xp.asarray(True),
|
||||||
|
env_xp.asarray(False),
|
||||||
|
{"data": env_xp.asarray([1.0, 2.0])},
|
||||||
|
)
|
||||||
|
|
||||||
|
env = GenericTestEnv(reset_func=reset_func, step_func=step_func)
|
||||||
|
|
||||||
|
# Check that the reset and step for env_xp environment are as expected
|
||||||
|
obs, info = env.reset()
|
||||||
|
# env_xp is automatically converted to the compatible namespace by array_namespace, so we need
|
||||||
|
# to check against the compatible namespace of env_xp in array_api_compat
|
||||||
|
env_xp_compat = module_namespace(env_xp)
|
||||||
|
assert array_namespace(obs) is env_xp_compat
|
||||||
|
assert isinstance(info, dict) and array_namespace(info["data"]) is env_xp_compat
|
||||||
|
|
||||||
|
obs, reward, terminated, truncated, info = env.step(env_xp_compat.asarray([1, 2]))
|
||||||
|
assert array_namespace(obs) is env_xp_compat
|
||||||
|
assert array_namespace(reward) is env_xp_compat
|
||||||
|
assert array_namespace(terminated) is env_xp_compat
|
||||||
|
assert array_namespace(truncated) is env_xp_compat
|
||||||
|
assert isinstance(info, dict) and array_namespace(info["data"]) is env_xp_compat
|
||||||
|
|
||||||
|
# Check that the wrapped version is correct.
|
||||||
|
target_xp_compat = module_namespace(target_xp)
|
||||||
|
wrapped_env = ArrayConversion(env, env_xp=env_xp, target_xp=target_xp)
|
||||||
|
obs, info = wrapped_env.reset()
|
||||||
|
assert array_namespace(obs) is target_xp_compat
|
||||||
|
assert isinstance(info, dict) and array_namespace(info["data"]) is target_xp_compat
|
||||||
|
|
||||||
|
action = target_xp.asarray([1, 2], dtype=target_xp.int32)
|
||||||
|
obs, reward, terminated, truncated, info = wrapped_env.step(action)
|
||||||
|
assert array_namespace(obs) is target_xp_compat
|
||||||
|
assert isinstance(reward, float)
|
||||||
|
assert isinstance(terminated, bool) and isinstance(truncated, bool)
|
||||||
|
assert isinstance(info, dict) and array_namespace(info["data"]) is target_xp_compat
|
||||||
|
|
||||||
|
# Check that the wrapped environment can render. This implicitly returns None and requires a
|
||||||
|
# None -> None conversion
|
||||||
|
wrapped_env.render()
|
||||||
|
|
||||||
|
# Test that the wrapped environment can be pickled
|
||||||
|
env = gymnasium.make("CartPole-v1", disable_env_checker=True)
|
||||||
|
wrapped_env = ArrayConversion(env, env_xp=env_xp, target_xp=target_xp)
|
||||||
|
pkl = pickle.dumps(wrapped_env)
|
||||||
|
pickle.loads(pkl)
|
@@ -1,10 +1,13 @@
|
|||||||
"""Test suite for JaxToNumpy wrapper."""
|
"""Test suite for JaxToNumpy wrapper."""
|
||||||
|
|
||||||
|
import pickle
|
||||||
from typing import NamedTuple
|
from typing import NamedTuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
import gymnasium
|
||||||
|
|
||||||
|
|
||||||
jax = pytest.importorskip("jax")
|
jax = pytest.importorskip("jax")
|
||||||
jnp = pytest.importorskip("jax.numpy")
|
jnp = pytest.importorskip("jax.numpy")
|
||||||
@@ -133,3 +136,9 @@ def test_jax_to_numpy_wrapper():
|
|||||||
# Check that the wrapped environment can render. This implicitly returns None and requires a
|
# Check that the wrapped environment can render. This implicitly returns None and requires a
|
||||||
# None -> None conversion
|
# None -> None conversion
|
||||||
numpy_env.render()
|
numpy_env.render()
|
||||||
|
|
||||||
|
# Test that the wrapped environment can be pickled
|
||||||
|
env = gymnasium.make("CartPole-v1", disable_env_checker=True)
|
||||||
|
wrapped_env = JaxToNumpy(env)
|
||||||
|
pkl = pickle.dumps(wrapped_env)
|
||||||
|
pickle.loads(pkl)
|
||||||
|
@@ -1,9 +1,12 @@
|
|||||||
"""Test suite for TorchToJax wrapper."""
|
"""Test suite for TorchToJax wrapper."""
|
||||||
|
|
||||||
|
import pickle
|
||||||
from typing import NamedTuple
|
from typing import NamedTuple
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
import gymnasium
|
||||||
|
|
||||||
|
|
||||||
jax = pytest.importorskip("jax")
|
jax = pytest.importorskip("jax")
|
||||||
jnp = pytest.importorskip("jax.numpy")
|
jnp = pytest.importorskip("jax.numpy")
|
||||||
@@ -148,3 +151,9 @@ def test_jax_to_torch_wrapper():
|
|||||||
# Check that the wrapped environment can render. This implicitly returns None and requires a
|
# Check that the wrapped environment can render. This implicitly returns None and requires a
|
||||||
# None -> None conversion
|
# None -> None conversion
|
||||||
wrapped_env.render()
|
wrapped_env.render()
|
||||||
|
|
||||||
|
# Test that the wrapped environment can be pickled
|
||||||
|
env = gymnasium.make("CartPole-v1", disable_env_checker=True)
|
||||||
|
wrapped_env = JaxToTorch(env)
|
||||||
|
pkl = pickle.dumps(wrapped_env)
|
||||||
|
pickle.loads(pkl)
|
||||||
|
@@ -1,10 +1,13 @@
|
|||||||
"""Test suite for NumPyToTorch wrapper."""
|
"""Test suite for NumPyToTorch wrapper."""
|
||||||
|
|
||||||
|
import pickle
|
||||||
from typing import NamedTuple
|
from typing import NamedTuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
import gymnasium
|
||||||
|
|
||||||
|
|
||||||
torch = pytest.importorskip("torch")
|
torch = pytest.importorskip("torch")
|
||||||
|
|
||||||
@@ -128,3 +131,9 @@ def test_numpy_to_torch():
|
|||||||
# Check that the wrapped environment can render. This implicitly returns None and requires a
|
# Check that the wrapped environment can render. This implicitly returns None and requires a
|
||||||
# None -> None conversion
|
# None -> None conversion
|
||||||
torch_env.render()
|
torch_env.render()
|
||||||
|
|
||||||
|
# Test that the wrapped environment can be pickled
|
||||||
|
env = gymnasium.make("CartPole-v1", disable_env_checker=True)
|
||||||
|
wrapped_env = NumpyToTorch(env)
|
||||||
|
pkl = pickle.dumps(wrapped_env)
|
||||||
|
pickle.loads(pkl)
|
||||||
|
146
tests/wrappers/vector/test_array_conversion.py
Normal file
146
tests/wrappers/vector/test_array_conversion.py
Normal file
@@ -0,0 +1,146 @@
|
|||||||
|
"""Test suite for vector ArrayConversion wrapper."""
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
import itertools
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from tests.testing_env import GenericTestVectorEnv
|
||||||
|
|
||||||
|
|
||||||
|
array_api_compat = pytest.importorskip("array_api_compat")
|
||||||
|
from array_api_compat import array_namespace # noqa: E402
|
||||||
|
|
||||||
|
from gymnasium.wrappers.array_conversion import module_namespace # noqa: E402
|
||||||
|
from gymnasium.wrappers.vector.array_conversion import ArrayConversion # noqa: E402
|
||||||
|
from gymnasium.wrappers.vector.jax_to_numpy import JaxToNumpy # noqa: E402
|
||||||
|
from gymnasium.wrappers.vector.jax_to_torch import JaxToTorch # noqa: E402
|
||||||
|
from gymnasium.wrappers.vector.numpy_to_torch import NumpyToTorch # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
|
# Define available modules
|
||||||
|
installed_modules = []
|
||||||
|
array_api_modules = [
|
||||||
|
"numpy",
|
||||||
|
"jax.numpy",
|
||||||
|
"torch",
|
||||||
|
"cupy",
|
||||||
|
"dask.array",
|
||||||
|
"sparse",
|
||||||
|
"array_api_strict",
|
||||||
|
]
|
||||||
|
for module in array_api_modules:
|
||||||
|
try:
|
||||||
|
installed_modules.append(importlib.import_module(module))
|
||||||
|
except ImportError:
|
||||||
|
pass # Modules that are not installed are skipped
|
||||||
|
|
||||||
|
installed_modules_combinations = list(itertools.permutations(installed_modules, 2))
|
||||||
|
|
||||||
|
|
||||||
|
def create_vector_env(env_xp):
|
||||||
|
_reset_func = partial(reset_func, num_envs=3, xp=env_xp)
|
||||||
|
_step_func = partial(step_func, num_envs=3, xp=env_xp)
|
||||||
|
return GenericTestVectorEnv(reset_func=_reset_func, step_func=_step_func)
|
||||||
|
|
||||||
|
|
||||||
|
def reset_func(self, seed=None, options=None, num_envs: int = 1, xp=np):
|
||||||
|
return xp.asarray([[1.0, 2.0, 3.0] * num_envs]), {
|
||||||
|
"data": xp.asarray([[1, 2, 3] * num_envs])
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def step_func(self, action, num_envs: int = 1, xp=np):
|
||||||
|
assert isinstance(action, type(xp.zeros(1)))
|
||||||
|
return (
|
||||||
|
xp.asarray([[1, 2, 3] * num_envs]),
|
||||||
|
xp.asarray([5.0] * num_envs),
|
||||||
|
xp.asarray([False] * num_envs),
|
||||||
|
xp.asarray([False] * num_envs),
|
||||||
|
{"data": xp.asarray([[1.0, 2.0] * num_envs])},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("env_xp, target_xp", installed_modules_combinations)
|
||||||
|
def test_array_conversion_wrapper(env_xp, target_xp):
|
||||||
|
env_xp_compat = module_namespace(env_xp)
|
||||||
|
env = create_vector_env(env_xp_compat)
|
||||||
|
|
||||||
|
# Check that the reset and step for env_xp environment are as expected
|
||||||
|
obs, info = env.reset()
|
||||||
|
# env_xp is automatically converted to the compatible namespace by array_namespace, so we need
|
||||||
|
# to check against the compatible namespace of env_xp in array_api_compat
|
||||||
|
assert array_namespace(obs) is env_xp_compat
|
||||||
|
assert isinstance(info, dict) and array_namespace(info["data"]) is env_xp_compat
|
||||||
|
|
||||||
|
obs, reward, terminated, truncated, info = env.step(env_xp_compat.asarray([1, 2]))
|
||||||
|
assert array_namespace(obs) is env_xp_compat
|
||||||
|
assert array_namespace(reward) is env_xp_compat
|
||||||
|
assert array_namespace(terminated) is env_xp_compat
|
||||||
|
assert array_namespace(truncated) is env_xp_compat
|
||||||
|
assert isinstance(info, dict) and array_namespace(info["data"]) is env_xp_compat
|
||||||
|
|
||||||
|
# Check that the wrapped version is correct.
|
||||||
|
target_xp_compat = module_namespace(target_xp)
|
||||||
|
wrapped_env = ArrayConversion(env, env_xp=env_xp, target_xp=target_xp)
|
||||||
|
obs, info = wrapped_env.reset()
|
||||||
|
assert array_namespace(obs) is target_xp_compat
|
||||||
|
assert isinstance(info, dict) and array_namespace(info["data"]) is target_xp_compat
|
||||||
|
|
||||||
|
action = target_xp.asarray([1, 2], dtype=target_xp.int32)
|
||||||
|
obs, reward, terminated, truncated, info = wrapped_env.step(action)
|
||||||
|
assert array_namespace(obs) is target_xp_compat
|
||||||
|
assert array_namespace(reward) is target_xp_compat
|
||||||
|
assert array_namespace(terminated) is target_xp_compat
|
||||||
|
assert terminated.dtype == target_xp.bool
|
||||||
|
assert array_namespace(truncated) is target_xp_compat
|
||||||
|
assert truncated.dtype == target_xp.bool
|
||||||
|
assert isinstance(info, dict) and array_namespace(info["data"]) is target_xp_compat
|
||||||
|
|
||||||
|
# Check that the wrapped environment can render. This implicitly returns None and requires a
|
||||||
|
# None -> None conversion
|
||||||
|
wrapped_env.render()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("wrapper", [JaxToNumpy, JaxToTorch, NumpyToTorch])
|
||||||
|
def test_specialized_wrappers(wrapper: type[JaxToNumpy | JaxToTorch | NumpyToTorch]):
|
||||||
|
if wrapper is JaxToNumpy:
|
||||||
|
jax = pytest.importorskip("jax")
|
||||||
|
env_xp, target_xp = jax.numpy, np
|
||||||
|
elif wrapper is JaxToTorch:
|
||||||
|
jax = pytest.importorskip("jax")
|
||||||
|
torch = pytest.importorskip("torch")
|
||||||
|
env_xp, target_xp = jax.numpy, torch
|
||||||
|
elif wrapper is NumpyToTorch:
|
||||||
|
torch = pytest.importorskip("torch")
|
||||||
|
env_xp, target_xp = np, torch
|
||||||
|
else:
|
||||||
|
raise TypeError(f"Unknown specialized conversion wrapper {type(wrapper)}")
|
||||||
|
env_xp_compat = module_namespace(env_xp)
|
||||||
|
target_xp_compat = module_namespace(target_xp)
|
||||||
|
|
||||||
|
# The unwrapped test env sanity check is already covered by test_array_conversion_wrapper for
|
||||||
|
# all known frameworks, including the specialized ones.
|
||||||
|
env = create_vector_env(env_xp_compat)
|
||||||
|
|
||||||
|
# Check that the wrapped version is correct.
|
||||||
|
wrapped_env = wrapper(env)
|
||||||
|
obs, info = wrapped_env.reset()
|
||||||
|
assert array_namespace(obs) is target_xp_compat
|
||||||
|
assert isinstance(info, dict) and array_namespace(info["data"]) is target_xp_compat
|
||||||
|
|
||||||
|
action = target_xp.asarray([1, 2], dtype=target_xp.int32)
|
||||||
|
obs, reward, terminated, truncated, info = wrapped_env.step(action)
|
||||||
|
assert array_namespace(obs) is target_xp_compat
|
||||||
|
assert array_namespace(reward) is target_xp_compat
|
||||||
|
assert array_namespace(terminated) is target_xp_compat
|
||||||
|
assert terminated.dtype == target_xp.bool
|
||||||
|
assert array_namespace(truncated) is target_xp_compat
|
||||||
|
assert truncated.dtype == target_xp.bool
|
||||||
|
assert isinstance(info, dict) and array_namespace(info["data"]) is target_xp_compat
|
||||||
|
|
||||||
|
# Check that the wrapped environment can render. This implicitly returns None and requires a
|
||||||
|
# None -> None conversion
|
||||||
|
wrapped_env.render()
|
Reference in New Issue
Block a user