mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-07-31 13:54:31 +00:00
Adds wrappers for jax environments converting to numpy and pytorch (#168)
Co-authored-by: Justin Deutsch <djustin8@vt.edu> Co-authored-by: Gianluca De Cola <42657588+gianlucadecola@users.noreply.github.com>
This commit is contained in:
@@ -175,21 +175,21 @@ Gymnasium already contains a large collection of wrappers, but we believe that t
|
||||
* - :class:`wrappers.EnvCompatibility`
|
||||
- Moved to `shimmy <https://github.com/Farama-Foundation/Shimmy/blob/main/shimmy/openai_gym_compatibility.py>`_
|
||||
- Not Implemented
|
||||
* - :class:`RecordEpisodeStatistics`
|
||||
* - :class:`wrappers.RecordEpisodeStatistics`
|
||||
- RecordEpisodeStatistics
|
||||
- VectorRecordEpisodeStatistics
|
||||
* - :class:`RenderCollection`
|
||||
* - :class:`wrappers.RenderCollection`
|
||||
- RenderCollection
|
||||
- VectorRenderCollection
|
||||
* - :class:`HumanRendering`
|
||||
* - :class:`wrappers.HumanRendering`
|
||||
- HumanRendering
|
||||
- Not Implemented
|
||||
* - Not Implemented
|
||||
- JaxToNumpy
|
||||
- VectorJaxToNumpy
|
||||
- :class:`experimental.wrappers.JaxToNumpy`
|
||||
- VectorJaxToNumpy (*)
|
||||
* - Not Implemented
|
||||
- JaxToTorch
|
||||
- VectorJaxToTorch
|
||||
- :class:`experimental.wrappers.JaxToTorch`
|
||||
- VectorJaxToTorch (*)
|
||||
```
|
||||
|
||||
### Vector Only Wrappers
|
||||
|
@@ -12,6 +12,8 @@ from gymnasium.experimental.wrappers.lambda_action import (
|
||||
)
|
||||
from gymnasium.experimental.wrappers.lambda_observations import LambdaObservationV0
|
||||
from gymnasium.experimental.wrappers.lambda_reward import ClipRewardV0, LambdaRewardV0
|
||||
from gymnasium.experimental.wrappers.numpy_to_jax import JaxToNumpyV0
|
||||
from gymnasium.experimental.wrappers.torch_to_jax import JaxToTorchV0
|
||||
from gymnasium.experimental.wrappers.sticky_action import StickyActionV0
|
||||
from gymnasium.experimental.wrappers.time_aware_observation import (
|
||||
TimeAwareObservationV0,
|
||||
@@ -32,4 +34,7 @@ __all__ = [
|
||||
# Lambda Reward
|
||||
"LambdaRewardV0",
|
||||
"ClipRewardV0",
|
||||
# Jax conversion wrappers
|
||||
"JaxToNumpyV0",
|
||||
"JaxToTorchV0",
|
||||
]
|
||||
|
134
gymnasium/experimental/wrappers/numpy_to_jax.py
Normal file
134
gymnasium/experimental/wrappers/numpy_to_jax.py
Normal file
@@ -0,0 +1,134 @@
|
||||
"""Helper functions and wrapper class for converting between numpy and Jax."""
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import numbers
|
||||
from collections import abc
|
||||
from typing import Any, Iterable, Mapping, SupportsFloat
|
||||
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
from gymnasium import Env, Wrapper
|
||||
from gymnasium.core import RenderFrame, WrapperActType, WrapperObsType
|
||||
|
||||
|
||||
@functools.singledispatch
|
||||
def numpy_to_jax(value: Any) -> Any:
|
||||
"""Converts a value to a Jax DeviceArray."""
|
||||
raise Exception(
|
||||
f"No conversion for Numpy to Jax registered for type: {type(value)}"
|
||||
)
|
||||
|
||||
|
||||
@numpy_to_jax.register(numbers.Number)
|
||||
@numpy_to_jax.register(np.ndarray)
|
||||
def _number_ndarray_numpy_to_jax(value: np.ndarray | numbers.Number) -> jnp.DeviceArray:
|
||||
"""Converts a numpy array or number (int, float, etc.) to a Jax DeviceArray."""
|
||||
return jnp.array(value)
|
||||
|
||||
|
||||
@numpy_to_jax.register(abc.Mapping)
|
||||
def _mapping_numpy_to_jax(value: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
"""Converts a dictionary of numpy arrays to a mapping of Jax DeviceArrays."""
|
||||
return type(value)(**{k: numpy_to_jax(v) for k, v in value.items()})
|
||||
|
||||
|
||||
@numpy_to_jax.register(abc.Iterable)
|
||||
def _iterable_numpy_to_jax(
|
||||
value: Iterable[np.ndarray | Any],
|
||||
) -> Iterable[jnp.DeviceArray | Any]:
|
||||
"""Converts an Iterable from Numpy Arrays to an iterable of Jax DeviceArrays."""
|
||||
return type(value)(numpy_to_jax(v) for v in value)
|
||||
|
||||
|
||||
@functools.singledispatch
|
||||
def jax_to_numpy(value: Any) -> Any:
|
||||
"""Converts a value to a numpy array."""
|
||||
raise Exception(
|
||||
f"No conversion for Jax to Numpy registered for type: {type(value)}"
|
||||
)
|
||||
|
||||
|
||||
@jax_to_numpy.register(jnp.DeviceArray)
|
||||
def _devicearray_jax_to_numpy(value: jnp.DeviceArray) -> np.ndarray:
|
||||
"""Converts a Jax DeviceArray to a numpy array."""
|
||||
return np.array(value)
|
||||
|
||||
|
||||
@jax_to_numpy.register(abc.Mapping)
|
||||
def _mapping_jax_to_numpy(
|
||||
value: Mapping[str, jnp.DeviceArray | Any]
|
||||
) -> Mapping[str, np.ndarray | Any]:
|
||||
"""Converts a dictionary of Jax DeviceArrays to a mapping of numpy arrays."""
|
||||
return type(value)(**{k: jax_to_numpy(v) for k, v in value.items()})
|
||||
|
||||
|
||||
@jax_to_numpy.register(abc.Iterable)
|
||||
def _iterable_jax_to_numpy(
|
||||
value: Iterable[np.ndarray | Any],
|
||||
) -> Iterable[jnp.DeviceArray | Any]:
|
||||
"""Converts an Iterable from Numpy arrays to an iterable of Jax DeviceArrays."""
|
||||
return type(value)(jax_to_numpy(v) for v in value)
|
||||
|
||||
|
||||
class JaxToNumpyV0(Wrapper):
|
||||
"""Wraps a jax environment so that it can be interacted with through numpy arrays.
|
||||
|
||||
Actions must be provided as numpy arrays and observations will be returned as numpy arrays.
|
||||
|
||||
Notes:
|
||||
The Jax To Numpy and Numpy to Jax conversion does not guarantee a roundtrip (jax -> numpy -> jax) and vice versa.
|
||||
The reason for this is jax does not support non-array values, therefore numpy ``int_32(5) -> DeviceArray([5], dtype=jnp.int23)``
|
||||
"""
|
||||
|
||||
def __init__(self, env: Env):
|
||||
"""Wraps an environment such that the input and outputs are numpy arrays.
|
||||
|
||||
Args:
|
||||
env: the environment to wrap
|
||||
"""
|
||||
super().__init__(env)
|
||||
|
||||
def step(
|
||||
self, action: WrapperActType
|
||||
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]:
|
||||
"""Transforms the action to a jax array .
|
||||
|
||||
Args:
|
||||
action: the action to perform as a numpy array
|
||||
|
||||
Returns:
|
||||
A tuple containing the next observation, reward, termination, truncation, and extra info.
|
||||
"""
|
||||
jax_action = numpy_to_jax(action)
|
||||
obs, reward, terminated, truncated, info = self.env.step(jax_action)
|
||||
|
||||
return (
|
||||
jax_to_numpy(obs),
|
||||
float(reward),
|
||||
bool(terminated),
|
||||
bool(truncated),
|
||||
jax_to_numpy(info),
|
||||
)
|
||||
|
||||
def reset(
|
||||
self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
||||
) -> tuple[WrapperObsType, dict[str, Any]]:
|
||||
"""Resets the environment returning numpy-based observation and info.
|
||||
|
||||
Args:
|
||||
seed: The seed for resetting the environment
|
||||
options: The options for resetting the environment, these are converted to jax arrays.
|
||||
|
||||
Returns:
|
||||
Numpy-based observations and info
|
||||
"""
|
||||
if options:
|
||||
options = numpy_to_jax(options)
|
||||
|
||||
return jax_to_numpy(self.env.reset(seed=seed, options=options))
|
||||
|
||||
def render(self) -> RenderFrame | list[RenderFrame] | None:
|
||||
"""Returns the rendered frames as a numpy array."""
|
||||
return jax_to_numpy(self.env.render())
|
163
gymnasium/experimental/wrappers/torch_to_jax.py
Normal file
163
gymnasium/experimental/wrappers/torch_to_jax.py
Normal file
@@ -0,0 +1,163 @@
|
||||
# This wrapper will convert torch inputs for the actions and observations to Jax arrays
|
||||
# for an underlying Jax environment then convert the return observations from Jax arrays
|
||||
# back to torch tensors.
|
||||
#
|
||||
# Functionality for converting between torch and jax types originally copied from
|
||||
# https://github.com/google/brax/blob/9d6b7ced2a13da0d074b5e9fbd3aad8311e26997/brax/io/torch.py
|
||||
# Under the Apache 2.0 license. Copyright is held by the authors
|
||||
|
||||
"""Helper functions and wrapper class for converting between PyTorch and Jax."""
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import numbers
|
||||
from collections import abc
|
||||
from typing import Any, Iterable, Mapping, SupportsFloat, Union
|
||||
|
||||
import jax.numpy as jnp
|
||||
from jax import dlpack as jax_dlpack
|
||||
|
||||
from gymnasium import Env, Wrapper
|
||||
from gymnasium.core import RenderFrame, WrapperActType, WrapperObsType
|
||||
from gymnasium.error import DependencyNotInstalled
|
||||
from gymnasium.experimental.wrappers.numpy_to_jax import jax_to_numpy
|
||||
|
||||
try:
|
||||
import torch
|
||||
from torch.utils import dlpack as torch_dlpack
|
||||
except ImportError:
|
||||
raise DependencyNotInstalled("torch is not installed, run `pip install torch`")
|
||||
|
||||
|
||||
Device = Union[str, torch.device]
|
||||
|
||||
|
||||
@functools.singledispatch
|
||||
def torch_to_jax(value: Any) -> Any:
|
||||
"""Converts a PyTorch Tensor into a Jax DeviceArray."""
|
||||
raise Exception(
|
||||
f"No conversion for PyTorch to Jax registered for type: {type(value)}"
|
||||
)
|
||||
|
||||
|
||||
@torch_to_jax.register(numbers.Number)
|
||||
def _number_torch_to_jax(value: numbers.Number) -> Any:
|
||||
return jnp.array(value)
|
||||
|
||||
|
||||
@torch_to_jax.register(torch.Tensor)
|
||||
def _tensor_torch_to_jax(value: torch.Tensor) -> jnp.DeviceArray:
|
||||
"""Converts a PyTorch Tensor into a Jax DeviceArray."""
|
||||
tensor = torch_dlpack.to_dlpack(value) # pyright: ignore[reportPrivateImportUsage]
|
||||
tensor = jax_dlpack.from_dlpack(tensor) # pyright: ignore[reportPrivateImportUsage]
|
||||
return tensor
|
||||
|
||||
|
||||
@torch_to_jax.register(abc.Mapping)
|
||||
def _mapping_torch_to_jax(value: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
"""Converts a mapping of PyTorch Tensors into a Dictionary of Jax DeviceArrays."""
|
||||
return type(value)(**{k: torch_to_jax(v) for k, v in value.items()})
|
||||
|
||||
|
||||
@torch_to_jax.register(abc.Iterable)
|
||||
def _iterable_torch_to_jax(value: Iterable[Any]) -> Iterable[Any]:
|
||||
"""Converts an Iterable from PyTorch Tensors to an iterable of Jax DeviceArrays."""
|
||||
return type(value)(torch_to_jax(v) for v in value)
|
||||
|
||||
|
||||
@functools.singledispatch
|
||||
def jax_to_torch(value: Any, device: Device | None = None) -> Any:
|
||||
"""Converts a Jax DeviceArray into a PyTorch Tensor."""
|
||||
raise Exception(
|
||||
f"No conversion for Jax to PyTorch registered for type={type(value)} and device: {device}"
|
||||
)
|
||||
|
||||
|
||||
@jax_to_torch.register(jnp.DeviceArray)
|
||||
def _devicearray_jax_to_torch(
|
||||
value: jnp.DeviceArray, device: Device | None = None
|
||||
) -> torch.Tensor:
|
||||
"""Converts a Jax DeviceArray into a PyTorch Tensor."""
|
||||
dlpack = jax_dlpack.to_dlpack(value) # pyright: ignore[reportPrivateImportUsage]
|
||||
tensor = torch_dlpack.from_dlpack(dlpack)
|
||||
if device:
|
||||
return tensor.to(device=device)
|
||||
return tensor
|
||||
|
||||
|
||||
@jax_to_torch.register(abc.Mapping)
|
||||
def _jax_mapping_to_torch(
|
||||
value: Mapping[str, Any], device: Device | None = None
|
||||
) -> Mapping[str, Any]:
|
||||
"""Converts a mapping of Jax DeviceArrays into a Dictionary of PyTorch Tensors."""
|
||||
return type(value)(**{k: jax_to_torch(v, device) for k, v in value.items()})
|
||||
|
||||
|
||||
@jax_to_torch.register(abc.Iterable)
|
||||
def _jax_iterable_to_torch(
|
||||
value: Iterable[Any], device: Device | None = None
|
||||
) -> Iterable[Any]:
|
||||
"""Converts an Iterable from Jax DeviceArrays to an iterable of PyTorch Tensors."""
|
||||
return type(value)(jax_to_torch(v, device) for v in value)
|
||||
|
||||
|
||||
class JaxToTorchV0(Wrapper):
|
||||
"""Wraps a jax-based environment so that it can be interacted with through PyTorch Tensors.
|
||||
|
||||
Actions must be provided as PyTorch Tensors and observations will be returned as PyTorch Tensors.
|
||||
|
||||
For ``rendered`` this is returned as a NumPy array not a pytorch Tensor.
|
||||
"""
|
||||
|
||||
def __init__(self, env: Env, device: Device | None = None):
|
||||
"""Wrapper class to change inputs and outputs of environment to PyTorch tensors.
|
||||
|
||||
Args:
|
||||
env: The Jax-based environment to wrap
|
||||
device: The device the torch Tensors should be moved to
|
||||
"""
|
||||
super().__init__(env)
|
||||
self.device: Device | None = device
|
||||
|
||||
def step(
|
||||
self, action: WrapperActType
|
||||
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]:
|
||||
"""Performs the given action within the environment.
|
||||
|
||||
Args:
|
||||
action: The action to perform as a PyTorch Tensor
|
||||
|
||||
Returns:
|
||||
The next observation, reward, termination, truncation, and extra info
|
||||
"""
|
||||
jax_action = torch_to_jax(action)
|
||||
obs, reward, terminated, truncated, info = self.env.step(jax_action)
|
||||
|
||||
return (
|
||||
jax_to_torch(obs, self.device),
|
||||
float(reward),
|
||||
bool(terminated),
|
||||
bool(truncated),
|
||||
jax_to_torch(info, self.device),
|
||||
)
|
||||
|
||||
def reset(
|
||||
self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
||||
) -> tuple[WrapperObsType, dict[str, Any]]:
|
||||
"""Resets the environment returning PyTorch-based observation and info.
|
||||
|
||||
Args:
|
||||
seed: The seed for resetting the environment
|
||||
options: The options for resetting the environment, these are converted to jax arrays.
|
||||
|
||||
Returns:
|
||||
PyTorch-based observations and info
|
||||
"""
|
||||
if options:
|
||||
options = torch_to_jax(options)
|
||||
|
||||
return jax_to_torch(self.env.reset(seed=seed, options=options), self.device)
|
||||
|
||||
def render(self) -> RenderFrame | list[RenderFrame] | None:
|
||||
"""Returns the rendered frames as a NumPy array."""
|
||||
return jax_to_numpy(self.env.render())
|
9
setup.py
9
setup.py
@@ -42,7 +42,14 @@ extras: Dict[str, List[str]] = {
|
||||
"mujoco": ["mujoco>=2.3.0", "imageio>=2.14.1"],
|
||||
"toy_text": ["pygame==2.1.0"],
|
||||
"jax": ["jax==0.3.20", "jaxlib==0.3.20"],
|
||||
"other": ["lz4>=3.1.0", "opencv-python>=3.0", "matplotlib>=3.0", "moviepy>=1.0.0"],
|
||||
"other": [
|
||||
"lz4>=3.1.0",
|
||||
"opencv-python>=3.0",
|
||||
"matplotlib>=3.0",
|
||||
"moviepy>=1.0.0",
|
||||
"tensorflow>=2.1.0",
|
||||
"torch>=1.0.0",
|
||||
],
|
||||
}
|
||||
|
||||
# All dependency groups - accept rom license as requires user to run
|
||||
|
83
tests/experimental/wrappers/test_numpy_to_jax.py
Normal file
83
tests/experimental/wrappers/test_numpy_to_jax.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from gymnasium.experimental.wrappers import JaxToNumpyV0
|
||||
from gymnasium.experimental.wrappers.numpy_to_jax import jax_to_numpy, numpy_to_jax
|
||||
from gymnasium.utils.env_checker import data_equivalence
|
||||
from tests.testing_env import GenericTestEnv
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"value, expected_value",
|
||||
[
|
||||
(1.0, np.array(1.0)),
|
||||
(2, np.array(2)),
|
||||
((3.0, 4), (np.array(3.0), np.array(4))),
|
||||
([3.0, 4], [np.array(3.0), np.array(4)]),
|
||||
(
|
||||
{
|
||||
"a": 6.0,
|
||||
"b": 7,
|
||||
},
|
||||
{"a": np.array(6.0), "b": np.array(7)},
|
||||
),
|
||||
(np.array(1.0), np.array(1.0)),
|
||||
(np.array([1, 2]), np.array([1, 2])),
|
||||
(np.array([[1.0], [2.0]]), np.array([[1.0], [2.0]])),
|
||||
(
|
||||
{"a": (1, np.array(2.0), np.array([3, 4])), "b": {"c": 5}},
|
||||
{
|
||||
"a": (np.array(1), np.array(2.0), np.array([3, 4])),
|
||||
"b": {"c": np.array(5)},
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_roundtripping(value, expected_value):
|
||||
"""We test numpy -> jax -> numpy as this is direction in the NumpyToJax wrapper."""
|
||||
assert data_equivalence(jax_to_numpy(numpy_to_jax(value)), expected_value)
|
||||
|
||||
|
||||
def jax_reset_func(self, seed=None, options=None):
|
||||
return jnp.array([1.0, 2.0, 3.0]), {"data": jnp.array([1, 2, 3])}
|
||||
|
||||
|
||||
def jax_step_func(self, action):
|
||||
assert isinstance(action, jnp.DeviceArray), type(action)
|
||||
return (
|
||||
jnp.array([1, 2, 3]),
|
||||
jnp.array(5.0),
|
||||
jnp.array(True),
|
||||
jnp.array(False),
|
||||
{"data": jnp.array([1.0, 2.0])},
|
||||
)
|
||||
|
||||
|
||||
def test_jax_to_numpy():
|
||||
jax_env = GenericTestEnv(reset_fn=jax_reset_func, step_fn=jax_step_func)
|
||||
|
||||
# Check that the reset and step for jax environment are as expected
|
||||
obs, info = jax_env.reset()
|
||||
assert isinstance(obs, jnp.DeviceArray)
|
||||
assert isinstance(info, dict) and isinstance(info["data"], jnp.DeviceArray)
|
||||
|
||||
obs, reward, terminated, truncated, info = jax_env.step(jnp.array([1, 2]))
|
||||
assert isinstance(obs, jnp.DeviceArray)
|
||||
assert isinstance(reward, jnp.DeviceArray)
|
||||
assert isinstance(terminated, jnp.DeviceArray) and isinstance(
|
||||
truncated, jnp.DeviceArray
|
||||
)
|
||||
assert isinstance(info, dict) and isinstance(info["data"], jnp.DeviceArray)
|
||||
|
||||
# Check that the wrapped version is correct.
|
||||
numpy_env = JaxToNumpyV0(jax_env)
|
||||
obs, info = numpy_env.reset()
|
||||
assert isinstance(obs, np.ndarray)
|
||||
assert isinstance(info, dict) and isinstance(info["data"], np.ndarray)
|
||||
|
||||
obs, reward, terminated, truncated, info = numpy_env.step(np.array([1, 2]))
|
||||
assert isinstance(obs, np.ndarray)
|
||||
assert isinstance(reward, float)
|
||||
assert isinstance(terminated, bool) and isinstance(truncated, bool)
|
||||
assert isinstance(info, dict) and isinstance(info["data"], np.ndarray)
|
103
tests/experimental/wrappers/test_torch_to_jax.py
Normal file
103
tests/experimental/wrappers/test_torch_to_jax.py
Normal file
@@ -0,0 +1,103 @@
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from gymnasium.experimental.wrappers import JaxToTorchV0
|
||||
from gymnasium.experimental.wrappers.torch_to_jax import jax_to_torch, torch_to_jax
|
||||
from tests.testing_env import GenericTestEnv
|
||||
|
||||
|
||||
def torch_data_equivalence(data_1, data_2) -> bool:
|
||||
if type(data_1) == type(data_2):
|
||||
if isinstance(data_1, dict):
|
||||
return data_1.keys() == data_2.keys() and all(
|
||||
torch_data_equivalence(data_1[k], data_2[k]) for k in data_1.keys()
|
||||
)
|
||||
elif isinstance(data_1, (tuple, list)):
|
||||
return len(data_1) == len(data_2) and all(
|
||||
torch_data_equivalence(o_1, o_2) for o_1, o_2 in zip(data_1, data_2)
|
||||
)
|
||||
elif isinstance(data_1, torch.Tensor):
|
||||
return data_1.shape == data_2.shape and np.allclose(
|
||||
data_1, data_2, atol=0.00001
|
||||
)
|
||||
else:
|
||||
return data_1 == data_2
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"value, expected_value",
|
||||
[
|
||||
(1.0, torch.tensor(1.0)),
|
||||
(2, torch.tensor(2)),
|
||||
((3.0, 4), (torch.tensor(3.0), torch.tensor(4))),
|
||||
([3.0, 4], [torch.tensor(3.0), torch.tensor(4)]),
|
||||
(
|
||||
{
|
||||
"a": 6.0,
|
||||
"b": 7,
|
||||
},
|
||||
{"a": torch.tensor(6.0), "b": torch.tensor(7)},
|
||||
),
|
||||
(torch.tensor(1.0), torch.tensor(1.0)),
|
||||
(torch.tensor([1, 2]), torch.tensor([1, 2])),
|
||||
(torch.tensor([[1.0], [2.0]]), torch.tensor([[1.0], [2.0]])),
|
||||
(
|
||||
{"a": (1, torch.tensor(2.0), torch.tensor([3, 4])), "b": {"c": 5}},
|
||||
{
|
||||
"a": (torch.tensor(1), torch.tensor(2.0), torch.tensor([3, 4])),
|
||||
"b": {"c": torch.tensor(5)},
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_roundtripping(value, expected_value):
|
||||
"""We test numpy -> jax -> numpy as this is direction in the NumpyToJax wrapper."""
|
||||
assert torch_data_equivalence(jax_to_torch(torch_to_jax(value)), expected_value)
|
||||
|
||||
|
||||
def jax_reset_func(self, seed=None, options=None):
|
||||
return jnp.array([1.0, 2.0, 3.0]), {"data": jnp.array([1, 2, 3])}
|
||||
|
||||
|
||||
def jax_step_func(self, action):
|
||||
assert isinstance(action, jnp.DeviceArray), type(action)
|
||||
return (
|
||||
jnp.array([1, 2, 3]),
|
||||
jnp.array(5.0),
|
||||
jnp.array(True),
|
||||
jnp.array(False),
|
||||
{"data": jnp.array([1.0, 2.0])},
|
||||
)
|
||||
|
||||
|
||||
def test_jax_to_torch():
|
||||
env = GenericTestEnv(reset_fn=jax_reset_func, step_fn=jax_step_func)
|
||||
|
||||
# Check that the reset and step for jax environment are as expected
|
||||
obs, info = env.reset()
|
||||
assert isinstance(obs, jnp.DeviceArray)
|
||||
assert isinstance(info, dict) and isinstance(info["data"], jnp.DeviceArray)
|
||||
|
||||
obs, reward, terminated, truncated, info = env.step(jnp.array([1, 2]))
|
||||
assert isinstance(obs, jnp.DeviceArray)
|
||||
assert isinstance(reward, jnp.DeviceArray)
|
||||
assert isinstance(terminated, jnp.DeviceArray) and isinstance(
|
||||
truncated, jnp.DeviceArray
|
||||
)
|
||||
assert isinstance(info, dict) and isinstance(info["data"], jnp.DeviceArray)
|
||||
|
||||
# Check that the wrapped version is correct.
|
||||
wrapped_env = JaxToTorchV0(env)
|
||||
obs, info = wrapped_env.reset()
|
||||
assert isinstance(obs, torch.Tensor)
|
||||
assert isinstance(info, dict) and isinstance(info["data"], torch.Tensor)
|
||||
|
||||
obs, reward, terminated, truncated, info = wrapped_env.step(torch.tensor([1, 2]))
|
||||
assert isinstance(obs, torch.Tensor)
|
||||
assert isinstance(reward, float)
|
||||
assert isinstance(terminated, bool) and isinstance(truncated, bool)
|
||||
assert isinstance(info, dict) and isinstance(info["data"], torch.Tensor)
|
Reference in New Issue
Block a user