2023-11-07 13:27:25 +00:00
|
|
|
"""Test suite for TorchToJax wrapper."""
|
2023-12-04 12:14:19 +00:00
|
|
|
from typing import NamedTuple
|
2022-12-10 22:04:14 +00:00
|
|
|
|
2022-12-03 19:46:01 +00:00
|
|
|
import numpy as np
|
|
|
|
import pytest
|
|
|
|
|
2023-04-25 03:47:51 -07:00
|
|
|
|
2023-07-03 23:53:57 +02:00
|
|
|
jax = pytest.importorskip("jax")
|
2023-04-25 03:47:51 -07:00
|
|
|
jnp = pytest.importorskip("jax.numpy")
|
|
|
|
torch = pytest.importorskip("torch")
|
|
|
|
|
2023-11-07 13:27:25 +00:00
|
|
|
from gymnasium.wrappers.jax_to_torch import ( # noqa: E402
|
|
|
|
JaxToTorch,
|
2023-02-14 15:03:38 -05:00
|
|
|
jax_to_torch,
|
|
|
|
torch_to_jax,
|
|
|
|
)
|
2023-04-25 03:47:51 -07:00
|
|
|
from tests.testing_env import GenericTestEnv # noqa: E402
|
2022-12-03 19:46:01 +00:00
|
|
|
|
|
|
|
|
|
|
|
def torch_data_equivalence(data_1, data_2) -> bool:
|
2022-12-05 19:14:56 +00:00
|
|
|
"""Return if two variables are equivalent that might contain ``torch.Tensor``."""
|
2022-12-03 19:46:01 +00:00
|
|
|
if type(data_1) == type(data_2):
|
|
|
|
if isinstance(data_1, dict):
|
|
|
|
return data_1.keys() == data_2.keys() and all(
|
|
|
|
torch_data_equivalence(data_1[k], data_2[k]) for k in data_1.keys()
|
|
|
|
)
|
|
|
|
elif isinstance(data_1, (tuple, list)):
|
|
|
|
return len(data_1) == len(data_2) and all(
|
|
|
|
torch_data_equivalence(o_1, o_2) for o_1, o_2 in zip(data_1, data_2)
|
|
|
|
)
|
|
|
|
elif isinstance(data_1, torch.Tensor):
|
|
|
|
return data_1.shape == data_2.shape and np.allclose(
|
|
|
|
data_1, data_2, atol=0.00001
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
return data_1 == data_2
|
|
|
|
else:
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
2023-12-04 12:14:19 +00:00
|
|
|
class ExampleNamedTuple(NamedTuple):
|
|
|
|
a: torch.Tensor
|
|
|
|
b: torch.Tensor
|
|
|
|
|
|
|
|
|
2022-12-03 19:46:01 +00:00
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"value, expected_value",
|
|
|
|
[
|
|
|
|
(1.0, torch.tensor(1.0)),
|
|
|
|
(2, torch.tensor(2)),
|
|
|
|
((3.0, 4), (torch.tensor(3.0), torch.tensor(4))),
|
|
|
|
([3.0, 4], [torch.tensor(3.0), torch.tensor(4)]),
|
|
|
|
(
|
|
|
|
{
|
|
|
|
"a": 6.0,
|
|
|
|
"b": 7,
|
|
|
|
},
|
|
|
|
{"a": torch.tensor(6.0), "b": torch.tensor(7)},
|
|
|
|
),
|
|
|
|
(torch.tensor(1.0), torch.tensor(1.0)),
|
2023-12-04 12:14:19 +00:00
|
|
|
(torch.tensor(1.0), torch.tensor(1.0)),
|
2022-12-03 19:46:01 +00:00
|
|
|
(torch.tensor([1, 2]), torch.tensor([1, 2])),
|
|
|
|
(
|
2023-12-04 12:14:19 +00:00
|
|
|
torch.tensor([[1.0], [2.0]]),
|
|
|
|
torch.tensor([[1.0], [2.0]]),
|
|
|
|
),
|
|
|
|
(
|
2022-12-03 19:46:01 +00:00
|
|
|
{
|
2023-12-04 12:14:19 +00:00
|
|
|
"a": (
|
|
|
|
1,
|
|
|
|
torch.tensor(2.0),
|
|
|
|
torch.tensor([3, 4]),
|
|
|
|
),
|
|
|
|
"b": {"c": 5},
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"a": (
|
|
|
|
torch.tensor(1),
|
|
|
|
torch.tensor(2.0),
|
|
|
|
torch.tensor([3, 4]),
|
|
|
|
),
|
2022-12-03 19:46:01 +00:00
|
|
|
"b": {"c": torch.tensor(5)},
|
|
|
|
},
|
|
|
|
),
|
2023-12-04 12:14:19 +00:00
|
|
|
(
|
|
|
|
ExampleNamedTuple(
|
|
|
|
a=torch.tensor([1, 2]),
|
|
|
|
b=torch.tensor([1.0, 2.0]),
|
|
|
|
),
|
|
|
|
ExampleNamedTuple(
|
|
|
|
a=torch.tensor([1, 2]),
|
|
|
|
b=torch.tensor([1.0, 2.0]),
|
|
|
|
),
|
|
|
|
),
|
2022-12-03 19:46:01 +00:00
|
|
|
],
|
|
|
|
)
|
|
|
|
def test_roundtripping(value, expected_value):
|
|
|
|
"""We test numpy -> jax -> numpy as this is direction in the NumpyToJax wrapper."""
|
2023-12-04 12:14:19 +00:00
|
|
|
print(f"{value=}")
|
|
|
|
print(f"{torch_to_jax(value)=}")
|
|
|
|
print(f"{jax_to_torch(torch_to_jax(value))=}")
|
2022-12-05 19:14:56 +00:00
|
|
|
roundtripped_value = jax_to_torch(torch_to_jax(value))
|
|
|
|
assert torch_data_equivalence(roundtripped_value, expected_value)
|
2022-12-03 19:46:01 +00:00
|
|
|
|
|
|
|
|
2022-12-05 19:14:56 +00:00
|
|
|
def _jax_reset_func(self, seed=None, options=None):
|
2022-12-03 19:46:01 +00:00
|
|
|
return jnp.array([1.0, 2.0, 3.0]), {"data": jnp.array([1, 2, 3])}
|
|
|
|
|
|
|
|
|
2022-12-05 19:14:56 +00:00
|
|
|
def _jax_step_func(self, action):
|
2023-07-03 23:53:57 +02:00
|
|
|
assert isinstance(action, jax.Array), type(action)
|
2022-12-03 19:46:01 +00:00
|
|
|
return (
|
|
|
|
jnp.array([1, 2, 3]),
|
|
|
|
jnp.array(5.0),
|
|
|
|
jnp.array(True),
|
|
|
|
jnp.array(False),
|
|
|
|
{"data": jnp.array([1.0, 2.0])},
|
|
|
|
)
|
|
|
|
|
|
|
|
|
2022-12-10 22:04:14 +00:00
|
|
|
def test_jax_to_torch_wrapper():
|
|
|
|
"""Tests the `JaxToTorchV0` wrapper."""
|
2022-12-05 19:14:56 +00:00
|
|
|
env = GenericTestEnv(reset_func=_jax_reset_func, step_func=_jax_step_func)
|
2022-12-03 19:46:01 +00:00
|
|
|
|
|
|
|
# Check that the reset and step for jax environment are as expected
|
|
|
|
obs, info = env.reset()
|
2023-07-03 23:53:57 +02:00
|
|
|
assert isinstance(obs, jax.Array)
|
|
|
|
assert isinstance(info, dict) and isinstance(info["data"], jax.Array)
|
2022-12-03 19:46:01 +00:00
|
|
|
|
|
|
|
obs, reward, terminated, truncated, info = env.step(jnp.array([1, 2]))
|
2023-07-03 23:53:57 +02:00
|
|
|
assert isinstance(obs, jax.Array)
|
|
|
|
assert isinstance(reward, jax.Array)
|
|
|
|
assert isinstance(terminated, jax.Array) and isinstance(truncated, jax.Array)
|
|
|
|
assert isinstance(info, dict) and isinstance(info["data"], jax.Array)
|
2022-12-03 19:46:01 +00:00
|
|
|
|
|
|
|
# Check that the wrapped version is correct.
|
2023-11-07 13:27:25 +00:00
|
|
|
wrapped_env = JaxToTorch(env)
|
2022-12-03 19:46:01 +00:00
|
|
|
obs, info = wrapped_env.reset()
|
|
|
|
assert isinstance(obs, torch.Tensor)
|
|
|
|
assert isinstance(info, dict) and isinstance(info["data"], torch.Tensor)
|
|
|
|
|
|
|
|
obs, reward, terminated, truncated, info = wrapped_env.step(torch.tensor([1, 2]))
|
|
|
|
assert isinstance(obs, torch.Tensor)
|
|
|
|
assert isinstance(reward, float)
|
|
|
|
assert isinstance(terminated, bool) and isinstance(truncated, bool)
|
|
|
|
assert isinstance(info, dict) and isinstance(info["data"], torch.Tensor)
|