mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-21 06:20:15 +00:00
Add support for NamedTuple in jax->torch and numpy->torch (#811)
This commit is contained in:
@@ -74,7 +74,12 @@ def _mapping_torch_to_jax(value: Mapping[str, Any]) -> Mapping[str, Any]:
|
|||||||
@torch_to_jax.register(abc.Iterable)
|
@torch_to_jax.register(abc.Iterable)
|
||||||
def _iterable_torch_to_jax(value: Iterable[Any]) -> Iterable[Any]:
|
def _iterable_torch_to_jax(value: Iterable[Any]) -> Iterable[Any]:
|
||||||
"""Converts an Iterable from PyTorch Tensors to an iterable of Jax Array."""
|
"""Converts an Iterable from PyTorch Tensors to an iterable of Jax Array."""
|
||||||
return type(value)(torch_to_jax(v) for v in value)
|
if hasattr(value, "_make"):
|
||||||
|
# namedtuple - underline used to prevent potential name conflicts
|
||||||
|
# noinspection PyProtectedMember
|
||||||
|
return type(value)._make(torch_to_jax(v) for v in value)
|
||||||
|
else:
|
||||||
|
return type(value)(torch_to_jax(v) for v in value)
|
||||||
|
|
||||||
|
|
||||||
@functools.singledispatch
|
@functools.singledispatch
|
||||||
@@ -111,7 +116,12 @@ def _jax_iterable_to_torch(
|
|||||||
value: Iterable[Any], device: Device | None = None
|
value: Iterable[Any], device: Device | None = None
|
||||||
) -> Iterable[Any]:
|
) -> Iterable[Any]:
|
||||||
"""Converts an Iterable from Jax Array to an iterable of PyTorch Tensors."""
|
"""Converts an Iterable from Jax Array to an iterable of PyTorch Tensors."""
|
||||||
return type(value)(jax_to_torch(v, device) for v in value)
|
if hasattr(value, "_make"):
|
||||||
|
# namedtuple - underline used to prevent potential name conflicts
|
||||||
|
# noinspection PyProtectedMember
|
||||||
|
return type(value)._make(jax_to_torch(v) for v in value)
|
||||||
|
else:
|
||||||
|
return type(value)(jax_to_torch(v) for v in value)
|
||||||
|
|
||||||
|
|
||||||
class JaxToTorch(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
class JaxToTorch(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
||||||
|
@@ -50,7 +50,12 @@ def _mapping_torch_to_numpy(value: Mapping[str, Any]) -> Mapping[str, Any]:
|
|||||||
@torch_to_numpy.register(abc.Iterable)
|
@torch_to_numpy.register(abc.Iterable)
|
||||||
def _iterable_torch_to_numpy(value: Iterable[Any]) -> Iterable[Any]:
|
def _iterable_torch_to_numpy(value: Iterable[Any]) -> Iterable[Any]:
|
||||||
"""Converts an Iterable from PyTorch Tensors to an iterable of Jax Array."""
|
"""Converts an Iterable from PyTorch Tensors to an iterable of Jax Array."""
|
||||||
return type(value)(torch_to_numpy(v) for v in value)
|
if hasattr(value, "_make"):
|
||||||
|
# namedtuple - underline used to prevent potential name conflicts
|
||||||
|
# noinspection PyProtectedMember
|
||||||
|
return type(value)._make(torch_to_numpy(v) for v in value)
|
||||||
|
else:
|
||||||
|
return type(value)(torch_to_numpy(v) for v in value)
|
||||||
|
|
||||||
|
|
||||||
@functools.singledispatch
|
@functools.singledispatch
|
||||||
@@ -85,7 +90,12 @@ def _numpy_iterable_to_torch(
|
|||||||
value: Iterable[Any], device: Device | None = None
|
value: Iterable[Any], device: Device | None = None
|
||||||
) -> Iterable[Any]:
|
) -> Iterable[Any]:
|
||||||
"""Converts an Iterable from Jax Array to an iterable of PyTorch Tensors."""
|
"""Converts an Iterable from Jax Array to an iterable of PyTorch Tensors."""
|
||||||
return type(value)(tuple(numpy_to_torch(v, device) for v in value))
|
if hasattr(value, "_make"):
|
||||||
|
# namedtuple - underline used to prevent potential name conflicts
|
||||||
|
# noinspection PyProtectedMember
|
||||||
|
return type(value)._make(numpy_to_torch(v) for v in value)
|
||||||
|
else:
|
||||||
|
return type(value)(numpy_to_torch(v) for v in value)
|
||||||
|
|
||||||
|
|
||||||
class NumpyToTorch(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
class NumpyToTorch(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
||||||
|
@@ -17,7 +17,7 @@ from gymnasium.wrappers.jax_to_numpy import ( # noqa: E402
|
|||||||
from tests.testing_env import GenericTestEnv # noqa: E402
|
from tests.testing_env import GenericTestEnv # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
class TestingNamedTuple(NamedTuple):
|
class ExampleNamedTuple(NamedTuple):
|
||||||
a: jax.Array
|
a: jax.Array
|
||||||
b: jax.Array
|
b: jax.Array
|
||||||
|
|
||||||
@@ -62,11 +62,11 @@ class TestingNamedTuple(NamedTuple):
|
|||||||
},
|
},
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
TestingNamedTuple(
|
ExampleNamedTuple(
|
||||||
a=np.array([1, 2], dtype=np.int32),
|
a=np.array([1, 2], dtype=np.int32),
|
||||||
b=np.array([1.0, 2.0], dtype=np.float32),
|
b=np.array([1.0, 2.0], dtype=np.float32),
|
||||||
),
|
),
|
||||||
TestingNamedTuple(
|
ExampleNamedTuple(
|
||||||
a=np.array([1, 2], dtype=np.int32),
|
a=np.array([1, 2], dtype=np.int32),
|
||||||
b=np.array([1.0, 2.0], dtype=np.float32),
|
b=np.array([1.0, 2.0], dtype=np.float32),
|
||||||
),
|
),
|
||||||
|
@@ -1,4 +1,5 @@
|
|||||||
"""Test suite for TorchToJax wrapper."""
|
"""Test suite for TorchToJax wrapper."""
|
||||||
|
from typing import NamedTuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
@@ -37,6 +38,11 @@ def torch_data_equivalence(data_1, data_2) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class ExampleNamedTuple(NamedTuple):
|
||||||
|
a: torch.Tensor
|
||||||
|
b: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"value, expected_value",
|
"value, expected_value",
|
||||||
[
|
[
|
||||||
@@ -52,19 +58,47 @@ def torch_data_equivalence(data_1, data_2) -> bool:
|
|||||||
{"a": torch.tensor(6.0), "b": torch.tensor(7)},
|
{"a": torch.tensor(6.0), "b": torch.tensor(7)},
|
||||||
),
|
),
|
||||||
(torch.tensor(1.0), torch.tensor(1.0)),
|
(torch.tensor(1.0), torch.tensor(1.0)),
|
||||||
|
(torch.tensor(1.0), torch.tensor(1.0)),
|
||||||
(torch.tensor([1, 2]), torch.tensor([1, 2])),
|
(torch.tensor([1, 2]), torch.tensor([1, 2])),
|
||||||
(torch.tensor([[1.0], [2.0]]), torch.tensor([[1.0], [2.0]])),
|
|
||||||
(
|
(
|
||||||
{"a": (1, torch.tensor(2.0), torch.tensor([3, 4])), "b": {"c": 5}},
|
torch.tensor([[1.0], [2.0]]),
|
||||||
|
torch.tensor([[1.0], [2.0]]),
|
||||||
|
),
|
||||||
|
(
|
||||||
{
|
{
|
||||||
"a": (torch.tensor(1), torch.tensor(2.0), torch.tensor([3, 4])),
|
"a": (
|
||||||
|
1,
|
||||||
|
torch.tensor(2.0),
|
||||||
|
torch.tensor([3, 4]),
|
||||||
|
),
|
||||||
|
"b": {"c": 5},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"a": (
|
||||||
|
torch.tensor(1),
|
||||||
|
torch.tensor(2.0),
|
||||||
|
torch.tensor([3, 4]),
|
||||||
|
),
|
||||||
"b": {"c": torch.tensor(5)},
|
"b": {"c": torch.tensor(5)},
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
|
(
|
||||||
|
ExampleNamedTuple(
|
||||||
|
a=torch.tensor([1, 2]),
|
||||||
|
b=torch.tensor([1.0, 2.0]),
|
||||||
|
),
|
||||||
|
ExampleNamedTuple(
|
||||||
|
a=torch.tensor([1, 2]),
|
||||||
|
b=torch.tensor([1.0, 2.0]),
|
||||||
|
),
|
||||||
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_roundtripping(value, expected_value):
|
def test_roundtripping(value, expected_value):
|
||||||
"""We test numpy -> jax -> numpy as this is direction in the NumpyToJax wrapper."""
|
"""We test numpy -> jax -> numpy as this is direction in the NumpyToJax wrapper."""
|
||||||
|
print(f"{value=}")
|
||||||
|
print(f"{torch_to_jax(value)=}")
|
||||||
|
print(f"{jax_to_torch(torch_to_jax(value))=}")
|
||||||
roundtripped_value = jax_to_torch(torch_to_jax(value))
|
roundtripped_value = jax_to_torch(torch_to_jax(value))
|
||||||
assert torch_data_equivalence(roundtripped_value, expected_value)
|
assert torch_data_equivalence(roundtripped_value, expected_value)
|
||||||
|
|
||||||
|
@@ -1,4 +1,5 @@
|
|||||||
"""Test suite for NumPyToTorch wrapper."""
|
"""Test suite for NumPyToTorch wrapper."""
|
||||||
|
from typing import NamedTuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
@@ -16,6 +17,11 @@ from gymnasium.wrappers.numpy_to_torch import ( # noqa: E402
|
|||||||
from tests.testing_env import GenericTestEnv # noqa: E402
|
from tests.testing_env import GenericTestEnv # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
|
class ExampleNamedTuple(NamedTuple):
|
||||||
|
a: np.ndarray
|
||||||
|
b: np.ndarray
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"value, expected_value",
|
"value, expected_value",
|
||||||
[
|
[
|
||||||
@@ -55,13 +61,21 @@ from tests.testing_env import GenericTestEnv # noqa: E402
|
|||||||
"b": {"c": np.array(5, dtype=np.int64)},
|
"b": {"c": np.array(5, dtype=np.int64)},
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
|
(
|
||||||
|
ExampleNamedTuple(
|
||||||
|
a=np.array([1, 2], dtype=np.int32),
|
||||||
|
b=np.array([1.0, 2.0], dtype=np.float32),
|
||||||
|
),
|
||||||
|
ExampleNamedTuple(
|
||||||
|
a=np.array([1, 2], dtype=np.int32),
|
||||||
|
b=np.array([1.0, 2.0], dtype=np.float32),
|
||||||
|
),
|
||||||
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_roundtripping(value, expected_value):
|
def test_roundtripping(value, expected_value):
|
||||||
"""We test numpy -> torch -> numpy as this is direction in the NumpyToTorch wrapper."""
|
"""We test numpy -> torch -> numpy as this is direction in the NumpyToTorch wrapper."""
|
||||||
torch_value = numpy_to_torch(value)
|
roundtripped_value = torch_to_numpy(numpy_to_torch(value))
|
||||||
roundtripped_value = torch_to_numpy(torch_value)
|
|
||||||
# roundtripped_value = torch_to_numpy(numpy_to_torch(value))
|
|
||||||
assert data_equivalence(roundtripped_value, expected_value)
|
assert data_equivalence(roundtripped_value, expected_value)
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user