Add support for NamedTuple in jax->torch and numpy->torch (#811)

This commit is contained in:
Mark Towers
2023-12-04 12:14:19 +00:00
committed by GitHub
parent b57b9139cd
commit 359cb59e8d
5 changed files with 81 additions and 13 deletions

View File

@@ -17,7 +17,7 @@ from gymnasium.wrappers.jax_to_numpy import ( # noqa: E402
from tests.testing_env import GenericTestEnv # noqa: E402
class TestingNamedTuple(NamedTuple):
class ExampleNamedTuple(NamedTuple):
a: jax.Array
b: jax.Array
@@ -62,11 +62,11 @@ class TestingNamedTuple(NamedTuple):
},
),
(
TestingNamedTuple(
ExampleNamedTuple(
a=np.array([1, 2], dtype=np.int32),
b=np.array([1.0, 2.0], dtype=np.float32),
),
TestingNamedTuple(
ExampleNamedTuple(
a=np.array([1, 2], dtype=np.int32),
b=np.array([1.0, 2.0], dtype=np.float32),
),