2023-11-07 13:27:25 +00:00
|
|
|
"""Test suite for NumPyToTorch wrapper."""
|
2022-12-10 22:04:14 +00:00
|
|
|
|
2022-12-03 19:46:01 +00:00
|
|
|
import numpy as np
|
|
|
|
import pytest
|
|
|
|
|
2023-04-25 03:47:51 -07:00
|
|
|
|
2023-11-07 13:27:25 +00:00
|
|
|
torch = pytest.importorskip("torch")
|
|
|
|
|
2023-04-25 03:47:51 -07:00
|
|
|
|
|
|
|
from gymnasium.utils.env_checker import data_equivalence # noqa: E402
|
2023-11-07 13:27:25 +00:00
|
|
|
from gymnasium.wrappers.numpy_to_torch import ( # noqa: E402
|
|
|
|
NumpyToTorch,
|
|
|
|
numpy_to_torch,
|
|
|
|
torch_to_numpy,
|
|
|
|
)
|
2023-04-25 03:47:51 -07:00
|
|
|
from tests.testing_env import GenericTestEnv # noqa: E402
|
2022-12-03 19:46:01 +00:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"value, expected_value",
|
|
|
|
[
|
2023-05-23 17:03:25 +01:00
|
|
|
(1.0, np.array(1.0, dtype=np.float32)),
|
2023-11-07 13:27:25 +00:00
|
|
|
(2, np.array(2, dtype=np.int64)),
|
|
|
|
((3.0, 4), (np.array(3.0, dtype=np.float32), np.array(4, dtype=np.int64))),
|
|
|
|
([3.0, 4], [np.array(3.0, dtype=np.float32), np.array(4, dtype=np.int64)]),
|
2022-12-03 19:46:01 +00:00
|
|
|
(
|
|
|
|
{
|
|
|
|
"a": 6.0,
|
|
|
|
"b": 7,
|
|
|
|
},
|
2023-11-07 13:27:25 +00:00
|
|
|
{"a": np.array(6.0, dtype=np.float32), "b": np.array(7, dtype=np.int64)},
|
2023-05-23 17:03:25 +01:00
|
|
|
),
|
|
|
|
(np.array(1.0, dtype=np.float32), np.array(1.0, dtype=np.float32)),
|
|
|
|
(np.array(1.0, dtype=np.uint8), np.array(1.0, dtype=np.uint8)),
|
|
|
|
(np.array([1, 2], dtype=np.int32), np.array([1, 2], dtype=np.int32)),
|
|
|
|
(
|
|
|
|
np.array([[1.0], [2.0]], dtype=np.int32),
|
|
|
|
np.array([[1.0], [2.0]], dtype=np.int32),
|
2022-12-03 19:46:01 +00:00
|
|
|
),
|
|
|
|
(
|
|
|
|
{
|
2023-05-23 17:03:25 +01:00
|
|
|
"a": (
|
|
|
|
1,
|
|
|
|
np.array(2.0, dtype=np.float32),
|
|
|
|
np.array([3, 4], dtype=np.int32),
|
|
|
|
),
|
|
|
|
"b": {"c": 5},
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"a": (
|
2023-11-07 13:27:25 +00:00
|
|
|
np.array(1, dtype=np.int64),
|
2023-05-23 17:03:25 +01:00
|
|
|
np.array(2.0, dtype=np.float32),
|
|
|
|
np.array([3, 4], dtype=np.int32),
|
|
|
|
),
|
2023-11-07 13:27:25 +00:00
|
|
|
"b": {"c": np.array(5, dtype=np.int64)},
|
2022-12-03 19:46:01 +00:00
|
|
|
},
|
|
|
|
),
|
|
|
|
],
|
|
|
|
)
|
|
|
|
def test_roundtripping(value, expected_value):
|
2023-11-07 13:27:25 +00:00
|
|
|
"""We test numpy -> torch -> numpy as this is direction in the NumpyToTorch wrapper."""
|
|
|
|
torch_value = numpy_to_torch(value)
|
|
|
|
roundtripped_value = torch_to_numpy(torch_value)
|
|
|
|
# roundtripped_value = torch_to_numpy(numpy_to_torch(value))
|
2023-05-23 17:03:25 +01:00
|
|
|
assert data_equivalence(roundtripped_value, expected_value)
|
2022-12-03 19:46:01 +00:00
|
|
|
|
|
|
|
|
2023-11-07 13:27:25 +00:00
|
|
|
def numpy_reset_func(self, seed=None, options=None):
|
|
|
|
"""A Numpy-based reset function."""
|
|
|
|
return np.array([1.0, 2.0, 3.0]), {"data": np.array([1, 2, 3])}
|
2022-12-03 19:46:01 +00:00
|
|
|
|
|
|
|
|
2023-11-07 13:27:25 +00:00
|
|
|
def numpy_step_func(self, action):
|
|
|
|
"""A Numpy-based step function."""
|
|
|
|
assert isinstance(action, np.ndarray), type(action)
|
2022-12-03 19:46:01 +00:00
|
|
|
return (
|
2023-11-07 13:27:25 +00:00
|
|
|
np.array([1, 2, 3]),
|
|
|
|
5.0,
|
|
|
|
True,
|
|
|
|
False,
|
|
|
|
{"data": np.array([1.0, 2.0])},
|
2022-12-03 19:46:01 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
|
2023-11-07 13:27:25 +00:00
|
|
|
def test_numpy_to_torch():
|
|
|
|
"""Tests the ``TorchToNumpy`` wrapper."""
|
|
|
|
numpy_env = GenericTestEnv(reset_func=numpy_reset_func, step_func=numpy_step_func)
|
2022-12-03 19:46:01 +00:00
|
|
|
obs, info = numpy_env.reset()
|
|
|
|
assert isinstance(obs, np.ndarray)
|
|
|
|
assert isinstance(info, dict) and isinstance(info["data"], np.ndarray)
|
|
|
|
|
|
|
|
obs, reward, terminated, truncated, info = numpy_env.step(np.array([1, 2]))
|
|
|
|
assert isinstance(obs, np.ndarray)
|
|
|
|
assert isinstance(reward, float)
|
|
|
|
assert isinstance(terminated, bool) and isinstance(truncated, bool)
|
|
|
|
assert isinstance(info, dict) and isinstance(info["data"], np.ndarray)
|
2023-11-07 13:27:25 +00:00
|
|
|
|
|
|
|
# Check that the wrapped version is correct.
|
|
|
|
torch_env = NumpyToTorch(numpy_env)
|
|
|
|
|
|
|
|
# Check that the reset and step for torch environment are as expected
|
|
|
|
obs, info = torch_env.reset()
|
|
|
|
assert isinstance(obs, torch.Tensor)
|
|
|
|
assert isinstance(info, dict) and isinstance(info["data"], torch.Tensor)
|
|
|
|
|
|
|
|
obs, reward, terminated, truncated, info = torch_env.step(torch.tensor([1, 2]))
|
|
|
|
assert isinstance(obs, torch.Tensor)
|
|
|
|
assert isinstance(reward, float)
|
|
|
|
assert isinstance(terminated, bool) and isinstance(truncated, bool)
|
|
|
|
assert isinstance(info, dict) and isinstance(info["data"], torch.Tensor)
|