Add support for NamedTuple in jax->torch and numpy->torch (#811)

This commit is contained in:
Mark Towers
2023-12-04 12:14:19 +00:00
committed by GitHub
parent b57b9139cd
commit 359cb59e8d
5 changed files with 81 additions and 13 deletions

View File

@@ -74,6 +74,11 @@ def _mapping_torch_to_jax(value: Mapping[str, Any]) -> Mapping[str, Any]:
@torch_to_jax.register(abc.Iterable) @torch_to_jax.register(abc.Iterable)
def _iterable_torch_to_jax(value: Iterable[Any]) -> Iterable[Any]: def _iterable_torch_to_jax(value: Iterable[Any]) -> Iterable[Any]:
"""Converts an Iterable from PyTorch Tensors to an iterable of Jax Array.""" """Converts an Iterable from PyTorch Tensors to an iterable of Jax Array."""
if hasattr(value, "_make"):
# namedtuple - underline used to prevent potential name conflicts
# noinspection PyProtectedMember
return type(value)._make(torch_to_jax(v) for v in value)
else:
return type(value)(torch_to_jax(v) for v in value) return type(value)(torch_to_jax(v) for v in value)
@@ -111,7 +116,12 @@ def _jax_iterable_to_torch(
value: Iterable[Any], device: Device | None = None value: Iterable[Any], device: Device | None = None
) -> Iterable[Any]: ) -> Iterable[Any]:
"""Converts an Iterable from Jax Array to an iterable of PyTorch Tensors.""" """Converts an Iterable from Jax Array to an iterable of PyTorch Tensors."""
return type(value)(jax_to_torch(v, device) for v in value) if hasattr(value, "_make"):
# namedtuple - underline used to prevent potential name conflicts
# noinspection PyProtectedMember
return type(value)._make(jax_to_torch(v) for v in value)
else:
return type(value)(jax_to_torch(v) for v in value)
class JaxToTorch(gym.Wrapper, gym.utils.RecordConstructorArgs): class JaxToTorch(gym.Wrapper, gym.utils.RecordConstructorArgs):

View File

@@ -50,6 +50,11 @@ def _mapping_torch_to_numpy(value: Mapping[str, Any]) -> Mapping[str, Any]:
@torch_to_numpy.register(abc.Iterable) @torch_to_numpy.register(abc.Iterable)
def _iterable_torch_to_numpy(value: Iterable[Any]) -> Iterable[Any]: def _iterable_torch_to_numpy(value: Iterable[Any]) -> Iterable[Any]:
"""Converts an Iterable from PyTorch Tensors to an iterable of Jax Array.""" """Converts an Iterable from PyTorch Tensors to an iterable of Jax Array."""
if hasattr(value, "_make"):
# namedtuple - underline used to prevent potential name conflicts
# noinspection PyProtectedMember
return type(value)._make(torch_to_numpy(v) for v in value)
else:
return type(value)(torch_to_numpy(v) for v in value) return type(value)(torch_to_numpy(v) for v in value)
@@ -85,7 +90,12 @@ def _numpy_iterable_to_torch(
value: Iterable[Any], device: Device | None = None value: Iterable[Any], device: Device | None = None
) -> Iterable[Any]: ) -> Iterable[Any]:
"""Converts an Iterable from Jax Array to an iterable of PyTorch Tensors.""" """Converts an Iterable from Jax Array to an iterable of PyTorch Tensors."""
return type(value)(tuple(numpy_to_torch(v, device) for v in value)) if hasattr(value, "_make"):
# namedtuple - underline used to prevent potential name conflicts
# noinspection PyProtectedMember
return type(value)._make(numpy_to_torch(v) for v in value)
else:
return type(value)(numpy_to_torch(v) for v in value)
class NumpyToTorch(gym.Wrapper, gym.utils.RecordConstructorArgs): class NumpyToTorch(gym.Wrapper, gym.utils.RecordConstructorArgs):

View File

@@ -17,7 +17,7 @@ from gymnasium.wrappers.jax_to_numpy import ( # noqa: E402
from tests.testing_env import GenericTestEnv # noqa: E402 from tests.testing_env import GenericTestEnv # noqa: E402
class TestingNamedTuple(NamedTuple): class ExampleNamedTuple(NamedTuple):
a: jax.Array a: jax.Array
b: jax.Array b: jax.Array
@@ -62,11 +62,11 @@ class TestingNamedTuple(NamedTuple):
}, },
), ),
( (
TestingNamedTuple( ExampleNamedTuple(
a=np.array([1, 2], dtype=np.int32), a=np.array([1, 2], dtype=np.int32),
b=np.array([1.0, 2.0], dtype=np.float32), b=np.array([1.0, 2.0], dtype=np.float32),
), ),
TestingNamedTuple( ExampleNamedTuple(
a=np.array([1, 2], dtype=np.int32), a=np.array([1, 2], dtype=np.int32),
b=np.array([1.0, 2.0], dtype=np.float32), b=np.array([1.0, 2.0], dtype=np.float32),
), ),

View File

@@ -1,4 +1,5 @@
"""Test suite for TorchToJax wrapper.""" """Test suite for TorchToJax wrapper."""
from typing import NamedTuple
import numpy as np import numpy as np
import pytest import pytest
@@ -37,6 +38,11 @@ def torch_data_equivalence(data_1, data_2) -> bool:
return False return False
class ExampleNamedTuple(NamedTuple):
a: torch.Tensor
b: torch.Tensor
@pytest.mark.parametrize( @pytest.mark.parametrize(
"value, expected_value", "value, expected_value",
[ [
@@ -52,19 +58,47 @@ def torch_data_equivalence(data_1, data_2) -> bool:
{"a": torch.tensor(6.0), "b": torch.tensor(7)}, {"a": torch.tensor(6.0), "b": torch.tensor(7)},
), ),
(torch.tensor(1.0), torch.tensor(1.0)), (torch.tensor(1.0), torch.tensor(1.0)),
(torch.tensor(1.0), torch.tensor(1.0)),
(torch.tensor([1, 2]), torch.tensor([1, 2])), (torch.tensor([1, 2]), torch.tensor([1, 2])),
(torch.tensor([[1.0], [2.0]]), torch.tensor([[1.0], [2.0]])),
( (
{"a": (1, torch.tensor(2.0), torch.tensor([3, 4])), "b": {"c": 5}}, torch.tensor([[1.0], [2.0]]),
torch.tensor([[1.0], [2.0]]),
),
(
{ {
"a": (torch.tensor(1), torch.tensor(2.0), torch.tensor([3, 4])), "a": (
1,
torch.tensor(2.0),
torch.tensor([3, 4]),
),
"b": {"c": 5},
},
{
"a": (
torch.tensor(1),
torch.tensor(2.0),
torch.tensor([3, 4]),
),
"b": {"c": torch.tensor(5)}, "b": {"c": torch.tensor(5)},
}, },
), ),
(
ExampleNamedTuple(
a=torch.tensor([1, 2]),
b=torch.tensor([1.0, 2.0]),
),
ExampleNamedTuple(
a=torch.tensor([1, 2]),
b=torch.tensor([1.0, 2.0]),
),
),
], ],
) )
def test_roundtripping(value, expected_value): def test_roundtripping(value, expected_value):
"""We test numpy -> jax -> numpy as this is direction in the NumpyToJax wrapper.""" """We test numpy -> jax -> numpy as this is direction in the NumpyToJax wrapper."""
print(f"{value=}")
print(f"{torch_to_jax(value)=}")
print(f"{jax_to_torch(torch_to_jax(value))=}")
roundtripped_value = jax_to_torch(torch_to_jax(value)) roundtripped_value = jax_to_torch(torch_to_jax(value))
assert torch_data_equivalence(roundtripped_value, expected_value) assert torch_data_equivalence(roundtripped_value, expected_value)

View File

@@ -1,4 +1,5 @@
"""Test suite for NumPyToTorch wrapper.""" """Test suite for NumPyToTorch wrapper."""
from typing import NamedTuple
import numpy as np import numpy as np
import pytest import pytest
@@ -16,6 +17,11 @@ from gymnasium.wrappers.numpy_to_torch import ( # noqa: E402
from tests.testing_env import GenericTestEnv # noqa: E402 from tests.testing_env import GenericTestEnv # noqa: E402
class ExampleNamedTuple(NamedTuple):
a: np.ndarray
b: np.ndarray
@pytest.mark.parametrize( @pytest.mark.parametrize(
"value, expected_value", "value, expected_value",
[ [
@@ -55,13 +61,21 @@ from tests.testing_env import GenericTestEnv # noqa: E402
"b": {"c": np.array(5, dtype=np.int64)}, "b": {"c": np.array(5, dtype=np.int64)},
}, },
), ),
(
ExampleNamedTuple(
a=np.array([1, 2], dtype=np.int32),
b=np.array([1.0, 2.0], dtype=np.float32),
),
ExampleNamedTuple(
a=np.array([1, 2], dtype=np.int32),
b=np.array([1.0, 2.0], dtype=np.float32),
),
),
], ],
) )
def test_roundtripping(value, expected_value): def test_roundtripping(value, expected_value):
"""We test numpy -> torch -> numpy as this is direction in the NumpyToTorch wrapper.""" """We test numpy -> torch -> numpy as this is direction in the NumpyToTorch wrapper."""
torch_value = numpy_to_torch(value) roundtripped_value = torch_to_numpy(numpy_to_torch(value))
roundtripped_value = torch_to_numpy(torch_value)
# roundtripped_value = torch_to_numpy(numpy_to_torch(value))
assert data_equivalence(roundtripped_value, expected_value) assert data_equivalence(roundtripped_value, expected_value)