diff --git a/gymnasium/core.py b/gymnasium/core.py index 789834796..51cbeaadf 100644 --- a/gymnasium/core.py +++ b/gymnasium/core.py @@ -11,7 +11,7 @@ from gymnasium.utils import RecordConstructorArgs, seeding if TYPE_CHECKING: - from gymnasium.envs.registration import EnvSpec + from gymnasium.envs.registration import EnvSpec, WrapperSpec ObsType = TypeVar("ObsType") ActType = TypeVar("ActType") @@ -266,6 +266,8 @@ class Wrapper( self._reward_range: tuple[SupportsFloat, SupportsFloat] | None = None self._metadata: dict[str, Any] | None = None + self._cached_spec: EnvSpec | None = None + def __getattr__(self, name: str) -> Any: """Returns an attribute with ``name``, unless ``name`` starts with an underscore.""" if name == "_np_random": @@ -279,11 +281,11 @@ class Wrapper( @property def spec(self) -> EnvSpec | None: """Returns the :attr:`Env` :attr:`spec` attribute with the `WrapperSpec` if the wrapper inherits from `EzPickle`.""" + if self._cached_spec is not None: + return self._cached_spec + env_spec = self.env.spec - if env_spec is not None: - from gymnasium.envs.registration import WrapperSpec - # See if the wrapper inherits from `RecordConstructorArgs` then add the kwargs otherwise use `None` for the wrapper kwargs. This will raise an error in `make` if isinstance(self, RecordConstructorArgs): kwargs = getattr(self, "_saved_kwargs") @@ -293,6 +295,8 @@ class Wrapper( else: kwargs = None + from gymnasium.envs.registration import WrapperSpec + wrapper_spec = WrapperSpec( name=self.class_name(), entry_point=f"{self.__module__}:{type(self).__name__}", @@ -301,10 +305,22 @@ class Wrapper( # to avoid reference issues we deepcopy the prior environments spec and add the new information env_spec = deepcopy(env_spec) - env_spec.applied_wrappers += (wrapper_spec,) + env_spec.additional_wrappers += (wrapper_spec,) + self._cached_spec = env_spec return env_spec + @classmethod + def wrapper_spec(cls, **kwargs: Any) -> WrapperSpec: + """Generates a `WrapperSpec` for the wrappers.""" + from gymnasium.envs.registration import WrapperSpec + + return WrapperSpec( + name=cls.class_name(), + entry_point=f"{cls.__module__}:{cls.__name__}", + kwargs=kwargs, + ) + @classmethod def class_name(cls) -> str: """Returns the class name of the wrapper.""" diff --git a/gymnasium/envs/registration.py b/gymnasium/envs/registration.py index 8852ea098..5c7e032c5 100644 --- a/gymnasium/envs/registration.py +++ b/gymnasium/envs/registration.py @@ -99,7 +99,7 @@ class EnvSpec: * **autoreset**: If to automatically reset the environment on episode end * **disable_env_checker**: If to disable the environment checker wrapper in :meth:`gymnasium.make`, by default False (runs the environment checker) * **kwargs**: Additional keyword arguments passed to the environment during initialisation - * **applied_wrappers**: A tuple of applied wrappers (WrapperSpec) + * **additional_wrappers**: A tuple of additional wrappers applied to the environment (WrapperSpec) * **vector_entry_point**: The location of the vectorized environment to create from """ @@ -126,7 +126,7 @@ class EnvSpec: version: int | None = field(init=False) # applied wrappers - applied_wrappers: tuple[WrapperSpec, ...] = field(init=False, default_factory=tuple) + additional_wrappers: tuple[WrapperSpec, ...] = field(default_factory=tuple) # Vectorized environment entry point vector_entry_point: VectorEnvCreator | str | None = field(default=None) @@ -187,7 +187,7 @@ class EnvSpec: parsed_env_spec = json.loads(json_env_spec) applied_wrapper_specs: list[WrapperSpec] = [] - for wrapper_spec_json in parsed_env_spec.pop("applied_wrappers"): + for wrapper_spec_json in parsed_env_spec.pop("additional_wrappers"): try: applied_wrapper_specs.append(WrapperSpec(**wrapper_spec_json)) except Exception as e: @@ -197,7 +197,7 @@ class EnvSpec: try: env_spec = EnvSpec(**parsed_env_spec) - env_spec.applied_wrappers = tuple(applied_wrapper_specs) + env_spec.additional_wrappers = tuple(applied_wrapper_specs) except Exception as e: raise ValueError( f"An issue occurred when trying to make {parsed_env_spec} an EnvSpec" @@ -241,9 +241,9 @@ class EnvSpec: if print_all or self.apply_api_compatibility is not False: output += f"\napplied_api_compatibility={self.apply_api_compatibility}" - if print_all or self.applied_wrappers: + if print_all or self.additional_wrappers: wrapper_output: list[str] = [] - for wrapper_spec in self.applied_wrappers: + for wrapper_spec in self.additional_wrappers: if include_entry_points: wrapper_output.append( f"\n\tname={wrapper_spec.name}, entry_point={wrapper_spec.entry_point}, kwargs={wrapper_spec.kwargs}" @@ -254,9 +254,9 @@ class EnvSpec: ) if len(wrapper_output) == 0: - output += "\napplied_wrappers=[]" + output += "\nadditional_wrappers=[]" else: - output += f"\napplied_wrappers=[{','.join(wrapper_output)}\n]" + output += f"\nadditional_wrappers=[{','.join(wrapper_output)}\n]" if disable_print: return output @@ -555,161 +555,6 @@ def load_env_creator(name: str) -> EnvCreator | VectorEnvCreator: return fn -def _create_from_env_spec( - env_spec: EnvSpec, - kwargs: dict[str, Any], -) -> Env: - """Recreates an environment spec using a list of wrapper specs.""" - if callable(env_spec.entry_point): - env_creator = env_spec.entry_point - else: - env_creator: EnvCreator = load_env_creator(env_spec.entry_point) - - # Create the environment - env: Env = env_creator(**env_spec.kwargs, **kwargs) - - # Set the `EnvSpec` to the environment - new_env_spec = copy.deepcopy(env_spec) - new_env_spec.applied_wrappers = () - new_env_spec.kwargs.update(kwargs) - env.unwrapped.spec = new_env_spec - - # Check if the environment spec - assert env.spec is not None # this is for pyright - num_prior_wrappers = len(env.spec.applied_wrappers) - if env_spec.applied_wrappers[:num_prior_wrappers] != env.spec.applied_wrappers: - for env_spec_wrapper_spec, recreated_wrapper_spec in zip( - env_spec.applied_wrappers, env.spec.applied_wrappers - ): - raise ValueError( - f"The environment's wrapper spec {recreated_wrapper_spec} is different from the saved `EnvSpec` applied_wrappers {env_spec_wrapper_spec}" - ) - - for wrapper_spec in env_spec.applied_wrappers[num_prior_wrappers:]: - if wrapper_spec.kwargs is None: - raise ValueError( - f"{wrapper_spec.name} wrapper does not inherit from `gymnasium.utils.RecordConstructorArgs`, therefore, the wrapper cannot be recreated." - ) - - env = load_env_creator(wrapper_spec.entry_point)(env=env, **wrapper_spec.kwargs) - - return env - - -def _create_from_env_id( - env_spec: EnvSpec, - kwargs: dict[str, Any], - max_episode_steps: int | None = None, - autoreset: bool = False, - apply_api_compatibility: bool | None = None, - disable_env_checker: bool | None = None, -) -> Env: - """Creates an environment based on the `env_spec` along with wrapper options. See `make` for their meaning.""" - spec_kwargs = copy.deepcopy(env_spec.kwargs) - spec_kwargs.update(kwargs) - - # Load the environment creator - if env_spec.entry_point is None: - raise error.Error(f"{env_spec.id} registered but entry_point is not specified") - elif callable(env_spec.entry_point): - env_creator = env_spec.entry_point - else: - # Assume it's a string - env_creator = load_env_creator(env_spec.entry_point) - - # Determine if to use the rendering - render_modes: list[str] | None = None - if hasattr(env_creator, "metadata"): - _check_metadata(env_creator.metadata) - render_modes = env_creator.metadata.get("render_modes") - mode = spec_kwargs.get("render_mode") - apply_human_rendering = False - apply_render_collection = False - - # If mode is not valid, try applying HumanRendering/RenderCollection wrappers - if mode is not None and render_modes is not None and mode not in render_modes: - displayable_modes = {"rgb_array", "rgb_array_list"}.intersection(render_modes) - if mode == "human" and len(displayable_modes) > 0: - logger.warn( - "You are trying to use 'human' rendering for an environment that doesn't natively support it. " - "The HumanRendering wrapper is being applied to your environment." - ) - spec_kwargs["render_mode"] = displayable_modes.pop() - apply_human_rendering = True - elif mode.endswith("_list") and mode[: -len("_list")] in render_modes: - spec_kwargs["render_mode"] = mode[: -len("_list")] - apply_render_collection = True - else: - logger.warn( - f"The environment is being initialised with render_mode={mode!r} " - f"that is not in the possible render_modes ({render_modes})." - ) - - if apply_api_compatibility or ( - apply_api_compatibility is None and env_spec.apply_api_compatibility - ): - # If we use the compatibility layer, we treat the render mode explicitly and don't pass it to the env creator - render_mode = spec_kwargs.pop("render_mode", None) - else: - render_mode = None - - try: - env = env_creator(**spec_kwargs) - except TypeError as e: - if ( - str(e).find("got an unexpected keyword argument 'render_mode'") >= 0 - and apply_human_rendering - ): - raise error.Error( - f"You passed render_mode='human' although {env_spec.id} doesn't implement human-rendering natively. " - "Gym tried to apply the HumanRendering wrapper but it looks like your environment is using the old " - "rendering API, which is not supported by the HumanRendering wrapper." - ) from e - else: - raise e - - # Copies the environment creation specification and kwargs to add to the environment specification details - env_spec = copy.deepcopy(env_spec) - env_spec.kwargs = spec_kwargs - env.unwrapped.spec = env_spec - - # Add step API wrapper - if apply_api_compatibility is True or ( - apply_api_compatibility is None and env_spec.apply_api_compatibility is True - ): - env = EnvCompatibility(env, render_mode) - - # Run the environment checker as the lowest level wrapper - if disable_env_checker is False or ( - disable_env_checker is None and env_spec.disable_env_checker is False - ): - env = PassiveEnvChecker(env) - - # Add the order enforcing wrapper - if env_spec.order_enforce: - env = OrderEnforcing(env) - - # Add the time limit wrapper - if max_episode_steps is not None: - assert env.unwrapped.spec is not None # for pyright - env.unwrapped.spec.max_episode_steps = max_episode_steps - env = TimeLimit(env, max_episode_steps) - elif env_spec.max_episode_steps is not None: - env = TimeLimit(env, env_spec.max_episode_steps) - - # Add the auto-reset wrapper - if autoreset: - env = AutoResetWrapper(env) - - # Add human rendering wrapper - if apply_human_rendering: - env = HumanRendering(env) - elif apply_render_collection: - env = RenderCollection(env) - - return env - - def load_plugin_envs(entry_point: str = "gymnasium.envs"): """Load modules (plugins) using the gymnasium entry points in order to register external module's environments on ``import gymnasium``. @@ -779,6 +624,7 @@ def register( autoreset: bool = False, disable_env_checker: bool = False, apply_api_compatibility: bool = False, + additional_wrappers: tuple[WrapperSpec, ...] = (), vector_entry_point: VectorEnvCreator | str | None = None, **kwargs: Any, ): @@ -801,6 +647,7 @@ def register( disable_env_checker: If to disable the :class:`gymnasium.wrappers.PassiveEnvChecker` to the environment. apply_api_compatibility: If to apply the :class:`gymnasium.wrappers.StepAPICompatibility` wrapper to the environment. Use if the environment is implemented in the gym v0.21 environment API. + additional_wrappers: Additional wrappers to apply the environment. vector_entry_point: The entry point for creating the vector environment **kwargs: arbitrary keyword arguments which are passed to the environment constructor on initialisation. """ @@ -835,8 +682,9 @@ def register( order_enforce=order_enforce, autoreset=autoreset, disable_env_checker=disable_env_checker, - **kwargs, apply_api_compatibility=apply_api_compatibility, + **kwargs, + additional_wrappers=additional_wrappers, vector_entry_point=vector_entry_point, ) _check_spec_register(new_spec) @@ -849,7 +697,7 @@ def register( def make( id: str | EnvSpec, max_episode_steps: int | None = None, - autoreset: bool = False, + autoreset: bool | None = None, apply_api_compatibility: bool | None = None, disable_env_checker: bool | None = None, **kwargs: Any, @@ -878,32 +726,12 @@ def make( Error: If the ``id`` doesn't exist in the :attr:`registry` """ if isinstance(id, EnvSpec): - if hasattr(id, "applied_wrappers") and id.applied_wrappers is not None: - if max_episode_steps is not None: - logger.warn( - f"For `gymnasium.make` with an `EnvSpec`, the `max_episode_step` parameter is ignored, use `gym.make({id.id}, max_episode_steps={max_episode_steps})` and any additional wrappers" - ) - if autoreset is True: - logger.warn( - f"For `gymnasium.make` with an `EnvSpec`, the `autoreset` parameter is ignored, use `gym.make({id.id}, autoreset={autoreset})` and any additional wrappers" - ) - if apply_api_compatibility is not None: - logger.warn( - f"For `gymnasium.make` with an `EnvSpec`, the `apply_api_compatibility` parameter is ignored, use `gym.make({id.id}, apply_api_compatibility={apply_api_compatibility})` and any additional wrappers" - ) - if disable_env_checker is not None: - logger.warn( - f"For `gymnasium.make` with an `EnvSpec`, the `disable_env_checker` parameter is ignored, use `gym.make({id.id}, disable_env_checker={disable_env_checker})` and any additional wrappers" - ) - - return _create_from_env_spec( - id, - kwargs, - ) - else: - raise ValueError( - f"The EnvSpec used does not contain `applied_wrappers` parameters or is `None`. Expected to be a tuple, actually {id}." + env_spec = id + if not hasattr(env_spec, "additional_wrappers"): + logger.warn( + f"The env spec passed to `make` does not have a `additional_wrappers`, set it to an empty tuple. Env_spec={env_spec}" ) + env_spec.additional_wrappers = () else: # For string id's, load the environment spec from the registry then make the environment spec assert isinstance(id, str) @@ -911,14 +739,150 @@ def make( # The environment name can include an unloaded module in "module:env_name" style env_spec = _find_spec(id) - return _create_from_env_id( - env_spec, - kwargs, - max_episode_steps=max_episode_steps, - autoreset=autoreset, - apply_api_compatibility=apply_api_compatibility, - disable_env_checker=disable_env_checker, - ) + assert isinstance(env_spec, EnvSpec) + + # Update the env spec kwargs with the `make` kwargs + env_spec_kwargs = copy.deepcopy(env_spec.kwargs) + env_spec_kwargs.update(kwargs) + + # Load the environment creator + if env_spec.entry_point is None: + raise error.Error(f"{env_spec.id} registered but entry_point is not specified") + elif callable(env_spec.entry_point): + env_creator = env_spec.entry_point + else: + # Assume it's a string + env_creator = load_env_creator(env_spec.entry_point) + + # Determine if to use the rendering + render_modes: list[str] | None = None + if hasattr(env_creator, "metadata"): + _check_metadata(env_creator.metadata) + render_modes = env_creator.metadata.get("render_modes") + render_mode = env_spec_kwargs.get("render_mode") + apply_human_rendering = False + apply_render_collection = False + + # If mode is not valid, try applying HumanRendering/RenderCollection wrappers + if ( + render_mode is not None + and render_modes is not None + and render_mode not in render_modes + ): + displayable_modes = {"rgb_array", "rgb_array_list"}.intersection(render_modes) + if render_mode == "human" and len(displayable_modes) > 0: + logger.warn( + "You are trying to use 'human' rendering for an environment that doesn't natively support it. " + "The HumanRendering wrapper is being applied to your environment." + ) + env_spec_kwargs["render_mode"] = displayable_modes.pop() + apply_human_rendering = True + elif ( + render_mode.endswith("_list") + and render_mode[: -len("_list")] in render_modes + ): + env_spec_kwargs["render_mode"] = render_mode[: -len("_list")] + apply_render_collection = True + else: + logger.warn( + f"The environment is being initialised with render_mode={render_mode!r} " + f"that is not in the possible render_modes ({render_modes})." + ) + + if apply_api_compatibility or ( + apply_api_compatibility is None and env_spec.apply_api_compatibility + ): + # If we use the compatibility layer, we treat the render mode explicitly and don't pass it to the env creator + render_mode = env_spec_kwargs.pop("render_mode", None) + else: + render_mode = None + + try: + env = env_creator(**env_spec_kwargs) + except TypeError as e: + if ( + str(e).find("got an unexpected keyword argument 'render_mode'") >= 0 + and apply_human_rendering + ): + raise error.Error( + f"You passed render_mode='human' although {env_spec.id} doesn't implement human-rendering natively. " + "Gym tried to apply the HumanRendering wrapper but it looks like your environment is using the old " + "rendering API, which is not supported by the HumanRendering wrapper." + ) from e + else: + raise e + + # Set the minimal env spec for the environment. + env.unwrapped.spec = EnvSpec( + id=env_spec.id, + entry_point=env_spec.entry_point, + reward_threshold=env_spec.reward_threshold, + nondeterministic=env_spec.nondeterministic, + max_episode_steps=None, + order_enforce=False, + autoreset=False, + disable_env_checker=True, + apply_api_compatibility=False, + kwargs=env_spec_kwargs, + additional_wrappers=(), + vector_entry_point=env_spec.vector_entry_point, + ) + + # Check if pre-wrapped wrappers + assert env.spec is not None + num_prior_wrappers = len(env.spec.additional_wrappers) + if ( + env_spec.additional_wrappers[:num_prior_wrappers] + != env.spec.additional_wrappers + ): + for env_spec_wrapper_spec, recreated_wrapper_spec in zip( + env_spec.additional_wrappers, env.spec.additional_wrappers + ): + raise ValueError( + f"The environment's wrapper spec {recreated_wrapper_spec} is different from the saved `EnvSpec` additional wrapper {env_spec_wrapper_spec}" + ) + + # Add step API wrapper + if apply_api_compatibility is True or ( + apply_api_compatibility is None and env_spec.apply_api_compatibility is True + ): + env = EnvCompatibility(env, render_mode) + + # Run the environment checker as the lowest level wrapper + if disable_env_checker is False or ( + disable_env_checker is None and env_spec.disable_env_checker is False + ): + env = PassiveEnvChecker(env) + + # Add the order enforcing wrapper + if env_spec.order_enforce: + env = OrderEnforcing(env) + + # Add the time limit wrapper + if max_episode_steps is not None: + env = TimeLimit(env, max_episode_steps) + elif env_spec.max_episode_steps is not None: + env = TimeLimit(env, env_spec.max_episode_steps) + + # Add the auto-reset wrapper + if autoreset is True or (autoreset is None and env_spec.autoreset is True): + env = AutoResetWrapper(env) + + for wrapper_spec in env_spec.additional_wrappers[num_prior_wrappers:]: + if wrapper_spec.kwargs is None: + raise ValueError( + f"{wrapper_spec.name} wrapper does not inherit from `gymnasium.utils.RecordConstructorArgs`, therefore, the wrapper cannot be recreated." + ) + + env = load_env_creator(wrapper_spec.entry_point)(env=env, **wrapper_spec.kwargs) + + # Add human rendering wrapper + if apply_human_rendering: + env = HumanRendering(env) + elif apply_render_collection: + env = RenderCollection(env) + + return env def make_vec( diff --git a/gymnasium/wrappers/autoreset.py b/gymnasium/wrappers/autoreset.py index 84c0c16ea..09012e0b9 100644 --- a/gymnasium/wrappers/autoreset.py +++ b/gymnasium/wrappers/autoreset.py @@ -1,7 +1,16 @@ """Wrapper that autoreset environments when `terminated=True` or `truncated=True`.""" +from __future__ import annotations + +from copy import deepcopy +from typing import TYPE_CHECKING + import gymnasium as gym +if TYPE_CHECKING: + from gymnasium.envs.registration import EnvSpec + + class AutoResetWrapper(gym.Wrapper, gym.utils.RecordConstructorArgs): """A class for providing an automatic reset functionality for gymnasium environments when calling :meth:`self.step`. @@ -60,3 +69,17 @@ class AutoResetWrapper(gym.Wrapper, gym.utils.RecordConstructorArgs): info = new_info return obs, reward, terminated, truncated, info + + @property + def spec(self) -> EnvSpec | None: + """Modifies the environment spec to specify the `autoreset=True`.""" + if self._cached_spec is not None: + return self._cached_spec + + env_spec = self.env.spec + if env_spec is not None: + env_spec = deepcopy(env_spec) + env_spec.autoreset = True + + self._cached_spec = env_spec + return env_spec diff --git a/gymnasium/wrappers/env_checker.py b/gymnasium/wrappers/env_checker.py index 51115d1fc..ae86724ea 100644 --- a/gymnasium/wrappers/env_checker.py +++ b/gymnasium/wrappers/env_checker.py @@ -1,4 +1,9 @@ """A passive environment checker wrapper for an environment's observation and action space along with the reset, step and render functions.""" +from __future__ import annotations + +from copy import deepcopy +from typing import TYPE_CHECKING + import gymnasium as gym from gymnasium.core import ActType from gymnasium.utils.passive_env_checker import ( @@ -10,6 +15,10 @@ from gymnasium.utils.passive_env_checker import ( ) +if TYPE_CHECKING: + from gymnasium.envs.registration import EnvSpec + + class PassiveEnvChecker(gym.Wrapper, gym.utils.RecordConstructorArgs): """A passive environment checker wrapper that surrounds the step, reset and render functions to check they follow the gymnasium API.""" @@ -54,3 +63,17 @@ class PassiveEnvChecker(gym.Wrapper, gym.utils.RecordConstructorArgs): return env_render_passive_checker(self.env, *args, **kwargs) else: return self.env.render(*args, **kwargs) + + @property + def spec(self) -> EnvSpec | None: + """Modifies the environment spec to such that `disable_env_checker=False`.""" + if self._cached_spec is not None: + return self._cached_spec + + env_spec = self.env.spec + if env_spec is not None: + env_spec = deepcopy(env_spec) + env_spec.disable_env_checker = False + + self._cached_spec = env_spec + return env_spec diff --git a/gymnasium/wrappers/order_enforcing.py b/gymnasium/wrappers/order_enforcing.py index 21efa3304..abfb1be0f 100644 --- a/gymnasium/wrappers/order_enforcing.py +++ b/gymnasium/wrappers/order_enforcing.py @@ -1,8 +1,17 @@ """Wrapper to enforce the proper ordering of environment operations.""" +from __future__ import annotations + +from copy import deepcopy +from typing import TYPE_CHECKING + import gymnasium as gym from gymnasium.error import ResetNeeded +if TYPE_CHECKING: + from gymnasium.envs.registration import EnvSpec + + class OrderEnforcing(gym.Wrapper, gym.utils.RecordConstructorArgs): """A wrapper that will produce an error if :meth:`step` is called before an initial :meth:`reset`. @@ -64,3 +73,17 @@ class OrderEnforcing(gym.Wrapper, gym.utils.RecordConstructorArgs): def has_reset(self): """Returns if the environment has been reset before.""" return self._has_reset + + @property + def spec(self) -> EnvSpec | None: + """Modifies the environment spec to add the `order_enforce=True`.""" + if self._cached_spec is not None: + return self._cached_spec + + env_spec = self.env.spec + if env_spec is not None: + env_spec = deepcopy(env_spec) + env_spec.order_enforce = True + + self._cached_spec = env_spec + return env_spec diff --git a/gymnasium/wrappers/time_limit.py b/gymnasium/wrappers/time_limit.py index e4c61caf3..48d564ad9 100644 --- a/gymnasium/wrappers/time_limit.py +++ b/gymnasium/wrappers/time_limit.py @@ -1,9 +1,16 @@ """Wrapper for limiting the time steps of an environment.""" -from typing import Optional +from __future__ import annotations + +from copy import deepcopy +from typing import TYPE_CHECKING import gymnasium as gym +if TYPE_CHECKING: + from gymnasium.envs.registration import EnvSpec + + class TimeLimit(gym.Wrapper, gym.utils.RecordConstructorArgs): """This wrapper will issue a `truncated` signal if a maximum number of timesteps is exceeded. @@ -20,7 +27,7 @@ class TimeLimit(gym.Wrapper, gym.utils.RecordConstructorArgs): def __init__( self, env: gym.Env, - max_episode_steps: Optional[int] = None, + max_episode_steps: int, ): """Initializes the :class:`TimeLimit` wrapper with an environment and the number of steps after which truncation will occur. @@ -33,9 +40,6 @@ class TimeLimit(gym.Wrapper, gym.utils.RecordConstructorArgs): ) gym.Wrapper.__init__(self, env) - if max_episode_steps is None and self.env.spec is not None: - assert env.spec is not None - max_episode_steps = env.spec.max_episode_steps self._max_episode_steps = max_episode_steps self._elapsed_steps = None @@ -69,3 +73,17 @@ class TimeLimit(gym.Wrapper, gym.utils.RecordConstructorArgs): """ self._elapsed_steps = 0 return self.env.reset(**kwargs) + + @property + def spec(self) -> EnvSpec | None: + """Modifies the environment spec to include the `max_episode_steps=self._max_episode_steps`.""" + if self._cached_spec is not None: + return self._cached_spec + + env_spec = self.env.spec + if env_spec is not None: + env_spec = deepcopy(env_spec) + env_spec.max_episode_steps = self._max_episode_steps + + self._cached_spec = env_spec + return env_spec diff --git a/tests/envs/registration/test_env_spec.py b/tests/envs/registration/test_env_spec.py index 8a72fa286..bdc78de29 100644 --- a/tests/envs/registration/test_env_spec.py +++ b/tests/envs/registration/test_env_spec.py @@ -1,4 +1,4 @@ -"""Example file showing usage of env.specstack.""" +"""Test for the `EnvSpec`, in particular, a full integration with `EnvSpec`.""" import pickle import pytest @@ -11,15 +11,16 @@ from gymnasium.utils.env_checker import data_equivalence def test_full_integration(): # Create an environment to test with - env = gym.make("CartPole-v1", render_mode="rgb_array").unwrapped + env = gym.make("CartPole-v1", render_mode="rgb_array") - env = gym.wrappers.FlattenObservation(env) env = gym.wrappers.TimeAwareObservation(env) env = gym.wrappers.NormalizeReward(env, gamma=0.8) # Generate the spec_stack env_spec = env.spec assert isinstance(env_spec, EnvSpec) + # additional_wrappers = (TimeAwareObservation, NormalizeReward) + assert len(env_spec.additional_wrappers) == 2 # env_spec.pprint() # Serialize the spec_stack @@ -30,10 +31,7 @@ def test_full_integration(): recreate_env_spec = EnvSpec.from_json(env_spec_json) # recreate_env_spec.pprint() - for wrapper_spec, recreated_wrapper_spec in zip( - env_spec.applied_wrappers, recreate_env_spec.applied_wrappers - ): - assert wrapper_spec == recreated_wrapper_spec + assert env_spec.additional_wrappers == recreate_env_spec.additional_wrappers assert recreate_env_spec == env_spec # Recreate the environment using the spec_stack @@ -86,43 +84,6 @@ def test_env_spec_to_from_json(env_spec: EnvSpec): assert env_spec == recreated_env_spec -def test_wrapped_env_entry_point(): - def _create_env(): - _env = gym.make("CartPole-v1", render_mode="rgb_array") - _env = gym.wrappers.FlattenObservation(_env) - return _env - - gym.register("TestingEnv-v0", entry_point=_create_env) - - env = gym.make("TestingEnv-v0") - env = gym.wrappers.TimeAwareObservation(env) - env = gym.wrappers.NormalizeReward(env, gamma=0.8) - - recreated_env = gym.make(env.spec) - - obs, info = env.reset(seed=42) - recreated_obs, recreated_info = recreated_env.reset(seed=42) - assert data_equivalence(obs, recreated_obs) - assert data_equivalence(info, recreated_info) - - action = env.action_space.sample() - obs, reward, terminated, truncated, info = env.step(action) - ( - recreated_obs, - recreated_reward, - recreated_terminated, - recreated_truncated, - recreated_info, - ) = recreated_env.step(action) - assert data_equivalence(obs, recreated_obs) - assert data_equivalence(reward, recreated_reward) - assert data_equivalence(terminated, recreated_terminated) - assert data_equivalence(truncated, recreated_truncated) - assert data_equivalence(info, recreated_info) - - del gym.registry["TestingEnv-v0"] - - def test_pickling_env_stack(): env = gym.make("CartPole-v1", render_mode="rgb_array") @@ -163,6 +124,8 @@ def test_pickling_env_stack(): def test_env_spec_pprint(): env = gym.make("CartPole-v1") + env = gym.wrappers.TimeAwareObservation(env) + env_spec = env.spec assert env_spec is not None @@ -172,10 +135,8 @@ def test_env_spec_pprint(): == """id=CartPole-v1 reward_threshold=475.0 max_episode_steps=500 -applied_wrappers=[ - name=PassiveEnvChecker, kwargs={}, - name=OrderEnforcing, kwargs={'disable_render_order_enforcing': False}, - name=TimeLimit, kwargs={'max_episode_steps': 500} +additional_wrappers=[ + name=TimeAwareObservation, kwargs={} ]""" ) @@ -186,10 +147,8 @@ applied_wrappers=[ entry_point=gymnasium.envs.classic_control.cartpole:CartPoleEnv reward_threshold=475.0 max_episode_steps=500 -applied_wrappers=[ - name=PassiveEnvChecker, entry_point=gymnasium.wrappers.env_checker:PassiveEnvChecker, kwargs={}, - name=OrderEnforcing, entry_point=gymnasium.wrappers.order_enforcing:OrderEnforcing, kwargs={'disable_render_order_enforcing': False}, - name=TimeLimit, entry_point=gymnasium.wrappers.time_limit:TimeLimit, kwargs={'max_episode_steps': 500} +additional_wrappers=[ + name=TimeAwareObservation, entry_point=gymnasium.wrappers.time_aware_observation:TimeAwareObservation, kwargs={} ]""" ) @@ -205,14 +164,12 @@ order_enforce=True autoreset=False disable_env_checker=False applied_api_compatibility=False -applied_wrappers=[ - name=PassiveEnvChecker, kwargs={}, - name=OrderEnforcing, kwargs={'disable_render_order_enforcing': False}, - name=TimeLimit, kwargs={'max_episode_steps': 500} +additional_wrappers=[ + name=TimeAwareObservation, kwargs={} ]""" ) - env_spec.applied_wrappers = () + env_spec.additional_wrappers = () output = env_spec.pprint(disable_print=True) assert ( output @@ -233,5 +190,5 @@ order_enforce=True autoreset=False disable_env_checker=False applied_api_compatibility=False -applied_wrappers=[]""" +additional_wrappers=[]""" ) diff --git a/tests/envs/registration/test_make.py b/tests/envs/registration/test_make.py index 25b189664..552babce2 100644 --- a/tests/envs/registration/test_make.py +++ b/tests/envs/registration/test_make.py @@ -1,4 +1,4 @@ -"""Tests that gym.make works as expected.""" +"""Tests that `gym.make` works as expected.""" from __future__ import annotations import re @@ -11,6 +11,8 @@ import gymnasium as gym from gymnasium import Env from gymnasium.core import ActType, ObsType, WrapperObsType from gymnasium.envs.classic_control import CartPoleEnv +from gymnasium.error import NameNotFound +from gymnasium.utils.env_checker import data_equivalence from gymnasium.wrappers import ( AutoResetWrapper, HumanRendering, @@ -24,111 +26,101 @@ from tests.testing_env import GenericTestEnv, old_step_func from tests.wrappers.utils import has_wrapper -try: - import shimmy -except ImportError: - shimmy = None +# Tests +# * basic example +# * parameters (equivalent for str and EnvSpec) +# 1. max_episode_steps +# 2. autoreset +# 3. apply_api_compatibility +# 4. disable_env_checker +# * rendering +# 1. render_mode +# 2. HumanRendering +# 3. RenderCollection +# * make kwargs +# * make import module +# * make env spec additional wrappers +# * env_id str errors -@pytest.fixture(scope="function") -def register_testing_envs(): - """Registers testing envs for `gym.make`""" - gym.register( - id="test.ArgumentEnv-v0", - entry_point="tests.envs.registration.utils_envs:ArgumentEnv", - kwargs={ - "arg1": "arg1", - "arg2": "arg2", - }, - ) +def test_no_arguments(env_id: str = "CartPole-v1"): + """Test `gym.make` using str and EnvSpec with no arguments.""" + env_from_id = gym.make(env_id) + assert env_from_id.spec is not None + assert env_from_id.spec.id == env_id + assert isinstance(env_from_id.unwrapped, CartPoleEnv) - gym.register( - id="test/NoHuman-v0", - entry_point="tests.envs.registration.utils_envs:NoHuman", - ) - gym.register( - id="test/NoHumanOldAPI-v0", - entry_point="tests.envs.registration.utils_envs:NoHumanOldAPI", - ) + env_spec = gym.spec(env_id) + env_from_spec = gym.make(env_spec) + assert env_from_spec.spec is not None + assert env_from_spec.spec.id == env_id + assert isinstance(env_from_spec.unwrapped, CartPoleEnv) - gym.register( - id="test/NoHumanNoRGB-v0", - entry_point="tests.envs.registration.utils_envs:NoHumanNoRGB", - ) - - gym.register( - id="test/NoRenderModesMetadata-v0", - entry_point="tests.envs.registration.utils_envs:NoRenderModesMetadata", - ) - - yield - - del gym.envs.registration.registry["test.ArgumentEnv-v0"] - del gym.envs.registration.registry["test/NoRenderModesMetadata-v0"] - del gym.envs.registration.registry["test/NoHuman-v0"] - del gym.envs.registration.registry["test/NoHumanOldAPI-v0"] - del gym.envs.registration.registry["test/NoHumanNoRGB-v0"] + assert env_from_id.spec == env_from_spec.spec -def test_make(): - """Test basic `gym.make`.""" - env = gym.make("CartPole-v1") - assert env.spec is not None - assert env.spec.id == "CartPole-v1" - assert isinstance(env.unwrapped, CartPoleEnv) - env.close() +def test_max_episode_steps(register_parameter_envs): + """Test the `max_episode_steps` parameter in `gym.make`.""" + for make_id in ["CartPole-v1", gym.spec("CartPole-v1")]: + env_spec = gym.spec(make_id) if isinstance(make_id, str) else make_id + + # Use the spec's value + env = gym.make(make_id) + assert has_wrapper(env, TimeLimit) + assert env.spec is not None + assert env.spec.max_episode_steps == env_spec.max_episode_steps + + # Set a custom max episode steps value + assert env_spec.max_episode_steps != 100 + env = gym.make(make_id, max_episode_steps=100) + assert has_wrapper(env, TimeLimit) + assert env.spec is not None + assert env.spec.max_episode_steps == 100, make_id + + for make_id in ["NoMaxEpisodeStepsEnv-v0", gym.spec("NoMaxEpisodeStepsEnv-v0")]: + env_spec = gym.spec(make_id) if isinstance(make_id, str) else make_id + + # env spec has no max episode steps + assert env_spec.max_episode_steps is None + env = gym.make(make_id) + assert env.spec is not None + assert env.spec.max_episode_steps is None + assert has_wrapper(env, TimeLimit) is False + + # set a custom max episode steps values + env = gym.make(make_id, max_episode_steps=100) + assert env.spec is not None + assert env.spec.max_episode_steps == 100 + assert has_wrapper(env, TimeLimit) -def test_make_deprecated(): - """Test make with a deprecated environment (i.e., doesn't exist).""" - with warnings.catch_warnings(record=True): - with pytest.raises( - gym.error.Error, - match=re.escape( - "Environment version v0 for `Humanoid` is deprecated. Please use `Humanoid-v4` instead." - ), - ): - gym.make("Humanoid-v0") +def test_autorest(register_parameter_envs): + """Test the `autoreset` parameter in `gym.make`.""" + for make_id in [ + "CartPole-v1", + gym.spec("CartPole-v1"), + "AutoresetEnv-v0", + gym.spec("AutoresetEnv-v0"), + ]: + env_spec = gym.spec(make_id) if isinstance(make_id, str) else make_id + # Use the spec's value + env = gym.make(make_id) + assert env.spec is not None + assert env.spec.autoreset == env_spec.autoreset + assert has_wrapper(env, AutoResetWrapper) is env_spec.autoreset -def test_make_max_episode_steps(register_testing_envs): - # Default, uses the spec's - env = gym.make("CartPole-v1") - assert has_wrapper(env, TimeLimit) - assert env.spec is not None - assert env.spec.max_episode_steps == gym.spec("CartPole-v1").max_episode_steps - env.close() + # Set autoreset is True + env = gym.make(make_id, autoreset=True) + assert has_wrapper(env, AutoResetWrapper) + assert env.spec is not None + assert env.spec.autoreset is True - # Custom max episode steps - assert gym.spec("CartPole-v1").max_episode_steps != 100 - env = gym.make("CartPole-v1", max_episode_steps=100) - assert has_wrapper(env, TimeLimit) - assert env.spec is not None - assert env.spec.max_episode_steps == 100 - env.close() - - # Env spec has no max episode steps - assert gym.spec("test.ArgumentEnv-v0").max_episode_steps is None - env = gym.make("test.ArgumentEnv-v0", arg1=None, arg2=None, arg3=None) - assert env.spec is not None - assert env.spec.max_episode_steps is None - assert has_wrapper(env, TimeLimit) is False - env.close() - - -def test_make_autoreset(): - """Tests that `gym.make` autoreset wrapper is applied only when `gym.make(..., autoreset=True)`.""" - env = gym.make("CartPole-v1") - assert has_wrapper(env, AutoResetWrapper) is False - env.close() - - env = gym.make("CartPole-v1", autoreset=False) - assert has_wrapper(env, AutoResetWrapper) is False - env.close() - - env = gym.make("CartPole-v1", autoreset=True) - assert has_wrapper(env, AutoResetWrapper) - env.close() + # Set autoreset is False + env = gym.make(make_id, autoreset=False) + assert has_wrapper(env, AutoResetWrapper) is False + assert env.spec is not None + assert env.spec.autoreset is False @pytest.mark.parametrize( @@ -142,7 +134,7 @@ def test_make_autoreset(): [True, None, True], ], ) -def test_make_disable_env_checker( +def test_disable_env_checker( registration_disabled: bool, make_disabled: bool | None, if_disabled: bool ): """Tests that `gym.make` disable env checker is applied only when `gym.make(..., disable_env_checker=False)`. @@ -150,84 +142,76 @@ def test_make_disable_env_checker( The ordering is 1. if the `make(..., disable_env_checker=...)` is bool, then the `registration(..., disable_env_checker=...)` """ gym.register( - "testing-env-v0", + "DisableEnvCheckerEnv-v0", lambda: GenericTestEnv(), disable_env_checker=registration_disabled, ) # Test when the registered EnvSpec.disable_env_checker = False - env = gym.make("testing-env-v0", disable_env_checker=make_disabled) + env = gym.make("DisableEnvCheckerEnv-v0", disable_env_checker=make_disabled) assert has_wrapper(env, PassiveEnvChecker) is not if_disabled - env.close() - del gym.registry["testing-env-v0"] + env_spec = gym.spec("DisableEnvCheckerEnv-v0") + env = gym.make(env_spec, disable_env_checker=make_disabled) + assert has_wrapper(env, PassiveEnvChecker) is not if_disabled + + del gym.registry["DisableEnvCheckerEnv-v0"] -def test_make_apply_api_compatibility(): - """Test the API compatibility wrapper.""" - gym.register( - "testing-old-env", - lambda: GenericTestEnv(step_func=old_step_func), - apply_api_compatibility=True, - max_episode_steps=3, - ) +def test_apply_api_compatibility(register_parameter_envs): + """Test the `apply_api_compatibility` parameter for `gym.make`.""" # Apply the environment compatibility and check it works as intended - env = gym.make("testing-old-env") - assert isinstance(env.unwrapped, gym.wrappers.EnvCompatibility) + for make_id in ["EnabledApplyApiComp-v0", gym.spec("EnabledApplyApiComp-v0")]: + env = gym.make(make_id) + assert isinstance(env.unwrapped, gym.wrappers.EnvCompatibility) - env.reset() - assert len(env.step(env.action_space.sample())) == 5 - env.step(env.action_space.sample()) - _, _, termination, truncation, _ = env.step(env.action_space.sample()) - assert termination is False and truncation is True - - # Turn off the spec api compatibility - gym.spec("testing-old-env").apply_api_compatibility = False - env = gym.make("testing-old-env") - assert isinstance(env.unwrapped, gym.wrappers.EnvCompatibility) is False - env.reset() - with pytest.raises( - ValueError, match=re.escape("not enough values to unpack (expected 5, got 4)") - ): + # env has time limit of 3 enabling this test + env.reset() + assert len(env.step(env.action_space.sample())) == 5 env.step(env.action_space.sample()) + _, _, termination, truncation, _ = env.step(env.action_space.sample()) + assert termination is False and truncation is True - # Apply the environment compatibility and check it works as intended - env = gym.make("testing-old-env", apply_api_compatibility=True) - assert isinstance(env.unwrapped, gym.wrappers.EnvCompatibility) + for make_id in ["DisabledApplyApiComp-v0", gym.spec("DisabledApplyApiComp-v0")]: + # Turn off the spec api compatibility + env = gym.make(make_id) + assert isinstance(env.unwrapped, gym.wrappers.EnvCompatibility) is False + env.reset() + with pytest.raises( + ValueError, + match=re.escape("not enough values to unpack (expected 5, got 4)"), + ): + env.step(env.action_space.sample()) - env.reset() - assert len(env.step(env.action_space.sample())) == 5 - env.step(env.action_space.sample()) - _, _, termination, truncation, _ = env.step(env.action_space.sample()) - assert termination is False and truncation is True + # Apply the environment compatibility and check it works as intended + assert env.spec is not None + assert env.spec.apply_api_compatibility is False + env = gym.make(make_id, apply_api_compatibility=True) + assert isinstance(env.unwrapped, gym.wrappers.EnvCompatibility) - del gym.registry["testing-old-env"] + env.reset() + assert len(env.step(env.action_space.sample())) == 5 + env.step(env.action_space.sample()) + _, _, termination, truncation, _ = env.step(env.action_space.sample()) + assert termination is False and truncation is True -def test_make_order_enforcing(): +def test_order_enforcing(register_parameter_envs): """Checks that gym.make wrappers the environment with the OrderEnforcing wrapper.""" - assert all(spec.order_enforce is True for spec in all_testing_env_specs) + assert all(spec.order_enforce is False for spec in all_testing_env_specs) - env = gym.make("CartPole-v1") - assert has_wrapper(env, OrderEnforcing) - # We can assume that there all other specs will also have the order enforcing - env.close() + for make_id in ["CartPole-v1", gym.spec("CartPole-v1")]: + env = gym.make(make_id) + assert has_wrapper(env, OrderEnforcing) - gym.register( - id="test.OrderlessArgumentEnv-v0", - entry_point="tests.envs.registration.utils_envs:ArgumentEnv", - order_enforce=False, - kwargs={"arg1": None, "arg2": None, "arg3": None}, - ) - - env = gym.make("test.OrderlessArgumentEnv-v0") - assert has_wrapper(env, OrderEnforcing) is False - env.close() + for make_id in ["OrderlessEnv-v0", gym.spec("OrderlessEnv-v0")]: + env = gym.make(make_id) + assert has_wrapper(env, OrderEnforcing) is False # There is no `make(..., order_enforcing=...)` so we don't test that -def test_make_render_mode(): +def test_make_with_render_mode(): """Test the `make(..., render_mode=...)`, in particular, if to apply the `RenderCollection` or the `HumanRendering`.""" env = gym.make("CartPole-v1", render_mode=None) assert env.render_mode is None @@ -267,10 +251,12 @@ def test_make_render_collection(): env.close() -def test_make_human_rendering(register_testing_envs): +def test_make_human_rendering(register_rendering_testing_envs): # Make sure that native rendering is used when possible env = gym.make("CartPole-v1", render_mode="human") - assert not has_wrapper(env, HumanRendering) # Should use native human-rendering + assert ( + has_wrapper(env, HumanRendering) is False + ) # Should use native human-rendering assert env.render_mode == "human" env.close() @@ -281,7 +267,7 @@ def test_make_human_rendering(register_testing_envs): ), ): # Make sure that `HumanRendering` is applied here as the environment doesn't use native rendering - env = gym.make("test/NoHuman-v0", render_mode="human") + env = gym.make("NoHumanRendering-v0", render_mode="human") assert has_wrapper(env, HumanRendering) assert env.render_mode == "human" env.close() @@ -290,7 +276,7 @@ def test_make_human_rendering(register_testing_envs): TypeError, match=re.escape("got an unexpected keyword argument 'render_mode'") ): gym.make( - "test/NoHumanOldAPI-v0", + "NoHumanRenderingOldAPI-v0", render_mode="rgb_array_list", ) @@ -299,10 +285,10 @@ def test_make_human_rendering(register_testing_envs): with pytest.raises( gym.error.Error, match=re.escape( - "You passed render_mode='human' although test/NoHumanOldAPI-v0 doesn't implement human-rendering natively." + "You passed render_mode='human' although NoHumanRenderingOldAPI-v0 doesn't implement human-rendering natively." ), ): - gym.make("test/NoHumanOldAPI-v0", render_mode="human") + gym.make("NoHumanRenderingOldAPI-v0", render_mode="human") # This test ensures that the additional exception "Gym tried to apply the HumanRendering wrapper but it looks like # your environment is using the old rendering API" is *not* triggered by a TypeError that originate from @@ -320,10 +306,10 @@ def test_make_human_rendering(register_testing_envs): "\x1b[33mWARN: The environment is being initialised with render_mode='rgb_array' that is not in the possible render_modes ([]).\x1b[0m" ), ): - gym.make("test/NoRenderModesMetadata-v0", render_mode="rgb_array") + gym.make("NoRenderModesMetadata-v0", render_mode="rgb_array") -def test_make_kwargs(register_testing_envs): +def test_make_kwargs(register_kwargs_env): env = gym.make( "test.ArgumentEnv-v0", arg2="override_arg2", @@ -367,12 +353,11 @@ class NoRecordArgsWrapper(gym.ObservationWrapper): return self.observation_space.sample() -def test_make_env_spec(): +def test_make_with_env_spec(): # make - env_1 = gym.make(gym.spec("CartPole-v1")) - assert isinstance(env_1, CartPoleEnv) - assert env_1 is env_1.unwrapped - env_1.close() + id_env = gym.make("CartPole-v1") + spec_env = gym.make(gym.spec("CartPole-v1")) + assert id_env.spec == spec_env.spec # make with applied wrappers env_2 = gym.wrappers.NormalizeReward( @@ -396,30 +381,202 @@ def test_make_env_spec(): # make with wrapper in env-creator gym.register( - "CartPole-v3", lambda: gym.wrappers.TimeAwareObservation(CartPoleEnv()) + "CartPole-v3", + lambda: gym.wrappers.TimeAwareObservation(CartPoleEnv()), + disable_env_checker=True, + order_enforce=False, ) env_4 = gym.make(gym.spec("CartPole-v3")) assert isinstance(env_4, gym.wrappers.TimeAwareObservation) assert isinstance(env_4.env, CartPoleEnv) env_4.close() - # make with no ezpickle wrapper - env_5 = NoRecordArgsWrapper(gym.make("CartPole-v1").unwrapped) + gym.register( + "CartPole-v4", + lambda: CartPoleEnv(), + disable_env_checker=True, + order_enforce=False, + additional_wrappers=(gym.wrappers.TimeAwareObservation.wrapper_spec(),), + ) + env_5 = gym.make(gym.spec("CartPole-v4")) + assert isinstance(env_5, gym.wrappers.TimeAwareObservation) + assert isinstance(env_5.env, CartPoleEnv) env_5.close() + + # make with no ezpickle wrapper + env_6 = NoRecordArgsWrapper(gym.make("CartPole-v1")) with pytest.raises( ValueError, match=re.escape( "NoRecordArgsWrapper wrapper does not inherit from `gymnasium.utils.RecordConstructorArgs`, therefore, the wrapper cannot be recreated." ), ): - gym.make(env_5.spec) + gym.make(env_6.spec) # make with no ezpickle wrapper but in the entry point - gym.register("CartPole-v4", entry_point=lambda: NoRecordArgsWrapper(CartPoleEnv())) - env_6 = gym.make(gym.spec("CartPole-v4")) - assert isinstance(env_6, NoRecordArgsWrapper) - assert isinstance(env_6.unwrapped, CartPoleEnv) + gym.register( + "CartPole-v5", + entry_point=lambda: NoRecordArgsWrapper(CartPoleEnv()), + disable_env_checker=True, + order_enforce=False, + ) + env_7 = gym.make(gym.spec("CartPole-v5")) + assert isinstance(env_7, NoRecordArgsWrapper) + assert isinstance(env_7.unwrapped, CartPoleEnv) + + gym.register( + "CartPole-v6", + entry_point=lambda: CartPoleEnv(), + disable_env_checker=True, + order_enforce=False, + additional_wrappers=(NoRecordArgsWrapper.wrapper_spec(),), + ) del gym.registry["CartPole-v2"] del gym.registry["CartPole-v3"] del gym.registry["CartPole-v4"] + del gym.registry["CartPole-v5"] + del gym.registry["CartPole-v6"] + + +def test_make_with_env_spec_levels(): + """Test that we can recreate the environment at each 'level'.""" + env = gym.wrappers.NormalizeReward( + gym.wrappers.TimeAwareObservation( + gym.wrappers.FlattenObservation( + gym.make("CartPole-v1", render_mode="rgb_array") + ) + ), + gamma=0.8, + ) + + while env is not env.unwrapped: + recreated_env = gym.make(env.spec) + assert env.spec == recreated_env.spec + + env = env.env + + +def test_wrapped_env_entry_point(): + def _create_env(): + _env = gym.make("CartPole-v1", render_mode="rgb_array") + _env = gym.wrappers.FlattenObservation(_env) + return _env + + gym.register("TestingEnv-v0", entry_point=_create_env) + + env = gym.make("TestingEnv-v0") + env = gym.wrappers.TimeAwareObservation(env) + env = gym.wrappers.NormalizeReward(env, gamma=0.8) + + recreated_env = gym.make(env.spec) + + obs, info = env.reset(seed=42) + recreated_obs, recreated_info = recreated_env.reset(seed=42) + assert data_equivalence(obs, recreated_obs) + assert data_equivalence(info, recreated_info) + + action = env.action_space.sample() + obs, reward, terminated, truncated, info = env.step(action) + ( + recreated_obs, + recreated_reward, + recreated_terminated, + recreated_truncated, + recreated_info, + ) = recreated_env.step(action) + assert data_equivalence(obs, recreated_obs) + assert data_equivalence(reward, recreated_reward) + assert data_equivalence(terminated, recreated_terminated) + assert data_equivalence(truncated, recreated_truncated) + assert data_equivalence(info, recreated_info) + + del gym.registry["TestingEnv-v0"] + + +def test_make_errors(): + """Test make with a deprecated environment (i.e., doesn't exist).""" + with warnings.catch_warnings(record=True): + with pytest.raises( + gym.error.Error, + match=re.escape( + "Environment version v0 for `Humanoid` is deprecated. Please use `Humanoid-v4` instead." + ), + ): + gym.make("Humanoid-v0") + + with pytest.raises( + NameNotFound, match=re.escape("Environment `NonExistenceEnv` doesn't exist.") + ): + gym.make("NonExistenceEnv-v0") + + +@pytest.fixture(scope="function") +def register_parameter_envs(): + gym.register( + "NoMaxEpisodeStepsEnv-v0", lambda: GenericTestEnv(), max_episode_steps=None + ) + gym.register("AutoresetEnv-v0", lambda: GenericTestEnv(), autoreset=True) + + gym.register( + "EnabledApplyApiComp-v0", + lambda: GenericTestEnv(step_func=old_step_func), + apply_api_compatibility=True, + max_episode_steps=3, + ) + gym.register( + "DisabledApplyApiComp-v0", + lambda: GenericTestEnv(step_func=old_step_func), + apply_api_compatibility=False, + max_episode_steps=3, + ) + gym.register("OrderlessEnv-v0", lambda: GenericTestEnv(), order_enforce=False) + + yield + + del gym.registry["NoMaxEpisodeStepsEnv-v0"] + del gym.registry["AutoresetEnv-v0"] + del gym.registry["EnabledApplyApiComp-v0"] + del gym.registry["DisabledApplyApiComp-v0"] + del gym.registry["OrderlessEnv-v0"] + + +@pytest.fixture(scope="function") +def register_kwargs_env(): + gym.register( + id="test.ArgumentEnv-v0", + entry_point="tests.envs.registration.utils_envs:ArgumentEnv", + kwargs={ + "arg1": "arg1", + "arg2": "arg2", + }, + ) + + +@pytest.fixture(scope="function") +def register_rendering_testing_envs(): + gym.register( + id="NoHumanRendering-v0", + entry_point="tests.envs.registration.utils_envs:NoHuman", + ) + gym.register( + id="NoHumanRenderingOldAPI-v0", + entry_point="tests.envs.registration.utils_envs:NoHumanOldAPI", + ) + + gym.register( + id="NoHumanRenderingNoRGB-v0", + entry_point="tests.envs.registration.utils_envs:NoHumanNoRGB", + ) + + gym.register( + id="NoRenderModesMetadata-v0", + entry_point="tests.envs.registration.utils_envs:NoRenderModesMetadata", + ) + + yield + + del gym.envs.registration.registry["NoHumanRendering-v0"] + del gym.envs.registration.registry["NoHumanRenderingOldAPI-v0"] + del gym.envs.registration.registry["NoHumanRenderingNoRGB-v0"] + del gym.envs.registration.registry["NoRenderModesMetadata-v0"] diff --git a/tests/test_core.py b/tests/test_core.py index da2aa0e4c..048338691 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -171,7 +171,7 @@ def test_compatibility_with_old_style_env(): """Test compatibility with old style environment.""" env = OldStyleEnv() env = OrderEnforcing(env) - env = TimeLimit(env) + env = TimeLimit(env, 100) obs = env.reset() assert obs == 0 diff --git a/tests/wrappers/test_time_limit.py b/tests/wrappers/test_time_limit.py index 269a203b6..ca479a818 100644 --- a/tests/wrappers/test_time_limit.py +++ b/tests/wrappers/test_time_limit.py @@ -7,7 +7,7 @@ from gymnasium.wrappers import TimeLimit def test_time_limit_reset_info(): env = gym.make("CartPole-v1", disable_env_checker=True) - env = TimeLimit(env) + env = TimeLimit(env, 100) ob_space = env.observation_space obs, info = env.reset() assert ob_space.contains(obs)