Update EnvSpec and make to support reproducing the "whole" environment spec including wrappers (#292)

Co-authored-by: will <will2346@live.co.uk>
Co-authored-by: Will Dudley <14932240+WillDudley@users.noreply.github.com>
Co-authored-by: Mark Towers <marktowers@Marks-MacBook-Pro.local>
This commit is contained in:
Mark Towers
2023-02-24 11:34:20 +00:00
committed by GitHub
parent 6f35e7f87f
commit 9e3200d000
47 changed files with 1178 additions and 319 deletions

View File

@@ -1,12 +1,13 @@
"""Core API for Environment, Wrapper, ActionWrapper, RewardWrapper and ObservationWrapper."""
from __future__ import annotations
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Generic, SupportsFloat, TypeVar
import numpy as np
from gymnasium import spaces
from gymnasium.utils import seeding
from gymnasium.utils import RecordConstructorArgs, seeding
if TYPE_CHECKING:
@@ -277,8 +278,32 @@ class Wrapper(
@property
def spec(self) -> EnvSpec | None:
"""Returns the :attr:`Env` :attr:`spec` attribute."""
return self.env.spec
"""Returns the :attr:`Env` :attr:`spec` attribute with the `WrapperSpec` if the wrapper inherits from `EzPickle`."""
env_spec = self.env.spec
if env_spec is not None:
from gymnasium.envs.registration import WrapperSpec
# See if the wrapper inherits from `RecordConstructorArgs` then add the kwargs otherwise use `None` for the wrapper kwargs. This will raise an error in `make`
if isinstance(self, RecordConstructorArgs):
kwargs = getattr(self, "_saved_kwargs")
if "env" in kwargs:
kwargs = deepcopy(kwargs)
kwargs.pop("env")
else:
kwargs = None
wrapper_spec = WrapperSpec(
name=self.class_name(),
entry_point=f"{self.__module__}:{type(self).__name__}",
kwargs=kwargs,
)
# to avoid reference issues we deepcopy the prior environments spec and add the new information
env_spec = deepcopy(env_spec)
env_spec.applied_wrappers += (wrapper_spec,)
return env_spec
@classmethod
def class_name(cls) -> str:
@@ -409,7 +434,7 @@ class ObservationWrapper(Wrapper[WrapperObsType, ActType, ObsType, ActType]):
def __init__(self, env: Env[ObsType, ActType]):
"""Constructor for the observation wrapper."""
super().__init__(env)
Wrapper.__init__(self, env)
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
@@ -449,7 +474,7 @@ class RewardWrapper(Wrapper[ObsType, ActType, ObsType, ActType]):
def __init__(self, env: Env[ObsType, ActType]):
"""Constructor for the Reward wrapper."""
super().__init__(env)
Wrapper.__init__(self, env)
def step(
self, action: ActType
@@ -485,7 +510,7 @@ class ActionWrapper(Wrapper[ObsType, WrapperActType, ObsType, ActType]):
def __init__(self, env: Env[ObsType, ActType]):
"""Constructor for the action wrapper."""
super().__init__(env)
Wrapper.__init__(self, env)
def step(
self, action: WrapperActType

View File

@@ -3,9 +3,11 @@ from __future__ import annotations
import contextlib
import copy
import dataclasses
import difflib
import importlib
import importlib.util
import json
import re
import sys
import traceback
@@ -17,13 +19,13 @@ from gymnasium import Env, Wrapper, error, logger
from gymnasium.experimental.vector import AsyncVectorEnv, SyncVectorEnv, VectorEnv
from gymnasium.wrappers import (
AutoResetWrapper,
EnvCompatibility,
HumanRendering,
OrderEnforcing,
PassiveEnvChecker,
RenderCollection,
TimeLimit,
)
from gymnasium.wrappers.compatibility import EnvCompatibility
from gymnasium.wrappers.env_checker import PassiveEnvChecker
if sys.version_info < (3, 10):
@@ -43,11 +45,14 @@ ENV_ID_RE = re.compile(
__all__ = [
"EnvSpec",
"registry",
"current_namespace",
"EnvSpec",
"WrapperSpec",
# Functions
"register",
"make",
"make_vec",
"spec",
"pprint_registry",
]
@@ -67,6 +72,20 @@ class VectorEnvCreator(Protocol):
...
@dataclass
class WrapperSpec:
"""A specification for recording wrapper configs.
* name: The name of the wrapper.
* entry_point: The location of the wrapper to create from.
* kwargs: Additional keyword arguments passed to the wrapper. If the wrapper doesn't inherit from EzPickle then this is ``None``
"""
name: str
entry_point: str
kwargs: dict[str, Any] | None
@dataclass
class EnvSpec:
"""A specification for creating environments with :meth:`gymnasium.make`.
@@ -80,6 +99,7 @@ class EnvSpec:
* **autoreset**: If to automatically reset the environment on episode end
* **disable_env_checker**: If to disable the environment checker wrapper in :meth:`gymnasium.make`, by default False (runs the environment checker)
* **kwargs**: Additional keyword arguments passed to the environment during initialisation
* **applied_wrappers**: A tuple of applied wrappers (WrapperSpec)
* **vector_entry_point**: The location of the vectorized environment to create from
"""
@@ -97,27 +117,152 @@ class EnvSpec:
disable_env_checker: bool = field(default=False)
apply_api_compatibility: bool = field(default=False)
# Environment arguments
kwargs: dict = field(default_factory=dict)
# post-init attributes
namespace: str | None = field(init=False)
name: str = field(init=False)
version: int | None = field(init=False)
# Environment arguments
kwargs: dict = field(default_factory=dict)
# applied wrappers
applied_wrappers: tuple[WrapperSpec, ...] = field(init=False, default_factory=tuple)
# Vectorized environment
vector_entry_point: str | None = field(default=None)
# Vectorized environment entry point
vector_entry_point: VectorEnvCreator | str | None = field(default=None)
def __post_init__(self):
"""Calls after the spec is created to extract the namespace, name and version from the id."""
# Initialize namespace, name, version
"""Calls after the spec is created to extract the namespace, name and version from the environment id."""
self.namespace, self.name, self.version = parse_env_id(self.id)
def make(self, **kwargs: Any) -> Env:
"""Calls ``make`` using the environment spec and any keyword arguments."""
# For compatibility purposes
return make(self, **kwargs)
def to_json(self) -> str:
"""Converts the environment spec into a json compatible string.
Returns:
A jsonifyied string for the environment spec
"""
env_spec_dict = dataclasses.asdict(self)
# As the namespace, name and version are initialised after `init` then we remove the attributes
env_spec_dict.pop("namespace")
env_spec_dict.pop("name")
env_spec_dict.pop("version")
# To check that the environment spec can be transformed to a json compatible type
self._check_can_jsonify(env_spec_dict)
return json.dumps(env_spec_dict)
@staticmethod
def _check_can_jsonify(env_spec: dict[str, Any]):
"""Warns the user about serialisation failing if the spec contains a callable.
Args:
env_spec: An environment or wrapper specification.
Returns: The specification with lambda functions converted to strings.
"""
spec_name = env_spec["name"] if "name" in env_spec else env_spec["id"]
for key, value in env_spec.items():
if callable(value):
ValueError(
f"Callable found in {spec_name} for {key} attribute with value={value}. Currently, Gymnasium does not support serialising callables."
)
@staticmethod
def from_json(json_env_spec: str) -> EnvSpec:
"""Converts a JSON string into a specification stack.
Args:
json_env_spec: A JSON string representing the env specification.
Returns:
An environment spec
"""
parsed_env_spec = json.loads(json_env_spec)
applied_wrapper_specs: list[WrapperSpec] = []
for wrapper_spec_json in parsed_env_spec.pop("applied_wrappers"):
try:
applied_wrapper_specs.append(WrapperSpec(**wrapper_spec_json))
except Exception as e:
raise ValueError(
f"An issue occurred when trying to make {wrapper_spec_json} a WrapperSpec"
) from e
try:
env_spec = EnvSpec(**parsed_env_spec)
env_spec.applied_wrappers = tuple(applied_wrapper_specs)
except Exception as e:
raise ValueError(
f"An issue occurred when trying to make {parsed_env_spec} an EnvSpec"
) from e
return env_spec
def pprint(
self,
disable_print: bool = False,
include_entry_points: bool = False,
print_all: bool = False,
) -> str | None:
"""Pretty prints the environment spec.
Args:
disable_print: If to disable print and return the output
include_entry_points: If to include the entry_points in the output
print_all: If to print all information, including variables with default values
Returns:
If ``disable_print is True`` a string otherwise ``None``
"""
output = f"id={self.id}"
if print_all or include_entry_points:
output += f"\nentry_point={self.entry_point}"
if print_all or self.reward_threshold is not None:
output += f"\nreward_threshold={self.reward_threshold}"
if print_all or self.nondeterministic is not False:
output += f"\nnondeterministic={self.nondeterministic}"
if print_all or self.max_episode_steps is not None:
output += f"\nmax_episode_steps={self.max_episode_steps}"
if print_all or self.order_enforce is not True:
output += f"\norder_enforce={self.order_enforce}"
if print_all or self.autoreset is not False:
output += f"\nautoreset={self.autoreset}"
if print_all or self.disable_env_checker is not False:
output += f"\ndisable_env_checker={self.disable_env_checker}"
if print_all or self.apply_api_compatibility is not False:
output += f"\napplied_api_compatibility={self.apply_api_compatibility}"
if print_all or self.applied_wrappers:
wrapper_output: list[str] = []
for wrapper_spec in self.applied_wrappers:
if include_entry_points:
wrapper_output.append(
f"\n\tname={wrapper_spec.name}, entry_point={wrapper_spec.entry_point}, kwargs={wrapper_spec.kwargs}"
)
else:
wrapper_output.append(
f"\n\tname={wrapper_spec.name}, kwargs={wrapper_spec.kwargs}"
)
if len(wrapper_output) == 0:
output += "\napplied_wrappers=[]"
else:
output += f"\napplied_wrappers=[{','.join(wrapper_output)}\n]"
if disable_print:
return output
else:
print(output)
# Global registry of environments. Meant to be accessed through `register` and `make`
registry: dict[str, EnvSpec] = {}
@@ -352,8 +497,12 @@ def _check_metadata(testing_metadata: dict[str, Any]):
)
def _find_spec(id: str) -> EnvSpec:
module, env_name = (None, id) if ":" not in id else id.split(":")
def _find_spec(env_id: str) -> EnvSpec:
# For string id's, load the environment spec from the registry then make the environment spec
assert isinstance(env_id, str)
# The environment name can include an unloaded module in "module:env_name" style
module, env_name = (None, env_id) if ":" not in env_id else env_id.split(":")
if module is not None:
try:
importlib.import_module(module)
@@ -391,7 +540,7 @@ def _find_spec(id: str) -> EnvSpec:
return env_spec
def load_env(name: str) -> EnvCreator:
def load_env_creator(name: str) -> EnvCreator | VectorEnvCreator:
"""Loads an environment with name of style ``"(import path):(environment name)"`` and returns the environment creation function, normally the environment class type.
Args:
@@ -406,6 +555,161 @@ def load_env(name: str) -> EnvCreator:
return fn
def _create_from_env_spec(
env_spec: EnvSpec,
kwargs: dict[str, Any],
) -> Env:
"""Recreates an environment spec using a list of wrapper specs."""
if callable(env_spec.entry_point):
env_creator = env_spec.entry_point
else:
env_creator: EnvCreator = load_env_creator(env_spec.entry_point)
# Create the environment
env: Env = env_creator(**env_spec.kwargs, **kwargs)
# Set the `EnvSpec` to the environment
new_env_spec = copy.deepcopy(env_spec)
new_env_spec.applied_wrappers = ()
new_env_spec.kwargs.update(kwargs)
env.unwrapped.spec = new_env_spec
# Check if the environment spec
assert env.spec is not None # this is for pyright
num_prior_wrappers = len(env.spec.applied_wrappers)
if env_spec.applied_wrappers[:num_prior_wrappers] != env.spec.applied_wrappers:
for env_spec_wrapper_spec, recreated_wrapper_spec in zip(
env_spec.applied_wrappers, env.spec.applied_wrappers
):
raise ValueError(
f"The environment's wrapper spec {recreated_wrapper_spec} is different from the saved `EnvSpec` applied_wrappers {env_spec_wrapper_spec}"
)
for wrapper_spec in env_spec.applied_wrappers[num_prior_wrappers:]:
if wrapper_spec.kwargs is None:
raise ValueError(
f"{wrapper_spec.name} wrapper does not inherit from `gymnasium.utils.RecordConstructorArgs`, therefore, the wrapper cannot be recreated."
)
env = load_env_creator(wrapper_spec.entry_point)(env=env, **wrapper_spec.kwargs)
return env
def _create_from_env_id(
env_spec: EnvSpec,
kwargs: dict[str, Any],
max_episode_steps: int | None = None,
autoreset: bool = False,
apply_api_compatibility: bool | None = None,
disable_env_checker: bool | None = None,
) -> Env:
"""Creates an environment based on the `env_spec` along with wrapper options. See `make` for their meaning."""
spec_kwargs = copy.deepcopy(env_spec.kwargs)
spec_kwargs.update(kwargs)
# Load the environment creator
if env_spec.entry_point is None:
raise error.Error(f"{env_spec.id} registered but entry_point is not specified")
elif callable(env_spec.entry_point):
env_creator = env_spec.entry_point
else:
# Assume it's a string
env_creator = load_env_creator(env_spec.entry_point)
# Determine if to use the rendering
render_modes: list[str] | None = None
if hasattr(env_creator, "metadata"):
_check_metadata(env_creator.metadata)
render_modes = env_creator.metadata.get("render_modes")
mode = spec_kwargs.get("render_mode")
apply_human_rendering = False
apply_render_collection = False
# If mode is not valid, try applying HumanRendering/RenderCollection wrappers
if mode is not None and render_modes is not None and mode not in render_modes:
displayable_modes = {"rgb_array", "rgb_array_list"}.intersection(render_modes)
if mode == "human" and len(displayable_modes) > 0:
logger.warn(
"You are trying to use 'human' rendering for an environment that doesn't natively support it. "
"The HumanRendering wrapper is being applied to your environment."
)
spec_kwargs["render_mode"] = displayable_modes.pop()
apply_human_rendering = True
elif mode.endswith("_list") and mode[: -len("_list")] in render_modes:
spec_kwargs["render_mode"] = mode[: -len("_list")]
apply_render_collection = True
else:
logger.warn(
f"The environment is being initialised with render_mode={mode!r} "
f"that is not in the possible render_modes ({render_modes})."
)
if apply_api_compatibility or (
apply_api_compatibility is None and env_spec.apply_api_compatibility
):
# If we use the compatibility layer, we treat the render mode explicitly and don't pass it to the env creator
render_mode = spec_kwargs.pop("render_mode", None)
else:
render_mode = None
try:
env = env_creator(**spec_kwargs)
except TypeError as e:
if (
str(e).find("got an unexpected keyword argument 'render_mode'") >= 0
and apply_human_rendering
):
raise error.Error(
f"You passed render_mode='human' although {env_spec.id} doesn't implement human-rendering natively. "
"Gym tried to apply the HumanRendering wrapper but it looks like your environment is using the old "
"rendering API, which is not supported by the HumanRendering wrapper."
) from e
else:
raise e
# Copies the environment creation specification and kwargs to add to the environment specification details
env_spec = copy.deepcopy(env_spec)
env_spec.kwargs = spec_kwargs
env.unwrapped.spec = env_spec
# Add step API wrapper
if apply_api_compatibility is True or (
apply_api_compatibility is None and env_spec.apply_api_compatibility is True
):
env = EnvCompatibility(env, render_mode)
# Run the environment checker as the lowest level wrapper
if disable_env_checker is False or (
disable_env_checker is None and env_spec.disable_env_checker is False
):
env = PassiveEnvChecker(env)
# Add the order enforcing wrapper
if env_spec.order_enforce:
env = OrderEnforcing(env)
# Add the time limit wrapper
if max_episode_steps is not None:
assert env.unwrapped.spec is not None # for pyright
env.unwrapped.spec.max_episode_steps = max_episode_steps
env = TimeLimit(env, max_episode_steps)
elif env_spec.max_episode_steps is not None:
env = TimeLimit(env, env_spec.max_episode_steps)
# Add the auto-reset wrapper
if autoreset:
env = AutoResetWrapper(env)
# Add human rendering wrapper
if apply_human_rendering:
env = HumanRendering(env)
elif apply_render_collection:
env = RenderCollection(env)
return env
def load_plugin_envs(entry_point: str = "gymnasium.envs"):
"""Load modules (plugins) using the gymnasium entry points in order to register external module's environments on ``import gymnasium``.
@@ -437,10 +741,8 @@ def load_plugin_envs(entry_point: str = "gymnasium.envs"):
context = namespace(plugin.name)
if plugin.name.startswith("__") and plugin.name.endswith("__"):
# `__internal__` is an artifact of the plugin system when
# the root namespace had an allow-list. The allow-list is now
# removed and plugins can register environments in the root
# namespace with the `__root__` magic key.
# `__internal__` is an artifact of the plugin system when the root namespace had an allow-list.
# The allow-list is now removed and plugins can register environments in the root namespace with the `__root__` magic key.
if plugin.name == "__root__" or plugin.name == "__internal__":
context = contextlib.nullcontext()
else:
@@ -533,9 +835,9 @@ def register(
order_enforce=order_enforce,
autoreset=autoreset,
disable_env_checker=disable_env_checker,
**kwargs,
apply_api_compatibility=apply_api_compatibility,
vector_entry_point=vector_entry_point,
**kwargs,
)
_check_spec_register(new_spec)
@@ -576,116 +878,47 @@ def make(
Error: If the ``id`` doesn't exist in the :attr:`registry`
"""
if isinstance(id, EnvSpec):
env_spec = id
if hasattr(id, "applied_wrappers") and id.applied_wrappers is not None:
if max_episode_steps is not None:
logger.warn(
f"For `gymnasium.make` with an `EnvSpec`, the `max_episode_step` parameter is ignored, use `gym.make({id.id}, max_episode_steps={max_episode_steps})` and any additional wrappers"
)
if autoreset is True:
logger.warn(
f"For `gymnasium.make` with an `EnvSpec`, the `autoreset` parameter is ignored, use `gym.make({id.id}, autoreset={autoreset})` and any additional wrappers"
)
if apply_api_compatibility is not None:
logger.warn(
f"For `gymnasium.make` with an `EnvSpec`, the `apply_api_compatibility` parameter is ignored, use `gym.make({id.id}, apply_api_compatibility={apply_api_compatibility})` and any additional wrappers"
)
if disable_env_checker is not None:
logger.warn(
f"For `gymnasium.make` with an `EnvSpec`, the `disable_env_checker` parameter is ignored, use `gym.make({id.id}, disable_env_checker={disable_env_checker})` and any additional wrappers"
)
return _create_from_env_spec(
id,
kwargs,
)
else:
raise ValueError(
f"The EnvSpec used does not contain `applied_wrappers` parameters or is `None`. Expected to be a tuple, actually {id}."
)
else:
# For string id's, load the environment spec from the registry then make the environment spec
assert isinstance(id, str)
# The environment name can include an unloaded module in "module:env_name" style
env_spec = _find_spec(id)
assert isinstance(
env_spec, EnvSpec
), f"We expected to collect an `EnvSpec`, actually collected a {type(env_spec)}"
# Extract the spec kwargs and append the make kwargs
spec_kwargs = env_spec.kwargs.copy()
spec_kwargs.update(kwargs)
# Load the environment creator
if env_spec.entry_point is None:
raise error.Error(f"{env_spec.id} registered but entry_point is not specified")
elif callable(env_spec.entry_point):
env_creator = env_spec.entry_point
else:
# Assume it's a string
env_creator = load_env(env_spec.entry_point)
# Determine if to use the rendering
render_modes: list[str] | None = None
if hasattr(env_creator, "metadata"):
_check_metadata(env_creator.metadata)
render_modes = env_creator.metadata.get("render_modes")
mode = spec_kwargs.get("render_mode")
apply_human_rendering = False
apply_render_collection = False
# If mode is not valid, try applying HumanRendering/RenderCollection wrappers
if mode is not None and render_modes is not None and mode not in render_modes:
displayable_modes = {"rgb_array", "rgb_array_list"}.intersection(render_modes)
if mode == "human" and len(displayable_modes) > 0:
logger.warn(
"You are trying to use 'human' rendering for an environment that doesn't natively support it. "
"The HumanRendering wrapper is being applied to your environment."
)
spec_kwargs["render_mode"] = displayable_modes.pop()
apply_human_rendering = True
elif mode.endswith("_list") and mode[: -len("_list")] in render_modes:
spec_kwargs["render_mode"] = mode[: -len("_list")]
apply_render_collection = True
else:
logger.warn(
f"The environment is being initialised with render_mode={mode!r} "
f"that is not in the possible render_modes ({render_modes})."
)
if apply_api_compatibility or (
apply_api_compatibility is None and env_spec.apply_api_compatibility
):
# If we use the compatibility layer, we treat the render mode explicitly and don't pass it to the env creator
render_mode = spec_kwargs.pop("render_mode", None)
else:
render_mode = None
try:
env = env_creator(**spec_kwargs)
except TypeError as e:
if (
str(e).find("got an unexpected keyword argument 'render_mode'") >= 0
and apply_human_rendering
):
raise error.Error(
f"You passed render_mode='human' although {id} doesn't implement human-rendering natively. "
"Gym tried to apply the HumanRendering wrapper but it looks like your environment is using the old "
"rendering API, which is not supported by the HumanRendering wrapper."
) from e
else:
raise e
# Copies the environment creation specification and kwargs to add to the environment specification details
env_spec = copy.deepcopy(env_spec)
env_spec.kwargs = spec_kwargs
env.unwrapped.spec = env_spec
# Add step API wrapper
if apply_api_compatibility is True or (
apply_api_compatibility is None and env_spec.apply_api_compatibility is True
):
env = EnvCompatibility(env, render_mode)
# Run the environment checker as the lowest level wrapper
if disable_env_checker is False or (
disable_env_checker is None and env_spec.disable_env_checker is False
):
env = PassiveEnvChecker(env)
# Add the order enforcing wrapper
if env_spec.order_enforce:
env = OrderEnforcing(env)
# Add the time limit wrapper
if max_episode_steps is not None:
env = TimeLimit(env, max_episode_steps)
elif env_spec.max_episode_steps is not None:
env = TimeLimit(env, env_spec.max_episode_steps)
# Add the autoreset wrapper
if autoreset:
env = AutoResetWrapper(env)
# Add human rendering wrapper
if apply_human_rendering:
env = HumanRendering(env)
elif apply_render_collection:
env = RenderCollection(env)
return env
return _create_from_env_id(
env_spec,
kwargs,
max_episode_steps=max_episode_steps,
autoreset=autoreset,
apply_api_compatibility=apply_api_compatibility,
disable_env_checker=disable_env_checker,
)
def make_vec(
@@ -752,7 +985,7 @@ def make_vec(
env_creator = entry_point
else:
# Assume it's a string
env_creator = load_env(entry_point)
env_creator = load_env_creator(entry_point)
def _create_env():
# Env creator for use with sync and async modes

View File

@@ -11,7 +11,7 @@ except ImportError:
cv2 = None
class AtariPreprocessingV0(gym.Wrapper):
class AtariPreprocessingV0(gym.Wrapper, gym.utils.RecordConstructorArgs):
"""Atari 2600 preprocessing wrapper.
This class follows the guidelines in Machado et al. (2018),
@@ -60,7 +60,18 @@ class AtariPreprocessingV0(gym.Wrapper):
DependencyNotInstalled: opencv-python package not installed
ValueError: Disable frame-skipping in the original env
"""
super().__init__(env)
gym.utils.RecordConstructorArgs.__init__(
self,
noop_max=noop_max,
frame_skip=frame_skip,
screen_size=screen_size,
terminal_on_life_loss=terminal_on_life_loss,
grayscale_obs=grayscale_obs,
grayscale_newaxis=grayscale_newaxis,
scale_obs=scale_obs,
)
gym.Wrapper.__init__(self, env)
if cv2 is None:
raise gym.error.DependencyNotInstalled(
"opencv-python package not installed, run `pip install gymnasium[other]` to get dependencies for atari"

View File

@@ -14,7 +14,6 @@ from typing import Any, SupportsFloat
import numpy as np
import gymnasium as gym
from gymnasium import Env
from gymnasium.core import ActType, ObsType, RenderFrame
from gymnasium.error import ResetNeeded
from gymnasium.utils.passive_env_checker import (
@@ -26,7 +25,9 @@ from gymnasium.utils.passive_env_checker import (
)
class AutoresetV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
class AutoresetV0(
gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
):
"""A class for providing an automatic reset functionality for gymnasium environments when calling :meth:`self.step`."""
def __init__(self, env: gym.Env[ObsType, ActType]):
@@ -35,7 +36,9 @@ class AutoresetV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
Args:
env (gym.Env): The environment to apply the wrapper
"""
super().__init__(env)
gym.utils.RecordConstructorArgs.__init__(self)
gym.Wrapper.__init__(self, env)
self._episode_ended: bool = False
self._reset_options: dict[str, Any] | None = None
@@ -68,12 +71,15 @@ class AutoresetV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
return super().reset(seed=seed, options=self._reset_options)
class PassiveEnvCheckerV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
class PassiveEnvCheckerV0(
gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
):
"""A passive environment checker wrapper that surrounds the step, reset and render functions to check they follow the gymnasium API."""
def __init__(self, env: Env[ObsType, ActType]):
def __init__(self, env: gym.Env[ObsType, ActType]):
"""Initialises the wrapper with the environments, run the observation and action space tests."""
super().__init__(env)
gym.utils.RecordConstructorArgs.__init__(self)
gym.Wrapper.__init__(self, env)
assert hasattr(
env, "action_space"
@@ -117,7 +123,9 @@ class PassiveEnvCheckerV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
return self.env.render()
class OrderEnforcingV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
class OrderEnforcingV0(
gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
):
"""A wrapper that will produce an error if :meth:`step` is called before an initial :meth:`reset`.
Example:
@@ -150,7 +158,11 @@ class OrderEnforcingV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
env: The environment to wrap
disable_render_order_enforcing: If to disable render order enforcing
"""
super().__init__(env)
gym.utils.RecordConstructorArgs.__init__(
self, disable_render_order_enforcing=disable_render_order_enforcing
)
gym.Wrapper.__init__(self, env)
self._has_reset: bool = False
self._disable_render_order_enforcing: bool = disable_render_order_enforcing
@@ -182,7 +194,9 @@ class OrderEnforcingV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
return self._has_reset
class RecordEpisodeStatisticsV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
class RecordEpisodeStatisticsV0(
gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
):
"""This wrapper will keep track of cumulative rewards and episode lengths.
At the end of an episode, the statistics of the episode will be added to ``info``
@@ -226,7 +240,7 @@ class RecordEpisodeStatisticsV0(gym.Wrapper[ObsType, ActType, ObsType, ActType])
def __init__(
self,
env: Env[ObsType, ActType],
env: gym.Env[ObsType, ActType],
buffer_length: int | None = 100,
stats_key: str = "episode",
):
@@ -237,7 +251,8 @@ class RecordEpisodeStatisticsV0(gym.Wrapper[ObsType, ActType, ObsType, ActType])
buffer_length: The size of the buffers :attr:`return_queue` and :attr:`length_queue`
stats_key: The info key for the episode statistics
"""
super().__init__(env)
gym.utils.RecordConstructorArgs.__init__(self)
gym.Wrapper.__init__(self, env)
self._stats_key = stats_key

View File

@@ -8,7 +8,7 @@ from typing import Any, Iterable, Mapping, SupportsFloat
import numpy as np
from gymnasium import Env, Wrapper
import gymnasium as gym
from gymnasium.core import ActType, ObsType, RenderFrame, WrapperActType, WrapperObsType
from gymnasium.error import DependencyNotInstalled
@@ -92,7 +92,10 @@ if jnp is not None:
return type(value)(jax_to_numpy(v) for v in value)
class JaxToNumpyV0(Wrapper[WrapperObsType, WrapperActType, ObsType, ActType]):
class JaxToNumpyV0(
gym.Wrapper[WrapperObsType, WrapperActType, ObsType, ActType],
gym.utils.RecordConstructorArgs,
):
"""Wraps a jax environment so that it can be interacted with through numpy arrays.
Actions must be provided as numpy arrays and observations will be returned as numpy arrays.
@@ -102,7 +105,7 @@ class JaxToNumpyV0(Wrapper[WrapperObsType, WrapperActType, ObsType, ActType]):
The reason for this is jax does not support non-array values, therefore numpy ``int_32(5) -> DeviceArray([5], dtype=jnp.int23)``
"""
def __init__(self, env: Env[ObsType, ActType]):
def __init__(self, env: gym.Env[ObsType, ActType]):
"""Wraps an environment such that the input and outputs are numpy arrays.
Args:
@@ -112,7 +115,8 @@ class JaxToNumpyV0(Wrapper[WrapperObsType, WrapperActType, ObsType, ActType]):
raise DependencyNotInstalled(
"jax is not installed, run `pip install gymnasium[jax]`"
)
super().__init__(env)
gym.utils.RecordConstructorArgs.__init__(self)
gym.Wrapper.__init__(self, env)
def step(
self, action: WrapperActType

View File

@@ -14,7 +14,7 @@ import numbers
from collections import abc
from typing import Any, Iterable, Mapping, SupportsFloat, Union
from gymnasium import Env, Wrapper
import gymnasium as gym
from gymnasium.core import RenderFrame, WrapperActType, WrapperObsType
from gymnasium.error import DependencyNotInstalled
from gymnasium.experimental.wrappers.conversion.jax_to_numpy import jax_to_numpy
@@ -131,7 +131,7 @@ if torch is not None and jnp is not None:
return type(value)(jax_to_torch(v, device) for v in value)
class JaxToTorchV0(Wrapper):
class JaxToTorchV0(gym.Wrapper, gym.utils.RecordConstructorArgs):
"""Wraps a jax-based environment so that it can be interacted with through PyTorch Tensors.
Actions must be provided as PyTorch Tensors and observations will be returned as PyTorch Tensors.
@@ -140,7 +140,7 @@ class JaxToTorchV0(Wrapper):
For ``rendered`` this is returned as a NumPy array not a pytorch Tensor.
"""
def __init__(self, env: Env, device: Device | None = None):
def __init__(self, env: gym.Env, device: Device | None = None):
"""Wrapper class to change inputs and outputs of environment to PyTorch tensors.
Args:
@@ -156,7 +156,9 @@ class JaxToTorchV0(Wrapper):
"jax is not installed, run `pip install gymnasium[jax]`"
)
super().__init__(env)
gym.utils.RecordConstructorArgs.__init__(self, device=device)
gym.Wrapper.__init__(self, env)
self.device: Device | None = device
def step(

View File

@@ -8,7 +8,7 @@ from typing import Any, Iterable, Mapping, SupportsFloat, Union
import numpy as np
from gymnasium import Env, Wrapper
import gymnasium as gym
from gymnasium.core import WrapperActType, WrapperObsType
from gymnasium.error import DependencyNotInstalled
@@ -94,7 +94,7 @@ if torch is not None:
return type(value)(numpy_to_torch(v, device) for v in value)
class NumpyToTorchV0(Wrapper):
class NumpyToTorchV0(gym.Wrapper, gym.utils.RecordConstructorArgs):
"""Wraps a numpy-based environment so that it can be interacted with through PyTorch Tensors.
Actions must be provided as PyTorch Tensors and observations will be returned as PyTorch Tensors.
@@ -103,7 +103,7 @@ class NumpyToTorchV0(Wrapper):
For ``rendered`` this is returned as a NumPy array not a pytorch Tensor.
"""
def __init__(self, env: Env, device: Device | None = None):
def __init__(self, env: gym.Env, device: Device | None = None):
"""Wrapper class to change inputs and outputs of environment to PyTorch tensors.
Args:
@@ -115,7 +115,9 @@ class NumpyToTorchV0(Wrapper):
"torch is not installed, run `pip install torch`"
)
super().__init__(env)
gym.utils.RecordConstructorArgs.__init__(self, device=device)
gym.Wrapper.__init__(self, env)
self.device: Device | None = device
def step(

View File

@@ -20,7 +20,9 @@ from gymnasium.core import ActType, ObsType, WrapperActType
from gymnasium.spaces import Box, Space
class LambdaActionV0(gym.ActionWrapper[ObsType, WrapperActType, ActType]):
class LambdaActionV0(
gym.ActionWrapper[ObsType, WrapperActType, ActType], gym.utils.RecordConstructorArgs
):
"""A wrapper that provides a function to modify the action passed to :meth:`step`."""
def __init__(
@@ -36,7 +38,11 @@ class LambdaActionV0(gym.ActionWrapper[ObsType, WrapperActType, ActType]):
func: Function to apply to ``step`` ``action``
action_space: The updated action space of the wrapper given the function.
"""
super().__init__(env)
gym.utils.RecordConstructorArgs.__init__(
self, func=func, action_space=action_space
)
gym.Wrapper.__init__(self, env)
if action_space is not None:
self.action_space = action_space
@@ -47,7 +53,9 @@ class LambdaActionV0(gym.ActionWrapper[ObsType, WrapperActType, ActType]):
return self.func(action)
class ClipActionV0(LambdaActionV0[ObsType, WrapperActType, ActType]):
class ClipActionV0(
LambdaActionV0[ObsType, WrapperActType, ActType], gym.utils.RecordConstructorArgs
):
"""Clip the continuous action within the valid :class:`Box` observation space bound.
Example:
@@ -71,10 +79,14 @@ class ClipActionV0(LambdaActionV0[ObsType, WrapperActType, ActType]):
"""
assert isinstance(env.action_space, Box)
super().__init__(
env,
lambda action: jp.clip(action, env.action_space.low, env.action_space.high),
Box(
gym.utils.RecordConstructorArgs.__init__(self)
LambdaActionV0.__init__(
self,
env=env,
func=lambda action: jp.clip(
action, env.action_space.low, env.action_space.high
),
action_space=Box(
-np.inf,
np.inf,
shape=env.action_space.shape,
@@ -83,7 +95,9 @@ class ClipActionV0(LambdaActionV0[ObsType, WrapperActType, ActType]):
)
class RescaleActionV0(LambdaActionV0[ObsType, WrapperActType, ActType]):
class RescaleActionV0(
LambdaActionV0[ObsType, WrapperActType, ActType], gym.utils.RecordConstructorArgs
):
"""Affinely rescales the continuous action space of the environment to the range [min_action, max_action].
The base environment :attr:`env` must have an action space of type :class:`spaces.Box`. If :attr:`min_action`
@@ -118,6 +132,10 @@ class RescaleActionV0(LambdaActionV0[ObsType, WrapperActType, ActType]):
min_action (float, int or np.ndarray): The min values for each action. This may be a numpy array or a scalar.
max_action (float, int or np.ndarray): The max values for each action. This may be a numpy array or a scalar.
"""
gym.utils.RecordConstructorArgs.__init__(
self, min_action=min_action, max_action=max_action
)
assert isinstance(env.action_space, Box)
assert not np.any(env.action_space.low == np.inf) and not np.any(
env.action_space.high == np.inf
@@ -149,10 +167,11 @@ class RescaleActionV0(LambdaActionV0[ObsType, WrapperActType, ActType]):
)
intercept = gradient * -min_action + env.action_space.low
super().__init__(
env,
lambda action: gradient * action + intercept,
Box(
LambdaActionV0.__init__(
self,
env=env,
func=lambda action: gradient * action + intercept,
action_space=Box(
low=min_action,
high=max_action,
shape=env.action_space.shape,

View File

@@ -24,14 +24,16 @@ except ImportError as e:
import numpy as np
import gymnasium as gym
from gymnasium import Env, spaces
from gymnasium.core import ActType, ObservationWrapper, ObsType, WrapperObsType
from gymnasium import spaces
from gymnasium.core import ActType, ObsType, WrapperObsType
from gymnasium.error import DependencyNotInstalled
from gymnasium.experimental.wrappers.utils import RunningMeanStd
from gymnasium.spaces import Box, Dict, utils
class LambdaObservationV0(gym.ObservationWrapper[WrapperObsType, ActType, ObsType]):
class LambdaObservationV0(
gym.ObservationWrapper[WrapperObsType, ActType, ObsType],
gym.utils.RecordConstructorArgs,
):
"""Transforms an observation via a function provided to the wrapper.
The function :attr:`func` will be applied to all observations.
@@ -61,7 +63,11 @@ class LambdaObservationV0(gym.ObservationWrapper[WrapperObsType, ActType, ObsTyp
func: A function that will transform an observation. If this transformed observation is outside the observation space of `env.observation_space` then provide an `observation_space`.
observation_space: The observation spaces of the wrapper, if None, then it is assumed the same as `env.observation_space`.
"""
super().__init__(env)
gym.utils.RecordConstructorArgs.__init__(
self, func=func, observation_space=observation_space
)
gym.ObservationWrapper.__init__(self, env)
if observation_space is not None:
self.observation_space = observation_space
@@ -72,7 +78,10 @@ class LambdaObservationV0(gym.ObservationWrapper[WrapperObsType, ActType, ObsTyp
return self.func(observation)
class FilterObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]):
class FilterObservationV0(
LambdaObservationV0[WrapperObsType, ActType, ObsType],
gym.utils.RecordConstructorArgs,
):
"""Filter Dict observation space by the keys.
Example:
@@ -96,6 +105,7 @@ class FilterObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType])
):
"""Constructor for an environment with a dictionary observation space where all :attr:`filter_keys` are in the observation space keys."""
assert isinstance(filter_keys, Sequence)
gym.utils.RecordConstructorArgs.__init__(self, filter_keys=filter_keys)
# Filters for dictionary space
if isinstance(env.observation_space, spaces.Dict):
@@ -124,10 +134,11 @@ class FilterObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType])
"The observation space is empty due to filtering all keys."
)
super().__init__(
env,
lambda obs: {key: obs[key] for key in filter_keys},
new_observation_space,
LambdaObservationV0.__init__(
self,
env=env,
func=lambda obs: {key: obs[key] for key in filter_keys},
observation_space=new_observation_space,
)
# Filter for tuple observation
elif isinstance(env.observation_space, spaces.Tuple):
@@ -158,10 +169,11 @@ class FilterObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType])
"The observation space is empty due to filtering all keys."
)
super().__init__(
env,
lambda obs: tuple(obs[key] for key in filter_keys),
new_observation_spaces,
LambdaObservationV0.__init__(
self,
env=env,
func=lambda obs: tuple(obs[key] for key in filter_keys),
observation_space=new_observation_spaces,
)
else:
raise ValueError(
@@ -171,7 +183,10 @@ class FilterObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType])
self.filter_keys: Final[Sequence[str | int]] = filter_keys
class FlattenObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]):
class FlattenObservationV0(
LambdaObservationV0[WrapperObsType, ActType, ObsType],
gym.utils.RecordConstructorArgs,
):
"""Observation wrapper that flattens the observation.
Example:
@@ -190,14 +205,19 @@ class FlattenObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]
def __init__(self, env: gym.Env[ObsType, ActType]):
"""Constructor for any environment's observation space that implements ``spaces.utils.flatten_space`` and ``spaces.utils.flatten``."""
super().__init__(
env,
lambda obs: utils.flatten(env.observation_space, obs),
utils.flatten_space(env.observation_space),
gym.utils.RecordConstructorArgs.__init__(self)
LambdaObservationV0.__init__(
self,
env=env,
func=lambda obs: spaces.utils.flatten(env.observation_space, obs),
observation_space=spaces.utils.flatten_space(env.observation_space),
)
class GrayscaleObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]):
class GrayscaleObservationV0(
LambdaObservationV0[WrapperObsType, ActType, ObsType],
gym.utils.RecordConstructorArgs,
):
"""Observation wrapper that converts an RGB image to grayscale.
The :attr:`keep_dim` will keep the channel dimension
@@ -228,6 +248,7 @@ class GrayscaleObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsTyp
and np.all(env.observation_space.high == 255)
and env.observation_space.dtype == np.uint8
)
gym.utils.RecordConstructorArgs.__init__(self, keep_dim=keep_dim)
self.keep_dim: Final[bool] = keep_dim
if keep_dim:
@@ -237,30 +258,35 @@ class GrayscaleObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsTyp
shape=env.observation_space.shape[:2] + (1,),
dtype=np.uint8,
)
super().__init__(
env,
lambda obs: jp.expand_dims(
LambdaObservationV0.__init__(
self,
env=env,
func=lambda obs: jp.expand_dims(
jp.sum(
jp.multiply(obs, jp.array([0.2125, 0.7154, 0.0721])), axis=-1
).astype(np.uint8),
axis=-1,
),
new_observation_space,
observation_space=new_observation_space,
)
else:
new_observation_space = spaces.Box(
low=0, high=255, shape=env.observation_space.shape[:2], dtype=np.uint8
)
super().__init__(
env,
lambda obs: jp.sum(
LambdaObservationV0.__init__(
self,
env=env,
func=lambda obs: jp.sum(
jp.multiply(obs, jp.array([0.2125, 0.7154, 0.0721])), axis=-1
).astype(np.uint8),
new_observation_space,
observation_space=new_observation_space,
)
class ResizeObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]):
class ResizeObservationV0(
LambdaObservationV0[WrapperObsType, ActType, ObsType],
gym.utils.RecordConstructorArgs,
):
"""Resizes image observations using OpenCV to shape.
Example:
@@ -299,14 +325,20 @@ class ResizeObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType])
new_observation_space = spaces.Box(
low=0, high=255, shape=self.shape + env.observation_space.shape[2:]
)
super().__init__(
env,
lambda obs: cv2.resize(obs, self.shape, interpolation=cv2.INTER_AREA),
new_observation_space,
gym.utils.RecordConstructorArgs.__init__(self, shape=shape)
LambdaObservationV0.__init__(
self,
env=env,
func=lambda obs: cv2.resize(obs, self.shape, interpolation=cv2.INTER_AREA),
observation_space=new_observation_space,
)
class ReshapeObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]):
class ReshapeObservationV0(
LambdaObservationV0[WrapperObsType, ActType, ObsType],
gym.utils.RecordConstructorArgs,
):
"""Reshapes array based observations to shapes.
Example:
@@ -336,10 +368,20 @@ class ReshapeObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]
dtype=env.observation_space.dtype,
)
self.shape = shape
super().__init__(env, lambda obs: jp.reshape(obs, shape), new_observation_space)
gym.utils.RecordConstructorArgs.__init__(self, shape=shape)
LambdaObservationV0.__init__(
self,
env=env,
func=lambda obs: jp.reshape(obs, shape),
observation_space=new_observation_space,
)
class RescaleObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]):
class RescaleObservationV0(
LambdaObservationV0[WrapperObsType, ActType, ObsType],
gym.utils.RecordConstructorArgs,
):
"""Linearly rescales observation to between a minimum and maximum value.
Example:
@@ -392,10 +434,12 @@ class RescaleObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]
)
intercept = gradient * -env.observation_space.low + min_obs
super().__init__(
env,
lambda obs: gradient * obs + intercept,
Box(
gym.utils.RecordConstructorArgs.__init__(self, min_obs=min_obs, max_obs=max_obs)
LambdaObservationV0.__init__(
self,
env=env,
func=lambda obs: gradient * obs + intercept,
observation_space=spaces.Box(
low=min_obs,
high=max_obs,
shape=env.observation_space.shape,
@@ -404,7 +448,10 @@ class RescaleObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]
)
class DtypeObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]):
class DtypeObservationV0(
LambdaObservationV0[WrapperObsType, ActType, ObsType],
gym.utils.RecordConstructorArgs,
):
"""Observation wrapper for transforming the dtype of an observation."""
def __init__(self, env: gym.Env[ObsType, ActType], dtype: Any):
@@ -445,10 +492,19 @@ class DtypeObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]):
"DtypeObservation is only compatible with value / array-based observations."
)
super().__init__(env, lambda obs: dtype(obs), new_observation_space)
gym.utils.RecordConstructorArgs.__init__(self, dtype=dtype)
LambdaObservationV0.__init__(
self,
env=env,
func=lambda obs: dtype(obs),
observation_space=new_observation_space,
)
class PixelObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]):
class PixelObservationV0(
LambdaObservationV0[WrapperObsType, ActType, ObsType],
gym.utils.RecordConstructorArgs,
):
"""Augment observations by pixel values.
Observations of this wrapper will be dictionaries of images.
@@ -461,7 +517,7 @@ class PixelObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]):
def __init__(
self,
env: Env[ObsType, ActType],
env: gym.Env[ObsType, ActType],
pixels_only: bool = True,
pixels_key: str = "pixels",
obs_key: str = "state",
@@ -478,30 +534,49 @@ class PixelObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]):
pixels_key: Optional custom string specifying the pixel key. Defaults to "pixels"
obs_key: Optional custom string specifying the obs key. Defaults to "state"
"""
gym.utils.RecordConstructorArgs.__init__(
self, pixels_only=pixels_only, pixels_key=pixels_key, obs_key=obs_key
)
assert env.render_mode is not None and env.render_mode != "human"
env.reset()
pixels = env.render()
assert pixels is not None and isinstance(pixels, np.ndarray)
pixel_space = Box(low=0, high=255, shape=pixels.shape, dtype=np.uint8)
pixel_space = spaces.Box(low=0, high=255, shape=pixels.shape, dtype=np.uint8)
if pixels_only:
obs_space = pixel_space
super().__init__(env, lambda _: self.render(), obs_space)
elif isinstance(env.observation_space, Dict):
LambdaObservationV0.__init__(
self, env=env, func=lambda _: self.render(), observation_space=obs_space
)
elif isinstance(env.observation_space, spaces.Dict):
assert pixels_key not in env.observation_space.spaces.keys()
obs_space = Dict({pixels_key: pixel_space, **env.observation_space.spaces})
super().__init__(
env, lambda obs: {pixels_key: self.render(), **obs_space}, obs_space
obs_space = spaces.Dict(
{pixels_key: pixel_space, **env.observation_space.spaces}
)
LambdaObservationV0.__init__(
self,
env=env,
func=lambda obs: {pixels_key: self.render(), **obs_space},
observation_space=obs_space,
)
else:
obs_space = Dict({obs_key: env.observation_space, pixels_key: pixel_space})
super().__init__(
env, lambda obs: {obs_key: obs, pixels_key: self.render()}, obs_space
obs_space = spaces.Dict(
{obs_key: env.observation_space, pixels_key: pixel_space}
)
LambdaObservationV0.__init__(
self,
env=env,
func=lambda obs: {obs_key: obs, pixels_key: self.render()},
observation_space=obs_space,
)
class NormalizeObservationV0(ObservationWrapper[WrapperObsType, ActType, ObsType]):
class NormalizeObservationV0(
gym.ObservationWrapper[WrapperObsType, ActType, ObsType],
gym.utils.RecordConstructorArgs,
):
"""This wrapper will normalize observations s.t. each coordinate is centered with unit variance.
The property `_update_running_mean` allows to freeze/continue the running mean calculation of the observation
@@ -520,7 +595,9 @@ class NormalizeObservationV0(ObservationWrapper[WrapperObsType, ActType, ObsType
env (Env): The environment to apply the wrapper
epsilon: A stability parameter that is used when scaling the observations.
"""
super().__init__(env)
gym.utils.RecordConstructorArgs.__init__(self, epsilon=epsilon)
gym.ObservationWrapper.__init__(self, env)
self.obs_rms = RunningMeanStd(shape=self.observation_space.shape)
self.epsilon = epsilon
self._update_running_mean = True

View File

@@ -15,7 +15,9 @@ from gymnasium.error import InvalidBound
from gymnasium.experimental.wrappers.utils import RunningMeanStd
class LambdaRewardV0(gym.RewardWrapper[ObsType, ActType]):
class LambdaRewardV0(
gym.RewardWrapper[ObsType, ActType], gym.utils.RecordConstructorArgs
):
"""A reward wrapper that allows a custom function to modify the step reward.
Example:
@@ -40,7 +42,8 @@ class LambdaRewardV0(gym.RewardWrapper[ObsType, ActType]):
env (Env): The environment to apply the wrapper
func: (Callable): The function to apply to reward
"""
super().__init__(env)
gym.utils.RecordConstructorArgs.__init__(self, func=func)
gym.RewardWrapper.__init__(self, env)
self.func = func
@@ -53,7 +56,7 @@ class LambdaRewardV0(gym.RewardWrapper[ObsType, ActType]):
return self.func(reward)
class ClipRewardV0(LambdaRewardV0[ObsType, ActType]):
class ClipRewardV0(LambdaRewardV0[ObsType, ActType], gym.utils.RecordConstructorArgs):
"""A wrapper that clips the rewards for an environment between an upper and lower bound.
Example:
@@ -89,10 +92,17 @@ class ClipRewardV0(LambdaRewardV0[ObsType, ActType]):
f"Min reward ({min_reward}) must be smaller than max reward ({max_reward})"
)
super().__init__(env, lambda x: np.clip(x, a_min=min_reward, a_max=max_reward))
gym.utils.RecordConstructorArgs.__init__(
self, min_reward=min_reward, max_reward=max_reward
)
LambdaRewardV0.__init__(
self, env=env, func=lambda x: np.clip(x, a_min=min_reward, a_max=max_reward)
)
class NormalizeRewardV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
class NormalizeRewardV0(
gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
):
r"""This wrapper will normalize immediate rewards s.t. their exponential moving average has a fixed variance.
The exponential moving average will have variance :math:`(1 - \gamma)^2`.
@@ -119,7 +129,9 @@ class NormalizeRewardV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
epsilon (float): A stability parameter
gamma (float): The discount factor that is used in the exponential moving average.
"""
super().__init__(env)
gym.utils.RecordConstructorArgs.__init__(self, gamma=gamma, epsilon=epsilon)
gym.Wrapper.__init__(self, env)
self.rewards_running_means = RunningMeanStd(shape=())
self.discounted_reward: np.array = np.array([0.0])
self.gamma = gamma

View File

@@ -18,7 +18,9 @@ from gymnasium.core import ActType, ObsType, RenderFrame
from gymnasium.error import DependencyNotInstalled
class RenderCollectionV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
class RenderCollectionV0(
gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
):
"""Collect rendered frames of an environment such ``render`` returns a ``list[RenderedFrame]``."""
def __init__(
@@ -34,7 +36,11 @@ class RenderCollectionV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
pop_frames (bool): If true, clear the collection frames after ``meth:render`` is called. Default value is ``True``.
reset_clean (bool): If true, clear the collection frames when ``meth:reset`` is called. Default value is ``True``.
"""
super().__init__(env)
gym.utils.RecordConstructorArgs.__init__(
self, pop_frames=pop_frames, reset_clean=reset_clean
)
gym.Wrapper.__init__(self, env)
assert env.render_mode is not None
assert not env.render_mode.endswith("_list")
@@ -80,7 +86,9 @@ class RenderCollectionV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
return frames
class RecordVideoV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
class RecordVideoV0(
gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
):
"""This wrapper records videos of rollouts.
Usually, you only want to record episodes intermittently, say every hundredth episode.
@@ -117,9 +125,18 @@ class RecordVideoV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
Otherwise, snippets of the specified length are captured
name_prefix (str): Will be prepended to the filename of the recordings
disable_logger (bool): Whether to disable moviepy logger or not
"""
super().__init__(env)
gym.utils.RecordConstructorArgs.__init__(
self,
video_folder=video_folder,
episode_trigger=episode_trigger,
step_trigger=step_trigger,
video_length=video_length,
name_prefix=name_prefix,
disable_logger=disable_logger,
)
gym.Wrapper.__init__(self, env)
try:
import moviepy # noqa: F401
except ImportError as e:
@@ -277,7 +294,9 @@ class RecordVideoV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
logger.warn("Unable to save last video! Did you call close()?")
class HumanRenderingV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
class HumanRenderingV0(
gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
):
"""Performs human rendering for an environment that only supports "rgb_array"rendering.
This wrapper is particularly useful when you have implemented an environment that can produce
@@ -317,7 +336,9 @@ class HumanRenderingV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
Args:
env: The environment that is being wrapped
"""
super().__init__(env)
gym.utils.RecordConstructorArgs.__init__(self)
gym.Wrapper.__init__(self, env)
assert env.render_mode in [
"rgb_array",
"rgb_array_list",

View File

@@ -4,11 +4,13 @@ from __future__ import annotations
from typing import Any
import gymnasium as gym
from gymnasium.core import ActionWrapper, ActType, ObsType
from gymnasium.core import ActType, ObsType
from gymnasium.error import InvalidProbability
class StickyActionV0(ActionWrapper[ObsType, ActType, ActType]):
class StickyActionV0(
gym.ActionWrapper[ObsType, ActType, ActType], gym.utils.RecordConstructorArgs
):
"""Wrapper which adds a probability of repeating the previous action.
This wrapper follows the implementation proposed by `Machado et al., 2018 <https://arxiv.org/pdf/1709.06009.pdf>`_
@@ -29,7 +31,11 @@ class StickyActionV0(ActionWrapper[ObsType, ActType, ActType]):
f"repeat_action_probability should be in the interval [0,1). Received {repeat_action_probability}"
)
super().__init__(env)
gym.utils.RecordConstructorArgs.__init__(
self, repeat_action_probability=repeat_action_probability
)
gym.ActionWrapper.__init__(self, env)
self.repeat_action_probability = repeat_action_probability
self.last_action: ActType | None = None

View File

@@ -13,8 +13,8 @@ from typing_extensions import Final
import numpy as np
import gymnasium as gym
import gymnasium.spaces as spaces
from gymnasium import Env, ObservationWrapper, Space, Wrapper
from gymnasium.core import ActType, ObsType, WrapperActType, WrapperObsType
from gymnasium.experimental.vector.utils import (
batch_space,
@@ -25,7 +25,9 @@ from gymnasium.experimental.wrappers.utils import create_zero_array
from gymnasium.spaces import Box, Dict, Tuple
class DelayObservationV0(ObservationWrapper[ObsType, ActType, ObsType]):
class DelayObservationV0(
gym.ObservationWrapper[ObsType, ActType, ObsType], gym.utils.RecordConstructorArgs
):
"""Wrapper which adds a delay to the returned observation.
Before reaching the :attr:`delay` number of timesteps, returned observations is an array of zeros with
@@ -49,15 +51,13 @@ class DelayObservationV0(ObservationWrapper[ObsType, ActType, ObsType]):
This does not support random delay values, if users are interested, please raise an issue or pull request to add this feature.
"""
def __init__(self, env: Env[ObsType, ActType], delay: int):
def __init__(self, env: gym.Env[ObsType, ActType], delay: int):
"""Initialises the DelayObservation wrapper with an integer.
Args:
env: The environment to wrap
delay: The number of timesteps to delay observations
"""
super().__init__(env)
if not np.issubdtype(type(delay), np.integer):
raise TypeError(
f"The delay is expected to be an integer, actual type: {type(delay)}"
@@ -67,6 +67,9 @@ class DelayObservationV0(ObservationWrapper[ObsType, ActType, ObsType]):
f"The delay needs to be greater than zero, actual value: {delay}"
)
gym.utils.RecordConstructorArgs.__init__(self, delay=delay)
gym.ObservationWrapper.__init__(self, env)
self.delay: Final[int] = int(delay)
self.observation_queue: Final[deque] = deque()
@@ -88,7 +91,10 @@ class DelayObservationV0(ObservationWrapper[ObsType, ActType, ObsType]):
return create_zero_array(self.observation_space)
class TimeAwareObservationV0(ObservationWrapper[WrapperObsType, ActType, ObsType]):
class TimeAwareObservationV0(
gym.ObservationWrapper[WrapperObsType, ActType, ObsType],
gym.utils.RecordConstructorArgs,
):
"""Augment the observation with time information of the episode.
The :attr:`normalize_time` if ``True`` represents time as a normalized value between [0,1]
@@ -144,7 +150,7 @@ class TimeAwareObservationV0(ObservationWrapper[WrapperObsType, ActType, ObsType
def __init__(
self,
env: Env[ObsType, ActType],
env: gym.Env[ObsType, ActType],
flatten: bool = False,
normalize_time: bool = True,
*,
@@ -159,7 +165,13 @@ class TimeAwareObservationV0(ObservationWrapper[WrapperObsType, ActType, ObsType
otherwise return time as remaining timesteps before truncation
dict_time_key: For environment with a ``Dict`` observation space, the key for the time space. By default, `"time"`.
"""
super().__init__(env)
gym.utils.RecordConstructorArgs.__init__(
self,
flatten=flatten,
normalize_time=normalize_time,
dict_time_key=dict_time_key,
)
gym.ObservationWrapper.__init__(self, env)
self.flatten: Final[bool] = flatten
self.normalize_time: Final[bool] = normalize_time
@@ -203,14 +215,14 @@ class TimeAwareObservationV0(ObservationWrapper[WrapperObsType, ActType, ObsType
# If to flatten the observation space
if self.flatten:
self.observation_space: Space[WrapperObsType] = spaces.flatten_space(
self.observation_space: gym.Space[WrapperObsType] = spaces.flatten_space(
observation_space
)
self._obs_postprocess_func = lambda obs: spaces.flatten(
observation_space, obs
)
else:
self.observation_space: Space[WrapperObsType] = observation_space
self.observation_space: gym.Space[WrapperObsType] = observation_space
self._obs_postprocess_func = lambda obs: obs
def observation(self, observation: ObsType) -> WrapperObsType:
@@ -260,7 +272,10 @@ class TimeAwareObservationV0(ObservationWrapper[WrapperObsType, ActType, ObsType
return super().reset(seed=seed, options=options)
class FrameStackObservationV0(Wrapper[WrapperObsType, ActType, ObsType, ActType]):
class FrameStackObservationV0(
gym.Wrapper[WrapperObsType, ActType, ObsType, ActType],
gym.utils.RecordConstructorArgs,
):
"""Observation wrapper that stacks the observations in a rolling manner.
For example, if the number of stacks is 4, then the returned observation contains
@@ -286,7 +301,7 @@ class FrameStackObservationV0(Wrapper[WrapperObsType, ActType, ObsType, ActType]
def __init__(
self,
env: Env[ObsType, ActType],
env: gym.Env[ObsType, ActType],
stack_size: int,
*,
zeros_obs: ObsType | None = None,
@@ -298,8 +313,6 @@ class FrameStackObservationV0(Wrapper[WrapperObsType, ActType, ObsType, ActType]
stack_size: The number of frames to stack with zero_obs being used originally.
zeros_obs: Keyword only parameter that allows a custom padding observation at :meth:`reset`
"""
super().__init__(env)
if not np.issubdtype(type(stack_size), np.integer):
raise TypeError(
f"The stack_size is expected to be an integer, actual type: {type(stack_size)}"
@@ -309,6 +322,9 @@ class FrameStackObservationV0(Wrapper[WrapperObsType, ActType, ObsType, ActType]
f"The stack_size needs to be greater than one, actual value: {stack_size}"
)
gym.utils.RecordConstructorArgs.__init__(self, stack_size=stack_size)
gym.Wrapper.__init__(self, env)
self.observation_space = batch_space(env.observation_space, n=stack_size)
self.stack_size: Final[int] = stack_size

View File

@@ -8,6 +8,7 @@ These are not intended as API functions, and will not remain stable over time.
# that verify that our dependencies are actually present.
from gymnasium.utils.colorize import colorize
from gymnasium.utils.ezpickle import EzPickle
from gymnasium.utils.record_constructor import RecordConstructorArgs
__all__ = ["colorize", "EzPickle"]
__all__ = ["colorize", "EzPickle", "RecordConstructorArgs"]

View File

@@ -1,13 +1,15 @@
"""Class for pickling and unpickling objects via their constructor arguments."""
from typing import Any
class EzPickle:
"""Objects that are pickled and unpickled via their constructor arguments.
Example:
>>> class Dog(Animal, EzPickle): # doctest: +SKIP
>>> class Animal: pass
>>> class Dog(Animal, EzPickle):
... def __init__(self, furcolor, tailkind="bushy"):
... Animal.__init__()
... Animal.__init__(self)
... EzPickle.__init__(self, furcolor, tailkind)
When this object is unpickled, a new ``Dog`` will be constructed by passing the provided furcolor and tailkind into the constructor.
@@ -16,7 +18,7 @@ class EzPickle:
This is generally needed only for environments which wrap C/C++ code, such as MuJoCo and Atari.
"""
def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any):
"""Uses the ``args`` and ``kwargs`` from the object's constructor for pickling."""
self._ezpickle_args = args
self._ezpickle_kwargs = kwargs

View File

@@ -0,0 +1,33 @@
"""Allows attributes passed to `RecordConstructorArgs` to be saved. This is used by the `Wrapper.spec` to know the constructor arguments of implemented wrappers."""
from __future__ import annotations
from copy import deepcopy
from typing import Any
class RecordConstructorArgs:
"""Records all arguments passed to constructor to `_saved_kwargs`.
This can be used to save and reproduce class constructor arguments.
Note:
If two class inherit from RecordConstructorArgs then the first class to call `RecordConstructorArgs.__init__(self, ...)` will have
their kwargs saved will all subsequent `RecordConstructorArgs.__init__` being ignored.
Therefore, always call `RecordConstructorArgs.__init__` before the `Class.__init__`
"""
def __init__(self, *, _disable_deepcopy: bool = False, **kwargs: Any):
"""Records all arguments passed to constructor to `_saved_kwargs`.
Args:
_disable_deepcopy: If to not deepcopy the kwargs passed
**kwargs: Arguments to save
"""
# See class docstring for explanation
if not hasattr(self, "_saved_kwargs"):
if _disable_deepcopy is False:
kwargs = deepcopy(kwargs)
self._saved_kwargs: dict[str, Any] = kwargs

View File

@@ -11,7 +11,7 @@ except ImportError:
cv2 = None
class AtariPreprocessing(gym.Wrapper):
class AtariPreprocessing(gym.Wrapper, gym.utils.RecordConstructorArgs):
"""Atari 2600 preprocessing wrapper.
This class follows the guidelines in Machado et al. (2018),
@@ -60,7 +60,18 @@ class AtariPreprocessing(gym.Wrapper):
DependencyNotInstalled: opencv-python package not installed
ValueError: Disable frame-skipping in the original env
"""
super().__init__(env)
gym.utils.RecordConstructorArgs.__init__(
self,
noop_max=noop_max,
frame_skip=frame_skip,
screen_size=screen_size,
terminal_on_life_loss=terminal_on_life_loss,
grayscale_obs=grayscale_obs,
grayscale_newaxis=grayscale_newaxis,
scale_obs=scale_obs,
)
gym.Wrapper.__init__(self, env)
if cv2 is None:
raise gym.error.DependencyNotInstalled(
"opencv-python package not installed, run `pip install gymnasium[other]` to get dependencies for atari"

View File

@@ -2,7 +2,7 @@
import gymnasium as gym
class AutoResetWrapper(gym.Wrapper):
class AutoResetWrapper(gym.Wrapper, gym.utils.RecordConstructorArgs):
"""A class for providing an automatic reset functionality for gymnasium environments when calling :meth:`self.step`.
When calling step causes :meth:`Env.step` to return `terminated=True` or `truncated=True`, :meth:`Env.reset` is called,
@@ -31,7 +31,8 @@ class AutoResetWrapper(gym.Wrapper):
Args:
env (gym.Env): The environment to apply the wrapper
"""
super().__init__(env)
gym.utils.RecordConstructorArgs.__init__(self)
gym.Wrapper.__init__(self, env)
def step(self, action):
"""Steps through the environment with action and resets the environment if a terminated or truncated signal is encountered.

View File

@@ -2,11 +2,10 @@
import numpy as np
import gymnasium as gym
from gymnasium import ActionWrapper
from gymnasium.spaces import Box
class ClipAction(ActionWrapper):
class ClipAction(gym.ActionWrapper, gym.utils.RecordConstructorArgs):
"""Clip the continuous action within the valid :class:`Box` observation space bound.
Example:
@@ -28,7 +27,9 @@ class ClipAction(ActionWrapper):
env: The environment to apply the wrapper
"""
assert isinstance(env.action_space, Box)
super().__init__(env)
gym.utils.RecordConstructorArgs.__init__(self)
gym.ActionWrapper.__init__(self, env)
def action(self, action):
"""Clips the action within the valid bounds.

View File

@@ -68,11 +68,12 @@ class EnvCompatibility(gym.Env):
"The `gymnasium.make(..., apply_api_compatibility=...)` parameter is deprecated and will be removed in v0.29. "
"Instead use `gym.make('GymV21Environment-v0', env_name=...)` or `from shimmy import GymV21CompatibilityV0`"
)
self.env = old_env
self.metadata = getattr(old_env, "metadata", {"render_modes": []})
self.render_mode = render_mode
self.reward_range = getattr(old_env, "reward_range", None)
self.spec = getattr(old_env, "spec", None)
self.env = old_env
self.observation_space = old_env.observation_space
self.action_space = old_env.action_space

View File

@@ -10,12 +10,13 @@ from gymnasium.utils.passive_env_checker import (
)
class PassiveEnvChecker(gym.Wrapper):
class PassiveEnvChecker(gym.Wrapper, gym.utils.RecordConstructorArgs):
"""A passive environment checker wrapper that surrounds the step, reset and render functions to check they follow the gymnasium API."""
def __init__(self, env):
"""Initialises the wrapper with the environments, run the observation and action space tests."""
super().__init__(env)
gym.utils.RecordConstructorArgs.__init__(self)
gym.Wrapper.__init__(self, env)
assert hasattr(
env, "action_space"

View File

@@ -6,7 +6,7 @@ import gymnasium as gym
from gymnasium import spaces
class FilterObservation(gym.ObservationWrapper):
class FilterObservation(gym.ObservationWrapper, gym.utils.RecordConstructorArgs):
"""Filter Dict observation space by the keys.
Example:
@@ -35,7 +35,8 @@ class FilterObservation(gym.ObservationWrapper):
ValueError: If the environment's observation space is not :class:`spaces.Dict`
ValueError: If any of the `filter_keys` are not included in the original `env`'s observation space
"""
super().__init__(env)
gym.utils.RecordConstructorArgs.__init__(self, filter_keys=filter_keys)
gym.ObservationWrapper.__init__(self, env)
wrapped_observation_space = env.observation_space
if not isinstance(wrapped_observation_space, spaces.Dict):

View File

@@ -3,7 +3,7 @@ import gymnasium as gym
from gymnasium import spaces
class FlattenObservation(gym.ObservationWrapper):
class FlattenObservation(gym.ObservationWrapper, gym.utils.RecordConstructorArgs):
"""Observation wrapper that flattens the observation.
Example:
@@ -26,7 +26,9 @@ class FlattenObservation(gym.ObservationWrapper):
Args:
env: The environment to apply the wrapper
"""
super().__init__(env)
gym.utils.RecordConstructorArgs.__init__(self)
gym.ObservationWrapper.__init__(self, env)
self.observation_space = spaces.flatten_space(env.observation_space)
def observation(self, observation):

View File

@@ -97,7 +97,7 @@ class LazyFrames:
return frame
class FrameStack(gym.ObservationWrapper):
class FrameStack(gym.ObservationWrapper, gym.utils.RecordConstructorArgs):
"""Observation wrapper that stacks the observations in a rolling manner.
For example, if the number of stacks is 4, then the returned observation contains
@@ -137,7 +137,11 @@ class FrameStack(gym.ObservationWrapper):
num_stack (int): The number of frames to stack
lz4_compress (bool): Use lz4 to compress the frames internally
"""
super().__init__(env)
gym.utils.RecordConstructorArgs.__init__(
self, num_stack=num_stack, lz4_compress=lz4_compress
)
gym.ObservationWrapper.__init__(self, env)
self.num_stack = num_stack
self.lz4_compress = lz4_compress

View File

@@ -5,7 +5,7 @@ import gymnasium as gym
from gymnasium.spaces import Box
class GrayScaleObservation(gym.ObservationWrapper):
class GrayScaleObservation(gym.ObservationWrapper, gym.utils.RecordConstructorArgs):
"""Convert the image observation from RGB to gray scale.
Example:
@@ -30,7 +30,9 @@ class GrayScaleObservation(gym.ObservationWrapper):
keep_dim (bool): If `True`, a singleton dimension will be added, i.e. observations are of the shape AxBx1.
Otherwise, they are of shape AxB.
"""
super().__init__(env)
gym.utils.RecordConstructorArgs.__init__(self, keep_dim=keep_dim)
gym.ObservationWrapper.__init__(self, env)
self.keep_dim = keep_dim
assert (

View File

@@ -7,7 +7,7 @@ import gymnasium as gym
from gymnasium.error import DependencyNotInstalled
class HumanRendering(gym.Wrapper):
class HumanRendering(gym.Wrapper, gym.utils.RecordConstructorArgs):
"""Performs human rendering for an environment that only supports "rgb_array"rendering.
This wrapper is particularly useful when you have implemented an environment that can produce
@@ -47,7 +47,9 @@ class HumanRendering(gym.Wrapper):
Args:
env: The environment that is being wrapped
"""
super().__init__(env)
gym.utils.RecordConstructorArgs.__init__(self)
gym.Wrapper.__init__(self, env)
assert env.render_mode in [
"rgb_array",
"rgb_array_list",
@@ -64,6 +66,8 @@ class HumanRendering(gym.Wrapper):
if "human" not in self.metadata["render_modes"]:
self.metadata["render_modes"].append("human")
gym.utils.RecordConstructorArgs.__init__(self)
@property
def render_mode(self):
"""Always returns ``'human'``."""

View File

@@ -45,7 +45,7 @@ def update_mean_var_count_from_moments(
return new_mean, new_var, new_count
class NormalizeObservation(gym.Wrapper):
class NormalizeObservation(gym.Wrapper, gym.utils.RecordConstructorArgs):
"""This wrapper will normalize observations s.t. each coordinate is centered with unit variance.
Note:
@@ -60,7 +60,9 @@ class NormalizeObservation(gym.Wrapper):
env (Env): The environment to apply the wrapper
epsilon: A stability parameter that is used when scaling the observations.
"""
super().__init__(env)
gym.utils.RecordConstructorArgs.__init__(self, epsilon=epsilon)
gym.Wrapper.__init__(self, env)
self.num_envs = getattr(env, "num_envs", 1)
self.is_vector_env = getattr(env, "is_vector_env", False)
if self.is_vector_env:
@@ -93,7 +95,7 @@ class NormalizeObservation(gym.Wrapper):
return (obs - self.obs_rms.mean) / np.sqrt(self.obs_rms.var + self.epsilon)
class NormalizeReward(gym.core.Wrapper):
class NormalizeReward(gym.core.Wrapper, gym.utils.RecordConstructorArgs):
r"""This wrapper will normalize immediate rewards s.t. their exponential moving average has a fixed variance.
The exponential moving average will have variance :math:`(1 - \gamma)^2`.
@@ -116,7 +118,9 @@ class NormalizeReward(gym.core.Wrapper):
epsilon (float): A stability parameter
gamma (float): The discount factor that is used in the exponential moving average.
"""
super().__init__(env)
gym.utils.RecordConstructorArgs.__init__(self, gamma=gamma, epsilon=epsilon)
gym.Wrapper.__init__(self, env)
self.num_envs = getattr(env, "num_envs", 1)
self.is_vector_env = getattr(env, "is_vector_env", False)
self.return_rms = RunningMeanStd(shape=())

View File

@@ -3,7 +3,7 @@ import gymnasium as gym
from gymnasium.error import ResetNeeded
class OrderEnforcing(gym.Wrapper):
class OrderEnforcing(gym.Wrapper, gym.utils.RecordConstructorArgs):
"""A wrapper that will produce an error if :meth:`step` is called before an initial :meth:`reset`.
Example:
@@ -32,7 +32,11 @@ class OrderEnforcing(gym.Wrapper):
env: The environment to wrap
disable_render_order_enforcing: If to disable render order enforcing
"""
super().__init__(env)
gym.utils.RecordConstructorArgs.__init__(
self, disable_render_order_enforcing=disable_render_order_enforcing
)
gym.Wrapper.__init__(self, env)
self._has_reset: bool = False
self._disable_render_order_enforcing: bool = disable_render_order_enforcing

View File

@@ -13,7 +13,7 @@ from gymnasium import spaces
STATE_KEY = "state"
class PixelObservationWrapper(gym.ObservationWrapper):
class PixelObservationWrapper(gym.ObservationWrapper, gym.utils.RecordConstructorArgs):
"""Augment observations by pixel values.
Observations of this wrapper will be dictionaries of images.
@@ -79,7 +79,13 @@ class PixelObservationWrapper(gym.ObservationWrapper):
specified ``pixel_keys``.
TypeError: When an unexpected pixel type is used
"""
super().__init__(env)
gym.utils.RecordConstructorArgs.__init__(
self,
pixels_only=pixels_only,
render_kwargs=render_kwargs,
pixel_keys=pixel_keys,
)
gym.ObservationWrapper.__init__(self, env)
# Avoid side-effects that occur when render_kwargs is manipulated
render_kwargs = copy.deepcopy(render_kwargs)

View File

@@ -8,7 +8,7 @@ import numpy as np
import gymnasium as gym
class RecordEpisodeStatistics(gym.Wrapper):
class RecordEpisodeStatistics(gym.Wrapper, gym.utils.RecordConstructorArgs):
"""This wrapper will keep track of cumulative rewards and episode lengths.
At the end of an episode, the statistics of the episode will be added to ``info``
@@ -56,7 +56,9 @@ class RecordEpisodeStatistics(gym.Wrapper):
env (Env): The environment to apply the wrapper
deque_size: The size of the buffers :attr:`return_queue` and :attr:`length_queue`
"""
super().__init__(env)
gym.utils.RecordConstructorArgs.__init__(self, deque_size=deque_size)
gym.Wrapper.__init__(self, env)
self.num_envs = getattr(env, "num_envs", 1)
self.episode_count = 0
self.episode_start_times: np.ndarray = None

View File

@@ -24,7 +24,7 @@ def capped_cubic_video_schedule(episode_id: int) -> bool:
return episode_id % 1000 == 0
class RecordVideo(gym.Wrapper):
class RecordVideo(gym.Wrapper, gym.utils.RecordConstructorArgs):
"""This wrapper records videos of rollouts.
Usually, you only want to record episodes intermittently, say every hundredth episode.
@@ -58,9 +58,17 @@ class RecordVideo(gym.Wrapper):
Otherwise, snippets of the specified length are captured
name_prefix (str): Will be prepended to the filename of the recordings
disable_logger (bool): Whether to disable moviepy logger or not.
"""
super().__init__(env)
gym.utils.RecordConstructorArgs.__init__(
self,
video_folder=video_folder,
episode_trigger=episode_trigger,
step_trigger=step_trigger,
video_length=video_length,
name_prefix=name_prefix,
disable_logger=disable_logger,
)
gym.Wrapper.__init__(self, env)
if episode_trigger is None and step_trigger is None:
episode_trigger = capped_cubic_video_schedule

View File

@@ -4,7 +4,7 @@ import copy
import gymnasium as gym
class RenderCollection(gym.Wrapper):
class RenderCollection(gym.Wrapper, gym.utils.RecordConstructorArgs):
"""Save collection of render frames."""
def __init__(self, env: gym.Env, pop_frames: bool = True, reset_clean: bool = True):
@@ -17,7 +17,11 @@ class RenderCollection(gym.Wrapper):
reset_clean (bool): If true, clear the collection frames when .reset() is called.
Default value is True.
"""
super().__init__(env)
gym.utils.RecordConstructorArgs.__init__(
self, pop_frames=pop_frames, reset_clean=reset_clean
)
gym.Wrapper.__init__(self, env)
assert env.render_mode is not None
assert not env.render_mode.endswith("_list")
self.frame_list = []

View File

@@ -7,7 +7,7 @@ import gymnasium as gym
from gymnasium.spaces import Box
class RescaleAction(gym.ActionWrapper):
class RescaleAction(gym.ActionWrapper, gym.utils.RecordConstructorArgs):
"""Affinely rescales the continuous action space of the environment to the range [min_action, max_action].
The base environment :attr:`env` must have an action space of type :class:`spaces.Box`. If :attr:`min_action`
@@ -47,7 +47,11 @@ class RescaleAction(gym.ActionWrapper):
), f"expected Box action space, got {type(env.action_space)}"
assert np.less_equal(min_action, max_action).all(), (min_action, max_action)
super().__init__(env)
gym.utils.RecordConstructorArgs.__init__(
self, min_action=min_action, max_action=max_action
)
gym.ActionWrapper.__init__(self, env)
self.min_action = (
np.zeros(env.action_space.shape, dtype=env.action_space.dtype) + min_action
)

View File

@@ -8,7 +8,7 @@ from gymnasium.error import DependencyNotInstalled
from gymnasium.spaces import Box
class ResizeObservation(gym.ObservationWrapper):
class ResizeObservation(gym.ObservationWrapper, gym.utils.RecordConstructorArgs):
"""Resize the image observation.
This wrapper works on environments with image observations. More generally,
@@ -36,7 +36,9 @@ class ResizeObservation(gym.ObservationWrapper):
env: The environment to apply the wrapper
shape: The shape of the resized observations
"""
super().__init__(env)
gym.utils.RecordConstructorArgs.__init__(self, shape=shape)
gym.ObservationWrapper.__init__(self, env)
if isinstance(shape, int):
shape = (shape, shape)
assert len(shape) == 2 and all(

View File

@@ -4,7 +4,7 @@ from gymnasium.logger import deprecation
from gymnasium.utils.step_api_compatibility import step_api_compatibility
class StepAPICompatibility(gym.Wrapper):
class StepAPICompatibility(gym.Wrapper, gym.utils.RecordConstructorArgs):
r"""A wrapper which can transform an environment from new step API to old and vice-versa.
Old step API refers to step() method returning (observation, reward, done, info)
@@ -29,7 +29,11 @@ class StepAPICompatibility(gym.Wrapper):
env (gym.Env): the env to wrap. Can be in old or new API
output_truncation_bool (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API)
"""
super().__init__(env)
gym.utils.RecordConstructorArgs.__init__(
self, output_truncation_bool=output_truncation_bool
)
gym.Wrapper.__init__(self, env)
self.is_vector_env = isinstance(env.unwrapped, gym.vector.VectorEnv)
self.output_truncation_bool = output_truncation_bool
if not self.output_truncation_bool:

View File

@@ -5,7 +5,7 @@ import gymnasium as gym
from gymnasium.spaces import Box
class TimeAwareObservation(gym.ObservationWrapper):
class TimeAwareObservation(gym.ObservationWrapper, gym.utils.RecordConstructorArgs):
"""Augment the observation with the current time step in the episode.
The observation space of the wrapped environment is assumed to be a flat :class:`Box`.
@@ -29,7 +29,9 @@ class TimeAwareObservation(gym.ObservationWrapper):
Args:
env: The environment to apply the wrapper
"""
super().__init__(env)
gym.utils.RecordConstructorArgs.__init__(self)
gym.ObservationWrapper.__init__(self, env)
assert isinstance(env.observation_space, Box)
assert env.observation_space.dtype == np.float32
low = np.append(self.observation_space.low, 0.0)

View File

@@ -4,7 +4,7 @@ from typing import Optional
import gymnasium as gym
class TimeLimit(gym.Wrapper):
class TimeLimit(gym.Wrapper, gym.utils.RecordConstructorArgs):
"""This wrapper will issue a `truncated` signal if a maximum number of timesteps is exceeded.
If a truncation is not defined inside the environment itself, this is the only place that the truncation signal is issued.
@@ -26,14 +26,16 @@ class TimeLimit(gym.Wrapper):
Args:
env: The environment to apply the wrapper
max_episode_steps: An optional max episode steps (if ``Ǹone``, ``env.spec.max_episode_steps`` is used)
max_episode_steps: An optional max episode steps (if ``None``, ``env.spec.max_episode_steps`` is used)
"""
super().__init__(env)
gym.utils.RecordConstructorArgs.__init__(
self, max_episode_steps=max_episode_steps
)
gym.Wrapper.__init__(self, env)
if max_episode_steps is None and self.env.spec is not None:
assert env.spec is not None
max_episode_steps = env.spec.max_episode_steps
if self.env.spec is not None:
self.env.spec.max_episode_steps = max_episode_steps
self._max_episode_steps = max_episode_steps
self._elapsed_steps = None

View File

@@ -4,7 +4,7 @@ from typing import Any, Callable
import gymnasium as gym
class TransformObservation(gym.ObservationWrapper):
class TransformObservation(gym.ObservationWrapper, gym.utils.RecordConstructorArgs):
"""Transform the observation via an arbitrary function :attr:`f`.
The function :attr:`f` should be defined on the observation space of the base environment, ``env``, and should, ideally, return values in the same space.
@@ -29,7 +29,9 @@ class TransformObservation(gym.ObservationWrapper):
env: The environment to apply the wrapper
f: A function that transforms the observation
"""
super().__init__(env)
gym.utils.RecordConstructorArgs.__init__(self, f=f)
gym.ObservationWrapper.__init__(self, env)
assert callable(f)
self.f = f

View File

@@ -2,10 +2,9 @@
from typing import Callable
import gymnasium as gym
from gymnasium import RewardWrapper
class TransformReward(RewardWrapper):
class TransformReward(gym.RewardWrapper, gym.utils.RecordConstructorArgs):
"""Transform the reward via an arbitrary function.
Warning:
@@ -29,7 +28,9 @@ class TransformReward(RewardWrapper):
env: The environment to apply the wrapper
f: A function that transforms the reward
"""
super().__init__(env)
gym.utils.RecordConstructorArgs.__init__(self, f=f)
gym.RewardWrapper.__init__(self, env)
assert callable(f)
self.f = f

View File

@@ -5,7 +5,7 @@ from typing import List
import gymnasium as gym
class VectorListInfo(gym.Wrapper):
class VectorListInfo(gym.Wrapper, gym.utils.RecordConstructorArgs):
"""Converts infos of vectorized environments from dict to List[dict].
This wrapper converts the info format of a
@@ -51,7 +51,9 @@ class VectorListInfo(gym.Wrapper):
assert getattr(
env, "is_vector_env", False
), "This wrapper can only be used in vectorized environments."
super().__init__(env)
gym.utils.RecordConstructorArgs.__init__(self)
gym.Wrapper.__init__(self, env)
def step(self, action):
"""Steps through the environment, convert dict info to list."""

View File

@@ -0,0 +1,237 @@
"""Example file showing usage of env.specstack."""
import pickle
import pytest
import gymnasium as gym
from gymnasium.envs.classic_control import CartPoleEnv
from gymnasium.envs.registration import EnvSpec
from gymnasium.utils.env_checker import data_equivalence
def test_full_integration():
# Create an environment to test with
env = gym.make("CartPole-v1", render_mode="rgb_array").unwrapped
env = gym.wrappers.FlattenObservation(env)
env = gym.wrappers.TimeAwareObservation(env)
env = gym.wrappers.NormalizeReward(env, gamma=0.8)
# Generate the spec_stack
env_spec = env.spec
assert isinstance(env_spec, EnvSpec)
# env_spec.pprint()
# Serialize the spec_stack
env_spec_json = env_spec.to_json()
assert isinstance(env_spec_json, str)
# Deserialize the spec_stack
recreate_env_spec = EnvSpec.from_json(env_spec_json)
# recreate_env_spec.pprint()
for wrapper_spec, recreated_wrapper_spec in zip(
env_spec.applied_wrappers, recreate_env_spec.applied_wrappers
):
assert wrapper_spec == recreated_wrapper_spec
assert recreate_env_spec == env_spec
# Recreate the environment using the spec_stack
recreated_env = gym.make(recreate_env_spec)
assert recreated_env.render_mode == "rgb_array"
assert isinstance(recreated_env, gym.wrappers.NormalizeReward)
assert recreated_env.gamma == 0.8
assert isinstance(recreated_env.env, gym.wrappers.TimeAwareObservation)
assert isinstance(recreated_env.unwrapped, CartPoleEnv)
obs, info = env.reset(seed=42)
recreated_obs, recreated_info = recreated_env.reset(seed=42)
assert data_equivalence(obs, recreated_obs)
assert data_equivalence(info, recreated_info)
action = env.action_space.sample()
obs, reward, terminated, truncated, info = env.step(action)
(
recreated_obs,
recreated_reward,
recreated_terminated,
recreated_truncated,
recreated_info,
) = recreated_env.step(action)
assert data_equivalence(obs, recreated_obs)
assert data_equivalence(reward, recreated_reward)
assert data_equivalence(terminated, recreated_terminated)
assert data_equivalence(truncated, recreated_truncated)
assert data_equivalence(info, recreated_info)
# Test the pprint of the spec_stack
spec_stack_output = env_spec.pprint(disable_print=True)
json_spec_stack_output = env_spec.pprint(disable_print=True)
assert spec_stack_output == json_spec_stack_output
@pytest.mark.parametrize(
"env_spec",
[
gym.spec("CartPole-v1"),
gym.make("CartPole-v1").unwrapped.spec,
gym.make("CartPole-v1").spec,
gym.wrappers.NormalizeReward(gym.make("CartPole-v1")).spec,
],
)
def test_env_spec_to_from_json(env_spec: EnvSpec):
json_spec = env_spec.to_json()
recreated_env_spec = EnvSpec.from_json(json_spec)
assert env_spec == recreated_env_spec
def test_wrapped_env_entry_point():
def _create_env():
_env = gym.make("CartPole-v1", render_mode="rgb_array")
_env = gym.wrappers.FlattenObservation(_env)
return _env
gym.register("TestingEnv-v0", entry_point=_create_env)
env = gym.make("TestingEnv-v0")
env = gym.wrappers.TimeAwareObservation(env)
env = gym.wrappers.NormalizeReward(env, gamma=0.8)
recreated_env = gym.make(env.spec)
obs, info = env.reset(seed=42)
recreated_obs, recreated_info = recreated_env.reset(seed=42)
assert data_equivalence(obs, recreated_obs)
assert data_equivalence(info, recreated_info)
action = env.action_space.sample()
obs, reward, terminated, truncated, info = env.step(action)
(
recreated_obs,
recreated_reward,
recreated_terminated,
recreated_truncated,
recreated_info,
) = recreated_env.step(action)
assert data_equivalence(obs, recreated_obs)
assert data_equivalence(reward, recreated_reward)
assert data_equivalence(terminated, recreated_terminated)
assert data_equivalence(truncated, recreated_truncated)
assert data_equivalence(info, recreated_info)
del gym.registry["TestingEnv-v0"]
def test_pickling_env_stack():
env = gym.make("CartPole-v1", render_mode="rgb_array")
env = gym.wrappers.FlattenObservation(env)
env = gym.wrappers.TimeAwareObservation(env)
env = gym.wrappers.NormalizeReward(env, gamma=0.8)
pickled_env = pickle.loads(pickle.dumps(env))
obs, info = env.reset(seed=123)
pickled_obs, pickled_info = pickled_env.reset(seed=123)
assert data_equivalence(obs, pickled_obs)
assert data_equivalence(info, pickled_info)
action = env.action_space.sample()
obs, reward, terminated, truncated, info = env.step(action)
(
pickled_obs,
pickled_reward,
pickled_terminated,
pickled_truncated,
pickled_info,
) = pickled_env.step(action)
assert data_equivalence(obs, pickled_obs)
assert data_equivalence(reward, pickled_reward)
assert data_equivalence(terminated, pickled_terminated)
assert data_equivalence(truncated, pickled_truncated)
assert data_equivalence(info, pickled_info)
env.close()
pickled_env.close()
# flake8: noqa
def test_env_spec_pprint():
env = gym.make("CartPole-v1")
env_spec = env.spec
assert env_spec is not None
output = env_spec.pprint(disable_print=True)
assert (
output
== """id=CartPole-v1
reward_threshold=475.0
max_episode_steps=500
applied_wrappers=[
name=PassiveEnvChecker, kwargs={},
name=OrderEnforcing, kwargs={'disable_render_order_enforcing': False},
name=TimeLimit, kwargs={'max_episode_steps': 500}
]"""
)
output = env_spec.pprint(disable_print=True, include_entry_points=True)
assert (
output
== """id=CartPole-v1
entry_point=gymnasium.envs.classic_control.cartpole:CartPoleEnv
reward_threshold=475.0
max_episode_steps=500
applied_wrappers=[
name=PassiveEnvChecker, entry_point=gymnasium.wrappers.env_checker:PassiveEnvChecker, kwargs={},
name=OrderEnforcing, entry_point=gymnasium.wrappers.order_enforcing:OrderEnforcing, kwargs={'disable_render_order_enforcing': False},
name=TimeLimit, entry_point=gymnasium.wrappers.time_limit:TimeLimit, kwargs={'max_episode_steps': 500}
]"""
)
output = env_spec.pprint(disable_print=True, print_all=True)
assert (
output
== """id=CartPole-v1
entry_point=gymnasium.envs.classic_control.cartpole:CartPoleEnv
reward_threshold=475.0
nondeterministic=False
max_episode_steps=500
order_enforce=True
autoreset=False
disable_env_checker=False
applied_api_compatibility=False
applied_wrappers=[
name=PassiveEnvChecker, kwargs={},
name=OrderEnforcing, kwargs={'disable_render_order_enforcing': False},
name=TimeLimit, kwargs={'max_episode_steps': 500}
]"""
)
env_spec.applied_wrappers = ()
output = env_spec.pprint(disable_print=True)
assert (
output
== """id=CartPole-v1
reward_threshold=475.0
max_episode_steps=500"""
)
output = env_spec.pprint(disable_print=True, print_all=True)
assert (
output
== """id=CartPole-v1
entry_point=gymnasium.envs.classic_control.cartpole:CartPoleEnv
reward_threshold=475.0
nondeterministic=False
max_episode_steps=500
order_enforce=True
autoreset=False
disable_env_checker=False
applied_api_compatibility=False
applied_wrappers=[]"""
)

View File

@@ -8,6 +8,8 @@ import numpy as np
import pytest
import gymnasium as gym
from gymnasium import Env
from gymnasium.core import ActType, ObsType, WrapperObsType
from gymnasium.envs.classic_control import CartPoleEnv
from gymnasium.wrappers import (
AutoResetWrapper,
@@ -355,3 +357,69 @@ def test_import_module_during_make():
env.close()
del gym.registry["RegisterDuringMake-v0"]
class NoRecordArgsWrapper(gym.ObservationWrapper):
def __init__(self, env: Env[ObsType, ActType]):
super().__init__(env)
def observation(self, observation: ObsType) -> WrapperObsType:
return self.observation_space.sample()
def test_make_env_spec():
# make
env_1 = gym.make(gym.spec("CartPole-v1"))
assert isinstance(env_1, CartPoleEnv)
assert env_1 is env_1.unwrapped
env_1.close()
# make with applied wrappers
env_2 = gym.wrappers.NormalizeReward(
gym.wrappers.TimeAwareObservation(
gym.wrappers.FlattenObservation(
gym.make("CartPole-v1", render_mode="rgb_array")
)
),
gamma=0.8,
)
env_2_recreated = gym.make(env_2.spec)
assert env_2.spec == env_2_recreated.spec
env_2.close()
env_2_recreated.close()
# make with callable entry point
gym.register("CartPole-v2", lambda: CartPoleEnv())
env_3 = gym.make("CartPole-v2")
assert isinstance(env_3.unwrapped, CartPoleEnv)
env_3.close()
# make with wrapper in env-creator
gym.register(
"CartPole-v3", lambda: gym.wrappers.TimeAwareObservation(CartPoleEnv())
)
env_4 = gym.make(gym.spec("CartPole-v3"))
assert isinstance(env_4, gym.wrappers.TimeAwareObservation)
assert isinstance(env_4.env, CartPoleEnv)
env_4.close()
# make with no ezpickle wrapper
env_5 = NoRecordArgsWrapper(gym.make("CartPole-v1").unwrapped)
env_5.close()
with pytest.raises(
ValueError,
match=re.escape(
"NoRecordArgsWrapper wrapper does not inherit from `gymnasium.utils.RecordConstructorArgs`, therefore, the wrapper cannot be recreated."
),
):
gym.make(env_5.spec)
# make with no ezpickle wrapper but in the entry point
gym.register("CartPole-v4", entry_point=lambda: NoRecordArgsWrapper(CartPoleEnv()))
env_6 = gym.make(gym.spec("CartPole-v4"))
assert isinstance(env_6, NoRecordArgsWrapper)
assert isinstance(env_6.unwrapped, CartPoleEnv)
del gym.registry["CartPole-v2"]
del gym.registry["CartPole-v3"]
del gym.registry["CartPole-v4"]

View File

@@ -22,7 +22,7 @@ def test_mujoco_action_dimensions(env_spec: EnvSpec):
* Too many dimensions
* Incorrect shape
"""
env = env_spec.make(disable_env_checker=True)
env = env_spec.make()
env.reset()
# Too few actions

View File

@@ -42,7 +42,7 @@ CHECK_ENV_IGNORE_WARNINGS = [
def test_all_env_api(spec):
"""Check that all environments pass the environment checker with no warnings other than the expected."""
with warnings.catch_warnings(record=True) as caught_warnings:
env = spec.make(disable_env_checker=True).unwrapped
env = spec.make().unwrapped
check_env(env, skip_render_check=True)
env.close()

View File

@@ -15,8 +15,8 @@ def verify_environments_match(
):
"""Verifies with two environment ids (old and new) are identical in obs, reward and done
(except info where all old info must be contained in new info)."""
old_env = envs.make(old_env_id, disable_env_checker=True)
new_env = envs.make(new_env_id, disable_env_checker=True)
old_env = envs.make(old_env_id)
new_env = envs.make(new_env_id)
old_reset_obs, old_info = old_env.reset(seed=seed)
new_reset_obs, new_info = new_env.reset(seed=seed)

View File

@@ -106,7 +106,7 @@ class TestNestedDictWrapper:
observation_space = env.observation_space
assert isinstance(observation_space, Dict)
wrapped_env = FlattenObservation(FilterObservation(env, env.obs_keys))
wrapped_env = FlattenObservation(FilterObservation(env, list(env.obs_keys)))
assert wrapped_env.observation_space.shape == flat_shape
assert wrapped_env.observation_space.dtype == np.float32
@@ -114,7 +114,7 @@ class TestNestedDictWrapper:
@pytest.mark.parametrize("observation_space, flat_shape", NESTED_DICT_TEST_CASES)
def test_nested_dicts_ravel(self, observation_space, flat_shape):
env = FakeEnvironment(observation_space=observation_space)
wrapped_env = FlattenObservation(FilterObservation(env, env.obs_keys))
wrapped_env = FlattenObservation(FilterObservation(env, list(env.obs_keys)))
obs, info = wrapped_env.reset()
assert obs.shape == wrapped_env.observation_space.shape
assert isinstance(info, dict)

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
import gymnasium as gym
def has_wrapper(wrapped_env: gym.Env, wrapper_type: type) -> bool:
def has_wrapper(wrapped_env: gym.Env, wrapper_type: type[gym.Wrapper]) -> bool:
while isinstance(wrapped_env, gym.Wrapper):
if isinstance(wrapped_env, wrapper_type):
return True