Files
Gymnasium/tests/wrappers/test_jax_to_numpy.py

145 lines
4.6 KiB
Python
Raw Normal View History

"""Test suite for JaxToNumpy wrapper."""
2024-06-10 17:07:47 +01:00
import pickle
from typing import NamedTuple
2022-12-10 22:04:14 +00:00
import numpy as np
import pytest
import gymnasium
2023-07-03 23:53:57 +02:00
jax = pytest.importorskip("jax")
jnp = pytest.importorskip("jax.numpy")
from gymnasium.utils.env_checker import data_equivalence # noqa: E402
from gymnasium.wrappers.jax_to_numpy import ( # noqa: E402
JaxToNumpy,
jax_to_numpy,
numpy_to_jax,
)
from tests.testing_env import GenericTestEnv # noqa: E402
class ExampleNamedTuple(NamedTuple):
a: jax.Array
b: jax.Array
@pytest.mark.parametrize(
"value, expected_value",
[
(1.0, np.array(1.0, dtype=np.float32)),
(2, np.array(2, dtype=np.int32)),
((3.0, 4), (np.array(3.0, dtype=np.float32), np.array(4, dtype=np.int32))),
([3.0, 4], [np.array(3.0, dtype=np.float32), np.array(4, dtype=np.int32)]),
(
{
"a": 6.0,
"b": 7,
},
{"a": np.array(6.0, dtype=np.float32), "b": np.array(7, dtype=np.int32)},
),
(np.array(1.0, dtype=np.float32), np.array(1.0, dtype=np.float32)),
(np.array(1.0, dtype=np.uint8), np.array(1.0, dtype=np.uint8)),
(np.array([1, 2], dtype=np.int32), np.array([1, 2], dtype=np.int32)),
(
np.array([[1.0], [2.0]], dtype=np.int32),
np.array([[1.0], [2.0]], dtype=np.int32),
),
(
{
"a": (
1,
np.array(2.0, dtype=np.float32),
np.array([3, 4], dtype=np.int32),
),
"b": {"c": 5},
},
{
"a": (
np.array(1, dtype=np.int32),
np.array(2.0, dtype=np.float32),
np.array([3, 4], dtype=np.int32),
),
"b": {"c": np.array(5, dtype=np.int32)},
},
),
(
ExampleNamedTuple(
a=np.array([1, 2], dtype=np.int32),
b=np.array([1.0, 2.0], dtype=np.float32),
),
ExampleNamedTuple(
a=np.array([1, 2], dtype=np.int32),
b=np.array([1.0, 2.0], dtype=np.float32),
),
),
(None, None),
],
)
def test_roundtripping(value, expected_value):
"""We test numpy -> jax -> numpy as this is direction in the NumpyToJax wrapper.
Warning: Jax doesn't support float64 out of the box, therefore, we only test float32 in this test.
"""
roundtripped_value = jax_to_numpy(numpy_to_jax(value))
assert data_equivalence(roundtripped_value, expected_value)
def jax_reset_func(self, seed=None, options=None):
2022-12-10 22:04:14 +00:00
"""A jax-based reset function."""
return jnp.array([1.0, 2.0, 3.0]), {"data": jnp.array([1, 2, 3])}
def jax_step_func(self, action):
2022-12-10 22:04:14 +00:00
"""A jax-based step function."""
2023-07-03 23:53:57 +02:00
assert isinstance(action, jax.Array), type(action)
return (
jnp.array([1, 2, 3]),
jnp.array(5.0),
jnp.array(True),
jnp.array(False),
{"data": jnp.array([1.0, 2.0])},
)
2022-12-10 22:04:14 +00:00
def test_jax_to_numpy_wrapper():
"""Tests the ``JaxToNumpyV0`` wrapper."""
2022-12-05 19:14:56 +00:00
jax_env = GenericTestEnv(reset_func=jax_reset_func, step_func=jax_step_func)
# Check that the reset and step for jax environment are as expected
obs, info = jax_env.reset()
2023-07-03 23:53:57 +02:00
assert isinstance(obs, jax.Array)
assert isinstance(info, dict) and isinstance(info["data"], jax.Array)
obs, reward, terminated, truncated, info = jax_env.step(jnp.array([1, 2]))
2023-07-03 23:53:57 +02:00
assert isinstance(obs, jax.Array)
assert isinstance(reward, jax.Array)
assert isinstance(terminated, jax.Array) and isinstance(truncated, jax.Array)
assert isinstance(info, dict) and isinstance(info["data"], jax.Array)
# Check that the wrapped version is correct.
numpy_env = JaxToNumpy(jax_env)
obs, info = numpy_env.reset()
assert isinstance(obs, np.ndarray)
assert isinstance(info, dict) and isinstance(info["data"], np.ndarray)
obs, reward, terminated, truncated, info = numpy_env.step(
np.array([1, 2], dtype=np.int32)
)
assert isinstance(obs, np.ndarray)
assert isinstance(reward, float)
assert isinstance(terminated, bool) and isinstance(truncated, bool)
assert isinstance(info, dict) and isinstance(info["data"], np.ndarray)
# Check that the wrapped environment can render. This implicitly returns None and requires a
# None -> None conversion
numpy_env.render()
# Test that the wrapped environment can be pickled
env = gymnasium.make("CartPole-v1", disable_env_checker=True)
wrapped_env = JaxToNumpy(env)
pkl = pickle.dumps(wrapped_env)
pickle.loads(pkl)