mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-20 22:12:03 +00:00
Add support for NamedTuple in jax->torch and numpy->torch (#811)
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
"""Test suite for NumPyToTorch wrapper."""
|
||||
from typing import NamedTuple
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
@@ -16,6 +17,11 @@ from gymnasium.wrappers.numpy_to_torch import ( # noqa: E402
|
||||
from tests.testing_env import GenericTestEnv # noqa: E402
|
||||
|
||||
|
||||
class ExampleNamedTuple(NamedTuple):
|
||||
a: np.ndarray
|
||||
b: np.ndarray
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"value, expected_value",
|
||||
[
|
||||
@@ -55,13 +61,21 @@ from tests.testing_env import GenericTestEnv # noqa: E402
|
||||
"b": {"c": np.array(5, dtype=np.int64)},
|
||||
},
|
||||
),
|
||||
(
|
||||
ExampleNamedTuple(
|
||||
a=np.array([1, 2], dtype=np.int32),
|
||||
b=np.array([1.0, 2.0], dtype=np.float32),
|
||||
),
|
||||
ExampleNamedTuple(
|
||||
a=np.array([1, 2], dtype=np.int32),
|
||||
b=np.array([1.0, 2.0], dtype=np.float32),
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_roundtripping(value, expected_value):
|
||||
"""We test numpy -> torch -> numpy as this is direction in the NumpyToTorch wrapper."""
|
||||
torch_value = numpy_to_torch(value)
|
||||
roundtripped_value = torch_to_numpy(torch_value)
|
||||
# roundtripped_value = torch_to_numpy(numpy_to_torch(value))
|
||||
roundtripped_value = torch_to_numpy(numpy_to_torch(value))
|
||||
assert data_equivalence(roundtripped_value, expected_value)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user