mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-30 01:50:19 +00:00
Add support for NamedTuple in jax->torch and numpy->torch (#811)
This commit is contained in:
@@ -17,7 +17,7 @@ from gymnasium.wrappers.jax_to_numpy import ( # noqa: E402
|
||||
from tests.testing_env import GenericTestEnv # noqa: E402
|
||||
|
||||
|
||||
class TestingNamedTuple(NamedTuple):
|
||||
class ExampleNamedTuple(NamedTuple):
|
||||
a: jax.Array
|
||||
b: jax.Array
|
||||
|
||||
@@ -62,11 +62,11 @@ class TestingNamedTuple(NamedTuple):
|
||||
},
|
||||
),
|
||||
(
|
||||
TestingNamedTuple(
|
||||
ExampleNamedTuple(
|
||||
a=np.array([1, 2], dtype=np.int32),
|
||||
b=np.array([1.0, 2.0], dtype=np.float32),
|
||||
),
|
||||
TestingNamedTuple(
|
||||
ExampleNamedTuple(
|
||||
a=np.array([1, 2], dtype=np.int32),
|
||||
b=np.array([1.0, 2.0], dtype=np.float32),
|
||||
),
|
||||
|
Reference in New Issue
Block a user