Files
Gymnasium/tests/experimental/wrappers/test_jax_to_numpy.py
2022-12-10 22:04:14 +00:00

89 lines
3.0 KiB
Python

"""Test suite for JaxToNumpyV0."""
import jax.numpy as jnp
import numpy as np
import pytest
from gymnasium.experimental.wrappers import JaxToNumpyV0
from gymnasium.experimental.wrappers.jax_to_numpy import jax_to_numpy, numpy_to_jax
from gymnasium.utils.env_checker import data_equivalence
from tests.testing_env import GenericTestEnv
@pytest.mark.parametrize(
"value, expected_value",
[
(1.0, np.array(1.0)),
(2, np.array(2)),
((3.0, 4), (np.array(3.0), np.array(4))),
([3.0, 4], [np.array(3.0), np.array(4)]),
(
{
"a": 6.0,
"b": 7,
},
{"a": np.array(6.0), "b": np.array(7)},
),
(np.array(1.0), np.array(1.0)),
(np.array([1, 2]), np.array([1, 2])),
(np.array([[1.0], [2.0]]), np.array([[1.0], [2.0]])),
(
{"a": (1, np.array(2.0), np.array([3, 4])), "b": {"c": 5}},
{
"a": (np.array(1), np.array(2.0), np.array([3, 4])),
"b": {"c": np.array(5)},
},
),
],
)
def test_roundtripping(value, expected_value):
"""We test numpy -> jax -> numpy as this is direction in the NumpyToJax wrapper."""
assert data_equivalence(jax_to_numpy(numpy_to_jax(value)), expected_value)
def jax_reset_func(self, seed=None, options=None):
"""A jax-based reset function."""
return jnp.array([1.0, 2.0, 3.0]), {"data": jnp.array([1, 2, 3])}
def jax_step_func(self, action):
"""A jax-based step function."""
assert isinstance(action, jnp.DeviceArray), type(action)
return (
jnp.array([1, 2, 3]),
jnp.array(5.0),
jnp.array(True),
jnp.array(False),
{"data": jnp.array([1.0, 2.0])},
)
def test_jax_to_numpy_wrapper():
"""Tests the ``JaxToNumpyV0`` wrapper."""
jax_env = GenericTestEnv(reset_func=jax_reset_func, step_func=jax_step_func)
# Check that the reset and step for jax environment are as expected
obs, info = jax_env.reset()
assert isinstance(obs, jnp.DeviceArray)
assert isinstance(info, dict) and isinstance(info["data"], jnp.DeviceArray)
obs, reward, terminated, truncated, info = jax_env.step(jnp.array([1, 2]))
assert isinstance(obs, jnp.DeviceArray)
assert isinstance(reward, jnp.DeviceArray)
assert isinstance(terminated, jnp.DeviceArray) and isinstance(
truncated, jnp.DeviceArray
)
assert isinstance(info, dict) and isinstance(info["data"], jnp.DeviceArray)
# Check that the wrapped version is correct.
numpy_env = JaxToNumpyV0(jax_env)
obs, info = numpy_env.reset()
assert isinstance(obs, np.ndarray)
assert isinstance(info, dict) and isinstance(info["data"], np.ndarray)
obs, reward, terminated, truncated, info = numpy_env.step(np.array([1, 2]))
assert isinstance(obs, np.ndarray)
assert isinstance(reward, float)
assert isinstance(terminated, bool) and isinstance(truncated, bool)
assert isinstance(info, dict) and isinstance(info["data"], np.ndarray)