Add support for NamedTuple in jax->torch and numpy->torch (#811)

This commit is contained in:
Mark Towers
2023-12-04 12:14:19 +00:00
committed by GitHub
parent b57b9139cd
commit 359cb59e8d
5 changed files with 81 additions and 13 deletions

View File

@@ -1,4 +1,5 @@
"""Test suite for NumPyToTorch wrapper."""
from typing import NamedTuple
import numpy as np
import pytest
@@ -16,6 +17,11 @@ from gymnasium.wrappers.numpy_to_torch import ( # noqa: E402
from tests.testing_env import GenericTestEnv # noqa: E402
class ExampleNamedTuple(NamedTuple):
a: np.ndarray
b: np.ndarray
@pytest.mark.parametrize(
"value, expected_value",
[
@@ -55,13 +61,21 @@ from tests.testing_env import GenericTestEnv # noqa: E402
"b": {"c": np.array(5, dtype=np.int64)},
},
),
(
ExampleNamedTuple(
a=np.array([1, 2], dtype=np.int32),
b=np.array([1.0, 2.0], dtype=np.float32),
),
ExampleNamedTuple(
a=np.array([1, 2], dtype=np.int32),
b=np.array([1.0, 2.0], dtype=np.float32),
),
),
],
)
def test_roundtripping(value, expected_value):
"""We test numpy -> torch -> numpy as this is direction in the NumpyToTorch wrapper."""
torch_value = numpy_to_torch(value)
roundtripped_value = torch_to_numpy(torch_value)
# roundtripped_value = torch_to_numpy(numpy_to_torch(value))
roundtripped_value = torch_to_numpy(numpy_to_torch(value))
assert data_equivalence(roundtripped_value, expected_value)