mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-22 07:02:19 +00:00
Merge v1.0.0 (#682)
Co-authored-by: Kallinteris Andreas <30759571+Kallinteris-Andreas@users.noreply.github.com> Co-authored-by: Jet <38184875+jjshoots@users.noreply.github.com> Co-authored-by: Omar Younis <42100908+younik@users.noreply.github.com>
This commit is contained in:
@@ -1,9 +1,91 @@
|
||||
"""Utility functions for testing the wrappers."""
|
||||
from __future__ import annotations
|
||||
|
||||
import gymnasium as gym
|
||||
from tests.spaces.utils import TESTING_SPACES, TESTING_SPACES_IDS
|
||||
from tests.testing_env import GenericTestEnv
|
||||
|
||||
|
||||
SEED = 42
|
||||
ENV_ID = "CartPole-v1"
|
||||
DISCRETE_ACTION = 0
|
||||
NUM_ENVS = 3
|
||||
NUM_STEPS = 20
|
||||
|
||||
|
||||
def record_obs_reset(self: gym.Env, seed=None, options: dict = None):
|
||||
"""Records and uses an observation passed through options."""
|
||||
return options["obs"], {"obs": options["obs"]}
|
||||
|
||||
|
||||
def record_random_obs_reset(self: gym.Env, seed=None, options=None):
|
||||
"""Records random observation generated by the environment."""
|
||||
obs = self.observation_space.sample()
|
||||
return obs, {"obs": obs}
|
||||
|
||||
|
||||
def record_action_step(self: gym.Env, action):
|
||||
"""Records the actions passed to the environment."""
|
||||
return 0, 0, False, False, {"action": action}
|
||||
|
||||
|
||||
def record_random_obs_step(self: gym.Env, action):
|
||||
"""Records the observation generated by the environment."""
|
||||
obs = self.observation_space.sample()
|
||||
return obs, 0, False, False, {"obs": obs}
|
||||
|
||||
|
||||
def record_action_as_obs_step(self: gym.Env, action):
|
||||
"""Uses the action as the observation."""
|
||||
return action, 0, False, False, {"obs": action}
|
||||
|
||||
|
||||
def record_action_as_record_step(self: gym.Env, action):
|
||||
"""Uses the action as the reward."""
|
||||
return 0, action, False, False, {"reward": action}
|
||||
|
||||
|
||||
def check_obs(
|
||||
env: gym.Env,
|
||||
wrapped_env: gym.Wrapper,
|
||||
transformed_obs,
|
||||
original_obs,
|
||||
strict: bool = True,
|
||||
):
|
||||
"""Checks that the original and transformed observations using the environment and wrapped environment.
|
||||
|
||||
Args:
|
||||
env: The base environment
|
||||
wrapped_env: The wrapped environment
|
||||
transformed_obs: The transformed observation by the wrapped environment
|
||||
original_obs: The original observation by the base environment.
|
||||
strict: If to check that the observations aren't contained in the other environment.
|
||||
"""
|
||||
assert (
|
||||
transformed_obs in wrapped_env.observation_space
|
||||
), f"{transformed_obs}, {wrapped_env.observation_space}"
|
||||
assert (
|
||||
original_obs in env.observation_space
|
||||
), f"{original_obs}, {env.observation_space}"
|
||||
|
||||
if strict:
|
||||
assert (
|
||||
transformed_obs not in env.observation_space
|
||||
), f"{transformed_obs}, {env.observation_space}"
|
||||
assert (
|
||||
original_obs not in wrapped_env.observation_space
|
||||
), f"{original_obs}, {wrapped_env.observation_space}"
|
||||
|
||||
|
||||
TESTING_OBS_ENVS = [GenericTestEnv(observation_space=space) for space in TESTING_SPACES]
|
||||
TESTING_OBS_ENVS_IDS = TESTING_SPACES_IDS
|
||||
|
||||
TESTING_ACTION_ENVS = [GenericTestEnv(action_space=space) for space in TESTING_SPACES]
|
||||
TESTING_ACTION_ENVS_IDS = TESTING_SPACES_IDS
|
||||
|
||||
|
||||
def has_wrapper(wrapped_env: gym.Env, wrapper_type: type[gym.Wrapper]) -> bool:
|
||||
"""Checks if the wrapper type is within the wrapped environment stack."""
|
||||
while isinstance(wrapped_env, gym.Wrapper):
|
||||
if isinstance(wrapped_env, wrapper_type):
|
||||
return True
|
||||
|
Reference in New Issue
Block a user