mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-15 11:18:47 +00:00
Jax environment return jax data rather than numpy data (#817)
Co-authored-by: pseudo-rnd-thoughts <mark.m.towers@gmail.com>
This commit is contained in:
committed by
GitHub
parent
f0202ae350
commit
d43037920f
@@ -60,6 +60,7 @@ register(
|
|||||||
vector_entry_point="gymnasium.envs.phys2d.cartpole:CartPoleJaxVectorEnv",
|
vector_entry_point="gymnasium.envs.phys2d.cartpole:CartPoleJaxVectorEnv",
|
||||||
max_episode_steps=200,
|
max_episode_steps=200,
|
||||||
reward_threshold=195.0,
|
reward_threshold=195.0,
|
||||||
|
disable_env_checker=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
register(
|
register(
|
||||||
@@ -68,6 +69,7 @@ register(
|
|||||||
vector_entry_point="gymnasium.envs.phys2d.cartpole:CartPoleJaxVectorEnv",
|
vector_entry_point="gymnasium.envs.phys2d.cartpole:CartPoleJaxVectorEnv",
|
||||||
max_episode_steps=500,
|
max_episode_steps=500,
|
||||||
reward_threshold=475.0,
|
reward_threshold=475.0,
|
||||||
|
disable_env_checker=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
register(
|
register(
|
||||||
@@ -75,6 +77,7 @@ register(
|
|||||||
entry_point="gymnasium.envs.phys2d.pendulum:PendulumJaxEnv",
|
entry_point="gymnasium.envs.phys2d.pendulum:PendulumJaxEnv",
|
||||||
vector_entry_point="gymnasium.envs.phys2d.pendulum:PendulumJaxVectorEnv",
|
vector_entry_point="gymnasium.envs.phys2d.pendulum:PendulumJaxVectorEnv",
|
||||||
max_episode_steps=200,
|
max_episode_steps=200,
|
||||||
|
disable_env_checker=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Box2d
|
# Box2d
|
||||||
@@ -161,11 +164,13 @@ register(
|
|||||||
register(
|
register(
|
||||||
id="tabular/Blackjack-v0",
|
id="tabular/Blackjack-v0",
|
||||||
entry_point="gymnasium.envs.tabular.blackjack:BlackJackJaxEnv",
|
entry_point="gymnasium.envs.tabular.blackjack:BlackJackJaxEnv",
|
||||||
|
disable_env_checker=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
register(
|
register(
|
||||||
id="tabular/CliffWalking-v0",
|
id="tabular/CliffWalking-v0",
|
||||||
entry_point="gymnasium.envs.tabular.cliffwalking:CliffWalkingJaxEnv",
|
entry_point="gymnasium.envs.tabular.cliffwalking:CliffWalkingJaxEnv",
|
||||||
|
disable_env_checker=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@@ -6,14 +6,12 @@ from typing import Any
|
|||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import jax.random as jrng
|
import jax.random as jrng
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
from gymnasium.envs.registration import EnvSpec
|
from gymnasium.envs.registration import EnvSpec
|
||||||
from gymnasium.functional import ActType, FuncEnv, StateType
|
from gymnasium.functional import ActType, FuncEnv, StateType
|
||||||
from gymnasium.utils import seeding
|
from gymnasium.utils import seeding
|
||||||
from gymnasium.vector.utils import batch_space
|
from gymnasium.vector.utils import batch_space
|
||||||
from gymnasium.wrappers.jax_to_numpy import jax_to_numpy
|
|
||||||
|
|
||||||
|
|
||||||
class FunctionalJaxEnv(gym.Env):
|
class FunctionalJaxEnv(gym.Env):
|
||||||
@@ -32,7 +30,8 @@ class FunctionalJaxEnv(gym.Env):
|
|||||||
):
|
):
|
||||||
"""Initialize the environment from a FuncEnv."""
|
"""Initialize the environment from a FuncEnv."""
|
||||||
if metadata is None:
|
if metadata is None:
|
||||||
metadata = {"render_mode": []}
|
# metadata.get("jax", False) can be used downstream to know that the environment returns jax arrays
|
||||||
|
metadata = {"render_mode": [], "jax": True}
|
||||||
|
|
||||||
self.func_env = func_env
|
self.func_env = func_env
|
||||||
|
|
||||||
@@ -45,8 +44,6 @@ class FunctionalJaxEnv(gym.Env):
|
|||||||
|
|
||||||
self.spec = spec
|
self.spec = spec
|
||||||
|
|
||||||
self._is_box_action_space = isinstance(self.action_space, gym.spaces.Box)
|
|
||||||
|
|
||||||
if self.render_mode == "rgb_array":
|
if self.render_mode == "rgb_array":
|
||||||
self.render_state = self.func_env.render_init()
|
self.render_state = self.func_env.render_init()
|
||||||
else:
|
else:
|
||||||
@@ -69,20 +66,10 @@ class FunctionalJaxEnv(gym.Env):
|
|||||||
obs = self.func_env.observation(self.state)
|
obs = self.func_env.observation(self.state)
|
||||||
info = self.func_env.state_info(self.state)
|
info = self.func_env.state_info(self.state)
|
||||||
|
|
||||||
obs = jax_to_numpy(obs)
|
|
||||||
|
|
||||||
return obs, info
|
return obs, info
|
||||||
|
|
||||||
def step(self, action: ActType):
|
def step(self, action: ActType):
|
||||||
"""Steps through the environment using the action."""
|
"""Steps through the environment using the action."""
|
||||||
if self._is_box_action_space:
|
|
||||||
assert isinstance(self.action_space, gym.spaces.Box) # For typing
|
|
||||||
action = np.clip(action, self.action_space.low, self.action_space.high)
|
|
||||||
else: # Discrete
|
|
||||||
# For now we assume jax envs don't use complex spaces
|
|
||||||
err_msg = f"{action!r} ({type(action)}) invalid"
|
|
||||||
assert self.action_space.contains(action), err_msg
|
|
||||||
|
|
||||||
rng, self.rng = jrng.split(self.rng)
|
rng, self.rng = jrng.split(self.rng)
|
||||||
|
|
||||||
next_state = self.func_env.transition(self.state, action, rng)
|
next_state = self.func_env.transition(self.state, action, rng)
|
||||||
@@ -92,8 +79,6 @@ class FunctionalJaxEnv(gym.Env):
|
|||||||
info = self.func_env.transition_info(self.state, action, next_state)
|
info = self.func_env.transition_info(self.state, action, next_state)
|
||||||
self.state = next_state
|
self.state = next_state
|
||||||
|
|
||||||
observation = jax_to_numpy(observation)
|
|
||||||
|
|
||||||
return observation, float(reward), bool(terminated), False, info
|
return observation, float(reward), bool(terminated), False, info
|
||||||
|
|
||||||
def render(self):
|
def render(self):
|
||||||
@@ -153,8 +138,6 @@ class FunctionalJaxVectorEnv(gym.vector.VectorEnv):
|
|||||||
|
|
||||||
self.autoreset_envs = jnp.zeros(self.num_envs, dtype=jnp.bool_)
|
self.autoreset_envs = jnp.zeros(self.num_envs, dtype=jnp.bool_)
|
||||||
|
|
||||||
self._is_box_action_space = isinstance(self.action_space, gym.spaces.Box)
|
|
||||||
|
|
||||||
if self.render_mode == "rgb_array":
|
if self.render_mode == "rgb_array":
|
||||||
self.render_state = self.func_env.render_init()
|
self.render_state = self.func_env.render_init()
|
||||||
else:
|
else:
|
||||||
@@ -183,20 +166,10 @@ class FunctionalJaxVectorEnv(gym.vector.VectorEnv):
|
|||||||
|
|
||||||
self.steps = jnp.zeros(self.num_envs, dtype=jnp.int32)
|
self.steps = jnp.zeros(self.num_envs, dtype=jnp.int32)
|
||||||
|
|
||||||
obs = jax_to_numpy(obs)
|
|
||||||
|
|
||||||
return obs, info
|
return obs, info
|
||||||
|
|
||||||
def step(self, action: ActType):
|
def step(self, action: ActType):
|
||||||
"""Steps through the environment using the action."""
|
"""Steps through the environment using the action."""
|
||||||
if self._is_box_action_space:
|
|
||||||
assert isinstance(self.action_space, gym.spaces.Box) # For typing
|
|
||||||
action = np.clip(action, self.action_space.low, self.action_space.high)
|
|
||||||
else: # Discrete
|
|
||||||
# For now we assume jax envs don't use complex spaces
|
|
||||||
assert self.action_space.contains(
|
|
||||||
action
|
|
||||||
), f"{action!r} ({type(action)}) invalid"
|
|
||||||
self.steps += 1
|
self.steps += 1
|
||||||
|
|
||||||
rng, self.rng = jrng.split(self.rng)
|
rng, self.rng = jrng.split(self.rng)
|
||||||
@@ -232,12 +205,6 @@ class FunctionalJaxVectorEnv(gym.vector.VectorEnv):
|
|||||||
self.autoreset_envs = done
|
self.autoreset_envs = done
|
||||||
|
|
||||||
observation = self.func_env.observation(next_state)
|
observation = self.func_env.observation(next_state)
|
||||||
observation = jax_to_numpy(observation)
|
|
||||||
|
|
||||||
reward = jax_to_numpy(reward)
|
|
||||||
|
|
||||||
terminated = jax_to_numpy(terminated)
|
|
||||||
truncated = jax_to_numpy(truncated)
|
|
||||||
|
|
||||||
self.state = next_state
|
self.state = next_state
|
||||||
|
|
||||||
|
@@ -245,7 +245,7 @@ class CartPoleFunctional(
|
|||||||
class CartPoleJaxEnv(FunctionalJaxEnv, EzPickle):
|
class CartPoleJaxEnv(FunctionalJaxEnv, EzPickle):
|
||||||
"""Jax-based implementation of the CartPole environment."""
|
"""Jax-based implementation of the CartPole environment."""
|
||||||
|
|
||||||
metadata = {"render_modes": ["rgb_array"], "render_fps": 50}
|
metadata = {"render_modes": ["rgb_array"], "render_fps": 50, "jax": True}
|
||||||
|
|
||||||
def __init__(self, render_mode: str | None = None, **kwargs: Any):
|
def __init__(self, render_mode: str | None = None, **kwargs: Any):
|
||||||
"""Constructor for the CartPole where the kwargs are applied to the functional environment."""
|
"""Constructor for the CartPole where the kwargs are applied to the functional environment."""
|
||||||
@@ -265,7 +265,7 @@ class CartPoleJaxEnv(FunctionalJaxEnv, EzPickle):
|
|||||||
class CartPoleJaxVectorEnv(FunctionalJaxVectorEnv, EzPickle):
|
class CartPoleJaxVectorEnv(FunctionalJaxVectorEnv, EzPickle):
|
||||||
"""Jax-based implementation of the vectorized CartPole environment."""
|
"""Jax-based implementation of the vectorized CartPole environment."""
|
||||||
|
|
||||||
metadata = {"render_modes": ["rgb_array"], "render_fps": 50}
|
metadata = {"render_modes": ["rgb_array"], "render_fps": 50, "jax": True}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@@ -223,7 +223,7 @@ class PendulumFunctional(
|
|||||||
class PendulumJaxEnv(FunctionalJaxEnv, EzPickle):
|
class PendulumJaxEnv(FunctionalJaxEnv, EzPickle):
|
||||||
"""Jax-based pendulum environment using the functional version as base."""
|
"""Jax-based pendulum environment using the functional version as base."""
|
||||||
|
|
||||||
metadata = {"render_modes": ["rgb_array"], "render_fps": 30}
|
metadata = {"render_modes": ["rgb_array"], "render_fps": 30, "jax": True}
|
||||||
|
|
||||||
def __init__(self, render_mode: str | None = None, **kwargs: Any):
|
def __init__(self, render_mode: str | None = None, **kwargs: Any):
|
||||||
"""Constructor where the kwargs are passed to the base environment to modify the parameters."""
|
"""Constructor where the kwargs are passed to the base environment to modify the parameters."""
|
||||||
@@ -242,7 +242,7 @@ class PendulumJaxEnv(FunctionalJaxEnv, EzPickle):
|
|||||||
class PendulumJaxVectorEnv(FunctionalJaxVectorEnv, EzPickle):
|
class PendulumJaxVectorEnv(FunctionalJaxVectorEnv, EzPickle):
|
||||||
"""Jax-based implementation of the vectorized CartPole environment."""
|
"""Jax-based implementation of the vectorized CartPole environment."""
|
||||||
|
|
||||||
metadata = {"render_modes": ["rgb_array"], "render_fps": 50}
|
metadata = {"render_modes": ["rgb_array"], "render_fps": 50, "jax": True}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@@ -496,7 +496,7 @@ class BlackjackFunctional(
|
|||||||
class BlackJackJaxEnv(FunctionalJaxEnv, EzPickle):
|
class BlackJackJaxEnv(FunctionalJaxEnv, EzPickle):
|
||||||
"""A Gymnasium Env wrapper for the functional blackjack env."""
|
"""A Gymnasium Env wrapper for the functional blackjack env."""
|
||||||
|
|
||||||
metadata = {"render_modes": ["rgb_array"], "render_fps": 50}
|
metadata = {"render_modes": ["rgb_array"], "render_fps": 50, "jax": True}
|
||||||
|
|
||||||
def __init__(self, render_mode: Optional[str] = None, **kwargs):
|
def __init__(self, render_mode: Optional[str] = None, **kwargs):
|
||||||
"""Initializes Gym wrapper for blackjack functional env."""
|
"""Initializes Gym wrapper for blackjack functional env."""
|
||||||
|
@@ -358,7 +358,7 @@ class CliffWalkingFunctional(
|
|||||||
class CliffWalkingJaxEnv(FunctionalJaxEnv, EzPickle):
|
class CliffWalkingJaxEnv(FunctionalJaxEnv, EzPickle):
|
||||||
"""A Gymnasium Env wrapper for the functional cliffwalking env."""
|
"""A Gymnasium Env wrapper for the functional cliffwalking env."""
|
||||||
|
|
||||||
metadata = {"render_modes": ["rgb_array"], "render_fps": 50}
|
metadata = {"render_modes": ["rgb_array"], "render_fps": 50, "jax": True}
|
||||||
|
|
||||||
def __init__(self, render_mode: str | None = None, **kwargs):
|
def __init__(self, render_mode: str | None = None, **kwargs):
|
||||||
"""Initializes Gym wrapper for cliffwalking functional env."""
|
"""Initializes Gym wrapper for cliffwalking functional env."""
|
||||||
|
@@ -367,6 +367,11 @@ def check_env(
|
|||||||
f"The environment ({env}) is different from the unwrapped version ({env.unwrapped}). This could effect the environment checker as the environment most likely has a wrapper applied to it. We recommend using the raw environment for `check_env` using `env.unwrapped`."
|
f"The environment ({env}) is different from the unwrapped version ({env.unwrapped}). This could effect the environment checker as the environment most likely has a wrapper applied to it. We recommend using the raw environment for `check_env` using `env.unwrapped`."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if env.metadata.get("jax", False):
|
||||||
|
env = gym.wrappers.JaxToNumpy(env)
|
||||||
|
elif env.metadata.get("torch", False):
|
||||||
|
env = gym.wrappers.JaxToTorch(env)
|
||||||
|
|
||||||
# ============= Check the spaces (observation and action) ================
|
# ============= Check the spaces (observation and action) ================
|
||||||
if not hasattr(env, "action_space"):
|
if not hasattr(env, "action_space"):
|
||||||
raise AttributeError(
|
raise AttributeError(
|
||||||
|
@@ -120,7 +120,7 @@ def test_vectorized(env_class):
|
|||||||
|
|
||||||
obs, info = env.reset(seed=0)
|
obs, info = env.reset(seed=0)
|
||||||
assert obs.shape == (10,) + env.single_observation_space.shape
|
assert obs.shape == (10,) + env.single_observation_space.shape
|
||||||
assert isinstance(obs, np.ndarray)
|
assert isinstance(obs, jax.Array)
|
||||||
assert isinstance(info, dict)
|
assert isinstance(info, dict)
|
||||||
|
|
||||||
for t in range(100):
|
for t in range(100):
|
||||||
@@ -128,13 +128,13 @@ def test_vectorized(env_class):
|
|||||||
obs, reward, terminated, truncated, info = env.step(action)
|
obs, reward, terminated, truncated, info = env.step(action)
|
||||||
|
|
||||||
assert obs.shape == (10,) + env.single_observation_space.shape
|
assert obs.shape == (10,) + env.single_observation_space.shape
|
||||||
assert isinstance(obs, np.ndarray)
|
assert isinstance(obs, jax.Array)
|
||||||
assert reward.shape == (10,)
|
assert reward.shape == (10,)
|
||||||
assert isinstance(reward, np.ndarray)
|
assert isinstance(reward, jax.Array)
|
||||||
assert terminated.shape == (10,)
|
assert terminated.shape == (10,)
|
||||||
assert isinstance(terminated, np.ndarray)
|
assert isinstance(terminated, jax.Array)
|
||||||
assert truncated.shape == (10,)
|
assert truncated.shape == (10,)
|
||||||
assert isinstance(truncated, np.ndarray)
|
assert isinstance(truncated, jax.Array)
|
||||||
assert isinstance(info, dict)
|
assert isinstance(info, dict)
|
||||||
|
|
||||||
# These were removed in the new autoreset order
|
# These were removed in the new autoreset order
|
||||||
|
@@ -70,6 +70,9 @@ def test_discrete_actions_out_of_bound(env: gym.Env):
|
|||||||
Args:
|
Args:
|
||||||
env (gym.Env): the gymnasium environment
|
env (gym.Env): the gymnasium environment
|
||||||
"""
|
"""
|
||||||
|
if env.metadata.get("jax", False):
|
||||||
|
return
|
||||||
|
|
||||||
assert isinstance(env.action_space, spaces.Discrete)
|
assert isinstance(env.action_space, spaces.Discrete)
|
||||||
upper_bound = env.action_space.start + env.action_space.n - 1
|
upper_bound = env.action_space.start + env.action_space.n - 1
|
||||||
|
|
||||||
@@ -102,6 +105,9 @@ def test_box_actions_out_of_bound(env: gym.Env):
|
|||||||
Args:
|
Args:
|
||||||
env (gym.Env): the gymnasium environment
|
env (gym.Env): the gymnasium environment
|
||||||
"""
|
"""
|
||||||
|
if env.metadata.get("jax", False):
|
||||||
|
return
|
||||||
|
|
||||||
env.reset(seed=42)
|
env.reset(seed=42)
|
||||||
|
|
||||||
assert env.spec is not None
|
assert env.spec is not None
|
||||||
|
@@ -7,11 +7,7 @@ import pytest
|
|||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
from gymnasium.envs.registration import EnvSpec
|
from gymnasium.envs.registration import EnvSpec
|
||||||
from gymnasium.utils.env_checker import check_env, data_equivalence
|
from gymnasium.utils.env_checker import check_env, data_equivalence
|
||||||
from tests.envs.utils import (
|
from tests.envs.utils import all_testing_env_specs, all_testing_initialised_envs
|
||||||
all_testing_env_specs,
|
|
||||||
all_testing_initialised_envs,
|
|
||||||
assert_equals,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# This runs a smoketest on each official registered env. We may want
|
# This runs a smoketest on each official registered env. We may want
|
||||||
@@ -42,6 +38,7 @@ def test_all_env_api(spec):
|
|||||||
"""Check that all environments pass the environment checker with no warnings other than the expected."""
|
"""Check that all environments pass the environment checker with no warnings other than the expected."""
|
||||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||||
env = spec.make().unwrapped
|
env = spec.make().unwrapped
|
||||||
|
|
||||||
check_env(env, skip_render_check=True)
|
check_env(env, skip_render_check=True)
|
||||||
|
|
||||||
env.close()
|
env.close()
|
||||||
@@ -98,9 +95,13 @@ def test_env_determinism_rollout(env_spec: EnvSpec):
|
|||||||
env_1 = env_spec.make(disable_env_checker=True)
|
env_1 = env_spec.make(disable_env_checker=True)
|
||||||
env_2 = env_spec.make(disable_env_checker=True)
|
env_2 = env_spec.make(disable_env_checker=True)
|
||||||
|
|
||||||
|
if env_1.metadata.get("jax", False):
|
||||||
|
env_1 = gym.wrappers.JaxToNumpy(env_1)
|
||||||
|
env_2 = gym.wrappers.JaxToNumpy(env_2)
|
||||||
|
|
||||||
initial_obs_1, initial_info_1 = env_1.reset(seed=SEED)
|
initial_obs_1, initial_info_1 = env_1.reset(seed=SEED)
|
||||||
initial_obs_2, initial_info_2 = env_2.reset(seed=SEED)
|
initial_obs_2, initial_info_2 = env_2.reset(seed=SEED)
|
||||||
assert_equals(initial_obs_1, initial_obs_2)
|
assert data_equivalence(initial_obs_1, initial_obs_2, exact=True)
|
||||||
|
|
||||||
env_1.action_space.seed(SEED)
|
env_1.action_space.seed(SEED)
|
||||||
|
|
||||||
@@ -111,7 +112,9 @@ def test_env_determinism_rollout(env_spec: EnvSpec):
|
|||||||
obs_1, rew_1, terminated_1, truncated_1, info_1 = env_1.step(action)
|
obs_1, rew_1, terminated_1, truncated_1, info_1 = env_1.step(action)
|
||||||
obs_2, rew_2, terminated_2, truncated_2, info_2 = env_2.step(action)
|
obs_2, rew_2, terminated_2, truncated_2, info_2 = env_2.step(action)
|
||||||
|
|
||||||
assert_equals(obs_1, obs_2, f"[{time_step}] ")
|
assert data_equivalence(
|
||||||
|
obs_1, obs_2, exact=True
|
||||||
|
), f"[{time_step}] obs_1={obs_1}, obs_2={obs_2}"
|
||||||
assert env_1.observation_space.contains(
|
assert env_1.observation_space.contains(
|
||||||
obs_1
|
obs_1
|
||||||
) # obs_2 verified by previous assertion
|
) # obs_2 verified by previous assertion
|
||||||
@@ -123,7 +126,9 @@ def test_env_determinism_rollout(env_spec: EnvSpec):
|
|||||||
assert (
|
assert (
|
||||||
truncated_1 == truncated_2
|
truncated_1 == truncated_2
|
||||||
), f"[{time_step}] done 1={truncated_1}, done 2={truncated_2}"
|
), f"[{time_step}] done 1={truncated_1}, done 2={truncated_2}"
|
||||||
assert_equals(info_1, info_2, f"[{time_step}] ")
|
assert data_equivalence(
|
||||||
|
info_1, info_2, exact=True
|
||||||
|
), f"[{time_step}] info_1={info_1}, info_2={info_2}"
|
||||||
|
|
||||||
if (
|
if (
|
||||||
terminated_1 or truncated_1
|
terminated_1 or truncated_1
|
||||||
@@ -141,6 +146,9 @@ def test_env_determinism_rollout(env_spec: EnvSpec):
|
|||||||
ids=[env.spec.id for env in all_testing_initialised_envs if env.spec is not None],
|
ids=[env.spec.id for env in all_testing_initialised_envs if env.spec is not None],
|
||||||
)
|
)
|
||||||
def test_pickle_env(env: gym.Env):
|
def test_pickle_env(env: gym.Env):
|
||||||
|
if env.metadata.get("jax", False):
|
||||||
|
env = gym.wrappers.JaxToNumpy(env)
|
||||||
|
|
||||||
pickled_env = pickle.loads(pickle.dumps(env))
|
pickled_env = pickle.loads(pickle.dumps(env))
|
||||||
|
|
||||||
data_equivalence(env.reset(), pickled_env.reset())
|
data_equivalence(env.reset(), pickled_env.reset())
|
||||||
|
@@ -1,8 +1,6 @@
|
|||||||
"""Finds all the specs that we can test with"""
|
"""Finds all the specs that we can test with"""
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
from gymnasium import logger
|
from gymnasium import logger
|
||||||
from gymnasium.envs.registration import EnvSpec
|
from gymnasium.envs.registration import EnvSpec
|
||||||
@@ -55,28 +53,3 @@ gym_testing_env_specs: List[EnvSpec] = [
|
|||||||
for ep in ["box2d", "classic_control", "toy_text"]
|
for ep in ["box2d", "classic_control", "toy_text"]
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def assert_equals(a, b, prefix=None):
|
|
||||||
"""Assert equality of data structures `a` and `b`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
a: first data structure
|
|
||||||
b: second data structure
|
|
||||||
prefix: prefix for failed assertion message for types and dicts
|
|
||||||
"""
|
|
||||||
assert type(a) is type(b), f"{prefix}Differing types: {a} and {b}"
|
|
||||||
if isinstance(a, dict):
|
|
||||||
assert list(a.keys()) == list(b.keys()), f"{prefix}Key sets differ: {a} and {b}"
|
|
||||||
|
|
||||||
for k in a.keys():
|
|
||||||
v_a = a[k]
|
|
||||||
v_b = b[k]
|
|
||||||
assert_equals(v_a, v_b)
|
|
||||||
elif isinstance(a, np.ndarray):
|
|
||||||
np.testing.assert_array_equal(a, b)
|
|
||||||
elif isinstance(a, tuple):
|
|
||||||
for elem_from_a, elem_from_b in zip(a, b):
|
|
||||||
assert_equals(elem_from_a, elem_from_b)
|
|
||||||
else:
|
|
||||||
assert a == b
|
|
||||||
|
@@ -19,6 +19,9 @@ from tests.testing_env import GenericTestEnv
|
|||||||
ids=[env.spec.id for env in all_testing_initialised_envs if env.spec is not None],
|
ids=[env.spec.id for env in all_testing_initialised_envs if env.spec is not None],
|
||||||
)
|
)
|
||||||
def test_passive_checker_wrapper_warnings(env):
|
def test_passive_checker_wrapper_warnings(env):
|
||||||
|
if env.spec is not None and env.spec.disable_env_checker:
|
||||||
|
return
|
||||||
|
|
||||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||||
checker_env = PassiveEnvChecker(env)
|
checker_env = PassiveEnvChecker(env)
|
||||||
checker_env.reset()
|
checker_env.reset()
|
||||||
|
Reference in New Issue
Block a user