mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-07-31 13:54:31 +00:00
Adds wrappers for jax environments converting to numpy and pytorch (#168)
Co-authored-by: Justin Deutsch <djustin8@vt.edu> Co-authored-by: Gianluca De Cola <42657588+gianlucadecola@users.noreply.github.com>
This commit is contained in:
@@ -175,21 +175,21 @@ Gymnasium already contains a large collection of wrappers, but we believe that t
|
|||||||
* - :class:`wrappers.EnvCompatibility`
|
* - :class:`wrappers.EnvCompatibility`
|
||||||
- Moved to `shimmy <https://github.com/Farama-Foundation/Shimmy/blob/main/shimmy/openai_gym_compatibility.py>`_
|
- Moved to `shimmy <https://github.com/Farama-Foundation/Shimmy/blob/main/shimmy/openai_gym_compatibility.py>`_
|
||||||
- Not Implemented
|
- Not Implemented
|
||||||
* - :class:`RecordEpisodeStatistics`
|
* - :class:`wrappers.RecordEpisodeStatistics`
|
||||||
- RecordEpisodeStatistics
|
- RecordEpisodeStatistics
|
||||||
- VectorRecordEpisodeStatistics
|
- VectorRecordEpisodeStatistics
|
||||||
* - :class:`RenderCollection`
|
* - :class:`wrappers.RenderCollection`
|
||||||
- RenderCollection
|
- RenderCollection
|
||||||
- VectorRenderCollection
|
- VectorRenderCollection
|
||||||
* - :class:`HumanRendering`
|
* - :class:`wrappers.HumanRendering`
|
||||||
- HumanRendering
|
- HumanRendering
|
||||||
- Not Implemented
|
- Not Implemented
|
||||||
* - Not Implemented
|
* - Not Implemented
|
||||||
- JaxToNumpy
|
- :class:`experimental.wrappers.JaxToNumpy`
|
||||||
- VectorJaxToNumpy
|
- VectorJaxToNumpy (*)
|
||||||
* - Not Implemented
|
* - Not Implemented
|
||||||
- JaxToTorch
|
- :class:`experimental.wrappers.JaxToTorch`
|
||||||
- VectorJaxToTorch
|
- VectorJaxToTorch (*)
|
||||||
```
|
```
|
||||||
|
|
||||||
### Vector Only Wrappers
|
### Vector Only Wrappers
|
||||||
|
@@ -12,6 +12,8 @@ from gymnasium.experimental.wrappers.lambda_action import (
|
|||||||
)
|
)
|
||||||
from gymnasium.experimental.wrappers.lambda_observations import LambdaObservationV0
|
from gymnasium.experimental.wrappers.lambda_observations import LambdaObservationV0
|
||||||
from gymnasium.experimental.wrappers.lambda_reward import ClipRewardV0, LambdaRewardV0
|
from gymnasium.experimental.wrappers.lambda_reward import ClipRewardV0, LambdaRewardV0
|
||||||
|
from gymnasium.experimental.wrappers.numpy_to_jax import JaxToNumpyV0
|
||||||
|
from gymnasium.experimental.wrappers.torch_to_jax import JaxToTorchV0
|
||||||
from gymnasium.experimental.wrappers.sticky_action import StickyActionV0
|
from gymnasium.experimental.wrappers.sticky_action import StickyActionV0
|
||||||
from gymnasium.experimental.wrappers.time_aware_observation import (
|
from gymnasium.experimental.wrappers.time_aware_observation import (
|
||||||
TimeAwareObservationV0,
|
TimeAwareObservationV0,
|
||||||
@@ -32,4 +34,7 @@ __all__ = [
|
|||||||
# Lambda Reward
|
# Lambda Reward
|
||||||
"LambdaRewardV0",
|
"LambdaRewardV0",
|
||||||
"ClipRewardV0",
|
"ClipRewardV0",
|
||||||
|
# Jax conversion wrappers
|
||||||
|
"JaxToNumpyV0",
|
||||||
|
"JaxToTorchV0",
|
||||||
]
|
]
|
||||||
|
134
gymnasium/experimental/wrappers/numpy_to_jax.py
Normal file
134
gymnasium/experimental/wrappers/numpy_to_jax.py
Normal file
@@ -0,0 +1,134 @@
|
|||||||
|
"""Helper functions and wrapper class for converting between numpy and Jax."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import functools
|
||||||
|
import numbers
|
||||||
|
from collections import abc
|
||||||
|
from typing import Any, Iterable, Mapping, SupportsFloat
|
||||||
|
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from gymnasium import Env, Wrapper
|
||||||
|
from gymnasium.core import RenderFrame, WrapperActType, WrapperObsType
|
||||||
|
|
||||||
|
|
||||||
|
@functools.singledispatch
|
||||||
|
def numpy_to_jax(value: Any) -> Any:
|
||||||
|
"""Converts a value to a Jax DeviceArray."""
|
||||||
|
raise Exception(
|
||||||
|
f"No conversion for Numpy to Jax registered for type: {type(value)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@numpy_to_jax.register(numbers.Number)
|
||||||
|
@numpy_to_jax.register(np.ndarray)
|
||||||
|
def _number_ndarray_numpy_to_jax(value: np.ndarray | numbers.Number) -> jnp.DeviceArray:
|
||||||
|
"""Converts a numpy array or number (int, float, etc.) to a Jax DeviceArray."""
|
||||||
|
return jnp.array(value)
|
||||||
|
|
||||||
|
|
||||||
|
@numpy_to_jax.register(abc.Mapping)
|
||||||
|
def _mapping_numpy_to_jax(value: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||||
|
"""Converts a dictionary of numpy arrays to a mapping of Jax DeviceArrays."""
|
||||||
|
return type(value)(**{k: numpy_to_jax(v) for k, v in value.items()})
|
||||||
|
|
||||||
|
|
||||||
|
@numpy_to_jax.register(abc.Iterable)
|
||||||
|
def _iterable_numpy_to_jax(
|
||||||
|
value: Iterable[np.ndarray | Any],
|
||||||
|
) -> Iterable[jnp.DeviceArray | Any]:
|
||||||
|
"""Converts an Iterable from Numpy Arrays to an iterable of Jax DeviceArrays."""
|
||||||
|
return type(value)(numpy_to_jax(v) for v in value)
|
||||||
|
|
||||||
|
|
||||||
|
@functools.singledispatch
|
||||||
|
def jax_to_numpy(value: Any) -> Any:
|
||||||
|
"""Converts a value to a numpy array."""
|
||||||
|
raise Exception(
|
||||||
|
f"No conversion for Jax to Numpy registered for type: {type(value)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@jax_to_numpy.register(jnp.DeviceArray)
|
||||||
|
def _devicearray_jax_to_numpy(value: jnp.DeviceArray) -> np.ndarray:
|
||||||
|
"""Converts a Jax DeviceArray to a numpy array."""
|
||||||
|
return np.array(value)
|
||||||
|
|
||||||
|
|
||||||
|
@jax_to_numpy.register(abc.Mapping)
|
||||||
|
def _mapping_jax_to_numpy(
|
||||||
|
value: Mapping[str, jnp.DeviceArray | Any]
|
||||||
|
) -> Mapping[str, np.ndarray | Any]:
|
||||||
|
"""Converts a dictionary of Jax DeviceArrays to a mapping of numpy arrays."""
|
||||||
|
return type(value)(**{k: jax_to_numpy(v) for k, v in value.items()})
|
||||||
|
|
||||||
|
|
||||||
|
@jax_to_numpy.register(abc.Iterable)
|
||||||
|
def _iterable_jax_to_numpy(
|
||||||
|
value: Iterable[np.ndarray | Any],
|
||||||
|
) -> Iterable[jnp.DeviceArray | Any]:
|
||||||
|
"""Converts an Iterable from Numpy arrays to an iterable of Jax DeviceArrays."""
|
||||||
|
return type(value)(jax_to_numpy(v) for v in value)
|
||||||
|
|
||||||
|
|
||||||
|
class JaxToNumpyV0(Wrapper):
|
||||||
|
"""Wraps a jax environment so that it can be interacted with through numpy arrays.
|
||||||
|
|
||||||
|
Actions must be provided as numpy arrays and observations will be returned as numpy arrays.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
The Jax To Numpy and Numpy to Jax conversion does not guarantee a roundtrip (jax -> numpy -> jax) and vice versa.
|
||||||
|
The reason for this is jax does not support non-array values, therefore numpy ``int_32(5) -> DeviceArray([5], dtype=jnp.int23)``
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, env: Env):
|
||||||
|
"""Wraps an environment such that the input and outputs are numpy arrays.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env: the environment to wrap
|
||||||
|
"""
|
||||||
|
super().__init__(env)
|
||||||
|
|
||||||
|
def step(
|
||||||
|
self, action: WrapperActType
|
||||||
|
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]:
|
||||||
|
"""Transforms the action to a jax array .
|
||||||
|
|
||||||
|
Args:
|
||||||
|
action: the action to perform as a numpy array
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple containing the next observation, reward, termination, truncation, and extra info.
|
||||||
|
"""
|
||||||
|
jax_action = numpy_to_jax(action)
|
||||||
|
obs, reward, terminated, truncated, info = self.env.step(jax_action)
|
||||||
|
|
||||||
|
return (
|
||||||
|
jax_to_numpy(obs),
|
||||||
|
float(reward),
|
||||||
|
bool(terminated),
|
||||||
|
bool(truncated),
|
||||||
|
jax_to_numpy(info),
|
||||||
|
)
|
||||||
|
|
||||||
|
def reset(
|
||||||
|
self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
||||||
|
) -> tuple[WrapperObsType, dict[str, Any]]:
|
||||||
|
"""Resets the environment returning numpy-based observation and info.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seed: The seed for resetting the environment
|
||||||
|
options: The options for resetting the environment, these are converted to jax arrays.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Numpy-based observations and info
|
||||||
|
"""
|
||||||
|
if options:
|
||||||
|
options = numpy_to_jax(options)
|
||||||
|
|
||||||
|
return jax_to_numpy(self.env.reset(seed=seed, options=options))
|
||||||
|
|
||||||
|
def render(self) -> RenderFrame | list[RenderFrame] | None:
|
||||||
|
"""Returns the rendered frames as a numpy array."""
|
||||||
|
return jax_to_numpy(self.env.render())
|
163
gymnasium/experimental/wrappers/torch_to_jax.py
Normal file
163
gymnasium/experimental/wrappers/torch_to_jax.py
Normal file
@@ -0,0 +1,163 @@
|
|||||||
|
# This wrapper will convert torch inputs for the actions and observations to Jax arrays
|
||||||
|
# for an underlying Jax environment then convert the return observations from Jax arrays
|
||||||
|
# back to torch tensors.
|
||||||
|
#
|
||||||
|
# Functionality for converting between torch and jax types originally copied from
|
||||||
|
# https://github.com/google/brax/blob/9d6b7ced2a13da0d074b5e9fbd3aad8311e26997/brax/io/torch.py
|
||||||
|
# Under the Apache 2.0 license. Copyright is held by the authors
|
||||||
|
|
||||||
|
"""Helper functions and wrapper class for converting between PyTorch and Jax."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import functools
|
||||||
|
import numbers
|
||||||
|
from collections import abc
|
||||||
|
from typing import Any, Iterable, Mapping, SupportsFloat, Union
|
||||||
|
|
||||||
|
import jax.numpy as jnp
|
||||||
|
from jax import dlpack as jax_dlpack
|
||||||
|
|
||||||
|
from gymnasium import Env, Wrapper
|
||||||
|
from gymnasium.core import RenderFrame, WrapperActType, WrapperObsType
|
||||||
|
from gymnasium.error import DependencyNotInstalled
|
||||||
|
from gymnasium.experimental.wrappers.numpy_to_jax import jax_to_numpy
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
from torch.utils import dlpack as torch_dlpack
|
||||||
|
except ImportError:
|
||||||
|
raise DependencyNotInstalled("torch is not installed, run `pip install torch`")
|
||||||
|
|
||||||
|
|
||||||
|
Device = Union[str, torch.device]
|
||||||
|
|
||||||
|
|
||||||
|
@functools.singledispatch
|
||||||
|
def torch_to_jax(value: Any) -> Any:
|
||||||
|
"""Converts a PyTorch Tensor into a Jax DeviceArray."""
|
||||||
|
raise Exception(
|
||||||
|
f"No conversion for PyTorch to Jax registered for type: {type(value)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@torch_to_jax.register(numbers.Number)
|
||||||
|
def _number_torch_to_jax(value: numbers.Number) -> Any:
|
||||||
|
return jnp.array(value)
|
||||||
|
|
||||||
|
|
||||||
|
@torch_to_jax.register(torch.Tensor)
|
||||||
|
def _tensor_torch_to_jax(value: torch.Tensor) -> jnp.DeviceArray:
|
||||||
|
"""Converts a PyTorch Tensor into a Jax DeviceArray."""
|
||||||
|
tensor = torch_dlpack.to_dlpack(value) # pyright: ignore[reportPrivateImportUsage]
|
||||||
|
tensor = jax_dlpack.from_dlpack(tensor) # pyright: ignore[reportPrivateImportUsage]
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
@torch_to_jax.register(abc.Mapping)
|
||||||
|
def _mapping_torch_to_jax(value: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||||
|
"""Converts a mapping of PyTorch Tensors into a Dictionary of Jax DeviceArrays."""
|
||||||
|
return type(value)(**{k: torch_to_jax(v) for k, v in value.items()})
|
||||||
|
|
||||||
|
|
||||||
|
@torch_to_jax.register(abc.Iterable)
|
||||||
|
def _iterable_torch_to_jax(value: Iterable[Any]) -> Iterable[Any]:
|
||||||
|
"""Converts an Iterable from PyTorch Tensors to an iterable of Jax DeviceArrays."""
|
||||||
|
return type(value)(torch_to_jax(v) for v in value)
|
||||||
|
|
||||||
|
|
||||||
|
@functools.singledispatch
|
||||||
|
def jax_to_torch(value: Any, device: Device | None = None) -> Any:
|
||||||
|
"""Converts a Jax DeviceArray into a PyTorch Tensor."""
|
||||||
|
raise Exception(
|
||||||
|
f"No conversion for Jax to PyTorch registered for type={type(value)} and device: {device}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@jax_to_torch.register(jnp.DeviceArray)
|
||||||
|
def _devicearray_jax_to_torch(
|
||||||
|
value: jnp.DeviceArray, device: Device | None = None
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Converts a Jax DeviceArray into a PyTorch Tensor."""
|
||||||
|
dlpack = jax_dlpack.to_dlpack(value) # pyright: ignore[reportPrivateImportUsage]
|
||||||
|
tensor = torch_dlpack.from_dlpack(dlpack)
|
||||||
|
if device:
|
||||||
|
return tensor.to(device=device)
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
@jax_to_torch.register(abc.Mapping)
|
||||||
|
def _jax_mapping_to_torch(
|
||||||
|
value: Mapping[str, Any], device: Device | None = None
|
||||||
|
) -> Mapping[str, Any]:
|
||||||
|
"""Converts a mapping of Jax DeviceArrays into a Dictionary of PyTorch Tensors."""
|
||||||
|
return type(value)(**{k: jax_to_torch(v, device) for k, v in value.items()})
|
||||||
|
|
||||||
|
|
||||||
|
@jax_to_torch.register(abc.Iterable)
|
||||||
|
def _jax_iterable_to_torch(
|
||||||
|
value: Iterable[Any], device: Device | None = None
|
||||||
|
) -> Iterable[Any]:
|
||||||
|
"""Converts an Iterable from Jax DeviceArrays to an iterable of PyTorch Tensors."""
|
||||||
|
return type(value)(jax_to_torch(v, device) for v in value)
|
||||||
|
|
||||||
|
|
||||||
|
class JaxToTorchV0(Wrapper):
|
||||||
|
"""Wraps a jax-based environment so that it can be interacted with through PyTorch Tensors.
|
||||||
|
|
||||||
|
Actions must be provided as PyTorch Tensors and observations will be returned as PyTorch Tensors.
|
||||||
|
|
||||||
|
For ``rendered`` this is returned as a NumPy array not a pytorch Tensor.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, env: Env, device: Device | None = None):
|
||||||
|
"""Wrapper class to change inputs and outputs of environment to PyTorch tensors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env: The Jax-based environment to wrap
|
||||||
|
device: The device the torch Tensors should be moved to
|
||||||
|
"""
|
||||||
|
super().__init__(env)
|
||||||
|
self.device: Device | None = device
|
||||||
|
|
||||||
|
def step(
|
||||||
|
self, action: WrapperActType
|
||||||
|
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]:
|
||||||
|
"""Performs the given action within the environment.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
action: The action to perform as a PyTorch Tensor
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The next observation, reward, termination, truncation, and extra info
|
||||||
|
"""
|
||||||
|
jax_action = torch_to_jax(action)
|
||||||
|
obs, reward, terminated, truncated, info = self.env.step(jax_action)
|
||||||
|
|
||||||
|
return (
|
||||||
|
jax_to_torch(obs, self.device),
|
||||||
|
float(reward),
|
||||||
|
bool(terminated),
|
||||||
|
bool(truncated),
|
||||||
|
jax_to_torch(info, self.device),
|
||||||
|
)
|
||||||
|
|
||||||
|
def reset(
|
||||||
|
self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
||||||
|
) -> tuple[WrapperObsType, dict[str, Any]]:
|
||||||
|
"""Resets the environment returning PyTorch-based observation and info.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seed: The seed for resetting the environment
|
||||||
|
options: The options for resetting the environment, these are converted to jax arrays.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PyTorch-based observations and info
|
||||||
|
"""
|
||||||
|
if options:
|
||||||
|
options = torch_to_jax(options)
|
||||||
|
|
||||||
|
return jax_to_torch(self.env.reset(seed=seed, options=options), self.device)
|
||||||
|
|
||||||
|
def render(self) -> RenderFrame | list[RenderFrame] | None:
|
||||||
|
"""Returns the rendered frames as a NumPy array."""
|
||||||
|
return jax_to_numpy(self.env.render())
|
9
setup.py
9
setup.py
@@ -42,7 +42,14 @@ extras: Dict[str, List[str]] = {
|
|||||||
"mujoco": ["mujoco>=2.3.0", "imageio>=2.14.1"],
|
"mujoco": ["mujoco>=2.3.0", "imageio>=2.14.1"],
|
||||||
"toy_text": ["pygame==2.1.0"],
|
"toy_text": ["pygame==2.1.0"],
|
||||||
"jax": ["jax==0.3.20", "jaxlib==0.3.20"],
|
"jax": ["jax==0.3.20", "jaxlib==0.3.20"],
|
||||||
"other": ["lz4>=3.1.0", "opencv-python>=3.0", "matplotlib>=3.0", "moviepy>=1.0.0"],
|
"other": [
|
||||||
|
"lz4>=3.1.0",
|
||||||
|
"opencv-python>=3.0",
|
||||||
|
"matplotlib>=3.0",
|
||||||
|
"moviepy>=1.0.0",
|
||||||
|
"tensorflow>=2.1.0",
|
||||||
|
"torch>=1.0.0",
|
||||||
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
# All dependency groups - accept rom license as requires user to run
|
# All dependency groups - accept rom license as requires user to run
|
||||||
|
83
tests/experimental/wrappers/test_numpy_to_jax.py
Normal file
83
tests/experimental/wrappers/test_numpy_to_jax.py
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
import jax.numpy as jnp
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from gymnasium.experimental.wrappers import JaxToNumpyV0
|
||||||
|
from gymnasium.experimental.wrappers.numpy_to_jax import jax_to_numpy, numpy_to_jax
|
||||||
|
from gymnasium.utils.env_checker import data_equivalence
|
||||||
|
from tests.testing_env import GenericTestEnv
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"value, expected_value",
|
||||||
|
[
|
||||||
|
(1.0, np.array(1.0)),
|
||||||
|
(2, np.array(2)),
|
||||||
|
((3.0, 4), (np.array(3.0), np.array(4))),
|
||||||
|
([3.0, 4], [np.array(3.0), np.array(4)]),
|
||||||
|
(
|
||||||
|
{
|
||||||
|
"a": 6.0,
|
||||||
|
"b": 7,
|
||||||
|
},
|
||||||
|
{"a": np.array(6.0), "b": np.array(7)},
|
||||||
|
),
|
||||||
|
(np.array(1.0), np.array(1.0)),
|
||||||
|
(np.array([1, 2]), np.array([1, 2])),
|
||||||
|
(np.array([[1.0], [2.0]]), np.array([[1.0], [2.0]])),
|
||||||
|
(
|
||||||
|
{"a": (1, np.array(2.0), np.array([3, 4])), "b": {"c": 5}},
|
||||||
|
{
|
||||||
|
"a": (np.array(1), np.array(2.0), np.array([3, 4])),
|
||||||
|
"b": {"c": np.array(5)},
|
||||||
|
},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_roundtripping(value, expected_value):
|
||||||
|
"""We test numpy -> jax -> numpy as this is direction in the NumpyToJax wrapper."""
|
||||||
|
assert data_equivalence(jax_to_numpy(numpy_to_jax(value)), expected_value)
|
||||||
|
|
||||||
|
|
||||||
|
def jax_reset_func(self, seed=None, options=None):
|
||||||
|
return jnp.array([1.0, 2.0, 3.0]), {"data": jnp.array([1, 2, 3])}
|
||||||
|
|
||||||
|
|
||||||
|
def jax_step_func(self, action):
|
||||||
|
assert isinstance(action, jnp.DeviceArray), type(action)
|
||||||
|
return (
|
||||||
|
jnp.array([1, 2, 3]),
|
||||||
|
jnp.array(5.0),
|
||||||
|
jnp.array(True),
|
||||||
|
jnp.array(False),
|
||||||
|
{"data": jnp.array([1.0, 2.0])},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_jax_to_numpy():
|
||||||
|
jax_env = GenericTestEnv(reset_fn=jax_reset_func, step_fn=jax_step_func)
|
||||||
|
|
||||||
|
# Check that the reset and step for jax environment are as expected
|
||||||
|
obs, info = jax_env.reset()
|
||||||
|
assert isinstance(obs, jnp.DeviceArray)
|
||||||
|
assert isinstance(info, dict) and isinstance(info["data"], jnp.DeviceArray)
|
||||||
|
|
||||||
|
obs, reward, terminated, truncated, info = jax_env.step(jnp.array([1, 2]))
|
||||||
|
assert isinstance(obs, jnp.DeviceArray)
|
||||||
|
assert isinstance(reward, jnp.DeviceArray)
|
||||||
|
assert isinstance(terminated, jnp.DeviceArray) and isinstance(
|
||||||
|
truncated, jnp.DeviceArray
|
||||||
|
)
|
||||||
|
assert isinstance(info, dict) and isinstance(info["data"], jnp.DeviceArray)
|
||||||
|
|
||||||
|
# Check that the wrapped version is correct.
|
||||||
|
numpy_env = JaxToNumpyV0(jax_env)
|
||||||
|
obs, info = numpy_env.reset()
|
||||||
|
assert isinstance(obs, np.ndarray)
|
||||||
|
assert isinstance(info, dict) and isinstance(info["data"], np.ndarray)
|
||||||
|
|
||||||
|
obs, reward, terminated, truncated, info = numpy_env.step(np.array([1, 2]))
|
||||||
|
assert isinstance(obs, np.ndarray)
|
||||||
|
assert isinstance(reward, float)
|
||||||
|
assert isinstance(terminated, bool) and isinstance(truncated, bool)
|
||||||
|
assert isinstance(info, dict) and isinstance(info["data"], np.ndarray)
|
103
tests/experimental/wrappers/test_torch_to_jax.py
Normal file
103
tests/experimental/wrappers/test_torch_to_jax.py
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
import jax.numpy as jnp
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from gymnasium.experimental.wrappers import JaxToTorchV0
|
||||||
|
from gymnasium.experimental.wrappers.torch_to_jax import jax_to_torch, torch_to_jax
|
||||||
|
from tests.testing_env import GenericTestEnv
|
||||||
|
|
||||||
|
|
||||||
|
def torch_data_equivalence(data_1, data_2) -> bool:
|
||||||
|
if type(data_1) == type(data_2):
|
||||||
|
if isinstance(data_1, dict):
|
||||||
|
return data_1.keys() == data_2.keys() and all(
|
||||||
|
torch_data_equivalence(data_1[k], data_2[k]) for k in data_1.keys()
|
||||||
|
)
|
||||||
|
elif isinstance(data_1, (tuple, list)):
|
||||||
|
return len(data_1) == len(data_2) and all(
|
||||||
|
torch_data_equivalence(o_1, o_2) for o_1, o_2 in zip(data_1, data_2)
|
||||||
|
)
|
||||||
|
elif isinstance(data_1, torch.Tensor):
|
||||||
|
return data_1.shape == data_2.shape and np.allclose(
|
||||||
|
data_1, data_2, atol=0.00001
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return data_1 == data_2
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"value, expected_value",
|
||||||
|
[
|
||||||
|
(1.0, torch.tensor(1.0)),
|
||||||
|
(2, torch.tensor(2)),
|
||||||
|
((3.0, 4), (torch.tensor(3.0), torch.tensor(4))),
|
||||||
|
([3.0, 4], [torch.tensor(3.0), torch.tensor(4)]),
|
||||||
|
(
|
||||||
|
{
|
||||||
|
"a": 6.0,
|
||||||
|
"b": 7,
|
||||||
|
},
|
||||||
|
{"a": torch.tensor(6.0), "b": torch.tensor(7)},
|
||||||
|
),
|
||||||
|
(torch.tensor(1.0), torch.tensor(1.0)),
|
||||||
|
(torch.tensor([1, 2]), torch.tensor([1, 2])),
|
||||||
|
(torch.tensor([[1.0], [2.0]]), torch.tensor([[1.0], [2.0]])),
|
||||||
|
(
|
||||||
|
{"a": (1, torch.tensor(2.0), torch.tensor([3, 4])), "b": {"c": 5}},
|
||||||
|
{
|
||||||
|
"a": (torch.tensor(1), torch.tensor(2.0), torch.tensor([3, 4])),
|
||||||
|
"b": {"c": torch.tensor(5)},
|
||||||
|
},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_roundtripping(value, expected_value):
|
||||||
|
"""We test numpy -> jax -> numpy as this is direction in the NumpyToJax wrapper."""
|
||||||
|
assert torch_data_equivalence(jax_to_torch(torch_to_jax(value)), expected_value)
|
||||||
|
|
||||||
|
|
||||||
|
def jax_reset_func(self, seed=None, options=None):
|
||||||
|
return jnp.array([1.0, 2.0, 3.0]), {"data": jnp.array([1, 2, 3])}
|
||||||
|
|
||||||
|
|
||||||
|
def jax_step_func(self, action):
|
||||||
|
assert isinstance(action, jnp.DeviceArray), type(action)
|
||||||
|
return (
|
||||||
|
jnp.array([1, 2, 3]),
|
||||||
|
jnp.array(5.0),
|
||||||
|
jnp.array(True),
|
||||||
|
jnp.array(False),
|
||||||
|
{"data": jnp.array([1.0, 2.0])},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_jax_to_torch():
|
||||||
|
env = GenericTestEnv(reset_fn=jax_reset_func, step_fn=jax_step_func)
|
||||||
|
|
||||||
|
# Check that the reset and step for jax environment are as expected
|
||||||
|
obs, info = env.reset()
|
||||||
|
assert isinstance(obs, jnp.DeviceArray)
|
||||||
|
assert isinstance(info, dict) and isinstance(info["data"], jnp.DeviceArray)
|
||||||
|
|
||||||
|
obs, reward, terminated, truncated, info = env.step(jnp.array([1, 2]))
|
||||||
|
assert isinstance(obs, jnp.DeviceArray)
|
||||||
|
assert isinstance(reward, jnp.DeviceArray)
|
||||||
|
assert isinstance(terminated, jnp.DeviceArray) and isinstance(
|
||||||
|
truncated, jnp.DeviceArray
|
||||||
|
)
|
||||||
|
assert isinstance(info, dict) and isinstance(info["data"], jnp.DeviceArray)
|
||||||
|
|
||||||
|
# Check that the wrapped version is correct.
|
||||||
|
wrapped_env = JaxToTorchV0(env)
|
||||||
|
obs, info = wrapped_env.reset()
|
||||||
|
assert isinstance(obs, torch.Tensor)
|
||||||
|
assert isinstance(info, dict) and isinstance(info["data"], torch.Tensor)
|
||||||
|
|
||||||
|
obs, reward, terminated, truncated, info = wrapped_env.step(torch.tensor([1, 2]))
|
||||||
|
assert isinstance(obs, torch.Tensor)
|
||||||
|
assert isinstance(reward, float)
|
||||||
|
assert isinstance(terminated, bool) and isinstance(truncated, bool)
|
||||||
|
assert isinstance(info, dict) and isinstance(info["data"], torch.Tensor)
|
Reference in New Issue
Block a user