diff --git a/gymnasium/wrappers/jax_to_torch.py b/gymnasium/wrappers/jax_to_torch.py index ce813fe7f..fbc5de33b 100644 --- a/gymnasium/wrappers/jax_to_torch.py +++ b/gymnasium/wrappers/jax_to_torch.py @@ -74,7 +74,12 @@ def _mapping_torch_to_jax(value: Mapping[str, Any]) -> Mapping[str, Any]: @torch_to_jax.register(abc.Iterable) def _iterable_torch_to_jax(value: Iterable[Any]) -> Iterable[Any]: """Converts an Iterable from PyTorch Tensors to an iterable of Jax Array.""" - return type(value)(torch_to_jax(v) for v in value) + if hasattr(value, "_make"): + # namedtuple - underline used to prevent potential name conflicts + # noinspection PyProtectedMember + return type(value)._make(torch_to_jax(v) for v in value) + else: + return type(value)(torch_to_jax(v) for v in value) @functools.singledispatch @@ -111,7 +116,12 @@ def _jax_iterable_to_torch( value: Iterable[Any], device: Device | None = None ) -> Iterable[Any]: """Converts an Iterable from Jax Array to an iterable of PyTorch Tensors.""" - return type(value)(jax_to_torch(v, device) for v in value) + if hasattr(value, "_make"): + # namedtuple - underline used to prevent potential name conflicts + # noinspection PyProtectedMember + return type(value)._make(jax_to_torch(v) for v in value) + else: + return type(value)(jax_to_torch(v) for v in value) class JaxToTorch(gym.Wrapper, gym.utils.RecordConstructorArgs): diff --git a/gymnasium/wrappers/numpy_to_torch.py b/gymnasium/wrappers/numpy_to_torch.py index d3d00cc9e..5d8139de0 100644 --- a/gymnasium/wrappers/numpy_to_torch.py +++ b/gymnasium/wrappers/numpy_to_torch.py @@ -50,7 +50,12 @@ def _mapping_torch_to_numpy(value: Mapping[str, Any]) -> Mapping[str, Any]: @torch_to_numpy.register(abc.Iterable) def _iterable_torch_to_numpy(value: Iterable[Any]) -> Iterable[Any]: """Converts an Iterable from PyTorch Tensors to an iterable of Jax Array.""" - return type(value)(torch_to_numpy(v) for v in value) + if hasattr(value, "_make"): + # namedtuple - underline used to prevent potential name conflicts + # noinspection PyProtectedMember + return type(value)._make(torch_to_numpy(v) for v in value) + else: + return type(value)(torch_to_numpy(v) for v in value) @functools.singledispatch @@ -85,7 +90,12 @@ def _numpy_iterable_to_torch( value: Iterable[Any], device: Device | None = None ) -> Iterable[Any]: """Converts an Iterable from Jax Array to an iterable of PyTorch Tensors.""" - return type(value)(tuple(numpy_to_torch(v, device) for v in value)) + if hasattr(value, "_make"): + # namedtuple - underline used to prevent potential name conflicts + # noinspection PyProtectedMember + return type(value)._make(numpy_to_torch(v) for v in value) + else: + return type(value)(numpy_to_torch(v) for v in value) class NumpyToTorch(gym.Wrapper, gym.utils.RecordConstructorArgs): diff --git a/tests/wrappers/test_jax_to_numpy.py b/tests/wrappers/test_jax_to_numpy.py index 61422a2df..72b923456 100644 --- a/tests/wrappers/test_jax_to_numpy.py +++ b/tests/wrappers/test_jax_to_numpy.py @@ -17,7 +17,7 @@ from gymnasium.wrappers.jax_to_numpy import ( # noqa: E402 from tests.testing_env import GenericTestEnv # noqa: E402 -class TestingNamedTuple(NamedTuple): +class ExampleNamedTuple(NamedTuple): a: jax.Array b: jax.Array @@ -62,11 +62,11 @@ class TestingNamedTuple(NamedTuple): }, ), ( - TestingNamedTuple( + ExampleNamedTuple( a=np.array([1, 2], dtype=np.int32), b=np.array([1.0, 2.0], dtype=np.float32), ), - TestingNamedTuple( + ExampleNamedTuple( a=np.array([1, 2], dtype=np.int32), b=np.array([1.0, 2.0], dtype=np.float32), ), diff --git a/tests/wrappers/test_jax_to_torch.py b/tests/wrappers/test_jax_to_torch.py index cb93898de..c7054eecd 100644 --- a/tests/wrappers/test_jax_to_torch.py +++ b/tests/wrappers/test_jax_to_torch.py @@ -1,4 +1,5 @@ """Test suite for TorchToJax wrapper.""" +from typing import NamedTuple import numpy as np import pytest @@ -37,6 +38,11 @@ def torch_data_equivalence(data_1, data_2) -> bool: return False +class ExampleNamedTuple(NamedTuple): + a: torch.Tensor + b: torch.Tensor + + @pytest.mark.parametrize( "value, expected_value", [ @@ -52,19 +58,47 @@ def torch_data_equivalence(data_1, data_2) -> bool: {"a": torch.tensor(6.0), "b": torch.tensor(7)}, ), (torch.tensor(1.0), torch.tensor(1.0)), + (torch.tensor(1.0), torch.tensor(1.0)), (torch.tensor([1, 2]), torch.tensor([1, 2])), - (torch.tensor([[1.0], [2.0]]), torch.tensor([[1.0], [2.0]])), ( - {"a": (1, torch.tensor(2.0), torch.tensor([3, 4])), "b": {"c": 5}}, + torch.tensor([[1.0], [2.0]]), + torch.tensor([[1.0], [2.0]]), + ), + ( { - "a": (torch.tensor(1), torch.tensor(2.0), torch.tensor([3, 4])), + "a": ( + 1, + torch.tensor(2.0), + torch.tensor([3, 4]), + ), + "b": {"c": 5}, + }, + { + "a": ( + torch.tensor(1), + torch.tensor(2.0), + torch.tensor([3, 4]), + ), "b": {"c": torch.tensor(5)}, }, ), + ( + ExampleNamedTuple( + a=torch.tensor([1, 2]), + b=torch.tensor([1.0, 2.0]), + ), + ExampleNamedTuple( + a=torch.tensor([1, 2]), + b=torch.tensor([1.0, 2.0]), + ), + ), ], ) def test_roundtripping(value, expected_value): """We test numpy -> jax -> numpy as this is direction in the NumpyToJax wrapper.""" + print(f"{value=}") + print(f"{torch_to_jax(value)=}") + print(f"{jax_to_torch(torch_to_jax(value))=}") roundtripped_value = jax_to_torch(torch_to_jax(value)) assert torch_data_equivalence(roundtripped_value, expected_value) diff --git a/tests/wrappers/test_numpy_to_torch.py b/tests/wrappers/test_numpy_to_torch.py index 8c304ab7d..c980c40f9 100644 --- a/tests/wrappers/test_numpy_to_torch.py +++ b/tests/wrappers/test_numpy_to_torch.py @@ -1,4 +1,5 @@ """Test suite for NumPyToTorch wrapper.""" +from typing import NamedTuple import numpy as np import pytest @@ -16,6 +17,11 @@ from gymnasium.wrappers.numpy_to_torch import ( # noqa: E402 from tests.testing_env import GenericTestEnv # noqa: E402 +class ExampleNamedTuple(NamedTuple): + a: np.ndarray + b: np.ndarray + + @pytest.mark.parametrize( "value, expected_value", [ @@ -55,13 +61,21 @@ from tests.testing_env import GenericTestEnv # noqa: E402 "b": {"c": np.array(5, dtype=np.int64)}, }, ), + ( + ExampleNamedTuple( + a=np.array([1, 2], dtype=np.int32), + b=np.array([1.0, 2.0], dtype=np.float32), + ), + ExampleNamedTuple( + a=np.array([1, 2], dtype=np.int32), + b=np.array([1.0, 2.0], dtype=np.float32), + ), + ), ], ) def test_roundtripping(value, expected_value): """We test numpy -> torch -> numpy as this is direction in the NumpyToTorch wrapper.""" - torch_value = numpy_to_torch(value) - roundtripped_value = torch_to_numpy(torch_value) - # roundtripped_value = torch_to_numpy(numpy_to_torch(value)) + roundtripped_value = torch_to_numpy(numpy_to_torch(value)) assert data_equivalence(roundtripped_value, expected_value)