Update EnvSpec to improve backward compatibility (#355)

This commit is contained in:
Mark Towers
2023-03-08 14:07:09 +00:00
committed by GitHub
parent cb3f0d19cd
commit 9d8db14e7f
10 changed files with 620 additions and 439 deletions

View File

@@ -11,7 +11,7 @@ from gymnasium.utils import RecordConstructorArgs, seeding
if TYPE_CHECKING:
from gymnasium.envs.registration import EnvSpec
from gymnasium.envs.registration import EnvSpec, WrapperSpec
ObsType = TypeVar("ObsType")
ActType = TypeVar("ActType")
@@ -266,6 +266,8 @@ class Wrapper(
self._reward_range: tuple[SupportsFloat, SupportsFloat] | None = None
self._metadata: dict[str, Any] | None = None
self._cached_spec: EnvSpec | None = None
def __getattr__(self, name: str) -> Any:
"""Returns an attribute with ``name``, unless ``name`` starts with an underscore."""
if name == "_np_random":
@@ -279,11 +281,11 @@ class Wrapper(
@property
def spec(self) -> EnvSpec | None:
"""Returns the :attr:`Env` :attr:`spec` attribute with the `WrapperSpec` if the wrapper inherits from `EzPickle`."""
if self._cached_spec is not None:
return self._cached_spec
env_spec = self.env.spec
if env_spec is not None:
from gymnasium.envs.registration import WrapperSpec
# See if the wrapper inherits from `RecordConstructorArgs` then add the kwargs otherwise use `None` for the wrapper kwargs. This will raise an error in `make`
if isinstance(self, RecordConstructorArgs):
kwargs = getattr(self, "_saved_kwargs")
@@ -293,6 +295,8 @@ class Wrapper(
else:
kwargs = None
from gymnasium.envs.registration import WrapperSpec
wrapper_spec = WrapperSpec(
name=self.class_name(),
entry_point=f"{self.__module__}:{type(self).__name__}",
@@ -301,10 +305,22 @@ class Wrapper(
# to avoid reference issues we deepcopy the prior environments spec and add the new information
env_spec = deepcopy(env_spec)
env_spec.applied_wrappers += (wrapper_spec,)
env_spec.additional_wrappers += (wrapper_spec,)
self._cached_spec = env_spec
return env_spec
@classmethod
def wrapper_spec(cls, **kwargs: Any) -> WrapperSpec:
"""Generates a `WrapperSpec` for the wrappers."""
from gymnasium.envs.registration import WrapperSpec
return WrapperSpec(
name=cls.class_name(),
entry_point=f"{cls.__module__}:{cls.__name__}",
kwargs=kwargs,
)
@classmethod
def class_name(cls) -> str:
"""Returns the class name of the wrapper."""

View File

@@ -99,7 +99,7 @@ class EnvSpec:
* **autoreset**: If to automatically reset the environment on episode end
* **disable_env_checker**: If to disable the environment checker wrapper in :meth:`gymnasium.make`, by default False (runs the environment checker)
* **kwargs**: Additional keyword arguments passed to the environment during initialisation
* **applied_wrappers**: A tuple of applied wrappers (WrapperSpec)
* **additional_wrappers**: A tuple of additional wrappers applied to the environment (WrapperSpec)
* **vector_entry_point**: The location of the vectorized environment to create from
"""
@@ -126,7 +126,7 @@ class EnvSpec:
version: int | None = field(init=False)
# applied wrappers
applied_wrappers: tuple[WrapperSpec, ...] = field(init=False, default_factory=tuple)
additional_wrappers: tuple[WrapperSpec, ...] = field(default_factory=tuple)
# Vectorized environment entry point
vector_entry_point: VectorEnvCreator | str | None = field(default=None)
@@ -187,7 +187,7 @@ class EnvSpec:
parsed_env_spec = json.loads(json_env_spec)
applied_wrapper_specs: list[WrapperSpec] = []
for wrapper_spec_json in parsed_env_spec.pop("applied_wrappers"):
for wrapper_spec_json in parsed_env_spec.pop("additional_wrappers"):
try:
applied_wrapper_specs.append(WrapperSpec(**wrapper_spec_json))
except Exception as e:
@@ -197,7 +197,7 @@ class EnvSpec:
try:
env_spec = EnvSpec(**parsed_env_spec)
env_spec.applied_wrappers = tuple(applied_wrapper_specs)
env_spec.additional_wrappers = tuple(applied_wrapper_specs)
except Exception as e:
raise ValueError(
f"An issue occurred when trying to make {parsed_env_spec} an EnvSpec"
@@ -241,9 +241,9 @@ class EnvSpec:
if print_all or self.apply_api_compatibility is not False:
output += f"\napplied_api_compatibility={self.apply_api_compatibility}"
if print_all or self.applied_wrappers:
if print_all or self.additional_wrappers:
wrapper_output: list[str] = []
for wrapper_spec in self.applied_wrappers:
for wrapper_spec in self.additional_wrappers:
if include_entry_points:
wrapper_output.append(
f"\n\tname={wrapper_spec.name}, entry_point={wrapper_spec.entry_point}, kwargs={wrapper_spec.kwargs}"
@@ -254,9 +254,9 @@ class EnvSpec:
)
if len(wrapper_output) == 0:
output += "\napplied_wrappers=[]"
output += "\nadditional_wrappers=[]"
else:
output += f"\napplied_wrappers=[{','.join(wrapper_output)}\n]"
output += f"\nadditional_wrappers=[{','.join(wrapper_output)}\n]"
if disable_print:
return output
@@ -555,161 +555,6 @@ def load_env_creator(name: str) -> EnvCreator | VectorEnvCreator:
return fn
def _create_from_env_spec(
env_spec: EnvSpec,
kwargs: dict[str, Any],
) -> Env:
"""Recreates an environment spec using a list of wrapper specs."""
if callable(env_spec.entry_point):
env_creator = env_spec.entry_point
else:
env_creator: EnvCreator = load_env_creator(env_spec.entry_point)
# Create the environment
env: Env = env_creator(**env_spec.kwargs, **kwargs)
# Set the `EnvSpec` to the environment
new_env_spec = copy.deepcopy(env_spec)
new_env_spec.applied_wrappers = ()
new_env_spec.kwargs.update(kwargs)
env.unwrapped.spec = new_env_spec
# Check if the environment spec
assert env.spec is not None # this is for pyright
num_prior_wrappers = len(env.spec.applied_wrappers)
if env_spec.applied_wrappers[:num_prior_wrappers] != env.spec.applied_wrappers:
for env_spec_wrapper_spec, recreated_wrapper_spec in zip(
env_spec.applied_wrappers, env.spec.applied_wrappers
):
raise ValueError(
f"The environment's wrapper spec {recreated_wrapper_spec} is different from the saved `EnvSpec` applied_wrappers {env_spec_wrapper_spec}"
)
for wrapper_spec in env_spec.applied_wrappers[num_prior_wrappers:]:
if wrapper_spec.kwargs is None:
raise ValueError(
f"{wrapper_spec.name} wrapper does not inherit from `gymnasium.utils.RecordConstructorArgs`, therefore, the wrapper cannot be recreated."
)
env = load_env_creator(wrapper_spec.entry_point)(env=env, **wrapper_spec.kwargs)
return env
def _create_from_env_id(
env_spec: EnvSpec,
kwargs: dict[str, Any],
max_episode_steps: int | None = None,
autoreset: bool = False,
apply_api_compatibility: bool | None = None,
disable_env_checker: bool | None = None,
) -> Env:
"""Creates an environment based on the `env_spec` along with wrapper options. See `make` for their meaning."""
spec_kwargs = copy.deepcopy(env_spec.kwargs)
spec_kwargs.update(kwargs)
# Load the environment creator
if env_spec.entry_point is None:
raise error.Error(f"{env_spec.id} registered but entry_point is not specified")
elif callable(env_spec.entry_point):
env_creator = env_spec.entry_point
else:
# Assume it's a string
env_creator = load_env_creator(env_spec.entry_point)
# Determine if to use the rendering
render_modes: list[str] | None = None
if hasattr(env_creator, "metadata"):
_check_metadata(env_creator.metadata)
render_modes = env_creator.metadata.get("render_modes")
mode = spec_kwargs.get("render_mode")
apply_human_rendering = False
apply_render_collection = False
# If mode is not valid, try applying HumanRendering/RenderCollection wrappers
if mode is not None and render_modes is not None and mode not in render_modes:
displayable_modes = {"rgb_array", "rgb_array_list"}.intersection(render_modes)
if mode == "human" and len(displayable_modes) > 0:
logger.warn(
"You are trying to use 'human' rendering for an environment that doesn't natively support it. "
"The HumanRendering wrapper is being applied to your environment."
)
spec_kwargs["render_mode"] = displayable_modes.pop()
apply_human_rendering = True
elif mode.endswith("_list") and mode[: -len("_list")] in render_modes:
spec_kwargs["render_mode"] = mode[: -len("_list")]
apply_render_collection = True
else:
logger.warn(
f"The environment is being initialised with render_mode={mode!r} "
f"that is not in the possible render_modes ({render_modes})."
)
if apply_api_compatibility or (
apply_api_compatibility is None and env_spec.apply_api_compatibility
):
# If we use the compatibility layer, we treat the render mode explicitly and don't pass it to the env creator
render_mode = spec_kwargs.pop("render_mode", None)
else:
render_mode = None
try:
env = env_creator(**spec_kwargs)
except TypeError as e:
if (
str(e).find("got an unexpected keyword argument 'render_mode'") >= 0
and apply_human_rendering
):
raise error.Error(
f"You passed render_mode='human' although {env_spec.id} doesn't implement human-rendering natively. "
"Gym tried to apply the HumanRendering wrapper but it looks like your environment is using the old "
"rendering API, which is not supported by the HumanRendering wrapper."
) from e
else:
raise e
# Copies the environment creation specification and kwargs to add to the environment specification details
env_spec = copy.deepcopy(env_spec)
env_spec.kwargs = spec_kwargs
env.unwrapped.spec = env_spec
# Add step API wrapper
if apply_api_compatibility is True or (
apply_api_compatibility is None and env_spec.apply_api_compatibility is True
):
env = EnvCompatibility(env, render_mode)
# Run the environment checker as the lowest level wrapper
if disable_env_checker is False or (
disable_env_checker is None and env_spec.disable_env_checker is False
):
env = PassiveEnvChecker(env)
# Add the order enforcing wrapper
if env_spec.order_enforce:
env = OrderEnforcing(env)
# Add the time limit wrapper
if max_episode_steps is not None:
assert env.unwrapped.spec is not None # for pyright
env.unwrapped.spec.max_episode_steps = max_episode_steps
env = TimeLimit(env, max_episode_steps)
elif env_spec.max_episode_steps is not None:
env = TimeLimit(env, env_spec.max_episode_steps)
# Add the auto-reset wrapper
if autoreset:
env = AutoResetWrapper(env)
# Add human rendering wrapper
if apply_human_rendering:
env = HumanRendering(env)
elif apply_render_collection:
env = RenderCollection(env)
return env
def load_plugin_envs(entry_point: str = "gymnasium.envs"):
"""Load modules (plugins) using the gymnasium entry points in order to register external module's environments on ``import gymnasium``.
@@ -779,6 +624,7 @@ def register(
autoreset: bool = False,
disable_env_checker: bool = False,
apply_api_compatibility: bool = False,
additional_wrappers: tuple[WrapperSpec, ...] = (),
vector_entry_point: VectorEnvCreator | str | None = None,
**kwargs: Any,
):
@@ -801,6 +647,7 @@ def register(
disable_env_checker: If to disable the :class:`gymnasium.wrappers.PassiveEnvChecker` to the environment.
apply_api_compatibility: If to apply the :class:`gymnasium.wrappers.StepAPICompatibility` wrapper to the environment.
Use if the environment is implemented in the gym v0.21 environment API.
additional_wrappers: Additional wrappers to apply the environment.
vector_entry_point: The entry point for creating the vector environment
**kwargs: arbitrary keyword arguments which are passed to the environment constructor on initialisation.
"""
@@ -835,8 +682,9 @@ def register(
order_enforce=order_enforce,
autoreset=autoreset,
disable_env_checker=disable_env_checker,
**kwargs,
apply_api_compatibility=apply_api_compatibility,
**kwargs,
additional_wrappers=additional_wrappers,
vector_entry_point=vector_entry_point,
)
_check_spec_register(new_spec)
@@ -849,7 +697,7 @@ def register(
def make(
id: str | EnvSpec,
max_episode_steps: int | None = None,
autoreset: bool = False,
autoreset: bool | None = None,
apply_api_compatibility: bool | None = None,
disable_env_checker: bool | None = None,
**kwargs: Any,
@@ -878,32 +726,12 @@ def make(
Error: If the ``id`` doesn't exist in the :attr:`registry`
"""
if isinstance(id, EnvSpec):
if hasattr(id, "applied_wrappers") and id.applied_wrappers is not None:
if max_episode_steps is not None:
logger.warn(
f"For `gymnasium.make` with an `EnvSpec`, the `max_episode_step` parameter is ignored, use `gym.make({id.id}, max_episode_steps={max_episode_steps})` and any additional wrappers"
)
if autoreset is True:
logger.warn(
f"For `gymnasium.make` with an `EnvSpec`, the `autoreset` parameter is ignored, use `gym.make({id.id}, autoreset={autoreset})` and any additional wrappers"
)
if apply_api_compatibility is not None:
logger.warn(
f"For `gymnasium.make` with an `EnvSpec`, the `apply_api_compatibility` parameter is ignored, use `gym.make({id.id}, apply_api_compatibility={apply_api_compatibility})` and any additional wrappers"
)
if disable_env_checker is not None:
logger.warn(
f"For `gymnasium.make` with an `EnvSpec`, the `disable_env_checker` parameter is ignored, use `gym.make({id.id}, disable_env_checker={disable_env_checker})` and any additional wrappers"
)
return _create_from_env_spec(
id,
kwargs,
)
else:
raise ValueError(
f"The EnvSpec used does not contain `applied_wrappers` parameters or is `None`. Expected to be a tuple, actually {id}."
env_spec = id
if not hasattr(env_spec, "additional_wrappers"):
logger.warn(
f"The env spec passed to `make` does not have a `additional_wrappers`, set it to an empty tuple. Env_spec={env_spec}"
)
env_spec.additional_wrappers = ()
else:
# For string id's, load the environment spec from the registry then make the environment spec
assert isinstance(id, str)
@@ -911,14 +739,150 @@ def make(
# The environment name can include an unloaded module in "module:env_name" style
env_spec = _find_spec(id)
return _create_from_env_id(
env_spec,
kwargs,
max_episode_steps=max_episode_steps,
autoreset=autoreset,
apply_api_compatibility=apply_api_compatibility,
disable_env_checker=disable_env_checker,
)
assert isinstance(env_spec, EnvSpec)
# Update the env spec kwargs with the `make` kwargs
env_spec_kwargs = copy.deepcopy(env_spec.kwargs)
env_spec_kwargs.update(kwargs)
# Load the environment creator
if env_spec.entry_point is None:
raise error.Error(f"{env_spec.id} registered but entry_point is not specified")
elif callable(env_spec.entry_point):
env_creator = env_spec.entry_point
else:
# Assume it's a string
env_creator = load_env_creator(env_spec.entry_point)
# Determine if to use the rendering
render_modes: list[str] | None = None
if hasattr(env_creator, "metadata"):
_check_metadata(env_creator.metadata)
render_modes = env_creator.metadata.get("render_modes")
render_mode = env_spec_kwargs.get("render_mode")
apply_human_rendering = False
apply_render_collection = False
# If mode is not valid, try applying HumanRendering/RenderCollection wrappers
if (
render_mode is not None
and render_modes is not None
and render_mode not in render_modes
):
displayable_modes = {"rgb_array", "rgb_array_list"}.intersection(render_modes)
if render_mode == "human" and len(displayable_modes) > 0:
logger.warn(
"You are trying to use 'human' rendering for an environment that doesn't natively support it. "
"The HumanRendering wrapper is being applied to your environment."
)
env_spec_kwargs["render_mode"] = displayable_modes.pop()
apply_human_rendering = True
elif (
render_mode.endswith("_list")
and render_mode[: -len("_list")] in render_modes
):
env_spec_kwargs["render_mode"] = render_mode[: -len("_list")]
apply_render_collection = True
else:
logger.warn(
f"The environment is being initialised with render_mode={render_mode!r} "
f"that is not in the possible render_modes ({render_modes})."
)
if apply_api_compatibility or (
apply_api_compatibility is None and env_spec.apply_api_compatibility
):
# If we use the compatibility layer, we treat the render mode explicitly and don't pass it to the env creator
render_mode = env_spec_kwargs.pop("render_mode", None)
else:
render_mode = None
try:
env = env_creator(**env_spec_kwargs)
except TypeError as e:
if (
str(e).find("got an unexpected keyword argument 'render_mode'") >= 0
and apply_human_rendering
):
raise error.Error(
f"You passed render_mode='human' although {env_spec.id} doesn't implement human-rendering natively. "
"Gym tried to apply the HumanRendering wrapper but it looks like your environment is using the old "
"rendering API, which is not supported by the HumanRendering wrapper."
) from e
else:
raise e
# Set the minimal env spec for the environment.
env.unwrapped.spec = EnvSpec(
id=env_spec.id,
entry_point=env_spec.entry_point,
reward_threshold=env_spec.reward_threshold,
nondeterministic=env_spec.nondeterministic,
max_episode_steps=None,
order_enforce=False,
autoreset=False,
disable_env_checker=True,
apply_api_compatibility=False,
kwargs=env_spec_kwargs,
additional_wrappers=(),
vector_entry_point=env_spec.vector_entry_point,
)
# Check if pre-wrapped wrappers
assert env.spec is not None
num_prior_wrappers = len(env.spec.additional_wrappers)
if (
env_spec.additional_wrappers[:num_prior_wrappers]
!= env.spec.additional_wrappers
):
for env_spec_wrapper_spec, recreated_wrapper_spec in zip(
env_spec.additional_wrappers, env.spec.additional_wrappers
):
raise ValueError(
f"The environment's wrapper spec {recreated_wrapper_spec} is different from the saved `EnvSpec` additional wrapper {env_spec_wrapper_spec}"
)
# Add step API wrapper
if apply_api_compatibility is True or (
apply_api_compatibility is None and env_spec.apply_api_compatibility is True
):
env = EnvCompatibility(env, render_mode)
# Run the environment checker as the lowest level wrapper
if disable_env_checker is False or (
disable_env_checker is None and env_spec.disable_env_checker is False
):
env = PassiveEnvChecker(env)
# Add the order enforcing wrapper
if env_spec.order_enforce:
env = OrderEnforcing(env)
# Add the time limit wrapper
if max_episode_steps is not None:
env = TimeLimit(env, max_episode_steps)
elif env_spec.max_episode_steps is not None:
env = TimeLimit(env, env_spec.max_episode_steps)
# Add the auto-reset wrapper
if autoreset is True or (autoreset is None and env_spec.autoreset is True):
env = AutoResetWrapper(env)
for wrapper_spec in env_spec.additional_wrappers[num_prior_wrappers:]:
if wrapper_spec.kwargs is None:
raise ValueError(
f"{wrapper_spec.name} wrapper does not inherit from `gymnasium.utils.RecordConstructorArgs`, therefore, the wrapper cannot be recreated."
)
env = load_env_creator(wrapper_spec.entry_point)(env=env, **wrapper_spec.kwargs)
# Add human rendering wrapper
if apply_human_rendering:
env = HumanRendering(env)
elif apply_render_collection:
env = RenderCollection(env)
return env
def make_vec(

View File

@@ -1,7 +1,16 @@
"""Wrapper that autoreset environments when `terminated=True` or `truncated=True`."""
from __future__ import annotations
from copy import deepcopy
from typing import TYPE_CHECKING
import gymnasium as gym
if TYPE_CHECKING:
from gymnasium.envs.registration import EnvSpec
class AutoResetWrapper(gym.Wrapper, gym.utils.RecordConstructorArgs):
"""A class for providing an automatic reset functionality for gymnasium environments when calling :meth:`self.step`.
@@ -60,3 +69,17 @@ class AutoResetWrapper(gym.Wrapper, gym.utils.RecordConstructorArgs):
info = new_info
return obs, reward, terminated, truncated, info
@property
def spec(self) -> EnvSpec | None:
"""Modifies the environment spec to specify the `autoreset=True`."""
if self._cached_spec is not None:
return self._cached_spec
env_spec = self.env.spec
if env_spec is not None:
env_spec = deepcopy(env_spec)
env_spec.autoreset = True
self._cached_spec = env_spec
return env_spec

View File

@@ -1,4 +1,9 @@
"""A passive environment checker wrapper for an environment's observation and action space along with the reset, step and render functions."""
from __future__ import annotations
from copy import deepcopy
from typing import TYPE_CHECKING
import gymnasium as gym
from gymnasium.core import ActType
from gymnasium.utils.passive_env_checker import (
@@ -10,6 +15,10 @@ from gymnasium.utils.passive_env_checker import (
)
if TYPE_CHECKING:
from gymnasium.envs.registration import EnvSpec
class PassiveEnvChecker(gym.Wrapper, gym.utils.RecordConstructorArgs):
"""A passive environment checker wrapper that surrounds the step, reset and render functions to check they follow the gymnasium API."""
@@ -54,3 +63,17 @@ class PassiveEnvChecker(gym.Wrapper, gym.utils.RecordConstructorArgs):
return env_render_passive_checker(self.env, *args, **kwargs)
else:
return self.env.render(*args, **kwargs)
@property
def spec(self) -> EnvSpec | None:
"""Modifies the environment spec to such that `disable_env_checker=False`."""
if self._cached_spec is not None:
return self._cached_spec
env_spec = self.env.spec
if env_spec is not None:
env_spec = deepcopy(env_spec)
env_spec.disable_env_checker = False
self._cached_spec = env_spec
return env_spec

View File

@@ -1,8 +1,17 @@
"""Wrapper to enforce the proper ordering of environment operations."""
from __future__ import annotations
from copy import deepcopy
from typing import TYPE_CHECKING
import gymnasium as gym
from gymnasium.error import ResetNeeded
if TYPE_CHECKING:
from gymnasium.envs.registration import EnvSpec
class OrderEnforcing(gym.Wrapper, gym.utils.RecordConstructorArgs):
"""A wrapper that will produce an error if :meth:`step` is called before an initial :meth:`reset`.
@@ -64,3 +73,17 @@ class OrderEnforcing(gym.Wrapper, gym.utils.RecordConstructorArgs):
def has_reset(self):
"""Returns if the environment has been reset before."""
return self._has_reset
@property
def spec(self) -> EnvSpec | None:
"""Modifies the environment spec to add the `order_enforce=True`."""
if self._cached_spec is not None:
return self._cached_spec
env_spec = self.env.spec
if env_spec is not None:
env_spec = deepcopy(env_spec)
env_spec.order_enforce = True
self._cached_spec = env_spec
return env_spec

View File

@@ -1,9 +1,16 @@
"""Wrapper for limiting the time steps of an environment."""
from typing import Optional
from __future__ import annotations
from copy import deepcopy
from typing import TYPE_CHECKING
import gymnasium as gym
if TYPE_CHECKING:
from gymnasium.envs.registration import EnvSpec
class TimeLimit(gym.Wrapper, gym.utils.RecordConstructorArgs):
"""This wrapper will issue a `truncated` signal if a maximum number of timesteps is exceeded.
@@ -20,7 +27,7 @@ class TimeLimit(gym.Wrapper, gym.utils.RecordConstructorArgs):
def __init__(
self,
env: gym.Env,
max_episode_steps: Optional[int] = None,
max_episode_steps: int,
):
"""Initializes the :class:`TimeLimit` wrapper with an environment and the number of steps after which truncation will occur.
@@ -33,9 +40,6 @@ class TimeLimit(gym.Wrapper, gym.utils.RecordConstructorArgs):
)
gym.Wrapper.__init__(self, env)
if max_episode_steps is None and self.env.spec is not None:
assert env.spec is not None
max_episode_steps = env.spec.max_episode_steps
self._max_episode_steps = max_episode_steps
self._elapsed_steps = None
@@ -69,3 +73,17 @@ class TimeLimit(gym.Wrapper, gym.utils.RecordConstructorArgs):
"""
self._elapsed_steps = 0
return self.env.reset(**kwargs)
@property
def spec(self) -> EnvSpec | None:
"""Modifies the environment spec to include the `max_episode_steps=self._max_episode_steps`."""
if self._cached_spec is not None:
return self._cached_spec
env_spec = self.env.spec
if env_spec is not None:
env_spec = deepcopy(env_spec)
env_spec.max_episode_steps = self._max_episode_steps
self._cached_spec = env_spec
return env_spec

View File

@@ -1,4 +1,4 @@
"""Example file showing usage of env.specstack."""
"""Test for the `EnvSpec`, in particular, a full integration with `EnvSpec`."""
import pickle
import pytest
@@ -11,15 +11,16 @@ from gymnasium.utils.env_checker import data_equivalence
def test_full_integration():
# Create an environment to test with
env = gym.make("CartPole-v1", render_mode="rgb_array").unwrapped
env = gym.make("CartPole-v1", render_mode="rgb_array")
env = gym.wrappers.FlattenObservation(env)
env = gym.wrappers.TimeAwareObservation(env)
env = gym.wrappers.NormalizeReward(env, gamma=0.8)
# Generate the spec_stack
env_spec = env.spec
assert isinstance(env_spec, EnvSpec)
# additional_wrappers = (TimeAwareObservation, NormalizeReward)
assert len(env_spec.additional_wrappers) == 2
# env_spec.pprint()
# Serialize the spec_stack
@@ -30,10 +31,7 @@ def test_full_integration():
recreate_env_spec = EnvSpec.from_json(env_spec_json)
# recreate_env_spec.pprint()
for wrapper_spec, recreated_wrapper_spec in zip(
env_spec.applied_wrappers, recreate_env_spec.applied_wrappers
):
assert wrapper_spec == recreated_wrapper_spec
assert env_spec.additional_wrappers == recreate_env_spec.additional_wrappers
assert recreate_env_spec == env_spec
# Recreate the environment using the spec_stack
@@ -86,43 +84,6 @@ def test_env_spec_to_from_json(env_spec: EnvSpec):
assert env_spec == recreated_env_spec
def test_wrapped_env_entry_point():
def _create_env():
_env = gym.make("CartPole-v1", render_mode="rgb_array")
_env = gym.wrappers.FlattenObservation(_env)
return _env
gym.register("TestingEnv-v0", entry_point=_create_env)
env = gym.make("TestingEnv-v0")
env = gym.wrappers.TimeAwareObservation(env)
env = gym.wrappers.NormalizeReward(env, gamma=0.8)
recreated_env = gym.make(env.spec)
obs, info = env.reset(seed=42)
recreated_obs, recreated_info = recreated_env.reset(seed=42)
assert data_equivalence(obs, recreated_obs)
assert data_equivalence(info, recreated_info)
action = env.action_space.sample()
obs, reward, terminated, truncated, info = env.step(action)
(
recreated_obs,
recreated_reward,
recreated_terminated,
recreated_truncated,
recreated_info,
) = recreated_env.step(action)
assert data_equivalence(obs, recreated_obs)
assert data_equivalence(reward, recreated_reward)
assert data_equivalence(terminated, recreated_terminated)
assert data_equivalence(truncated, recreated_truncated)
assert data_equivalence(info, recreated_info)
del gym.registry["TestingEnv-v0"]
def test_pickling_env_stack():
env = gym.make("CartPole-v1", render_mode="rgb_array")
@@ -163,6 +124,8 @@ def test_pickling_env_stack():
def test_env_spec_pprint():
env = gym.make("CartPole-v1")
env = gym.wrappers.TimeAwareObservation(env)
env_spec = env.spec
assert env_spec is not None
@@ -172,10 +135,8 @@ def test_env_spec_pprint():
== """id=CartPole-v1
reward_threshold=475.0
max_episode_steps=500
applied_wrappers=[
name=PassiveEnvChecker, kwargs={},
name=OrderEnforcing, kwargs={'disable_render_order_enforcing': False},
name=TimeLimit, kwargs={'max_episode_steps': 500}
additional_wrappers=[
name=TimeAwareObservation, kwargs={}
]"""
)
@@ -186,10 +147,8 @@ applied_wrappers=[
entry_point=gymnasium.envs.classic_control.cartpole:CartPoleEnv
reward_threshold=475.0
max_episode_steps=500
applied_wrappers=[
name=PassiveEnvChecker, entry_point=gymnasium.wrappers.env_checker:PassiveEnvChecker, kwargs={},
name=OrderEnforcing, entry_point=gymnasium.wrappers.order_enforcing:OrderEnforcing, kwargs={'disable_render_order_enforcing': False},
name=TimeLimit, entry_point=gymnasium.wrappers.time_limit:TimeLimit, kwargs={'max_episode_steps': 500}
additional_wrappers=[
name=TimeAwareObservation, entry_point=gymnasium.wrappers.time_aware_observation:TimeAwareObservation, kwargs={}
]"""
)
@@ -205,14 +164,12 @@ order_enforce=True
autoreset=False
disable_env_checker=False
applied_api_compatibility=False
applied_wrappers=[
name=PassiveEnvChecker, kwargs={},
name=OrderEnforcing, kwargs={'disable_render_order_enforcing': False},
name=TimeLimit, kwargs={'max_episode_steps': 500}
additional_wrappers=[
name=TimeAwareObservation, kwargs={}
]"""
)
env_spec.applied_wrappers = ()
env_spec.additional_wrappers = ()
output = env_spec.pprint(disable_print=True)
assert (
output
@@ -233,5 +190,5 @@ order_enforce=True
autoreset=False
disable_env_checker=False
applied_api_compatibility=False
applied_wrappers=[]"""
additional_wrappers=[]"""
)

View File

@@ -1,4 +1,4 @@
"""Tests that gym.make works as expected."""
"""Tests that `gym.make` works as expected."""
from __future__ import annotations
import re
@@ -11,6 +11,8 @@ import gymnasium as gym
from gymnasium import Env
from gymnasium.core import ActType, ObsType, WrapperObsType
from gymnasium.envs.classic_control import CartPoleEnv
from gymnasium.error import NameNotFound
from gymnasium.utils.env_checker import data_equivalence
from gymnasium.wrappers import (
AutoResetWrapper,
HumanRendering,
@@ -24,111 +26,101 @@ from tests.testing_env import GenericTestEnv, old_step_func
from tests.wrappers.utils import has_wrapper
try:
import shimmy
except ImportError:
shimmy = None
# Tests
# * basic example
# * parameters (equivalent for str and EnvSpec)
# 1. max_episode_steps
# 2. autoreset
# 3. apply_api_compatibility
# 4. disable_env_checker
# * rendering
# 1. render_mode
# 2. HumanRendering
# 3. RenderCollection
# * make kwargs
# * make import module
# * make env spec additional wrappers
# * env_id str errors
@pytest.fixture(scope="function")
def register_testing_envs():
"""Registers testing envs for `gym.make`"""
gym.register(
id="test.ArgumentEnv-v0",
entry_point="tests.envs.registration.utils_envs:ArgumentEnv",
kwargs={
"arg1": "arg1",
"arg2": "arg2",
},
)
def test_no_arguments(env_id: str = "CartPole-v1"):
"""Test `gym.make` using str and EnvSpec with no arguments."""
env_from_id = gym.make(env_id)
assert env_from_id.spec is not None
assert env_from_id.spec.id == env_id
assert isinstance(env_from_id.unwrapped, CartPoleEnv)
gym.register(
id="test/NoHuman-v0",
entry_point="tests.envs.registration.utils_envs:NoHuman",
)
gym.register(
id="test/NoHumanOldAPI-v0",
entry_point="tests.envs.registration.utils_envs:NoHumanOldAPI",
)
env_spec = gym.spec(env_id)
env_from_spec = gym.make(env_spec)
assert env_from_spec.spec is not None
assert env_from_spec.spec.id == env_id
assert isinstance(env_from_spec.unwrapped, CartPoleEnv)
gym.register(
id="test/NoHumanNoRGB-v0",
entry_point="tests.envs.registration.utils_envs:NoHumanNoRGB",
)
gym.register(
id="test/NoRenderModesMetadata-v0",
entry_point="tests.envs.registration.utils_envs:NoRenderModesMetadata",
)
yield
del gym.envs.registration.registry["test.ArgumentEnv-v0"]
del gym.envs.registration.registry["test/NoRenderModesMetadata-v0"]
del gym.envs.registration.registry["test/NoHuman-v0"]
del gym.envs.registration.registry["test/NoHumanOldAPI-v0"]
del gym.envs.registration.registry["test/NoHumanNoRGB-v0"]
assert env_from_id.spec == env_from_spec.spec
def test_make():
"""Test basic `gym.make`."""
env = gym.make("CartPole-v1")
assert env.spec is not None
assert env.spec.id == "CartPole-v1"
assert isinstance(env.unwrapped, CartPoleEnv)
env.close()
def test_max_episode_steps(register_parameter_envs):
"""Test the `max_episode_steps` parameter in `gym.make`."""
for make_id in ["CartPole-v1", gym.spec("CartPole-v1")]:
env_spec = gym.spec(make_id) if isinstance(make_id, str) else make_id
# Use the spec's value
env = gym.make(make_id)
assert has_wrapper(env, TimeLimit)
assert env.spec is not None
assert env.spec.max_episode_steps == env_spec.max_episode_steps
# Set a custom max episode steps value
assert env_spec.max_episode_steps != 100
env = gym.make(make_id, max_episode_steps=100)
assert has_wrapper(env, TimeLimit)
assert env.spec is not None
assert env.spec.max_episode_steps == 100, make_id
for make_id in ["NoMaxEpisodeStepsEnv-v0", gym.spec("NoMaxEpisodeStepsEnv-v0")]:
env_spec = gym.spec(make_id) if isinstance(make_id, str) else make_id
# env spec has no max episode steps
assert env_spec.max_episode_steps is None
env = gym.make(make_id)
assert env.spec is not None
assert env.spec.max_episode_steps is None
assert has_wrapper(env, TimeLimit) is False
# set a custom max episode steps values
env = gym.make(make_id, max_episode_steps=100)
assert env.spec is not None
assert env.spec.max_episode_steps == 100
assert has_wrapper(env, TimeLimit)
def test_make_deprecated():
"""Test make with a deprecated environment (i.e., doesn't exist)."""
with warnings.catch_warnings(record=True):
with pytest.raises(
gym.error.Error,
match=re.escape(
"Environment version v0 for `Humanoid` is deprecated. Please use `Humanoid-v4` instead."
),
):
gym.make("Humanoid-v0")
def test_autorest(register_parameter_envs):
"""Test the `autoreset` parameter in `gym.make`."""
for make_id in [
"CartPole-v1",
gym.spec("CartPole-v1"),
"AutoresetEnv-v0",
gym.spec("AutoresetEnv-v0"),
]:
env_spec = gym.spec(make_id) if isinstance(make_id, str) else make_id
# Use the spec's value
env = gym.make(make_id)
assert env.spec is not None
assert env.spec.autoreset == env_spec.autoreset
assert has_wrapper(env, AutoResetWrapper) is env_spec.autoreset
def test_make_max_episode_steps(register_testing_envs):
# Default, uses the spec's
env = gym.make("CartPole-v1")
assert has_wrapper(env, TimeLimit)
assert env.spec is not None
assert env.spec.max_episode_steps == gym.spec("CartPole-v1").max_episode_steps
env.close()
# Set autoreset is True
env = gym.make(make_id, autoreset=True)
assert has_wrapper(env, AutoResetWrapper)
assert env.spec is not None
assert env.spec.autoreset is True
# Custom max episode steps
assert gym.spec("CartPole-v1").max_episode_steps != 100
env = gym.make("CartPole-v1", max_episode_steps=100)
assert has_wrapper(env, TimeLimit)
assert env.spec is not None
assert env.spec.max_episode_steps == 100
env.close()
# Env spec has no max episode steps
assert gym.spec("test.ArgumentEnv-v0").max_episode_steps is None
env = gym.make("test.ArgumentEnv-v0", arg1=None, arg2=None, arg3=None)
assert env.spec is not None
assert env.spec.max_episode_steps is None
assert has_wrapper(env, TimeLimit) is False
env.close()
def test_make_autoreset():
"""Tests that `gym.make` autoreset wrapper is applied only when `gym.make(..., autoreset=True)`."""
env = gym.make("CartPole-v1")
assert has_wrapper(env, AutoResetWrapper) is False
env.close()
env = gym.make("CartPole-v1", autoreset=False)
assert has_wrapper(env, AutoResetWrapper) is False
env.close()
env = gym.make("CartPole-v1", autoreset=True)
assert has_wrapper(env, AutoResetWrapper)
env.close()
# Set autoreset is False
env = gym.make(make_id, autoreset=False)
assert has_wrapper(env, AutoResetWrapper) is False
assert env.spec is not None
assert env.spec.autoreset is False
@pytest.mark.parametrize(
@@ -142,7 +134,7 @@ def test_make_autoreset():
[True, None, True],
],
)
def test_make_disable_env_checker(
def test_disable_env_checker(
registration_disabled: bool, make_disabled: bool | None, if_disabled: bool
):
"""Tests that `gym.make` disable env checker is applied only when `gym.make(..., disable_env_checker=False)`.
@@ -150,84 +142,76 @@ def test_make_disable_env_checker(
The ordering is 1. if the `make(..., disable_env_checker=...)` is bool, then the `registration(..., disable_env_checker=...)`
"""
gym.register(
"testing-env-v0",
"DisableEnvCheckerEnv-v0",
lambda: GenericTestEnv(),
disable_env_checker=registration_disabled,
)
# Test when the registered EnvSpec.disable_env_checker = False
env = gym.make("testing-env-v0", disable_env_checker=make_disabled)
env = gym.make("DisableEnvCheckerEnv-v0", disable_env_checker=make_disabled)
assert has_wrapper(env, PassiveEnvChecker) is not if_disabled
env.close()
del gym.registry["testing-env-v0"]
env_spec = gym.spec("DisableEnvCheckerEnv-v0")
env = gym.make(env_spec, disable_env_checker=make_disabled)
assert has_wrapper(env, PassiveEnvChecker) is not if_disabled
del gym.registry["DisableEnvCheckerEnv-v0"]
def test_make_apply_api_compatibility():
"""Test the API compatibility wrapper."""
gym.register(
"testing-old-env",
lambda: GenericTestEnv(step_func=old_step_func),
apply_api_compatibility=True,
max_episode_steps=3,
)
def test_apply_api_compatibility(register_parameter_envs):
"""Test the `apply_api_compatibility` parameter for `gym.make`."""
# Apply the environment compatibility and check it works as intended
env = gym.make("testing-old-env")
assert isinstance(env.unwrapped, gym.wrappers.EnvCompatibility)
for make_id in ["EnabledApplyApiComp-v0", gym.spec("EnabledApplyApiComp-v0")]:
env = gym.make(make_id)
assert isinstance(env.unwrapped, gym.wrappers.EnvCompatibility)
env.reset()
assert len(env.step(env.action_space.sample())) == 5
env.step(env.action_space.sample())
_, _, termination, truncation, _ = env.step(env.action_space.sample())
assert termination is False and truncation is True
# Turn off the spec api compatibility
gym.spec("testing-old-env").apply_api_compatibility = False
env = gym.make("testing-old-env")
assert isinstance(env.unwrapped, gym.wrappers.EnvCompatibility) is False
env.reset()
with pytest.raises(
ValueError, match=re.escape("not enough values to unpack (expected 5, got 4)")
):
# env has time limit of 3 enabling this test
env.reset()
assert len(env.step(env.action_space.sample())) == 5
env.step(env.action_space.sample())
_, _, termination, truncation, _ = env.step(env.action_space.sample())
assert termination is False and truncation is True
# Apply the environment compatibility and check it works as intended
env = gym.make("testing-old-env", apply_api_compatibility=True)
assert isinstance(env.unwrapped, gym.wrappers.EnvCompatibility)
for make_id in ["DisabledApplyApiComp-v0", gym.spec("DisabledApplyApiComp-v0")]:
# Turn off the spec api compatibility
env = gym.make(make_id)
assert isinstance(env.unwrapped, gym.wrappers.EnvCompatibility) is False
env.reset()
with pytest.raises(
ValueError,
match=re.escape("not enough values to unpack (expected 5, got 4)"),
):
env.step(env.action_space.sample())
env.reset()
assert len(env.step(env.action_space.sample())) == 5
env.step(env.action_space.sample())
_, _, termination, truncation, _ = env.step(env.action_space.sample())
assert termination is False and truncation is True
# Apply the environment compatibility and check it works as intended
assert env.spec is not None
assert env.spec.apply_api_compatibility is False
env = gym.make(make_id, apply_api_compatibility=True)
assert isinstance(env.unwrapped, gym.wrappers.EnvCompatibility)
del gym.registry["testing-old-env"]
env.reset()
assert len(env.step(env.action_space.sample())) == 5
env.step(env.action_space.sample())
_, _, termination, truncation, _ = env.step(env.action_space.sample())
assert termination is False and truncation is True
def test_make_order_enforcing():
def test_order_enforcing(register_parameter_envs):
"""Checks that gym.make wrappers the environment with the OrderEnforcing wrapper."""
assert all(spec.order_enforce is True for spec in all_testing_env_specs)
assert all(spec.order_enforce is False for spec in all_testing_env_specs)
env = gym.make("CartPole-v1")
assert has_wrapper(env, OrderEnforcing)
# We can assume that there all other specs will also have the order enforcing
env.close()
for make_id in ["CartPole-v1", gym.spec("CartPole-v1")]:
env = gym.make(make_id)
assert has_wrapper(env, OrderEnforcing)
gym.register(
id="test.OrderlessArgumentEnv-v0",
entry_point="tests.envs.registration.utils_envs:ArgumentEnv",
order_enforce=False,
kwargs={"arg1": None, "arg2": None, "arg3": None},
)
env = gym.make("test.OrderlessArgumentEnv-v0")
assert has_wrapper(env, OrderEnforcing) is False
env.close()
for make_id in ["OrderlessEnv-v0", gym.spec("OrderlessEnv-v0")]:
env = gym.make(make_id)
assert has_wrapper(env, OrderEnforcing) is False
# There is no `make(..., order_enforcing=...)` so we don't test that
def test_make_render_mode():
def test_make_with_render_mode():
"""Test the `make(..., render_mode=...)`, in particular, if to apply the `RenderCollection` or the `HumanRendering`."""
env = gym.make("CartPole-v1", render_mode=None)
assert env.render_mode is None
@@ -267,10 +251,12 @@ def test_make_render_collection():
env.close()
def test_make_human_rendering(register_testing_envs):
def test_make_human_rendering(register_rendering_testing_envs):
# Make sure that native rendering is used when possible
env = gym.make("CartPole-v1", render_mode="human")
assert not has_wrapper(env, HumanRendering) # Should use native human-rendering
assert (
has_wrapper(env, HumanRendering) is False
) # Should use native human-rendering
assert env.render_mode == "human"
env.close()
@@ -281,7 +267,7 @@ def test_make_human_rendering(register_testing_envs):
),
):
# Make sure that `HumanRendering` is applied here as the environment doesn't use native rendering
env = gym.make("test/NoHuman-v0", render_mode="human")
env = gym.make("NoHumanRendering-v0", render_mode="human")
assert has_wrapper(env, HumanRendering)
assert env.render_mode == "human"
env.close()
@@ -290,7 +276,7 @@ def test_make_human_rendering(register_testing_envs):
TypeError, match=re.escape("got an unexpected keyword argument 'render_mode'")
):
gym.make(
"test/NoHumanOldAPI-v0",
"NoHumanRenderingOldAPI-v0",
render_mode="rgb_array_list",
)
@@ -299,10 +285,10 @@ def test_make_human_rendering(register_testing_envs):
with pytest.raises(
gym.error.Error,
match=re.escape(
"You passed render_mode='human' although test/NoHumanOldAPI-v0 doesn't implement human-rendering natively."
"You passed render_mode='human' although NoHumanRenderingOldAPI-v0 doesn't implement human-rendering natively."
),
):
gym.make("test/NoHumanOldAPI-v0", render_mode="human")
gym.make("NoHumanRenderingOldAPI-v0", render_mode="human")
# This test ensures that the additional exception "Gym tried to apply the HumanRendering wrapper but it looks like
# your environment is using the old rendering API" is *not* triggered by a TypeError that originate from
@@ -320,10 +306,10 @@ def test_make_human_rendering(register_testing_envs):
"\x1b[33mWARN: The environment is being initialised with render_mode='rgb_array' that is not in the possible render_modes ([]).\x1b[0m"
),
):
gym.make("test/NoRenderModesMetadata-v0", render_mode="rgb_array")
gym.make("NoRenderModesMetadata-v0", render_mode="rgb_array")
def test_make_kwargs(register_testing_envs):
def test_make_kwargs(register_kwargs_env):
env = gym.make(
"test.ArgumentEnv-v0",
arg2="override_arg2",
@@ -367,12 +353,11 @@ class NoRecordArgsWrapper(gym.ObservationWrapper):
return self.observation_space.sample()
def test_make_env_spec():
def test_make_with_env_spec():
# make
env_1 = gym.make(gym.spec("CartPole-v1"))
assert isinstance(env_1, CartPoleEnv)
assert env_1 is env_1.unwrapped
env_1.close()
id_env = gym.make("CartPole-v1")
spec_env = gym.make(gym.spec("CartPole-v1"))
assert id_env.spec == spec_env.spec
# make with applied wrappers
env_2 = gym.wrappers.NormalizeReward(
@@ -396,30 +381,202 @@ def test_make_env_spec():
# make with wrapper in env-creator
gym.register(
"CartPole-v3", lambda: gym.wrappers.TimeAwareObservation(CartPoleEnv())
"CartPole-v3",
lambda: gym.wrappers.TimeAwareObservation(CartPoleEnv()),
disable_env_checker=True,
order_enforce=False,
)
env_4 = gym.make(gym.spec("CartPole-v3"))
assert isinstance(env_4, gym.wrappers.TimeAwareObservation)
assert isinstance(env_4.env, CartPoleEnv)
env_4.close()
# make with no ezpickle wrapper
env_5 = NoRecordArgsWrapper(gym.make("CartPole-v1").unwrapped)
gym.register(
"CartPole-v4",
lambda: CartPoleEnv(),
disable_env_checker=True,
order_enforce=False,
additional_wrappers=(gym.wrappers.TimeAwareObservation.wrapper_spec(),),
)
env_5 = gym.make(gym.spec("CartPole-v4"))
assert isinstance(env_5, gym.wrappers.TimeAwareObservation)
assert isinstance(env_5.env, CartPoleEnv)
env_5.close()
# make with no ezpickle wrapper
env_6 = NoRecordArgsWrapper(gym.make("CartPole-v1"))
with pytest.raises(
ValueError,
match=re.escape(
"NoRecordArgsWrapper wrapper does not inherit from `gymnasium.utils.RecordConstructorArgs`, therefore, the wrapper cannot be recreated."
),
):
gym.make(env_5.spec)
gym.make(env_6.spec)
# make with no ezpickle wrapper but in the entry point
gym.register("CartPole-v4", entry_point=lambda: NoRecordArgsWrapper(CartPoleEnv()))
env_6 = gym.make(gym.spec("CartPole-v4"))
assert isinstance(env_6, NoRecordArgsWrapper)
assert isinstance(env_6.unwrapped, CartPoleEnv)
gym.register(
"CartPole-v5",
entry_point=lambda: NoRecordArgsWrapper(CartPoleEnv()),
disable_env_checker=True,
order_enforce=False,
)
env_7 = gym.make(gym.spec("CartPole-v5"))
assert isinstance(env_7, NoRecordArgsWrapper)
assert isinstance(env_7.unwrapped, CartPoleEnv)
gym.register(
"CartPole-v6",
entry_point=lambda: CartPoleEnv(),
disable_env_checker=True,
order_enforce=False,
additional_wrappers=(NoRecordArgsWrapper.wrapper_spec(),),
)
del gym.registry["CartPole-v2"]
del gym.registry["CartPole-v3"]
del gym.registry["CartPole-v4"]
del gym.registry["CartPole-v5"]
del gym.registry["CartPole-v6"]
def test_make_with_env_spec_levels():
"""Test that we can recreate the environment at each 'level'."""
env = gym.wrappers.NormalizeReward(
gym.wrappers.TimeAwareObservation(
gym.wrappers.FlattenObservation(
gym.make("CartPole-v1", render_mode="rgb_array")
)
),
gamma=0.8,
)
while env is not env.unwrapped:
recreated_env = gym.make(env.spec)
assert env.spec == recreated_env.spec
env = env.env
def test_wrapped_env_entry_point():
def _create_env():
_env = gym.make("CartPole-v1", render_mode="rgb_array")
_env = gym.wrappers.FlattenObservation(_env)
return _env
gym.register("TestingEnv-v0", entry_point=_create_env)
env = gym.make("TestingEnv-v0")
env = gym.wrappers.TimeAwareObservation(env)
env = gym.wrappers.NormalizeReward(env, gamma=0.8)
recreated_env = gym.make(env.spec)
obs, info = env.reset(seed=42)
recreated_obs, recreated_info = recreated_env.reset(seed=42)
assert data_equivalence(obs, recreated_obs)
assert data_equivalence(info, recreated_info)
action = env.action_space.sample()
obs, reward, terminated, truncated, info = env.step(action)
(
recreated_obs,
recreated_reward,
recreated_terminated,
recreated_truncated,
recreated_info,
) = recreated_env.step(action)
assert data_equivalence(obs, recreated_obs)
assert data_equivalence(reward, recreated_reward)
assert data_equivalence(terminated, recreated_terminated)
assert data_equivalence(truncated, recreated_truncated)
assert data_equivalence(info, recreated_info)
del gym.registry["TestingEnv-v0"]
def test_make_errors():
"""Test make with a deprecated environment (i.e., doesn't exist)."""
with warnings.catch_warnings(record=True):
with pytest.raises(
gym.error.Error,
match=re.escape(
"Environment version v0 for `Humanoid` is deprecated. Please use `Humanoid-v4` instead."
),
):
gym.make("Humanoid-v0")
with pytest.raises(
NameNotFound, match=re.escape("Environment `NonExistenceEnv` doesn't exist.")
):
gym.make("NonExistenceEnv-v0")
@pytest.fixture(scope="function")
def register_parameter_envs():
gym.register(
"NoMaxEpisodeStepsEnv-v0", lambda: GenericTestEnv(), max_episode_steps=None
)
gym.register("AutoresetEnv-v0", lambda: GenericTestEnv(), autoreset=True)
gym.register(
"EnabledApplyApiComp-v0",
lambda: GenericTestEnv(step_func=old_step_func),
apply_api_compatibility=True,
max_episode_steps=3,
)
gym.register(
"DisabledApplyApiComp-v0",
lambda: GenericTestEnv(step_func=old_step_func),
apply_api_compatibility=False,
max_episode_steps=3,
)
gym.register("OrderlessEnv-v0", lambda: GenericTestEnv(), order_enforce=False)
yield
del gym.registry["NoMaxEpisodeStepsEnv-v0"]
del gym.registry["AutoresetEnv-v0"]
del gym.registry["EnabledApplyApiComp-v0"]
del gym.registry["DisabledApplyApiComp-v0"]
del gym.registry["OrderlessEnv-v0"]
@pytest.fixture(scope="function")
def register_kwargs_env():
gym.register(
id="test.ArgumentEnv-v0",
entry_point="tests.envs.registration.utils_envs:ArgumentEnv",
kwargs={
"arg1": "arg1",
"arg2": "arg2",
},
)
@pytest.fixture(scope="function")
def register_rendering_testing_envs():
gym.register(
id="NoHumanRendering-v0",
entry_point="tests.envs.registration.utils_envs:NoHuman",
)
gym.register(
id="NoHumanRenderingOldAPI-v0",
entry_point="tests.envs.registration.utils_envs:NoHumanOldAPI",
)
gym.register(
id="NoHumanRenderingNoRGB-v0",
entry_point="tests.envs.registration.utils_envs:NoHumanNoRGB",
)
gym.register(
id="NoRenderModesMetadata-v0",
entry_point="tests.envs.registration.utils_envs:NoRenderModesMetadata",
)
yield
del gym.envs.registration.registry["NoHumanRendering-v0"]
del gym.envs.registration.registry["NoHumanRenderingOldAPI-v0"]
del gym.envs.registration.registry["NoHumanRenderingNoRGB-v0"]
del gym.envs.registration.registry["NoRenderModesMetadata-v0"]

View File

@@ -171,7 +171,7 @@ def test_compatibility_with_old_style_env():
"""Test compatibility with old style environment."""
env = OldStyleEnv()
env = OrderEnforcing(env)
env = TimeLimit(env)
env = TimeLimit(env, 100)
obs = env.reset()
assert obs == 0

View File

@@ -7,7 +7,7 @@ from gymnasium.wrappers import TimeLimit
def test_time_limit_reset_info():
env = gym.make("CartPole-v1", disable_env_checker=True)
env = TimeLimit(env)
env = TimeLimit(env, 100)
ob_space = env.observation_space
obs, info = env.reset()
assert ob_space.contains(obs)