Update EnvSpec and make to support reproducing the "whole" environment spec including wrappers (#292)

Co-authored-by: will <will2346@live.co.uk>
Co-authored-by: Will Dudley <14932240+WillDudley@users.noreply.github.com>
Co-authored-by: Mark Towers <marktowers@Marks-MacBook-Pro.local>
This commit is contained in:
Mark Towers
2023-02-24 11:34:20 +00:00
committed by GitHub
parent 6f35e7f87f
commit 9e3200d000
47 changed files with 1178 additions and 319 deletions

View File

@@ -1,12 +1,13 @@
"""Core API for Environment, Wrapper, ActionWrapper, RewardWrapper and ObservationWrapper.""" """Core API for Environment, Wrapper, ActionWrapper, RewardWrapper and ObservationWrapper."""
from __future__ import annotations from __future__ import annotations
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Generic, SupportsFloat, TypeVar from typing import TYPE_CHECKING, Any, Generic, SupportsFloat, TypeVar
import numpy as np import numpy as np
from gymnasium import spaces from gymnasium import spaces
from gymnasium.utils import seeding from gymnasium.utils import RecordConstructorArgs, seeding
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -277,8 +278,32 @@ class Wrapper(
@property @property
def spec(self) -> EnvSpec | None: def spec(self) -> EnvSpec | None:
"""Returns the :attr:`Env` :attr:`spec` attribute.""" """Returns the :attr:`Env` :attr:`spec` attribute with the `WrapperSpec` if the wrapper inherits from `EzPickle`."""
return self.env.spec env_spec = self.env.spec
if env_spec is not None:
from gymnasium.envs.registration import WrapperSpec
# See if the wrapper inherits from `RecordConstructorArgs` then add the kwargs otherwise use `None` for the wrapper kwargs. This will raise an error in `make`
if isinstance(self, RecordConstructorArgs):
kwargs = getattr(self, "_saved_kwargs")
if "env" in kwargs:
kwargs = deepcopy(kwargs)
kwargs.pop("env")
else:
kwargs = None
wrapper_spec = WrapperSpec(
name=self.class_name(),
entry_point=f"{self.__module__}:{type(self).__name__}",
kwargs=kwargs,
)
# to avoid reference issues we deepcopy the prior environments spec and add the new information
env_spec = deepcopy(env_spec)
env_spec.applied_wrappers += (wrapper_spec,)
return env_spec
@classmethod @classmethod
def class_name(cls) -> str: def class_name(cls) -> str:
@@ -409,7 +434,7 @@ class ObservationWrapper(Wrapper[WrapperObsType, ActType, ObsType, ActType]):
def __init__(self, env: Env[ObsType, ActType]): def __init__(self, env: Env[ObsType, ActType]):
"""Constructor for the observation wrapper.""" """Constructor for the observation wrapper."""
super().__init__(env) Wrapper.__init__(self, env)
def reset( def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None self, *, seed: int | None = None, options: dict[str, Any] | None = None
@@ -449,7 +474,7 @@ class RewardWrapper(Wrapper[ObsType, ActType, ObsType, ActType]):
def __init__(self, env: Env[ObsType, ActType]): def __init__(self, env: Env[ObsType, ActType]):
"""Constructor for the Reward wrapper.""" """Constructor for the Reward wrapper."""
super().__init__(env) Wrapper.__init__(self, env)
def step( def step(
self, action: ActType self, action: ActType
@@ -485,7 +510,7 @@ class ActionWrapper(Wrapper[ObsType, WrapperActType, ObsType, ActType]):
def __init__(self, env: Env[ObsType, ActType]): def __init__(self, env: Env[ObsType, ActType]):
"""Constructor for the action wrapper.""" """Constructor for the action wrapper."""
super().__init__(env) Wrapper.__init__(self, env)
def step( def step(
self, action: WrapperActType self, action: WrapperActType

View File

@@ -3,9 +3,11 @@ from __future__ import annotations
import contextlib import contextlib
import copy import copy
import dataclasses
import difflib import difflib
import importlib import importlib
import importlib.util import importlib.util
import json
import re import re
import sys import sys
import traceback import traceback
@@ -17,13 +19,13 @@ from gymnasium import Env, Wrapper, error, logger
from gymnasium.experimental.vector import AsyncVectorEnv, SyncVectorEnv, VectorEnv from gymnasium.experimental.vector import AsyncVectorEnv, SyncVectorEnv, VectorEnv
from gymnasium.wrappers import ( from gymnasium.wrappers import (
AutoResetWrapper, AutoResetWrapper,
EnvCompatibility,
HumanRendering, HumanRendering,
OrderEnforcing, OrderEnforcing,
PassiveEnvChecker,
RenderCollection, RenderCollection,
TimeLimit, TimeLimit,
) )
from gymnasium.wrappers.compatibility import EnvCompatibility
from gymnasium.wrappers.env_checker import PassiveEnvChecker
if sys.version_info < (3, 10): if sys.version_info < (3, 10):
@@ -43,11 +45,14 @@ ENV_ID_RE = re.compile(
__all__ = [ __all__ = [
"EnvSpec",
"registry", "registry",
"current_namespace", "current_namespace",
"EnvSpec",
"WrapperSpec",
# Functions
"register", "register",
"make", "make",
"make_vec",
"spec", "spec",
"pprint_registry", "pprint_registry",
] ]
@@ -67,6 +72,20 @@ class VectorEnvCreator(Protocol):
... ...
@dataclass
class WrapperSpec:
"""A specification for recording wrapper configs.
* name: The name of the wrapper.
* entry_point: The location of the wrapper to create from.
* kwargs: Additional keyword arguments passed to the wrapper. If the wrapper doesn't inherit from EzPickle then this is ``None``
"""
name: str
entry_point: str
kwargs: dict[str, Any] | None
@dataclass @dataclass
class EnvSpec: class EnvSpec:
"""A specification for creating environments with :meth:`gymnasium.make`. """A specification for creating environments with :meth:`gymnasium.make`.
@@ -80,6 +99,7 @@ class EnvSpec:
* **autoreset**: If to automatically reset the environment on episode end * **autoreset**: If to automatically reset the environment on episode end
* **disable_env_checker**: If to disable the environment checker wrapper in :meth:`gymnasium.make`, by default False (runs the environment checker) * **disable_env_checker**: If to disable the environment checker wrapper in :meth:`gymnasium.make`, by default False (runs the environment checker)
* **kwargs**: Additional keyword arguments passed to the environment during initialisation * **kwargs**: Additional keyword arguments passed to the environment during initialisation
* **applied_wrappers**: A tuple of applied wrappers (WrapperSpec)
* **vector_entry_point**: The location of the vectorized environment to create from * **vector_entry_point**: The location of the vectorized environment to create from
""" """
@@ -97,27 +117,152 @@ class EnvSpec:
disable_env_checker: bool = field(default=False) disable_env_checker: bool = field(default=False)
apply_api_compatibility: bool = field(default=False) apply_api_compatibility: bool = field(default=False)
# Environment arguments
kwargs: dict = field(default_factory=dict)
# post-init attributes # post-init attributes
namespace: str | None = field(init=False) namespace: str | None = field(init=False)
name: str = field(init=False) name: str = field(init=False)
version: int | None = field(init=False) version: int | None = field(init=False)
# Environment arguments # applied wrappers
kwargs: dict = field(default_factory=dict) applied_wrappers: tuple[WrapperSpec, ...] = field(init=False, default_factory=tuple)
# Vectorized environment # Vectorized environment entry point
vector_entry_point: str | None = field(default=None) vector_entry_point: VectorEnvCreator | str | None = field(default=None)
def __post_init__(self): def __post_init__(self):
"""Calls after the spec is created to extract the namespace, name and version from the id.""" """Calls after the spec is created to extract the namespace, name and version from the environment id."""
# Initialize namespace, name, version
self.namespace, self.name, self.version = parse_env_id(self.id) self.namespace, self.name, self.version = parse_env_id(self.id)
def make(self, **kwargs: Any) -> Env: def make(self, **kwargs: Any) -> Env:
"""Calls ``make`` using the environment spec and any keyword arguments.""" """Calls ``make`` using the environment spec and any keyword arguments."""
# For compatibility purposes
return make(self, **kwargs) return make(self, **kwargs)
def to_json(self) -> str:
"""Converts the environment spec into a json compatible string.
Returns:
A jsonifyied string for the environment spec
"""
env_spec_dict = dataclasses.asdict(self)
# As the namespace, name and version are initialised after `init` then we remove the attributes
env_spec_dict.pop("namespace")
env_spec_dict.pop("name")
env_spec_dict.pop("version")
# To check that the environment spec can be transformed to a json compatible type
self._check_can_jsonify(env_spec_dict)
return json.dumps(env_spec_dict)
@staticmethod
def _check_can_jsonify(env_spec: dict[str, Any]):
"""Warns the user about serialisation failing if the spec contains a callable.
Args:
env_spec: An environment or wrapper specification.
Returns: The specification with lambda functions converted to strings.
"""
spec_name = env_spec["name"] if "name" in env_spec else env_spec["id"]
for key, value in env_spec.items():
if callable(value):
ValueError(
f"Callable found in {spec_name} for {key} attribute with value={value}. Currently, Gymnasium does not support serialising callables."
)
@staticmethod
def from_json(json_env_spec: str) -> EnvSpec:
"""Converts a JSON string into a specification stack.
Args:
json_env_spec: A JSON string representing the env specification.
Returns:
An environment spec
"""
parsed_env_spec = json.loads(json_env_spec)
applied_wrapper_specs: list[WrapperSpec] = []
for wrapper_spec_json in parsed_env_spec.pop("applied_wrappers"):
try:
applied_wrapper_specs.append(WrapperSpec(**wrapper_spec_json))
except Exception as e:
raise ValueError(
f"An issue occurred when trying to make {wrapper_spec_json} a WrapperSpec"
) from e
try:
env_spec = EnvSpec(**parsed_env_spec)
env_spec.applied_wrappers = tuple(applied_wrapper_specs)
except Exception as e:
raise ValueError(
f"An issue occurred when trying to make {parsed_env_spec} an EnvSpec"
) from e
return env_spec
def pprint(
self,
disable_print: bool = False,
include_entry_points: bool = False,
print_all: bool = False,
) -> str | None:
"""Pretty prints the environment spec.
Args:
disable_print: If to disable print and return the output
include_entry_points: If to include the entry_points in the output
print_all: If to print all information, including variables with default values
Returns:
If ``disable_print is True`` a string otherwise ``None``
"""
output = f"id={self.id}"
if print_all or include_entry_points:
output += f"\nentry_point={self.entry_point}"
if print_all or self.reward_threshold is not None:
output += f"\nreward_threshold={self.reward_threshold}"
if print_all or self.nondeterministic is not False:
output += f"\nnondeterministic={self.nondeterministic}"
if print_all or self.max_episode_steps is not None:
output += f"\nmax_episode_steps={self.max_episode_steps}"
if print_all or self.order_enforce is not True:
output += f"\norder_enforce={self.order_enforce}"
if print_all or self.autoreset is not False:
output += f"\nautoreset={self.autoreset}"
if print_all or self.disable_env_checker is not False:
output += f"\ndisable_env_checker={self.disable_env_checker}"
if print_all or self.apply_api_compatibility is not False:
output += f"\napplied_api_compatibility={self.apply_api_compatibility}"
if print_all or self.applied_wrappers:
wrapper_output: list[str] = []
for wrapper_spec in self.applied_wrappers:
if include_entry_points:
wrapper_output.append(
f"\n\tname={wrapper_spec.name}, entry_point={wrapper_spec.entry_point}, kwargs={wrapper_spec.kwargs}"
)
else:
wrapper_output.append(
f"\n\tname={wrapper_spec.name}, kwargs={wrapper_spec.kwargs}"
)
if len(wrapper_output) == 0:
output += "\napplied_wrappers=[]"
else:
output += f"\napplied_wrappers=[{','.join(wrapper_output)}\n]"
if disable_print:
return output
else:
print(output)
# Global registry of environments. Meant to be accessed through `register` and `make` # Global registry of environments. Meant to be accessed through `register` and `make`
registry: dict[str, EnvSpec] = {} registry: dict[str, EnvSpec] = {}
@@ -352,8 +497,12 @@ def _check_metadata(testing_metadata: dict[str, Any]):
) )
def _find_spec(id: str) -> EnvSpec: def _find_spec(env_id: str) -> EnvSpec:
module, env_name = (None, id) if ":" not in id else id.split(":") # For string id's, load the environment spec from the registry then make the environment spec
assert isinstance(env_id, str)
# The environment name can include an unloaded module in "module:env_name" style
module, env_name = (None, env_id) if ":" not in env_id else env_id.split(":")
if module is not None: if module is not None:
try: try:
importlib.import_module(module) importlib.import_module(module)
@@ -391,7 +540,7 @@ def _find_spec(id: str) -> EnvSpec:
return env_spec return env_spec
def load_env(name: str) -> EnvCreator: def load_env_creator(name: str) -> EnvCreator | VectorEnvCreator:
"""Loads an environment with name of style ``"(import path):(environment name)"`` and returns the environment creation function, normally the environment class type. """Loads an environment with name of style ``"(import path):(environment name)"`` and returns the environment creation function, normally the environment class type.
Args: Args:
@@ -406,6 +555,161 @@ def load_env(name: str) -> EnvCreator:
return fn return fn
def _create_from_env_spec(
env_spec: EnvSpec,
kwargs: dict[str, Any],
) -> Env:
"""Recreates an environment spec using a list of wrapper specs."""
if callable(env_spec.entry_point):
env_creator = env_spec.entry_point
else:
env_creator: EnvCreator = load_env_creator(env_spec.entry_point)
# Create the environment
env: Env = env_creator(**env_spec.kwargs, **kwargs)
# Set the `EnvSpec` to the environment
new_env_spec = copy.deepcopy(env_spec)
new_env_spec.applied_wrappers = ()
new_env_spec.kwargs.update(kwargs)
env.unwrapped.spec = new_env_spec
# Check if the environment spec
assert env.spec is not None # this is for pyright
num_prior_wrappers = len(env.spec.applied_wrappers)
if env_spec.applied_wrappers[:num_prior_wrappers] != env.spec.applied_wrappers:
for env_spec_wrapper_spec, recreated_wrapper_spec in zip(
env_spec.applied_wrappers, env.spec.applied_wrappers
):
raise ValueError(
f"The environment's wrapper spec {recreated_wrapper_spec} is different from the saved `EnvSpec` applied_wrappers {env_spec_wrapper_spec}"
)
for wrapper_spec in env_spec.applied_wrappers[num_prior_wrappers:]:
if wrapper_spec.kwargs is None:
raise ValueError(
f"{wrapper_spec.name} wrapper does not inherit from `gymnasium.utils.RecordConstructorArgs`, therefore, the wrapper cannot be recreated."
)
env = load_env_creator(wrapper_spec.entry_point)(env=env, **wrapper_spec.kwargs)
return env
def _create_from_env_id(
env_spec: EnvSpec,
kwargs: dict[str, Any],
max_episode_steps: int | None = None,
autoreset: bool = False,
apply_api_compatibility: bool | None = None,
disable_env_checker: bool | None = None,
) -> Env:
"""Creates an environment based on the `env_spec` along with wrapper options. See `make` for their meaning."""
spec_kwargs = copy.deepcopy(env_spec.kwargs)
spec_kwargs.update(kwargs)
# Load the environment creator
if env_spec.entry_point is None:
raise error.Error(f"{env_spec.id} registered but entry_point is not specified")
elif callable(env_spec.entry_point):
env_creator = env_spec.entry_point
else:
# Assume it's a string
env_creator = load_env_creator(env_spec.entry_point)
# Determine if to use the rendering
render_modes: list[str] | None = None
if hasattr(env_creator, "metadata"):
_check_metadata(env_creator.metadata)
render_modes = env_creator.metadata.get("render_modes")
mode = spec_kwargs.get("render_mode")
apply_human_rendering = False
apply_render_collection = False
# If mode is not valid, try applying HumanRendering/RenderCollection wrappers
if mode is not None and render_modes is not None and mode not in render_modes:
displayable_modes = {"rgb_array", "rgb_array_list"}.intersection(render_modes)
if mode == "human" and len(displayable_modes) > 0:
logger.warn(
"You are trying to use 'human' rendering for an environment that doesn't natively support it. "
"The HumanRendering wrapper is being applied to your environment."
)
spec_kwargs["render_mode"] = displayable_modes.pop()
apply_human_rendering = True
elif mode.endswith("_list") and mode[: -len("_list")] in render_modes:
spec_kwargs["render_mode"] = mode[: -len("_list")]
apply_render_collection = True
else:
logger.warn(
f"The environment is being initialised with render_mode={mode!r} "
f"that is not in the possible render_modes ({render_modes})."
)
if apply_api_compatibility or (
apply_api_compatibility is None and env_spec.apply_api_compatibility
):
# If we use the compatibility layer, we treat the render mode explicitly and don't pass it to the env creator
render_mode = spec_kwargs.pop("render_mode", None)
else:
render_mode = None
try:
env = env_creator(**spec_kwargs)
except TypeError as e:
if (
str(e).find("got an unexpected keyword argument 'render_mode'") >= 0
and apply_human_rendering
):
raise error.Error(
f"You passed render_mode='human' although {env_spec.id} doesn't implement human-rendering natively. "
"Gym tried to apply the HumanRendering wrapper but it looks like your environment is using the old "
"rendering API, which is not supported by the HumanRendering wrapper."
) from e
else:
raise e
# Copies the environment creation specification and kwargs to add to the environment specification details
env_spec = copy.deepcopy(env_spec)
env_spec.kwargs = spec_kwargs
env.unwrapped.spec = env_spec
# Add step API wrapper
if apply_api_compatibility is True or (
apply_api_compatibility is None and env_spec.apply_api_compatibility is True
):
env = EnvCompatibility(env, render_mode)
# Run the environment checker as the lowest level wrapper
if disable_env_checker is False or (
disable_env_checker is None and env_spec.disable_env_checker is False
):
env = PassiveEnvChecker(env)
# Add the order enforcing wrapper
if env_spec.order_enforce:
env = OrderEnforcing(env)
# Add the time limit wrapper
if max_episode_steps is not None:
assert env.unwrapped.spec is not None # for pyright
env.unwrapped.spec.max_episode_steps = max_episode_steps
env = TimeLimit(env, max_episode_steps)
elif env_spec.max_episode_steps is not None:
env = TimeLimit(env, env_spec.max_episode_steps)
# Add the auto-reset wrapper
if autoreset:
env = AutoResetWrapper(env)
# Add human rendering wrapper
if apply_human_rendering:
env = HumanRendering(env)
elif apply_render_collection:
env = RenderCollection(env)
return env
def load_plugin_envs(entry_point: str = "gymnasium.envs"): def load_plugin_envs(entry_point: str = "gymnasium.envs"):
"""Load modules (plugins) using the gymnasium entry points in order to register external module's environments on ``import gymnasium``. """Load modules (plugins) using the gymnasium entry points in order to register external module's environments on ``import gymnasium``.
@@ -437,10 +741,8 @@ def load_plugin_envs(entry_point: str = "gymnasium.envs"):
context = namespace(plugin.name) context = namespace(plugin.name)
if plugin.name.startswith("__") and plugin.name.endswith("__"): if plugin.name.startswith("__") and plugin.name.endswith("__"):
# `__internal__` is an artifact of the plugin system when # `__internal__` is an artifact of the plugin system when the root namespace had an allow-list.
# the root namespace had an allow-list. The allow-list is now # The allow-list is now removed and plugins can register environments in the root namespace with the `__root__` magic key.
# removed and plugins can register environments in the root
# namespace with the `__root__` magic key.
if plugin.name == "__root__" or plugin.name == "__internal__": if plugin.name == "__root__" or plugin.name == "__internal__":
context = contextlib.nullcontext() context = contextlib.nullcontext()
else: else:
@@ -533,9 +835,9 @@ def register(
order_enforce=order_enforce, order_enforce=order_enforce,
autoreset=autoreset, autoreset=autoreset,
disable_env_checker=disable_env_checker, disable_env_checker=disable_env_checker,
**kwargs,
apply_api_compatibility=apply_api_compatibility, apply_api_compatibility=apply_api_compatibility,
vector_entry_point=vector_entry_point, vector_entry_point=vector_entry_point,
**kwargs,
) )
_check_spec_register(new_spec) _check_spec_register(new_spec)
@@ -576,116 +878,47 @@ def make(
Error: If the ``id`` doesn't exist in the :attr:`registry` Error: If the ``id`` doesn't exist in the :attr:`registry`
""" """
if isinstance(id, EnvSpec): if isinstance(id, EnvSpec):
env_spec = id if hasattr(id, "applied_wrappers") and id.applied_wrappers is not None:
if max_episode_steps is not None:
logger.warn(
f"For `gymnasium.make` with an `EnvSpec`, the `max_episode_step` parameter is ignored, use `gym.make({id.id}, max_episode_steps={max_episode_steps})` and any additional wrappers"
)
if autoreset is True:
logger.warn(
f"For `gymnasium.make` with an `EnvSpec`, the `autoreset` parameter is ignored, use `gym.make({id.id}, autoreset={autoreset})` and any additional wrappers"
)
if apply_api_compatibility is not None:
logger.warn(
f"For `gymnasium.make` with an `EnvSpec`, the `apply_api_compatibility` parameter is ignored, use `gym.make({id.id}, apply_api_compatibility={apply_api_compatibility})` and any additional wrappers"
)
if disable_env_checker is not None:
logger.warn(
f"For `gymnasium.make` with an `EnvSpec`, the `disable_env_checker` parameter is ignored, use `gym.make({id.id}, disable_env_checker={disable_env_checker})` and any additional wrappers"
)
return _create_from_env_spec(
id,
kwargs,
)
else:
raise ValueError(
f"The EnvSpec used does not contain `applied_wrappers` parameters or is `None`. Expected to be a tuple, actually {id}."
)
else: else:
# For string id's, load the environment spec from the registry then make the environment spec
assert isinstance(id, str)
# The environment name can include an unloaded module in "module:env_name" style # The environment name can include an unloaded module in "module:env_name" style
env_spec = _find_spec(id) env_spec = _find_spec(id)
assert isinstance( return _create_from_env_id(
env_spec, EnvSpec env_spec,
), f"We expected to collect an `EnvSpec`, actually collected a {type(env_spec)}" kwargs,
# Extract the spec kwargs and append the make kwargs max_episode_steps=max_episode_steps,
spec_kwargs = env_spec.kwargs.copy() autoreset=autoreset,
spec_kwargs.update(kwargs) apply_api_compatibility=apply_api_compatibility,
disable_env_checker=disable_env_checker,
# Load the environment creator )
if env_spec.entry_point is None:
raise error.Error(f"{env_spec.id} registered but entry_point is not specified")
elif callable(env_spec.entry_point):
env_creator = env_spec.entry_point
else:
# Assume it's a string
env_creator = load_env(env_spec.entry_point)
# Determine if to use the rendering
render_modes: list[str] | None = None
if hasattr(env_creator, "metadata"):
_check_metadata(env_creator.metadata)
render_modes = env_creator.metadata.get("render_modes")
mode = spec_kwargs.get("render_mode")
apply_human_rendering = False
apply_render_collection = False
# If mode is not valid, try applying HumanRendering/RenderCollection wrappers
if mode is not None and render_modes is not None and mode not in render_modes:
displayable_modes = {"rgb_array", "rgb_array_list"}.intersection(render_modes)
if mode == "human" and len(displayable_modes) > 0:
logger.warn(
"You are trying to use 'human' rendering for an environment that doesn't natively support it. "
"The HumanRendering wrapper is being applied to your environment."
)
spec_kwargs["render_mode"] = displayable_modes.pop()
apply_human_rendering = True
elif mode.endswith("_list") and mode[: -len("_list")] in render_modes:
spec_kwargs["render_mode"] = mode[: -len("_list")]
apply_render_collection = True
else:
logger.warn(
f"The environment is being initialised with render_mode={mode!r} "
f"that is not in the possible render_modes ({render_modes})."
)
if apply_api_compatibility or (
apply_api_compatibility is None and env_spec.apply_api_compatibility
):
# If we use the compatibility layer, we treat the render mode explicitly and don't pass it to the env creator
render_mode = spec_kwargs.pop("render_mode", None)
else:
render_mode = None
try:
env = env_creator(**spec_kwargs)
except TypeError as e:
if (
str(e).find("got an unexpected keyword argument 'render_mode'") >= 0
and apply_human_rendering
):
raise error.Error(
f"You passed render_mode='human' although {id} doesn't implement human-rendering natively. "
"Gym tried to apply the HumanRendering wrapper but it looks like your environment is using the old "
"rendering API, which is not supported by the HumanRendering wrapper."
) from e
else:
raise e
# Copies the environment creation specification and kwargs to add to the environment specification details
env_spec = copy.deepcopy(env_spec)
env_spec.kwargs = spec_kwargs
env.unwrapped.spec = env_spec
# Add step API wrapper
if apply_api_compatibility is True or (
apply_api_compatibility is None and env_spec.apply_api_compatibility is True
):
env = EnvCompatibility(env, render_mode)
# Run the environment checker as the lowest level wrapper
if disable_env_checker is False or (
disable_env_checker is None and env_spec.disable_env_checker is False
):
env = PassiveEnvChecker(env)
# Add the order enforcing wrapper
if env_spec.order_enforce:
env = OrderEnforcing(env)
# Add the time limit wrapper
if max_episode_steps is not None:
env = TimeLimit(env, max_episode_steps)
elif env_spec.max_episode_steps is not None:
env = TimeLimit(env, env_spec.max_episode_steps)
# Add the autoreset wrapper
if autoreset:
env = AutoResetWrapper(env)
# Add human rendering wrapper
if apply_human_rendering:
env = HumanRendering(env)
elif apply_render_collection:
env = RenderCollection(env)
return env
def make_vec( def make_vec(
@@ -752,7 +985,7 @@ def make_vec(
env_creator = entry_point env_creator = entry_point
else: else:
# Assume it's a string # Assume it's a string
env_creator = load_env(entry_point) env_creator = load_env_creator(entry_point)
def _create_env(): def _create_env():
# Env creator for use with sync and async modes # Env creator for use with sync and async modes

View File

@@ -11,7 +11,7 @@ except ImportError:
cv2 = None cv2 = None
class AtariPreprocessingV0(gym.Wrapper): class AtariPreprocessingV0(gym.Wrapper, gym.utils.RecordConstructorArgs):
"""Atari 2600 preprocessing wrapper. """Atari 2600 preprocessing wrapper.
This class follows the guidelines in Machado et al. (2018), This class follows the guidelines in Machado et al. (2018),
@@ -60,7 +60,18 @@ class AtariPreprocessingV0(gym.Wrapper):
DependencyNotInstalled: opencv-python package not installed DependencyNotInstalled: opencv-python package not installed
ValueError: Disable frame-skipping in the original env ValueError: Disable frame-skipping in the original env
""" """
super().__init__(env) gym.utils.RecordConstructorArgs.__init__(
self,
noop_max=noop_max,
frame_skip=frame_skip,
screen_size=screen_size,
terminal_on_life_loss=terminal_on_life_loss,
grayscale_obs=grayscale_obs,
grayscale_newaxis=grayscale_newaxis,
scale_obs=scale_obs,
)
gym.Wrapper.__init__(self, env)
if cv2 is None: if cv2 is None:
raise gym.error.DependencyNotInstalled( raise gym.error.DependencyNotInstalled(
"opencv-python package not installed, run `pip install gymnasium[other]` to get dependencies for atari" "opencv-python package not installed, run `pip install gymnasium[other]` to get dependencies for atari"

View File

@@ -14,7 +14,6 @@ from typing import Any, SupportsFloat
import numpy as np import numpy as np
import gymnasium as gym import gymnasium as gym
from gymnasium import Env
from gymnasium.core import ActType, ObsType, RenderFrame from gymnasium.core import ActType, ObsType, RenderFrame
from gymnasium.error import ResetNeeded from gymnasium.error import ResetNeeded
from gymnasium.utils.passive_env_checker import ( from gymnasium.utils.passive_env_checker import (
@@ -26,7 +25,9 @@ from gymnasium.utils.passive_env_checker import (
) )
class AutoresetV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]): class AutoresetV0(
gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
):
"""A class for providing an automatic reset functionality for gymnasium environments when calling :meth:`self.step`.""" """A class for providing an automatic reset functionality for gymnasium environments when calling :meth:`self.step`."""
def __init__(self, env: gym.Env[ObsType, ActType]): def __init__(self, env: gym.Env[ObsType, ActType]):
@@ -35,7 +36,9 @@ class AutoresetV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
Args: Args:
env (gym.Env): The environment to apply the wrapper env (gym.Env): The environment to apply the wrapper
""" """
super().__init__(env) gym.utils.RecordConstructorArgs.__init__(self)
gym.Wrapper.__init__(self, env)
self._episode_ended: bool = False self._episode_ended: bool = False
self._reset_options: dict[str, Any] | None = None self._reset_options: dict[str, Any] | None = None
@@ -68,12 +71,15 @@ class AutoresetV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
return super().reset(seed=seed, options=self._reset_options) return super().reset(seed=seed, options=self._reset_options)
class PassiveEnvCheckerV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]): class PassiveEnvCheckerV0(
gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
):
"""A passive environment checker wrapper that surrounds the step, reset and render functions to check they follow the gymnasium API.""" """A passive environment checker wrapper that surrounds the step, reset and render functions to check they follow the gymnasium API."""
def __init__(self, env: Env[ObsType, ActType]): def __init__(self, env: gym.Env[ObsType, ActType]):
"""Initialises the wrapper with the environments, run the observation and action space tests.""" """Initialises the wrapper with the environments, run the observation and action space tests."""
super().__init__(env) gym.utils.RecordConstructorArgs.__init__(self)
gym.Wrapper.__init__(self, env)
assert hasattr( assert hasattr(
env, "action_space" env, "action_space"
@@ -117,7 +123,9 @@ class PassiveEnvCheckerV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
return self.env.render() return self.env.render()
class OrderEnforcingV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]): class OrderEnforcingV0(
gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
):
"""A wrapper that will produce an error if :meth:`step` is called before an initial :meth:`reset`. """A wrapper that will produce an error if :meth:`step` is called before an initial :meth:`reset`.
Example: Example:
@@ -150,7 +158,11 @@ class OrderEnforcingV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
env: The environment to wrap env: The environment to wrap
disable_render_order_enforcing: If to disable render order enforcing disable_render_order_enforcing: If to disable render order enforcing
""" """
super().__init__(env) gym.utils.RecordConstructorArgs.__init__(
self, disable_render_order_enforcing=disable_render_order_enforcing
)
gym.Wrapper.__init__(self, env)
self._has_reset: bool = False self._has_reset: bool = False
self._disable_render_order_enforcing: bool = disable_render_order_enforcing self._disable_render_order_enforcing: bool = disable_render_order_enforcing
@@ -182,7 +194,9 @@ class OrderEnforcingV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
return self._has_reset return self._has_reset
class RecordEpisodeStatisticsV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]): class RecordEpisodeStatisticsV0(
gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
):
"""This wrapper will keep track of cumulative rewards and episode lengths. """This wrapper will keep track of cumulative rewards and episode lengths.
At the end of an episode, the statistics of the episode will be added to ``info`` At the end of an episode, the statistics of the episode will be added to ``info``
@@ -226,7 +240,7 @@ class RecordEpisodeStatisticsV0(gym.Wrapper[ObsType, ActType, ObsType, ActType])
def __init__( def __init__(
self, self,
env: Env[ObsType, ActType], env: gym.Env[ObsType, ActType],
buffer_length: int | None = 100, buffer_length: int | None = 100,
stats_key: str = "episode", stats_key: str = "episode",
): ):
@@ -237,7 +251,8 @@ class RecordEpisodeStatisticsV0(gym.Wrapper[ObsType, ActType, ObsType, ActType])
buffer_length: The size of the buffers :attr:`return_queue` and :attr:`length_queue` buffer_length: The size of the buffers :attr:`return_queue` and :attr:`length_queue`
stats_key: The info key for the episode statistics stats_key: The info key for the episode statistics
""" """
super().__init__(env) gym.utils.RecordConstructorArgs.__init__(self)
gym.Wrapper.__init__(self, env)
self._stats_key = stats_key self._stats_key = stats_key

View File

@@ -8,7 +8,7 @@ from typing import Any, Iterable, Mapping, SupportsFloat
import numpy as np import numpy as np
from gymnasium import Env, Wrapper import gymnasium as gym
from gymnasium.core import ActType, ObsType, RenderFrame, WrapperActType, WrapperObsType from gymnasium.core import ActType, ObsType, RenderFrame, WrapperActType, WrapperObsType
from gymnasium.error import DependencyNotInstalled from gymnasium.error import DependencyNotInstalled
@@ -92,7 +92,10 @@ if jnp is not None:
return type(value)(jax_to_numpy(v) for v in value) return type(value)(jax_to_numpy(v) for v in value)
class JaxToNumpyV0(Wrapper[WrapperObsType, WrapperActType, ObsType, ActType]): class JaxToNumpyV0(
gym.Wrapper[WrapperObsType, WrapperActType, ObsType, ActType],
gym.utils.RecordConstructorArgs,
):
"""Wraps a jax environment so that it can be interacted with through numpy arrays. """Wraps a jax environment so that it can be interacted with through numpy arrays.
Actions must be provided as numpy arrays and observations will be returned as numpy arrays. Actions must be provided as numpy arrays and observations will be returned as numpy arrays.
@@ -102,7 +105,7 @@ class JaxToNumpyV0(Wrapper[WrapperObsType, WrapperActType, ObsType, ActType]):
The reason for this is jax does not support non-array values, therefore numpy ``int_32(5) -> DeviceArray([5], dtype=jnp.int23)`` The reason for this is jax does not support non-array values, therefore numpy ``int_32(5) -> DeviceArray([5], dtype=jnp.int23)``
""" """
def __init__(self, env: Env[ObsType, ActType]): def __init__(self, env: gym.Env[ObsType, ActType]):
"""Wraps an environment such that the input and outputs are numpy arrays. """Wraps an environment such that the input and outputs are numpy arrays.
Args: Args:
@@ -112,7 +115,8 @@ class JaxToNumpyV0(Wrapper[WrapperObsType, WrapperActType, ObsType, ActType]):
raise DependencyNotInstalled( raise DependencyNotInstalled(
"jax is not installed, run `pip install gymnasium[jax]`" "jax is not installed, run `pip install gymnasium[jax]`"
) )
super().__init__(env) gym.utils.RecordConstructorArgs.__init__(self)
gym.Wrapper.__init__(self, env)
def step( def step(
self, action: WrapperActType self, action: WrapperActType

View File

@@ -14,7 +14,7 @@ import numbers
from collections import abc from collections import abc
from typing import Any, Iterable, Mapping, SupportsFloat, Union from typing import Any, Iterable, Mapping, SupportsFloat, Union
from gymnasium import Env, Wrapper import gymnasium as gym
from gymnasium.core import RenderFrame, WrapperActType, WrapperObsType from gymnasium.core import RenderFrame, WrapperActType, WrapperObsType
from gymnasium.error import DependencyNotInstalled from gymnasium.error import DependencyNotInstalled
from gymnasium.experimental.wrappers.conversion.jax_to_numpy import jax_to_numpy from gymnasium.experimental.wrappers.conversion.jax_to_numpy import jax_to_numpy
@@ -131,7 +131,7 @@ if torch is not None and jnp is not None:
return type(value)(jax_to_torch(v, device) for v in value) return type(value)(jax_to_torch(v, device) for v in value)
class JaxToTorchV0(Wrapper): class JaxToTorchV0(gym.Wrapper, gym.utils.RecordConstructorArgs):
"""Wraps a jax-based environment so that it can be interacted with through PyTorch Tensors. """Wraps a jax-based environment so that it can be interacted with through PyTorch Tensors.
Actions must be provided as PyTorch Tensors and observations will be returned as PyTorch Tensors. Actions must be provided as PyTorch Tensors and observations will be returned as PyTorch Tensors.
@@ -140,7 +140,7 @@ class JaxToTorchV0(Wrapper):
For ``rendered`` this is returned as a NumPy array not a pytorch Tensor. For ``rendered`` this is returned as a NumPy array not a pytorch Tensor.
""" """
def __init__(self, env: Env, device: Device | None = None): def __init__(self, env: gym.Env, device: Device | None = None):
"""Wrapper class to change inputs and outputs of environment to PyTorch tensors. """Wrapper class to change inputs and outputs of environment to PyTorch tensors.
Args: Args:
@@ -156,7 +156,9 @@ class JaxToTorchV0(Wrapper):
"jax is not installed, run `pip install gymnasium[jax]`" "jax is not installed, run `pip install gymnasium[jax]`"
) )
super().__init__(env) gym.utils.RecordConstructorArgs.__init__(self, device=device)
gym.Wrapper.__init__(self, env)
self.device: Device | None = device self.device: Device | None = device
def step( def step(

View File

@@ -8,7 +8,7 @@ from typing import Any, Iterable, Mapping, SupportsFloat, Union
import numpy as np import numpy as np
from gymnasium import Env, Wrapper import gymnasium as gym
from gymnasium.core import WrapperActType, WrapperObsType from gymnasium.core import WrapperActType, WrapperObsType
from gymnasium.error import DependencyNotInstalled from gymnasium.error import DependencyNotInstalled
@@ -94,7 +94,7 @@ if torch is not None:
return type(value)(numpy_to_torch(v, device) for v in value) return type(value)(numpy_to_torch(v, device) for v in value)
class NumpyToTorchV0(Wrapper): class NumpyToTorchV0(gym.Wrapper, gym.utils.RecordConstructorArgs):
"""Wraps a numpy-based environment so that it can be interacted with through PyTorch Tensors. """Wraps a numpy-based environment so that it can be interacted with through PyTorch Tensors.
Actions must be provided as PyTorch Tensors and observations will be returned as PyTorch Tensors. Actions must be provided as PyTorch Tensors and observations will be returned as PyTorch Tensors.
@@ -103,7 +103,7 @@ class NumpyToTorchV0(Wrapper):
For ``rendered`` this is returned as a NumPy array not a pytorch Tensor. For ``rendered`` this is returned as a NumPy array not a pytorch Tensor.
""" """
def __init__(self, env: Env, device: Device | None = None): def __init__(self, env: gym.Env, device: Device | None = None):
"""Wrapper class to change inputs and outputs of environment to PyTorch tensors. """Wrapper class to change inputs and outputs of environment to PyTorch tensors.
Args: Args:
@@ -115,7 +115,9 @@ class NumpyToTorchV0(Wrapper):
"torch is not installed, run `pip install torch`" "torch is not installed, run `pip install torch`"
) )
super().__init__(env) gym.utils.RecordConstructorArgs.__init__(self, device=device)
gym.Wrapper.__init__(self, env)
self.device: Device | None = device self.device: Device | None = device
def step( def step(

View File

@@ -20,7 +20,9 @@ from gymnasium.core import ActType, ObsType, WrapperActType
from gymnasium.spaces import Box, Space from gymnasium.spaces import Box, Space
class LambdaActionV0(gym.ActionWrapper[ObsType, WrapperActType, ActType]): class LambdaActionV0(
gym.ActionWrapper[ObsType, WrapperActType, ActType], gym.utils.RecordConstructorArgs
):
"""A wrapper that provides a function to modify the action passed to :meth:`step`.""" """A wrapper that provides a function to modify the action passed to :meth:`step`."""
def __init__( def __init__(
@@ -36,7 +38,11 @@ class LambdaActionV0(gym.ActionWrapper[ObsType, WrapperActType, ActType]):
func: Function to apply to ``step`` ``action`` func: Function to apply to ``step`` ``action``
action_space: The updated action space of the wrapper given the function. action_space: The updated action space of the wrapper given the function.
""" """
super().__init__(env) gym.utils.RecordConstructorArgs.__init__(
self, func=func, action_space=action_space
)
gym.Wrapper.__init__(self, env)
if action_space is not None: if action_space is not None:
self.action_space = action_space self.action_space = action_space
@@ -47,7 +53,9 @@ class LambdaActionV0(gym.ActionWrapper[ObsType, WrapperActType, ActType]):
return self.func(action) return self.func(action)
class ClipActionV0(LambdaActionV0[ObsType, WrapperActType, ActType]): class ClipActionV0(
LambdaActionV0[ObsType, WrapperActType, ActType], gym.utils.RecordConstructorArgs
):
"""Clip the continuous action within the valid :class:`Box` observation space bound. """Clip the continuous action within the valid :class:`Box` observation space bound.
Example: Example:
@@ -71,10 +79,14 @@ class ClipActionV0(LambdaActionV0[ObsType, WrapperActType, ActType]):
""" """
assert isinstance(env.action_space, Box) assert isinstance(env.action_space, Box)
super().__init__( gym.utils.RecordConstructorArgs.__init__(self)
env, LambdaActionV0.__init__(
lambda action: jp.clip(action, env.action_space.low, env.action_space.high), self,
Box( env=env,
func=lambda action: jp.clip(
action, env.action_space.low, env.action_space.high
),
action_space=Box(
-np.inf, -np.inf,
np.inf, np.inf,
shape=env.action_space.shape, shape=env.action_space.shape,
@@ -83,7 +95,9 @@ class ClipActionV0(LambdaActionV0[ObsType, WrapperActType, ActType]):
) )
class RescaleActionV0(LambdaActionV0[ObsType, WrapperActType, ActType]): class RescaleActionV0(
LambdaActionV0[ObsType, WrapperActType, ActType], gym.utils.RecordConstructorArgs
):
"""Affinely rescales the continuous action space of the environment to the range [min_action, max_action]. """Affinely rescales the continuous action space of the environment to the range [min_action, max_action].
The base environment :attr:`env` must have an action space of type :class:`spaces.Box`. If :attr:`min_action` The base environment :attr:`env` must have an action space of type :class:`spaces.Box`. If :attr:`min_action`
@@ -118,6 +132,10 @@ class RescaleActionV0(LambdaActionV0[ObsType, WrapperActType, ActType]):
min_action (float, int or np.ndarray): The min values for each action. This may be a numpy array or a scalar. min_action (float, int or np.ndarray): The min values for each action. This may be a numpy array or a scalar.
max_action (float, int or np.ndarray): The max values for each action. This may be a numpy array or a scalar. max_action (float, int or np.ndarray): The max values for each action. This may be a numpy array or a scalar.
""" """
gym.utils.RecordConstructorArgs.__init__(
self, min_action=min_action, max_action=max_action
)
assert isinstance(env.action_space, Box) assert isinstance(env.action_space, Box)
assert not np.any(env.action_space.low == np.inf) and not np.any( assert not np.any(env.action_space.low == np.inf) and not np.any(
env.action_space.high == np.inf env.action_space.high == np.inf
@@ -149,10 +167,11 @@ class RescaleActionV0(LambdaActionV0[ObsType, WrapperActType, ActType]):
) )
intercept = gradient * -min_action + env.action_space.low intercept = gradient * -min_action + env.action_space.low
super().__init__( LambdaActionV0.__init__(
env, self,
lambda action: gradient * action + intercept, env=env,
Box( func=lambda action: gradient * action + intercept,
action_space=Box(
low=min_action, low=min_action,
high=max_action, high=max_action,
shape=env.action_space.shape, shape=env.action_space.shape,

View File

@@ -24,14 +24,16 @@ except ImportError as e:
import numpy as np import numpy as np
import gymnasium as gym import gymnasium as gym
from gymnasium import Env, spaces from gymnasium import spaces
from gymnasium.core import ActType, ObservationWrapper, ObsType, WrapperObsType from gymnasium.core import ActType, ObsType, WrapperObsType
from gymnasium.error import DependencyNotInstalled from gymnasium.error import DependencyNotInstalled
from gymnasium.experimental.wrappers.utils import RunningMeanStd from gymnasium.experimental.wrappers.utils import RunningMeanStd
from gymnasium.spaces import Box, Dict, utils
class LambdaObservationV0(gym.ObservationWrapper[WrapperObsType, ActType, ObsType]): class LambdaObservationV0(
gym.ObservationWrapper[WrapperObsType, ActType, ObsType],
gym.utils.RecordConstructorArgs,
):
"""Transforms an observation via a function provided to the wrapper. """Transforms an observation via a function provided to the wrapper.
The function :attr:`func` will be applied to all observations. The function :attr:`func` will be applied to all observations.
@@ -61,7 +63,11 @@ class LambdaObservationV0(gym.ObservationWrapper[WrapperObsType, ActType, ObsTyp
func: A function that will transform an observation. If this transformed observation is outside the observation space of `env.observation_space` then provide an `observation_space`. func: A function that will transform an observation. If this transformed observation is outside the observation space of `env.observation_space` then provide an `observation_space`.
observation_space: The observation spaces of the wrapper, if None, then it is assumed the same as `env.observation_space`. observation_space: The observation spaces of the wrapper, if None, then it is assumed the same as `env.observation_space`.
""" """
super().__init__(env) gym.utils.RecordConstructorArgs.__init__(
self, func=func, observation_space=observation_space
)
gym.ObservationWrapper.__init__(self, env)
if observation_space is not None: if observation_space is not None:
self.observation_space = observation_space self.observation_space = observation_space
@@ -72,7 +78,10 @@ class LambdaObservationV0(gym.ObservationWrapper[WrapperObsType, ActType, ObsTyp
return self.func(observation) return self.func(observation)
class FilterObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]): class FilterObservationV0(
LambdaObservationV0[WrapperObsType, ActType, ObsType],
gym.utils.RecordConstructorArgs,
):
"""Filter Dict observation space by the keys. """Filter Dict observation space by the keys.
Example: Example:
@@ -96,6 +105,7 @@ class FilterObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType])
): ):
"""Constructor for an environment with a dictionary observation space where all :attr:`filter_keys` are in the observation space keys.""" """Constructor for an environment with a dictionary observation space where all :attr:`filter_keys` are in the observation space keys."""
assert isinstance(filter_keys, Sequence) assert isinstance(filter_keys, Sequence)
gym.utils.RecordConstructorArgs.__init__(self, filter_keys=filter_keys)
# Filters for dictionary space # Filters for dictionary space
if isinstance(env.observation_space, spaces.Dict): if isinstance(env.observation_space, spaces.Dict):
@@ -124,10 +134,11 @@ class FilterObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType])
"The observation space is empty due to filtering all keys." "The observation space is empty due to filtering all keys."
) )
super().__init__( LambdaObservationV0.__init__(
env, self,
lambda obs: {key: obs[key] for key in filter_keys}, env=env,
new_observation_space, func=lambda obs: {key: obs[key] for key in filter_keys},
observation_space=new_observation_space,
) )
# Filter for tuple observation # Filter for tuple observation
elif isinstance(env.observation_space, spaces.Tuple): elif isinstance(env.observation_space, spaces.Tuple):
@@ -158,10 +169,11 @@ class FilterObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType])
"The observation space is empty due to filtering all keys." "The observation space is empty due to filtering all keys."
) )
super().__init__( LambdaObservationV0.__init__(
env, self,
lambda obs: tuple(obs[key] for key in filter_keys), env=env,
new_observation_spaces, func=lambda obs: tuple(obs[key] for key in filter_keys),
observation_space=new_observation_spaces,
) )
else: else:
raise ValueError( raise ValueError(
@@ -171,7 +183,10 @@ class FilterObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType])
self.filter_keys: Final[Sequence[str | int]] = filter_keys self.filter_keys: Final[Sequence[str | int]] = filter_keys
class FlattenObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]): class FlattenObservationV0(
LambdaObservationV0[WrapperObsType, ActType, ObsType],
gym.utils.RecordConstructorArgs,
):
"""Observation wrapper that flattens the observation. """Observation wrapper that flattens the observation.
Example: Example:
@@ -190,14 +205,19 @@ class FlattenObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]
def __init__(self, env: gym.Env[ObsType, ActType]): def __init__(self, env: gym.Env[ObsType, ActType]):
"""Constructor for any environment's observation space that implements ``spaces.utils.flatten_space`` and ``spaces.utils.flatten``.""" """Constructor for any environment's observation space that implements ``spaces.utils.flatten_space`` and ``spaces.utils.flatten``."""
super().__init__( gym.utils.RecordConstructorArgs.__init__(self)
env, LambdaObservationV0.__init__(
lambda obs: utils.flatten(env.observation_space, obs), self,
utils.flatten_space(env.observation_space), env=env,
func=lambda obs: spaces.utils.flatten(env.observation_space, obs),
observation_space=spaces.utils.flatten_space(env.observation_space),
) )
class GrayscaleObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]): class GrayscaleObservationV0(
LambdaObservationV0[WrapperObsType, ActType, ObsType],
gym.utils.RecordConstructorArgs,
):
"""Observation wrapper that converts an RGB image to grayscale. """Observation wrapper that converts an RGB image to grayscale.
The :attr:`keep_dim` will keep the channel dimension The :attr:`keep_dim` will keep the channel dimension
@@ -228,6 +248,7 @@ class GrayscaleObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsTyp
and np.all(env.observation_space.high == 255) and np.all(env.observation_space.high == 255)
and env.observation_space.dtype == np.uint8 and env.observation_space.dtype == np.uint8
) )
gym.utils.RecordConstructorArgs.__init__(self, keep_dim=keep_dim)
self.keep_dim: Final[bool] = keep_dim self.keep_dim: Final[bool] = keep_dim
if keep_dim: if keep_dim:
@@ -237,30 +258,35 @@ class GrayscaleObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsTyp
shape=env.observation_space.shape[:2] + (1,), shape=env.observation_space.shape[:2] + (1,),
dtype=np.uint8, dtype=np.uint8,
) )
super().__init__( LambdaObservationV0.__init__(
env, self,
lambda obs: jp.expand_dims( env=env,
func=lambda obs: jp.expand_dims(
jp.sum( jp.sum(
jp.multiply(obs, jp.array([0.2125, 0.7154, 0.0721])), axis=-1 jp.multiply(obs, jp.array([0.2125, 0.7154, 0.0721])), axis=-1
).astype(np.uint8), ).astype(np.uint8),
axis=-1, axis=-1,
), ),
new_observation_space, observation_space=new_observation_space,
) )
else: else:
new_observation_space = spaces.Box( new_observation_space = spaces.Box(
low=0, high=255, shape=env.observation_space.shape[:2], dtype=np.uint8 low=0, high=255, shape=env.observation_space.shape[:2], dtype=np.uint8
) )
super().__init__( LambdaObservationV0.__init__(
env, self,
lambda obs: jp.sum( env=env,
func=lambda obs: jp.sum(
jp.multiply(obs, jp.array([0.2125, 0.7154, 0.0721])), axis=-1 jp.multiply(obs, jp.array([0.2125, 0.7154, 0.0721])), axis=-1
).astype(np.uint8), ).astype(np.uint8),
new_observation_space, observation_space=new_observation_space,
) )
class ResizeObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]): class ResizeObservationV0(
LambdaObservationV0[WrapperObsType, ActType, ObsType],
gym.utils.RecordConstructorArgs,
):
"""Resizes image observations using OpenCV to shape. """Resizes image observations using OpenCV to shape.
Example: Example:
@@ -299,14 +325,20 @@ class ResizeObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType])
new_observation_space = spaces.Box( new_observation_space = spaces.Box(
low=0, high=255, shape=self.shape + env.observation_space.shape[2:] low=0, high=255, shape=self.shape + env.observation_space.shape[2:]
) )
super().__init__(
env, gym.utils.RecordConstructorArgs.__init__(self, shape=shape)
lambda obs: cv2.resize(obs, self.shape, interpolation=cv2.INTER_AREA), LambdaObservationV0.__init__(
new_observation_space, self,
env=env,
func=lambda obs: cv2.resize(obs, self.shape, interpolation=cv2.INTER_AREA),
observation_space=new_observation_space,
) )
class ReshapeObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]): class ReshapeObservationV0(
LambdaObservationV0[WrapperObsType, ActType, ObsType],
gym.utils.RecordConstructorArgs,
):
"""Reshapes array based observations to shapes. """Reshapes array based observations to shapes.
Example: Example:
@@ -336,10 +368,20 @@ class ReshapeObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]
dtype=env.observation_space.dtype, dtype=env.observation_space.dtype,
) )
self.shape = shape self.shape = shape
super().__init__(env, lambda obs: jp.reshape(obs, shape), new_observation_space)
gym.utils.RecordConstructorArgs.__init__(self, shape=shape)
LambdaObservationV0.__init__(
self,
env=env,
func=lambda obs: jp.reshape(obs, shape),
observation_space=new_observation_space,
)
class RescaleObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]): class RescaleObservationV0(
LambdaObservationV0[WrapperObsType, ActType, ObsType],
gym.utils.RecordConstructorArgs,
):
"""Linearly rescales observation to between a minimum and maximum value. """Linearly rescales observation to between a minimum and maximum value.
Example: Example:
@@ -392,10 +434,12 @@ class RescaleObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]
) )
intercept = gradient * -env.observation_space.low + min_obs intercept = gradient * -env.observation_space.low + min_obs
super().__init__( gym.utils.RecordConstructorArgs.__init__(self, min_obs=min_obs, max_obs=max_obs)
env, LambdaObservationV0.__init__(
lambda obs: gradient * obs + intercept, self,
Box( env=env,
func=lambda obs: gradient * obs + intercept,
observation_space=spaces.Box(
low=min_obs, low=min_obs,
high=max_obs, high=max_obs,
shape=env.observation_space.shape, shape=env.observation_space.shape,
@@ -404,7 +448,10 @@ class RescaleObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]
) )
class DtypeObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]): class DtypeObservationV0(
LambdaObservationV0[WrapperObsType, ActType, ObsType],
gym.utils.RecordConstructorArgs,
):
"""Observation wrapper for transforming the dtype of an observation.""" """Observation wrapper for transforming the dtype of an observation."""
def __init__(self, env: gym.Env[ObsType, ActType], dtype: Any): def __init__(self, env: gym.Env[ObsType, ActType], dtype: Any):
@@ -445,10 +492,19 @@ class DtypeObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]):
"DtypeObservation is only compatible with value / array-based observations." "DtypeObservation is only compatible with value / array-based observations."
) )
super().__init__(env, lambda obs: dtype(obs), new_observation_space) gym.utils.RecordConstructorArgs.__init__(self, dtype=dtype)
LambdaObservationV0.__init__(
self,
env=env,
func=lambda obs: dtype(obs),
observation_space=new_observation_space,
)
class PixelObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]): class PixelObservationV0(
LambdaObservationV0[WrapperObsType, ActType, ObsType],
gym.utils.RecordConstructorArgs,
):
"""Augment observations by pixel values. """Augment observations by pixel values.
Observations of this wrapper will be dictionaries of images. Observations of this wrapper will be dictionaries of images.
@@ -461,7 +517,7 @@ class PixelObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]):
def __init__( def __init__(
self, self,
env: Env[ObsType, ActType], env: gym.Env[ObsType, ActType],
pixels_only: bool = True, pixels_only: bool = True,
pixels_key: str = "pixels", pixels_key: str = "pixels",
obs_key: str = "state", obs_key: str = "state",
@@ -478,30 +534,49 @@ class PixelObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]):
pixels_key: Optional custom string specifying the pixel key. Defaults to "pixels" pixels_key: Optional custom string specifying the pixel key. Defaults to "pixels"
obs_key: Optional custom string specifying the obs key. Defaults to "state" obs_key: Optional custom string specifying the obs key. Defaults to "state"
""" """
gym.utils.RecordConstructorArgs.__init__(
self, pixels_only=pixels_only, pixels_key=pixels_key, obs_key=obs_key
)
assert env.render_mode is not None and env.render_mode != "human" assert env.render_mode is not None and env.render_mode != "human"
env.reset() env.reset()
pixels = env.render() pixels = env.render()
assert pixels is not None and isinstance(pixels, np.ndarray) assert pixels is not None and isinstance(pixels, np.ndarray)
pixel_space = Box(low=0, high=255, shape=pixels.shape, dtype=np.uint8) pixel_space = spaces.Box(low=0, high=255, shape=pixels.shape, dtype=np.uint8)
if pixels_only: if pixels_only:
obs_space = pixel_space obs_space = pixel_space
super().__init__(env, lambda _: self.render(), obs_space) LambdaObservationV0.__init__(
elif isinstance(env.observation_space, Dict): self, env=env, func=lambda _: self.render(), observation_space=obs_space
)
elif isinstance(env.observation_space, spaces.Dict):
assert pixels_key not in env.observation_space.spaces.keys() assert pixels_key not in env.observation_space.spaces.keys()
obs_space = Dict({pixels_key: pixel_space, **env.observation_space.spaces}) obs_space = spaces.Dict(
super().__init__( {pixels_key: pixel_space, **env.observation_space.spaces}
env, lambda obs: {pixels_key: self.render(), **obs_space}, obs_space )
LambdaObservationV0.__init__(
self,
env=env,
func=lambda obs: {pixels_key: self.render(), **obs_space},
observation_space=obs_space,
) )
else: else:
obs_space = Dict({obs_key: env.observation_space, pixels_key: pixel_space}) obs_space = spaces.Dict(
super().__init__( {obs_key: env.observation_space, pixels_key: pixel_space}
env, lambda obs: {obs_key: obs, pixels_key: self.render()}, obs_space )
LambdaObservationV0.__init__(
self,
env=env,
func=lambda obs: {obs_key: obs, pixels_key: self.render()},
observation_space=obs_space,
) )
class NormalizeObservationV0(ObservationWrapper[WrapperObsType, ActType, ObsType]): class NormalizeObservationV0(
gym.ObservationWrapper[WrapperObsType, ActType, ObsType],
gym.utils.RecordConstructorArgs,
):
"""This wrapper will normalize observations s.t. each coordinate is centered with unit variance. """This wrapper will normalize observations s.t. each coordinate is centered with unit variance.
The property `_update_running_mean` allows to freeze/continue the running mean calculation of the observation The property `_update_running_mean` allows to freeze/continue the running mean calculation of the observation
@@ -520,7 +595,9 @@ class NormalizeObservationV0(ObservationWrapper[WrapperObsType, ActType, ObsType
env (Env): The environment to apply the wrapper env (Env): The environment to apply the wrapper
epsilon: A stability parameter that is used when scaling the observations. epsilon: A stability parameter that is used when scaling the observations.
""" """
super().__init__(env) gym.utils.RecordConstructorArgs.__init__(self, epsilon=epsilon)
gym.ObservationWrapper.__init__(self, env)
self.obs_rms = RunningMeanStd(shape=self.observation_space.shape) self.obs_rms = RunningMeanStd(shape=self.observation_space.shape)
self.epsilon = epsilon self.epsilon = epsilon
self._update_running_mean = True self._update_running_mean = True

View File

@@ -15,7 +15,9 @@ from gymnasium.error import InvalidBound
from gymnasium.experimental.wrappers.utils import RunningMeanStd from gymnasium.experimental.wrappers.utils import RunningMeanStd
class LambdaRewardV0(gym.RewardWrapper[ObsType, ActType]): class LambdaRewardV0(
gym.RewardWrapper[ObsType, ActType], gym.utils.RecordConstructorArgs
):
"""A reward wrapper that allows a custom function to modify the step reward. """A reward wrapper that allows a custom function to modify the step reward.
Example: Example:
@@ -40,7 +42,8 @@ class LambdaRewardV0(gym.RewardWrapper[ObsType, ActType]):
env (Env): The environment to apply the wrapper env (Env): The environment to apply the wrapper
func: (Callable): The function to apply to reward func: (Callable): The function to apply to reward
""" """
super().__init__(env) gym.utils.RecordConstructorArgs.__init__(self, func=func)
gym.RewardWrapper.__init__(self, env)
self.func = func self.func = func
@@ -53,7 +56,7 @@ class LambdaRewardV0(gym.RewardWrapper[ObsType, ActType]):
return self.func(reward) return self.func(reward)
class ClipRewardV0(LambdaRewardV0[ObsType, ActType]): class ClipRewardV0(LambdaRewardV0[ObsType, ActType], gym.utils.RecordConstructorArgs):
"""A wrapper that clips the rewards for an environment between an upper and lower bound. """A wrapper that clips the rewards for an environment between an upper and lower bound.
Example: Example:
@@ -89,10 +92,17 @@ class ClipRewardV0(LambdaRewardV0[ObsType, ActType]):
f"Min reward ({min_reward}) must be smaller than max reward ({max_reward})" f"Min reward ({min_reward}) must be smaller than max reward ({max_reward})"
) )
super().__init__(env, lambda x: np.clip(x, a_min=min_reward, a_max=max_reward)) gym.utils.RecordConstructorArgs.__init__(
self, min_reward=min_reward, max_reward=max_reward
)
LambdaRewardV0.__init__(
self, env=env, func=lambda x: np.clip(x, a_min=min_reward, a_max=max_reward)
)
class NormalizeRewardV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]): class NormalizeRewardV0(
gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
):
r"""This wrapper will normalize immediate rewards s.t. their exponential moving average has a fixed variance. r"""This wrapper will normalize immediate rewards s.t. their exponential moving average has a fixed variance.
The exponential moving average will have variance :math:`(1 - \gamma)^2`. The exponential moving average will have variance :math:`(1 - \gamma)^2`.
@@ -119,7 +129,9 @@ class NormalizeRewardV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
epsilon (float): A stability parameter epsilon (float): A stability parameter
gamma (float): The discount factor that is used in the exponential moving average. gamma (float): The discount factor that is used in the exponential moving average.
""" """
super().__init__(env) gym.utils.RecordConstructorArgs.__init__(self, gamma=gamma, epsilon=epsilon)
gym.Wrapper.__init__(self, env)
self.rewards_running_means = RunningMeanStd(shape=()) self.rewards_running_means = RunningMeanStd(shape=())
self.discounted_reward: np.array = np.array([0.0]) self.discounted_reward: np.array = np.array([0.0])
self.gamma = gamma self.gamma = gamma

View File

@@ -18,7 +18,9 @@ from gymnasium.core import ActType, ObsType, RenderFrame
from gymnasium.error import DependencyNotInstalled from gymnasium.error import DependencyNotInstalled
class RenderCollectionV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]): class RenderCollectionV0(
gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
):
"""Collect rendered frames of an environment such ``render`` returns a ``list[RenderedFrame]``.""" """Collect rendered frames of an environment such ``render`` returns a ``list[RenderedFrame]``."""
def __init__( def __init__(
@@ -34,7 +36,11 @@ class RenderCollectionV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
pop_frames (bool): If true, clear the collection frames after ``meth:render`` is called. Default value is ``True``. pop_frames (bool): If true, clear the collection frames after ``meth:render`` is called. Default value is ``True``.
reset_clean (bool): If true, clear the collection frames when ``meth:reset`` is called. Default value is ``True``. reset_clean (bool): If true, clear the collection frames when ``meth:reset`` is called. Default value is ``True``.
""" """
super().__init__(env) gym.utils.RecordConstructorArgs.__init__(
self, pop_frames=pop_frames, reset_clean=reset_clean
)
gym.Wrapper.__init__(self, env)
assert env.render_mode is not None assert env.render_mode is not None
assert not env.render_mode.endswith("_list") assert not env.render_mode.endswith("_list")
@@ -80,7 +86,9 @@ class RenderCollectionV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
return frames return frames
class RecordVideoV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]): class RecordVideoV0(
gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
):
"""This wrapper records videos of rollouts. """This wrapper records videos of rollouts.
Usually, you only want to record episodes intermittently, say every hundredth episode. Usually, you only want to record episodes intermittently, say every hundredth episode.
@@ -117,9 +125,18 @@ class RecordVideoV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
Otherwise, snippets of the specified length are captured Otherwise, snippets of the specified length are captured
name_prefix (str): Will be prepended to the filename of the recordings name_prefix (str): Will be prepended to the filename of the recordings
disable_logger (bool): Whether to disable moviepy logger or not disable_logger (bool): Whether to disable moviepy logger or not
""" """
super().__init__(env) gym.utils.RecordConstructorArgs.__init__(
self,
video_folder=video_folder,
episode_trigger=episode_trigger,
step_trigger=step_trigger,
video_length=video_length,
name_prefix=name_prefix,
disable_logger=disable_logger,
)
gym.Wrapper.__init__(self, env)
try: try:
import moviepy # noqa: F401 import moviepy # noqa: F401
except ImportError as e: except ImportError as e:
@@ -277,7 +294,9 @@ class RecordVideoV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
logger.warn("Unable to save last video! Did you call close()?") logger.warn("Unable to save last video! Did you call close()?")
class HumanRenderingV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]): class HumanRenderingV0(
gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
):
"""Performs human rendering for an environment that only supports "rgb_array"rendering. """Performs human rendering for an environment that only supports "rgb_array"rendering.
This wrapper is particularly useful when you have implemented an environment that can produce This wrapper is particularly useful when you have implemented an environment that can produce
@@ -317,7 +336,9 @@ class HumanRenderingV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
Args: Args:
env: The environment that is being wrapped env: The environment that is being wrapped
""" """
super().__init__(env) gym.utils.RecordConstructorArgs.__init__(self)
gym.Wrapper.__init__(self, env)
assert env.render_mode in [ assert env.render_mode in [
"rgb_array", "rgb_array",
"rgb_array_list", "rgb_array_list",

View File

@@ -4,11 +4,13 @@ from __future__ import annotations
from typing import Any from typing import Any
import gymnasium as gym import gymnasium as gym
from gymnasium.core import ActionWrapper, ActType, ObsType from gymnasium.core import ActType, ObsType
from gymnasium.error import InvalidProbability from gymnasium.error import InvalidProbability
class StickyActionV0(ActionWrapper[ObsType, ActType, ActType]): class StickyActionV0(
gym.ActionWrapper[ObsType, ActType, ActType], gym.utils.RecordConstructorArgs
):
"""Wrapper which adds a probability of repeating the previous action. """Wrapper which adds a probability of repeating the previous action.
This wrapper follows the implementation proposed by `Machado et al., 2018 <https://arxiv.org/pdf/1709.06009.pdf>`_ This wrapper follows the implementation proposed by `Machado et al., 2018 <https://arxiv.org/pdf/1709.06009.pdf>`_
@@ -29,7 +31,11 @@ class StickyActionV0(ActionWrapper[ObsType, ActType, ActType]):
f"repeat_action_probability should be in the interval [0,1). Received {repeat_action_probability}" f"repeat_action_probability should be in the interval [0,1). Received {repeat_action_probability}"
) )
super().__init__(env) gym.utils.RecordConstructorArgs.__init__(
self, repeat_action_probability=repeat_action_probability
)
gym.ActionWrapper.__init__(self, env)
self.repeat_action_probability = repeat_action_probability self.repeat_action_probability = repeat_action_probability
self.last_action: ActType | None = None self.last_action: ActType | None = None

View File

@@ -13,8 +13,8 @@ from typing_extensions import Final
import numpy as np import numpy as np
import gymnasium as gym
import gymnasium.spaces as spaces import gymnasium.spaces as spaces
from gymnasium import Env, ObservationWrapper, Space, Wrapper
from gymnasium.core import ActType, ObsType, WrapperActType, WrapperObsType from gymnasium.core import ActType, ObsType, WrapperActType, WrapperObsType
from gymnasium.experimental.vector.utils import ( from gymnasium.experimental.vector.utils import (
batch_space, batch_space,
@@ -25,7 +25,9 @@ from gymnasium.experimental.wrappers.utils import create_zero_array
from gymnasium.spaces import Box, Dict, Tuple from gymnasium.spaces import Box, Dict, Tuple
class DelayObservationV0(ObservationWrapper[ObsType, ActType, ObsType]): class DelayObservationV0(
gym.ObservationWrapper[ObsType, ActType, ObsType], gym.utils.RecordConstructorArgs
):
"""Wrapper which adds a delay to the returned observation. """Wrapper which adds a delay to the returned observation.
Before reaching the :attr:`delay` number of timesteps, returned observations is an array of zeros with Before reaching the :attr:`delay` number of timesteps, returned observations is an array of zeros with
@@ -49,15 +51,13 @@ class DelayObservationV0(ObservationWrapper[ObsType, ActType, ObsType]):
This does not support random delay values, if users are interested, please raise an issue or pull request to add this feature. This does not support random delay values, if users are interested, please raise an issue or pull request to add this feature.
""" """
def __init__(self, env: Env[ObsType, ActType], delay: int): def __init__(self, env: gym.Env[ObsType, ActType], delay: int):
"""Initialises the DelayObservation wrapper with an integer. """Initialises the DelayObservation wrapper with an integer.
Args: Args:
env: The environment to wrap env: The environment to wrap
delay: The number of timesteps to delay observations delay: The number of timesteps to delay observations
""" """
super().__init__(env)
if not np.issubdtype(type(delay), np.integer): if not np.issubdtype(type(delay), np.integer):
raise TypeError( raise TypeError(
f"The delay is expected to be an integer, actual type: {type(delay)}" f"The delay is expected to be an integer, actual type: {type(delay)}"
@@ -67,6 +67,9 @@ class DelayObservationV0(ObservationWrapper[ObsType, ActType, ObsType]):
f"The delay needs to be greater than zero, actual value: {delay}" f"The delay needs to be greater than zero, actual value: {delay}"
) )
gym.utils.RecordConstructorArgs.__init__(self, delay=delay)
gym.ObservationWrapper.__init__(self, env)
self.delay: Final[int] = int(delay) self.delay: Final[int] = int(delay)
self.observation_queue: Final[deque] = deque() self.observation_queue: Final[deque] = deque()
@@ -88,7 +91,10 @@ class DelayObservationV0(ObservationWrapper[ObsType, ActType, ObsType]):
return create_zero_array(self.observation_space) return create_zero_array(self.observation_space)
class TimeAwareObservationV0(ObservationWrapper[WrapperObsType, ActType, ObsType]): class TimeAwareObservationV0(
gym.ObservationWrapper[WrapperObsType, ActType, ObsType],
gym.utils.RecordConstructorArgs,
):
"""Augment the observation with time information of the episode. """Augment the observation with time information of the episode.
The :attr:`normalize_time` if ``True`` represents time as a normalized value between [0,1] The :attr:`normalize_time` if ``True`` represents time as a normalized value between [0,1]
@@ -144,7 +150,7 @@ class TimeAwareObservationV0(ObservationWrapper[WrapperObsType, ActType, ObsType
def __init__( def __init__(
self, self,
env: Env[ObsType, ActType], env: gym.Env[ObsType, ActType],
flatten: bool = False, flatten: bool = False,
normalize_time: bool = True, normalize_time: bool = True,
*, *,
@@ -159,7 +165,13 @@ class TimeAwareObservationV0(ObservationWrapper[WrapperObsType, ActType, ObsType
otherwise return time as remaining timesteps before truncation otherwise return time as remaining timesteps before truncation
dict_time_key: For environment with a ``Dict`` observation space, the key for the time space. By default, `"time"`. dict_time_key: For environment with a ``Dict`` observation space, the key for the time space. By default, `"time"`.
""" """
super().__init__(env) gym.utils.RecordConstructorArgs.__init__(
self,
flatten=flatten,
normalize_time=normalize_time,
dict_time_key=dict_time_key,
)
gym.ObservationWrapper.__init__(self, env)
self.flatten: Final[bool] = flatten self.flatten: Final[bool] = flatten
self.normalize_time: Final[bool] = normalize_time self.normalize_time: Final[bool] = normalize_time
@@ -203,14 +215,14 @@ class TimeAwareObservationV0(ObservationWrapper[WrapperObsType, ActType, ObsType
# If to flatten the observation space # If to flatten the observation space
if self.flatten: if self.flatten:
self.observation_space: Space[WrapperObsType] = spaces.flatten_space( self.observation_space: gym.Space[WrapperObsType] = spaces.flatten_space(
observation_space observation_space
) )
self._obs_postprocess_func = lambda obs: spaces.flatten( self._obs_postprocess_func = lambda obs: spaces.flatten(
observation_space, obs observation_space, obs
) )
else: else:
self.observation_space: Space[WrapperObsType] = observation_space self.observation_space: gym.Space[WrapperObsType] = observation_space
self._obs_postprocess_func = lambda obs: obs self._obs_postprocess_func = lambda obs: obs
def observation(self, observation: ObsType) -> WrapperObsType: def observation(self, observation: ObsType) -> WrapperObsType:
@@ -260,7 +272,10 @@ class TimeAwareObservationV0(ObservationWrapper[WrapperObsType, ActType, ObsType
return super().reset(seed=seed, options=options) return super().reset(seed=seed, options=options)
class FrameStackObservationV0(Wrapper[WrapperObsType, ActType, ObsType, ActType]): class FrameStackObservationV0(
gym.Wrapper[WrapperObsType, ActType, ObsType, ActType],
gym.utils.RecordConstructorArgs,
):
"""Observation wrapper that stacks the observations in a rolling manner. """Observation wrapper that stacks the observations in a rolling manner.
For example, if the number of stacks is 4, then the returned observation contains For example, if the number of stacks is 4, then the returned observation contains
@@ -286,7 +301,7 @@ class FrameStackObservationV0(Wrapper[WrapperObsType, ActType, ObsType, ActType]
def __init__( def __init__(
self, self,
env: Env[ObsType, ActType], env: gym.Env[ObsType, ActType],
stack_size: int, stack_size: int,
*, *,
zeros_obs: ObsType | None = None, zeros_obs: ObsType | None = None,
@@ -298,8 +313,6 @@ class FrameStackObservationV0(Wrapper[WrapperObsType, ActType, ObsType, ActType]
stack_size: The number of frames to stack with zero_obs being used originally. stack_size: The number of frames to stack with zero_obs being used originally.
zeros_obs: Keyword only parameter that allows a custom padding observation at :meth:`reset` zeros_obs: Keyword only parameter that allows a custom padding observation at :meth:`reset`
""" """
super().__init__(env)
if not np.issubdtype(type(stack_size), np.integer): if not np.issubdtype(type(stack_size), np.integer):
raise TypeError( raise TypeError(
f"The stack_size is expected to be an integer, actual type: {type(stack_size)}" f"The stack_size is expected to be an integer, actual type: {type(stack_size)}"
@@ -309,6 +322,9 @@ class FrameStackObservationV0(Wrapper[WrapperObsType, ActType, ObsType, ActType]
f"The stack_size needs to be greater than one, actual value: {stack_size}" f"The stack_size needs to be greater than one, actual value: {stack_size}"
) )
gym.utils.RecordConstructorArgs.__init__(self, stack_size=stack_size)
gym.Wrapper.__init__(self, env)
self.observation_space = batch_space(env.observation_space, n=stack_size) self.observation_space = batch_space(env.observation_space, n=stack_size)
self.stack_size: Final[int] = stack_size self.stack_size: Final[int] = stack_size

View File

@@ -8,6 +8,7 @@ These are not intended as API functions, and will not remain stable over time.
# that verify that our dependencies are actually present. # that verify that our dependencies are actually present.
from gymnasium.utils.colorize import colorize from gymnasium.utils.colorize import colorize
from gymnasium.utils.ezpickle import EzPickle from gymnasium.utils.ezpickle import EzPickle
from gymnasium.utils.record_constructor import RecordConstructorArgs
__all__ = ["colorize", "EzPickle"] __all__ = ["colorize", "EzPickle", "RecordConstructorArgs"]

View File

@@ -1,13 +1,15 @@
"""Class for pickling and unpickling objects via their constructor arguments.""" """Class for pickling and unpickling objects via their constructor arguments."""
from typing import Any
class EzPickle: class EzPickle:
"""Objects that are pickled and unpickled via their constructor arguments. """Objects that are pickled and unpickled via their constructor arguments.
Example: Example:
>>> class Dog(Animal, EzPickle): # doctest: +SKIP >>> class Animal: pass
>>> class Dog(Animal, EzPickle):
... def __init__(self, furcolor, tailkind="bushy"): ... def __init__(self, furcolor, tailkind="bushy"):
... Animal.__init__() ... Animal.__init__(self)
... EzPickle.__init__(self, furcolor, tailkind) ... EzPickle.__init__(self, furcolor, tailkind)
When this object is unpickled, a new ``Dog`` will be constructed by passing the provided furcolor and tailkind into the constructor. When this object is unpickled, a new ``Dog`` will be constructed by passing the provided furcolor and tailkind into the constructor.
@@ -16,7 +18,7 @@ class EzPickle:
This is generally needed only for environments which wrap C/C++ code, such as MuJoCo and Atari. This is generally needed only for environments which wrap C/C++ code, such as MuJoCo and Atari.
""" """
def __init__(self, *args, **kwargs): def __init__(self, *args: Any, **kwargs: Any):
"""Uses the ``args`` and ``kwargs`` from the object's constructor for pickling.""" """Uses the ``args`` and ``kwargs`` from the object's constructor for pickling."""
self._ezpickle_args = args self._ezpickle_args = args
self._ezpickle_kwargs = kwargs self._ezpickle_kwargs = kwargs

View File

@@ -0,0 +1,33 @@
"""Allows attributes passed to `RecordConstructorArgs` to be saved. This is used by the `Wrapper.spec` to know the constructor arguments of implemented wrappers."""
from __future__ import annotations
from copy import deepcopy
from typing import Any
class RecordConstructorArgs:
"""Records all arguments passed to constructor to `_saved_kwargs`.
This can be used to save and reproduce class constructor arguments.
Note:
If two class inherit from RecordConstructorArgs then the first class to call `RecordConstructorArgs.__init__(self, ...)` will have
their kwargs saved will all subsequent `RecordConstructorArgs.__init__` being ignored.
Therefore, always call `RecordConstructorArgs.__init__` before the `Class.__init__`
"""
def __init__(self, *, _disable_deepcopy: bool = False, **kwargs: Any):
"""Records all arguments passed to constructor to `_saved_kwargs`.
Args:
_disable_deepcopy: If to not deepcopy the kwargs passed
**kwargs: Arguments to save
"""
# See class docstring for explanation
if not hasattr(self, "_saved_kwargs"):
if _disable_deepcopy is False:
kwargs = deepcopy(kwargs)
self._saved_kwargs: dict[str, Any] = kwargs

View File

@@ -11,7 +11,7 @@ except ImportError:
cv2 = None cv2 = None
class AtariPreprocessing(gym.Wrapper): class AtariPreprocessing(gym.Wrapper, gym.utils.RecordConstructorArgs):
"""Atari 2600 preprocessing wrapper. """Atari 2600 preprocessing wrapper.
This class follows the guidelines in Machado et al. (2018), This class follows the guidelines in Machado et al. (2018),
@@ -60,7 +60,18 @@ class AtariPreprocessing(gym.Wrapper):
DependencyNotInstalled: opencv-python package not installed DependencyNotInstalled: opencv-python package not installed
ValueError: Disable frame-skipping in the original env ValueError: Disable frame-skipping in the original env
""" """
super().__init__(env) gym.utils.RecordConstructorArgs.__init__(
self,
noop_max=noop_max,
frame_skip=frame_skip,
screen_size=screen_size,
terminal_on_life_loss=terminal_on_life_loss,
grayscale_obs=grayscale_obs,
grayscale_newaxis=grayscale_newaxis,
scale_obs=scale_obs,
)
gym.Wrapper.__init__(self, env)
if cv2 is None: if cv2 is None:
raise gym.error.DependencyNotInstalled( raise gym.error.DependencyNotInstalled(
"opencv-python package not installed, run `pip install gymnasium[other]` to get dependencies for atari" "opencv-python package not installed, run `pip install gymnasium[other]` to get dependencies for atari"

View File

@@ -2,7 +2,7 @@
import gymnasium as gym import gymnasium as gym
class AutoResetWrapper(gym.Wrapper): class AutoResetWrapper(gym.Wrapper, gym.utils.RecordConstructorArgs):
"""A class for providing an automatic reset functionality for gymnasium environments when calling :meth:`self.step`. """A class for providing an automatic reset functionality for gymnasium environments when calling :meth:`self.step`.
When calling step causes :meth:`Env.step` to return `terminated=True` or `truncated=True`, :meth:`Env.reset` is called, When calling step causes :meth:`Env.step` to return `terminated=True` or `truncated=True`, :meth:`Env.reset` is called,
@@ -31,7 +31,8 @@ class AutoResetWrapper(gym.Wrapper):
Args: Args:
env (gym.Env): The environment to apply the wrapper env (gym.Env): The environment to apply the wrapper
""" """
super().__init__(env) gym.utils.RecordConstructorArgs.__init__(self)
gym.Wrapper.__init__(self, env)
def step(self, action): def step(self, action):
"""Steps through the environment with action and resets the environment if a terminated or truncated signal is encountered. """Steps through the environment with action and resets the environment if a terminated or truncated signal is encountered.

View File

@@ -2,11 +2,10 @@
import numpy as np import numpy as np
import gymnasium as gym import gymnasium as gym
from gymnasium import ActionWrapper
from gymnasium.spaces import Box from gymnasium.spaces import Box
class ClipAction(ActionWrapper): class ClipAction(gym.ActionWrapper, gym.utils.RecordConstructorArgs):
"""Clip the continuous action within the valid :class:`Box` observation space bound. """Clip the continuous action within the valid :class:`Box` observation space bound.
Example: Example:
@@ -28,7 +27,9 @@ class ClipAction(ActionWrapper):
env: The environment to apply the wrapper env: The environment to apply the wrapper
""" """
assert isinstance(env.action_space, Box) assert isinstance(env.action_space, Box)
super().__init__(env)
gym.utils.RecordConstructorArgs.__init__(self)
gym.ActionWrapper.__init__(self, env)
def action(self, action): def action(self, action):
"""Clips the action within the valid bounds. """Clips the action within the valid bounds.

View File

@@ -68,11 +68,12 @@ class EnvCompatibility(gym.Env):
"The `gymnasium.make(..., apply_api_compatibility=...)` parameter is deprecated and will be removed in v0.29. " "The `gymnasium.make(..., apply_api_compatibility=...)` parameter is deprecated and will be removed in v0.29. "
"Instead use `gym.make('GymV21Environment-v0', env_name=...)` or `from shimmy import GymV21CompatibilityV0`" "Instead use `gym.make('GymV21Environment-v0', env_name=...)` or `from shimmy import GymV21CompatibilityV0`"
) )
self.env = old_env
self.metadata = getattr(old_env, "metadata", {"render_modes": []}) self.metadata = getattr(old_env, "metadata", {"render_modes": []})
self.render_mode = render_mode self.render_mode = render_mode
self.reward_range = getattr(old_env, "reward_range", None) self.reward_range = getattr(old_env, "reward_range", None)
self.spec = getattr(old_env, "spec", None) self.spec = getattr(old_env, "spec", None)
self.env = old_env
self.observation_space = old_env.observation_space self.observation_space = old_env.observation_space
self.action_space = old_env.action_space self.action_space = old_env.action_space

View File

@@ -10,12 +10,13 @@ from gymnasium.utils.passive_env_checker import (
) )
class PassiveEnvChecker(gym.Wrapper): class PassiveEnvChecker(gym.Wrapper, gym.utils.RecordConstructorArgs):
"""A passive environment checker wrapper that surrounds the step, reset and render functions to check they follow the gymnasium API.""" """A passive environment checker wrapper that surrounds the step, reset and render functions to check they follow the gymnasium API."""
def __init__(self, env): def __init__(self, env):
"""Initialises the wrapper with the environments, run the observation and action space tests.""" """Initialises the wrapper with the environments, run the observation and action space tests."""
super().__init__(env) gym.utils.RecordConstructorArgs.__init__(self)
gym.Wrapper.__init__(self, env)
assert hasattr( assert hasattr(
env, "action_space" env, "action_space"

View File

@@ -6,7 +6,7 @@ import gymnasium as gym
from gymnasium import spaces from gymnasium import spaces
class FilterObservation(gym.ObservationWrapper): class FilterObservation(gym.ObservationWrapper, gym.utils.RecordConstructorArgs):
"""Filter Dict observation space by the keys. """Filter Dict observation space by the keys.
Example: Example:
@@ -35,7 +35,8 @@ class FilterObservation(gym.ObservationWrapper):
ValueError: If the environment's observation space is not :class:`spaces.Dict` ValueError: If the environment's observation space is not :class:`spaces.Dict`
ValueError: If any of the `filter_keys` are not included in the original `env`'s observation space ValueError: If any of the `filter_keys` are not included in the original `env`'s observation space
""" """
super().__init__(env) gym.utils.RecordConstructorArgs.__init__(self, filter_keys=filter_keys)
gym.ObservationWrapper.__init__(self, env)
wrapped_observation_space = env.observation_space wrapped_observation_space = env.observation_space
if not isinstance(wrapped_observation_space, spaces.Dict): if not isinstance(wrapped_observation_space, spaces.Dict):

View File

@@ -3,7 +3,7 @@ import gymnasium as gym
from gymnasium import spaces from gymnasium import spaces
class FlattenObservation(gym.ObservationWrapper): class FlattenObservation(gym.ObservationWrapper, gym.utils.RecordConstructorArgs):
"""Observation wrapper that flattens the observation. """Observation wrapper that flattens the observation.
Example: Example:
@@ -26,7 +26,9 @@ class FlattenObservation(gym.ObservationWrapper):
Args: Args:
env: The environment to apply the wrapper env: The environment to apply the wrapper
""" """
super().__init__(env) gym.utils.RecordConstructorArgs.__init__(self)
gym.ObservationWrapper.__init__(self, env)
self.observation_space = spaces.flatten_space(env.observation_space) self.observation_space = spaces.flatten_space(env.observation_space)
def observation(self, observation): def observation(self, observation):

View File

@@ -97,7 +97,7 @@ class LazyFrames:
return frame return frame
class FrameStack(gym.ObservationWrapper): class FrameStack(gym.ObservationWrapper, gym.utils.RecordConstructorArgs):
"""Observation wrapper that stacks the observations in a rolling manner. """Observation wrapper that stacks the observations in a rolling manner.
For example, if the number of stacks is 4, then the returned observation contains For example, if the number of stacks is 4, then the returned observation contains
@@ -137,7 +137,11 @@ class FrameStack(gym.ObservationWrapper):
num_stack (int): The number of frames to stack num_stack (int): The number of frames to stack
lz4_compress (bool): Use lz4 to compress the frames internally lz4_compress (bool): Use lz4 to compress the frames internally
""" """
super().__init__(env) gym.utils.RecordConstructorArgs.__init__(
self, num_stack=num_stack, lz4_compress=lz4_compress
)
gym.ObservationWrapper.__init__(self, env)
self.num_stack = num_stack self.num_stack = num_stack
self.lz4_compress = lz4_compress self.lz4_compress = lz4_compress

View File

@@ -5,7 +5,7 @@ import gymnasium as gym
from gymnasium.spaces import Box from gymnasium.spaces import Box
class GrayScaleObservation(gym.ObservationWrapper): class GrayScaleObservation(gym.ObservationWrapper, gym.utils.RecordConstructorArgs):
"""Convert the image observation from RGB to gray scale. """Convert the image observation from RGB to gray scale.
Example: Example:
@@ -30,7 +30,9 @@ class GrayScaleObservation(gym.ObservationWrapper):
keep_dim (bool): If `True`, a singleton dimension will be added, i.e. observations are of the shape AxBx1. keep_dim (bool): If `True`, a singleton dimension will be added, i.e. observations are of the shape AxBx1.
Otherwise, they are of shape AxB. Otherwise, they are of shape AxB.
""" """
super().__init__(env) gym.utils.RecordConstructorArgs.__init__(self, keep_dim=keep_dim)
gym.ObservationWrapper.__init__(self, env)
self.keep_dim = keep_dim self.keep_dim = keep_dim
assert ( assert (

View File

@@ -7,7 +7,7 @@ import gymnasium as gym
from gymnasium.error import DependencyNotInstalled from gymnasium.error import DependencyNotInstalled
class HumanRendering(gym.Wrapper): class HumanRendering(gym.Wrapper, gym.utils.RecordConstructorArgs):
"""Performs human rendering for an environment that only supports "rgb_array"rendering. """Performs human rendering for an environment that only supports "rgb_array"rendering.
This wrapper is particularly useful when you have implemented an environment that can produce This wrapper is particularly useful when you have implemented an environment that can produce
@@ -47,7 +47,9 @@ class HumanRendering(gym.Wrapper):
Args: Args:
env: The environment that is being wrapped env: The environment that is being wrapped
""" """
super().__init__(env) gym.utils.RecordConstructorArgs.__init__(self)
gym.Wrapper.__init__(self, env)
assert env.render_mode in [ assert env.render_mode in [
"rgb_array", "rgb_array",
"rgb_array_list", "rgb_array_list",
@@ -64,6 +66,8 @@ class HumanRendering(gym.Wrapper):
if "human" not in self.metadata["render_modes"]: if "human" not in self.metadata["render_modes"]:
self.metadata["render_modes"].append("human") self.metadata["render_modes"].append("human")
gym.utils.RecordConstructorArgs.__init__(self)
@property @property
def render_mode(self): def render_mode(self):
"""Always returns ``'human'``.""" """Always returns ``'human'``."""

View File

@@ -45,7 +45,7 @@ def update_mean_var_count_from_moments(
return new_mean, new_var, new_count return new_mean, new_var, new_count
class NormalizeObservation(gym.Wrapper): class NormalizeObservation(gym.Wrapper, gym.utils.RecordConstructorArgs):
"""This wrapper will normalize observations s.t. each coordinate is centered with unit variance. """This wrapper will normalize observations s.t. each coordinate is centered with unit variance.
Note: Note:
@@ -60,7 +60,9 @@ class NormalizeObservation(gym.Wrapper):
env (Env): The environment to apply the wrapper env (Env): The environment to apply the wrapper
epsilon: A stability parameter that is used when scaling the observations. epsilon: A stability parameter that is used when scaling the observations.
""" """
super().__init__(env) gym.utils.RecordConstructorArgs.__init__(self, epsilon=epsilon)
gym.Wrapper.__init__(self, env)
self.num_envs = getattr(env, "num_envs", 1) self.num_envs = getattr(env, "num_envs", 1)
self.is_vector_env = getattr(env, "is_vector_env", False) self.is_vector_env = getattr(env, "is_vector_env", False)
if self.is_vector_env: if self.is_vector_env:
@@ -93,7 +95,7 @@ class NormalizeObservation(gym.Wrapper):
return (obs - self.obs_rms.mean) / np.sqrt(self.obs_rms.var + self.epsilon) return (obs - self.obs_rms.mean) / np.sqrt(self.obs_rms.var + self.epsilon)
class NormalizeReward(gym.core.Wrapper): class NormalizeReward(gym.core.Wrapper, gym.utils.RecordConstructorArgs):
r"""This wrapper will normalize immediate rewards s.t. their exponential moving average has a fixed variance. r"""This wrapper will normalize immediate rewards s.t. their exponential moving average has a fixed variance.
The exponential moving average will have variance :math:`(1 - \gamma)^2`. The exponential moving average will have variance :math:`(1 - \gamma)^2`.
@@ -116,7 +118,9 @@ class NormalizeReward(gym.core.Wrapper):
epsilon (float): A stability parameter epsilon (float): A stability parameter
gamma (float): The discount factor that is used in the exponential moving average. gamma (float): The discount factor that is used in the exponential moving average.
""" """
super().__init__(env) gym.utils.RecordConstructorArgs.__init__(self, gamma=gamma, epsilon=epsilon)
gym.Wrapper.__init__(self, env)
self.num_envs = getattr(env, "num_envs", 1) self.num_envs = getattr(env, "num_envs", 1)
self.is_vector_env = getattr(env, "is_vector_env", False) self.is_vector_env = getattr(env, "is_vector_env", False)
self.return_rms = RunningMeanStd(shape=()) self.return_rms = RunningMeanStd(shape=())

View File

@@ -3,7 +3,7 @@ import gymnasium as gym
from gymnasium.error import ResetNeeded from gymnasium.error import ResetNeeded
class OrderEnforcing(gym.Wrapper): class OrderEnforcing(gym.Wrapper, gym.utils.RecordConstructorArgs):
"""A wrapper that will produce an error if :meth:`step` is called before an initial :meth:`reset`. """A wrapper that will produce an error if :meth:`step` is called before an initial :meth:`reset`.
Example: Example:
@@ -32,7 +32,11 @@ class OrderEnforcing(gym.Wrapper):
env: The environment to wrap env: The environment to wrap
disable_render_order_enforcing: If to disable render order enforcing disable_render_order_enforcing: If to disable render order enforcing
""" """
super().__init__(env) gym.utils.RecordConstructorArgs.__init__(
self, disable_render_order_enforcing=disable_render_order_enforcing
)
gym.Wrapper.__init__(self, env)
self._has_reset: bool = False self._has_reset: bool = False
self._disable_render_order_enforcing: bool = disable_render_order_enforcing self._disable_render_order_enforcing: bool = disable_render_order_enforcing

View File

@@ -13,7 +13,7 @@ from gymnasium import spaces
STATE_KEY = "state" STATE_KEY = "state"
class PixelObservationWrapper(gym.ObservationWrapper): class PixelObservationWrapper(gym.ObservationWrapper, gym.utils.RecordConstructorArgs):
"""Augment observations by pixel values. """Augment observations by pixel values.
Observations of this wrapper will be dictionaries of images. Observations of this wrapper will be dictionaries of images.
@@ -79,7 +79,13 @@ class PixelObservationWrapper(gym.ObservationWrapper):
specified ``pixel_keys``. specified ``pixel_keys``.
TypeError: When an unexpected pixel type is used TypeError: When an unexpected pixel type is used
""" """
super().__init__(env) gym.utils.RecordConstructorArgs.__init__(
self,
pixels_only=pixels_only,
render_kwargs=render_kwargs,
pixel_keys=pixel_keys,
)
gym.ObservationWrapper.__init__(self, env)
# Avoid side-effects that occur when render_kwargs is manipulated # Avoid side-effects that occur when render_kwargs is manipulated
render_kwargs = copy.deepcopy(render_kwargs) render_kwargs = copy.deepcopy(render_kwargs)

View File

@@ -8,7 +8,7 @@ import numpy as np
import gymnasium as gym import gymnasium as gym
class RecordEpisodeStatistics(gym.Wrapper): class RecordEpisodeStatistics(gym.Wrapper, gym.utils.RecordConstructorArgs):
"""This wrapper will keep track of cumulative rewards and episode lengths. """This wrapper will keep track of cumulative rewards and episode lengths.
At the end of an episode, the statistics of the episode will be added to ``info`` At the end of an episode, the statistics of the episode will be added to ``info``
@@ -56,7 +56,9 @@ class RecordEpisodeStatistics(gym.Wrapper):
env (Env): The environment to apply the wrapper env (Env): The environment to apply the wrapper
deque_size: The size of the buffers :attr:`return_queue` and :attr:`length_queue` deque_size: The size of the buffers :attr:`return_queue` and :attr:`length_queue`
""" """
super().__init__(env) gym.utils.RecordConstructorArgs.__init__(self, deque_size=deque_size)
gym.Wrapper.__init__(self, env)
self.num_envs = getattr(env, "num_envs", 1) self.num_envs = getattr(env, "num_envs", 1)
self.episode_count = 0 self.episode_count = 0
self.episode_start_times: np.ndarray = None self.episode_start_times: np.ndarray = None

View File

@@ -24,7 +24,7 @@ def capped_cubic_video_schedule(episode_id: int) -> bool:
return episode_id % 1000 == 0 return episode_id % 1000 == 0
class RecordVideo(gym.Wrapper): class RecordVideo(gym.Wrapper, gym.utils.RecordConstructorArgs):
"""This wrapper records videos of rollouts. """This wrapper records videos of rollouts.
Usually, you only want to record episodes intermittently, say every hundredth episode. Usually, you only want to record episodes intermittently, say every hundredth episode.
@@ -58,9 +58,17 @@ class RecordVideo(gym.Wrapper):
Otherwise, snippets of the specified length are captured Otherwise, snippets of the specified length are captured
name_prefix (str): Will be prepended to the filename of the recordings name_prefix (str): Will be prepended to the filename of the recordings
disable_logger (bool): Whether to disable moviepy logger or not. disable_logger (bool): Whether to disable moviepy logger or not.
""" """
super().__init__(env) gym.utils.RecordConstructorArgs.__init__(
self,
video_folder=video_folder,
episode_trigger=episode_trigger,
step_trigger=step_trigger,
video_length=video_length,
name_prefix=name_prefix,
disable_logger=disable_logger,
)
gym.Wrapper.__init__(self, env)
if episode_trigger is None and step_trigger is None: if episode_trigger is None and step_trigger is None:
episode_trigger = capped_cubic_video_schedule episode_trigger = capped_cubic_video_schedule

View File

@@ -4,7 +4,7 @@ import copy
import gymnasium as gym import gymnasium as gym
class RenderCollection(gym.Wrapper): class RenderCollection(gym.Wrapper, gym.utils.RecordConstructorArgs):
"""Save collection of render frames.""" """Save collection of render frames."""
def __init__(self, env: gym.Env, pop_frames: bool = True, reset_clean: bool = True): def __init__(self, env: gym.Env, pop_frames: bool = True, reset_clean: bool = True):
@@ -17,7 +17,11 @@ class RenderCollection(gym.Wrapper):
reset_clean (bool): If true, clear the collection frames when .reset() is called. reset_clean (bool): If true, clear the collection frames when .reset() is called.
Default value is True. Default value is True.
""" """
super().__init__(env) gym.utils.RecordConstructorArgs.__init__(
self, pop_frames=pop_frames, reset_clean=reset_clean
)
gym.Wrapper.__init__(self, env)
assert env.render_mode is not None assert env.render_mode is not None
assert not env.render_mode.endswith("_list") assert not env.render_mode.endswith("_list")
self.frame_list = [] self.frame_list = []

View File

@@ -7,7 +7,7 @@ import gymnasium as gym
from gymnasium.spaces import Box from gymnasium.spaces import Box
class RescaleAction(gym.ActionWrapper): class RescaleAction(gym.ActionWrapper, gym.utils.RecordConstructorArgs):
"""Affinely rescales the continuous action space of the environment to the range [min_action, max_action]. """Affinely rescales the continuous action space of the environment to the range [min_action, max_action].
The base environment :attr:`env` must have an action space of type :class:`spaces.Box`. If :attr:`min_action` The base environment :attr:`env` must have an action space of type :class:`spaces.Box`. If :attr:`min_action`
@@ -47,7 +47,11 @@ class RescaleAction(gym.ActionWrapper):
), f"expected Box action space, got {type(env.action_space)}" ), f"expected Box action space, got {type(env.action_space)}"
assert np.less_equal(min_action, max_action).all(), (min_action, max_action) assert np.less_equal(min_action, max_action).all(), (min_action, max_action)
super().__init__(env) gym.utils.RecordConstructorArgs.__init__(
self, min_action=min_action, max_action=max_action
)
gym.ActionWrapper.__init__(self, env)
self.min_action = ( self.min_action = (
np.zeros(env.action_space.shape, dtype=env.action_space.dtype) + min_action np.zeros(env.action_space.shape, dtype=env.action_space.dtype) + min_action
) )

View File

@@ -8,7 +8,7 @@ from gymnasium.error import DependencyNotInstalled
from gymnasium.spaces import Box from gymnasium.spaces import Box
class ResizeObservation(gym.ObservationWrapper): class ResizeObservation(gym.ObservationWrapper, gym.utils.RecordConstructorArgs):
"""Resize the image observation. """Resize the image observation.
This wrapper works on environments with image observations. More generally, This wrapper works on environments with image observations. More generally,
@@ -36,7 +36,9 @@ class ResizeObservation(gym.ObservationWrapper):
env: The environment to apply the wrapper env: The environment to apply the wrapper
shape: The shape of the resized observations shape: The shape of the resized observations
""" """
super().__init__(env) gym.utils.RecordConstructorArgs.__init__(self, shape=shape)
gym.ObservationWrapper.__init__(self, env)
if isinstance(shape, int): if isinstance(shape, int):
shape = (shape, shape) shape = (shape, shape)
assert len(shape) == 2 and all( assert len(shape) == 2 and all(

View File

@@ -4,7 +4,7 @@ from gymnasium.logger import deprecation
from gymnasium.utils.step_api_compatibility import step_api_compatibility from gymnasium.utils.step_api_compatibility import step_api_compatibility
class StepAPICompatibility(gym.Wrapper): class StepAPICompatibility(gym.Wrapper, gym.utils.RecordConstructorArgs):
r"""A wrapper which can transform an environment from new step API to old and vice-versa. r"""A wrapper which can transform an environment from new step API to old and vice-versa.
Old step API refers to step() method returning (observation, reward, done, info) Old step API refers to step() method returning (observation, reward, done, info)
@@ -29,7 +29,11 @@ class StepAPICompatibility(gym.Wrapper):
env (gym.Env): the env to wrap. Can be in old or new API env (gym.Env): the env to wrap. Can be in old or new API
output_truncation_bool (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API) output_truncation_bool (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API)
""" """
super().__init__(env) gym.utils.RecordConstructorArgs.__init__(
self, output_truncation_bool=output_truncation_bool
)
gym.Wrapper.__init__(self, env)
self.is_vector_env = isinstance(env.unwrapped, gym.vector.VectorEnv) self.is_vector_env = isinstance(env.unwrapped, gym.vector.VectorEnv)
self.output_truncation_bool = output_truncation_bool self.output_truncation_bool = output_truncation_bool
if not self.output_truncation_bool: if not self.output_truncation_bool:

View File

@@ -5,7 +5,7 @@ import gymnasium as gym
from gymnasium.spaces import Box from gymnasium.spaces import Box
class TimeAwareObservation(gym.ObservationWrapper): class TimeAwareObservation(gym.ObservationWrapper, gym.utils.RecordConstructorArgs):
"""Augment the observation with the current time step in the episode. """Augment the observation with the current time step in the episode.
The observation space of the wrapped environment is assumed to be a flat :class:`Box`. The observation space of the wrapped environment is assumed to be a flat :class:`Box`.
@@ -29,7 +29,9 @@ class TimeAwareObservation(gym.ObservationWrapper):
Args: Args:
env: The environment to apply the wrapper env: The environment to apply the wrapper
""" """
super().__init__(env) gym.utils.RecordConstructorArgs.__init__(self)
gym.ObservationWrapper.__init__(self, env)
assert isinstance(env.observation_space, Box) assert isinstance(env.observation_space, Box)
assert env.observation_space.dtype == np.float32 assert env.observation_space.dtype == np.float32
low = np.append(self.observation_space.low, 0.0) low = np.append(self.observation_space.low, 0.0)

View File

@@ -4,7 +4,7 @@ from typing import Optional
import gymnasium as gym import gymnasium as gym
class TimeLimit(gym.Wrapper): class TimeLimit(gym.Wrapper, gym.utils.RecordConstructorArgs):
"""This wrapper will issue a `truncated` signal if a maximum number of timesteps is exceeded. """This wrapper will issue a `truncated` signal if a maximum number of timesteps is exceeded.
If a truncation is not defined inside the environment itself, this is the only place that the truncation signal is issued. If a truncation is not defined inside the environment itself, this is the only place that the truncation signal is issued.
@@ -26,14 +26,16 @@ class TimeLimit(gym.Wrapper):
Args: Args:
env: The environment to apply the wrapper env: The environment to apply the wrapper
max_episode_steps: An optional max episode steps (if ``Ǹone``, ``env.spec.max_episode_steps`` is used) max_episode_steps: An optional max episode steps (if ``None``, ``env.spec.max_episode_steps`` is used)
""" """
super().__init__(env) gym.utils.RecordConstructorArgs.__init__(
self, max_episode_steps=max_episode_steps
)
gym.Wrapper.__init__(self, env)
if max_episode_steps is None and self.env.spec is not None: if max_episode_steps is None and self.env.spec is not None:
assert env.spec is not None assert env.spec is not None
max_episode_steps = env.spec.max_episode_steps max_episode_steps = env.spec.max_episode_steps
if self.env.spec is not None:
self.env.spec.max_episode_steps = max_episode_steps
self._max_episode_steps = max_episode_steps self._max_episode_steps = max_episode_steps
self._elapsed_steps = None self._elapsed_steps = None

View File

@@ -4,7 +4,7 @@ from typing import Any, Callable
import gymnasium as gym import gymnasium as gym
class TransformObservation(gym.ObservationWrapper): class TransformObservation(gym.ObservationWrapper, gym.utils.RecordConstructorArgs):
"""Transform the observation via an arbitrary function :attr:`f`. """Transform the observation via an arbitrary function :attr:`f`.
The function :attr:`f` should be defined on the observation space of the base environment, ``env``, and should, ideally, return values in the same space. The function :attr:`f` should be defined on the observation space of the base environment, ``env``, and should, ideally, return values in the same space.
@@ -29,7 +29,9 @@ class TransformObservation(gym.ObservationWrapper):
env: The environment to apply the wrapper env: The environment to apply the wrapper
f: A function that transforms the observation f: A function that transforms the observation
""" """
super().__init__(env) gym.utils.RecordConstructorArgs.__init__(self, f=f)
gym.ObservationWrapper.__init__(self, env)
assert callable(f) assert callable(f)
self.f = f self.f = f

View File

@@ -2,10 +2,9 @@
from typing import Callable from typing import Callable
import gymnasium as gym import gymnasium as gym
from gymnasium import RewardWrapper
class TransformReward(RewardWrapper): class TransformReward(gym.RewardWrapper, gym.utils.RecordConstructorArgs):
"""Transform the reward via an arbitrary function. """Transform the reward via an arbitrary function.
Warning: Warning:
@@ -29,7 +28,9 @@ class TransformReward(RewardWrapper):
env: The environment to apply the wrapper env: The environment to apply the wrapper
f: A function that transforms the reward f: A function that transforms the reward
""" """
super().__init__(env) gym.utils.RecordConstructorArgs.__init__(self, f=f)
gym.RewardWrapper.__init__(self, env)
assert callable(f) assert callable(f)
self.f = f self.f = f

View File

@@ -5,7 +5,7 @@ from typing import List
import gymnasium as gym import gymnasium as gym
class VectorListInfo(gym.Wrapper): class VectorListInfo(gym.Wrapper, gym.utils.RecordConstructorArgs):
"""Converts infos of vectorized environments from dict to List[dict]. """Converts infos of vectorized environments from dict to List[dict].
This wrapper converts the info format of a This wrapper converts the info format of a
@@ -51,7 +51,9 @@ class VectorListInfo(gym.Wrapper):
assert getattr( assert getattr(
env, "is_vector_env", False env, "is_vector_env", False
), "This wrapper can only be used in vectorized environments." ), "This wrapper can only be used in vectorized environments."
super().__init__(env)
gym.utils.RecordConstructorArgs.__init__(self)
gym.Wrapper.__init__(self, env)
def step(self, action): def step(self, action):
"""Steps through the environment, convert dict info to list.""" """Steps through the environment, convert dict info to list."""

View File

@@ -0,0 +1,237 @@
"""Example file showing usage of env.specstack."""
import pickle
import pytest
import gymnasium as gym
from gymnasium.envs.classic_control import CartPoleEnv
from gymnasium.envs.registration import EnvSpec
from gymnasium.utils.env_checker import data_equivalence
def test_full_integration():
# Create an environment to test with
env = gym.make("CartPole-v1", render_mode="rgb_array").unwrapped
env = gym.wrappers.FlattenObservation(env)
env = gym.wrappers.TimeAwareObservation(env)
env = gym.wrappers.NormalizeReward(env, gamma=0.8)
# Generate the spec_stack
env_spec = env.spec
assert isinstance(env_spec, EnvSpec)
# env_spec.pprint()
# Serialize the spec_stack
env_spec_json = env_spec.to_json()
assert isinstance(env_spec_json, str)
# Deserialize the spec_stack
recreate_env_spec = EnvSpec.from_json(env_spec_json)
# recreate_env_spec.pprint()
for wrapper_spec, recreated_wrapper_spec in zip(
env_spec.applied_wrappers, recreate_env_spec.applied_wrappers
):
assert wrapper_spec == recreated_wrapper_spec
assert recreate_env_spec == env_spec
# Recreate the environment using the spec_stack
recreated_env = gym.make(recreate_env_spec)
assert recreated_env.render_mode == "rgb_array"
assert isinstance(recreated_env, gym.wrappers.NormalizeReward)
assert recreated_env.gamma == 0.8
assert isinstance(recreated_env.env, gym.wrappers.TimeAwareObservation)
assert isinstance(recreated_env.unwrapped, CartPoleEnv)
obs, info = env.reset(seed=42)
recreated_obs, recreated_info = recreated_env.reset(seed=42)
assert data_equivalence(obs, recreated_obs)
assert data_equivalence(info, recreated_info)
action = env.action_space.sample()
obs, reward, terminated, truncated, info = env.step(action)
(
recreated_obs,
recreated_reward,
recreated_terminated,
recreated_truncated,
recreated_info,
) = recreated_env.step(action)
assert data_equivalence(obs, recreated_obs)
assert data_equivalence(reward, recreated_reward)
assert data_equivalence(terminated, recreated_terminated)
assert data_equivalence(truncated, recreated_truncated)
assert data_equivalence(info, recreated_info)
# Test the pprint of the spec_stack
spec_stack_output = env_spec.pprint(disable_print=True)
json_spec_stack_output = env_spec.pprint(disable_print=True)
assert spec_stack_output == json_spec_stack_output
@pytest.mark.parametrize(
"env_spec",
[
gym.spec("CartPole-v1"),
gym.make("CartPole-v1").unwrapped.spec,
gym.make("CartPole-v1").spec,
gym.wrappers.NormalizeReward(gym.make("CartPole-v1")).spec,
],
)
def test_env_spec_to_from_json(env_spec: EnvSpec):
json_spec = env_spec.to_json()
recreated_env_spec = EnvSpec.from_json(json_spec)
assert env_spec == recreated_env_spec
def test_wrapped_env_entry_point():
def _create_env():
_env = gym.make("CartPole-v1", render_mode="rgb_array")
_env = gym.wrappers.FlattenObservation(_env)
return _env
gym.register("TestingEnv-v0", entry_point=_create_env)
env = gym.make("TestingEnv-v0")
env = gym.wrappers.TimeAwareObservation(env)
env = gym.wrappers.NormalizeReward(env, gamma=0.8)
recreated_env = gym.make(env.spec)
obs, info = env.reset(seed=42)
recreated_obs, recreated_info = recreated_env.reset(seed=42)
assert data_equivalence(obs, recreated_obs)
assert data_equivalence(info, recreated_info)
action = env.action_space.sample()
obs, reward, terminated, truncated, info = env.step(action)
(
recreated_obs,
recreated_reward,
recreated_terminated,
recreated_truncated,
recreated_info,
) = recreated_env.step(action)
assert data_equivalence(obs, recreated_obs)
assert data_equivalence(reward, recreated_reward)
assert data_equivalence(terminated, recreated_terminated)
assert data_equivalence(truncated, recreated_truncated)
assert data_equivalence(info, recreated_info)
del gym.registry["TestingEnv-v0"]
def test_pickling_env_stack():
env = gym.make("CartPole-v1", render_mode="rgb_array")
env = gym.wrappers.FlattenObservation(env)
env = gym.wrappers.TimeAwareObservation(env)
env = gym.wrappers.NormalizeReward(env, gamma=0.8)
pickled_env = pickle.loads(pickle.dumps(env))
obs, info = env.reset(seed=123)
pickled_obs, pickled_info = pickled_env.reset(seed=123)
assert data_equivalence(obs, pickled_obs)
assert data_equivalence(info, pickled_info)
action = env.action_space.sample()
obs, reward, terminated, truncated, info = env.step(action)
(
pickled_obs,
pickled_reward,
pickled_terminated,
pickled_truncated,
pickled_info,
) = pickled_env.step(action)
assert data_equivalence(obs, pickled_obs)
assert data_equivalence(reward, pickled_reward)
assert data_equivalence(terminated, pickled_terminated)
assert data_equivalence(truncated, pickled_truncated)
assert data_equivalence(info, pickled_info)
env.close()
pickled_env.close()
# flake8: noqa
def test_env_spec_pprint():
env = gym.make("CartPole-v1")
env_spec = env.spec
assert env_spec is not None
output = env_spec.pprint(disable_print=True)
assert (
output
== """id=CartPole-v1
reward_threshold=475.0
max_episode_steps=500
applied_wrappers=[
name=PassiveEnvChecker, kwargs={},
name=OrderEnforcing, kwargs={'disable_render_order_enforcing': False},
name=TimeLimit, kwargs={'max_episode_steps': 500}
]"""
)
output = env_spec.pprint(disable_print=True, include_entry_points=True)
assert (
output
== """id=CartPole-v1
entry_point=gymnasium.envs.classic_control.cartpole:CartPoleEnv
reward_threshold=475.0
max_episode_steps=500
applied_wrappers=[
name=PassiveEnvChecker, entry_point=gymnasium.wrappers.env_checker:PassiveEnvChecker, kwargs={},
name=OrderEnforcing, entry_point=gymnasium.wrappers.order_enforcing:OrderEnforcing, kwargs={'disable_render_order_enforcing': False},
name=TimeLimit, entry_point=gymnasium.wrappers.time_limit:TimeLimit, kwargs={'max_episode_steps': 500}
]"""
)
output = env_spec.pprint(disable_print=True, print_all=True)
assert (
output
== """id=CartPole-v1
entry_point=gymnasium.envs.classic_control.cartpole:CartPoleEnv
reward_threshold=475.0
nondeterministic=False
max_episode_steps=500
order_enforce=True
autoreset=False
disable_env_checker=False
applied_api_compatibility=False
applied_wrappers=[
name=PassiveEnvChecker, kwargs={},
name=OrderEnforcing, kwargs={'disable_render_order_enforcing': False},
name=TimeLimit, kwargs={'max_episode_steps': 500}
]"""
)
env_spec.applied_wrappers = ()
output = env_spec.pprint(disable_print=True)
assert (
output
== """id=CartPole-v1
reward_threshold=475.0
max_episode_steps=500"""
)
output = env_spec.pprint(disable_print=True, print_all=True)
assert (
output
== """id=CartPole-v1
entry_point=gymnasium.envs.classic_control.cartpole:CartPoleEnv
reward_threshold=475.0
nondeterministic=False
max_episode_steps=500
order_enforce=True
autoreset=False
disable_env_checker=False
applied_api_compatibility=False
applied_wrappers=[]"""
)

View File

@@ -8,6 +8,8 @@ import numpy as np
import pytest import pytest
import gymnasium as gym import gymnasium as gym
from gymnasium import Env
from gymnasium.core import ActType, ObsType, WrapperObsType
from gymnasium.envs.classic_control import CartPoleEnv from gymnasium.envs.classic_control import CartPoleEnv
from gymnasium.wrappers import ( from gymnasium.wrappers import (
AutoResetWrapper, AutoResetWrapper,
@@ -355,3 +357,69 @@ def test_import_module_during_make():
env.close() env.close()
del gym.registry["RegisterDuringMake-v0"] del gym.registry["RegisterDuringMake-v0"]
class NoRecordArgsWrapper(gym.ObservationWrapper):
def __init__(self, env: Env[ObsType, ActType]):
super().__init__(env)
def observation(self, observation: ObsType) -> WrapperObsType:
return self.observation_space.sample()
def test_make_env_spec():
# make
env_1 = gym.make(gym.spec("CartPole-v1"))
assert isinstance(env_1, CartPoleEnv)
assert env_1 is env_1.unwrapped
env_1.close()
# make with applied wrappers
env_2 = gym.wrappers.NormalizeReward(
gym.wrappers.TimeAwareObservation(
gym.wrappers.FlattenObservation(
gym.make("CartPole-v1", render_mode="rgb_array")
)
),
gamma=0.8,
)
env_2_recreated = gym.make(env_2.spec)
assert env_2.spec == env_2_recreated.spec
env_2.close()
env_2_recreated.close()
# make with callable entry point
gym.register("CartPole-v2", lambda: CartPoleEnv())
env_3 = gym.make("CartPole-v2")
assert isinstance(env_3.unwrapped, CartPoleEnv)
env_3.close()
# make with wrapper in env-creator
gym.register(
"CartPole-v3", lambda: gym.wrappers.TimeAwareObservation(CartPoleEnv())
)
env_4 = gym.make(gym.spec("CartPole-v3"))
assert isinstance(env_4, gym.wrappers.TimeAwareObservation)
assert isinstance(env_4.env, CartPoleEnv)
env_4.close()
# make with no ezpickle wrapper
env_5 = NoRecordArgsWrapper(gym.make("CartPole-v1").unwrapped)
env_5.close()
with pytest.raises(
ValueError,
match=re.escape(
"NoRecordArgsWrapper wrapper does not inherit from `gymnasium.utils.RecordConstructorArgs`, therefore, the wrapper cannot be recreated."
),
):
gym.make(env_5.spec)
# make with no ezpickle wrapper but in the entry point
gym.register("CartPole-v4", entry_point=lambda: NoRecordArgsWrapper(CartPoleEnv()))
env_6 = gym.make(gym.spec("CartPole-v4"))
assert isinstance(env_6, NoRecordArgsWrapper)
assert isinstance(env_6.unwrapped, CartPoleEnv)
del gym.registry["CartPole-v2"]
del gym.registry["CartPole-v3"]
del gym.registry["CartPole-v4"]

View File

@@ -22,7 +22,7 @@ def test_mujoco_action_dimensions(env_spec: EnvSpec):
* Too many dimensions * Too many dimensions
* Incorrect shape * Incorrect shape
""" """
env = env_spec.make(disable_env_checker=True) env = env_spec.make()
env.reset() env.reset()
# Too few actions # Too few actions

View File

@@ -42,7 +42,7 @@ CHECK_ENV_IGNORE_WARNINGS = [
def test_all_env_api(spec): def test_all_env_api(spec):
"""Check that all environments pass the environment checker with no warnings other than the expected.""" """Check that all environments pass the environment checker with no warnings other than the expected."""
with warnings.catch_warnings(record=True) as caught_warnings: with warnings.catch_warnings(record=True) as caught_warnings:
env = spec.make(disable_env_checker=True).unwrapped env = spec.make().unwrapped
check_env(env, skip_render_check=True) check_env(env, skip_render_check=True)
env.close() env.close()

View File

@@ -15,8 +15,8 @@ def verify_environments_match(
): ):
"""Verifies with two environment ids (old and new) are identical in obs, reward and done """Verifies with two environment ids (old and new) are identical in obs, reward and done
(except info where all old info must be contained in new info).""" (except info where all old info must be contained in new info)."""
old_env = envs.make(old_env_id, disable_env_checker=True) old_env = envs.make(old_env_id)
new_env = envs.make(new_env_id, disable_env_checker=True) new_env = envs.make(new_env_id)
old_reset_obs, old_info = old_env.reset(seed=seed) old_reset_obs, old_info = old_env.reset(seed=seed)
new_reset_obs, new_info = new_env.reset(seed=seed) new_reset_obs, new_info = new_env.reset(seed=seed)

View File

@@ -106,7 +106,7 @@ class TestNestedDictWrapper:
observation_space = env.observation_space observation_space = env.observation_space
assert isinstance(observation_space, Dict) assert isinstance(observation_space, Dict)
wrapped_env = FlattenObservation(FilterObservation(env, env.obs_keys)) wrapped_env = FlattenObservation(FilterObservation(env, list(env.obs_keys)))
assert wrapped_env.observation_space.shape == flat_shape assert wrapped_env.observation_space.shape == flat_shape
assert wrapped_env.observation_space.dtype == np.float32 assert wrapped_env.observation_space.dtype == np.float32
@@ -114,7 +114,7 @@ class TestNestedDictWrapper:
@pytest.mark.parametrize("observation_space, flat_shape", NESTED_DICT_TEST_CASES) @pytest.mark.parametrize("observation_space, flat_shape", NESTED_DICT_TEST_CASES)
def test_nested_dicts_ravel(self, observation_space, flat_shape): def test_nested_dicts_ravel(self, observation_space, flat_shape):
env = FakeEnvironment(observation_space=observation_space) env = FakeEnvironment(observation_space=observation_space)
wrapped_env = FlattenObservation(FilterObservation(env, env.obs_keys)) wrapped_env = FlattenObservation(FilterObservation(env, list(env.obs_keys)))
obs, info = wrapped_env.reset() obs, info = wrapped_env.reset()
assert obs.shape == wrapped_env.observation_space.shape assert obs.shape == wrapped_env.observation_space.shape
assert isinstance(info, dict) assert isinstance(info, dict)

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
import gymnasium as gym import gymnasium as gym
def has_wrapper(wrapped_env: gym.Env, wrapper_type: type) -> bool: def has_wrapper(wrapped_env: gym.Env, wrapper_type: type[gym.Wrapper]) -> bool:
while isinstance(wrapped_env, gym.Wrapper): while isinstance(wrapped_env, gym.Wrapper):
if isinstance(wrapped_env, wrapper_type): if isinstance(wrapped_env, wrapper_type):
return True return True