mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-22 07:02:19 +00:00
Update EnvSpec
to improve backward compatibility (#355)
This commit is contained in:
@@ -11,7 +11,7 @@ from gymnasium.utils import RecordConstructorArgs, seeding
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from gymnasium.envs.registration import EnvSpec
|
from gymnasium.envs.registration import EnvSpec, WrapperSpec
|
||||||
|
|
||||||
ObsType = TypeVar("ObsType")
|
ObsType = TypeVar("ObsType")
|
||||||
ActType = TypeVar("ActType")
|
ActType = TypeVar("ActType")
|
||||||
@@ -266,6 +266,8 @@ class Wrapper(
|
|||||||
self._reward_range: tuple[SupportsFloat, SupportsFloat] | None = None
|
self._reward_range: tuple[SupportsFloat, SupportsFloat] | None = None
|
||||||
self._metadata: dict[str, Any] | None = None
|
self._metadata: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
self._cached_spec: EnvSpec | None = None
|
||||||
|
|
||||||
def __getattr__(self, name: str) -> Any:
|
def __getattr__(self, name: str) -> Any:
|
||||||
"""Returns an attribute with ``name``, unless ``name`` starts with an underscore."""
|
"""Returns an attribute with ``name``, unless ``name`` starts with an underscore."""
|
||||||
if name == "_np_random":
|
if name == "_np_random":
|
||||||
@@ -279,11 +281,11 @@ class Wrapper(
|
|||||||
@property
|
@property
|
||||||
def spec(self) -> EnvSpec | None:
|
def spec(self) -> EnvSpec | None:
|
||||||
"""Returns the :attr:`Env` :attr:`spec` attribute with the `WrapperSpec` if the wrapper inherits from `EzPickle`."""
|
"""Returns the :attr:`Env` :attr:`spec` attribute with the `WrapperSpec` if the wrapper inherits from `EzPickle`."""
|
||||||
|
if self._cached_spec is not None:
|
||||||
|
return self._cached_spec
|
||||||
|
|
||||||
env_spec = self.env.spec
|
env_spec = self.env.spec
|
||||||
|
|
||||||
if env_spec is not None:
|
if env_spec is not None:
|
||||||
from gymnasium.envs.registration import WrapperSpec
|
|
||||||
|
|
||||||
# See if the wrapper inherits from `RecordConstructorArgs` then add the kwargs otherwise use `None` for the wrapper kwargs. This will raise an error in `make`
|
# See if the wrapper inherits from `RecordConstructorArgs` then add the kwargs otherwise use `None` for the wrapper kwargs. This will raise an error in `make`
|
||||||
if isinstance(self, RecordConstructorArgs):
|
if isinstance(self, RecordConstructorArgs):
|
||||||
kwargs = getattr(self, "_saved_kwargs")
|
kwargs = getattr(self, "_saved_kwargs")
|
||||||
@@ -293,6 +295,8 @@ class Wrapper(
|
|||||||
else:
|
else:
|
||||||
kwargs = None
|
kwargs = None
|
||||||
|
|
||||||
|
from gymnasium.envs.registration import WrapperSpec
|
||||||
|
|
||||||
wrapper_spec = WrapperSpec(
|
wrapper_spec = WrapperSpec(
|
||||||
name=self.class_name(),
|
name=self.class_name(),
|
||||||
entry_point=f"{self.__module__}:{type(self).__name__}",
|
entry_point=f"{self.__module__}:{type(self).__name__}",
|
||||||
@@ -301,10 +305,22 @@ class Wrapper(
|
|||||||
|
|
||||||
# to avoid reference issues we deepcopy the prior environments spec and add the new information
|
# to avoid reference issues we deepcopy the prior environments spec and add the new information
|
||||||
env_spec = deepcopy(env_spec)
|
env_spec = deepcopy(env_spec)
|
||||||
env_spec.applied_wrappers += (wrapper_spec,)
|
env_spec.additional_wrappers += (wrapper_spec,)
|
||||||
|
|
||||||
|
self._cached_spec = env_spec
|
||||||
return env_spec
|
return env_spec
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def wrapper_spec(cls, **kwargs: Any) -> WrapperSpec:
|
||||||
|
"""Generates a `WrapperSpec` for the wrappers."""
|
||||||
|
from gymnasium.envs.registration import WrapperSpec
|
||||||
|
|
||||||
|
return WrapperSpec(
|
||||||
|
name=cls.class_name(),
|
||||||
|
entry_point=f"{cls.__module__}:{cls.__name__}",
|
||||||
|
kwargs=kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def class_name(cls) -> str:
|
def class_name(cls) -> str:
|
||||||
"""Returns the class name of the wrapper."""
|
"""Returns the class name of the wrapper."""
|
||||||
|
@@ -99,7 +99,7 @@ class EnvSpec:
|
|||||||
* **autoreset**: If to automatically reset the environment on episode end
|
* **autoreset**: If to automatically reset the environment on episode end
|
||||||
* **disable_env_checker**: If to disable the environment checker wrapper in :meth:`gymnasium.make`, by default False (runs the environment checker)
|
* **disable_env_checker**: If to disable the environment checker wrapper in :meth:`gymnasium.make`, by default False (runs the environment checker)
|
||||||
* **kwargs**: Additional keyword arguments passed to the environment during initialisation
|
* **kwargs**: Additional keyword arguments passed to the environment during initialisation
|
||||||
* **applied_wrappers**: A tuple of applied wrappers (WrapperSpec)
|
* **additional_wrappers**: A tuple of additional wrappers applied to the environment (WrapperSpec)
|
||||||
* **vector_entry_point**: The location of the vectorized environment to create from
|
* **vector_entry_point**: The location of the vectorized environment to create from
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -126,7 +126,7 @@ class EnvSpec:
|
|||||||
version: int | None = field(init=False)
|
version: int | None = field(init=False)
|
||||||
|
|
||||||
# applied wrappers
|
# applied wrappers
|
||||||
applied_wrappers: tuple[WrapperSpec, ...] = field(init=False, default_factory=tuple)
|
additional_wrappers: tuple[WrapperSpec, ...] = field(default_factory=tuple)
|
||||||
|
|
||||||
# Vectorized environment entry point
|
# Vectorized environment entry point
|
||||||
vector_entry_point: VectorEnvCreator | str | None = field(default=None)
|
vector_entry_point: VectorEnvCreator | str | None = field(default=None)
|
||||||
@@ -187,7 +187,7 @@ class EnvSpec:
|
|||||||
parsed_env_spec = json.loads(json_env_spec)
|
parsed_env_spec = json.loads(json_env_spec)
|
||||||
|
|
||||||
applied_wrapper_specs: list[WrapperSpec] = []
|
applied_wrapper_specs: list[WrapperSpec] = []
|
||||||
for wrapper_spec_json in parsed_env_spec.pop("applied_wrappers"):
|
for wrapper_spec_json in parsed_env_spec.pop("additional_wrappers"):
|
||||||
try:
|
try:
|
||||||
applied_wrapper_specs.append(WrapperSpec(**wrapper_spec_json))
|
applied_wrapper_specs.append(WrapperSpec(**wrapper_spec_json))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -197,7 +197,7 @@ class EnvSpec:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
env_spec = EnvSpec(**parsed_env_spec)
|
env_spec = EnvSpec(**parsed_env_spec)
|
||||||
env_spec.applied_wrappers = tuple(applied_wrapper_specs)
|
env_spec.additional_wrappers = tuple(applied_wrapper_specs)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"An issue occurred when trying to make {parsed_env_spec} an EnvSpec"
|
f"An issue occurred when trying to make {parsed_env_spec} an EnvSpec"
|
||||||
@@ -241,9 +241,9 @@ class EnvSpec:
|
|||||||
if print_all or self.apply_api_compatibility is not False:
|
if print_all or self.apply_api_compatibility is not False:
|
||||||
output += f"\napplied_api_compatibility={self.apply_api_compatibility}"
|
output += f"\napplied_api_compatibility={self.apply_api_compatibility}"
|
||||||
|
|
||||||
if print_all or self.applied_wrappers:
|
if print_all or self.additional_wrappers:
|
||||||
wrapper_output: list[str] = []
|
wrapper_output: list[str] = []
|
||||||
for wrapper_spec in self.applied_wrappers:
|
for wrapper_spec in self.additional_wrappers:
|
||||||
if include_entry_points:
|
if include_entry_points:
|
||||||
wrapper_output.append(
|
wrapper_output.append(
|
||||||
f"\n\tname={wrapper_spec.name}, entry_point={wrapper_spec.entry_point}, kwargs={wrapper_spec.kwargs}"
|
f"\n\tname={wrapper_spec.name}, entry_point={wrapper_spec.entry_point}, kwargs={wrapper_spec.kwargs}"
|
||||||
@@ -254,9 +254,9 @@ class EnvSpec:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if len(wrapper_output) == 0:
|
if len(wrapper_output) == 0:
|
||||||
output += "\napplied_wrappers=[]"
|
output += "\nadditional_wrappers=[]"
|
||||||
else:
|
else:
|
||||||
output += f"\napplied_wrappers=[{','.join(wrapper_output)}\n]"
|
output += f"\nadditional_wrappers=[{','.join(wrapper_output)}\n]"
|
||||||
|
|
||||||
if disable_print:
|
if disable_print:
|
||||||
return output
|
return output
|
||||||
@@ -555,161 +555,6 @@ def load_env_creator(name: str) -> EnvCreator | VectorEnvCreator:
|
|||||||
return fn
|
return fn
|
||||||
|
|
||||||
|
|
||||||
def _create_from_env_spec(
|
|
||||||
env_spec: EnvSpec,
|
|
||||||
kwargs: dict[str, Any],
|
|
||||||
) -> Env:
|
|
||||||
"""Recreates an environment spec using a list of wrapper specs."""
|
|
||||||
if callable(env_spec.entry_point):
|
|
||||||
env_creator = env_spec.entry_point
|
|
||||||
else:
|
|
||||||
env_creator: EnvCreator = load_env_creator(env_spec.entry_point)
|
|
||||||
|
|
||||||
# Create the environment
|
|
||||||
env: Env = env_creator(**env_spec.kwargs, **kwargs)
|
|
||||||
|
|
||||||
# Set the `EnvSpec` to the environment
|
|
||||||
new_env_spec = copy.deepcopy(env_spec)
|
|
||||||
new_env_spec.applied_wrappers = ()
|
|
||||||
new_env_spec.kwargs.update(kwargs)
|
|
||||||
env.unwrapped.spec = new_env_spec
|
|
||||||
|
|
||||||
# Check if the environment spec
|
|
||||||
assert env.spec is not None # this is for pyright
|
|
||||||
num_prior_wrappers = len(env.spec.applied_wrappers)
|
|
||||||
if env_spec.applied_wrappers[:num_prior_wrappers] != env.spec.applied_wrappers:
|
|
||||||
for env_spec_wrapper_spec, recreated_wrapper_spec in zip(
|
|
||||||
env_spec.applied_wrappers, env.spec.applied_wrappers
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
f"The environment's wrapper spec {recreated_wrapper_spec} is different from the saved `EnvSpec` applied_wrappers {env_spec_wrapper_spec}"
|
|
||||||
)
|
|
||||||
|
|
||||||
for wrapper_spec in env_spec.applied_wrappers[num_prior_wrappers:]:
|
|
||||||
if wrapper_spec.kwargs is None:
|
|
||||||
raise ValueError(
|
|
||||||
f"{wrapper_spec.name} wrapper does not inherit from `gymnasium.utils.RecordConstructorArgs`, therefore, the wrapper cannot be recreated."
|
|
||||||
)
|
|
||||||
|
|
||||||
env = load_env_creator(wrapper_spec.entry_point)(env=env, **wrapper_spec.kwargs)
|
|
||||||
|
|
||||||
return env
|
|
||||||
|
|
||||||
|
|
||||||
def _create_from_env_id(
|
|
||||||
env_spec: EnvSpec,
|
|
||||||
kwargs: dict[str, Any],
|
|
||||||
max_episode_steps: int | None = None,
|
|
||||||
autoreset: bool = False,
|
|
||||||
apply_api_compatibility: bool | None = None,
|
|
||||||
disable_env_checker: bool | None = None,
|
|
||||||
) -> Env:
|
|
||||||
"""Creates an environment based on the `env_spec` along with wrapper options. See `make` for their meaning."""
|
|
||||||
spec_kwargs = copy.deepcopy(env_spec.kwargs)
|
|
||||||
spec_kwargs.update(kwargs)
|
|
||||||
|
|
||||||
# Load the environment creator
|
|
||||||
if env_spec.entry_point is None:
|
|
||||||
raise error.Error(f"{env_spec.id} registered but entry_point is not specified")
|
|
||||||
elif callable(env_spec.entry_point):
|
|
||||||
env_creator = env_spec.entry_point
|
|
||||||
else:
|
|
||||||
# Assume it's a string
|
|
||||||
env_creator = load_env_creator(env_spec.entry_point)
|
|
||||||
|
|
||||||
# Determine if to use the rendering
|
|
||||||
render_modes: list[str] | None = None
|
|
||||||
if hasattr(env_creator, "metadata"):
|
|
||||||
_check_metadata(env_creator.metadata)
|
|
||||||
render_modes = env_creator.metadata.get("render_modes")
|
|
||||||
mode = spec_kwargs.get("render_mode")
|
|
||||||
apply_human_rendering = False
|
|
||||||
apply_render_collection = False
|
|
||||||
|
|
||||||
# If mode is not valid, try applying HumanRendering/RenderCollection wrappers
|
|
||||||
if mode is not None and render_modes is not None and mode not in render_modes:
|
|
||||||
displayable_modes = {"rgb_array", "rgb_array_list"}.intersection(render_modes)
|
|
||||||
if mode == "human" and len(displayable_modes) > 0:
|
|
||||||
logger.warn(
|
|
||||||
"You are trying to use 'human' rendering for an environment that doesn't natively support it. "
|
|
||||||
"The HumanRendering wrapper is being applied to your environment."
|
|
||||||
)
|
|
||||||
spec_kwargs["render_mode"] = displayable_modes.pop()
|
|
||||||
apply_human_rendering = True
|
|
||||||
elif mode.endswith("_list") and mode[: -len("_list")] in render_modes:
|
|
||||||
spec_kwargs["render_mode"] = mode[: -len("_list")]
|
|
||||||
apply_render_collection = True
|
|
||||||
else:
|
|
||||||
logger.warn(
|
|
||||||
f"The environment is being initialised with render_mode={mode!r} "
|
|
||||||
f"that is not in the possible render_modes ({render_modes})."
|
|
||||||
)
|
|
||||||
|
|
||||||
if apply_api_compatibility or (
|
|
||||||
apply_api_compatibility is None and env_spec.apply_api_compatibility
|
|
||||||
):
|
|
||||||
# If we use the compatibility layer, we treat the render mode explicitly and don't pass it to the env creator
|
|
||||||
render_mode = spec_kwargs.pop("render_mode", None)
|
|
||||||
else:
|
|
||||||
render_mode = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
env = env_creator(**spec_kwargs)
|
|
||||||
except TypeError as e:
|
|
||||||
if (
|
|
||||||
str(e).find("got an unexpected keyword argument 'render_mode'") >= 0
|
|
||||||
and apply_human_rendering
|
|
||||||
):
|
|
||||||
raise error.Error(
|
|
||||||
f"You passed render_mode='human' although {env_spec.id} doesn't implement human-rendering natively. "
|
|
||||||
"Gym tried to apply the HumanRendering wrapper but it looks like your environment is using the old "
|
|
||||||
"rendering API, which is not supported by the HumanRendering wrapper."
|
|
||||||
) from e
|
|
||||||
else:
|
|
||||||
raise e
|
|
||||||
|
|
||||||
# Copies the environment creation specification and kwargs to add to the environment specification details
|
|
||||||
env_spec = copy.deepcopy(env_spec)
|
|
||||||
env_spec.kwargs = spec_kwargs
|
|
||||||
env.unwrapped.spec = env_spec
|
|
||||||
|
|
||||||
# Add step API wrapper
|
|
||||||
if apply_api_compatibility is True or (
|
|
||||||
apply_api_compatibility is None and env_spec.apply_api_compatibility is True
|
|
||||||
):
|
|
||||||
env = EnvCompatibility(env, render_mode)
|
|
||||||
|
|
||||||
# Run the environment checker as the lowest level wrapper
|
|
||||||
if disable_env_checker is False or (
|
|
||||||
disable_env_checker is None and env_spec.disable_env_checker is False
|
|
||||||
):
|
|
||||||
env = PassiveEnvChecker(env)
|
|
||||||
|
|
||||||
# Add the order enforcing wrapper
|
|
||||||
if env_spec.order_enforce:
|
|
||||||
env = OrderEnforcing(env)
|
|
||||||
|
|
||||||
# Add the time limit wrapper
|
|
||||||
if max_episode_steps is not None:
|
|
||||||
assert env.unwrapped.spec is not None # for pyright
|
|
||||||
env.unwrapped.spec.max_episode_steps = max_episode_steps
|
|
||||||
env = TimeLimit(env, max_episode_steps)
|
|
||||||
elif env_spec.max_episode_steps is not None:
|
|
||||||
env = TimeLimit(env, env_spec.max_episode_steps)
|
|
||||||
|
|
||||||
# Add the auto-reset wrapper
|
|
||||||
if autoreset:
|
|
||||||
env = AutoResetWrapper(env)
|
|
||||||
|
|
||||||
# Add human rendering wrapper
|
|
||||||
if apply_human_rendering:
|
|
||||||
env = HumanRendering(env)
|
|
||||||
elif apply_render_collection:
|
|
||||||
env = RenderCollection(env)
|
|
||||||
|
|
||||||
return env
|
|
||||||
|
|
||||||
|
|
||||||
def load_plugin_envs(entry_point: str = "gymnasium.envs"):
|
def load_plugin_envs(entry_point: str = "gymnasium.envs"):
|
||||||
"""Load modules (plugins) using the gymnasium entry points in order to register external module's environments on ``import gymnasium``.
|
"""Load modules (plugins) using the gymnasium entry points in order to register external module's environments on ``import gymnasium``.
|
||||||
|
|
||||||
@@ -779,6 +624,7 @@ def register(
|
|||||||
autoreset: bool = False,
|
autoreset: bool = False,
|
||||||
disable_env_checker: bool = False,
|
disable_env_checker: bool = False,
|
||||||
apply_api_compatibility: bool = False,
|
apply_api_compatibility: bool = False,
|
||||||
|
additional_wrappers: tuple[WrapperSpec, ...] = (),
|
||||||
vector_entry_point: VectorEnvCreator | str | None = None,
|
vector_entry_point: VectorEnvCreator | str | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
):
|
):
|
||||||
@@ -801,6 +647,7 @@ def register(
|
|||||||
disable_env_checker: If to disable the :class:`gymnasium.wrappers.PassiveEnvChecker` to the environment.
|
disable_env_checker: If to disable the :class:`gymnasium.wrappers.PassiveEnvChecker` to the environment.
|
||||||
apply_api_compatibility: If to apply the :class:`gymnasium.wrappers.StepAPICompatibility` wrapper to the environment.
|
apply_api_compatibility: If to apply the :class:`gymnasium.wrappers.StepAPICompatibility` wrapper to the environment.
|
||||||
Use if the environment is implemented in the gym v0.21 environment API.
|
Use if the environment is implemented in the gym v0.21 environment API.
|
||||||
|
additional_wrappers: Additional wrappers to apply the environment.
|
||||||
vector_entry_point: The entry point for creating the vector environment
|
vector_entry_point: The entry point for creating the vector environment
|
||||||
**kwargs: arbitrary keyword arguments which are passed to the environment constructor on initialisation.
|
**kwargs: arbitrary keyword arguments which are passed to the environment constructor on initialisation.
|
||||||
"""
|
"""
|
||||||
@@ -835,8 +682,9 @@ def register(
|
|||||||
order_enforce=order_enforce,
|
order_enforce=order_enforce,
|
||||||
autoreset=autoreset,
|
autoreset=autoreset,
|
||||||
disable_env_checker=disable_env_checker,
|
disable_env_checker=disable_env_checker,
|
||||||
**kwargs,
|
|
||||||
apply_api_compatibility=apply_api_compatibility,
|
apply_api_compatibility=apply_api_compatibility,
|
||||||
|
**kwargs,
|
||||||
|
additional_wrappers=additional_wrappers,
|
||||||
vector_entry_point=vector_entry_point,
|
vector_entry_point=vector_entry_point,
|
||||||
)
|
)
|
||||||
_check_spec_register(new_spec)
|
_check_spec_register(new_spec)
|
||||||
@@ -849,7 +697,7 @@ def register(
|
|||||||
def make(
|
def make(
|
||||||
id: str | EnvSpec,
|
id: str | EnvSpec,
|
||||||
max_episode_steps: int | None = None,
|
max_episode_steps: int | None = None,
|
||||||
autoreset: bool = False,
|
autoreset: bool | None = None,
|
||||||
apply_api_compatibility: bool | None = None,
|
apply_api_compatibility: bool | None = None,
|
||||||
disable_env_checker: bool | None = None,
|
disable_env_checker: bool | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
@@ -878,32 +726,12 @@ def make(
|
|||||||
Error: If the ``id`` doesn't exist in the :attr:`registry`
|
Error: If the ``id`` doesn't exist in the :attr:`registry`
|
||||||
"""
|
"""
|
||||||
if isinstance(id, EnvSpec):
|
if isinstance(id, EnvSpec):
|
||||||
if hasattr(id, "applied_wrappers") and id.applied_wrappers is not None:
|
env_spec = id
|
||||||
if max_episode_steps is not None:
|
if not hasattr(env_spec, "additional_wrappers"):
|
||||||
logger.warn(
|
logger.warn(
|
||||||
f"For `gymnasium.make` with an `EnvSpec`, the `max_episode_step` parameter is ignored, use `gym.make({id.id}, max_episode_steps={max_episode_steps})` and any additional wrappers"
|
f"The env spec passed to `make` does not have a `additional_wrappers`, set it to an empty tuple. Env_spec={env_spec}"
|
||||||
)
|
|
||||||
if autoreset is True:
|
|
||||||
logger.warn(
|
|
||||||
f"For `gymnasium.make` with an `EnvSpec`, the `autoreset` parameter is ignored, use `gym.make({id.id}, autoreset={autoreset})` and any additional wrappers"
|
|
||||||
)
|
|
||||||
if apply_api_compatibility is not None:
|
|
||||||
logger.warn(
|
|
||||||
f"For `gymnasium.make` with an `EnvSpec`, the `apply_api_compatibility` parameter is ignored, use `gym.make({id.id}, apply_api_compatibility={apply_api_compatibility})` and any additional wrappers"
|
|
||||||
)
|
|
||||||
if disable_env_checker is not None:
|
|
||||||
logger.warn(
|
|
||||||
f"For `gymnasium.make` with an `EnvSpec`, the `disable_env_checker` parameter is ignored, use `gym.make({id.id}, disable_env_checker={disable_env_checker})` and any additional wrappers"
|
|
||||||
)
|
|
||||||
|
|
||||||
return _create_from_env_spec(
|
|
||||||
id,
|
|
||||||
kwargs,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"The EnvSpec used does not contain `applied_wrappers` parameters or is `None`. Expected to be a tuple, actually {id}."
|
|
||||||
)
|
)
|
||||||
|
env_spec.additional_wrappers = ()
|
||||||
else:
|
else:
|
||||||
# For string id's, load the environment spec from the registry then make the environment spec
|
# For string id's, load the environment spec from the registry then make the environment spec
|
||||||
assert isinstance(id, str)
|
assert isinstance(id, str)
|
||||||
@@ -911,14 +739,150 @@ def make(
|
|||||||
# The environment name can include an unloaded module in "module:env_name" style
|
# The environment name can include an unloaded module in "module:env_name" style
|
||||||
env_spec = _find_spec(id)
|
env_spec = _find_spec(id)
|
||||||
|
|
||||||
return _create_from_env_id(
|
assert isinstance(env_spec, EnvSpec)
|
||||||
env_spec,
|
|
||||||
kwargs,
|
# Update the env spec kwargs with the `make` kwargs
|
||||||
max_episode_steps=max_episode_steps,
|
env_spec_kwargs = copy.deepcopy(env_spec.kwargs)
|
||||||
autoreset=autoreset,
|
env_spec_kwargs.update(kwargs)
|
||||||
apply_api_compatibility=apply_api_compatibility,
|
|
||||||
disable_env_checker=disable_env_checker,
|
# Load the environment creator
|
||||||
)
|
if env_spec.entry_point is None:
|
||||||
|
raise error.Error(f"{env_spec.id} registered but entry_point is not specified")
|
||||||
|
elif callable(env_spec.entry_point):
|
||||||
|
env_creator = env_spec.entry_point
|
||||||
|
else:
|
||||||
|
# Assume it's a string
|
||||||
|
env_creator = load_env_creator(env_spec.entry_point)
|
||||||
|
|
||||||
|
# Determine if to use the rendering
|
||||||
|
render_modes: list[str] | None = None
|
||||||
|
if hasattr(env_creator, "metadata"):
|
||||||
|
_check_metadata(env_creator.metadata)
|
||||||
|
render_modes = env_creator.metadata.get("render_modes")
|
||||||
|
render_mode = env_spec_kwargs.get("render_mode")
|
||||||
|
apply_human_rendering = False
|
||||||
|
apply_render_collection = False
|
||||||
|
|
||||||
|
# If mode is not valid, try applying HumanRendering/RenderCollection wrappers
|
||||||
|
if (
|
||||||
|
render_mode is not None
|
||||||
|
and render_modes is not None
|
||||||
|
and render_mode not in render_modes
|
||||||
|
):
|
||||||
|
displayable_modes = {"rgb_array", "rgb_array_list"}.intersection(render_modes)
|
||||||
|
if render_mode == "human" and len(displayable_modes) > 0:
|
||||||
|
logger.warn(
|
||||||
|
"You are trying to use 'human' rendering for an environment that doesn't natively support it. "
|
||||||
|
"The HumanRendering wrapper is being applied to your environment."
|
||||||
|
)
|
||||||
|
env_spec_kwargs["render_mode"] = displayable_modes.pop()
|
||||||
|
apply_human_rendering = True
|
||||||
|
elif (
|
||||||
|
render_mode.endswith("_list")
|
||||||
|
and render_mode[: -len("_list")] in render_modes
|
||||||
|
):
|
||||||
|
env_spec_kwargs["render_mode"] = render_mode[: -len("_list")]
|
||||||
|
apply_render_collection = True
|
||||||
|
else:
|
||||||
|
logger.warn(
|
||||||
|
f"The environment is being initialised with render_mode={render_mode!r} "
|
||||||
|
f"that is not in the possible render_modes ({render_modes})."
|
||||||
|
)
|
||||||
|
|
||||||
|
if apply_api_compatibility or (
|
||||||
|
apply_api_compatibility is None and env_spec.apply_api_compatibility
|
||||||
|
):
|
||||||
|
# If we use the compatibility layer, we treat the render mode explicitly and don't pass it to the env creator
|
||||||
|
render_mode = env_spec_kwargs.pop("render_mode", None)
|
||||||
|
else:
|
||||||
|
render_mode = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
env = env_creator(**env_spec_kwargs)
|
||||||
|
except TypeError as e:
|
||||||
|
if (
|
||||||
|
str(e).find("got an unexpected keyword argument 'render_mode'") >= 0
|
||||||
|
and apply_human_rendering
|
||||||
|
):
|
||||||
|
raise error.Error(
|
||||||
|
f"You passed render_mode='human' although {env_spec.id} doesn't implement human-rendering natively. "
|
||||||
|
"Gym tried to apply the HumanRendering wrapper but it looks like your environment is using the old "
|
||||||
|
"rendering API, which is not supported by the HumanRendering wrapper."
|
||||||
|
) from e
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
# Set the minimal env spec for the environment.
|
||||||
|
env.unwrapped.spec = EnvSpec(
|
||||||
|
id=env_spec.id,
|
||||||
|
entry_point=env_spec.entry_point,
|
||||||
|
reward_threshold=env_spec.reward_threshold,
|
||||||
|
nondeterministic=env_spec.nondeterministic,
|
||||||
|
max_episode_steps=None,
|
||||||
|
order_enforce=False,
|
||||||
|
autoreset=False,
|
||||||
|
disable_env_checker=True,
|
||||||
|
apply_api_compatibility=False,
|
||||||
|
kwargs=env_spec_kwargs,
|
||||||
|
additional_wrappers=(),
|
||||||
|
vector_entry_point=env_spec.vector_entry_point,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if pre-wrapped wrappers
|
||||||
|
assert env.spec is not None
|
||||||
|
num_prior_wrappers = len(env.spec.additional_wrappers)
|
||||||
|
if (
|
||||||
|
env_spec.additional_wrappers[:num_prior_wrappers]
|
||||||
|
!= env.spec.additional_wrappers
|
||||||
|
):
|
||||||
|
for env_spec_wrapper_spec, recreated_wrapper_spec in zip(
|
||||||
|
env_spec.additional_wrappers, env.spec.additional_wrappers
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"The environment's wrapper spec {recreated_wrapper_spec} is different from the saved `EnvSpec` additional wrapper {env_spec_wrapper_spec}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add step API wrapper
|
||||||
|
if apply_api_compatibility is True or (
|
||||||
|
apply_api_compatibility is None and env_spec.apply_api_compatibility is True
|
||||||
|
):
|
||||||
|
env = EnvCompatibility(env, render_mode)
|
||||||
|
|
||||||
|
# Run the environment checker as the lowest level wrapper
|
||||||
|
if disable_env_checker is False or (
|
||||||
|
disable_env_checker is None and env_spec.disable_env_checker is False
|
||||||
|
):
|
||||||
|
env = PassiveEnvChecker(env)
|
||||||
|
|
||||||
|
# Add the order enforcing wrapper
|
||||||
|
if env_spec.order_enforce:
|
||||||
|
env = OrderEnforcing(env)
|
||||||
|
|
||||||
|
# Add the time limit wrapper
|
||||||
|
if max_episode_steps is not None:
|
||||||
|
env = TimeLimit(env, max_episode_steps)
|
||||||
|
elif env_spec.max_episode_steps is not None:
|
||||||
|
env = TimeLimit(env, env_spec.max_episode_steps)
|
||||||
|
|
||||||
|
# Add the auto-reset wrapper
|
||||||
|
if autoreset is True or (autoreset is None and env_spec.autoreset is True):
|
||||||
|
env = AutoResetWrapper(env)
|
||||||
|
|
||||||
|
for wrapper_spec in env_spec.additional_wrappers[num_prior_wrappers:]:
|
||||||
|
if wrapper_spec.kwargs is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"{wrapper_spec.name} wrapper does not inherit from `gymnasium.utils.RecordConstructorArgs`, therefore, the wrapper cannot be recreated."
|
||||||
|
)
|
||||||
|
|
||||||
|
env = load_env_creator(wrapper_spec.entry_point)(env=env, **wrapper_spec.kwargs)
|
||||||
|
|
||||||
|
# Add human rendering wrapper
|
||||||
|
if apply_human_rendering:
|
||||||
|
env = HumanRendering(env)
|
||||||
|
elif apply_render_collection:
|
||||||
|
env = RenderCollection(env)
|
||||||
|
|
||||||
|
return env
|
||||||
|
|
||||||
|
|
||||||
def make_vec(
|
def make_vec(
|
||||||
|
@@ -1,7 +1,16 @@
|
|||||||
"""Wrapper that autoreset environments when `terminated=True` or `truncated=True`."""
|
"""Wrapper that autoreset environments when `terminated=True` or `truncated=True`."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from copy import deepcopy
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from gymnasium.envs.registration import EnvSpec
|
||||||
|
|
||||||
|
|
||||||
class AutoResetWrapper(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
class AutoResetWrapper(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
||||||
"""A class for providing an automatic reset functionality for gymnasium environments when calling :meth:`self.step`.
|
"""A class for providing an automatic reset functionality for gymnasium environments when calling :meth:`self.step`.
|
||||||
|
|
||||||
@@ -60,3 +69,17 @@ class AutoResetWrapper(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
|||||||
info = new_info
|
info = new_info
|
||||||
|
|
||||||
return obs, reward, terminated, truncated, info
|
return obs, reward, terminated, truncated, info
|
||||||
|
|
||||||
|
@property
|
||||||
|
def spec(self) -> EnvSpec | None:
|
||||||
|
"""Modifies the environment spec to specify the `autoreset=True`."""
|
||||||
|
if self._cached_spec is not None:
|
||||||
|
return self._cached_spec
|
||||||
|
|
||||||
|
env_spec = self.env.spec
|
||||||
|
if env_spec is not None:
|
||||||
|
env_spec = deepcopy(env_spec)
|
||||||
|
env_spec.autoreset = True
|
||||||
|
|
||||||
|
self._cached_spec = env_spec
|
||||||
|
return env_spec
|
||||||
|
@@ -1,4 +1,9 @@
|
|||||||
"""A passive environment checker wrapper for an environment's observation and action space along with the reset, step and render functions."""
|
"""A passive environment checker wrapper for an environment's observation and action space along with the reset, step and render functions."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from copy import deepcopy
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
from gymnasium.core import ActType
|
from gymnasium.core import ActType
|
||||||
from gymnasium.utils.passive_env_checker import (
|
from gymnasium.utils.passive_env_checker import (
|
||||||
@@ -10,6 +15,10 @@ from gymnasium.utils.passive_env_checker import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from gymnasium.envs.registration import EnvSpec
|
||||||
|
|
||||||
|
|
||||||
class PassiveEnvChecker(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
class PassiveEnvChecker(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
||||||
"""A passive environment checker wrapper that surrounds the step, reset and render functions to check they follow the gymnasium API."""
|
"""A passive environment checker wrapper that surrounds the step, reset and render functions to check they follow the gymnasium API."""
|
||||||
|
|
||||||
@@ -54,3 +63,17 @@ class PassiveEnvChecker(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
|||||||
return env_render_passive_checker(self.env, *args, **kwargs)
|
return env_render_passive_checker(self.env, *args, **kwargs)
|
||||||
else:
|
else:
|
||||||
return self.env.render(*args, **kwargs)
|
return self.env.render(*args, **kwargs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def spec(self) -> EnvSpec | None:
|
||||||
|
"""Modifies the environment spec to such that `disable_env_checker=False`."""
|
||||||
|
if self._cached_spec is not None:
|
||||||
|
return self._cached_spec
|
||||||
|
|
||||||
|
env_spec = self.env.spec
|
||||||
|
if env_spec is not None:
|
||||||
|
env_spec = deepcopy(env_spec)
|
||||||
|
env_spec.disable_env_checker = False
|
||||||
|
|
||||||
|
self._cached_spec = env_spec
|
||||||
|
return env_spec
|
||||||
|
@@ -1,8 +1,17 @@
|
|||||||
"""Wrapper to enforce the proper ordering of environment operations."""
|
"""Wrapper to enforce the proper ordering of environment operations."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from copy import deepcopy
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
from gymnasium.error import ResetNeeded
|
from gymnasium.error import ResetNeeded
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from gymnasium.envs.registration import EnvSpec
|
||||||
|
|
||||||
|
|
||||||
class OrderEnforcing(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
class OrderEnforcing(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
||||||
"""A wrapper that will produce an error if :meth:`step` is called before an initial :meth:`reset`.
|
"""A wrapper that will produce an error if :meth:`step` is called before an initial :meth:`reset`.
|
||||||
|
|
||||||
@@ -64,3 +73,17 @@ class OrderEnforcing(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
|||||||
def has_reset(self):
|
def has_reset(self):
|
||||||
"""Returns if the environment has been reset before."""
|
"""Returns if the environment has been reset before."""
|
||||||
return self._has_reset
|
return self._has_reset
|
||||||
|
|
||||||
|
@property
|
||||||
|
def spec(self) -> EnvSpec | None:
|
||||||
|
"""Modifies the environment spec to add the `order_enforce=True`."""
|
||||||
|
if self._cached_spec is not None:
|
||||||
|
return self._cached_spec
|
||||||
|
|
||||||
|
env_spec = self.env.spec
|
||||||
|
if env_spec is not None:
|
||||||
|
env_spec = deepcopy(env_spec)
|
||||||
|
env_spec.order_enforce = True
|
||||||
|
|
||||||
|
self._cached_spec = env_spec
|
||||||
|
return env_spec
|
||||||
|
@@ -1,9 +1,16 @@
|
|||||||
"""Wrapper for limiting the time steps of an environment."""
|
"""Wrapper for limiting the time steps of an environment."""
|
||||||
from typing import Optional
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from copy import deepcopy
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from gymnasium.envs.registration import EnvSpec
|
||||||
|
|
||||||
|
|
||||||
class TimeLimit(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
class TimeLimit(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
||||||
"""This wrapper will issue a `truncated` signal if a maximum number of timesteps is exceeded.
|
"""This wrapper will issue a `truncated` signal if a maximum number of timesteps is exceeded.
|
||||||
|
|
||||||
@@ -20,7 +27,7 @@ class TimeLimit(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
env: gym.Env,
|
env: gym.Env,
|
||||||
max_episode_steps: Optional[int] = None,
|
max_episode_steps: int,
|
||||||
):
|
):
|
||||||
"""Initializes the :class:`TimeLimit` wrapper with an environment and the number of steps after which truncation will occur.
|
"""Initializes the :class:`TimeLimit` wrapper with an environment and the number of steps after which truncation will occur.
|
||||||
|
|
||||||
@@ -33,9 +40,6 @@ class TimeLimit(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
|||||||
)
|
)
|
||||||
gym.Wrapper.__init__(self, env)
|
gym.Wrapper.__init__(self, env)
|
||||||
|
|
||||||
if max_episode_steps is None and self.env.spec is not None:
|
|
||||||
assert env.spec is not None
|
|
||||||
max_episode_steps = env.spec.max_episode_steps
|
|
||||||
self._max_episode_steps = max_episode_steps
|
self._max_episode_steps = max_episode_steps
|
||||||
self._elapsed_steps = None
|
self._elapsed_steps = None
|
||||||
|
|
||||||
@@ -69,3 +73,17 @@ class TimeLimit(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
|||||||
"""
|
"""
|
||||||
self._elapsed_steps = 0
|
self._elapsed_steps = 0
|
||||||
return self.env.reset(**kwargs)
|
return self.env.reset(**kwargs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def spec(self) -> EnvSpec | None:
|
||||||
|
"""Modifies the environment spec to include the `max_episode_steps=self._max_episode_steps`."""
|
||||||
|
if self._cached_spec is not None:
|
||||||
|
return self._cached_spec
|
||||||
|
|
||||||
|
env_spec = self.env.spec
|
||||||
|
if env_spec is not None:
|
||||||
|
env_spec = deepcopy(env_spec)
|
||||||
|
env_spec.max_episode_steps = self._max_episode_steps
|
||||||
|
|
||||||
|
self._cached_spec = env_spec
|
||||||
|
return env_spec
|
||||||
|
@@ -1,4 +1,4 @@
|
|||||||
"""Example file showing usage of env.specstack."""
|
"""Test for the `EnvSpec`, in particular, a full integration with `EnvSpec`."""
|
||||||
import pickle
|
import pickle
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -11,15 +11,16 @@ from gymnasium.utils.env_checker import data_equivalence
|
|||||||
|
|
||||||
def test_full_integration():
|
def test_full_integration():
|
||||||
# Create an environment to test with
|
# Create an environment to test with
|
||||||
env = gym.make("CartPole-v1", render_mode="rgb_array").unwrapped
|
env = gym.make("CartPole-v1", render_mode="rgb_array")
|
||||||
|
|
||||||
env = gym.wrappers.FlattenObservation(env)
|
|
||||||
env = gym.wrappers.TimeAwareObservation(env)
|
env = gym.wrappers.TimeAwareObservation(env)
|
||||||
env = gym.wrappers.NormalizeReward(env, gamma=0.8)
|
env = gym.wrappers.NormalizeReward(env, gamma=0.8)
|
||||||
|
|
||||||
# Generate the spec_stack
|
# Generate the spec_stack
|
||||||
env_spec = env.spec
|
env_spec = env.spec
|
||||||
assert isinstance(env_spec, EnvSpec)
|
assert isinstance(env_spec, EnvSpec)
|
||||||
|
# additional_wrappers = (TimeAwareObservation, NormalizeReward)
|
||||||
|
assert len(env_spec.additional_wrappers) == 2
|
||||||
# env_spec.pprint()
|
# env_spec.pprint()
|
||||||
|
|
||||||
# Serialize the spec_stack
|
# Serialize the spec_stack
|
||||||
@@ -30,10 +31,7 @@ def test_full_integration():
|
|||||||
recreate_env_spec = EnvSpec.from_json(env_spec_json)
|
recreate_env_spec = EnvSpec.from_json(env_spec_json)
|
||||||
# recreate_env_spec.pprint()
|
# recreate_env_spec.pprint()
|
||||||
|
|
||||||
for wrapper_spec, recreated_wrapper_spec in zip(
|
assert env_spec.additional_wrappers == recreate_env_spec.additional_wrappers
|
||||||
env_spec.applied_wrappers, recreate_env_spec.applied_wrappers
|
|
||||||
):
|
|
||||||
assert wrapper_spec == recreated_wrapper_spec
|
|
||||||
assert recreate_env_spec == env_spec
|
assert recreate_env_spec == env_spec
|
||||||
|
|
||||||
# Recreate the environment using the spec_stack
|
# Recreate the environment using the spec_stack
|
||||||
@@ -86,43 +84,6 @@ def test_env_spec_to_from_json(env_spec: EnvSpec):
|
|||||||
assert env_spec == recreated_env_spec
|
assert env_spec == recreated_env_spec
|
||||||
|
|
||||||
|
|
||||||
def test_wrapped_env_entry_point():
|
|
||||||
def _create_env():
|
|
||||||
_env = gym.make("CartPole-v1", render_mode="rgb_array")
|
|
||||||
_env = gym.wrappers.FlattenObservation(_env)
|
|
||||||
return _env
|
|
||||||
|
|
||||||
gym.register("TestingEnv-v0", entry_point=_create_env)
|
|
||||||
|
|
||||||
env = gym.make("TestingEnv-v0")
|
|
||||||
env = gym.wrappers.TimeAwareObservation(env)
|
|
||||||
env = gym.wrappers.NormalizeReward(env, gamma=0.8)
|
|
||||||
|
|
||||||
recreated_env = gym.make(env.spec)
|
|
||||||
|
|
||||||
obs, info = env.reset(seed=42)
|
|
||||||
recreated_obs, recreated_info = recreated_env.reset(seed=42)
|
|
||||||
assert data_equivalence(obs, recreated_obs)
|
|
||||||
assert data_equivalence(info, recreated_info)
|
|
||||||
|
|
||||||
action = env.action_space.sample()
|
|
||||||
obs, reward, terminated, truncated, info = env.step(action)
|
|
||||||
(
|
|
||||||
recreated_obs,
|
|
||||||
recreated_reward,
|
|
||||||
recreated_terminated,
|
|
||||||
recreated_truncated,
|
|
||||||
recreated_info,
|
|
||||||
) = recreated_env.step(action)
|
|
||||||
assert data_equivalence(obs, recreated_obs)
|
|
||||||
assert data_equivalence(reward, recreated_reward)
|
|
||||||
assert data_equivalence(terminated, recreated_terminated)
|
|
||||||
assert data_equivalence(truncated, recreated_truncated)
|
|
||||||
assert data_equivalence(info, recreated_info)
|
|
||||||
|
|
||||||
del gym.registry["TestingEnv-v0"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_pickling_env_stack():
|
def test_pickling_env_stack():
|
||||||
env = gym.make("CartPole-v1", render_mode="rgb_array")
|
env = gym.make("CartPole-v1", render_mode="rgb_array")
|
||||||
|
|
||||||
@@ -163,6 +124,8 @@ def test_pickling_env_stack():
|
|||||||
|
|
||||||
def test_env_spec_pprint():
|
def test_env_spec_pprint():
|
||||||
env = gym.make("CartPole-v1")
|
env = gym.make("CartPole-v1")
|
||||||
|
env = gym.wrappers.TimeAwareObservation(env)
|
||||||
|
|
||||||
env_spec = env.spec
|
env_spec = env.spec
|
||||||
assert env_spec is not None
|
assert env_spec is not None
|
||||||
|
|
||||||
@@ -172,10 +135,8 @@ def test_env_spec_pprint():
|
|||||||
== """id=CartPole-v1
|
== """id=CartPole-v1
|
||||||
reward_threshold=475.0
|
reward_threshold=475.0
|
||||||
max_episode_steps=500
|
max_episode_steps=500
|
||||||
applied_wrappers=[
|
additional_wrappers=[
|
||||||
name=PassiveEnvChecker, kwargs={},
|
name=TimeAwareObservation, kwargs={}
|
||||||
name=OrderEnforcing, kwargs={'disable_render_order_enforcing': False},
|
|
||||||
name=TimeLimit, kwargs={'max_episode_steps': 500}
|
|
||||||
]"""
|
]"""
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -186,10 +147,8 @@ applied_wrappers=[
|
|||||||
entry_point=gymnasium.envs.classic_control.cartpole:CartPoleEnv
|
entry_point=gymnasium.envs.classic_control.cartpole:CartPoleEnv
|
||||||
reward_threshold=475.0
|
reward_threshold=475.0
|
||||||
max_episode_steps=500
|
max_episode_steps=500
|
||||||
applied_wrappers=[
|
additional_wrappers=[
|
||||||
name=PassiveEnvChecker, entry_point=gymnasium.wrappers.env_checker:PassiveEnvChecker, kwargs={},
|
name=TimeAwareObservation, entry_point=gymnasium.wrappers.time_aware_observation:TimeAwareObservation, kwargs={}
|
||||||
name=OrderEnforcing, entry_point=gymnasium.wrappers.order_enforcing:OrderEnforcing, kwargs={'disable_render_order_enforcing': False},
|
|
||||||
name=TimeLimit, entry_point=gymnasium.wrappers.time_limit:TimeLimit, kwargs={'max_episode_steps': 500}
|
|
||||||
]"""
|
]"""
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -205,14 +164,12 @@ order_enforce=True
|
|||||||
autoreset=False
|
autoreset=False
|
||||||
disable_env_checker=False
|
disable_env_checker=False
|
||||||
applied_api_compatibility=False
|
applied_api_compatibility=False
|
||||||
applied_wrappers=[
|
additional_wrappers=[
|
||||||
name=PassiveEnvChecker, kwargs={},
|
name=TimeAwareObservation, kwargs={}
|
||||||
name=OrderEnforcing, kwargs={'disable_render_order_enforcing': False},
|
|
||||||
name=TimeLimit, kwargs={'max_episode_steps': 500}
|
|
||||||
]"""
|
]"""
|
||||||
)
|
)
|
||||||
|
|
||||||
env_spec.applied_wrappers = ()
|
env_spec.additional_wrappers = ()
|
||||||
output = env_spec.pprint(disable_print=True)
|
output = env_spec.pprint(disable_print=True)
|
||||||
assert (
|
assert (
|
||||||
output
|
output
|
||||||
@@ -233,5 +190,5 @@ order_enforce=True
|
|||||||
autoreset=False
|
autoreset=False
|
||||||
disable_env_checker=False
|
disable_env_checker=False
|
||||||
applied_api_compatibility=False
|
applied_api_compatibility=False
|
||||||
applied_wrappers=[]"""
|
additional_wrappers=[]"""
|
||||||
)
|
)
|
||||||
|
@@ -1,4 +1,4 @@
|
|||||||
"""Tests that gym.make works as expected."""
|
"""Tests that `gym.make` works as expected."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import re
|
import re
|
||||||
@@ -11,6 +11,8 @@ import gymnasium as gym
|
|||||||
from gymnasium import Env
|
from gymnasium import Env
|
||||||
from gymnasium.core import ActType, ObsType, WrapperObsType
|
from gymnasium.core import ActType, ObsType, WrapperObsType
|
||||||
from gymnasium.envs.classic_control import CartPoleEnv
|
from gymnasium.envs.classic_control import CartPoleEnv
|
||||||
|
from gymnasium.error import NameNotFound
|
||||||
|
from gymnasium.utils.env_checker import data_equivalence
|
||||||
from gymnasium.wrappers import (
|
from gymnasium.wrappers import (
|
||||||
AutoResetWrapper,
|
AutoResetWrapper,
|
||||||
HumanRendering,
|
HumanRendering,
|
||||||
@@ -24,111 +26,101 @@ from tests.testing_env import GenericTestEnv, old_step_func
|
|||||||
from tests.wrappers.utils import has_wrapper
|
from tests.wrappers.utils import has_wrapper
|
||||||
|
|
||||||
|
|
||||||
try:
|
# Tests
|
||||||
import shimmy
|
# * basic example
|
||||||
except ImportError:
|
# * parameters (equivalent for str and EnvSpec)
|
||||||
shimmy = None
|
# 1. max_episode_steps
|
||||||
|
# 2. autoreset
|
||||||
|
# 3. apply_api_compatibility
|
||||||
|
# 4. disable_env_checker
|
||||||
|
# * rendering
|
||||||
|
# 1. render_mode
|
||||||
|
# 2. HumanRendering
|
||||||
|
# 3. RenderCollection
|
||||||
|
# * make kwargs
|
||||||
|
# * make import module
|
||||||
|
# * make env spec additional wrappers
|
||||||
|
# * env_id str errors
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
def test_no_arguments(env_id: str = "CartPole-v1"):
|
||||||
def register_testing_envs():
|
"""Test `gym.make` using str and EnvSpec with no arguments."""
|
||||||
"""Registers testing envs for `gym.make`"""
|
env_from_id = gym.make(env_id)
|
||||||
gym.register(
|
assert env_from_id.spec is not None
|
||||||
id="test.ArgumentEnv-v0",
|
assert env_from_id.spec.id == env_id
|
||||||
entry_point="tests.envs.registration.utils_envs:ArgumentEnv",
|
assert isinstance(env_from_id.unwrapped, CartPoleEnv)
|
||||||
kwargs={
|
|
||||||
"arg1": "arg1",
|
|
||||||
"arg2": "arg2",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
gym.register(
|
env_spec = gym.spec(env_id)
|
||||||
id="test/NoHuman-v0",
|
env_from_spec = gym.make(env_spec)
|
||||||
entry_point="tests.envs.registration.utils_envs:NoHuman",
|
assert env_from_spec.spec is not None
|
||||||
)
|
assert env_from_spec.spec.id == env_id
|
||||||
gym.register(
|
assert isinstance(env_from_spec.unwrapped, CartPoleEnv)
|
||||||
id="test/NoHumanOldAPI-v0",
|
|
||||||
entry_point="tests.envs.registration.utils_envs:NoHumanOldAPI",
|
|
||||||
)
|
|
||||||
|
|
||||||
gym.register(
|
assert env_from_id.spec == env_from_spec.spec
|
||||||
id="test/NoHumanNoRGB-v0",
|
|
||||||
entry_point="tests.envs.registration.utils_envs:NoHumanNoRGB",
|
|
||||||
)
|
|
||||||
|
|
||||||
gym.register(
|
|
||||||
id="test/NoRenderModesMetadata-v0",
|
|
||||||
entry_point="tests.envs.registration.utils_envs:NoRenderModesMetadata",
|
|
||||||
)
|
|
||||||
|
|
||||||
yield
|
|
||||||
|
|
||||||
del gym.envs.registration.registry["test.ArgumentEnv-v0"]
|
|
||||||
del gym.envs.registration.registry["test/NoRenderModesMetadata-v0"]
|
|
||||||
del gym.envs.registration.registry["test/NoHuman-v0"]
|
|
||||||
del gym.envs.registration.registry["test/NoHumanOldAPI-v0"]
|
|
||||||
del gym.envs.registration.registry["test/NoHumanNoRGB-v0"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_make():
|
def test_max_episode_steps(register_parameter_envs):
|
||||||
"""Test basic `gym.make`."""
|
"""Test the `max_episode_steps` parameter in `gym.make`."""
|
||||||
env = gym.make("CartPole-v1")
|
for make_id in ["CartPole-v1", gym.spec("CartPole-v1")]:
|
||||||
assert env.spec is not None
|
env_spec = gym.spec(make_id) if isinstance(make_id, str) else make_id
|
||||||
assert env.spec.id == "CartPole-v1"
|
|
||||||
assert isinstance(env.unwrapped, CartPoleEnv)
|
# Use the spec's value
|
||||||
env.close()
|
env = gym.make(make_id)
|
||||||
|
assert has_wrapper(env, TimeLimit)
|
||||||
|
assert env.spec is not None
|
||||||
|
assert env.spec.max_episode_steps == env_spec.max_episode_steps
|
||||||
|
|
||||||
|
# Set a custom max episode steps value
|
||||||
|
assert env_spec.max_episode_steps != 100
|
||||||
|
env = gym.make(make_id, max_episode_steps=100)
|
||||||
|
assert has_wrapper(env, TimeLimit)
|
||||||
|
assert env.spec is not None
|
||||||
|
assert env.spec.max_episode_steps == 100, make_id
|
||||||
|
|
||||||
|
for make_id in ["NoMaxEpisodeStepsEnv-v0", gym.spec("NoMaxEpisodeStepsEnv-v0")]:
|
||||||
|
env_spec = gym.spec(make_id) if isinstance(make_id, str) else make_id
|
||||||
|
|
||||||
|
# env spec has no max episode steps
|
||||||
|
assert env_spec.max_episode_steps is None
|
||||||
|
env = gym.make(make_id)
|
||||||
|
assert env.spec is not None
|
||||||
|
assert env.spec.max_episode_steps is None
|
||||||
|
assert has_wrapper(env, TimeLimit) is False
|
||||||
|
|
||||||
|
# set a custom max episode steps values
|
||||||
|
env = gym.make(make_id, max_episode_steps=100)
|
||||||
|
assert env.spec is not None
|
||||||
|
assert env.spec.max_episode_steps == 100
|
||||||
|
assert has_wrapper(env, TimeLimit)
|
||||||
|
|
||||||
|
|
||||||
def test_make_deprecated():
|
def test_autorest(register_parameter_envs):
|
||||||
"""Test make with a deprecated environment (i.e., doesn't exist)."""
|
"""Test the `autoreset` parameter in `gym.make`."""
|
||||||
with warnings.catch_warnings(record=True):
|
for make_id in [
|
||||||
with pytest.raises(
|
"CartPole-v1",
|
||||||
gym.error.Error,
|
gym.spec("CartPole-v1"),
|
||||||
match=re.escape(
|
"AutoresetEnv-v0",
|
||||||
"Environment version v0 for `Humanoid` is deprecated. Please use `Humanoid-v4` instead."
|
gym.spec("AutoresetEnv-v0"),
|
||||||
),
|
]:
|
||||||
):
|
env_spec = gym.spec(make_id) if isinstance(make_id, str) else make_id
|
||||||
gym.make("Humanoid-v0")
|
|
||||||
|
|
||||||
|
# Use the spec's value
|
||||||
|
env = gym.make(make_id)
|
||||||
|
assert env.spec is not None
|
||||||
|
assert env.spec.autoreset == env_spec.autoreset
|
||||||
|
assert has_wrapper(env, AutoResetWrapper) is env_spec.autoreset
|
||||||
|
|
||||||
def test_make_max_episode_steps(register_testing_envs):
|
# Set autoreset is True
|
||||||
# Default, uses the spec's
|
env = gym.make(make_id, autoreset=True)
|
||||||
env = gym.make("CartPole-v1")
|
assert has_wrapper(env, AutoResetWrapper)
|
||||||
assert has_wrapper(env, TimeLimit)
|
assert env.spec is not None
|
||||||
assert env.spec is not None
|
assert env.spec.autoreset is True
|
||||||
assert env.spec.max_episode_steps == gym.spec("CartPole-v1").max_episode_steps
|
|
||||||
env.close()
|
|
||||||
|
|
||||||
# Custom max episode steps
|
# Set autoreset is False
|
||||||
assert gym.spec("CartPole-v1").max_episode_steps != 100
|
env = gym.make(make_id, autoreset=False)
|
||||||
env = gym.make("CartPole-v1", max_episode_steps=100)
|
assert has_wrapper(env, AutoResetWrapper) is False
|
||||||
assert has_wrapper(env, TimeLimit)
|
assert env.spec is not None
|
||||||
assert env.spec is not None
|
assert env.spec.autoreset is False
|
||||||
assert env.spec.max_episode_steps == 100
|
|
||||||
env.close()
|
|
||||||
|
|
||||||
# Env spec has no max episode steps
|
|
||||||
assert gym.spec("test.ArgumentEnv-v0").max_episode_steps is None
|
|
||||||
env = gym.make("test.ArgumentEnv-v0", arg1=None, arg2=None, arg3=None)
|
|
||||||
assert env.spec is not None
|
|
||||||
assert env.spec.max_episode_steps is None
|
|
||||||
assert has_wrapper(env, TimeLimit) is False
|
|
||||||
env.close()
|
|
||||||
|
|
||||||
|
|
||||||
def test_make_autoreset():
|
|
||||||
"""Tests that `gym.make` autoreset wrapper is applied only when `gym.make(..., autoreset=True)`."""
|
|
||||||
env = gym.make("CartPole-v1")
|
|
||||||
assert has_wrapper(env, AutoResetWrapper) is False
|
|
||||||
env.close()
|
|
||||||
|
|
||||||
env = gym.make("CartPole-v1", autoreset=False)
|
|
||||||
assert has_wrapper(env, AutoResetWrapper) is False
|
|
||||||
env.close()
|
|
||||||
|
|
||||||
env = gym.make("CartPole-v1", autoreset=True)
|
|
||||||
assert has_wrapper(env, AutoResetWrapper)
|
|
||||||
env.close()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@@ -142,7 +134,7 @@ def test_make_autoreset():
|
|||||||
[True, None, True],
|
[True, None, True],
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_make_disable_env_checker(
|
def test_disable_env_checker(
|
||||||
registration_disabled: bool, make_disabled: bool | None, if_disabled: bool
|
registration_disabled: bool, make_disabled: bool | None, if_disabled: bool
|
||||||
):
|
):
|
||||||
"""Tests that `gym.make` disable env checker is applied only when `gym.make(..., disable_env_checker=False)`.
|
"""Tests that `gym.make` disable env checker is applied only when `gym.make(..., disable_env_checker=False)`.
|
||||||
@@ -150,84 +142,76 @@ def test_make_disable_env_checker(
|
|||||||
The ordering is 1. if the `make(..., disable_env_checker=...)` is bool, then the `registration(..., disable_env_checker=...)`
|
The ordering is 1. if the `make(..., disable_env_checker=...)` is bool, then the `registration(..., disable_env_checker=...)`
|
||||||
"""
|
"""
|
||||||
gym.register(
|
gym.register(
|
||||||
"testing-env-v0",
|
"DisableEnvCheckerEnv-v0",
|
||||||
lambda: GenericTestEnv(),
|
lambda: GenericTestEnv(),
|
||||||
disable_env_checker=registration_disabled,
|
disable_env_checker=registration_disabled,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Test when the registered EnvSpec.disable_env_checker = False
|
# Test when the registered EnvSpec.disable_env_checker = False
|
||||||
env = gym.make("testing-env-v0", disable_env_checker=make_disabled)
|
env = gym.make("DisableEnvCheckerEnv-v0", disable_env_checker=make_disabled)
|
||||||
assert has_wrapper(env, PassiveEnvChecker) is not if_disabled
|
assert has_wrapper(env, PassiveEnvChecker) is not if_disabled
|
||||||
env.close()
|
|
||||||
|
|
||||||
del gym.registry["testing-env-v0"]
|
env_spec = gym.spec("DisableEnvCheckerEnv-v0")
|
||||||
|
env = gym.make(env_spec, disable_env_checker=make_disabled)
|
||||||
|
assert has_wrapper(env, PassiveEnvChecker) is not if_disabled
|
||||||
|
|
||||||
|
del gym.registry["DisableEnvCheckerEnv-v0"]
|
||||||
|
|
||||||
|
|
||||||
def test_make_apply_api_compatibility():
|
def test_apply_api_compatibility(register_parameter_envs):
|
||||||
"""Test the API compatibility wrapper."""
|
"""Test the `apply_api_compatibility` parameter for `gym.make`."""
|
||||||
gym.register(
|
|
||||||
"testing-old-env",
|
|
||||||
lambda: GenericTestEnv(step_func=old_step_func),
|
|
||||||
apply_api_compatibility=True,
|
|
||||||
max_episode_steps=3,
|
|
||||||
)
|
|
||||||
# Apply the environment compatibility and check it works as intended
|
# Apply the environment compatibility and check it works as intended
|
||||||
env = gym.make("testing-old-env")
|
for make_id in ["EnabledApplyApiComp-v0", gym.spec("EnabledApplyApiComp-v0")]:
|
||||||
assert isinstance(env.unwrapped, gym.wrappers.EnvCompatibility)
|
env = gym.make(make_id)
|
||||||
|
assert isinstance(env.unwrapped, gym.wrappers.EnvCompatibility)
|
||||||
|
|
||||||
env.reset()
|
# env has time limit of 3 enabling this test
|
||||||
assert len(env.step(env.action_space.sample())) == 5
|
env.reset()
|
||||||
env.step(env.action_space.sample())
|
assert len(env.step(env.action_space.sample())) == 5
|
||||||
_, _, termination, truncation, _ = env.step(env.action_space.sample())
|
|
||||||
assert termination is False and truncation is True
|
|
||||||
|
|
||||||
# Turn off the spec api compatibility
|
|
||||||
gym.spec("testing-old-env").apply_api_compatibility = False
|
|
||||||
env = gym.make("testing-old-env")
|
|
||||||
assert isinstance(env.unwrapped, gym.wrappers.EnvCompatibility) is False
|
|
||||||
env.reset()
|
|
||||||
with pytest.raises(
|
|
||||||
ValueError, match=re.escape("not enough values to unpack (expected 5, got 4)")
|
|
||||||
):
|
|
||||||
env.step(env.action_space.sample())
|
env.step(env.action_space.sample())
|
||||||
|
_, _, termination, truncation, _ = env.step(env.action_space.sample())
|
||||||
|
assert termination is False and truncation is True
|
||||||
|
|
||||||
# Apply the environment compatibility and check it works as intended
|
for make_id in ["DisabledApplyApiComp-v0", gym.spec("DisabledApplyApiComp-v0")]:
|
||||||
env = gym.make("testing-old-env", apply_api_compatibility=True)
|
# Turn off the spec api compatibility
|
||||||
assert isinstance(env.unwrapped, gym.wrappers.EnvCompatibility)
|
env = gym.make(make_id)
|
||||||
|
assert isinstance(env.unwrapped, gym.wrappers.EnvCompatibility) is False
|
||||||
|
env.reset()
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match=re.escape("not enough values to unpack (expected 5, got 4)"),
|
||||||
|
):
|
||||||
|
env.step(env.action_space.sample())
|
||||||
|
|
||||||
env.reset()
|
# Apply the environment compatibility and check it works as intended
|
||||||
assert len(env.step(env.action_space.sample())) == 5
|
assert env.spec is not None
|
||||||
env.step(env.action_space.sample())
|
assert env.spec.apply_api_compatibility is False
|
||||||
_, _, termination, truncation, _ = env.step(env.action_space.sample())
|
env = gym.make(make_id, apply_api_compatibility=True)
|
||||||
assert termination is False and truncation is True
|
assert isinstance(env.unwrapped, gym.wrappers.EnvCompatibility)
|
||||||
|
|
||||||
del gym.registry["testing-old-env"]
|
env.reset()
|
||||||
|
assert len(env.step(env.action_space.sample())) == 5
|
||||||
|
env.step(env.action_space.sample())
|
||||||
|
_, _, termination, truncation, _ = env.step(env.action_space.sample())
|
||||||
|
assert termination is False and truncation is True
|
||||||
|
|
||||||
|
|
||||||
def test_make_order_enforcing():
|
def test_order_enforcing(register_parameter_envs):
|
||||||
"""Checks that gym.make wrappers the environment with the OrderEnforcing wrapper."""
|
"""Checks that gym.make wrappers the environment with the OrderEnforcing wrapper."""
|
||||||
assert all(spec.order_enforce is True for spec in all_testing_env_specs)
|
assert all(spec.order_enforce is False for spec in all_testing_env_specs)
|
||||||
|
|
||||||
env = gym.make("CartPole-v1")
|
for make_id in ["CartPole-v1", gym.spec("CartPole-v1")]:
|
||||||
assert has_wrapper(env, OrderEnforcing)
|
env = gym.make(make_id)
|
||||||
# We can assume that there all other specs will also have the order enforcing
|
assert has_wrapper(env, OrderEnforcing)
|
||||||
env.close()
|
|
||||||
|
|
||||||
gym.register(
|
for make_id in ["OrderlessEnv-v0", gym.spec("OrderlessEnv-v0")]:
|
||||||
id="test.OrderlessArgumentEnv-v0",
|
env = gym.make(make_id)
|
||||||
entry_point="tests.envs.registration.utils_envs:ArgumentEnv",
|
assert has_wrapper(env, OrderEnforcing) is False
|
||||||
order_enforce=False,
|
|
||||||
kwargs={"arg1": None, "arg2": None, "arg3": None},
|
|
||||||
)
|
|
||||||
|
|
||||||
env = gym.make("test.OrderlessArgumentEnv-v0")
|
|
||||||
assert has_wrapper(env, OrderEnforcing) is False
|
|
||||||
env.close()
|
|
||||||
|
|
||||||
# There is no `make(..., order_enforcing=...)` so we don't test that
|
# There is no `make(..., order_enforcing=...)` so we don't test that
|
||||||
|
|
||||||
|
|
||||||
def test_make_render_mode():
|
def test_make_with_render_mode():
|
||||||
"""Test the `make(..., render_mode=...)`, in particular, if to apply the `RenderCollection` or the `HumanRendering`."""
|
"""Test the `make(..., render_mode=...)`, in particular, if to apply the `RenderCollection` or the `HumanRendering`."""
|
||||||
env = gym.make("CartPole-v1", render_mode=None)
|
env = gym.make("CartPole-v1", render_mode=None)
|
||||||
assert env.render_mode is None
|
assert env.render_mode is None
|
||||||
@@ -267,10 +251,12 @@ def test_make_render_collection():
|
|||||||
env.close()
|
env.close()
|
||||||
|
|
||||||
|
|
||||||
def test_make_human_rendering(register_testing_envs):
|
def test_make_human_rendering(register_rendering_testing_envs):
|
||||||
# Make sure that native rendering is used when possible
|
# Make sure that native rendering is used when possible
|
||||||
env = gym.make("CartPole-v1", render_mode="human")
|
env = gym.make("CartPole-v1", render_mode="human")
|
||||||
assert not has_wrapper(env, HumanRendering) # Should use native human-rendering
|
assert (
|
||||||
|
has_wrapper(env, HumanRendering) is False
|
||||||
|
) # Should use native human-rendering
|
||||||
assert env.render_mode == "human"
|
assert env.render_mode == "human"
|
||||||
env.close()
|
env.close()
|
||||||
|
|
||||||
@@ -281,7 +267,7 @@ def test_make_human_rendering(register_testing_envs):
|
|||||||
),
|
),
|
||||||
):
|
):
|
||||||
# Make sure that `HumanRendering` is applied here as the environment doesn't use native rendering
|
# Make sure that `HumanRendering` is applied here as the environment doesn't use native rendering
|
||||||
env = gym.make("test/NoHuman-v0", render_mode="human")
|
env = gym.make("NoHumanRendering-v0", render_mode="human")
|
||||||
assert has_wrapper(env, HumanRendering)
|
assert has_wrapper(env, HumanRendering)
|
||||||
assert env.render_mode == "human"
|
assert env.render_mode == "human"
|
||||||
env.close()
|
env.close()
|
||||||
@@ -290,7 +276,7 @@ def test_make_human_rendering(register_testing_envs):
|
|||||||
TypeError, match=re.escape("got an unexpected keyword argument 'render_mode'")
|
TypeError, match=re.escape("got an unexpected keyword argument 'render_mode'")
|
||||||
):
|
):
|
||||||
gym.make(
|
gym.make(
|
||||||
"test/NoHumanOldAPI-v0",
|
"NoHumanRenderingOldAPI-v0",
|
||||||
render_mode="rgb_array_list",
|
render_mode="rgb_array_list",
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -299,10 +285,10 @@ def test_make_human_rendering(register_testing_envs):
|
|||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
gym.error.Error,
|
gym.error.Error,
|
||||||
match=re.escape(
|
match=re.escape(
|
||||||
"You passed render_mode='human' although test/NoHumanOldAPI-v0 doesn't implement human-rendering natively."
|
"You passed render_mode='human' although NoHumanRenderingOldAPI-v0 doesn't implement human-rendering natively."
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
gym.make("test/NoHumanOldAPI-v0", render_mode="human")
|
gym.make("NoHumanRenderingOldAPI-v0", render_mode="human")
|
||||||
|
|
||||||
# This test ensures that the additional exception "Gym tried to apply the HumanRendering wrapper but it looks like
|
# This test ensures that the additional exception "Gym tried to apply the HumanRendering wrapper but it looks like
|
||||||
# your environment is using the old rendering API" is *not* triggered by a TypeError that originate from
|
# your environment is using the old rendering API" is *not* triggered by a TypeError that originate from
|
||||||
@@ -320,10 +306,10 @@ def test_make_human_rendering(register_testing_envs):
|
|||||||
"\x1b[33mWARN: The environment is being initialised with render_mode='rgb_array' that is not in the possible render_modes ([]).\x1b[0m"
|
"\x1b[33mWARN: The environment is being initialised with render_mode='rgb_array' that is not in the possible render_modes ([]).\x1b[0m"
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
gym.make("test/NoRenderModesMetadata-v0", render_mode="rgb_array")
|
gym.make("NoRenderModesMetadata-v0", render_mode="rgb_array")
|
||||||
|
|
||||||
|
|
||||||
def test_make_kwargs(register_testing_envs):
|
def test_make_kwargs(register_kwargs_env):
|
||||||
env = gym.make(
|
env = gym.make(
|
||||||
"test.ArgumentEnv-v0",
|
"test.ArgumentEnv-v0",
|
||||||
arg2="override_arg2",
|
arg2="override_arg2",
|
||||||
@@ -367,12 +353,11 @@ class NoRecordArgsWrapper(gym.ObservationWrapper):
|
|||||||
return self.observation_space.sample()
|
return self.observation_space.sample()
|
||||||
|
|
||||||
|
|
||||||
def test_make_env_spec():
|
def test_make_with_env_spec():
|
||||||
# make
|
# make
|
||||||
env_1 = gym.make(gym.spec("CartPole-v1"))
|
id_env = gym.make("CartPole-v1")
|
||||||
assert isinstance(env_1, CartPoleEnv)
|
spec_env = gym.make(gym.spec("CartPole-v1"))
|
||||||
assert env_1 is env_1.unwrapped
|
assert id_env.spec == spec_env.spec
|
||||||
env_1.close()
|
|
||||||
|
|
||||||
# make with applied wrappers
|
# make with applied wrappers
|
||||||
env_2 = gym.wrappers.NormalizeReward(
|
env_2 = gym.wrappers.NormalizeReward(
|
||||||
@@ -396,30 +381,202 @@ def test_make_env_spec():
|
|||||||
|
|
||||||
# make with wrapper in env-creator
|
# make with wrapper in env-creator
|
||||||
gym.register(
|
gym.register(
|
||||||
"CartPole-v3", lambda: gym.wrappers.TimeAwareObservation(CartPoleEnv())
|
"CartPole-v3",
|
||||||
|
lambda: gym.wrappers.TimeAwareObservation(CartPoleEnv()),
|
||||||
|
disable_env_checker=True,
|
||||||
|
order_enforce=False,
|
||||||
)
|
)
|
||||||
env_4 = gym.make(gym.spec("CartPole-v3"))
|
env_4 = gym.make(gym.spec("CartPole-v3"))
|
||||||
assert isinstance(env_4, gym.wrappers.TimeAwareObservation)
|
assert isinstance(env_4, gym.wrappers.TimeAwareObservation)
|
||||||
assert isinstance(env_4.env, CartPoleEnv)
|
assert isinstance(env_4.env, CartPoleEnv)
|
||||||
env_4.close()
|
env_4.close()
|
||||||
|
|
||||||
# make with no ezpickle wrapper
|
gym.register(
|
||||||
env_5 = NoRecordArgsWrapper(gym.make("CartPole-v1").unwrapped)
|
"CartPole-v4",
|
||||||
|
lambda: CartPoleEnv(),
|
||||||
|
disable_env_checker=True,
|
||||||
|
order_enforce=False,
|
||||||
|
additional_wrappers=(gym.wrappers.TimeAwareObservation.wrapper_spec(),),
|
||||||
|
)
|
||||||
|
env_5 = gym.make(gym.spec("CartPole-v4"))
|
||||||
|
assert isinstance(env_5, gym.wrappers.TimeAwareObservation)
|
||||||
|
assert isinstance(env_5.env, CartPoleEnv)
|
||||||
env_5.close()
|
env_5.close()
|
||||||
|
|
||||||
|
# make with no ezpickle wrapper
|
||||||
|
env_6 = NoRecordArgsWrapper(gym.make("CartPole-v1"))
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
ValueError,
|
ValueError,
|
||||||
match=re.escape(
|
match=re.escape(
|
||||||
"NoRecordArgsWrapper wrapper does not inherit from `gymnasium.utils.RecordConstructorArgs`, therefore, the wrapper cannot be recreated."
|
"NoRecordArgsWrapper wrapper does not inherit from `gymnasium.utils.RecordConstructorArgs`, therefore, the wrapper cannot be recreated."
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
gym.make(env_5.spec)
|
gym.make(env_6.spec)
|
||||||
|
|
||||||
# make with no ezpickle wrapper but in the entry point
|
# make with no ezpickle wrapper but in the entry point
|
||||||
gym.register("CartPole-v4", entry_point=lambda: NoRecordArgsWrapper(CartPoleEnv()))
|
gym.register(
|
||||||
env_6 = gym.make(gym.spec("CartPole-v4"))
|
"CartPole-v5",
|
||||||
assert isinstance(env_6, NoRecordArgsWrapper)
|
entry_point=lambda: NoRecordArgsWrapper(CartPoleEnv()),
|
||||||
assert isinstance(env_6.unwrapped, CartPoleEnv)
|
disable_env_checker=True,
|
||||||
|
order_enforce=False,
|
||||||
|
)
|
||||||
|
env_7 = gym.make(gym.spec("CartPole-v5"))
|
||||||
|
assert isinstance(env_7, NoRecordArgsWrapper)
|
||||||
|
assert isinstance(env_7.unwrapped, CartPoleEnv)
|
||||||
|
|
||||||
|
gym.register(
|
||||||
|
"CartPole-v6",
|
||||||
|
entry_point=lambda: CartPoleEnv(),
|
||||||
|
disable_env_checker=True,
|
||||||
|
order_enforce=False,
|
||||||
|
additional_wrappers=(NoRecordArgsWrapper.wrapper_spec(),),
|
||||||
|
)
|
||||||
|
|
||||||
del gym.registry["CartPole-v2"]
|
del gym.registry["CartPole-v2"]
|
||||||
del gym.registry["CartPole-v3"]
|
del gym.registry["CartPole-v3"]
|
||||||
del gym.registry["CartPole-v4"]
|
del gym.registry["CartPole-v4"]
|
||||||
|
del gym.registry["CartPole-v5"]
|
||||||
|
del gym.registry["CartPole-v6"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_make_with_env_spec_levels():
|
||||||
|
"""Test that we can recreate the environment at each 'level'."""
|
||||||
|
env = gym.wrappers.NormalizeReward(
|
||||||
|
gym.wrappers.TimeAwareObservation(
|
||||||
|
gym.wrappers.FlattenObservation(
|
||||||
|
gym.make("CartPole-v1", render_mode="rgb_array")
|
||||||
|
)
|
||||||
|
),
|
||||||
|
gamma=0.8,
|
||||||
|
)
|
||||||
|
|
||||||
|
while env is not env.unwrapped:
|
||||||
|
recreated_env = gym.make(env.spec)
|
||||||
|
assert env.spec == recreated_env.spec
|
||||||
|
|
||||||
|
env = env.env
|
||||||
|
|
||||||
|
|
||||||
|
def test_wrapped_env_entry_point():
|
||||||
|
def _create_env():
|
||||||
|
_env = gym.make("CartPole-v1", render_mode="rgb_array")
|
||||||
|
_env = gym.wrappers.FlattenObservation(_env)
|
||||||
|
return _env
|
||||||
|
|
||||||
|
gym.register("TestingEnv-v0", entry_point=_create_env)
|
||||||
|
|
||||||
|
env = gym.make("TestingEnv-v0")
|
||||||
|
env = gym.wrappers.TimeAwareObservation(env)
|
||||||
|
env = gym.wrappers.NormalizeReward(env, gamma=0.8)
|
||||||
|
|
||||||
|
recreated_env = gym.make(env.spec)
|
||||||
|
|
||||||
|
obs, info = env.reset(seed=42)
|
||||||
|
recreated_obs, recreated_info = recreated_env.reset(seed=42)
|
||||||
|
assert data_equivalence(obs, recreated_obs)
|
||||||
|
assert data_equivalence(info, recreated_info)
|
||||||
|
|
||||||
|
action = env.action_space.sample()
|
||||||
|
obs, reward, terminated, truncated, info = env.step(action)
|
||||||
|
(
|
||||||
|
recreated_obs,
|
||||||
|
recreated_reward,
|
||||||
|
recreated_terminated,
|
||||||
|
recreated_truncated,
|
||||||
|
recreated_info,
|
||||||
|
) = recreated_env.step(action)
|
||||||
|
assert data_equivalence(obs, recreated_obs)
|
||||||
|
assert data_equivalence(reward, recreated_reward)
|
||||||
|
assert data_equivalence(terminated, recreated_terminated)
|
||||||
|
assert data_equivalence(truncated, recreated_truncated)
|
||||||
|
assert data_equivalence(info, recreated_info)
|
||||||
|
|
||||||
|
del gym.registry["TestingEnv-v0"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_make_errors():
|
||||||
|
"""Test make with a deprecated environment (i.e., doesn't exist)."""
|
||||||
|
with warnings.catch_warnings(record=True):
|
||||||
|
with pytest.raises(
|
||||||
|
gym.error.Error,
|
||||||
|
match=re.escape(
|
||||||
|
"Environment version v0 for `Humanoid` is deprecated. Please use `Humanoid-v4` instead."
|
||||||
|
),
|
||||||
|
):
|
||||||
|
gym.make("Humanoid-v0")
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
NameNotFound, match=re.escape("Environment `NonExistenceEnv` doesn't exist.")
|
||||||
|
):
|
||||||
|
gym.make("NonExistenceEnv-v0")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
def register_parameter_envs():
|
||||||
|
gym.register(
|
||||||
|
"NoMaxEpisodeStepsEnv-v0", lambda: GenericTestEnv(), max_episode_steps=None
|
||||||
|
)
|
||||||
|
gym.register("AutoresetEnv-v0", lambda: GenericTestEnv(), autoreset=True)
|
||||||
|
|
||||||
|
gym.register(
|
||||||
|
"EnabledApplyApiComp-v0",
|
||||||
|
lambda: GenericTestEnv(step_func=old_step_func),
|
||||||
|
apply_api_compatibility=True,
|
||||||
|
max_episode_steps=3,
|
||||||
|
)
|
||||||
|
gym.register(
|
||||||
|
"DisabledApplyApiComp-v0",
|
||||||
|
lambda: GenericTestEnv(step_func=old_step_func),
|
||||||
|
apply_api_compatibility=False,
|
||||||
|
max_episode_steps=3,
|
||||||
|
)
|
||||||
|
gym.register("OrderlessEnv-v0", lambda: GenericTestEnv(), order_enforce=False)
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
del gym.registry["NoMaxEpisodeStepsEnv-v0"]
|
||||||
|
del gym.registry["AutoresetEnv-v0"]
|
||||||
|
del gym.registry["EnabledApplyApiComp-v0"]
|
||||||
|
del gym.registry["DisabledApplyApiComp-v0"]
|
||||||
|
del gym.registry["OrderlessEnv-v0"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
def register_kwargs_env():
|
||||||
|
gym.register(
|
||||||
|
id="test.ArgumentEnv-v0",
|
||||||
|
entry_point="tests.envs.registration.utils_envs:ArgumentEnv",
|
||||||
|
kwargs={
|
||||||
|
"arg1": "arg1",
|
||||||
|
"arg2": "arg2",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
def register_rendering_testing_envs():
|
||||||
|
gym.register(
|
||||||
|
id="NoHumanRendering-v0",
|
||||||
|
entry_point="tests.envs.registration.utils_envs:NoHuman",
|
||||||
|
)
|
||||||
|
gym.register(
|
||||||
|
id="NoHumanRenderingOldAPI-v0",
|
||||||
|
entry_point="tests.envs.registration.utils_envs:NoHumanOldAPI",
|
||||||
|
)
|
||||||
|
|
||||||
|
gym.register(
|
||||||
|
id="NoHumanRenderingNoRGB-v0",
|
||||||
|
entry_point="tests.envs.registration.utils_envs:NoHumanNoRGB",
|
||||||
|
)
|
||||||
|
|
||||||
|
gym.register(
|
||||||
|
id="NoRenderModesMetadata-v0",
|
||||||
|
entry_point="tests.envs.registration.utils_envs:NoRenderModesMetadata",
|
||||||
|
)
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
del gym.envs.registration.registry["NoHumanRendering-v0"]
|
||||||
|
del gym.envs.registration.registry["NoHumanRenderingOldAPI-v0"]
|
||||||
|
del gym.envs.registration.registry["NoHumanRenderingNoRGB-v0"]
|
||||||
|
del gym.envs.registration.registry["NoRenderModesMetadata-v0"]
|
||||||
|
@@ -171,7 +171,7 @@ def test_compatibility_with_old_style_env():
|
|||||||
"""Test compatibility with old style environment."""
|
"""Test compatibility with old style environment."""
|
||||||
env = OldStyleEnv()
|
env = OldStyleEnv()
|
||||||
env = OrderEnforcing(env)
|
env = OrderEnforcing(env)
|
||||||
env = TimeLimit(env)
|
env = TimeLimit(env, 100)
|
||||||
obs = env.reset()
|
obs = env.reset()
|
||||||
assert obs == 0
|
assert obs == 0
|
||||||
|
|
||||||
|
@@ -7,7 +7,7 @@ from gymnasium.wrappers import TimeLimit
|
|||||||
|
|
||||||
def test_time_limit_reset_info():
|
def test_time_limit_reset_info():
|
||||||
env = gym.make("CartPole-v1", disable_env_checker=True)
|
env = gym.make("CartPole-v1", disable_env_checker=True)
|
||||||
env = TimeLimit(env)
|
env = TimeLimit(env, 100)
|
||||||
ob_space = env.observation_space
|
ob_space = env.observation_space
|
||||||
obs, info = env.reset()
|
obs, info = env.reset()
|
||||||
assert ob_space.contains(obs)
|
assert ob_space.contains(obs)
|
||||||
|
Reference in New Issue
Block a user