mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-07-31 05:44:31 +00:00
Update EnvSpec
and make
to support reproducing the "whole" environment spec including wrappers (#292)
Co-authored-by: will <will2346@live.co.uk> Co-authored-by: Will Dudley <14932240+WillDudley@users.noreply.github.com> Co-authored-by: Mark Towers <marktowers@Marks-MacBook-Pro.local>
This commit is contained in:
@@ -1,12 +1,13 @@
|
||||
"""Core API for Environment, Wrapper, ActionWrapper, RewardWrapper and ObservationWrapper."""
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import TYPE_CHECKING, Any, Generic, SupportsFloat, TypeVar
|
||||
|
||||
import numpy as np
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.utils import seeding
|
||||
from gymnasium.utils import RecordConstructorArgs, seeding
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -277,8 +278,32 @@ class Wrapper(
|
||||
|
||||
@property
|
||||
def spec(self) -> EnvSpec | None:
|
||||
"""Returns the :attr:`Env` :attr:`spec` attribute."""
|
||||
return self.env.spec
|
||||
"""Returns the :attr:`Env` :attr:`spec` attribute with the `WrapperSpec` if the wrapper inherits from `EzPickle`."""
|
||||
env_spec = self.env.spec
|
||||
|
||||
if env_spec is not None:
|
||||
from gymnasium.envs.registration import WrapperSpec
|
||||
|
||||
# See if the wrapper inherits from `RecordConstructorArgs` then add the kwargs otherwise use `None` for the wrapper kwargs. This will raise an error in `make`
|
||||
if isinstance(self, RecordConstructorArgs):
|
||||
kwargs = getattr(self, "_saved_kwargs")
|
||||
if "env" in kwargs:
|
||||
kwargs = deepcopy(kwargs)
|
||||
kwargs.pop("env")
|
||||
else:
|
||||
kwargs = None
|
||||
|
||||
wrapper_spec = WrapperSpec(
|
||||
name=self.class_name(),
|
||||
entry_point=f"{self.__module__}:{type(self).__name__}",
|
||||
kwargs=kwargs,
|
||||
)
|
||||
|
||||
# to avoid reference issues we deepcopy the prior environments spec and add the new information
|
||||
env_spec = deepcopy(env_spec)
|
||||
env_spec.applied_wrappers += (wrapper_spec,)
|
||||
|
||||
return env_spec
|
||||
|
||||
@classmethod
|
||||
def class_name(cls) -> str:
|
||||
@@ -409,7 +434,7 @@ class ObservationWrapper(Wrapper[WrapperObsType, ActType, ObsType, ActType]):
|
||||
|
||||
def __init__(self, env: Env[ObsType, ActType]):
|
||||
"""Constructor for the observation wrapper."""
|
||||
super().__init__(env)
|
||||
Wrapper.__init__(self, env)
|
||||
|
||||
def reset(
|
||||
self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
||||
@@ -449,7 +474,7 @@ class RewardWrapper(Wrapper[ObsType, ActType, ObsType, ActType]):
|
||||
|
||||
def __init__(self, env: Env[ObsType, ActType]):
|
||||
"""Constructor for the Reward wrapper."""
|
||||
super().__init__(env)
|
||||
Wrapper.__init__(self, env)
|
||||
|
||||
def step(
|
||||
self, action: ActType
|
||||
@@ -485,7 +510,7 @@ class ActionWrapper(Wrapper[ObsType, WrapperActType, ObsType, ActType]):
|
||||
|
||||
def __init__(self, env: Env[ObsType, ActType]):
|
||||
"""Constructor for the action wrapper."""
|
||||
super().__init__(env)
|
||||
Wrapper.__init__(self, env)
|
||||
|
||||
def step(
|
||||
self, action: WrapperActType
|
||||
|
@@ -3,9 +3,11 @@ from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import copy
|
||||
import dataclasses
|
||||
import difflib
|
||||
import importlib
|
||||
import importlib.util
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
import traceback
|
||||
@@ -17,13 +19,13 @@ from gymnasium import Env, Wrapper, error, logger
|
||||
from gymnasium.experimental.vector import AsyncVectorEnv, SyncVectorEnv, VectorEnv
|
||||
from gymnasium.wrappers import (
|
||||
AutoResetWrapper,
|
||||
EnvCompatibility,
|
||||
HumanRendering,
|
||||
OrderEnforcing,
|
||||
PassiveEnvChecker,
|
||||
RenderCollection,
|
||||
TimeLimit,
|
||||
)
|
||||
from gymnasium.wrappers.compatibility import EnvCompatibility
|
||||
from gymnasium.wrappers.env_checker import PassiveEnvChecker
|
||||
|
||||
|
||||
if sys.version_info < (3, 10):
|
||||
@@ -43,11 +45,14 @@ ENV_ID_RE = re.compile(
|
||||
|
||||
|
||||
__all__ = [
|
||||
"EnvSpec",
|
||||
"registry",
|
||||
"current_namespace",
|
||||
"EnvSpec",
|
||||
"WrapperSpec",
|
||||
# Functions
|
||||
"register",
|
||||
"make",
|
||||
"make_vec",
|
||||
"spec",
|
||||
"pprint_registry",
|
||||
]
|
||||
@@ -67,6 +72,20 @@ class VectorEnvCreator(Protocol):
|
||||
...
|
||||
|
||||
|
||||
@dataclass
|
||||
class WrapperSpec:
|
||||
"""A specification for recording wrapper configs.
|
||||
|
||||
* name: The name of the wrapper.
|
||||
* entry_point: The location of the wrapper to create from.
|
||||
* kwargs: Additional keyword arguments passed to the wrapper. If the wrapper doesn't inherit from EzPickle then this is ``None``
|
||||
"""
|
||||
|
||||
name: str
|
||||
entry_point: str
|
||||
kwargs: dict[str, Any] | None
|
||||
|
||||
|
||||
@dataclass
|
||||
class EnvSpec:
|
||||
"""A specification for creating environments with :meth:`gymnasium.make`.
|
||||
@@ -80,6 +99,7 @@ class EnvSpec:
|
||||
* **autoreset**: If to automatically reset the environment on episode end
|
||||
* **disable_env_checker**: If to disable the environment checker wrapper in :meth:`gymnasium.make`, by default False (runs the environment checker)
|
||||
* **kwargs**: Additional keyword arguments passed to the environment during initialisation
|
||||
* **applied_wrappers**: A tuple of applied wrappers (WrapperSpec)
|
||||
* **vector_entry_point**: The location of the vectorized environment to create from
|
||||
"""
|
||||
|
||||
@@ -97,27 +117,152 @@ class EnvSpec:
|
||||
disable_env_checker: bool = field(default=False)
|
||||
apply_api_compatibility: bool = field(default=False)
|
||||
|
||||
# Environment arguments
|
||||
kwargs: dict = field(default_factory=dict)
|
||||
|
||||
# post-init attributes
|
||||
namespace: str | None = field(init=False)
|
||||
name: str = field(init=False)
|
||||
version: int | None = field(init=False)
|
||||
|
||||
# Environment arguments
|
||||
kwargs: dict = field(default_factory=dict)
|
||||
# applied wrappers
|
||||
applied_wrappers: tuple[WrapperSpec, ...] = field(init=False, default_factory=tuple)
|
||||
|
||||
# Vectorized environment
|
||||
vector_entry_point: str | None = field(default=None)
|
||||
# Vectorized environment entry point
|
||||
vector_entry_point: VectorEnvCreator | str | None = field(default=None)
|
||||
|
||||
def __post_init__(self):
|
||||
"""Calls after the spec is created to extract the namespace, name and version from the id."""
|
||||
# Initialize namespace, name, version
|
||||
"""Calls after the spec is created to extract the namespace, name and version from the environment id."""
|
||||
self.namespace, self.name, self.version = parse_env_id(self.id)
|
||||
|
||||
def make(self, **kwargs: Any) -> Env:
|
||||
"""Calls ``make`` using the environment spec and any keyword arguments."""
|
||||
# For compatibility purposes
|
||||
return make(self, **kwargs)
|
||||
|
||||
def to_json(self) -> str:
|
||||
"""Converts the environment spec into a json compatible string.
|
||||
|
||||
Returns:
|
||||
A jsonifyied string for the environment spec
|
||||
"""
|
||||
env_spec_dict = dataclasses.asdict(self)
|
||||
# As the namespace, name and version are initialised after `init` then we remove the attributes
|
||||
env_spec_dict.pop("namespace")
|
||||
env_spec_dict.pop("name")
|
||||
env_spec_dict.pop("version")
|
||||
|
||||
# To check that the environment spec can be transformed to a json compatible type
|
||||
self._check_can_jsonify(env_spec_dict)
|
||||
|
||||
return json.dumps(env_spec_dict)
|
||||
|
||||
@staticmethod
|
||||
def _check_can_jsonify(env_spec: dict[str, Any]):
|
||||
"""Warns the user about serialisation failing if the spec contains a callable.
|
||||
|
||||
Args:
|
||||
env_spec: An environment or wrapper specification.
|
||||
|
||||
Returns: The specification with lambda functions converted to strings.
|
||||
|
||||
"""
|
||||
spec_name = env_spec["name"] if "name" in env_spec else env_spec["id"]
|
||||
|
||||
for key, value in env_spec.items():
|
||||
if callable(value):
|
||||
ValueError(
|
||||
f"Callable found in {spec_name} for {key} attribute with value={value}. Currently, Gymnasium does not support serialising callables."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_json(json_env_spec: str) -> EnvSpec:
|
||||
"""Converts a JSON string into a specification stack.
|
||||
|
||||
Args:
|
||||
json_env_spec: A JSON string representing the env specification.
|
||||
|
||||
Returns:
|
||||
An environment spec
|
||||
"""
|
||||
parsed_env_spec = json.loads(json_env_spec)
|
||||
|
||||
applied_wrapper_specs: list[WrapperSpec] = []
|
||||
for wrapper_spec_json in parsed_env_spec.pop("applied_wrappers"):
|
||||
try:
|
||||
applied_wrapper_specs.append(WrapperSpec(**wrapper_spec_json))
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"An issue occurred when trying to make {wrapper_spec_json} a WrapperSpec"
|
||||
) from e
|
||||
|
||||
try:
|
||||
env_spec = EnvSpec(**parsed_env_spec)
|
||||
env_spec.applied_wrappers = tuple(applied_wrapper_specs)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"An issue occurred when trying to make {parsed_env_spec} an EnvSpec"
|
||||
) from e
|
||||
|
||||
return env_spec
|
||||
|
||||
def pprint(
|
||||
self,
|
||||
disable_print: bool = False,
|
||||
include_entry_points: bool = False,
|
||||
print_all: bool = False,
|
||||
) -> str | None:
|
||||
"""Pretty prints the environment spec.
|
||||
|
||||
Args:
|
||||
disable_print: If to disable print and return the output
|
||||
include_entry_points: If to include the entry_points in the output
|
||||
print_all: If to print all information, including variables with default values
|
||||
|
||||
Returns:
|
||||
If ``disable_print is True`` a string otherwise ``None``
|
||||
"""
|
||||
output = f"id={self.id}"
|
||||
if print_all or include_entry_points:
|
||||
output += f"\nentry_point={self.entry_point}"
|
||||
|
||||
if print_all or self.reward_threshold is not None:
|
||||
output += f"\nreward_threshold={self.reward_threshold}"
|
||||
if print_all or self.nondeterministic is not False:
|
||||
output += f"\nnondeterministic={self.nondeterministic}"
|
||||
|
||||
if print_all or self.max_episode_steps is not None:
|
||||
output += f"\nmax_episode_steps={self.max_episode_steps}"
|
||||
if print_all or self.order_enforce is not True:
|
||||
output += f"\norder_enforce={self.order_enforce}"
|
||||
if print_all or self.autoreset is not False:
|
||||
output += f"\nautoreset={self.autoreset}"
|
||||
if print_all or self.disable_env_checker is not False:
|
||||
output += f"\ndisable_env_checker={self.disable_env_checker}"
|
||||
if print_all or self.apply_api_compatibility is not False:
|
||||
output += f"\napplied_api_compatibility={self.apply_api_compatibility}"
|
||||
|
||||
if print_all or self.applied_wrappers:
|
||||
wrapper_output: list[str] = []
|
||||
for wrapper_spec in self.applied_wrappers:
|
||||
if include_entry_points:
|
||||
wrapper_output.append(
|
||||
f"\n\tname={wrapper_spec.name}, entry_point={wrapper_spec.entry_point}, kwargs={wrapper_spec.kwargs}"
|
||||
)
|
||||
else:
|
||||
wrapper_output.append(
|
||||
f"\n\tname={wrapper_spec.name}, kwargs={wrapper_spec.kwargs}"
|
||||
)
|
||||
|
||||
if len(wrapper_output) == 0:
|
||||
output += "\napplied_wrappers=[]"
|
||||
else:
|
||||
output += f"\napplied_wrappers=[{','.join(wrapper_output)}\n]"
|
||||
|
||||
if disable_print:
|
||||
return output
|
||||
else:
|
||||
print(output)
|
||||
|
||||
|
||||
# Global registry of environments. Meant to be accessed through `register` and `make`
|
||||
registry: dict[str, EnvSpec] = {}
|
||||
@@ -352,8 +497,12 @@ def _check_metadata(testing_metadata: dict[str, Any]):
|
||||
)
|
||||
|
||||
|
||||
def _find_spec(id: str) -> EnvSpec:
|
||||
module, env_name = (None, id) if ":" not in id else id.split(":")
|
||||
def _find_spec(env_id: str) -> EnvSpec:
|
||||
# For string id's, load the environment spec from the registry then make the environment spec
|
||||
assert isinstance(env_id, str)
|
||||
|
||||
# The environment name can include an unloaded module in "module:env_name" style
|
||||
module, env_name = (None, env_id) if ":" not in env_id else env_id.split(":")
|
||||
if module is not None:
|
||||
try:
|
||||
importlib.import_module(module)
|
||||
@@ -391,7 +540,7 @@ def _find_spec(id: str) -> EnvSpec:
|
||||
return env_spec
|
||||
|
||||
|
||||
def load_env(name: str) -> EnvCreator:
|
||||
def load_env_creator(name: str) -> EnvCreator | VectorEnvCreator:
|
||||
"""Loads an environment with name of style ``"(import path):(environment name)"`` and returns the environment creation function, normally the environment class type.
|
||||
|
||||
Args:
|
||||
@@ -406,6 +555,161 @@ def load_env(name: str) -> EnvCreator:
|
||||
return fn
|
||||
|
||||
|
||||
def _create_from_env_spec(
|
||||
env_spec: EnvSpec,
|
||||
kwargs: dict[str, Any],
|
||||
) -> Env:
|
||||
"""Recreates an environment spec using a list of wrapper specs."""
|
||||
if callable(env_spec.entry_point):
|
||||
env_creator = env_spec.entry_point
|
||||
else:
|
||||
env_creator: EnvCreator = load_env_creator(env_spec.entry_point)
|
||||
|
||||
# Create the environment
|
||||
env: Env = env_creator(**env_spec.kwargs, **kwargs)
|
||||
|
||||
# Set the `EnvSpec` to the environment
|
||||
new_env_spec = copy.deepcopy(env_spec)
|
||||
new_env_spec.applied_wrappers = ()
|
||||
new_env_spec.kwargs.update(kwargs)
|
||||
env.unwrapped.spec = new_env_spec
|
||||
|
||||
# Check if the environment spec
|
||||
assert env.spec is not None # this is for pyright
|
||||
num_prior_wrappers = len(env.spec.applied_wrappers)
|
||||
if env_spec.applied_wrappers[:num_prior_wrappers] != env.spec.applied_wrappers:
|
||||
for env_spec_wrapper_spec, recreated_wrapper_spec in zip(
|
||||
env_spec.applied_wrappers, env.spec.applied_wrappers
|
||||
):
|
||||
raise ValueError(
|
||||
f"The environment's wrapper spec {recreated_wrapper_spec} is different from the saved `EnvSpec` applied_wrappers {env_spec_wrapper_spec}"
|
||||
)
|
||||
|
||||
for wrapper_spec in env_spec.applied_wrappers[num_prior_wrappers:]:
|
||||
if wrapper_spec.kwargs is None:
|
||||
raise ValueError(
|
||||
f"{wrapper_spec.name} wrapper does not inherit from `gymnasium.utils.RecordConstructorArgs`, therefore, the wrapper cannot be recreated."
|
||||
)
|
||||
|
||||
env = load_env_creator(wrapper_spec.entry_point)(env=env, **wrapper_spec.kwargs)
|
||||
|
||||
return env
|
||||
|
||||
|
||||
def _create_from_env_id(
|
||||
env_spec: EnvSpec,
|
||||
kwargs: dict[str, Any],
|
||||
max_episode_steps: int | None = None,
|
||||
autoreset: bool = False,
|
||||
apply_api_compatibility: bool | None = None,
|
||||
disable_env_checker: bool | None = None,
|
||||
) -> Env:
|
||||
"""Creates an environment based on the `env_spec` along with wrapper options. See `make` for their meaning."""
|
||||
spec_kwargs = copy.deepcopy(env_spec.kwargs)
|
||||
spec_kwargs.update(kwargs)
|
||||
|
||||
# Load the environment creator
|
||||
if env_spec.entry_point is None:
|
||||
raise error.Error(f"{env_spec.id} registered but entry_point is not specified")
|
||||
elif callable(env_spec.entry_point):
|
||||
env_creator = env_spec.entry_point
|
||||
else:
|
||||
# Assume it's a string
|
||||
env_creator = load_env_creator(env_spec.entry_point)
|
||||
|
||||
# Determine if to use the rendering
|
||||
render_modes: list[str] | None = None
|
||||
if hasattr(env_creator, "metadata"):
|
||||
_check_metadata(env_creator.metadata)
|
||||
render_modes = env_creator.metadata.get("render_modes")
|
||||
mode = spec_kwargs.get("render_mode")
|
||||
apply_human_rendering = False
|
||||
apply_render_collection = False
|
||||
|
||||
# If mode is not valid, try applying HumanRendering/RenderCollection wrappers
|
||||
if mode is not None and render_modes is not None and mode not in render_modes:
|
||||
displayable_modes = {"rgb_array", "rgb_array_list"}.intersection(render_modes)
|
||||
if mode == "human" and len(displayable_modes) > 0:
|
||||
logger.warn(
|
||||
"You are trying to use 'human' rendering for an environment that doesn't natively support it. "
|
||||
"The HumanRendering wrapper is being applied to your environment."
|
||||
)
|
||||
spec_kwargs["render_mode"] = displayable_modes.pop()
|
||||
apply_human_rendering = True
|
||||
elif mode.endswith("_list") and mode[: -len("_list")] in render_modes:
|
||||
spec_kwargs["render_mode"] = mode[: -len("_list")]
|
||||
apply_render_collection = True
|
||||
else:
|
||||
logger.warn(
|
||||
f"The environment is being initialised with render_mode={mode!r} "
|
||||
f"that is not in the possible render_modes ({render_modes})."
|
||||
)
|
||||
|
||||
if apply_api_compatibility or (
|
||||
apply_api_compatibility is None and env_spec.apply_api_compatibility
|
||||
):
|
||||
# If we use the compatibility layer, we treat the render mode explicitly and don't pass it to the env creator
|
||||
render_mode = spec_kwargs.pop("render_mode", None)
|
||||
else:
|
||||
render_mode = None
|
||||
|
||||
try:
|
||||
env = env_creator(**spec_kwargs)
|
||||
except TypeError as e:
|
||||
if (
|
||||
str(e).find("got an unexpected keyword argument 'render_mode'") >= 0
|
||||
and apply_human_rendering
|
||||
):
|
||||
raise error.Error(
|
||||
f"You passed render_mode='human' although {env_spec.id} doesn't implement human-rendering natively. "
|
||||
"Gym tried to apply the HumanRendering wrapper but it looks like your environment is using the old "
|
||||
"rendering API, which is not supported by the HumanRendering wrapper."
|
||||
) from e
|
||||
else:
|
||||
raise e
|
||||
|
||||
# Copies the environment creation specification and kwargs to add to the environment specification details
|
||||
env_spec = copy.deepcopy(env_spec)
|
||||
env_spec.kwargs = spec_kwargs
|
||||
env.unwrapped.spec = env_spec
|
||||
|
||||
# Add step API wrapper
|
||||
if apply_api_compatibility is True or (
|
||||
apply_api_compatibility is None and env_spec.apply_api_compatibility is True
|
||||
):
|
||||
env = EnvCompatibility(env, render_mode)
|
||||
|
||||
# Run the environment checker as the lowest level wrapper
|
||||
if disable_env_checker is False or (
|
||||
disable_env_checker is None and env_spec.disable_env_checker is False
|
||||
):
|
||||
env = PassiveEnvChecker(env)
|
||||
|
||||
# Add the order enforcing wrapper
|
||||
if env_spec.order_enforce:
|
||||
env = OrderEnforcing(env)
|
||||
|
||||
# Add the time limit wrapper
|
||||
if max_episode_steps is not None:
|
||||
assert env.unwrapped.spec is not None # for pyright
|
||||
env.unwrapped.spec.max_episode_steps = max_episode_steps
|
||||
env = TimeLimit(env, max_episode_steps)
|
||||
elif env_spec.max_episode_steps is not None:
|
||||
env = TimeLimit(env, env_spec.max_episode_steps)
|
||||
|
||||
# Add the auto-reset wrapper
|
||||
if autoreset:
|
||||
env = AutoResetWrapper(env)
|
||||
|
||||
# Add human rendering wrapper
|
||||
if apply_human_rendering:
|
||||
env = HumanRendering(env)
|
||||
elif apply_render_collection:
|
||||
env = RenderCollection(env)
|
||||
|
||||
return env
|
||||
|
||||
|
||||
def load_plugin_envs(entry_point: str = "gymnasium.envs"):
|
||||
"""Load modules (plugins) using the gymnasium entry points in order to register external module's environments on ``import gymnasium``.
|
||||
|
||||
@@ -437,10 +741,8 @@ def load_plugin_envs(entry_point: str = "gymnasium.envs"):
|
||||
|
||||
context = namespace(plugin.name)
|
||||
if plugin.name.startswith("__") and plugin.name.endswith("__"):
|
||||
# `__internal__` is an artifact of the plugin system when
|
||||
# the root namespace had an allow-list. The allow-list is now
|
||||
# removed and plugins can register environments in the root
|
||||
# namespace with the `__root__` magic key.
|
||||
# `__internal__` is an artifact of the plugin system when the root namespace had an allow-list.
|
||||
# The allow-list is now removed and plugins can register environments in the root namespace with the `__root__` magic key.
|
||||
if plugin.name == "__root__" or plugin.name == "__internal__":
|
||||
context = contextlib.nullcontext()
|
||||
else:
|
||||
@@ -533,9 +835,9 @@ def register(
|
||||
order_enforce=order_enforce,
|
||||
autoreset=autoreset,
|
||||
disable_env_checker=disable_env_checker,
|
||||
**kwargs,
|
||||
apply_api_compatibility=apply_api_compatibility,
|
||||
vector_entry_point=vector_entry_point,
|
||||
**kwargs,
|
||||
)
|
||||
_check_spec_register(new_spec)
|
||||
|
||||
@@ -576,116 +878,47 @@ def make(
|
||||
Error: If the ``id`` doesn't exist in the :attr:`registry`
|
||||
"""
|
||||
if isinstance(id, EnvSpec):
|
||||
env_spec = id
|
||||
if hasattr(id, "applied_wrappers") and id.applied_wrappers is not None:
|
||||
if max_episode_steps is not None:
|
||||
logger.warn(
|
||||
f"For `gymnasium.make` with an `EnvSpec`, the `max_episode_step` parameter is ignored, use `gym.make({id.id}, max_episode_steps={max_episode_steps})` and any additional wrappers"
|
||||
)
|
||||
if autoreset is True:
|
||||
logger.warn(
|
||||
f"For `gymnasium.make` with an `EnvSpec`, the `autoreset` parameter is ignored, use `gym.make({id.id}, autoreset={autoreset})` and any additional wrappers"
|
||||
)
|
||||
if apply_api_compatibility is not None:
|
||||
logger.warn(
|
||||
f"For `gymnasium.make` with an `EnvSpec`, the `apply_api_compatibility` parameter is ignored, use `gym.make({id.id}, apply_api_compatibility={apply_api_compatibility})` and any additional wrappers"
|
||||
)
|
||||
if disable_env_checker is not None:
|
||||
logger.warn(
|
||||
f"For `gymnasium.make` with an `EnvSpec`, the `disable_env_checker` parameter is ignored, use `gym.make({id.id}, disable_env_checker={disable_env_checker})` and any additional wrappers"
|
||||
)
|
||||
|
||||
return _create_from_env_spec(
|
||||
id,
|
||||
kwargs,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"The EnvSpec used does not contain `applied_wrappers` parameters or is `None`. Expected to be a tuple, actually {id}."
|
||||
)
|
||||
else:
|
||||
# For string id's, load the environment spec from the registry then make the environment spec
|
||||
assert isinstance(id, str)
|
||||
|
||||
# The environment name can include an unloaded module in "module:env_name" style
|
||||
env_spec = _find_spec(id)
|
||||
|
||||
assert isinstance(
|
||||
env_spec, EnvSpec
|
||||
), f"We expected to collect an `EnvSpec`, actually collected a {type(env_spec)}"
|
||||
# Extract the spec kwargs and append the make kwargs
|
||||
spec_kwargs = env_spec.kwargs.copy()
|
||||
spec_kwargs.update(kwargs)
|
||||
|
||||
# Load the environment creator
|
||||
if env_spec.entry_point is None:
|
||||
raise error.Error(f"{env_spec.id} registered but entry_point is not specified")
|
||||
elif callable(env_spec.entry_point):
|
||||
env_creator = env_spec.entry_point
|
||||
else:
|
||||
# Assume it's a string
|
||||
env_creator = load_env(env_spec.entry_point)
|
||||
|
||||
# Determine if to use the rendering
|
||||
render_modes: list[str] | None = None
|
||||
if hasattr(env_creator, "metadata"):
|
||||
_check_metadata(env_creator.metadata)
|
||||
render_modes = env_creator.metadata.get("render_modes")
|
||||
mode = spec_kwargs.get("render_mode")
|
||||
apply_human_rendering = False
|
||||
apply_render_collection = False
|
||||
|
||||
# If mode is not valid, try applying HumanRendering/RenderCollection wrappers
|
||||
if mode is not None and render_modes is not None and mode not in render_modes:
|
||||
displayable_modes = {"rgb_array", "rgb_array_list"}.intersection(render_modes)
|
||||
if mode == "human" and len(displayable_modes) > 0:
|
||||
logger.warn(
|
||||
"You are trying to use 'human' rendering for an environment that doesn't natively support it. "
|
||||
"The HumanRendering wrapper is being applied to your environment."
|
||||
)
|
||||
spec_kwargs["render_mode"] = displayable_modes.pop()
|
||||
apply_human_rendering = True
|
||||
elif mode.endswith("_list") and mode[: -len("_list")] in render_modes:
|
||||
spec_kwargs["render_mode"] = mode[: -len("_list")]
|
||||
apply_render_collection = True
|
||||
else:
|
||||
logger.warn(
|
||||
f"The environment is being initialised with render_mode={mode!r} "
|
||||
f"that is not in the possible render_modes ({render_modes})."
|
||||
)
|
||||
|
||||
if apply_api_compatibility or (
|
||||
apply_api_compatibility is None and env_spec.apply_api_compatibility
|
||||
):
|
||||
# If we use the compatibility layer, we treat the render mode explicitly and don't pass it to the env creator
|
||||
render_mode = spec_kwargs.pop("render_mode", None)
|
||||
else:
|
||||
render_mode = None
|
||||
|
||||
try:
|
||||
env = env_creator(**spec_kwargs)
|
||||
except TypeError as e:
|
||||
if (
|
||||
str(e).find("got an unexpected keyword argument 'render_mode'") >= 0
|
||||
and apply_human_rendering
|
||||
):
|
||||
raise error.Error(
|
||||
f"You passed render_mode='human' although {id} doesn't implement human-rendering natively. "
|
||||
"Gym tried to apply the HumanRendering wrapper but it looks like your environment is using the old "
|
||||
"rendering API, which is not supported by the HumanRendering wrapper."
|
||||
) from e
|
||||
else:
|
||||
raise e
|
||||
|
||||
# Copies the environment creation specification and kwargs to add to the environment specification details
|
||||
env_spec = copy.deepcopy(env_spec)
|
||||
env_spec.kwargs = spec_kwargs
|
||||
env.unwrapped.spec = env_spec
|
||||
|
||||
# Add step API wrapper
|
||||
if apply_api_compatibility is True or (
|
||||
apply_api_compatibility is None and env_spec.apply_api_compatibility is True
|
||||
):
|
||||
env = EnvCompatibility(env, render_mode)
|
||||
|
||||
# Run the environment checker as the lowest level wrapper
|
||||
if disable_env_checker is False or (
|
||||
disable_env_checker is None and env_spec.disable_env_checker is False
|
||||
):
|
||||
env = PassiveEnvChecker(env)
|
||||
|
||||
# Add the order enforcing wrapper
|
||||
if env_spec.order_enforce:
|
||||
env = OrderEnforcing(env)
|
||||
|
||||
# Add the time limit wrapper
|
||||
if max_episode_steps is not None:
|
||||
env = TimeLimit(env, max_episode_steps)
|
||||
elif env_spec.max_episode_steps is not None:
|
||||
env = TimeLimit(env, env_spec.max_episode_steps)
|
||||
|
||||
# Add the autoreset wrapper
|
||||
if autoreset:
|
||||
env = AutoResetWrapper(env)
|
||||
|
||||
# Add human rendering wrapper
|
||||
if apply_human_rendering:
|
||||
env = HumanRendering(env)
|
||||
elif apply_render_collection:
|
||||
env = RenderCollection(env)
|
||||
|
||||
return env
|
||||
return _create_from_env_id(
|
||||
env_spec,
|
||||
kwargs,
|
||||
max_episode_steps=max_episode_steps,
|
||||
autoreset=autoreset,
|
||||
apply_api_compatibility=apply_api_compatibility,
|
||||
disable_env_checker=disable_env_checker,
|
||||
)
|
||||
|
||||
|
||||
def make_vec(
|
||||
@@ -752,7 +985,7 @@ def make_vec(
|
||||
env_creator = entry_point
|
||||
else:
|
||||
# Assume it's a string
|
||||
env_creator = load_env(entry_point)
|
||||
env_creator = load_env_creator(entry_point)
|
||||
|
||||
def _create_env():
|
||||
# Env creator for use with sync and async modes
|
||||
|
@@ -11,7 +11,7 @@ except ImportError:
|
||||
cv2 = None
|
||||
|
||||
|
||||
class AtariPreprocessingV0(gym.Wrapper):
|
||||
class AtariPreprocessingV0(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
||||
"""Atari 2600 preprocessing wrapper.
|
||||
|
||||
This class follows the guidelines in Machado et al. (2018),
|
||||
@@ -60,7 +60,18 @@ class AtariPreprocessingV0(gym.Wrapper):
|
||||
DependencyNotInstalled: opencv-python package not installed
|
||||
ValueError: Disable frame-skipping in the original env
|
||||
"""
|
||||
super().__init__(env)
|
||||
gym.utils.RecordConstructorArgs.__init__(
|
||||
self,
|
||||
noop_max=noop_max,
|
||||
frame_skip=frame_skip,
|
||||
screen_size=screen_size,
|
||||
terminal_on_life_loss=terminal_on_life_loss,
|
||||
grayscale_obs=grayscale_obs,
|
||||
grayscale_newaxis=grayscale_newaxis,
|
||||
scale_obs=scale_obs,
|
||||
)
|
||||
gym.Wrapper.__init__(self, env)
|
||||
|
||||
if cv2 is None:
|
||||
raise gym.error.DependencyNotInstalled(
|
||||
"opencv-python package not installed, run `pip install gymnasium[other]` to get dependencies for atari"
|
||||
|
@@ -14,7 +14,6 @@ from typing import Any, SupportsFloat
|
||||
import numpy as np
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium import Env
|
||||
from gymnasium.core import ActType, ObsType, RenderFrame
|
||||
from gymnasium.error import ResetNeeded
|
||||
from gymnasium.utils.passive_env_checker import (
|
||||
@@ -26,7 +25,9 @@ from gymnasium.utils.passive_env_checker import (
|
||||
)
|
||||
|
||||
|
||||
class AutoresetV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
|
||||
class AutoresetV0(
|
||||
gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
|
||||
):
|
||||
"""A class for providing an automatic reset functionality for gymnasium environments when calling :meth:`self.step`."""
|
||||
|
||||
def __init__(self, env: gym.Env[ObsType, ActType]):
|
||||
@@ -35,7 +36,9 @@ class AutoresetV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
|
||||
Args:
|
||||
env (gym.Env): The environment to apply the wrapper
|
||||
"""
|
||||
super().__init__(env)
|
||||
gym.utils.RecordConstructorArgs.__init__(self)
|
||||
gym.Wrapper.__init__(self, env)
|
||||
|
||||
self._episode_ended: bool = False
|
||||
self._reset_options: dict[str, Any] | None = None
|
||||
|
||||
@@ -68,12 +71,15 @@ class AutoresetV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
|
||||
return super().reset(seed=seed, options=self._reset_options)
|
||||
|
||||
|
||||
class PassiveEnvCheckerV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
|
||||
class PassiveEnvCheckerV0(
|
||||
gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
|
||||
):
|
||||
"""A passive environment checker wrapper that surrounds the step, reset and render functions to check they follow the gymnasium API."""
|
||||
|
||||
def __init__(self, env: Env[ObsType, ActType]):
|
||||
def __init__(self, env: gym.Env[ObsType, ActType]):
|
||||
"""Initialises the wrapper with the environments, run the observation and action space tests."""
|
||||
super().__init__(env)
|
||||
gym.utils.RecordConstructorArgs.__init__(self)
|
||||
gym.Wrapper.__init__(self, env)
|
||||
|
||||
assert hasattr(
|
||||
env, "action_space"
|
||||
@@ -117,7 +123,9 @@ class PassiveEnvCheckerV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
|
||||
return self.env.render()
|
||||
|
||||
|
||||
class OrderEnforcingV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
|
||||
class OrderEnforcingV0(
|
||||
gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
|
||||
):
|
||||
"""A wrapper that will produce an error if :meth:`step` is called before an initial :meth:`reset`.
|
||||
|
||||
Example:
|
||||
@@ -150,7 +158,11 @@ class OrderEnforcingV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
|
||||
env: The environment to wrap
|
||||
disable_render_order_enforcing: If to disable render order enforcing
|
||||
"""
|
||||
super().__init__(env)
|
||||
gym.utils.RecordConstructorArgs.__init__(
|
||||
self, disable_render_order_enforcing=disable_render_order_enforcing
|
||||
)
|
||||
gym.Wrapper.__init__(self, env)
|
||||
|
||||
self._has_reset: bool = False
|
||||
self._disable_render_order_enforcing: bool = disable_render_order_enforcing
|
||||
|
||||
@@ -182,7 +194,9 @@ class OrderEnforcingV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
|
||||
return self._has_reset
|
||||
|
||||
|
||||
class RecordEpisodeStatisticsV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
|
||||
class RecordEpisodeStatisticsV0(
|
||||
gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
|
||||
):
|
||||
"""This wrapper will keep track of cumulative rewards and episode lengths.
|
||||
|
||||
At the end of an episode, the statistics of the episode will be added to ``info``
|
||||
@@ -226,7 +240,7 @@ class RecordEpisodeStatisticsV0(gym.Wrapper[ObsType, ActType, ObsType, ActType])
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env: Env[ObsType, ActType],
|
||||
env: gym.Env[ObsType, ActType],
|
||||
buffer_length: int | None = 100,
|
||||
stats_key: str = "episode",
|
||||
):
|
||||
@@ -237,7 +251,8 @@ class RecordEpisodeStatisticsV0(gym.Wrapper[ObsType, ActType, ObsType, ActType])
|
||||
buffer_length: The size of the buffers :attr:`return_queue` and :attr:`length_queue`
|
||||
stats_key: The info key for the episode statistics
|
||||
"""
|
||||
super().__init__(env)
|
||||
gym.utils.RecordConstructorArgs.__init__(self)
|
||||
gym.Wrapper.__init__(self, env)
|
||||
|
||||
self._stats_key = stats_key
|
||||
|
||||
|
@@ -8,7 +8,7 @@ from typing import Any, Iterable, Mapping, SupportsFloat
|
||||
|
||||
import numpy as np
|
||||
|
||||
from gymnasium import Env, Wrapper
|
||||
import gymnasium as gym
|
||||
from gymnasium.core import ActType, ObsType, RenderFrame, WrapperActType, WrapperObsType
|
||||
from gymnasium.error import DependencyNotInstalled
|
||||
|
||||
@@ -92,7 +92,10 @@ if jnp is not None:
|
||||
return type(value)(jax_to_numpy(v) for v in value)
|
||||
|
||||
|
||||
class JaxToNumpyV0(Wrapper[WrapperObsType, WrapperActType, ObsType, ActType]):
|
||||
class JaxToNumpyV0(
|
||||
gym.Wrapper[WrapperObsType, WrapperActType, ObsType, ActType],
|
||||
gym.utils.RecordConstructorArgs,
|
||||
):
|
||||
"""Wraps a jax environment so that it can be interacted with through numpy arrays.
|
||||
|
||||
Actions must be provided as numpy arrays and observations will be returned as numpy arrays.
|
||||
@@ -102,7 +105,7 @@ class JaxToNumpyV0(Wrapper[WrapperObsType, WrapperActType, ObsType, ActType]):
|
||||
The reason for this is jax does not support non-array values, therefore numpy ``int_32(5) -> DeviceArray([5], dtype=jnp.int23)``
|
||||
"""
|
||||
|
||||
def __init__(self, env: Env[ObsType, ActType]):
|
||||
def __init__(self, env: gym.Env[ObsType, ActType]):
|
||||
"""Wraps an environment such that the input and outputs are numpy arrays.
|
||||
|
||||
Args:
|
||||
@@ -112,7 +115,8 @@ class JaxToNumpyV0(Wrapper[WrapperObsType, WrapperActType, ObsType, ActType]):
|
||||
raise DependencyNotInstalled(
|
||||
"jax is not installed, run `pip install gymnasium[jax]`"
|
||||
)
|
||||
super().__init__(env)
|
||||
gym.utils.RecordConstructorArgs.__init__(self)
|
||||
gym.Wrapper.__init__(self, env)
|
||||
|
||||
def step(
|
||||
self, action: WrapperActType
|
||||
|
@@ -14,7 +14,7 @@ import numbers
|
||||
from collections import abc
|
||||
from typing import Any, Iterable, Mapping, SupportsFloat, Union
|
||||
|
||||
from gymnasium import Env, Wrapper
|
||||
import gymnasium as gym
|
||||
from gymnasium.core import RenderFrame, WrapperActType, WrapperObsType
|
||||
from gymnasium.error import DependencyNotInstalled
|
||||
from gymnasium.experimental.wrappers.conversion.jax_to_numpy import jax_to_numpy
|
||||
@@ -131,7 +131,7 @@ if torch is not None and jnp is not None:
|
||||
return type(value)(jax_to_torch(v, device) for v in value)
|
||||
|
||||
|
||||
class JaxToTorchV0(Wrapper):
|
||||
class JaxToTorchV0(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
||||
"""Wraps a jax-based environment so that it can be interacted with through PyTorch Tensors.
|
||||
|
||||
Actions must be provided as PyTorch Tensors and observations will be returned as PyTorch Tensors.
|
||||
@@ -140,7 +140,7 @@ class JaxToTorchV0(Wrapper):
|
||||
For ``rendered`` this is returned as a NumPy array not a pytorch Tensor.
|
||||
"""
|
||||
|
||||
def __init__(self, env: Env, device: Device | None = None):
|
||||
def __init__(self, env: gym.Env, device: Device | None = None):
|
||||
"""Wrapper class to change inputs and outputs of environment to PyTorch tensors.
|
||||
|
||||
Args:
|
||||
@@ -156,7 +156,9 @@ class JaxToTorchV0(Wrapper):
|
||||
"jax is not installed, run `pip install gymnasium[jax]`"
|
||||
)
|
||||
|
||||
super().__init__(env)
|
||||
gym.utils.RecordConstructorArgs.__init__(self, device=device)
|
||||
gym.Wrapper.__init__(self, env)
|
||||
|
||||
self.device: Device | None = device
|
||||
|
||||
def step(
|
||||
|
@@ -8,7 +8,7 @@ from typing import Any, Iterable, Mapping, SupportsFloat, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from gymnasium import Env, Wrapper
|
||||
import gymnasium as gym
|
||||
from gymnasium.core import WrapperActType, WrapperObsType
|
||||
from gymnasium.error import DependencyNotInstalled
|
||||
|
||||
@@ -94,7 +94,7 @@ if torch is not None:
|
||||
return type(value)(numpy_to_torch(v, device) for v in value)
|
||||
|
||||
|
||||
class NumpyToTorchV0(Wrapper):
|
||||
class NumpyToTorchV0(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
||||
"""Wraps a numpy-based environment so that it can be interacted with through PyTorch Tensors.
|
||||
|
||||
Actions must be provided as PyTorch Tensors and observations will be returned as PyTorch Tensors.
|
||||
@@ -103,7 +103,7 @@ class NumpyToTorchV0(Wrapper):
|
||||
For ``rendered`` this is returned as a NumPy array not a pytorch Tensor.
|
||||
"""
|
||||
|
||||
def __init__(self, env: Env, device: Device | None = None):
|
||||
def __init__(self, env: gym.Env, device: Device | None = None):
|
||||
"""Wrapper class to change inputs and outputs of environment to PyTorch tensors.
|
||||
|
||||
Args:
|
||||
@@ -115,7 +115,9 @@ class NumpyToTorchV0(Wrapper):
|
||||
"torch is not installed, run `pip install torch`"
|
||||
)
|
||||
|
||||
super().__init__(env)
|
||||
gym.utils.RecordConstructorArgs.__init__(self, device=device)
|
||||
gym.Wrapper.__init__(self, env)
|
||||
|
||||
self.device: Device | None = device
|
||||
|
||||
def step(
|
||||
|
@@ -20,7 +20,9 @@ from gymnasium.core import ActType, ObsType, WrapperActType
|
||||
from gymnasium.spaces import Box, Space
|
||||
|
||||
|
||||
class LambdaActionV0(gym.ActionWrapper[ObsType, WrapperActType, ActType]):
|
||||
class LambdaActionV0(
|
||||
gym.ActionWrapper[ObsType, WrapperActType, ActType], gym.utils.RecordConstructorArgs
|
||||
):
|
||||
"""A wrapper that provides a function to modify the action passed to :meth:`step`."""
|
||||
|
||||
def __init__(
|
||||
@@ -36,7 +38,11 @@ class LambdaActionV0(gym.ActionWrapper[ObsType, WrapperActType, ActType]):
|
||||
func: Function to apply to ``step`` ``action``
|
||||
action_space: The updated action space of the wrapper given the function.
|
||||
"""
|
||||
super().__init__(env)
|
||||
gym.utils.RecordConstructorArgs.__init__(
|
||||
self, func=func, action_space=action_space
|
||||
)
|
||||
gym.Wrapper.__init__(self, env)
|
||||
|
||||
if action_space is not None:
|
||||
self.action_space = action_space
|
||||
|
||||
@@ -47,7 +53,9 @@ class LambdaActionV0(gym.ActionWrapper[ObsType, WrapperActType, ActType]):
|
||||
return self.func(action)
|
||||
|
||||
|
||||
class ClipActionV0(LambdaActionV0[ObsType, WrapperActType, ActType]):
|
||||
class ClipActionV0(
|
||||
LambdaActionV0[ObsType, WrapperActType, ActType], gym.utils.RecordConstructorArgs
|
||||
):
|
||||
"""Clip the continuous action within the valid :class:`Box` observation space bound.
|
||||
|
||||
Example:
|
||||
@@ -71,10 +79,14 @@ class ClipActionV0(LambdaActionV0[ObsType, WrapperActType, ActType]):
|
||||
"""
|
||||
assert isinstance(env.action_space, Box)
|
||||
|
||||
super().__init__(
|
||||
env,
|
||||
lambda action: jp.clip(action, env.action_space.low, env.action_space.high),
|
||||
Box(
|
||||
gym.utils.RecordConstructorArgs.__init__(self)
|
||||
LambdaActionV0.__init__(
|
||||
self,
|
||||
env=env,
|
||||
func=lambda action: jp.clip(
|
||||
action, env.action_space.low, env.action_space.high
|
||||
),
|
||||
action_space=Box(
|
||||
-np.inf,
|
||||
np.inf,
|
||||
shape=env.action_space.shape,
|
||||
@@ -83,7 +95,9 @@ class ClipActionV0(LambdaActionV0[ObsType, WrapperActType, ActType]):
|
||||
)
|
||||
|
||||
|
||||
class RescaleActionV0(LambdaActionV0[ObsType, WrapperActType, ActType]):
|
||||
class RescaleActionV0(
|
||||
LambdaActionV0[ObsType, WrapperActType, ActType], gym.utils.RecordConstructorArgs
|
||||
):
|
||||
"""Affinely rescales the continuous action space of the environment to the range [min_action, max_action].
|
||||
|
||||
The base environment :attr:`env` must have an action space of type :class:`spaces.Box`. If :attr:`min_action`
|
||||
@@ -118,6 +132,10 @@ class RescaleActionV0(LambdaActionV0[ObsType, WrapperActType, ActType]):
|
||||
min_action (float, int or np.ndarray): The min values for each action. This may be a numpy array or a scalar.
|
||||
max_action (float, int or np.ndarray): The max values for each action. This may be a numpy array or a scalar.
|
||||
"""
|
||||
gym.utils.RecordConstructorArgs.__init__(
|
||||
self, min_action=min_action, max_action=max_action
|
||||
)
|
||||
|
||||
assert isinstance(env.action_space, Box)
|
||||
assert not np.any(env.action_space.low == np.inf) and not np.any(
|
||||
env.action_space.high == np.inf
|
||||
@@ -149,10 +167,11 @@ class RescaleActionV0(LambdaActionV0[ObsType, WrapperActType, ActType]):
|
||||
)
|
||||
intercept = gradient * -min_action + env.action_space.low
|
||||
|
||||
super().__init__(
|
||||
env,
|
||||
lambda action: gradient * action + intercept,
|
||||
Box(
|
||||
LambdaActionV0.__init__(
|
||||
self,
|
||||
env=env,
|
||||
func=lambda action: gradient * action + intercept,
|
||||
action_space=Box(
|
||||
low=min_action,
|
||||
high=max_action,
|
||||
shape=env.action_space.shape,
|
||||
|
@@ -24,14 +24,16 @@ except ImportError as e:
|
||||
import numpy as np
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium import Env, spaces
|
||||
from gymnasium.core import ActType, ObservationWrapper, ObsType, WrapperObsType
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ActType, ObsType, WrapperObsType
|
||||
from gymnasium.error import DependencyNotInstalled
|
||||
from gymnasium.experimental.wrappers.utils import RunningMeanStd
|
||||
from gymnasium.spaces import Box, Dict, utils
|
||||
|
||||
|
||||
class LambdaObservationV0(gym.ObservationWrapper[WrapperObsType, ActType, ObsType]):
|
||||
class LambdaObservationV0(
|
||||
gym.ObservationWrapper[WrapperObsType, ActType, ObsType],
|
||||
gym.utils.RecordConstructorArgs,
|
||||
):
|
||||
"""Transforms an observation via a function provided to the wrapper.
|
||||
|
||||
The function :attr:`func` will be applied to all observations.
|
||||
@@ -61,7 +63,11 @@ class LambdaObservationV0(gym.ObservationWrapper[WrapperObsType, ActType, ObsTyp
|
||||
func: A function that will transform an observation. If this transformed observation is outside the observation space of `env.observation_space` then provide an `observation_space`.
|
||||
observation_space: The observation spaces of the wrapper, if None, then it is assumed the same as `env.observation_space`.
|
||||
"""
|
||||
super().__init__(env)
|
||||
gym.utils.RecordConstructorArgs.__init__(
|
||||
self, func=func, observation_space=observation_space
|
||||
)
|
||||
gym.ObservationWrapper.__init__(self, env)
|
||||
|
||||
if observation_space is not None:
|
||||
self.observation_space = observation_space
|
||||
|
||||
@@ -72,7 +78,10 @@ class LambdaObservationV0(gym.ObservationWrapper[WrapperObsType, ActType, ObsTyp
|
||||
return self.func(observation)
|
||||
|
||||
|
||||
class FilterObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]):
|
||||
class FilterObservationV0(
|
||||
LambdaObservationV0[WrapperObsType, ActType, ObsType],
|
||||
gym.utils.RecordConstructorArgs,
|
||||
):
|
||||
"""Filter Dict observation space by the keys.
|
||||
|
||||
Example:
|
||||
@@ -96,6 +105,7 @@ class FilterObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType])
|
||||
):
|
||||
"""Constructor for an environment with a dictionary observation space where all :attr:`filter_keys` are in the observation space keys."""
|
||||
assert isinstance(filter_keys, Sequence)
|
||||
gym.utils.RecordConstructorArgs.__init__(self, filter_keys=filter_keys)
|
||||
|
||||
# Filters for dictionary space
|
||||
if isinstance(env.observation_space, spaces.Dict):
|
||||
@@ -124,10 +134,11 @@ class FilterObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType])
|
||||
"The observation space is empty due to filtering all keys."
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
env,
|
||||
lambda obs: {key: obs[key] for key in filter_keys},
|
||||
new_observation_space,
|
||||
LambdaObservationV0.__init__(
|
||||
self,
|
||||
env=env,
|
||||
func=lambda obs: {key: obs[key] for key in filter_keys},
|
||||
observation_space=new_observation_space,
|
||||
)
|
||||
# Filter for tuple observation
|
||||
elif isinstance(env.observation_space, spaces.Tuple):
|
||||
@@ -158,10 +169,11 @@ class FilterObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType])
|
||||
"The observation space is empty due to filtering all keys."
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
env,
|
||||
lambda obs: tuple(obs[key] for key in filter_keys),
|
||||
new_observation_spaces,
|
||||
LambdaObservationV0.__init__(
|
||||
self,
|
||||
env=env,
|
||||
func=lambda obs: tuple(obs[key] for key in filter_keys),
|
||||
observation_space=new_observation_spaces,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
@@ -171,7 +183,10 @@ class FilterObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType])
|
||||
self.filter_keys: Final[Sequence[str | int]] = filter_keys
|
||||
|
||||
|
||||
class FlattenObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]):
|
||||
class FlattenObservationV0(
|
||||
LambdaObservationV0[WrapperObsType, ActType, ObsType],
|
||||
gym.utils.RecordConstructorArgs,
|
||||
):
|
||||
"""Observation wrapper that flattens the observation.
|
||||
|
||||
Example:
|
||||
@@ -190,14 +205,19 @@ class FlattenObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]
|
||||
|
||||
def __init__(self, env: gym.Env[ObsType, ActType]):
|
||||
"""Constructor for any environment's observation space that implements ``spaces.utils.flatten_space`` and ``spaces.utils.flatten``."""
|
||||
super().__init__(
|
||||
env,
|
||||
lambda obs: utils.flatten(env.observation_space, obs),
|
||||
utils.flatten_space(env.observation_space),
|
||||
gym.utils.RecordConstructorArgs.__init__(self)
|
||||
LambdaObservationV0.__init__(
|
||||
self,
|
||||
env=env,
|
||||
func=lambda obs: spaces.utils.flatten(env.observation_space, obs),
|
||||
observation_space=spaces.utils.flatten_space(env.observation_space),
|
||||
)
|
||||
|
||||
|
||||
class GrayscaleObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]):
|
||||
class GrayscaleObservationV0(
|
||||
LambdaObservationV0[WrapperObsType, ActType, ObsType],
|
||||
gym.utils.RecordConstructorArgs,
|
||||
):
|
||||
"""Observation wrapper that converts an RGB image to grayscale.
|
||||
|
||||
The :attr:`keep_dim` will keep the channel dimension
|
||||
@@ -228,6 +248,7 @@ class GrayscaleObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsTyp
|
||||
and np.all(env.observation_space.high == 255)
|
||||
and env.observation_space.dtype == np.uint8
|
||||
)
|
||||
gym.utils.RecordConstructorArgs.__init__(self, keep_dim=keep_dim)
|
||||
|
||||
self.keep_dim: Final[bool] = keep_dim
|
||||
if keep_dim:
|
||||
@@ -237,30 +258,35 @@ class GrayscaleObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsTyp
|
||||
shape=env.observation_space.shape[:2] + (1,),
|
||||
dtype=np.uint8,
|
||||
)
|
||||
super().__init__(
|
||||
env,
|
||||
lambda obs: jp.expand_dims(
|
||||
LambdaObservationV0.__init__(
|
||||
self,
|
||||
env=env,
|
||||
func=lambda obs: jp.expand_dims(
|
||||
jp.sum(
|
||||
jp.multiply(obs, jp.array([0.2125, 0.7154, 0.0721])), axis=-1
|
||||
).astype(np.uint8),
|
||||
axis=-1,
|
||||
),
|
||||
new_observation_space,
|
||||
observation_space=new_observation_space,
|
||||
)
|
||||
else:
|
||||
new_observation_space = spaces.Box(
|
||||
low=0, high=255, shape=env.observation_space.shape[:2], dtype=np.uint8
|
||||
)
|
||||
super().__init__(
|
||||
env,
|
||||
lambda obs: jp.sum(
|
||||
LambdaObservationV0.__init__(
|
||||
self,
|
||||
env=env,
|
||||
func=lambda obs: jp.sum(
|
||||
jp.multiply(obs, jp.array([0.2125, 0.7154, 0.0721])), axis=-1
|
||||
).astype(np.uint8),
|
||||
new_observation_space,
|
||||
observation_space=new_observation_space,
|
||||
)
|
||||
|
||||
|
||||
class ResizeObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]):
|
||||
class ResizeObservationV0(
|
||||
LambdaObservationV0[WrapperObsType, ActType, ObsType],
|
||||
gym.utils.RecordConstructorArgs,
|
||||
):
|
||||
"""Resizes image observations using OpenCV to shape.
|
||||
|
||||
Example:
|
||||
@@ -299,14 +325,20 @@ class ResizeObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType])
|
||||
new_observation_space = spaces.Box(
|
||||
low=0, high=255, shape=self.shape + env.observation_space.shape[2:]
|
||||
)
|
||||
super().__init__(
|
||||
env,
|
||||
lambda obs: cv2.resize(obs, self.shape, interpolation=cv2.INTER_AREA),
|
||||
new_observation_space,
|
||||
|
||||
gym.utils.RecordConstructorArgs.__init__(self, shape=shape)
|
||||
LambdaObservationV0.__init__(
|
||||
self,
|
||||
env=env,
|
||||
func=lambda obs: cv2.resize(obs, self.shape, interpolation=cv2.INTER_AREA),
|
||||
observation_space=new_observation_space,
|
||||
)
|
||||
|
||||
|
||||
class ReshapeObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]):
|
||||
class ReshapeObservationV0(
|
||||
LambdaObservationV0[WrapperObsType, ActType, ObsType],
|
||||
gym.utils.RecordConstructorArgs,
|
||||
):
|
||||
"""Reshapes array based observations to shapes.
|
||||
|
||||
Example:
|
||||
@@ -336,10 +368,20 @@ class ReshapeObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]
|
||||
dtype=env.observation_space.dtype,
|
||||
)
|
||||
self.shape = shape
|
||||
super().__init__(env, lambda obs: jp.reshape(obs, shape), new_observation_space)
|
||||
|
||||
gym.utils.RecordConstructorArgs.__init__(self, shape=shape)
|
||||
LambdaObservationV0.__init__(
|
||||
self,
|
||||
env=env,
|
||||
func=lambda obs: jp.reshape(obs, shape),
|
||||
observation_space=new_observation_space,
|
||||
)
|
||||
|
||||
|
||||
class RescaleObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]):
|
||||
class RescaleObservationV0(
|
||||
LambdaObservationV0[WrapperObsType, ActType, ObsType],
|
||||
gym.utils.RecordConstructorArgs,
|
||||
):
|
||||
"""Linearly rescales observation to between a minimum and maximum value.
|
||||
|
||||
Example:
|
||||
@@ -392,10 +434,12 @@ class RescaleObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]
|
||||
)
|
||||
intercept = gradient * -env.observation_space.low + min_obs
|
||||
|
||||
super().__init__(
|
||||
env,
|
||||
lambda obs: gradient * obs + intercept,
|
||||
Box(
|
||||
gym.utils.RecordConstructorArgs.__init__(self, min_obs=min_obs, max_obs=max_obs)
|
||||
LambdaObservationV0.__init__(
|
||||
self,
|
||||
env=env,
|
||||
func=lambda obs: gradient * obs + intercept,
|
||||
observation_space=spaces.Box(
|
||||
low=min_obs,
|
||||
high=max_obs,
|
||||
shape=env.observation_space.shape,
|
||||
@@ -404,7 +448,10 @@ class RescaleObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]
|
||||
)
|
||||
|
||||
|
||||
class DtypeObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]):
|
||||
class DtypeObservationV0(
|
||||
LambdaObservationV0[WrapperObsType, ActType, ObsType],
|
||||
gym.utils.RecordConstructorArgs,
|
||||
):
|
||||
"""Observation wrapper for transforming the dtype of an observation."""
|
||||
|
||||
def __init__(self, env: gym.Env[ObsType, ActType], dtype: Any):
|
||||
@@ -445,10 +492,19 @@ class DtypeObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]):
|
||||
"DtypeObservation is only compatible with value / array-based observations."
|
||||
)
|
||||
|
||||
super().__init__(env, lambda obs: dtype(obs), new_observation_space)
|
||||
gym.utils.RecordConstructorArgs.__init__(self, dtype=dtype)
|
||||
LambdaObservationV0.__init__(
|
||||
self,
|
||||
env=env,
|
||||
func=lambda obs: dtype(obs),
|
||||
observation_space=new_observation_space,
|
||||
)
|
||||
|
||||
|
||||
class PixelObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]):
|
||||
class PixelObservationV0(
|
||||
LambdaObservationV0[WrapperObsType, ActType, ObsType],
|
||||
gym.utils.RecordConstructorArgs,
|
||||
):
|
||||
"""Augment observations by pixel values.
|
||||
|
||||
Observations of this wrapper will be dictionaries of images.
|
||||
@@ -461,7 +517,7 @@ class PixelObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env: Env[ObsType, ActType],
|
||||
env: gym.Env[ObsType, ActType],
|
||||
pixels_only: bool = True,
|
||||
pixels_key: str = "pixels",
|
||||
obs_key: str = "state",
|
||||
@@ -478,30 +534,49 @@ class PixelObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]):
|
||||
pixels_key: Optional custom string specifying the pixel key. Defaults to "pixels"
|
||||
obs_key: Optional custom string specifying the obs key. Defaults to "state"
|
||||
"""
|
||||
gym.utils.RecordConstructorArgs.__init__(
|
||||
self, pixels_only=pixels_only, pixels_key=pixels_key, obs_key=obs_key
|
||||
)
|
||||
|
||||
assert env.render_mode is not None and env.render_mode != "human"
|
||||
env.reset()
|
||||
pixels = env.render()
|
||||
assert pixels is not None and isinstance(pixels, np.ndarray)
|
||||
pixel_space = Box(low=0, high=255, shape=pixels.shape, dtype=np.uint8)
|
||||
pixel_space = spaces.Box(low=0, high=255, shape=pixels.shape, dtype=np.uint8)
|
||||
|
||||
if pixels_only:
|
||||
obs_space = pixel_space
|
||||
super().__init__(env, lambda _: self.render(), obs_space)
|
||||
elif isinstance(env.observation_space, Dict):
|
||||
LambdaObservationV0.__init__(
|
||||
self, env=env, func=lambda _: self.render(), observation_space=obs_space
|
||||
)
|
||||
elif isinstance(env.observation_space, spaces.Dict):
|
||||
assert pixels_key not in env.observation_space.spaces.keys()
|
||||
|
||||
obs_space = Dict({pixels_key: pixel_space, **env.observation_space.spaces})
|
||||
super().__init__(
|
||||
env, lambda obs: {pixels_key: self.render(), **obs_space}, obs_space
|
||||
obs_space = spaces.Dict(
|
||||
{pixels_key: pixel_space, **env.observation_space.spaces}
|
||||
)
|
||||
LambdaObservationV0.__init__(
|
||||
self,
|
||||
env=env,
|
||||
func=lambda obs: {pixels_key: self.render(), **obs_space},
|
||||
observation_space=obs_space,
|
||||
)
|
||||
else:
|
||||
obs_space = Dict({obs_key: env.observation_space, pixels_key: pixel_space})
|
||||
super().__init__(
|
||||
env, lambda obs: {obs_key: obs, pixels_key: self.render()}, obs_space
|
||||
obs_space = spaces.Dict(
|
||||
{obs_key: env.observation_space, pixels_key: pixel_space}
|
||||
)
|
||||
LambdaObservationV0.__init__(
|
||||
self,
|
||||
env=env,
|
||||
func=lambda obs: {obs_key: obs, pixels_key: self.render()},
|
||||
observation_space=obs_space,
|
||||
)
|
||||
|
||||
|
||||
class NormalizeObservationV0(ObservationWrapper[WrapperObsType, ActType, ObsType]):
|
||||
class NormalizeObservationV0(
|
||||
gym.ObservationWrapper[WrapperObsType, ActType, ObsType],
|
||||
gym.utils.RecordConstructorArgs,
|
||||
):
|
||||
"""This wrapper will normalize observations s.t. each coordinate is centered with unit variance.
|
||||
|
||||
The property `_update_running_mean` allows to freeze/continue the running mean calculation of the observation
|
||||
@@ -520,7 +595,9 @@ class NormalizeObservationV0(ObservationWrapper[WrapperObsType, ActType, ObsType
|
||||
env (Env): The environment to apply the wrapper
|
||||
epsilon: A stability parameter that is used when scaling the observations.
|
||||
"""
|
||||
super().__init__(env)
|
||||
gym.utils.RecordConstructorArgs.__init__(self, epsilon=epsilon)
|
||||
gym.ObservationWrapper.__init__(self, env)
|
||||
|
||||
self.obs_rms = RunningMeanStd(shape=self.observation_space.shape)
|
||||
self.epsilon = epsilon
|
||||
self._update_running_mean = True
|
||||
|
@@ -15,7 +15,9 @@ from gymnasium.error import InvalidBound
|
||||
from gymnasium.experimental.wrappers.utils import RunningMeanStd
|
||||
|
||||
|
||||
class LambdaRewardV0(gym.RewardWrapper[ObsType, ActType]):
|
||||
class LambdaRewardV0(
|
||||
gym.RewardWrapper[ObsType, ActType], gym.utils.RecordConstructorArgs
|
||||
):
|
||||
"""A reward wrapper that allows a custom function to modify the step reward.
|
||||
|
||||
Example:
|
||||
@@ -40,7 +42,8 @@ class LambdaRewardV0(gym.RewardWrapper[ObsType, ActType]):
|
||||
env (Env): The environment to apply the wrapper
|
||||
func: (Callable): The function to apply to reward
|
||||
"""
|
||||
super().__init__(env)
|
||||
gym.utils.RecordConstructorArgs.__init__(self, func=func)
|
||||
gym.RewardWrapper.__init__(self, env)
|
||||
|
||||
self.func = func
|
||||
|
||||
@@ -53,7 +56,7 @@ class LambdaRewardV0(gym.RewardWrapper[ObsType, ActType]):
|
||||
return self.func(reward)
|
||||
|
||||
|
||||
class ClipRewardV0(LambdaRewardV0[ObsType, ActType]):
|
||||
class ClipRewardV0(LambdaRewardV0[ObsType, ActType], gym.utils.RecordConstructorArgs):
|
||||
"""A wrapper that clips the rewards for an environment between an upper and lower bound.
|
||||
|
||||
Example:
|
||||
@@ -89,10 +92,17 @@ class ClipRewardV0(LambdaRewardV0[ObsType, ActType]):
|
||||
f"Min reward ({min_reward}) must be smaller than max reward ({max_reward})"
|
||||
)
|
||||
|
||||
super().__init__(env, lambda x: np.clip(x, a_min=min_reward, a_max=max_reward))
|
||||
gym.utils.RecordConstructorArgs.__init__(
|
||||
self, min_reward=min_reward, max_reward=max_reward
|
||||
)
|
||||
LambdaRewardV0.__init__(
|
||||
self, env=env, func=lambda x: np.clip(x, a_min=min_reward, a_max=max_reward)
|
||||
)
|
||||
|
||||
|
||||
class NormalizeRewardV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
|
||||
class NormalizeRewardV0(
|
||||
gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
|
||||
):
|
||||
r"""This wrapper will normalize immediate rewards s.t. their exponential moving average has a fixed variance.
|
||||
|
||||
The exponential moving average will have variance :math:`(1 - \gamma)^2`.
|
||||
@@ -119,7 +129,9 @@ class NormalizeRewardV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
|
||||
epsilon (float): A stability parameter
|
||||
gamma (float): The discount factor that is used in the exponential moving average.
|
||||
"""
|
||||
super().__init__(env)
|
||||
gym.utils.RecordConstructorArgs.__init__(self, gamma=gamma, epsilon=epsilon)
|
||||
gym.Wrapper.__init__(self, env)
|
||||
|
||||
self.rewards_running_means = RunningMeanStd(shape=())
|
||||
self.discounted_reward: np.array = np.array([0.0])
|
||||
self.gamma = gamma
|
||||
|
@@ -18,7 +18,9 @@ from gymnasium.core import ActType, ObsType, RenderFrame
|
||||
from gymnasium.error import DependencyNotInstalled
|
||||
|
||||
|
||||
class RenderCollectionV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
|
||||
class RenderCollectionV0(
|
||||
gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
|
||||
):
|
||||
"""Collect rendered frames of an environment such ``render`` returns a ``list[RenderedFrame]``."""
|
||||
|
||||
def __init__(
|
||||
@@ -34,7 +36,11 @@ class RenderCollectionV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
|
||||
pop_frames (bool): If true, clear the collection frames after ``meth:render`` is called. Default value is ``True``.
|
||||
reset_clean (bool): If true, clear the collection frames when ``meth:reset`` is called. Default value is ``True``.
|
||||
"""
|
||||
super().__init__(env)
|
||||
gym.utils.RecordConstructorArgs.__init__(
|
||||
self, pop_frames=pop_frames, reset_clean=reset_clean
|
||||
)
|
||||
gym.Wrapper.__init__(self, env)
|
||||
|
||||
assert env.render_mode is not None
|
||||
assert not env.render_mode.endswith("_list")
|
||||
|
||||
@@ -80,7 +86,9 @@ class RenderCollectionV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
|
||||
return frames
|
||||
|
||||
|
||||
class RecordVideoV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
|
||||
class RecordVideoV0(
|
||||
gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
|
||||
):
|
||||
"""This wrapper records videos of rollouts.
|
||||
|
||||
Usually, you only want to record episodes intermittently, say every hundredth episode.
|
||||
@@ -117,9 +125,18 @@ class RecordVideoV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
|
||||
Otherwise, snippets of the specified length are captured
|
||||
name_prefix (str): Will be prepended to the filename of the recordings
|
||||
disable_logger (bool): Whether to disable moviepy logger or not
|
||||
|
||||
"""
|
||||
super().__init__(env)
|
||||
gym.utils.RecordConstructorArgs.__init__(
|
||||
self,
|
||||
video_folder=video_folder,
|
||||
episode_trigger=episode_trigger,
|
||||
step_trigger=step_trigger,
|
||||
video_length=video_length,
|
||||
name_prefix=name_prefix,
|
||||
disable_logger=disable_logger,
|
||||
)
|
||||
gym.Wrapper.__init__(self, env)
|
||||
|
||||
try:
|
||||
import moviepy # noqa: F401
|
||||
except ImportError as e:
|
||||
@@ -277,7 +294,9 @@ class RecordVideoV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
|
||||
logger.warn("Unable to save last video! Did you call close()?")
|
||||
|
||||
|
||||
class HumanRenderingV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
|
||||
class HumanRenderingV0(
|
||||
gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
|
||||
):
|
||||
"""Performs human rendering for an environment that only supports "rgb_array"rendering.
|
||||
|
||||
This wrapper is particularly useful when you have implemented an environment that can produce
|
||||
@@ -317,7 +336,9 @@ class HumanRenderingV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
|
||||
Args:
|
||||
env: The environment that is being wrapped
|
||||
"""
|
||||
super().__init__(env)
|
||||
gym.utils.RecordConstructorArgs.__init__(self)
|
||||
gym.Wrapper.__init__(self, env)
|
||||
|
||||
assert env.render_mode in [
|
||||
"rgb_array",
|
||||
"rgb_array_list",
|
||||
|
@@ -4,11 +4,13 @@ from __future__ import annotations
|
||||
from typing import Any
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.core import ActionWrapper, ActType, ObsType
|
||||
from gymnasium.core import ActType, ObsType
|
||||
from gymnasium.error import InvalidProbability
|
||||
|
||||
|
||||
class StickyActionV0(ActionWrapper[ObsType, ActType, ActType]):
|
||||
class StickyActionV0(
|
||||
gym.ActionWrapper[ObsType, ActType, ActType], gym.utils.RecordConstructorArgs
|
||||
):
|
||||
"""Wrapper which adds a probability of repeating the previous action.
|
||||
|
||||
This wrapper follows the implementation proposed by `Machado et al., 2018 <https://arxiv.org/pdf/1709.06009.pdf>`_
|
||||
@@ -29,7 +31,11 @@ class StickyActionV0(ActionWrapper[ObsType, ActType, ActType]):
|
||||
f"repeat_action_probability should be in the interval [0,1). Received {repeat_action_probability}"
|
||||
)
|
||||
|
||||
super().__init__(env)
|
||||
gym.utils.RecordConstructorArgs.__init__(
|
||||
self, repeat_action_probability=repeat_action_probability
|
||||
)
|
||||
gym.ActionWrapper.__init__(self, env)
|
||||
|
||||
self.repeat_action_probability = repeat_action_probability
|
||||
self.last_action: ActType | None = None
|
||||
|
||||
|
@@ -13,8 +13,8 @@ from typing_extensions import Final
|
||||
|
||||
import numpy as np
|
||||
|
||||
import gymnasium as gym
|
||||
import gymnasium.spaces as spaces
|
||||
from gymnasium import Env, ObservationWrapper, Space, Wrapper
|
||||
from gymnasium.core import ActType, ObsType, WrapperActType, WrapperObsType
|
||||
from gymnasium.experimental.vector.utils import (
|
||||
batch_space,
|
||||
@@ -25,7 +25,9 @@ from gymnasium.experimental.wrappers.utils import create_zero_array
|
||||
from gymnasium.spaces import Box, Dict, Tuple
|
||||
|
||||
|
||||
class DelayObservationV0(ObservationWrapper[ObsType, ActType, ObsType]):
|
||||
class DelayObservationV0(
|
||||
gym.ObservationWrapper[ObsType, ActType, ObsType], gym.utils.RecordConstructorArgs
|
||||
):
|
||||
"""Wrapper which adds a delay to the returned observation.
|
||||
|
||||
Before reaching the :attr:`delay` number of timesteps, returned observations is an array of zeros with
|
||||
@@ -49,15 +51,13 @@ class DelayObservationV0(ObservationWrapper[ObsType, ActType, ObsType]):
|
||||
This does not support random delay values, if users are interested, please raise an issue or pull request to add this feature.
|
||||
"""
|
||||
|
||||
def __init__(self, env: Env[ObsType, ActType], delay: int):
|
||||
def __init__(self, env: gym.Env[ObsType, ActType], delay: int):
|
||||
"""Initialises the DelayObservation wrapper with an integer.
|
||||
|
||||
Args:
|
||||
env: The environment to wrap
|
||||
delay: The number of timesteps to delay observations
|
||||
"""
|
||||
super().__init__(env)
|
||||
|
||||
if not np.issubdtype(type(delay), np.integer):
|
||||
raise TypeError(
|
||||
f"The delay is expected to be an integer, actual type: {type(delay)}"
|
||||
@@ -67,6 +67,9 @@ class DelayObservationV0(ObservationWrapper[ObsType, ActType, ObsType]):
|
||||
f"The delay needs to be greater than zero, actual value: {delay}"
|
||||
)
|
||||
|
||||
gym.utils.RecordConstructorArgs.__init__(self, delay=delay)
|
||||
gym.ObservationWrapper.__init__(self, env)
|
||||
|
||||
self.delay: Final[int] = int(delay)
|
||||
self.observation_queue: Final[deque] = deque()
|
||||
|
||||
@@ -88,7 +91,10 @@ class DelayObservationV0(ObservationWrapper[ObsType, ActType, ObsType]):
|
||||
return create_zero_array(self.observation_space)
|
||||
|
||||
|
||||
class TimeAwareObservationV0(ObservationWrapper[WrapperObsType, ActType, ObsType]):
|
||||
class TimeAwareObservationV0(
|
||||
gym.ObservationWrapper[WrapperObsType, ActType, ObsType],
|
||||
gym.utils.RecordConstructorArgs,
|
||||
):
|
||||
"""Augment the observation with time information of the episode.
|
||||
|
||||
The :attr:`normalize_time` if ``True`` represents time as a normalized value between [0,1]
|
||||
@@ -144,7 +150,7 @@ class TimeAwareObservationV0(ObservationWrapper[WrapperObsType, ActType, ObsType
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env: Env[ObsType, ActType],
|
||||
env: gym.Env[ObsType, ActType],
|
||||
flatten: bool = False,
|
||||
normalize_time: bool = True,
|
||||
*,
|
||||
@@ -159,7 +165,13 @@ class TimeAwareObservationV0(ObservationWrapper[WrapperObsType, ActType, ObsType
|
||||
otherwise return time as remaining timesteps before truncation
|
||||
dict_time_key: For environment with a ``Dict`` observation space, the key for the time space. By default, `"time"`.
|
||||
"""
|
||||
super().__init__(env)
|
||||
gym.utils.RecordConstructorArgs.__init__(
|
||||
self,
|
||||
flatten=flatten,
|
||||
normalize_time=normalize_time,
|
||||
dict_time_key=dict_time_key,
|
||||
)
|
||||
gym.ObservationWrapper.__init__(self, env)
|
||||
|
||||
self.flatten: Final[bool] = flatten
|
||||
self.normalize_time: Final[bool] = normalize_time
|
||||
@@ -203,14 +215,14 @@ class TimeAwareObservationV0(ObservationWrapper[WrapperObsType, ActType, ObsType
|
||||
|
||||
# If to flatten the observation space
|
||||
if self.flatten:
|
||||
self.observation_space: Space[WrapperObsType] = spaces.flatten_space(
|
||||
self.observation_space: gym.Space[WrapperObsType] = spaces.flatten_space(
|
||||
observation_space
|
||||
)
|
||||
self._obs_postprocess_func = lambda obs: spaces.flatten(
|
||||
observation_space, obs
|
||||
)
|
||||
else:
|
||||
self.observation_space: Space[WrapperObsType] = observation_space
|
||||
self.observation_space: gym.Space[WrapperObsType] = observation_space
|
||||
self._obs_postprocess_func = lambda obs: obs
|
||||
|
||||
def observation(self, observation: ObsType) -> WrapperObsType:
|
||||
@@ -260,7 +272,10 @@ class TimeAwareObservationV0(ObservationWrapper[WrapperObsType, ActType, ObsType
|
||||
return super().reset(seed=seed, options=options)
|
||||
|
||||
|
||||
class FrameStackObservationV0(Wrapper[WrapperObsType, ActType, ObsType, ActType]):
|
||||
class FrameStackObservationV0(
|
||||
gym.Wrapper[WrapperObsType, ActType, ObsType, ActType],
|
||||
gym.utils.RecordConstructorArgs,
|
||||
):
|
||||
"""Observation wrapper that stacks the observations in a rolling manner.
|
||||
|
||||
For example, if the number of stacks is 4, then the returned observation contains
|
||||
@@ -286,7 +301,7 @@ class FrameStackObservationV0(Wrapper[WrapperObsType, ActType, ObsType, ActType]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env: Env[ObsType, ActType],
|
||||
env: gym.Env[ObsType, ActType],
|
||||
stack_size: int,
|
||||
*,
|
||||
zeros_obs: ObsType | None = None,
|
||||
@@ -298,8 +313,6 @@ class FrameStackObservationV0(Wrapper[WrapperObsType, ActType, ObsType, ActType]
|
||||
stack_size: The number of frames to stack with zero_obs being used originally.
|
||||
zeros_obs: Keyword only parameter that allows a custom padding observation at :meth:`reset`
|
||||
"""
|
||||
super().__init__(env)
|
||||
|
||||
if not np.issubdtype(type(stack_size), np.integer):
|
||||
raise TypeError(
|
||||
f"The stack_size is expected to be an integer, actual type: {type(stack_size)}"
|
||||
@@ -309,6 +322,9 @@ class FrameStackObservationV0(Wrapper[WrapperObsType, ActType, ObsType, ActType]
|
||||
f"The stack_size needs to be greater than one, actual value: {stack_size}"
|
||||
)
|
||||
|
||||
gym.utils.RecordConstructorArgs.__init__(self, stack_size=stack_size)
|
||||
gym.Wrapper.__init__(self, env)
|
||||
|
||||
self.observation_space = batch_space(env.observation_space, n=stack_size)
|
||||
self.stack_size: Final[int] = stack_size
|
||||
|
||||
|
@@ -8,6 +8,7 @@ These are not intended as API functions, and will not remain stable over time.
|
||||
# that verify that our dependencies are actually present.
|
||||
from gymnasium.utils.colorize import colorize
|
||||
from gymnasium.utils.ezpickle import EzPickle
|
||||
from gymnasium.utils.record_constructor import RecordConstructorArgs
|
||||
|
||||
|
||||
__all__ = ["colorize", "EzPickle"]
|
||||
__all__ = ["colorize", "EzPickle", "RecordConstructorArgs"]
|
||||
|
@@ -1,13 +1,15 @@
|
||||
"""Class for pickling and unpickling objects via their constructor arguments."""
|
||||
from typing import Any
|
||||
|
||||
|
||||
class EzPickle:
|
||||
"""Objects that are pickled and unpickled via their constructor arguments.
|
||||
|
||||
Example:
|
||||
>>> class Dog(Animal, EzPickle): # doctest: +SKIP
|
||||
>>> class Animal: pass
|
||||
>>> class Dog(Animal, EzPickle):
|
||||
... def __init__(self, furcolor, tailkind="bushy"):
|
||||
... Animal.__init__()
|
||||
... Animal.__init__(self)
|
||||
... EzPickle.__init__(self, furcolor, tailkind)
|
||||
|
||||
When this object is unpickled, a new ``Dog`` will be constructed by passing the provided furcolor and tailkind into the constructor.
|
||||
@@ -16,7 +18,7 @@ class EzPickle:
|
||||
This is generally needed only for environments which wrap C/C++ code, such as MuJoCo and Atari.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(self, *args: Any, **kwargs: Any):
|
||||
"""Uses the ``args`` and ``kwargs`` from the object's constructor for pickling."""
|
||||
self._ezpickle_args = args
|
||||
self._ezpickle_kwargs = kwargs
|
||||
|
33
gymnasium/utils/record_constructor.py
Normal file
33
gymnasium/utils/record_constructor.py
Normal file
@@ -0,0 +1,33 @@
|
||||
"""Allows attributes passed to `RecordConstructorArgs` to be saved. This is used by the `Wrapper.spec` to know the constructor arguments of implemented wrappers."""
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
|
||||
class RecordConstructorArgs:
|
||||
"""Records all arguments passed to constructor to `_saved_kwargs`.
|
||||
|
||||
This can be used to save and reproduce class constructor arguments.
|
||||
|
||||
Note:
|
||||
If two class inherit from RecordConstructorArgs then the first class to call `RecordConstructorArgs.__init__(self, ...)` will have
|
||||
their kwargs saved will all subsequent `RecordConstructorArgs.__init__` being ignored.
|
||||
|
||||
Therefore, always call `RecordConstructorArgs.__init__` before the `Class.__init__`
|
||||
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, *, _disable_deepcopy: bool = False, **kwargs: Any):
|
||||
"""Records all arguments passed to constructor to `_saved_kwargs`.
|
||||
|
||||
Args:
|
||||
_disable_deepcopy: If to not deepcopy the kwargs passed
|
||||
**kwargs: Arguments to save
|
||||
"""
|
||||
# See class docstring for explanation
|
||||
if not hasattr(self, "_saved_kwargs"):
|
||||
if _disable_deepcopy is False:
|
||||
kwargs = deepcopy(kwargs)
|
||||
self._saved_kwargs: dict[str, Any] = kwargs
|
@@ -11,7 +11,7 @@ except ImportError:
|
||||
cv2 = None
|
||||
|
||||
|
||||
class AtariPreprocessing(gym.Wrapper):
|
||||
class AtariPreprocessing(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
||||
"""Atari 2600 preprocessing wrapper.
|
||||
|
||||
This class follows the guidelines in Machado et al. (2018),
|
||||
@@ -60,7 +60,18 @@ class AtariPreprocessing(gym.Wrapper):
|
||||
DependencyNotInstalled: opencv-python package not installed
|
||||
ValueError: Disable frame-skipping in the original env
|
||||
"""
|
||||
super().__init__(env)
|
||||
gym.utils.RecordConstructorArgs.__init__(
|
||||
self,
|
||||
noop_max=noop_max,
|
||||
frame_skip=frame_skip,
|
||||
screen_size=screen_size,
|
||||
terminal_on_life_loss=terminal_on_life_loss,
|
||||
grayscale_obs=grayscale_obs,
|
||||
grayscale_newaxis=grayscale_newaxis,
|
||||
scale_obs=scale_obs,
|
||||
)
|
||||
gym.Wrapper.__init__(self, env)
|
||||
|
||||
if cv2 is None:
|
||||
raise gym.error.DependencyNotInstalled(
|
||||
"opencv-python package not installed, run `pip install gymnasium[other]` to get dependencies for atari"
|
||||
|
@@ -2,7 +2,7 @@
|
||||
import gymnasium as gym
|
||||
|
||||
|
||||
class AutoResetWrapper(gym.Wrapper):
|
||||
class AutoResetWrapper(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
||||
"""A class for providing an automatic reset functionality for gymnasium environments when calling :meth:`self.step`.
|
||||
|
||||
When calling step causes :meth:`Env.step` to return `terminated=True` or `truncated=True`, :meth:`Env.reset` is called,
|
||||
@@ -31,7 +31,8 @@ class AutoResetWrapper(gym.Wrapper):
|
||||
Args:
|
||||
env (gym.Env): The environment to apply the wrapper
|
||||
"""
|
||||
super().__init__(env)
|
||||
gym.utils.RecordConstructorArgs.__init__(self)
|
||||
gym.Wrapper.__init__(self, env)
|
||||
|
||||
def step(self, action):
|
||||
"""Steps through the environment with action and resets the environment if a terminated or truncated signal is encountered.
|
||||
|
@@ -2,11 +2,10 @@
|
||||
import numpy as np
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium import ActionWrapper
|
||||
from gymnasium.spaces import Box
|
||||
|
||||
|
||||
class ClipAction(ActionWrapper):
|
||||
class ClipAction(gym.ActionWrapper, gym.utils.RecordConstructorArgs):
|
||||
"""Clip the continuous action within the valid :class:`Box` observation space bound.
|
||||
|
||||
Example:
|
||||
@@ -28,7 +27,9 @@ class ClipAction(ActionWrapper):
|
||||
env: The environment to apply the wrapper
|
||||
"""
|
||||
assert isinstance(env.action_space, Box)
|
||||
super().__init__(env)
|
||||
|
||||
gym.utils.RecordConstructorArgs.__init__(self)
|
||||
gym.ActionWrapper.__init__(self, env)
|
||||
|
||||
def action(self, action):
|
||||
"""Clips the action within the valid bounds.
|
||||
|
@@ -68,11 +68,12 @@ class EnvCompatibility(gym.Env):
|
||||
"The `gymnasium.make(..., apply_api_compatibility=...)` parameter is deprecated and will be removed in v0.29. "
|
||||
"Instead use `gym.make('GymV21Environment-v0', env_name=...)` or `from shimmy import GymV21CompatibilityV0`"
|
||||
)
|
||||
|
||||
self.env = old_env
|
||||
self.metadata = getattr(old_env, "metadata", {"render_modes": []})
|
||||
self.render_mode = render_mode
|
||||
self.reward_range = getattr(old_env, "reward_range", None)
|
||||
self.spec = getattr(old_env, "spec", None)
|
||||
self.env = old_env
|
||||
|
||||
self.observation_space = old_env.observation_space
|
||||
self.action_space = old_env.action_space
|
||||
|
@@ -10,12 +10,13 @@ from gymnasium.utils.passive_env_checker import (
|
||||
)
|
||||
|
||||
|
||||
class PassiveEnvChecker(gym.Wrapper):
|
||||
class PassiveEnvChecker(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
||||
"""A passive environment checker wrapper that surrounds the step, reset and render functions to check they follow the gymnasium API."""
|
||||
|
||||
def __init__(self, env):
|
||||
"""Initialises the wrapper with the environments, run the observation and action space tests."""
|
||||
super().__init__(env)
|
||||
gym.utils.RecordConstructorArgs.__init__(self)
|
||||
gym.Wrapper.__init__(self, env)
|
||||
|
||||
assert hasattr(
|
||||
env, "action_space"
|
||||
|
@@ -6,7 +6,7 @@ import gymnasium as gym
|
||||
from gymnasium import spaces
|
||||
|
||||
|
||||
class FilterObservation(gym.ObservationWrapper):
|
||||
class FilterObservation(gym.ObservationWrapper, gym.utils.RecordConstructorArgs):
|
||||
"""Filter Dict observation space by the keys.
|
||||
|
||||
Example:
|
||||
@@ -35,7 +35,8 @@ class FilterObservation(gym.ObservationWrapper):
|
||||
ValueError: If the environment's observation space is not :class:`spaces.Dict`
|
||||
ValueError: If any of the `filter_keys` are not included in the original `env`'s observation space
|
||||
"""
|
||||
super().__init__(env)
|
||||
gym.utils.RecordConstructorArgs.__init__(self, filter_keys=filter_keys)
|
||||
gym.ObservationWrapper.__init__(self, env)
|
||||
|
||||
wrapped_observation_space = env.observation_space
|
||||
if not isinstance(wrapped_observation_space, spaces.Dict):
|
||||
|
@@ -3,7 +3,7 @@ import gymnasium as gym
|
||||
from gymnasium import spaces
|
||||
|
||||
|
||||
class FlattenObservation(gym.ObservationWrapper):
|
||||
class FlattenObservation(gym.ObservationWrapper, gym.utils.RecordConstructorArgs):
|
||||
"""Observation wrapper that flattens the observation.
|
||||
|
||||
Example:
|
||||
@@ -26,7 +26,9 @@ class FlattenObservation(gym.ObservationWrapper):
|
||||
Args:
|
||||
env: The environment to apply the wrapper
|
||||
"""
|
||||
super().__init__(env)
|
||||
gym.utils.RecordConstructorArgs.__init__(self)
|
||||
gym.ObservationWrapper.__init__(self, env)
|
||||
|
||||
self.observation_space = spaces.flatten_space(env.observation_space)
|
||||
|
||||
def observation(self, observation):
|
||||
|
@@ -97,7 +97,7 @@ class LazyFrames:
|
||||
return frame
|
||||
|
||||
|
||||
class FrameStack(gym.ObservationWrapper):
|
||||
class FrameStack(gym.ObservationWrapper, gym.utils.RecordConstructorArgs):
|
||||
"""Observation wrapper that stacks the observations in a rolling manner.
|
||||
|
||||
For example, if the number of stacks is 4, then the returned observation contains
|
||||
@@ -137,7 +137,11 @@ class FrameStack(gym.ObservationWrapper):
|
||||
num_stack (int): The number of frames to stack
|
||||
lz4_compress (bool): Use lz4 to compress the frames internally
|
||||
"""
|
||||
super().__init__(env)
|
||||
gym.utils.RecordConstructorArgs.__init__(
|
||||
self, num_stack=num_stack, lz4_compress=lz4_compress
|
||||
)
|
||||
gym.ObservationWrapper.__init__(self, env)
|
||||
|
||||
self.num_stack = num_stack
|
||||
self.lz4_compress = lz4_compress
|
||||
|
||||
|
@@ -5,7 +5,7 @@ import gymnasium as gym
|
||||
from gymnasium.spaces import Box
|
||||
|
||||
|
||||
class GrayScaleObservation(gym.ObservationWrapper):
|
||||
class GrayScaleObservation(gym.ObservationWrapper, gym.utils.RecordConstructorArgs):
|
||||
"""Convert the image observation from RGB to gray scale.
|
||||
|
||||
Example:
|
||||
@@ -30,7 +30,9 @@ class GrayScaleObservation(gym.ObservationWrapper):
|
||||
keep_dim (bool): If `True`, a singleton dimension will be added, i.e. observations are of the shape AxBx1.
|
||||
Otherwise, they are of shape AxB.
|
||||
"""
|
||||
super().__init__(env)
|
||||
gym.utils.RecordConstructorArgs.__init__(self, keep_dim=keep_dim)
|
||||
gym.ObservationWrapper.__init__(self, env)
|
||||
|
||||
self.keep_dim = keep_dim
|
||||
|
||||
assert (
|
||||
|
@@ -7,7 +7,7 @@ import gymnasium as gym
|
||||
from gymnasium.error import DependencyNotInstalled
|
||||
|
||||
|
||||
class HumanRendering(gym.Wrapper):
|
||||
class HumanRendering(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
||||
"""Performs human rendering for an environment that only supports "rgb_array"rendering.
|
||||
|
||||
This wrapper is particularly useful when you have implemented an environment that can produce
|
||||
@@ -47,7 +47,9 @@ class HumanRendering(gym.Wrapper):
|
||||
Args:
|
||||
env: The environment that is being wrapped
|
||||
"""
|
||||
super().__init__(env)
|
||||
gym.utils.RecordConstructorArgs.__init__(self)
|
||||
gym.Wrapper.__init__(self, env)
|
||||
|
||||
assert env.render_mode in [
|
||||
"rgb_array",
|
||||
"rgb_array_list",
|
||||
@@ -64,6 +66,8 @@ class HumanRendering(gym.Wrapper):
|
||||
if "human" not in self.metadata["render_modes"]:
|
||||
self.metadata["render_modes"].append("human")
|
||||
|
||||
gym.utils.RecordConstructorArgs.__init__(self)
|
||||
|
||||
@property
|
||||
def render_mode(self):
|
||||
"""Always returns ``'human'``."""
|
||||
|
@@ -45,7 +45,7 @@ def update_mean_var_count_from_moments(
|
||||
return new_mean, new_var, new_count
|
||||
|
||||
|
||||
class NormalizeObservation(gym.Wrapper):
|
||||
class NormalizeObservation(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
||||
"""This wrapper will normalize observations s.t. each coordinate is centered with unit variance.
|
||||
|
||||
Note:
|
||||
@@ -60,7 +60,9 @@ class NormalizeObservation(gym.Wrapper):
|
||||
env (Env): The environment to apply the wrapper
|
||||
epsilon: A stability parameter that is used when scaling the observations.
|
||||
"""
|
||||
super().__init__(env)
|
||||
gym.utils.RecordConstructorArgs.__init__(self, epsilon=epsilon)
|
||||
gym.Wrapper.__init__(self, env)
|
||||
|
||||
self.num_envs = getattr(env, "num_envs", 1)
|
||||
self.is_vector_env = getattr(env, "is_vector_env", False)
|
||||
if self.is_vector_env:
|
||||
@@ -93,7 +95,7 @@ class NormalizeObservation(gym.Wrapper):
|
||||
return (obs - self.obs_rms.mean) / np.sqrt(self.obs_rms.var + self.epsilon)
|
||||
|
||||
|
||||
class NormalizeReward(gym.core.Wrapper):
|
||||
class NormalizeReward(gym.core.Wrapper, gym.utils.RecordConstructorArgs):
|
||||
r"""This wrapper will normalize immediate rewards s.t. their exponential moving average has a fixed variance.
|
||||
|
||||
The exponential moving average will have variance :math:`(1 - \gamma)^2`.
|
||||
@@ -116,7 +118,9 @@ class NormalizeReward(gym.core.Wrapper):
|
||||
epsilon (float): A stability parameter
|
||||
gamma (float): The discount factor that is used in the exponential moving average.
|
||||
"""
|
||||
super().__init__(env)
|
||||
gym.utils.RecordConstructorArgs.__init__(self, gamma=gamma, epsilon=epsilon)
|
||||
gym.Wrapper.__init__(self, env)
|
||||
|
||||
self.num_envs = getattr(env, "num_envs", 1)
|
||||
self.is_vector_env = getattr(env, "is_vector_env", False)
|
||||
self.return_rms = RunningMeanStd(shape=())
|
||||
|
@@ -3,7 +3,7 @@ import gymnasium as gym
|
||||
from gymnasium.error import ResetNeeded
|
||||
|
||||
|
||||
class OrderEnforcing(gym.Wrapper):
|
||||
class OrderEnforcing(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
||||
"""A wrapper that will produce an error if :meth:`step` is called before an initial :meth:`reset`.
|
||||
|
||||
Example:
|
||||
@@ -32,7 +32,11 @@ class OrderEnforcing(gym.Wrapper):
|
||||
env: The environment to wrap
|
||||
disable_render_order_enforcing: If to disable render order enforcing
|
||||
"""
|
||||
super().__init__(env)
|
||||
gym.utils.RecordConstructorArgs.__init__(
|
||||
self, disable_render_order_enforcing=disable_render_order_enforcing
|
||||
)
|
||||
gym.Wrapper.__init__(self, env)
|
||||
|
||||
self._has_reset: bool = False
|
||||
self._disable_render_order_enforcing: bool = disable_render_order_enforcing
|
||||
|
||||
|
@@ -13,7 +13,7 @@ from gymnasium import spaces
|
||||
STATE_KEY = "state"
|
||||
|
||||
|
||||
class PixelObservationWrapper(gym.ObservationWrapper):
|
||||
class PixelObservationWrapper(gym.ObservationWrapper, gym.utils.RecordConstructorArgs):
|
||||
"""Augment observations by pixel values.
|
||||
|
||||
Observations of this wrapper will be dictionaries of images.
|
||||
@@ -79,7 +79,13 @@ class PixelObservationWrapper(gym.ObservationWrapper):
|
||||
specified ``pixel_keys``.
|
||||
TypeError: When an unexpected pixel type is used
|
||||
"""
|
||||
super().__init__(env)
|
||||
gym.utils.RecordConstructorArgs.__init__(
|
||||
self,
|
||||
pixels_only=pixels_only,
|
||||
render_kwargs=render_kwargs,
|
||||
pixel_keys=pixel_keys,
|
||||
)
|
||||
gym.ObservationWrapper.__init__(self, env)
|
||||
|
||||
# Avoid side-effects that occur when render_kwargs is manipulated
|
||||
render_kwargs = copy.deepcopy(render_kwargs)
|
||||
|
@@ -8,7 +8,7 @@ import numpy as np
|
||||
import gymnasium as gym
|
||||
|
||||
|
||||
class RecordEpisodeStatistics(gym.Wrapper):
|
||||
class RecordEpisodeStatistics(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
||||
"""This wrapper will keep track of cumulative rewards and episode lengths.
|
||||
|
||||
At the end of an episode, the statistics of the episode will be added to ``info``
|
||||
@@ -56,7 +56,9 @@ class RecordEpisodeStatistics(gym.Wrapper):
|
||||
env (Env): The environment to apply the wrapper
|
||||
deque_size: The size of the buffers :attr:`return_queue` and :attr:`length_queue`
|
||||
"""
|
||||
super().__init__(env)
|
||||
gym.utils.RecordConstructorArgs.__init__(self, deque_size=deque_size)
|
||||
gym.Wrapper.__init__(self, env)
|
||||
|
||||
self.num_envs = getattr(env, "num_envs", 1)
|
||||
self.episode_count = 0
|
||||
self.episode_start_times: np.ndarray = None
|
||||
|
@@ -24,7 +24,7 @@ def capped_cubic_video_schedule(episode_id: int) -> bool:
|
||||
return episode_id % 1000 == 0
|
||||
|
||||
|
||||
class RecordVideo(gym.Wrapper):
|
||||
class RecordVideo(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
||||
"""This wrapper records videos of rollouts.
|
||||
|
||||
Usually, you only want to record episodes intermittently, say every hundredth episode.
|
||||
@@ -58,9 +58,17 @@ class RecordVideo(gym.Wrapper):
|
||||
Otherwise, snippets of the specified length are captured
|
||||
name_prefix (str): Will be prepended to the filename of the recordings
|
||||
disable_logger (bool): Whether to disable moviepy logger or not.
|
||||
|
||||
"""
|
||||
super().__init__(env)
|
||||
gym.utils.RecordConstructorArgs.__init__(
|
||||
self,
|
||||
video_folder=video_folder,
|
||||
episode_trigger=episode_trigger,
|
||||
step_trigger=step_trigger,
|
||||
video_length=video_length,
|
||||
name_prefix=name_prefix,
|
||||
disable_logger=disable_logger,
|
||||
)
|
||||
gym.Wrapper.__init__(self, env)
|
||||
|
||||
if episode_trigger is None and step_trigger is None:
|
||||
episode_trigger = capped_cubic_video_schedule
|
||||
|
@@ -4,7 +4,7 @@ import copy
|
||||
import gymnasium as gym
|
||||
|
||||
|
||||
class RenderCollection(gym.Wrapper):
|
||||
class RenderCollection(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
||||
"""Save collection of render frames."""
|
||||
|
||||
def __init__(self, env: gym.Env, pop_frames: bool = True, reset_clean: bool = True):
|
||||
@@ -17,7 +17,11 @@ class RenderCollection(gym.Wrapper):
|
||||
reset_clean (bool): If true, clear the collection frames when .reset() is called.
|
||||
Default value is True.
|
||||
"""
|
||||
super().__init__(env)
|
||||
gym.utils.RecordConstructorArgs.__init__(
|
||||
self, pop_frames=pop_frames, reset_clean=reset_clean
|
||||
)
|
||||
gym.Wrapper.__init__(self, env)
|
||||
|
||||
assert env.render_mode is not None
|
||||
assert not env.render_mode.endswith("_list")
|
||||
self.frame_list = []
|
||||
|
@@ -7,7 +7,7 @@ import gymnasium as gym
|
||||
from gymnasium.spaces import Box
|
||||
|
||||
|
||||
class RescaleAction(gym.ActionWrapper):
|
||||
class RescaleAction(gym.ActionWrapper, gym.utils.RecordConstructorArgs):
|
||||
"""Affinely rescales the continuous action space of the environment to the range [min_action, max_action].
|
||||
|
||||
The base environment :attr:`env` must have an action space of type :class:`spaces.Box`. If :attr:`min_action`
|
||||
@@ -47,7 +47,11 @@ class RescaleAction(gym.ActionWrapper):
|
||||
), f"expected Box action space, got {type(env.action_space)}"
|
||||
assert np.less_equal(min_action, max_action).all(), (min_action, max_action)
|
||||
|
||||
super().__init__(env)
|
||||
gym.utils.RecordConstructorArgs.__init__(
|
||||
self, min_action=min_action, max_action=max_action
|
||||
)
|
||||
gym.ActionWrapper.__init__(self, env)
|
||||
|
||||
self.min_action = (
|
||||
np.zeros(env.action_space.shape, dtype=env.action_space.dtype) + min_action
|
||||
)
|
||||
|
@@ -8,7 +8,7 @@ from gymnasium.error import DependencyNotInstalled
|
||||
from gymnasium.spaces import Box
|
||||
|
||||
|
||||
class ResizeObservation(gym.ObservationWrapper):
|
||||
class ResizeObservation(gym.ObservationWrapper, gym.utils.RecordConstructorArgs):
|
||||
"""Resize the image observation.
|
||||
|
||||
This wrapper works on environments with image observations. More generally,
|
||||
@@ -36,7 +36,9 @@ class ResizeObservation(gym.ObservationWrapper):
|
||||
env: The environment to apply the wrapper
|
||||
shape: The shape of the resized observations
|
||||
"""
|
||||
super().__init__(env)
|
||||
gym.utils.RecordConstructorArgs.__init__(self, shape=shape)
|
||||
gym.ObservationWrapper.__init__(self, env)
|
||||
|
||||
if isinstance(shape, int):
|
||||
shape = (shape, shape)
|
||||
assert len(shape) == 2 and all(
|
||||
|
@@ -4,7 +4,7 @@ from gymnasium.logger import deprecation
|
||||
from gymnasium.utils.step_api_compatibility import step_api_compatibility
|
||||
|
||||
|
||||
class StepAPICompatibility(gym.Wrapper):
|
||||
class StepAPICompatibility(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
||||
r"""A wrapper which can transform an environment from new step API to old and vice-versa.
|
||||
|
||||
Old step API refers to step() method returning (observation, reward, done, info)
|
||||
@@ -29,7 +29,11 @@ class StepAPICompatibility(gym.Wrapper):
|
||||
env (gym.Env): the env to wrap. Can be in old or new API
|
||||
output_truncation_bool (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API)
|
||||
"""
|
||||
super().__init__(env)
|
||||
gym.utils.RecordConstructorArgs.__init__(
|
||||
self, output_truncation_bool=output_truncation_bool
|
||||
)
|
||||
gym.Wrapper.__init__(self, env)
|
||||
|
||||
self.is_vector_env = isinstance(env.unwrapped, gym.vector.VectorEnv)
|
||||
self.output_truncation_bool = output_truncation_bool
|
||||
if not self.output_truncation_bool:
|
||||
|
@@ -5,7 +5,7 @@ import gymnasium as gym
|
||||
from gymnasium.spaces import Box
|
||||
|
||||
|
||||
class TimeAwareObservation(gym.ObservationWrapper):
|
||||
class TimeAwareObservation(gym.ObservationWrapper, gym.utils.RecordConstructorArgs):
|
||||
"""Augment the observation with the current time step in the episode.
|
||||
|
||||
The observation space of the wrapped environment is assumed to be a flat :class:`Box`.
|
||||
@@ -29,7 +29,9 @@ class TimeAwareObservation(gym.ObservationWrapper):
|
||||
Args:
|
||||
env: The environment to apply the wrapper
|
||||
"""
|
||||
super().__init__(env)
|
||||
gym.utils.RecordConstructorArgs.__init__(self)
|
||||
gym.ObservationWrapper.__init__(self, env)
|
||||
|
||||
assert isinstance(env.observation_space, Box)
|
||||
assert env.observation_space.dtype == np.float32
|
||||
low = np.append(self.observation_space.low, 0.0)
|
||||
|
@@ -4,7 +4,7 @@ from typing import Optional
|
||||
import gymnasium as gym
|
||||
|
||||
|
||||
class TimeLimit(gym.Wrapper):
|
||||
class TimeLimit(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
||||
"""This wrapper will issue a `truncated` signal if a maximum number of timesteps is exceeded.
|
||||
|
||||
If a truncation is not defined inside the environment itself, this is the only place that the truncation signal is issued.
|
||||
@@ -26,14 +26,16 @@ class TimeLimit(gym.Wrapper):
|
||||
|
||||
Args:
|
||||
env: The environment to apply the wrapper
|
||||
max_episode_steps: An optional max episode steps (if ``Ǹone``, ``env.spec.max_episode_steps`` is used)
|
||||
max_episode_steps: An optional max episode steps (if ``None``, ``env.spec.max_episode_steps`` is used)
|
||||
"""
|
||||
super().__init__(env)
|
||||
gym.utils.RecordConstructorArgs.__init__(
|
||||
self, max_episode_steps=max_episode_steps
|
||||
)
|
||||
gym.Wrapper.__init__(self, env)
|
||||
|
||||
if max_episode_steps is None and self.env.spec is not None:
|
||||
assert env.spec is not None
|
||||
max_episode_steps = env.spec.max_episode_steps
|
||||
if self.env.spec is not None:
|
||||
self.env.spec.max_episode_steps = max_episode_steps
|
||||
self._max_episode_steps = max_episode_steps
|
||||
self._elapsed_steps = None
|
||||
|
||||
|
@@ -4,7 +4,7 @@ from typing import Any, Callable
|
||||
import gymnasium as gym
|
||||
|
||||
|
||||
class TransformObservation(gym.ObservationWrapper):
|
||||
class TransformObservation(gym.ObservationWrapper, gym.utils.RecordConstructorArgs):
|
||||
"""Transform the observation via an arbitrary function :attr:`f`.
|
||||
|
||||
The function :attr:`f` should be defined on the observation space of the base environment, ``env``, and should, ideally, return values in the same space.
|
||||
@@ -29,7 +29,9 @@ class TransformObservation(gym.ObservationWrapper):
|
||||
env: The environment to apply the wrapper
|
||||
f: A function that transforms the observation
|
||||
"""
|
||||
super().__init__(env)
|
||||
gym.utils.RecordConstructorArgs.__init__(self, f=f)
|
||||
gym.ObservationWrapper.__init__(self, env)
|
||||
|
||||
assert callable(f)
|
||||
self.f = f
|
||||
|
||||
|
@@ -2,10 +2,9 @@
|
||||
from typing import Callable
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium import RewardWrapper
|
||||
|
||||
|
||||
class TransformReward(RewardWrapper):
|
||||
class TransformReward(gym.RewardWrapper, gym.utils.RecordConstructorArgs):
|
||||
"""Transform the reward via an arbitrary function.
|
||||
|
||||
Warning:
|
||||
@@ -29,7 +28,9 @@ class TransformReward(RewardWrapper):
|
||||
env: The environment to apply the wrapper
|
||||
f: A function that transforms the reward
|
||||
"""
|
||||
super().__init__(env)
|
||||
gym.utils.RecordConstructorArgs.__init__(self, f=f)
|
||||
gym.RewardWrapper.__init__(self, env)
|
||||
|
||||
assert callable(f)
|
||||
self.f = f
|
||||
|
||||
|
@@ -5,7 +5,7 @@ from typing import List
|
||||
import gymnasium as gym
|
||||
|
||||
|
||||
class VectorListInfo(gym.Wrapper):
|
||||
class VectorListInfo(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
||||
"""Converts infos of vectorized environments from dict to List[dict].
|
||||
|
||||
This wrapper converts the info format of a
|
||||
@@ -51,7 +51,9 @@ class VectorListInfo(gym.Wrapper):
|
||||
assert getattr(
|
||||
env, "is_vector_env", False
|
||||
), "This wrapper can only be used in vectorized environments."
|
||||
super().__init__(env)
|
||||
|
||||
gym.utils.RecordConstructorArgs.__init__(self)
|
||||
gym.Wrapper.__init__(self, env)
|
||||
|
||||
def step(self, action):
|
||||
"""Steps through the environment, convert dict info to list."""
|
||||
|
237
tests/envs/registration/test_env_spec.py
Normal file
237
tests/envs/registration/test_env_spec.py
Normal file
@@ -0,0 +1,237 @@
|
||||
"""Example file showing usage of env.specstack."""
|
||||
import pickle
|
||||
|
||||
import pytest
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.envs.classic_control import CartPoleEnv
|
||||
from gymnasium.envs.registration import EnvSpec
|
||||
from gymnasium.utils.env_checker import data_equivalence
|
||||
|
||||
|
||||
def test_full_integration():
|
||||
# Create an environment to test with
|
||||
env = gym.make("CartPole-v1", render_mode="rgb_array").unwrapped
|
||||
|
||||
env = gym.wrappers.FlattenObservation(env)
|
||||
env = gym.wrappers.TimeAwareObservation(env)
|
||||
env = gym.wrappers.NormalizeReward(env, gamma=0.8)
|
||||
|
||||
# Generate the spec_stack
|
||||
env_spec = env.spec
|
||||
assert isinstance(env_spec, EnvSpec)
|
||||
# env_spec.pprint()
|
||||
|
||||
# Serialize the spec_stack
|
||||
env_spec_json = env_spec.to_json()
|
||||
assert isinstance(env_spec_json, str)
|
||||
|
||||
# Deserialize the spec_stack
|
||||
recreate_env_spec = EnvSpec.from_json(env_spec_json)
|
||||
# recreate_env_spec.pprint()
|
||||
|
||||
for wrapper_spec, recreated_wrapper_spec in zip(
|
||||
env_spec.applied_wrappers, recreate_env_spec.applied_wrappers
|
||||
):
|
||||
assert wrapper_spec == recreated_wrapper_spec
|
||||
assert recreate_env_spec == env_spec
|
||||
|
||||
# Recreate the environment using the spec_stack
|
||||
recreated_env = gym.make(recreate_env_spec)
|
||||
assert recreated_env.render_mode == "rgb_array"
|
||||
assert isinstance(recreated_env, gym.wrappers.NormalizeReward)
|
||||
assert recreated_env.gamma == 0.8
|
||||
assert isinstance(recreated_env.env, gym.wrappers.TimeAwareObservation)
|
||||
assert isinstance(recreated_env.unwrapped, CartPoleEnv)
|
||||
|
||||
obs, info = env.reset(seed=42)
|
||||
recreated_obs, recreated_info = recreated_env.reset(seed=42)
|
||||
assert data_equivalence(obs, recreated_obs)
|
||||
assert data_equivalence(info, recreated_info)
|
||||
|
||||
action = env.action_space.sample()
|
||||
obs, reward, terminated, truncated, info = env.step(action)
|
||||
(
|
||||
recreated_obs,
|
||||
recreated_reward,
|
||||
recreated_terminated,
|
||||
recreated_truncated,
|
||||
recreated_info,
|
||||
) = recreated_env.step(action)
|
||||
assert data_equivalence(obs, recreated_obs)
|
||||
assert data_equivalence(reward, recreated_reward)
|
||||
assert data_equivalence(terminated, recreated_terminated)
|
||||
assert data_equivalence(truncated, recreated_truncated)
|
||||
assert data_equivalence(info, recreated_info)
|
||||
|
||||
# Test the pprint of the spec_stack
|
||||
spec_stack_output = env_spec.pprint(disable_print=True)
|
||||
json_spec_stack_output = env_spec.pprint(disable_print=True)
|
||||
assert spec_stack_output == json_spec_stack_output
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"env_spec",
|
||||
[
|
||||
gym.spec("CartPole-v1"),
|
||||
gym.make("CartPole-v1").unwrapped.spec,
|
||||
gym.make("CartPole-v1").spec,
|
||||
gym.wrappers.NormalizeReward(gym.make("CartPole-v1")).spec,
|
||||
],
|
||||
)
|
||||
def test_env_spec_to_from_json(env_spec: EnvSpec):
|
||||
json_spec = env_spec.to_json()
|
||||
recreated_env_spec = EnvSpec.from_json(json_spec)
|
||||
|
||||
assert env_spec == recreated_env_spec
|
||||
|
||||
|
||||
def test_wrapped_env_entry_point():
|
||||
def _create_env():
|
||||
_env = gym.make("CartPole-v1", render_mode="rgb_array")
|
||||
_env = gym.wrappers.FlattenObservation(_env)
|
||||
return _env
|
||||
|
||||
gym.register("TestingEnv-v0", entry_point=_create_env)
|
||||
|
||||
env = gym.make("TestingEnv-v0")
|
||||
env = gym.wrappers.TimeAwareObservation(env)
|
||||
env = gym.wrappers.NormalizeReward(env, gamma=0.8)
|
||||
|
||||
recreated_env = gym.make(env.spec)
|
||||
|
||||
obs, info = env.reset(seed=42)
|
||||
recreated_obs, recreated_info = recreated_env.reset(seed=42)
|
||||
assert data_equivalence(obs, recreated_obs)
|
||||
assert data_equivalence(info, recreated_info)
|
||||
|
||||
action = env.action_space.sample()
|
||||
obs, reward, terminated, truncated, info = env.step(action)
|
||||
(
|
||||
recreated_obs,
|
||||
recreated_reward,
|
||||
recreated_terminated,
|
||||
recreated_truncated,
|
||||
recreated_info,
|
||||
) = recreated_env.step(action)
|
||||
assert data_equivalence(obs, recreated_obs)
|
||||
assert data_equivalence(reward, recreated_reward)
|
||||
assert data_equivalence(terminated, recreated_terminated)
|
||||
assert data_equivalence(truncated, recreated_truncated)
|
||||
assert data_equivalence(info, recreated_info)
|
||||
|
||||
del gym.registry["TestingEnv-v0"]
|
||||
|
||||
|
||||
def test_pickling_env_stack():
|
||||
env = gym.make("CartPole-v1", render_mode="rgb_array")
|
||||
|
||||
env = gym.wrappers.FlattenObservation(env)
|
||||
env = gym.wrappers.TimeAwareObservation(env)
|
||||
env = gym.wrappers.NormalizeReward(env, gamma=0.8)
|
||||
|
||||
pickled_env = pickle.loads(pickle.dumps(env))
|
||||
|
||||
obs, info = env.reset(seed=123)
|
||||
pickled_obs, pickled_info = pickled_env.reset(seed=123)
|
||||
|
||||
assert data_equivalence(obs, pickled_obs)
|
||||
assert data_equivalence(info, pickled_info)
|
||||
|
||||
action = env.action_space.sample()
|
||||
obs, reward, terminated, truncated, info = env.step(action)
|
||||
(
|
||||
pickled_obs,
|
||||
pickled_reward,
|
||||
pickled_terminated,
|
||||
pickled_truncated,
|
||||
pickled_info,
|
||||
) = pickled_env.step(action)
|
||||
|
||||
assert data_equivalence(obs, pickled_obs)
|
||||
assert data_equivalence(reward, pickled_reward)
|
||||
assert data_equivalence(terminated, pickled_terminated)
|
||||
assert data_equivalence(truncated, pickled_truncated)
|
||||
assert data_equivalence(info, pickled_info)
|
||||
|
||||
env.close()
|
||||
pickled_env.close()
|
||||
|
||||
|
||||
# flake8: noqa
|
||||
|
||||
|
||||
def test_env_spec_pprint():
|
||||
env = gym.make("CartPole-v1")
|
||||
env_spec = env.spec
|
||||
assert env_spec is not None
|
||||
|
||||
output = env_spec.pprint(disable_print=True)
|
||||
assert (
|
||||
output
|
||||
== """id=CartPole-v1
|
||||
reward_threshold=475.0
|
||||
max_episode_steps=500
|
||||
applied_wrappers=[
|
||||
name=PassiveEnvChecker, kwargs={},
|
||||
name=OrderEnforcing, kwargs={'disable_render_order_enforcing': False},
|
||||
name=TimeLimit, kwargs={'max_episode_steps': 500}
|
||||
]"""
|
||||
)
|
||||
|
||||
output = env_spec.pprint(disable_print=True, include_entry_points=True)
|
||||
assert (
|
||||
output
|
||||
== """id=CartPole-v1
|
||||
entry_point=gymnasium.envs.classic_control.cartpole:CartPoleEnv
|
||||
reward_threshold=475.0
|
||||
max_episode_steps=500
|
||||
applied_wrappers=[
|
||||
name=PassiveEnvChecker, entry_point=gymnasium.wrappers.env_checker:PassiveEnvChecker, kwargs={},
|
||||
name=OrderEnforcing, entry_point=gymnasium.wrappers.order_enforcing:OrderEnforcing, kwargs={'disable_render_order_enforcing': False},
|
||||
name=TimeLimit, entry_point=gymnasium.wrappers.time_limit:TimeLimit, kwargs={'max_episode_steps': 500}
|
||||
]"""
|
||||
)
|
||||
|
||||
output = env_spec.pprint(disable_print=True, print_all=True)
|
||||
assert (
|
||||
output
|
||||
== """id=CartPole-v1
|
||||
entry_point=gymnasium.envs.classic_control.cartpole:CartPoleEnv
|
||||
reward_threshold=475.0
|
||||
nondeterministic=False
|
||||
max_episode_steps=500
|
||||
order_enforce=True
|
||||
autoreset=False
|
||||
disable_env_checker=False
|
||||
applied_api_compatibility=False
|
||||
applied_wrappers=[
|
||||
name=PassiveEnvChecker, kwargs={},
|
||||
name=OrderEnforcing, kwargs={'disable_render_order_enforcing': False},
|
||||
name=TimeLimit, kwargs={'max_episode_steps': 500}
|
||||
]"""
|
||||
)
|
||||
|
||||
env_spec.applied_wrappers = ()
|
||||
output = env_spec.pprint(disable_print=True)
|
||||
assert (
|
||||
output
|
||||
== """id=CartPole-v1
|
||||
reward_threshold=475.0
|
||||
max_episode_steps=500"""
|
||||
)
|
||||
|
||||
output = env_spec.pprint(disable_print=True, print_all=True)
|
||||
assert (
|
||||
output
|
||||
== """id=CartPole-v1
|
||||
entry_point=gymnasium.envs.classic_control.cartpole:CartPoleEnv
|
||||
reward_threshold=475.0
|
||||
nondeterministic=False
|
||||
max_episode_steps=500
|
||||
order_enforce=True
|
||||
autoreset=False
|
||||
disable_env_checker=False
|
||||
applied_api_compatibility=False
|
||||
applied_wrappers=[]"""
|
||||
)
|
@@ -8,6 +8,8 @@ import numpy as np
|
||||
import pytest
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium import Env
|
||||
from gymnasium.core import ActType, ObsType, WrapperObsType
|
||||
from gymnasium.envs.classic_control import CartPoleEnv
|
||||
from gymnasium.wrappers import (
|
||||
AutoResetWrapper,
|
||||
@@ -355,3 +357,69 @@ def test_import_module_during_make():
|
||||
env.close()
|
||||
|
||||
del gym.registry["RegisterDuringMake-v0"]
|
||||
|
||||
|
||||
class NoRecordArgsWrapper(gym.ObservationWrapper):
|
||||
def __init__(self, env: Env[ObsType, ActType]):
|
||||
super().__init__(env)
|
||||
|
||||
def observation(self, observation: ObsType) -> WrapperObsType:
|
||||
return self.observation_space.sample()
|
||||
|
||||
|
||||
def test_make_env_spec():
|
||||
# make
|
||||
env_1 = gym.make(gym.spec("CartPole-v1"))
|
||||
assert isinstance(env_1, CartPoleEnv)
|
||||
assert env_1 is env_1.unwrapped
|
||||
env_1.close()
|
||||
|
||||
# make with applied wrappers
|
||||
env_2 = gym.wrappers.NormalizeReward(
|
||||
gym.wrappers.TimeAwareObservation(
|
||||
gym.wrappers.FlattenObservation(
|
||||
gym.make("CartPole-v1", render_mode="rgb_array")
|
||||
)
|
||||
),
|
||||
gamma=0.8,
|
||||
)
|
||||
env_2_recreated = gym.make(env_2.spec)
|
||||
assert env_2.spec == env_2_recreated.spec
|
||||
env_2.close()
|
||||
env_2_recreated.close()
|
||||
|
||||
# make with callable entry point
|
||||
gym.register("CartPole-v2", lambda: CartPoleEnv())
|
||||
env_3 = gym.make("CartPole-v2")
|
||||
assert isinstance(env_3.unwrapped, CartPoleEnv)
|
||||
env_3.close()
|
||||
|
||||
# make with wrapper in env-creator
|
||||
gym.register(
|
||||
"CartPole-v3", lambda: gym.wrappers.TimeAwareObservation(CartPoleEnv())
|
||||
)
|
||||
env_4 = gym.make(gym.spec("CartPole-v3"))
|
||||
assert isinstance(env_4, gym.wrappers.TimeAwareObservation)
|
||||
assert isinstance(env_4.env, CartPoleEnv)
|
||||
env_4.close()
|
||||
|
||||
# make with no ezpickle wrapper
|
||||
env_5 = NoRecordArgsWrapper(gym.make("CartPole-v1").unwrapped)
|
||||
env_5.close()
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=re.escape(
|
||||
"NoRecordArgsWrapper wrapper does not inherit from `gymnasium.utils.RecordConstructorArgs`, therefore, the wrapper cannot be recreated."
|
||||
),
|
||||
):
|
||||
gym.make(env_5.spec)
|
||||
|
||||
# make with no ezpickle wrapper but in the entry point
|
||||
gym.register("CartPole-v4", entry_point=lambda: NoRecordArgsWrapper(CartPoleEnv()))
|
||||
env_6 = gym.make(gym.spec("CartPole-v4"))
|
||||
assert isinstance(env_6, NoRecordArgsWrapper)
|
||||
assert isinstance(env_6.unwrapped, CartPoleEnv)
|
||||
|
||||
del gym.registry["CartPole-v2"]
|
||||
del gym.registry["CartPole-v3"]
|
||||
del gym.registry["CartPole-v4"]
|
||||
|
@@ -22,7 +22,7 @@ def test_mujoco_action_dimensions(env_spec: EnvSpec):
|
||||
* Too many dimensions
|
||||
* Incorrect shape
|
||||
"""
|
||||
env = env_spec.make(disable_env_checker=True)
|
||||
env = env_spec.make()
|
||||
env.reset()
|
||||
|
||||
# Too few actions
|
||||
|
@@ -42,7 +42,7 @@ CHECK_ENV_IGNORE_WARNINGS = [
|
||||
def test_all_env_api(spec):
|
||||
"""Check that all environments pass the environment checker with no warnings other than the expected."""
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
env = spec.make(disable_env_checker=True).unwrapped
|
||||
env = spec.make().unwrapped
|
||||
check_env(env, skip_render_check=True)
|
||||
|
||||
env.close()
|
||||
|
@@ -15,8 +15,8 @@ def verify_environments_match(
|
||||
):
|
||||
"""Verifies with two environment ids (old and new) are identical in obs, reward and done
|
||||
(except info where all old info must be contained in new info)."""
|
||||
old_env = envs.make(old_env_id, disable_env_checker=True)
|
||||
new_env = envs.make(new_env_id, disable_env_checker=True)
|
||||
old_env = envs.make(old_env_id)
|
||||
new_env = envs.make(new_env_id)
|
||||
|
||||
old_reset_obs, old_info = old_env.reset(seed=seed)
|
||||
new_reset_obs, new_info = new_env.reset(seed=seed)
|
||||
|
@@ -106,7 +106,7 @@ class TestNestedDictWrapper:
|
||||
observation_space = env.observation_space
|
||||
assert isinstance(observation_space, Dict)
|
||||
|
||||
wrapped_env = FlattenObservation(FilterObservation(env, env.obs_keys))
|
||||
wrapped_env = FlattenObservation(FilterObservation(env, list(env.obs_keys)))
|
||||
assert wrapped_env.observation_space.shape == flat_shape
|
||||
|
||||
assert wrapped_env.observation_space.dtype == np.float32
|
||||
@@ -114,7 +114,7 @@ class TestNestedDictWrapper:
|
||||
@pytest.mark.parametrize("observation_space, flat_shape", NESTED_DICT_TEST_CASES)
|
||||
def test_nested_dicts_ravel(self, observation_space, flat_shape):
|
||||
env = FakeEnvironment(observation_space=observation_space)
|
||||
wrapped_env = FlattenObservation(FilterObservation(env, env.obs_keys))
|
||||
wrapped_env = FlattenObservation(FilterObservation(env, list(env.obs_keys)))
|
||||
obs, info = wrapped_env.reset()
|
||||
assert obs.shape == wrapped_env.observation_space.shape
|
||||
assert isinstance(info, dict)
|
||||
|
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import gymnasium as gym
|
||||
|
||||
|
||||
def has_wrapper(wrapped_env: gym.Env, wrapper_type: type) -> bool:
|
||||
def has_wrapper(wrapped_env: gym.Env, wrapper_type: type[gym.Wrapper]) -> bool:
|
||||
while isinstance(wrapped_env, gym.Wrapper):
|
||||
if isinstance(wrapped_env, wrapper_type):
|
||||
return True
|
||||
|
Reference in New Issue
Block a user