Files
Gymnasium/tests/wrappers/test_jax_to_torch.py

160 lines
5.2 KiB
Python

"""Test suite for TorchToJax wrapper."""
import pickle
from typing import NamedTuple
import pytest
import gymnasium
jax = pytest.importorskip("jax")
jnp = pytest.importorskip("jax.numpy")
torch = pytest.importorskip("torch")
from gymnasium.wrappers.jax_to_torch import ( # noqa: E402
JaxToTorch,
jax_to_torch,
torch_to_jax,
)
from tests.testing_env import GenericTestEnv # noqa: E402
def torch_data_equivalence(data_1, data_2) -> bool:
"""Return if two variables are equivalent that might contain ``torch.Tensor``."""
if type(data_1) is type(data_2):
if isinstance(data_1, dict):
return data_1.keys() == data_2.keys() and all(
torch_data_equivalence(data_1[k], data_2[k]) for k in data_1.keys()
)
elif isinstance(data_1, (tuple, list)):
return len(data_1) == len(data_2) and all(
torch_data_equivalence(o_1, o_2) for o_1, o_2 in zip(data_1, data_2)
)
elif isinstance(data_1, torch.Tensor):
return data_1.shape == data_2.shape and torch.allclose(
data_1, data_2, atol=0.00001
)
else:
return data_1 == data_2
else:
return False
class ExampleNamedTuple(NamedTuple):
a: torch.Tensor
b: torch.Tensor
# Unless jax_enable_x64 is set at startup, jax will prevent us from creating double precision
# arrays. Therefore, all arrays are expected to be single precision after a roundtrip.
@pytest.mark.parametrize(
"value, expected_value",
[
(1.0, torch.tensor(1.0)),
(2, torch.tensor(2, dtype=torch.int32)),
((3.0, 4), (torch.tensor(3.0), torch.tensor(4, dtype=torch.int32))),
([3.0, 4], [torch.tensor(3.0), torch.tensor(4, dtype=torch.int32)]),
(
{
"a": 6.0,
"b": 7,
},
{"a": torch.tensor(6.0), "b": torch.tensor(7, dtype=torch.int32)},
),
(torch.tensor(1.0), torch.tensor(1.0)),
(torch.tensor(1), torch.tensor(1, dtype=torch.int32)),
(torch.tensor([1, 2]), torch.tensor([1, 2], dtype=torch.int32)),
(
torch.tensor([[1.0], [2.0]]),
torch.tensor([[1.0], [2.0]]),
),
(
{
"a": (
1,
torch.tensor(2.0),
torch.tensor([3, 4]),
),
"b": {"c": 5},
},
{
"a": (
torch.tensor(1, dtype=torch.int32),
torch.tensor(2.0),
torch.tensor([3, 4], dtype=torch.int32),
),
"b": {"c": torch.tensor(5, dtype=torch.int32)},
},
),
(
ExampleNamedTuple(
a=torch.tensor([1, 2]),
b=torch.tensor([1.0, 2.0]),
),
ExampleNamedTuple(
a=torch.tensor([1, 2], dtype=torch.int32),
b=torch.tensor([1.0, 2.0]),
),
),
(None, None),
],
)
def test_roundtripping(value, expected_value):
"""We test numpy -> jax -> numpy as this is direction in the NumpyToJax wrapper."""
roundtripped_value = jax_to_torch(torch_to_jax(value))
assert torch_data_equivalence(roundtripped_value, expected_value)
def _jax_reset_func(self, seed=None, options=None):
return jnp.array([1.0, 2.0, 3.0]), {"data": jnp.array([1, 2, 3])}
def _jax_step_func(self, action):
assert isinstance(action, jax.Array), type(action)
return (
jnp.array([1, 2, 3]),
jnp.array(5.0),
jnp.array(True),
jnp.array(False),
{"data": jnp.array([1.0, 2.0])},
)
def test_jax_to_torch_wrapper():
"""Tests the `JaxToTorchV0` wrapper."""
env = GenericTestEnv(reset_func=_jax_reset_func, step_func=_jax_step_func)
# Check that the reset and step for jax environment are as expected
obs, info = env.reset()
assert isinstance(obs, jax.Array)
assert isinstance(info, dict) and isinstance(info["data"], jax.Array)
obs, reward, terminated, truncated, info = env.step(jnp.array([1, 2]))
assert isinstance(obs, jax.Array)
assert isinstance(reward, jax.Array)
assert isinstance(terminated, jax.Array) and isinstance(truncated, jax.Array)
assert isinstance(info, dict) and isinstance(info["data"], jax.Array)
# Check that the wrapped version is correct.
wrapped_env = JaxToTorch(env)
obs, info = wrapped_env.reset()
assert isinstance(obs, torch.Tensor)
assert isinstance(info, dict) and isinstance(info["data"], torch.Tensor)
obs, reward, terminated, truncated, info = wrapped_env.step(torch.tensor([1, 2]))
assert isinstance(obs, torch.Tensor)
assert isinstance(reward, float)
assert isinstance(terminated, bool) and isinstance(truncated, bool)
assert isinstance(info, dict) and isinstance(info["data"], torch.Tensor)
# Check that the wrapped environment can render. This implicitly returns None and requires a
# None -> None conversion
wrapped_env.render()
# Test that the wrapped environment can be pickled
env = gymnasium.make("CartPole-v1", disable_env_checker=True)
wrapped_env = JaxToTorch(env)
pkl = pickle.dumps(wrapped_env)
pickle.loads(pkl)