Update EnvSpec to improve backward compatibility (#355)

This commit is contained in:
Mark Towers
2023-03-08 14:07:09 +00:00
committed by GitHub
parent cb3f0d19cd
commit 9d8db14e7f
10 changed files with 620 additions and 439 deletions

View File

@@ -11,7 +11,7 @@ from gymnasium.utils import RecordConstructorArgs, seeding
if TYPE_CHECKING: if TYPE_CHECKING:
from gymnasium.envs.registration import EnvSpec from gymnasium.envs.registration import EnvSpec, WrapperSpec
ObsType = TypeVar("ObsType") ObsType = TypeVar("ObsType")
ActType = TypeVar("ActType") ActType = TypeVar("ActType")
@@ -266,6 +266,8 @@ class Wrapper(
self._reward_range: tuple[SupportsFloat, SupportsFloat] | None = None self._reward_range: tuple[SupportsFloat, SupportsFloat] | None = None
self._metadata: dict[str, Any] | None = None self._metadata: dict[str, Any] | None = None
self._cached_spec: EnvSpec | None = None
def __getattr__(self, name: str) -> Any: def __getattr__(self, name: str) -> Any:
"""Returns an attribute with ``name``, unless ``name`` starts with an underscore.""" """Returns an attribute with ``name``, unless ``name`` starts with an underscore."""
if name == "_np_random": if name == "_np_random":
@@ -279,11 +281,11 @@ class Wrapper(
@property @property
def spec(self) -> EnvSpec | None: def spec(self) -> EnvSpec | None:
"""Returns the :attr:`Env` :attr:`spec` attribute with the `WrapperSpec` if the wrapper inherits from `EzPickle`.""" """Returns the :attr:`Env` :attr:`spec` attribute with the `WrapperSpec` if the wrapper inherits from `EzPickle`."""
if self._cached_spec is not None:
return self._cached_spec
env_spec = self.env.spec env_spec = self.env.spec
if env_spec is not None: if env_spec is not None:
from gymnasium.envs.registration import WrapperSpec
# See if the wrapper inherits from `RecordConstructorArgs` then add the kwargs otherwise use `None` for the wrapper kwargs. This will raise an error in `make` # See if the wrapper inherits from `RecordConstructorArgs` then add the kwargs otherwise use `None` for the wrapper kwargs. This will raise an error in `make`
if isinstance(self, RecordConstructorArgs): if isinstance(self, RecordConstructorArgs):
kwargs = getattr(self, "_saved_kwargs") kwargs = getattr(self, "_saved_kwargs")
@@ -293,6 +295,8 @@ class Wrapper(
else: else:
kwargs = None kwargs = None
from gymnasium.envs.registration import WrapperSpec
wrapper_spec = WrapperSpec( wrapper_spec = WrapperSpec(
name=self.class_name(), name=self.class_name(),
entry_point=f"{self.__module__}:{type(self).__name__}", entry_point=f"{self.__module__}:{type(self).__name__}",
@@ -301,10 +305,22 @@ class Wrapper(
# to avoid reference issues we deepcopy the prior environments spec and add the new information # to avoid reference issues we deepcopy the prior environments spec and add the new information
env_spec = deepcopy(env_spec) env_spec = deepcopy(env_spec)
env_spec.applied_wrappers += (wrapper_spec,) env_spec.additional_wrappers += (wrapper_spec,)
self._cached_spec = env_spec
return env_spec return env_spec
@classmethod
def wrapper_spec(cls, **kwargs: Any) -> WrapperSpec:
"""Generates a `WrapperSpec` for the wrappers."""
from gymnasium.envs.registration import WrapperSpec
return WrapperSpec(
name=cls.class_name(),
entry_point=f"{cls.__module__}:{cls.__name__}",
kwargs=kwargs,
)
@classmethod @classmethod
def class_name(cls) -> str: def class_name(cls) -> str:
"""Returns the class name of the wrapper.""" """Returns the class name of the wrapper."""

View File

@@ -99,7 +99,7 @@ class EnvSpec:
* **autoreset**: If to automatically reset the environment on episode end * **autoreset**: If to automatically reset the environment on episode end
* **disable_env_checker**: If to disable the environment checker wrapper in :meth:`gymnasium.make`, by default False (runs the environment checker) * **disable_env_checker**: If to disable the environment checker wrapper in :meth:`gymnasium.make`, by default False (runs the environment checker)
* **kwargs**: Additional keyword arguments passed to the environment during initialisation * **kwargs**: Additional keyword arguments passed to the environment during initialisation
* **applied_wrappers**: A tuple of applied wrappers (WrapperSpec) * **additional_wrappers**: A tuple of additional wrappers applied to the environment (WrapperSpec)
* **vector_entry_point**: The location of the vectorized environment to create from * **vector_entry_point**: The location of the vectorized environment to create from
""" """
@@ -126,7 +126,7 @@ class EnvSpec:
version: int | None = field(init=False) version: int | None = field(init=False)
# applied wrappers # applied wrappers
applied_wrappers: tuple[WrapperSpec, ...] = field(init=False, default_factory=tuple) additional_wrappers: tuple[WrapperSpec, ...] = field(default_factory=tuple)
# Vectorized environment entry point # Vectorized environment entry point
vector_entry_point: VectorEnvCreator | str | None = field(default=None) vector_entry_point: VectorEnvCreator | str | None = field(default=None)
@@ -187,7 +187,7 @@ class EnvSpec:
parsed_env_spec = json.loads(json_env_spec) parsed_env_spec = json.loads(json_env_spec)
applied_wrapper_specs: list[WrapperSpec] = [] applied_wrapper_specs: list[WrapperSpec] = []
for wrapper_spec_json in parsed_env_spec.pop("applied_wrappers"): for wrapper_spec_json in parsed_env_spec.pop("additional_wrappers"):
try: try:
applied_wrapper_specs.append(WrapperSpec(**wrapper_spec_json)) applied_wrapper_specs.append(WrapperSpec(**wrapper_spec_json))
except Exception as e: except Exception as e:
@@ -197,7 +197,7 @@ class EnvSpec:
try: try:
env_spec = EnvSpec(**parsed_env_spec) env_spec = EnvSpec(**parsed_env_spec)
env_spec.applied_wrappers = tuple(applied_wrapper_specs) env_spec.additional_wrappers = tuple(applied_wrapper_specs)
except Exception as e: except Exception as e:
raise ValueError( raise ValueError(
f"An issue occurred when trying to make {parsed_env_spec} an EnvSpec" f"An issue occurred when trying to make {parsed_env_spec} an EnvSpec"
@@ -241,9 +241,9 @@ class EnvSpec:
if print_all or self.apply_api_compatibility is not False: if print_all or self.apply_api_compatibility is not False:
output += f"\napplied_api_compatibility={self.apply_api_compatibility}" output += f"\napplied_api_compatibility={self.apply_api_compatibility}"
if print_all or self.applied_wrappers: if print_all or self.additional_wrappers:
wrapper_output: list[str] = [] wrapper_output: list[str] = []
for wrapper_spec in self.applied_wrappers: for wrapper_spec in self.additional_wrappers:
if include_entry_points: if include_entry_points:
wrapper_output.append( wrapper_output.append(
f"\n\tname={wrapper_spec.name}, entry_point={wrapper_spec.entry_point}, kwargs={wrapper_spec.kwargs}" f"\n\tname={wrapper_spec.name}, entry_point={wrapper_spec.entry_point}, kwargs={wrapper_spec.kwargs}"
@@ -254,9 +254,9 @@ class EnvSpec:
) )
if len(wrapper_output) == 0: if len(wrapper_output) == 0:
output += "\napplied_wrappers=[]" output += "\nadditional_wrappers=[]"
else: else:
output += f"\napplied_wrappers=[{','.join(wrapper_output)}\n]" output += f"\nadditional_wrappers=[{','.join(wrapper_output)}\n]"
if disable_print: if disable_print:
return output return output
@@ -555,161 +555,6 @@ def load_env_creator(name: str) -> EnvCreator | VectorEnvCreator:
return fn return fn
def _create_from_env_spec(
env_spec: EnvSpec,
kwargs: dict[str, Any],
) -> Env:
"""Recreates an environment spec using a list of wrapper specs."""
if callable(env_spec.entry_point):
env_creator = env_spec.entry_point
else:
env_creator: EnvCreator = load_env_creator(env_spec.entry_point)
# Create the environment
env: Env = env_creator(**env_spec.kwargs, **kwargs)
# Set the `EnvSpec` to the environment
new_env_spec = copy.deepcopy(env_spec)
new_env_spec.applied_wrappers = ()
new_env_spec.kwargs.update(kwargs)
env.unwrapped.spec = new_env_spec
# Check if the environment spec
assert env.spec is not None # this is for pyright
num_prior_wrappers = len(env.spec.applied_wrappers)
if env_spec.applied_wrappers[:num_prior_wrappers] != env.spec.applied_wrappers:
for env_spec_wrapper_spec, recreated_wrapper_spec in zip(
env_spec.applied_wrappers, env.spec.applied_wrappers
):
raise ValueError(
f"The environment's wrapper spec {recreated_wrapper_spec} is different from the saved `EnvSpec` applied_wrappers {env_spec_wrapper_spec}"
)
for wrapper_spec in env_spec.applied_wrappers[num_prior_wrappers:]:
if wrapper_spec.kwargs is None:
raise ValueError(
f"{wrapper_spec.name} wrapper does not inherit from `gymnasium.utils.RecordConstructorArgs`, therefore, the wrapper cannot be recreated."
)
env = load_env_creator(wrapper_spec.entry_point)(env=env, **wrapper_spec.kwargs)
return env
def _create_from_env_id(
env_spec: EnvSpec,
kwargs: dict[str, Any],
max_episode_steps: int | None = None,
autoreset: bool = False,
apply_api_compatibility: bool | None = None,
disable_env_checker: bool | None = None,
) -> Env:
"""Creates an environment based on the `env_spec` along with wrapper options. See `make` for their meaning."""
spec_kwargs = copy.deepcopy(env_spec.kwargs)
spec_kwargs.update(kwargs)
# Load the environment creator
if env_spec.entry_point is None:
raise error.Error(f"{env_spec.id} registered but entry_point is not specified")
elif callable(env_spec.entry_point):
env_creator = env_spec.entry_point
else:
# Assume it's a string
env_creator = load_env_creator(env_spec.entry_point)
# Determine if to use the rendering
render_modes: list[str] | None = None
if hasattr(env_creator, "metadata"):
_check_metadata(env_creator.metadata)
render_modes = env_creator.metadata.get("render_modes")
mode = spec_kwargs.get("render_mode")
apply_human_rendering = False
apply_render_collection = False
# If mode is not valid, try applying HumanRendering/RenderCollection wrappers
if mode is not None and render_modes is not None and mode not in render_modes:
displayable_modes = {"rgb_array", "rgb_array_list"}.intersection(render_modes)
if mode == "human" and len(displayable_modes) > 0:
logger.warn(
"You are trying to use 'human' rendering for an environment that doesn't natively support it. "
"The HumanRendering wrapper is being applied to your environment."
)
spec_kwargs["render_mode"] = displayable_modes.pop()
apply_human_rendering = True
elif mode.endswith("_list") and mode[: -len("_list")] in render_modes:
spec_kwargs["render_mode"] = mode[: -len("_list")]
apply_render_collection = True
else:
logger.warn(
f"The environment is being initialised with render_mode={mode!r} "
f"that is not in the possible render_modes ({render_modes})."
)
if apply_api_compatibility or (
apply_api_compatibility is None and env_spec.apply_api_compatibility
):
# If we use the compatibility layer, we treat the render mode explicitly and don't pass it to the env creator
render_mode = spec_kwargs.pop("render_mode", None)
else:
render_mode = None
try:
env = env_creator(**spec_kwargs)
except TypeError as e:
if (
str(e).find("got an unexpected keyword argument 'render_mode'") >= 0
and apply_human_rendering
):
raise error.Error(
f"You passed render_mode='human' although {env_spec.id} doesn't implement human-rendering natively. "
"Gym tried to apply the HumanRendering wrapper but it looks like your environment is using the old "
"rendering API, which is not supported by the HumanRendering wrapper."
) from e
else:
raise e
# Copies the environment creation specification and kwargs to add to the environment specification details
env_spec = copy.deepcopy(env_spec)
env_spec.kwargs = spec_kwargs
env.unwrapped.spec = env_spec
# Add step API wrapper
if apply_api_compatibility is True or (
apply_api_compatibility is None and env_spec.apply_api_compatibility is True
):
env = EnvCompatibility(env, render_mode)
# Run the environment checker as the lowest level wrapper
if disable_env_checker is False or (
disable_env_checker is None and env_spec.disable_env_checker is False
):
env = PassiveEnvChecker(env)
# Add the order enforcing wrapper
if env_spec.order_enforce:
env = OrderEnforcing(env)
# Add the time limit wrapper
if max_episode_steps is not None:
assert env.unwrapped.spec is not None # for pyright
env.unwrapped.spec.max_episode_steps = max_episode_steps
env = TimeLimit(env, max_episode_steps)
elif env_spec.max_episode_steps is not None:
env = TimeLimit(env, env_spec.max_episode_steps)
# Add the auto-reset wrapper
if autoreset:
env = AutoResetWrapper(env)
# Add human rendering wrapper
if apply_human_rendering:
env = HumanRendering(env)
elif apply_render_collection:
env = RenderCollection(env)
return env
def load_plugin_envs(entry_point: str = "gymnasium.envs"): def load_plugin_envs(entry_point: str = "gymnasium.envs"):
"""Load modules (plugins) using the gymnasium entry points in order to register external module's environments on ``import gymnasium``. """Load modules (plugins) using the gymnasium entry points in order to register external module's environments on ``import gymnasium``.
@@ -779,6 +624,7 @@ def register(
autoreset: bool = False, autoreset: bool = False,
disable_env_checker: bool = False, disable_env_checker: bool = False,
apply_api_compatibility: bool = False, apply_api_compatibility: bool = False,
additional_wrappers: tuple[WrapperSpec, ...] = (),
vector_entry_point: VectorEnvCreator | str | None = None, vector_entry_point: VectorEnvCreator | str | None = None,
**kwargs: Any, **kwargs: Any,
): ):
@@ -801,6 +647,7 @@ def register(
disable_env_checker: If to disable the :class:`gymnasium.wrappers.PassiveEnvChecker` to the environment. disable_env_checker: If to disable the :class:`gymnasium.wrappers.PassiveEnvChecker` to the environment.
apply_api_compatibility: If to apply the :class:`gymnasium.wrappers.StepAPICompatibility` wrapper to the environment. apply_api_compatibility: If to apply the :class:`gymnasium.wrappers.StepAPICompatibility` wrapper to the environment.
Use if the environment is implemented in the gym v0.21 environment API. Use if the environment is implemented in the gym v0.21 environment API.
additional_wrappers: Additional wrappers to apply the environment.
vector_entry_point: The entry point for creating the vector environment vector_entry_point: The entry point for creating the vector environment
**kwargs: arbitrary keyword arguments which are passed to the environment constructor on initialisation. **kwargs: arbitrary keyword arguments which are passed to the environment constructor on initialisation.
""" """
@@ -835,8 +682,9 @@ def register(
order_enforce=order_enforce, order_enforce=order_enforce,
autoreset=autoreset, autoreset=autoreset,
disable_env_checker=disable_env_checker, disable_env_checker=disable_env_checker,
**kwargs,
apply_api_compatibility=apply_api_compatibility, apply_api_compatibility=apply_api_compatibility,
**kwargs,
additional_wrappers=additional_wrappers,
vector_entry_point=vector_entry_point, vector_entry_point=vector_entry_point,
) )
_check_spec_register(new_spec) _check_spec_register(new_spec)
@@ -849,7 +697,7 @@ def register(
def make( def make(
id: str | EnvSpec, id: str | EnvSpec,
max_episode_steps: int | None = None, max_episode_steps: int | None = None,
autoreset: bool = False, autoreset: bool | None = None,
apply_api_compatibility: bool | None = None, apply_api_compatibility: bool | None = None,
disable_env_checker: bool | None = None, disable_env_checker: bool | None = None,
**kwargs: Any, **kwargs: Any,
@@ -878,32 +726,12 @@ def make(
Error: If the ``id`` doesn't exist in the :attr:`registry` Error: If the ``id`` doesn't exist in the :attr:`registry`
""" """
if isinstance(id, EnvSpec): if isinstance(id, EnvSpec):
if hasattr(id, "applied_wrappers") and id.applied_wrappers is not None: env_spec = id
if max_episode_steps is not None: if not hasattr(env_spec, "additional_wrappers"):
logger.warn( logger.warn(
f"For `gymnasium.make` with an `EnvSpec`, the `max_episode_step` parameter is ignored, use `gym.make({id.id}, max_episode_steps={max_episode_steps})` and any additional wrappers" f"The env spec passed to `make` does not have a `additional_wrappers`, set it to an empty tuple. Env_spec={env_spec}"
)
if autoreset is True:
logger.warn(
f"For `gymnasium.make` with an `EnvSpec`, the `autoreset` parameter is ignored, use `gym.make({id.id}, autoreset={autoreset})` and any additional wrappers"
)
if apply_api_compatibility is not None:
logger.warn(
f"For `gymnasium.make` with an `EnvSpec`, the `apply_api_compatibility` parameter is ignored, use `gym.make({id.id}, apply_api_compatibility={apply_api_compatibility})` and any additional wrappers"
)
if disable_env_checker is not None:
logger.warn(
f"For `gymnasium.make` with an `EnvSpec`, the `disable_env_checker` parameter is ignored, use `gym.make({id.id}, disable_env_checker={disable_env_checker})` and any additional wrappers"
)
return _create_from_env_spec(
id,
kwargs,
)
else:
raise ValueError(
f"The EnvSpec used does not contain `applied_wrappers` parameters or is `None`. Expected to be a tuple, actually {id}."
) )
env_spec.additional_wrappers = ()
else: else:
# For string id's, load the environment spec from the registry then make the environment spec # For string id's, load the environment spec from the registry then make the environment spec
assert isinstance(id, str) assert isinstance(id, str)
@@ -911,14 +739,150 @@ def make(
# The environment name can include an unloaded module in "module:env_name" style # The environment name can include an unloaded module in "module:env_name" style
env_spec = _find_spec(id) env_spec = _find_spec(id)
return _create_from_env_id( assert isinstance(env_spec, EnvSpec)
env_spec,
kwargs, # Update the env spec kwargs with the `make` kwargs
max_episode_steps=max_episode_steps, env_spec_kwargs = copy.deepcopy(env_spec.kwargs)
autoreset=autoreset, env_spec_kwargs.update(kwargs)
apply_api_compatibility=apply_api_compatibility,
disable_env_checker=disable_env_checker, # Load the environment creator
) if env_spec.entry_point is None:
raise error.Error(f"{env_spec.id} registered but entry_point is not specified")
elif callable(env_spec.entry_point):
env_creator = env_spec.entry_point
else:
# Assume it's a string
env_creator = load_env_creator(env_spec.entry_point)
# Determine if to use the rendering
render_modes: list[str] | None = None
if hasattr(env_creator, "metadata"):
_check_metadata(env_creator.metadata)
render_modes = env_creator.metadata.get("render_modes")
render_mode = env_spec_kwargs.get("render_mode")
apply_human_rendering = False
apply_render_collection = False
# If mode is not valid, try applying HumanRendering/RenderCollection wrappers
if (
render_mode is not None
and render_modes is not None
and render_mode not in render_modes
):
displayable_modes = {"rgb_array", "rgb_array_list"}.intersection(render_modes)
if render_mode == "human" and len(displayable_modes) > 0:
logger.warn(
"You are trying to use 'human' rendering for an environment that doesn't natively support it. "
"The HumanRendering wrapper is being applied to your environment."
)
env_spec_kwargs["render_mode"] = displayable_modes.pop()
apply_human_rendering = True
elif (
render_mode.endswith("_list")
and render_mode[: -len("_list")] in render_modes
):
env_spec_kwargs["render_mode"] = render_mode[: -len("_list")]
apply_render_collection = True
else:
logger.warn(
f"The environment is being initialised with render_mode={render_mode!r} "
f"that is not in the possible render_modes ({render_modes})."
)
if apply_api_compatibility or (
apply_api_compatibility is None and env_spec.apply_api_compatibility
):
# If we use the compatibility layer, we treat the render mode explicitly and don't pass it to the env creator
render_mode = env_spec_kwargs.pop("render_mode", None)
else:
render_mode = None
try:
env = env_creator(**env_spec_kwargs)
except TypeError as e:
if (
str(e).find("got an unexpected keyword argument 'render_mode'") >= 0
and apply_human_rendering
):
raise error.Error(
f"You passed render_mode='human' although {env_spec.id} doesn't implement human-rendering natively. "
"Gym tried to apply the HumanRendering wrapper but it looks like your environment is using the old "
"rendering API, which is not supported by the HumanRendering wrapper."
) from e
else:
raise e
# Set the minimal env spec for the environment.
env.unwrapped.spec = EnvSpec(
id=env_spec.id,
entry_point=env_spec.entry_point,
reward_threshold=env_spec.reward_threshold,
nondeterministic=env_spec.nondeterministic,
max_episode_steps=None,
order_enforce=False,
autoreset=False,
disable_env_checker=True,
apply_api_compatibility=False,
kwargs=env_spec_kwargs,
additional_wrappers=(),
vector_entry_point=env_spec.vector_entry_point,
)
# Check if pre-wrapped wrappers
assert env.spec is not None
num_prior_wrappers = len(env.spec.additional_wrappers)
if (
env_spec.additional_wrappers[:num_prior_wrappers]
!= env.spec.additional_wrappers
):
for env_spec_wrapper_spec, recreated_wrapper_spec in zip(
env_spec.additional_wrappers, env.spec.additional_wrappers
):
raise ValueError(
f"The environment's wrapper spec {recreated_wrapper_spec} is different from the saved `EnvSpec` additional wrapper {env_spec_wrapper_spec}"
)
# Add step API wrapper
if apply_api_compatibility is True or (
apply_api_compatibility is None and env_spec.apply_api_compatibility is True
):
env = EnvCompatibility(env, render_mode)
# Run the environment checker as the lowest level wrapper
if disable_env_checker is False or (
disable_env_checker is None and env_spec.disable_env_checker is False
):
env = PassiveEnvChecker(env)
# Add the order enforcing wrapper
if env_spec.order_enforce:
env = OrderEnforcing(env)
# Add the time limit wrapper
if max_episode_steps is not None:
env = TimeLimit(env, max_episode_steps)
elif env_spec.max_episode_steps is not None:
env = TimeLimit(env, env_spec.max_episode_steps)
# Add the auto-reset wrapper
if autoreset is True or (autoreset is None and env_spec.autoreset is True):
env = AutoResetWrapper(env)
for wrapper_spec in env_spec.additional_wrappers[num_prior_wrappers:]:
if wrapper_spec.kwargs is None:
raise ValueError(
f"{wrapper_spec.name} wrapper does not inherit from `gymnasium.utils.RecordConstructorArgs`, therefore, the wrapper cannot be recreated."
)
env = load_env_creator(wrapper_spec.entry_point)(env=env, **wrapper_spec.kwargs)
# Add human rendering wrapper
if apply_human_rendering:
env = HumanRendering(env)
elif apply_render_collection:
env = RenderCollection(env)
return env
def make_vec( def make_vec(

View File

@@ -1,7 +1,16 @@
"""Wrapper that autoreset environments when `terminated=True` or `truncated=True`.""" """Wrapper that autoreset environments when `terminated=True` or `truncated=True`."""
from __future__ import annotations
from copy import deepcopy
from typing import TYPE_CHECKING
import gymnasium as gym import gymnasium as gym
if TYPE_CHECKING:
from gymnasium.envs.registration import EnvSpec
class AutoResetWrapper(gym.Wrapper, gym.utils.RecordConstructorArgs): class AutoResetWrapper(gym.Wrapper, gym.utils.RecordConstructorArgs):
"""A class for providing an automatic reset functionality for gymnasium environments when calling :meth:`self.step`. """A class for providing an automatic reset functionality for gymnasium environments when calling :meth:`self.step`.
@@ -60,3 +69,17 @@ class AutoResetWrapper(gym.Wrapper, gym.utils.RecordConstructorArgs):
info = new_info info = new_info
return obs, reward, terminated, truncated, info return obs, reward, terminated, truncated, info
@property
def spec(self) -> EnvSpec | None:
"""Modifies the environment spec to specify the `autoreset=True`."""
if self._cached_spec is not None:
return self._cached_spec
env_spec = self.env.spec
if env_spec is not None:
env_spec = deepcopy(env_spec)
env_spec.autoreset = True
self._cached_spec = env_spec
return env_spec

View File

@@ -1,4 +1,9 @@
"""A passive environment checker wrapper for an environment's observation and action space along with the reset, step and render functions.""" """A passive environment checker wrapper for an environment's observation and action space along with the reset, step and render functions."""
from __future__ import annotations
from copy import deepcopy
from typing import TYPE_CHECKING
import gymnasium as gym import gymnasium as gym
from gymnasium.core import ActType from gymnasium.core import ActType
from gymnasium.utils.passive_env_checker import ( from gymnasium.utils.passive_env_checker import (
@@ -10,6 +15,10 @@ from gymnasium.utils.passive_env_checker import (
) )
if TYPE_CHECKING:
from gymnasium.envs.registration import EnvSpec
class PassiveEnvChecker(gym.Wrapper, gym.utils.RecordConstructorArgs): class PassiveEnvChecker(gym.Wrapper, gym.utils.RecordConstructorArgs):
"""A passive environment checker wrapper that surrounds the step, reset and render functions to check they follow the gymnasium API.""" """A passive environment checker wrapper that surrounds the step, reset and render functions to check they follow the gymnasium API."""
@@ -54,3 +63,17 @@ class PassiveEnvChecker(gym.Wrapper, gym.utils.RecordConstructorArgs):
return env_render_passive_checker(self.env, *args, **kwargs) return env_render_passive_checker(self.env, *args, **kwargs)
else: else:
return self.env.render(*args, **kwargs) return self.env.render(*args, **kwargs)
@property
def spec(self) -> EnvSpec | None:
"""Modifies the environment spec to such that `disable_env_checker=False`."""
if self._cached_spec is not None:
return self._cached_spec
env_spec = self.env.spec
if env_spec is not None:
env_spec = deepcopy(env_spec)
env_spec.disable_env_checker = False
self._cached_spec = env_spec
return env_spec

View File

@@ -1,8 +1,17 @@
"""Wrapper to enforce the proper ordering of environment operations.""" """Wrapper to enforce the proper ordering of environment operations."""
from __future__ import annotations
from copy import deepcopy
from typing import TYPE_CHECKING
import gymnasium as gym import gymnasium as gym
from gymnasium.error import ResetNeeded from gymnasium.error import ResetNeeded
if TYPE_CHECKING:
from gymnasium.envs.registration import EnvSpec
class OrderEnforcing(gym.Wrapper, gym.utils.RecordConstructorArgs): class OrderEnforcing(gym.Wrapper, gym.utils.RecordConstructorArgs):
"""A wrapper that will produce an error if :meth:`step` is called before an initial :meth:`reset`. """A wrapper that will produce an error if :meth:`step` is called before an initial :meth:`reset`.
@@ -64,3 +73,17 @@ class OrderEnforcing(gym.Wrapper, gym.utils.RecordConstructorArgs):
def has_reset(self): def has_reset(self):
"""Returns if the environment has been reset before.""" """Returns if the environment has been reset before."""
return self._has_reset return self._has_reset
@property
def spec(self) -> EnvSpec | None:
"""Modifies the environment spec to add the `order_enforce=True`."""
if self._cached_spec is not None:
return self._cached_spec
env_spec = self.env.spec
if env_spec is not None:
env_spec = deepcopy(env_spec)
env_spec.order_enforce = True
self._cached_spec = env_spec
return env_spec

View File

@@ -1,9 +1,16 @@
"""Wrapper for limiting the time steps of an environment.""" """Wrapper for limiting the time steps of an environment."""
from typing import Optional from __future__ import annotations
from copy import deepcopy
from typing import TYPE_CHECKING
import gymnasium as gym import gymnasium as gym
if TYPE_CHECKING:
from gymnasium.envs.registration import EnvSpec
class TimeLimit(gym.Wrapper, gym.utils.RecordConstructorArgs): class TimeLimit(gym.Wrapper, gym.utils.RecordConstructorArgs):
"""This wrapper will issue a `truncated` signal if a maximum number of timesteps is exceeded. """This wrapper will issue a `truncated` signal if a maximum number of timesteps is exceeded.
@@ -20,7 +27,7 @@ class TimeLimit(gym.Wrapper, gym.utils.RecordConstructorArgs):
def __init__( def __init__(
self, self,
env: gym.Env, env: gym.Env,
max_episode_steps: Optional[int] = None, max_episode_steps: int,
): ):
"""Initializes the :class:`TimeLimit` wrapper with an environment and the number of steps after which truncation will occur. """Initializes the :class:`TimeLimit` wrapper with an environment and the number of steps after which truncation will occur.
@@ -33,9 +40,6 @@ class TimeLimit(gym.Wrapper, gym.utils.RecordConstructorArgs):
) )
gym.Wrapper.__init__(self, env) gym.Wrapper.__init__(self, env)
if max_episode_steps is None and self.env.spec is not None:
assert env.spec is not None
max_episode_steps = env.spec.max_episode_steps
self._max_episode_steps = max_episode_steps self._max_episode_steps = max_episode_steps
self._elapsed_steps = None self._elapsed_steps = None
@@ -69,3 +73,17 @@ class TimeLimit(gym.Wrapper, gym.utils.RecordConstructorArgs):
""" """
self._elapsed_steps = 0 self._elapsed_steps = 0
return self.env.reset(**kwargs) return self.env.reset(**kwargs)
@property
def spec(self) -> EnvSpec | None:
"""Modifies the environment spec to include the `max_episode_steps=self._max_episode_steps`."""
if self._cached_spec is not None:
return self._cached_spec
env_spec = self.env.spec
if env_spec is not None:
env_spec = deepcopy(env_spec)
env_spec.max_episode_steps = self._max_episode_steps
self._cached_spec = env_spec
return env_spec

View File

@@ -1,4 +1,4 @@
"""Example file showing usage of env.specstack.""" """Test for the `EnvSpec`, in particular, a full integration with `EnvSpec`."""
import pickle import pickle
import pytest import pytest
@@ -11,15 +11,16 @@ from gymnasium.utils.env_checker import data_equivalence
def test_full_integration(): def test_full_integration():
# Create an environment to test with # Create an environment to test with
env = gym.make("CartPole-v1", render_mode="rgb_array").unwrapped env = gym.make("CartPole-v1", render_mode="rgb_array")
env = gym.wrappers.FlattenObservation(env)
env = gym.wrappers.TimeAwareObservation(env) env = gym.wrappers.TimeAwareObservation(env)
env = gym.wrappers.NormalizeReward(env, gamma=0.8) env = gym.wrappers.NormalizeReward(env, gamma=0.8)
# Generate the spec_stack # Generate the spec_stack
env_spec = env.spec env_spec = env.spec
assert isinstance(env_spec, EnvSpec) assert isinstance(env_spec, EnvSpec)
# additional_wrappers = (TimeAwareObservation, NormalizeReward)
assert len(env_spec.additional_wrappers) == 2
# env_spec.pprint() # env_spec.pprint()
# Serialize the spec_stack # Serialize the spec_stack
@@ -30,10 +31,7 @@ def test_full_integration():
recreate_env_spec = EnvSpec.from_json(env_spec_json) recreate_env_spec = EnvSpec.from_json(env_spec_json)
# recreate_env_spec.pprint() # recreate_env_spec.pprint()
for wrapper_spec, recreated_wrapper_spec in zip( assert env_spec.additional_wrappers == recreate_env_spec.additional_wrappers
env_spec.applied_wrappers, recreate_env_spec.applied_wrappers
):
assert wrapper_spec == recreated_wrapper_spec
assert recreate_env_spec == env_spec assert recreate_env_spec == env_spec
# Recreate the environment using the spec_stack # Recreate the environment using the spec_stack
@@ -86,43 +84,6 @@ def test_env_spec_to_from_json(env_spec: EnvSpec):
assert env_spec == recreated_env_spec assert env_spec == recreated_env_spec
def test_wrapped_env_entry_point():
def _create_env():
_env = gym.make("CartPole-v1", render_mode="rgb_array")
_env = gym.wrappers.FlattenObservation(_env)
return _env
gym.register("TestingEnv-v0", entry_point=_create_env)
env = gym.make("TestingEnv-v0")
env = gym.wrappers.TimeAwareObservation(env)
env = gym.wrappers.NormalizeReward(env, gamma=0.8)
recreated_env = gym.make(env.spec)
obs, info = env.reset(seed=42)
recreated_obs, recreated_info = recreated_env.reset(seed=42)
assert data_equivalence(obs, recreated_obs)
assert data_equivalence(info, recreated_info)
action = env.action_space.sample()
obs, reward, terminated, truncated, info = env.step(action)
(
recreated_obs,
recreated_reward,
recreated_terminated,
recreated_truncated,
recreated_info,
) = recreated_env.step(action)
assert data_equivalence(obs, recreated_obs)
assert data_equivalence(reward, recreated_reward)
assert data_equivalence(terminated, recreated_terminated)
assert data_equivalence(truncated, recreated_truncated)
assert data_equivalence(info, recreated_info)
del gym.registry["TestingEnv-v0"]
def test_pickling_env_stack(): def test_pickling_env_stack():
env = gym.make("CartPole-v1", render_mode="rgb_array") env = gym.make("CartPole-v1", render_mode="rgb_array")
@@ -163,6 +124,8 @@ def test_pickling_env_stack():
def test_env_spec_pprint(): def test_env_spec_pprint():
env = gym.make("CartPole-v1") env = gym.make("CartPole-v1")
env = gym.wrappers.TimeAwareObservation(env)
env_spec = env.spec env_spec = env.spec
assert env_spec is not None assert env_spec is not None
@@ -172,10 +135,8 @@ def test_env_spec_pprint():
== """id=CartPole-v1 == """id=CartPole-v1
reward_threshold=475.0 reward_threshold=475.0
max_episode_steps=500 max_episode_steps=500
applied_wrappers=[ additional_wrappers=[
name=PassiveEnvChecker, kwargs={}, name=TimeAwareObservation, kwargs={}
name=OrderEnforcing, kwargs={'disable_render_order_enforcing': False},
name=TimeLimit, kwargs={'max_episode_steps': 500}
]""" ]"""
) )
@@ -186,10 +147,8 @@ applied_wrappers=[
entry_point=gymnasium.envs.classic_control.cartpole:CartPoleEnv entry_point=gymnasium.envs.classic_control.cartpole:CartPoleEnv
reward_threshold=475.0 reward_threshold=475.0
max_episode_steps=500 max_episode_steps=500
applied_wrappers=[ additional_wrappers=[
name=PassiveEnvChecker, entry_point=gymnasium.wrappers.env_checker:PassiveEnvChecker, kwargs={}, name=TimeAwareObservation, entry_point=gymnasium.wrappers.time_aware_observation:TimeAwareObservation, kwargs={}
name=OrderEnforcing, entry_point=gymnasium.wrappers.order_enforcing:OrderEnforcing, kwargs={'disable_render_order_enforcing': False},
name=TimeLimit, entry_point=gymnasium.wrappers.time_limit:TimeLimit, kwargs={'max_episode_steps': 500}
]""" ]"""
) )
@@ -205,14 +164,12 @@ order_enforce=True
autoreset=False autoreset=False
disable_env_checker=False disable_env_checker=False
applied_api_compatibility=False applied_api_compatibility=False
applied_wrappers=[ additional_wrappers=[
name=PassiveEnvChecker, kwargs={}, name=TimeAwareObservation, kwargs={}
name=OrderEnforcing, kwargs={'disable_render_order_enforcing': False},
name=TimeLimit, kwargs={'max_episode_steps': 500}
]""" ]"""
) )
env_spec.applied_wrappers = () env_spec.additional_wrappers = ()
output = env_spec.pprint(disable_print=True) output = env_spec.pprint(disable_print=True)
assert ( assert (
output output
@@ -233,5 +190,5 @@ order_enforce=True
autoreset=False autoreset=False
disable_env_checker=False disable_env_checker=False
applied_api_compatibility=False applied_api_compatibility=False
applied_wrappers=[]""" additional_wrappers=[]"""
) )

View File

@@ -1,4 +1,4 @@
"""Tests that gym.make works as expected.""" """Tests that `gym.make` works as expected."""
from __future__ import annotations from __future__ import annotations
import re import re
@@ -11,6 +11,8 @@ import gymnasium as gym
from gymnasium import Env from gymnasium import Env
from gymnasium.core import ActType, ObsType, WrapperObsType from gymnasium.core import ActType, ObsType, WrapperObsType
from gymnasium.envs.classic_control import CartPoleEnv from gymnasium.envs.classic_control import CartPoleEnv
from gymnasium.error import NameNotFound
from gymnasium.utils.env_checker import data_equivalence
from gymnasium.wrappers import ( from gymnasium.wrappers import (
AutoResetWrapper, AutoResetWrapper,
HumanRendering, HumanRendering,
@@ -24,111 +26,101 @@ from tests.testing_env import GenericTestEnv, old_step_func
from tests.wrappers.utils import has_wrapper from tests.wrappers.utils import has_wrapper
try: # Tests
import shimmy # * basic example
except ImportError: # * parameters (equivalent for str and EnvSpec)
shimmy = None # 1. max_episode_steps
# 2. autoreset
# 3. apply_api_compatibility
# 4. disable_env_checker
# * rendering
# 1. render_mode
# 2. HumanRendering
# 3. RenderCollection
# * make kwargs
# * make import module
# * make env spec additional wrappers
# * env_id str errors
@pytest.fixture(scope="function") def test_no_arguments(env_id: str = "CartPole-v1"):
def register_testing_envs(): """Test `gym.make` using str and EnvSpec with no arguments."""
"""Registers testing envs for `gym.make`""" env_from_id = gym.make(env_id)
gym.register( assert env_from_id.spec is not None
id="test.ArgumentEnv-v0", assert env_from_id.spec.id == env_id
entry_point="tests.envs.registration.utils_envs:ArgumentEnv", assert isinstance(env_from_id.unwrapped, CartPoleEnv)
kwargs={
"arg1": "arg1",
"arg2": "arg2",
},
)
gym.register( env_spec = gym.spec(env_id)
id="test/NoHuman-v0", env_from_spec = gym.make(env_spec)
entry_point="tests.envs.registration.utils_envs:NoHuman", assert env_from_spec.spec is not None
) assert env_from_spec.spec.id == env_id
gym.register( assert isinstance(env_from_spec.unwrapped, CartPoleEnv)
id="test/NoHumanOldAPI-v0",
entry_point="tests.envs.registration.utils_envs:NoHumanOldAPI",
)
gym.register( assert env_from_id.spec == env_from_spec.spec
id="test/NoHumanNoRGB-v0",
entry_point="tests.envs.registration.utils_envs:NoHumanNoRGB",
)
gym.register(
id="test/NoRenderModesMetadata-v0",
entry_point="tests.envs.registration.utils_envs:NoRenderModesMetadata",
)
yield
del gym.envs.registration.registry["test.ArgumentEnv-v0"]
del gym.envs.registration.registry["test/NoRenderModesMetadata-v0"]
del gym.envs.registration.registry["test/NoHuman-v0"]
del gym.envs.registration.registry["test/NoHumanOldAPI-v0"]
del gym.envs.registration.registry["test/NoHumanNoRGB-v0"]
def test_make(): def test_max_episode_steps(register_parameter_envs):
"""Test basic `gym.make`.""" """Test the `max_episode_steps` parameter in `gym.make`."""
env = gym.make("CartPole-v1") for make_id in ["CartPole-v1", gym.spec("CartPole-v1")]:
assert env.spec is not None env_spec = gym.spec(make_id) if isinstance(make_id, str) else make_id
assert env.spec.id == "CartPole-v1"
assert isinstance(env.unwrapped, CartPoleEnv) # Use the spec's value
env.close() env = gym.make(make_id)
assert has_wrapper(env, TimeLimit)
assert env.spec is not None
assert env.spec.max_episode_steps == env_spec.max_episode_steps
# Set a custom max episode steps value
assert env_spec.max_episode_steps != 100
env = gym.make(make_id, max_episode_steps=100)
assert has_wrapper(env, TimeLimit)
assert env.spec is not None
assert env.spec.max_episode_steps == 100, make_id
for make_id in ["NoMaxEpisodeStepsEnv-v0", gym.spec("NoMaxEpisodeStepsEnv-v0")]:
env_spec = gym.spec(make_id) if isinstance(make_id, str) else make_id
# env spec has no max episode steps
assert env_spec.max_episode_steps is None
env = gym.make(make_id)
assert env.spec is not None
assert env.spec.max_episode_steps is None
assert has_wrapper(env, TimeLimit) is False
# set a custom max episode steps values
env = gym.make(make_id, max_episode_steps=100)
assert env.spec is not None
assert env.spec.max_episode_steps == 100
assert has_wrapper(env, TimeLimit)
def test_make_deprecated(): def test_autorest(register_parameter_envs):
"""Test make with a deprecated environment (i.e., doesn't exist).""" """Test the `autoreset` parameter in `gym.make`."""
with warnings.catch_warnings(record=True): for make_id in [
with pytest.raises( "CartPole-v1",
gym.error.Error, gym.spec("CartPole-v1"),
match=re.escape( "AutoresetEnv-v0",
"Environment version v0 for `Humanoid` is deprecated. Please use `Humanoid-v4` instead." gym.spec("AutoresetEnv-v0"),
), ]:
): env_spec = gym.spec(make_id) if isinstance(make_id, str) else make_id
gym.make("Humanoid-v0")
# Use the spec's value
env = gym.make(make_id)
assert env.spec is not None
assert env.spec.autoreset == env_spec.autoreset
assert has_wrapper(env, AutoResetWrapper) is env_spec.autoreset
def test_make_max_episode_steps(register_testing_envs): # Set autoreset is True
# Default, uses the spec's env = gym.make(make_id, autoreset=True)
env = gym.make("CartPole-v1") assert has_wrapper(env, AutoResetWrapper)
assert has_wrapper(env, TimeLimit) assert env.spec is not None
assert env.spec is not None assert env.spec.autoreset is True
assert env.spec.max_episode_steps == gym.spec("CartPole-v1").max_episode_steps
env.close()
# Custom max episode steps # Set autoreset is False
assert gym.spec("CartPole-v1").max_episode_steps != 100 env = gym.make(make_id, autoreset=False)
env = gym.make("CartPole-v1", max_episode_steps=100) assert has_wrapper(env, AutoResetWrapper) is False
assert has_wrapper(env, TimeLimit) assert env.spec is not None
assert env.spec is not None assert env.spec.autoreset is False
assert env.spec.max_episode_steps == 100
env.close()
# Env spec has no max episode steps
assert gym.spec("test.ArgumentEnv-v0").max_episode_steps is None
env = gym.make("test.ArgumentEnv-v0", arg1=None, arg2=None, arg3=None)
assert env.spec is not None
assert env.spec.max_episode_steps is None
assert has_wrapper(env, TimeLimit) is False
env.close()
def test_make_autoreset():
"""Tests that `gym.make` autoreset wrapper is applied only when `gym.make(..., autoreset=True)`."""
env = gym.make("CartPole-v1")
assert has_wrapper(env, AutoResetWrapper) is False
env.close()
env = gym.make("CartPole-v1", autoreset=False)
assert has_wrapper(env, AutoResetWrapper) is False
env.close()
env = gym.make("CartPole-v1", autoreset=True)
assert has_wrapper(env, AutoResetWrapper)
env.close()
@pytest.mark.parametrize( @pytest.mark.parametrize(
@@ -142,7 +134,7 @@ def test_make_autoreset():
[True, None, True], [True, None, True],
], ],
) )
def test_make_disable_env_checker( def test_disable_env_checker(
registration_disabled: bool, make_disabled: bool | None, if_disabled: bool registration_disabled: bool, make_disabled: bool | None, if_disabled: bool
): ):
"""Tests that `gym.make` disable env checker is applied only when `gym.make(..., disable_env_checker=False)`. """Tests that `gym.make` disable env checker is applied only when `gym.make(..., disable_env_checker=False)`.
@@ -150,84 +142,76 @@ def test_make_disable_env_checker(
The ordering is 1. if the `make(..., disable_env_checker=...)` is bool, then the `registration(..., disable_env_checker=...)` The ordering is 1. if the `make(..., disable_env_checker=...)` is bool, then the `registration(..., disable_env_checker=...)`
""" """
gym.register( gym.register(
"testing-env-v0", "DisableEnvCheckerEnv-v0",
lambda: GenericTestEnv(), lambda: GenericTestEnv(),
disable_env_checker=registration_disabled, disable_env_checker=registration_disabled,
) )
# Test when the registered EnvSpec.disable_env_checker = False # Test when the registered EnvSpec.disable_env_checker = False
env = gym.make("testing-env-v0", disable_env_checker=make_disabled) env = gym.make("DisableEnvCheckerEnv-v0", disable_env_checker=make_disabled)
assert has_wrapper(env, PassiveEnvChecker) is not if_disabled assert has_wrapper(env, PassiveEnvChecker) is not if_disabled
env.close()
del gym.registry["testing-env-v0"] env_spec = gym.spec("DisableEnvCheckerEnv-v0")
env = gym.make(env_spec, disable_env_checker=make_disabled)
assert has_wrapper(env, PassiveEnvChecker) is not if_disabled
del gym.registry["DisableEnvCheckerEnv-v0"]
def test_make_apply_api_compatibility(): def test_apply_api_compatibility(register_parameter_envs):
"""Test the API compatibility wrapper.""" """Test the `apply_api_compatibility` parameter for `gym.make`."""
gym.register(
"testing-old-env",
lambda: GenericTestEnv(step_func=old_step_func),
apply_api_compatibility=True,
max_episode_steps=3,
)
# Apply the environment compatibility and check it works as intended # Apply the environment compatibility and check it works as intended
env = gym.make("testing-old-env") for make_id in ["EnabledApplyApiComp-v0", gym.spec("EnabledApplyApiComp-v0")]:
assert isinstance(env.unwrapped, gym.wrappers.EnvCompatibility) env = gym.make(make_id)
assert isinstance(env.unwrapped, gym.wrappers.EnvCompatibility)
env.reset() # env has time limit of 3 enabling this test
assert len(env.step(env.action_space.sample())) == 5 env.reset()
env.step(env.action_space.sample()) assert len(env.step(env.action_space.sample())) == 5
_, _, termination, truncation, _ = env.step(env.action_space.sample())
assert termination is False and truncation is True
# Turn off the spec api compatibility
gym.spec("testing-old-env").apply_api_compatibility = False
env = gym.make("testing-old-env")
assert isinstance(env.unwrapped, gym.wrappers.EnvCompatibility) is False
env.reset()
with pytest.raises(
ValueError, match=re.escape("not enough values to unpack (expected 5, got 4)")
):
env.step(env.action_space.sample()) env.step(env.action_space.sample())
_, _, termination, truncation, _ = env.step(env.action_space.sample())
assert termination is False and truncation is True
# Apply the environment compatibility and check it works as intended for make_id in ["DisabledApplyApiComp-v0", gym.spec("DisabledApplyApiComp-v0")]:
env = gym.make("testing-old-env", apply_api_compatibility=True) # Turn off the spec api compatibility
assert isinstance(env.unwrapped, gym.wrappers.EnvCompatibility) env = gym.make(make_id)
assert isinstance(env.unwrapped, gym.wrappers.EnvCompatibility) is False
env.reset()
with pytest.raises(
ValueError,
match=re.escape("not enough values to unpack (expected 5, got 4)"),
):
env.step(env.action_space.sample())
env.reset() # Apply the environment compatibility and check it works as intended
assert len(env.step(env.action_space.sample())) == 5 assert env.spec is not None
env.step(env.action_space.sample()) assert env.spec.apply_api_compatibility is False
_, _, termination, truncation, _ = env.step(env.action_space.sample()) env = gym.make(make_id, apply_api_compatibility=True)
assert termination is False and truncation is True assert isinstance(env.unwrapped, gym.wrappers.EnvCompatibility)
del gym.registry["testing-old-env"] env.reset()
assert len(env.step(env.action_space.sample())) == 5
env.step(env.action_space.sample())
_, _, termination, truncation, _ = env.step(env.action_space.sample())
assert termination is False and truncation is True
def test_make_order_enforcing(): def test_order_enforcing(register_parameter_envs):
"""Checks that gym.make wrappers the environment with the OrderEnforcing wrapper.""" """Checks that gym.make wrappers the environment with the OrderEnforcing wrapper."""
assert all(spec.order_enforce is True for spec in all_testing_env_specs) assert all(spec.order_enforce is False for spec in all_testing_env_specs)
env = gym.make("CartPole-v1") for make_id in ["CartPole-v1", gym.spec("CartPole-v1")]:
assert has_wrapper(env, OrderEnforcing) env = gym.make(make_id)
# We can assume that there all other specs will also have the order enforcing assert has_wrapper(env, OrderEnforcing)
env.close()
gym.register( for make_id in ["OrderlessEnv-v0", gym.spec("OrderlessEnv-v0")]:
id="test.OrderlessArgumentEnv-v0", env = gym.make(make_id)
entry_point="tests.envs.registration.utils_envs:ArgumentEnv", assert has_wrapper(env, OrderEnforcing) is False
order_enforce=False,
kwargs={"arg1": None, "arg2": None, "arg3": None},
)
env = gym.make("test.OrderlessArgumentEnv-v0")
assert has_wrapper(env, OrderEnforcing) is False
env.close()
# There is no `make(..., order_enforcing=...)` so we don't test that # There is no `make(..., order_enforcing=...)` so we don't test that
def test_make_render_mode(): def test_make_with_render_mode():
"""Test the `make(..., render_mode=...)`, in particular, if to apply the `RenderCollection` or the `HumanRendering`.""" """Test the `make(..., render_mode=...)`, in particular, if to apply the `RenderCollection` or the `HumanRendering`."""
env = gym.make("CartPole-v1", render_mode=None) env = gym.make("CartPole-v1", render_mode=None)
assert env.render_mode is None assert env.render_mode is None
@@ -267,10 +251,12 @@ def test_make_render_collection():
env.close() env.close()
def test_make_human_rendering(register_testing_envs): def test_make_human_rendering(register_rendering_testing_envs):
# Make sure that native rendering is used when possible # Make sure that native rendering is used when possible
env = gym.make("CartPole-v1", render_mode="human") env = gym.make("CartPole-v1", render_mode="human")
assert not has_wrapper(env, HumanRendering) # Should use native human-rendering assert (
has_wrapper(env, HumanRendering) is False
) # Should use native human-rendering
assert env.render_mode == "human" assert env.render_mode == "human"
env.close() env.close()
@@ -281,7 +267,7 @@ def test_make_human_rendering(register_testing_envs):
), ),
): ):
# Make sure that `HumanRendering` is applied here as the environment doesn't use native rendering # Make sure that `HumanRendering` is applied here as the environment doesn't use native rendering
env = gym.make("test/NoHuman-v0", render_mode="human") env = gym.make("NoHumanRendering-v0", render_mode="human")
assert has_wrapper(env, HumanRendering) assert has_wrapper(env, HumanRendering)
assert env.render_mode == "human" assert env.render_mode == "human"
env.close() env.close()
@@ -290,7 +276,7 @@ def test_make_human_rendering(register_testing_envs):
TypeError, match=re.escape("got an unexpected keyword argument 'render_mode'") TypeError, match=re.escape("got an unexpected keyword argument 'render_mode'")
): ):
gym.make( gym.make(
"test/NoHumanOldAPI-v0", "NoHumanRenderingOldAPI-v0",
render_mode="rgb_array_list", render_mode="rgb_array_list",
) )
@@ -299,10 +285,10 @@ def test_make_human_rendering(register_testing_envs):
with pytest.raises( with pytest.raises(
gym.error.Error, gym.error.Error,
match=re.escape( match=re.escape(
"You passed render_mode='human' although test/NoHumanOldAPI-v0 doesn't implement human-rendering natively." "You passed render_mode='human' although NoHumanRenderingOldAPI-v0 doesn't implement human-rendering natively."
), ),
): ):
gym.make("test/NoHumanOldAPI-v0", render_mode="human") gym.make("NoHumanRenderingOldAPI-v0", render_mode="human")
# This test ensures that the additional exception "Gym tried to apply the HumanRendering wrapper but it looks like # This test ensures that the additional exception "Gym tried to apply the HumanRendering wrapper but it looks like
# your environment is using the old rendering API" is *not* triggered by a TypeError that originate from # your environment is using the old rendering API" is *not* triggered by a TypeError that originate from
@@ -320,10 +306,10 @@ def test_make_human_rendering(register_testing_envs):
"\x1b[33mWARN: The environment is being initialised with render_mode='rgb_array' that is not in the possible render_modes ([]).\x1b[0m" "\x1b[33mWARN: The environment is being initialised with render_mode='rgb_array' that is not in the possible render_modes ([]).\x1b[0m"
), ),
): ):
gym.make("test/NoRenderModesMetadata-v0", render_mode="rgb_array") gym.make("NoRenderModesMetadata-v0", render_mode="rgb_array")
def test_make_kwargs(register_testing_envs): def test_make_kwargs(register_kwargs_env):
env = gym.make( env = gym.make(
"test.ArgumentEnv-v0", "test.ArgumentEnv-v0",
arg2="override_arg2", arg2="override_arg2",
@@ -367,12 +353,11 @@ class NoRecordArgsWrapper(gym.ObservationWrapper):
return self.observation_space.sample() return self.observation_space.sample()
def test_make_env_spec(): def test_make_with_env_spec():
# make # make
env_1 = gym.make(gym.spec("CartPole-v1")) id_env = gym.make("CartPole-v1")
assert isinstance(env_1, CartPoleEnv) spec_env = gym.make(gym.spec("CartPole-v1"))
assert env_1 is env_1.unwrapped assert id_env.spec == spec_env.spec
env_1.close()
# make with applied wrappers # make with applied wrappers
env_2 = gym.wrappers.NormalizeReward( env_2 = gym.wrappers.NormalizeReward(
@@ -396,30 +381,202 @@ def test_make_env_spec():
# make with wrapper in env-creator # make with wrapper in env-creator
gym.register( gym.register(
"CartPole-v3", lambda: gym.wrappers.TimeAwareObservation(CartPoleEnv()) "CartPole-v3",
lambda: gym.wrappers.TimeAwareObservation(CartPoleEnv()),
disable_env_checker=True,
order_enforce=False,
) )
env_4 = gym.make(gym.spec("CartPole-v3")) env_4 = gym.make(gym.spec("CartPole-v3"))
assert isinstance(env_4, gym.wrappers.TimeAwareObservation) assert isinstance(env_4, gym.wrappers.TimeAwareObservation)
assert isinstance(env_4.env, CartPoleEnv) assert isinstance(env_4.env, CartPoleEnv)
env_4.close() env_4.close()
# make with no ezpickle wrapper gym.register(
env_5 = NoRecordArgsWrapper(gym.make("CartPole-v1").unwrapped) "CartPole-v4",
lambda: CartPoleEnv(),
disable_env_checker=True,
order_enforce=False,
additional_wrappers=(gym.wrappers.TimeAwareObservation.wrapper_spec(),),
)
env_5 = gym.make(gym.spec("CartPole-v4"))
assert isinstance(env_5, gym.wrappers.TimeAwareObservation)
assert isinstance(env_5.env, CartPoleEnv)
env_5.close() env_5.close()
# make with no ezpickle wrapper
env_6 = NoRecordArgsWrapper(gym.make("CartPole-v1"))
with pytest.raises( with pytest.raises(
ValueError, ValueError,
match=re.escape( match=re.escape(
"NoRecordArgsWrapper wrapper does not inherit from `gymnasium.utils.RecordConstructorArgs`, therefore, the wrapper cannot be recreated." "NoRecordArgsWrapper wrapper does not inherit from `gymnasium.utils.RecordConstructorArgs`, therefore, the wrapper cannot be recreated."
), ),
): ):
gym.make(env_5.spec) gym.make(env_6.spec)
# make with no ezpickle wrapper but in the entry point # make with no ezpickle wrapper but in the entry point
gym.register("CartPole-v4", entry_point=lambda: NoRecordArgsWrapper(CartPoleEnv())) gym.register(
env_6 = gym.make(gym.spec("CartPole-v4")) "CartPole-v5",
assert isinstance(env_6, NoRecordArgsWrapper) entry_point=lambda: NoRecordArgsWrapper(CartPoleEnv()),
assert isinstance(env_6.unwrapped, CartPoleEnv) disable_env_checker=True,
order_enforce=False,
)
env_7 = gym.make(gym.spec("CartPole-v5"))
assert isinstance(env_7, NoRecordArgsWrapper)
assert isinstance(env_7.unwrapped, CartPoleEnv)
gym.register(
"CartPole-v6",
entry_point=lambda: CartPoleEnv(),
disable_env_checker=True,
order_enforce=False,
additional_wrappers=(NoRecordArgsWrapper.wrapper_spec(),),
)
del gym.registry["CartPole-v2"] del gym.registry["CartPole-v2"]
del gym.registry["CartPole-v3"] del gym.registry["CartPole-v3"]
del gym.registry["CartPole-v4"] del gym.registry["CartPole-v4"]
del gym.registry["CartPole-v5"]
del gym.registry["CartPole-v6"]
def test_make_with_env_spec_levels():
"""Test that we can recreate the environment at each 'level'."""
env = gym.wrappers.NormalizeReward(
gym.wrappers.TimeAwareObservation(
gym.wrappers.FlattenObservation(
gym.make("CartPole-v1", render_mode="rgb_array")
)
),
gamma=0.8,
)
while env is not env.unwrapped:
recreated_env = gym.make(env.spec)
assert env.spec == recreated_env.spec
env = env.env
def test_wrapped_env_entry_point():
def _create_env():
_env = gym.make("CartPole-v1", render_mode="rgb_array")
_env = gym.wrappers.FlattenObservation(_env)
return _env
gym.register("TestingEnv-v0", entry_point=_create_env)
env = gym.make("TestingEnv-v0")
env = gym.wrappers.TimeAwareObservation(env)
env = gym.wrappers.NormalizeReward(env, gamma=0.8)
recreated_env = gym.make(env.spec)
obs, info = env.reset(seed=42)
recreated_obs, recreated_info = recreated_env.reset(seed=42)
assert data_equivalence(obs, recreated_obs)
assert data_equivalence(info, recreated_info)
action = env.action_space.sample()
obs, reward, terminated, truncated, info = env.step(action)
(
recreated_obs,
recreated_reward,
recreated_terminated,
recreated_truncated,
recreated_info,
) = recreated_env.step(action)
assert data_equivalence(obs, recreated_obs)
assert data_equivalence(reward, recreated_reward)
assert data_equivalence(terminated, recreated_terminated)
assert data_equivalence(truncated, recreated_truncated)
assert data_equivalence(info, recreated_info)
del gym.registry["TestingEnv-v0"]
def test_make_errors():
"""Test make with a deprecated environment (i.e., doesn't exist)."""
with warnings.catch_warnings(record=True):
with pytest.raises(
gym.error.Error,
match=re.escape(
"Environment version v0 for `Humanoid` is deprecated. Please use `Humanoid-v4` instead."
),
):
gym.make("Humanoid-v0")
with pytest.raises(
NameNotFound, match=re.escape("Environment `NonExistenceEnv` doesn't exist.")
):
gym.make("NonExistenceEnv-v0")
@pytest.fixture(scope="function")
def register_parameter_envs():
gym.register(
"NoMaxEpisodeStepsEnv-v0", lambda: GenericTestEnv(), max_episode_steps=None
)
gym.register("AutoresetEnv-v0", lambda: GenericTestEnv(), autoreset=True)
gym.register(
"EnabledApplyApiComp-v0",
lambda: GenericTestEnv(step_func=old_step_func),
apply_api_compatibility=True,
max_episode_steps=3,
)
gym.register(
"DisabledApplyApiComp-v0",
lambda: GenericTestEnv(step_func=old_step_func),
apply_api_compatibility=False,
max_episode_steps=3,
)
gym.register("OrderlessEnv-v0", lambda: GenericTestEnv(), order_enforce=False)
yield
del gym.registry["NoMaxEpisodeStepsEnv-v0"]
del gym.registry["AutoresetEnv-v0"]
del gym.registry["EnabledApplyApiComp-v0"]
del gym.registry["DisabledApplyApiComp-v0"]
del gym.registry["OrderlessEnv-v0"]
@pytest.fixture(scope="function")
def register_kwargs_env():
gym.register(
id="test.ArgumentEnv-v0",
entry_point="tests.envs.registration.utils_envs:ArgumentEnv",
kwargs={
"arg1": "arg1",
"arg2": "arg2",
},
)
@pytest.fixture(scope="function")
def register_rendering_testing_envs():
gym.register(
id="NoHumanRendering-v0",
entry_point="tests.envs.registration.utils_envs:NoHuman",
)
gym.register(
id="NoHumanRenderingOldAPI-v0",
entry_point="tests.envs.registration.utils_envs:NoHumanOldAPI",
)
gym.register(
id="NoHumanRenderingNoRGB-v0",
entry_point="tests.envs.registration.utils_envs:NoHumanNoRGB",
)
gym.register(
id="NoRenderModesMetadata-v0",
entry_point="tests.envs.registration.utils_envs:NoRenderModesMetadata",
)
yield
del gym.envs.registration.registry["NoHumanRendering-v0"]
del gym.envs.registration.registry["NoHumanRenderingOldAPI-v0"]
del gym.envs.registration.registry["NoHumanRenderingNoRGB-v0"]
del gym.envs.registration.registry["NoRenderModesMetadata-v0"]

View File

@@ -171,7 +171,7 @@ def test_compatibility_with_old_style_env():
"""Test compatibility with old style environment.""" """Test compatibility with old style environment."""
env = OldStyleEnv() env = OldStyleEnv()
env = OrderEnforcing(env) env = OrderEnforcing(env)
env = TimeLimit(env) env = TimeLimit(env, 100)
obs = env.reset() obs = env.reset()
assert obs == 0 assert obs == 0

View File

@@ -7,7 +7,7 @@ from gymnasium.wrappers import TimeLimit
def test_time_limit_reset_info(): def test_time_limit_reset_info():
env = gym.make("CartPole-v1", disable_env_checker=True) env = gym.make("CartPole-v1", disable_env_checker=True)
env = TimeLimit(env) env = TimeLimit(env, 100)
ob_space = env.observation_space ob_space = env.observation_space
obs, info = env.reset() obs, info = env.reset()
assert ob_space.contains(obs) assert ob_space.contains(obs)