mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-01 06:07:08 +00:00
Update EnvSpec
and make
to support reproducing the "whole" environment spec including wrappers (#292)
Co-authored-by: will <will2346@live.co.uk> Co-authored-by: Will Dudley <14932240+WillDudley@users.noreply.github.com> Co-authored-by: Mark Towers <marktowers@Marks-MacBook-Pro.local>
This commit is contained in:
@@ -1,12 +1,13 @@
|
|||||||
"""Core API for Environment, Wrapper, ActionWrapper, RewardWrapper and ObservationWrapper."""
|
"""Core API for Environment, Wrapper, ActionWrapper, RewardWrapper and ObservationWrapper."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from copy import deepcopy
|
||||||
from typing import TYPE_CHECKING, Any, Generic, SupportsFloat, TypeVar
|
from typing import TYPE_CHECKING, Any, Generic, SupportsFloat, TypeVar
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from gymnasium import spaces
|
from gymnasium import spaces
|
||||||
from gymnasium.utils import seeding
|
from gymnasium.utils import RecordConstructorArgs, seeding
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -277,8 +278,32 @@ class Wrapper(
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def spec(self) -> EnvSpec | None:
|
def spec(self) -> EnvSpec | None:
|
||||||
"""Returns the :attr:`Env` :attr:`spec` attribute."""
|
"""Returns the :attr:`Env` :attr:`spec` attribute with the `WrapperSpec` if the wrapper inherits from `EzPickle`."""
|
||||||
return self.env.spec
|
env_spec = self.env.spec
|
||||||
|
|
||||||
|
if env_spec is not None:
|
||||||
|
from gymnasium.envs.registration import WrapperSpec
|
||||||
|
|
||||||
|
# See if the wrapper inherits from `RecordConstructorArgs` then add the kwargs otherwise use `None` for the wrapper kwargs. This will raise an error in `make`
|
||||||
|
if isinstance(self, RecordConstructorArgs):
|
||||||
|
kwargs = getattr(self, "_saved_kwargs")
|
||||||
|
if "env" in kwargs:
|
||||||
|
kwargs = deepcopy(kwargs)
|
||||||
|
kwargs.pop("env")
|
||||||
|
else:
|
||||||
|
kwargs = None
|
||||||
|
|
||||||
|
wrapper_spec = WrapperSpec(
|
||||||
|
name=self.class_name(),
|
||||||
|
entry_point=f"{self.__module__}:{type(self).__name__}",
|
||||||
|
kwargs=kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# to avoid reference issues we deepcopy the prior environments spec and add the new information
|
||||||
|
env_spec = deepcopy(env_spec)
|
||||||
|
env_spec.applied_wrappers += (wrapper_spec,)
|
||||||
|
|
||||||
|
return env_spec
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def class_name(cls) -> str:
|
def class_name(cls) -> str:
|
||||||
@@ -409,7 +434,7 @@ class ObservationWrapper(Wrapper[WrapperObsType, ActType, ObsType, ActType]):
|
|||||||
|
|
||||||
def __init__(self, env: Env[ObsType, ActType]):
|
def __init__(self, env: Env[ObsType, ActType]):
|
||||||
"""Constructor for the observation wrapper."""
|
"""Constructor for the observation wrapper."""
|
||||||
super().__init__(env)
|
Wrapper.__init__(self, env)
|
||||||
|
|
||||||
def reset(
|
def reset(
|
||||||
self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
||||||
@@ -449,7 +474,7 @@ class RewardWrapper(Wrapper[ObsType, ActType, ObsType, ActType]):
|
|||||||
|
|
||||||
def __init__(self, env: Env[ObsType, ActType]):
|
def __init__(self, env: Env[ObsType, ActType]):
|
||||||
"""Constructor for the Reward wrapper."""
|
"""Constructor for the Reward wrapper."""
|
||||||
super().__init__(env)
|
Wrapper.__init__(self, env)
|
||||||
|
|
||||||
def step(
|
def step(
|
||||||
self, action: ActType
|
self, action: ActType
|
||||||
@@ -485,7 +510,7 @@ class ActionWrapper(Wrapper[ObsType, WrapperActType, ObsType, ActType]):
|
|||||||
|
|
||||||
def __init__(self, env: Env[ObsType, ActType]):
|
def __init__(self, env: Env[ObsType, ActType]):
|
||||||
"""Constructor for the action wrapper."""
|
"""Constructor for the action wrapper."""
|
||||||
super().__init__(env)
|
Wrapper.__init__(self, env)
|
||||||
|
|
||||||
def step(
|
def step(
|
||||||
self, action: WrapperActType
|
self, action: WrapperActType
|
||||||
|
@@ -3,9 +3,11 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
import copy
|
import copy
|
||||||
|
import dataclasses
|
||||||
import difflib
|
import difflib
|
||||||
import importlib
|
import importlib
|
||||||
import importlib.util
|
import importlib.util
|
||||||
|
import json
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
@@ -17,13 +19,13 @@ from gymnasium import Env, Wrapper, error, logger
|
|||||||
from gymnasium.experimental.vector import AsyncVectorEnv, SyncVectorEnv, VectorEnv
|
from gymnasium.experimental.vector import AsyncVectorEnv, SyncVectorEnv, VectorEnv
|
||||||
from gymnasium.wrappers import (
|
from gymnasium.wrappers import (
|
||||||
AutoResetWrapper,
|
AutoResetWrapper,
|
||||||
|
EnvCompatibility,
|
||||||
HumanRendering,
|
HumanRendering,
|
||||||
OrderEnforcing,
|
OrderEnforcing,
|
||||||
|
PassiveEnvChecker,
|
||||||
RenderCollection,
|
RenderCollection,
|
||||||
TimeLimit,
|
TimeLimit,
|
||||||
)
|
)
|
||||||
from gymnasium.wrappers.compatibility import EnvCompatibility
|
|
||||||
from gymnasium.wrappers.env_checker import PassiveEnvChecker
|
|
||||||
|
|
||||||
|
|
||||||
if sys.version_info < (3, 10):
|
if sys.version_info < (3, 10):
|
||||||
@@ -43,11 +45,14 @@ ENV_ID_RE = re.compile(
|
|||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"EnvSpec",
|
|
||||||
"registry",
|
"registry",
|
||||||
"current_namespace",
|
"current_namespace",
|
||||||
|
"EnvSpec",
|
||||||
|
"WrapperSpec",
|
||||||
|
# Functions
|
||||||
"register",
|
"register",
|
||||||
"make",
|
"make",
|
||||||
|
"make_vec",
|
||||||
"spec",
|
"spec",
|
||||||
"pprint_registry",
|
"pprint_registry",
|
||||||
]
|
]
|
||||||
@@ -67,6 +72,20 @@ class VectorEnvCreator(Protocol):
|
|||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class WrapperSpec:
|
||||||
|
"""A specification for recording wrapper configs.
|
||||||
|
|
||||||
|
* name: The name of the wrapper.
|
||||||
|
* entry_point: The location of the wrapper to create from.
|
||||||
|
* kwargs: Additional keyword arguments passed to the wrapper. If the wrapper doesn't inherit from EzPickle then this is ``None``
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
entry_point: str
|
||||||
|
kwargs: dict[str, Any] | None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class EnvSpec:
|
class EnvSpec:
|
||||||
"""A specification for creating environments with :meth:`gymnasium.make`.
|
"""A specification for creating environments with :meth:`gymnasium.make`.
|
||||||
@@ -80,6 +99,7 @@ class EnvSpec:
|
|||||||
* **autoreset**: If to automatically reset the environment on episode end
|
* **autoreset**: If to automatically reset the environment on episode end
|
||||||
* **disable_env_checker**: If to disable the environment checker wrapper in :meth:`gymnasium.make`, by default False (runs the environment checker)
|
* **disable_env_checker**: If to disable the environment checker wrapper in :meth:`gymnasium.make`, by default False (runs the environment checker)
|
||||||
* **kwargs**: Additional keyword arguments passed to the environment during initialisation
|
* **kwargs**: Additional keyword arguments passed to the environment during initialisation
|
||||||
|
* **applied_wrappers**: A tuple of applied wrappers (WrapperSpec)
|
||||||
* **vector_entry_point**: The location of the vectorized environment to create from
|
* **vector_entry_point**: The location of the vectorized environment to create from
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -97,27 +117,152 @@ class EnvSpec:
|
|||||||
disable_env_checker: bool = field(default=False)
|
disable_env_checker: bool = field(default=False)
|
||||||
apply_api_compatibility: bool = field(default=False)
|
apply_api_compatibility: bool = field(default=False)
|
||||||
|
|
||||||
|
# Environment arguments
|
||||||
|
kwargs: dict = field(default_factory=dict)
|
||||||
|
|
||||||
# post-init attributes
|
# post-init attributes
|
||||||
namespace: str | None = field(init=False)
|
namespace: str | None = field(init=False)
|
||||||
name: str = field(init=False)
|
name: str = field(init=False)
|
||||||
version: int | None = field(init=False)
|
version: int | None = field(init=False)
|
||||||
|
|
||||||
# Environment arguments
|
# applied wrappers
|
||||||
kwargs: dict = field(default_factory=dict)
|
applied_wrappers: tuple[WrapperSpec, ...] = field(init=False, default_factory=tuple)
|
||||||
|
|
||||||
# Vectorized environment
|
# Vectorized environment entry point
|
||||||
vector_entry_point: str | None = field(default=None)
|
vector_entry_point: VectorEnvCreator | str | None = field(default=None)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
"""Calls after the spec is created to extract the namespace, name and version from the id."""
|
"""Calls after the spec is created to extract the namespace, name and version from the environment id."""
|
||||||
# Initialize namespace, name, version
|
|
||||||
self.namespace, self.name, self.version = parse_env_id(self.id)
|
self.namespace, self.name, self.version = parse_env_id(self.id)
|
||||||
|
|
||||||
def make(self, **kwargs: Any) -> Env:
|
def make(self, **kwargs: Any) -> Env:
|
||||||
"""Calls ``make`` using the environment spec and any keyword arguments."""
|
"""Calls ``make`` using the environment spec and any keyword arguments."""
|
||||||
# For compatibility purposes
|
|
||||||
return make(self, **kwargs)
|
return make(self, **kwargs)
|
||||||
|
|
||||||
|
def to_json(self) -> str:
|
||||||
|
"""Converts the environment spec into a json compatible string.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A jsonifyied string for the environment spec
|
||||||
|
"""
|
||||||
|
env_spec_dict = dataclasses.asdict(self)
|
||||||
|
# As the namespace, name and version are initialised after `init` then we remove the attributes
|
||||||
|
env_spec_dict.pop("namespace")
|
||||||
|
env_spec_dict.pop("name")
|
||||||
|
env_spec_dict.pop("version")
|
||||||
|
|
||||||
|
# To check that the environment spec can be transformed to a json compatible type
|
||||||
|
self._check_can_jsonify(env_spec_dict)
|
||||||
|
|
||||||
|
return json.dumps(env_spec_dict)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _check_can_jsonify(env_spec: dict[str, Any]):
|
||||||
|
"""Warns the user about serialisation failing if the spec contains a callable.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env_spec: An environment or wrapper specification.
|
||||||
|
|
||||||
|
Returns: The specification with lambda functions converted to strings.
|
||||||
|
|
||||||
|
"""
|
||||||
|
spec_name = env_spec["name"] if "name" in env_spec else env_spec["id"]
|
||||||
|
|
||||||
|
for key, value in env_spec.items():
|
||||||
|
if callable(value):
|
||||||
|
ValueError(
|
||||||
|
f"Callable found in {spec_name} for {key} attribute with value={value}. Currently, Gymnasium does not support serialising callables."
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_json(json_env_spec: str) -> EnvSpec:
|
||||||
|
"""Converts a JSON string into a specification stack.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
json_env_spec: A JSON string representing the env specification.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An environment spec
|
||||||
|
"""
|
||||||
|
parsed_env_spec = json.loads(json_env_spec)
|
||||||
|
|
||||||
|
applied_wrapper_specs: list[WrapperSpec] = []
|
||||||
|
for wrapper_spec_json in parsed_env_spec.pop("applied_wrappers"):
|
||||||
|
try:
|
||||||
|
applied_wrapper_specs.append(WrapperSpec(**wrapper_spec_json))
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(
|
||||||
|
f"An issue occurred when trying to make {wrapper_spec_json} a WrapperSpec"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
try:
|
||||||
|
env_spec = EnvSpec(**parsed_env_spec)
|
||||||
|
env_spec.applied_wrappers = tuple(applied_wrapper_specs)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(
|
||||||
|
f"An issue occurred when trying to make {parsed_env_spec} an EnvSpec"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
return env_spec
|
||||||
|
|
||||||
|
def pprint(
|
||||||
|
self,
|
||||||
|
disable_print: bool = False,
|
||||||
|
include_entry_points: bool = False,
|
||||||
|
print_all: bool = False,
|
||||||
|
) -> str | None:
|
||||||
|
"""Pretty prints the environment spec.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
disable_print: If to disable print and return the output
|
||||||
|
include_entry_points: If to include the entry_points in the output
|
||||||
|
print_all: If to print all information, including variables with default values
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
If ``disable_print is True`` a string otherwise ``None``
|
||||||
|
"""
|
||||||
|
output = f"id={self.id}"
|
||||||
|
if print_all or include_entry_points:
|
||||||
|
output += f"\nentry_point={self.entry_point}"
|
||||||
|
|
||||||
|
if print_all or self.reward_threshold is not None:
|
||||||
|
output += f"\nreward_threshold={self.reward_threshold}"
|
||||||
|
if print_all or self.nondeterministic is not False:
|
||||||
|
output += f"\nnondeterministic={self.nondeterministic}"
|
||||||
|
|
||||||
|
if print_all or self.max_episode_steps is not None:
|
||||||
|
output += f"\nmax_episode_steps={self.max_episode_steps}"
|
||||||
|
if print_all or self.order_enforce is not True:
|
||||||
|
output += f"\norder_enforce={self.order_enforce}"
|
||||||
|
if print_all or self.autoreset is not False:
|
||||||
|
output += f"\nautoreset={self.autoreset}"
|
||||||
|
if print_all or self.disable_env_checker is not False:
|
||||||
|
output += f"\ndisable_env_checker={self.disable_env_checker}"
|
||||||
|
if print_all or self.apply_api_compatibility is not False:
|
||||||
|
output += f"\napplied_api_compatibility={self.apply_api_compatibility}"
|
||||||
|
|
||||||
|
if print_all or self.applied_wrappers:
|
||||||
|
wrapper_output: list[str] = []
|
||||||
|
for wrapper_spec in self.applied_wrappers:
|
||||||
|
if include_entry_points:
|
||||||
|
wrapper_output.append(
|
||||||
|
f"\n\tname={wrapper_spec.name}, entry_point={wrapper_spec.entry_point}, kwargs={wrapper_spec.kwargs}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
wrapper_output.append(
|
||||||
|
f"\n\tname={wrapper_spec.name}, kwargs={wrapper_spec.kwargs}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(wrapper_output) == 0:
|
||||||
|
output += "\napplied_wrappers=[]"
|
||||||
|
else:
|
||||||
|
output += f"\napplied_wrappers=[{','.join(wrapper_output)}\n]"
|
||||||
|
|
||||||
|
if disable_print:
|
||||||
|
return output
|
||||||
|
else:
|
||||||
|
print(output)
|
||||||
|
|
||||||
|
|
||||||
# Global registry of environments. Meant to be accessed through `register` and `make`
|
# Global registry of environments. Meant to be accessed through `register` and `make`
|
||||||
registry: dict[str, EnvSpec] = {}
|
registry: dict[str, EnvSpec] = {}
|
||||||
@@ -352,8 +497,12 @@ def _check_metadata(testing_metadata: dict[str, Any]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _find_spec(id: str) -> EnvSpec:
|
def _find_spec(env_id: str) -> EnvSpec:
|
||||||
module, env_name = (None, id) if ":" not in id else id.split(":")
|
# For string id's, load the environment spec from the registry then make the environment spec
|
||||||
|
assert isinstance(env_id, str)
|
||||||
|
|
||||||
|
# The environment name can include an unloaded module in "module:env_name" style
|
||||||
|
module, env_name = (None, env_id) if ":" not in env_id else env_id.split(":")
|
||||||
if module is not None:
|
if module is not None:
|
||||||
try:
|
try:
|
||||||
importlib.import_module(module)
|
importlib.import_module(module)
|
||||||
@@ -391,7 +540,7 @@ def _find_spec(id: str) -> EnvSpec:
|
|||||||
return env_spec
|
return env_spec
|
||||||
|
|
||||||
|
|
||||||
def load_env(name: str) -> EnvCreator:
|
def load_env_creator(name: str) -> EnvCreator | VectorEnvCreator:
|
||||||
"""Loads an environment with name of style ``"(import path):(environment name)"`` and returns the environment creation function, normally the environment class type.
|
"""Loads an environment with name of style ``"(import path):(environment name)"`` and returns the environment creation function, normally the environment class type.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -406,6 +555,161 @@ def load_env(name: str) -> EnvCreator:
|
|||||||
return fn
|
return fn
|
||||||
|
|
||||||
|
|
||||||
|
def _create_from_env_spec(
|
||||||
|
env_spec: EnvSpec,
|
||||||
|
kwargs: dict[str, Any],
|
||||||
|
) -> Env:
|
||||||
|
"""Recreates an environment spec using a list of wrapper specs."""
|
||||||
|
if callable(env_spec.entry_point):
|
||||||
|
env_creator = env_spec.entry_point
|
||||||
|
else:
|
||||||
|
env_creator: EnvCreator = load_env_creator(env_spec.entry_point)
|
||||||
|
|
||||||
|
# Create the environment
|
||||||
|
env: Env = env_creator(**env_spec.kwargs, **kwargs)
|
||||||
|
|
||||||
|
# Set the `EnvSpec` to the environment
|
||||||
|
new_env_spec = copy.deepcopy(env_spec)
|
||||||
|
new_env_spec.applied_wrappers = ()
|
||||||
|
new_env_spec.kwargs.update(kwargs)
|
||||||
|
env.unwrapped.spec = new_env_spec
|
||||||
|
|
||||||
|
# Check if the environment spec
|
||||||
|
assert env.spec is not None # this is for pyright
|
||||||
|
num_prior_wrappers = len(env.spec.applied_wrappers)
|
||||||
|
if env_spec.applied_wrappers[:num_prior_wrappers] != env.spec.applied_wrappers:
|
||||||
|
for env_spec_wrapper_spec, recreated_wrapper_spec in zip(
|
||||||
|
env_spec.applied_wrappers, env.spec.applied_wrappers
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"The environment's wrapper spec {recreated_wrapper_spec} is different from the saved `EnvSpec` applied_wrappers {env_spec_wrapper_spec}"
|
||||||
|
)
|
||||||
|
|
||||||
|
for wrapper_spec in env_spec.applied_wrappers[num_prior_wrappers:]:
|
||||||
|
if wrapper_spec.kwargs is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"{wrapper_spec.name} wrapper does not inherit from `gymnasium.utils.RecordConstructorArgs`, therefore, the wrapper cannot be recreated."
|
||||||
|
)
|
||||||
|
|
||||||
|
env = load_env_creator(wrapper_spec.entry_point)(env=env, **wrapper_spec.kwargs)
|
||||||
|
|
||||||
|
return env
|
||||||
|
|
||||||
|
|
||||||
|
def _create_from_env_id(
|
||||||
|
env_spec: EnvSpec,
|
||||||
|
kwargs: dict[str, Any],
|
||||||
|
max_episode_steps: int | None = None,
|
||||||
|
autoreset: bool = False,
|
||||||
|
apply_api_compatibility: bool | None = None,
|
||||||
|
disable_env_checker: bool | None = None,
|
||||||
|
) -> Env:
|
||||||
|
"""Creates an environment based on the `env_spec` along with wrapper options. See `make` for their meaning."""
|
||||||
|
spec_kwargs = copy.deepcopy(env_spec.kwargs)
|
||||||
|
spec_kwargs.update(kwargs)
|
||||||
|
|
||||||
|
# Load the environment creator
|
||||||
|
if env_spec.entry_point is None:
|
||||||
|
raise error.Error(f"{env_spec.id} registered but entry_point is not specified")
|
||||||
|
elif callable(env_spec.entry_point):
|
||||||
|
env_creator = env_spec.entry_point
|
||||||
|
else:
|
||||||
|
# Assume it's a string
|
||||||
|
env_creator = load_env_creator(env_spec.entry_point)
|
||||||
|
|
||||||
|
# Determine if to use the rendering
|
||||||
|
render_modes: list[str] | None = None
|
||||||
|
if hasattr(env_creator, "metadata"):
|
||||||
|
_check_metadata(env_creator.metadata)
|
||||||
|
render_modes = env_creator.metadata.get("render_modes")
|
||||||
|
mode = spec_kwargs.get("render_mode")
|
||||||
|
apply_human_rendering = False
|
||||||
|
apply_render_collection = False
|
||||||
|
|
||||||
|
# If mode is not valid, try applying HumanRendering/RenderCollection wrappers
|
||||||
|
if mode is not None and render_modes is not None and mode not in render_modes:
|
||||||
|
displayable_modes = {"rgb_array", "rgb_array_list"}.intersection(render_modes)
|
||||||
|
if mode == "human" and len(displayable_modes) > 0:
|
||||||
|
logger.warn(
|
||||||
|
"You are trying to use 'human' rendering for an environment that doesn't natively support it. "
|
||||||
|
"The HumanRendering wrapper is being applied to your environment."
|
||||||
|
)
|
||||||
|
spec_kwargs["render_mode"] = displayable_modes.pop()
|
||||||
|
apply_human_rendering = True
|
||||||
|
elif mode.endswith("_list") and mode[: -len("_list")] in render_modes:
|
||||||
|
spec_kwargs["render_mode"] = mode[: -len("_list")]
|
||||||
|
apply_render_collection = True
|
||||||
|
else:
|
||||||
|
logger.warn(
|
||||||
|
f"The environment is being initialised with render_mode={mode!r} "
|
||||||
|
f"that is not in the possible render_modes ({render_modes})."
|
||||||
|
)
|
||||||
|
|
||||||
|
if apply_api_compatibility or (
|
||||||
|
apply_api_compatibility is None and env_spec.apply_api_compatibility
|
||||||
|
):
|
||||||
|
# If we use the compatibility layer, we treat the render mode explicitly and don't pass it to the env creator
|
||||||
|
render_mode = spec_kwargs.pop("render_mode", None)
|
||||||
|
else:
|
||||||
|
render_mode = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
env = env_creator(**spec_kwargs)
|
||||||
|
except TypeError as e:
|
||||||
|
if (
|
||||||
|
str(e).find("got an unexpected keyword argument 'render_mode'") >= 0
|
||||||
|
and apply_human_rendering
|
||||||
|
):
|
||||||
|
raise error.Error(
|
||||||
|
f"You passed render_mode='human' although {env_spec.id} doesn't implement human-rendering natively. "
|
||||||
|
"Gym tried to apply the HumanRendering wrapper but it looks like your environment is using the old "
|
||||||
|
"rendering API, which is not supported by the HumanRendering wrapper."
|
||||||
|
) from e
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
# Copies the environment creation specification and kwargs to add to the environment specification details
|
||||||
|
env_spec = copy.deepcopy(env_spec)
|
||||||
|
env_spec.kwargs = spec_kwargs
|
||||||
|
env.unwrapped.spec = env_spec
|
||||||
|
|
||||||
|
# Add step API wrapper
|
||||||
|
if apply_api_compatibility is True or (
|
||||||
|
apply_api_compatibility is None and env_spec.apply_api_compatibility is True
|
||||||
|
):
|
||||||
|
env = EnvCompatibility(env, render_mode)
|
||||||
|
|
||||||
|
# Run the environment checker as the lowest level wrapper
|
||||||
|
if disable_env_checker is False or (
|
||||||
|
disable_env_checker is None and env_spec.disable_env_checker is False
|
||||||
|
):
|
||||||
|
env = PassiveEnvChecker(env)
|
||||||
|
|
||||||
|
# Add the order enforcing wrapper
|
||||||
|
if env_spec.order_enforce:
|
||||||
|
env = OrderEnforcing(env)
|
||||||
|
|
||||||
|
# Add the time limit wrapper
|
||||||
|
if max_episode_steps is not None:
|
||||||
|
assert env.unwrapped.spec is not None # for pyright
|
||||||
|
env.unwrapped.spec.max_episode_steps = max_episode_steps
|
||||||
|
env = TimeLimit(env, max_episode_steps)
|
||||||
|
elif env_spec.max_episode_steps is not None:
|
||||||
|
env = TimeLimit(env, env_spec.max_episode_steps)
|
||||||
|
|
||||||
|
# Add the auto-reset wrapper
|
||||||
|
if autoreset:
|
||||||
|
env = AutoResetWrapper(env)
|
||||||
|
|
||||||
|
# Add human rendering wrapper
|
||||||
|
if apply_human_rendering:
|
||||||
|
env = HumanRendering(env)
|
||||||
|
elif apply_render_collection:
|
||||||
|
env = RenderCollection(env)
|
||||||
|
|
||||||
|
return env
|
||||||
|
|
||||||
|
|
||||||
def load_plugin_envs(entry_point: str = "gymnasium.envs"):
|
def load_plugin_envs(entry_point: str = "gymnasium.envs"):
|
||||||
"""Load modules (plugins) using the gymnasium entry points in order to register external module's environments on ``import gymnasium``.
|
"""Load modules (plugins) using the gymnasium entry points in order to register external module's environments on ``import gymnasium``.
|
||||||
|
|
||||||
@@ -437,10 +741,8 @@ def load_plugin_envs(entry_point: str = "gymnasium.envs"):
|
|||||||
|
|
||||||
context = namespace(plugin.name)
|
context = namespace(plugin.name)
|
||||||
if plugin.name.startswith("__") and plugin.name.endswith("__"):
|
if plugin.name.startswith("__") and plugin.name.endswith("__"):
|
||||||
# `__internal__` is an artifact of the plugin system when
|
# `__internal__` is an artifact of the plugin system when the root namespace had an allow-list.
|
||||||
# the root namespace had an allow-list. The allow-list is now
|
# The allow-list is now removed and plugins can register environments in the root namespace with the `__root__` magic key.
|
||||||
# removed and plugins can register environments in the root
|
|
||||||
# namespace with the `__root__` magic key.
|
|
||||||
if plugin.name == "__root__" or plugin.name == "__internal__":
|
if plugin.name == "__root__" or plugin.name == "__internal__":
|
||||||
context = contextlib.nullcontext()
|
context = contextlib.nullcontext()
|
||||||
else:
|
else:
|
||||||
@@ -533,9 +835,9 @@ def register(
|
|||||||
order_enforce=order_enforce,
|
order_enforce=order_enforce,
|
||||||
autoreset=autoreset,
|
autoreset=autoreset,
|
||||||
disable_env_checker=disable_env_checker,
|
disable_env_checker=disable_env_checker,
|
||||||
|
**kwargs,
|
||||||
apply_api_compatibility=apply_api_compatibility,
|
apply_api_compatibility=apply_api_compatibility,
|
||||||
vector_entry_point=vector_entry_point,
|
vector_entry_point=vector_entry_point,
|
||||||
**kwargs,
|
|
||||||
)
|
)
|
||||||
_check_spec_register(new_spec)
|
_check_spec_register(new_spec)
|
||||||
|
|
||||||
@@ -576,116 +878,47 @@ def make(
|
|||||||
Error: If the ``id`` doesn't exist in the :attr:`registry`
|
Error: If the ``id`` doesn't exist in the :attr:`registry`
|
||||||
"""
|
"""
|
||||||
if isinstance(id, EnvSpec):
|
if isinstance(id, EnvSpec):
|
||||||
env_spec = id
|
if hasattr(id, "applied_wrappers") and id.applied_wrappers is not None:
|
||||||
|
if max_episode_steps is not None:
|
||||||
|
logger.warn(
|
||||||
|
f"For `gymnasium.make` with an `EnvSpec`, the `max_episode_step` parameter is ignored, use `gym.make({id.id}, max_episode_steps={max_episode_steps})` and any additional wrappers"
|
||||||
|
)
|
||||||
|
if autoreset is True:
|
||||||
|
logger.warn(
|
||||||
|
f"For `gymnasium.make` with an `EnvSpec`, the `autoreset` parameter is ignored, use `gym.make({id.id}, autoreset={autoreset})` and any additional wrappers"
|
||||||
|
)
|
||||||
|
if apply_api_compatibility is not None:
|
||||||
|
logger.warn(
|
||||||
|
f"For `gymnasium.make` with an `EnvSpec`, the `apply_api_compatibility` parameter is ignored, use `gym.make({id.id}, apply_api_compatibility={apply_api_compatibility})` and any additional wrappers"
|
||||||
|
)
|
||||||
|
if disable_env_checker is not None:
|
||||||
|
logger.warn(
|
||||||
|
f"For `gymnasium.make` with an `EnvSpec`, the `disable_env_checker` parameter is ignored, use `gym.make({id.id}, disable_env_checker={disable_env_checker})` and any additional wrappers"
|
||||||
|
)
|
||||||
|
|
||||||
|
return _create_from_env_spec(
|
||||||
|
id,
|
||||||
|
kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"The EnvSpec used does not contain `applied_wrappers` parameters or is `None`. Expected to be a tuple, actually {id}."
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
|
# For string id's, load the environment spec from the registry then make the environment spec
|
||||||
|
assert isinstance(id, str)
|
||||||
|
|
||||||
# The environment name can include an unloaded module in "module:env_name" style
|
# The environment name can include an unloaded module in "module:env_name" style
|
||||||
env_spec = _find_spec(id)
|
env_spec = _find_spec(id)
|
||||||
|
|
||||||
assert isinstance(
|
return _create_from_env_id(
|
||||||
env_spec, EnvSpec
|
env_spec,
|
||||||
), f"We expected to collect an `EnvSpec`, actually collected a {type(env_spec)}"
|
kwargs,
|
||||||
# Extract the spec kwargs and append the make kwargs
|
max_episode_steps=max_episode_steps,
|
||||||
spec_kwargs = env_spec.kwargs.copy()
|
autoreset=autoreset,
|
||||||
spec_kwargs.update(kwargs)
|
apply_api_compatibility=apply_api_compatibility,
|
||||||
|
disable_env_checker=disable_env_checker,
|
||||||
# Load the environment creator
|
)
|
||||||
if env_spec.entry_point is None:
|
|
||||||
raise error.Error(f"{env_spec.id} registered but entry_point is not specified")
|
|
||||||
elif callable(env_spec.entry_point):
|
|
||||||
env_creator = env_spec.entry_point
|
|
||||||
else:
|
|
||||||
# Assume it's a string
|
|
||||||
env_creator = load_env(env_spec.entry_point)
|
|
||||||
|
|
||||||
# Determine if to use the rendering
|
|
||||||
render_modes: list[str] | None = None
|
|
||||||
if hasattr(env_creator, "metadata"):
|
|
||||||
_check_metadata(env_creator.metadata)
|
|
||||||
render_modes = env_creator.metadata.get("render_modes")
|
|
||||||
mode = spec_kwargs.get("render_mode")
|
|
||||||
apply_human_rendering = False
|
|
||||||
apply_render_collection = False
|
|
||||||
|
|
||||||
# If mode is not valid, try applying HumanRendering/RenderCollection wrappers
|
|
||||||
if mode is not None and render_modes is not None and mode not in render_modes:
|
|
||||||
displayable_modes = {"rgb_array", "rgb_array_list"}.intersection(render_modes)
|
|
||||||
if mode == "human" and len(displayable_modes) > 0:
|
|
||||||
logger.warn(
|
|
||||||
"You are trying to use 'human' rendering for an environment that doesn't natively support it. "
|
|
||||||
"The HumanRendering wrapper is being applied to your environment."
|
|
||||||
)
|
|
||||||
spec_kwargs["render_mode"] = displayable_modes.pop()
|
|
||||||
apply_human_rendering = True
|
|
||||||
elif mode.endswith("_list") and mode[: -len("_list")] in render_modes:
|
|
||||||
spec_kwargs["render_mode"] = mode[: -len("_list")]
|
|
||||||
apply_render_collection = True
|
|
||||||
else:
|
|
||||||
logger.warn(
|
|
||||||
f"The environment is being initialised with render_mode={mode!r} "
|
|
||||||
f"that is not in the possible render_modes ({render_modes})."
|
|
||||||
)
|
|
||||||
|
|
||||||
if apply_api_compatibility or (
|
|
||||||
apply_api_compatibility is None and env_spec.apply_api_compatibility
|
|
||||||
):
|
|
||||||
# If we use the compatibility layer, we treat the render mode explicitly and don't pass it to the env creator
|
|
||||||
render_mode = spec_kwargs.pop("render_mode", None)
|
|
||||||
else:
|
|
||||||
render_mode = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
env = env_creator(**spec_kwargs)
|
|
||||||
except TypeError as e:
|
|
||||||
if (
|
|
||||||
str(e).find("got an unexpected keyword argument 'render_mode'") >= 0
|
|
||||||
and apply_human_rendering
|
|
||||||
):
|
|
||||||
raise error.Error(
|
|
||||||
f"You passed render_mode='human' although {id} doesn't implement human-rendering natively. "
|
|
||||||
"Gym tried to apply the HumanRendering wrapper but it looks like your environment is using the old "
|
|
||||||
"rendering API, which is not supported by the HumanRendering wrapper."
|
|
||||||
) from e
|
|
||||||
else:
|
|
||||||
raise e
|
|
||||||
|
|
||||||
# Copies the environment creation specification and kwargs to add to the environment specification details
|
|
||||||
env_spec = copy.deepcopy(env_spec)
|
|
||||||
env_spec.kwargs = spec_kwargs
|
|
||||||
env.unwrapped.spec = env_spec
|
|
||||||
|
|
||||||
# Add step API wrapper
|
|
||||||
if apply_api_compatibility is True or (
|
|
||||||
apply_api_compatibility is None and env_spec.apply_api_compatibility is True
|
|
||||||
):
|
|
||||||
env = EnvCompatibility(env, render_mode)
|
|
||||||
|
|
||||||
# Run the environment checker as the lowest level wrapper
|
|
||||||
if disable_env_checker is False or (
|
|
||||||
disable_env_checker is None and env_spec.disable_env_checker is False
|
|
||||||
):
|
|
||||||
env = PassiveEnvChecker(env)
|
|
||||||
|
|
||||||
# Add the order enforcing wrapper
|
|
||||||
if env_spec.order_enforce:
|
|
||||||
env = OrderEnforcing(env)
|
|
||||||
|
|
||||||
# Add the time limit wrapper
|
|
||||||
if max_episode_steps is not None:
|
|
||||||
env = TimeLimit(env, max_episode_steps)
|
|
||||||
elif env_spec.max_episode_steps is not None:
|
|
||||||
env = TimeLimit(env, env_spec.max_episode_steps)
|
|
||||||
|
|
||||||
# Add the autoreset wrapper
|
|
||||||
if autoreset:
|
|
||||||
env = AutoResetWrapper(env)
|
|
||||||
|
|
||||||
# Add human rendering wrapper
|
|
||||||
if apply_human_rendering:
|
|
||||||
env = HumanRendering(env)
|
|
||||||
elif apply_render_collection:
|
|
||||||
env = RenderCollection(env)
|
|
||||||
|
|
||||||
return env
|
|
||||||
|
|
||||||
|
|
||||||
def make_vec(
|
def make_vec(
|
||||||
@@ -752,7 +985,7 @@ def make_vec(
|
|||||||
env_creator = entry_point
|
env_creator = entry_point
|
||||||
else:
|
else:
|
||||||
# Assume it's a string
|
# Assume it's a string
|
||||||
env_creator = load_env(entry_point)
|
env_creator = load_env_creator(entry_point)
|
||||||
|
|
||||||
def _create_env():
|
def _create_env():
|
||||||
# Env creator for use with sync and async modes
|
# Env creator for use with sync and async modes
|
||||||
|
@@ -11,7 +11,7 @@ except ImportError:
|
|||||||
cv2 = None
|
cv2 = None
|
||||||
|
|
||||||
|
|
||||||
class AtariPreprocessingV0(gym.Wrapper):
|
class AtariPreprocessingV0(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
||||||
"""Atari 2600 preprocessing wrapper.
|
"""Atari 2600 preprocessing wrapper.
|
||||||
|
|
||||||
This class follows the guidelines in Machado et al. (2018),
|
This class follows the guidelines in Machado et al. (2018),
|
||||||
@@ -60,7 +60,18 @@ class AtariPreprocessingV0(gym.Wrapper):
|
|||||||
DependencyNotInstalled: opencv-python package not installed
|
DependencyNotInstalled: opencv-python package not installed
|
||||||
ValueError: Disable frame-skipping in the original env
|
ValueError: Disable frame-skipping in the original env
|
||||||
"""
|
"""
|
||||||
super().__init__(env)
|
gym.utils.RecordConstructorArgs.__init__(
|
||||||
|
self,
|
||||||
|
noop_max=noop_max,
|
||||||
|
frame_skip=frame_skip,
|
||||||
|
screen_size=screen_size,
|
||||||
|
terminal_on_life_loss=terminal_on_life_loss,
|
||||||
|
grayscale_obs=grayscale_obs,
|
||||||
|
grayscale_newaxis=grayscale_newaxis,
|
||||||
|
scale_obs=scale_obs,
|
||||||
|
)
|
||||||
|
gym.Wrapper.__init__(self, env)
|
||||||
|
|
||||||
if cv2 is None:
|
if cv2 is None:
|
||||||
raise gym.error.DependencyNotInstalled(
|
raise gym.error.DependencyNotInstalled(
|
||||||
"opencv-python package not installed, run `pip install gymnasium[other]` to get dependencies for atari"
|
"opencv-python package not installed, run `pip install gymnasium[other]` to get dependencies for atari"
|
||||||
|
@@ -14,7 +14,6 @@ from typing import Any, SupportsFloat
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
from gymnasium import Env
|
|
||||||
from gymnasium.core import ActType, ObsType, RenderFrame
|
from gymnasium.core import ActType, ObsType, RenderFrame
|
||||||
from gymnasium.error import ResetNeeded
|
from gymnasium.error import ResetNeeded
|
||||||
from gymnasium.utils.passive_env_checker import (
|
from gymnasium.utils.passive_env_checker import (
|
||||||
@@ -26,7 +25,9 @@ from gymnasium.utils.passive_env_checker import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class AutoresetV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
|
class AutoresetV0(
|
||||||
|
gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
|
||||||
|
):
|
||||||
"""A class for providing an automatic reset functionality for gymnasium environments when calling :meth:`self.step`."""
|
"""A class for providing an automatic reset functionality for gymnasium environments when calling :meth:`self.step`."""
|
||||||
|
|
||||||
def __init__(self, env: gym.Env[ObsType, ActType]):
|
def __init__(self, env: gym.Env[ObsType, ActType]):
|
||||||
@@ -35,7 +36,9 @@ class AutoresetV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
|
|||||||
Args:
|
Args:
|
||||||
env (gym.Env): The environment to apply the wrapper
|
env (gym.Env): The environment to apply the wrapper
|
||||||
"""
|
"""
|
||||||
super().__init__(env)
|
gym.utils.RecordConstructorArgs.__init__(self)
|
||||||
|
gym.Wrapper.__init__(self, env)
|
||||||
|
|
||||||
self._episode_ended: bool = False
|
self._episode_ended: bool = False
|
||||||
self._reset_options: dict[str, Any] | None = None
|
self._reset_options: dict[str, Any] | None = None
|
||||||
|
|
||||||
@@ -68,12 +71,15 @@ class AutoresetV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
|
|||||||
return super().reset(seed=seed, options=self._reset_options)
|
return super().reset(seed=seed, options=self._reset_options)
|
||||||
|
|
||||||
|
|
||||||
class PassiveEnvCheckerV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
|
class PassiveEnvCheckerV0(
|
||||||
|
gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
|
||||||
|
):
|
||||||
"""A passive environment checker wrapper that surrounds the step, reset and render functions to check they follow the gymnasium API."""
|
"""A passive environment checker wrapper that surrounds the step, reset and render functions to check they follow the gymnasium API."""
|
||||||
|
|
||||||
def __init__(self, env: Env[ObsType, ActType]):
|
def __init__(self, env: gym.Env[ObsType, ActType]):
|
||||||
"""Initialises the wrapper with the environments, run the observation and action space tests."""
|
"""Initialises the wrapper with the environments, run the observation and action space tests."""
|
||||||
super().__init__(env)
|
gym.utils.RecordConstructorArgs.__init__(self)
|
||||||
|
gym.Wrapper.__init__(self, env)
|
||||||
|
|
||||||
assert hasattr(
|
assert hasattr(
|
||||||
env, "action_space"
|
env, "action_space"
|
||||||
@@ -117,7 +123,9 @@ class PassiveEnvCheckerV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
|
|||||||
return self.env.render()
|
return self.env.render()
|
||||||
|
|
||||||
|
|
||||||
class OrderEnforcingV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
|
class OrderEnforcingV0(
|
||||||
|
gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
|
||||||
|
):
|
||||||
"""A wrapper that will produce an error if :meth:`step` is called before an initial :meth:`reset`.
|
"""A wrapper that will produce an error if :meth:`step` is called before an initial :meth:`reset`.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
@@ -150,7 +158,11 @@ class OrderEnforcingV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
|
|||||||
env: The environment to wrap
|
env: The environment to wrap
|
||||||
disable_render_order_enforcing: If to disable render order enforcing
|
disable_render_order_enforcing: If to disable render order enforcing
|
||||||
"""
|
"""
|
||||||
super().__init__(env)
|
gym.utils.RecordConstructorArgs.__init__(
|
||||||
|
self, disable_render_order_enforcing=disable_render_order_enforcing
|
||||||
|
)
|
||||||
|
gym.Wrapper.__init__(self, env)
|
||||||
|
|
||||||
self._has_reset: bool = False
|
self._has_reset: bool = False
|
||||||
self._disable_render_order_enforcing: bool = disable_render_order_enforcing
|
self._disable_render_order_enforcing: bool = disable_render_order_enforcing
|
||||||
|
|
||||||
@@ -182,7 +194,9 @@ class OrderEnforcingV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
|
|||||||
return self._has_reset
|
return self._has_reset
|
||||||
|
|
||||||
|
|
||||||
class RecordEpisodeStatisticsV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
|
class RecordEpisodeStatisticsV0(
|
||||||
|
gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
|
||||||
|
):
|
||||||
"""This wrapper will keep track of cumulative rewards and episode lengths.
|
"""This wrapper will keep track of cumulative rewards and episode lengths.
|
||||||
|
|
||||||
At the end of an episode, the statistics of the episode will be added to ``info``
|
At the end of an episode, the statistics of the episode will be added to ``info``
|
||||||
@@ -226,7 +240,7 @@ class RecordEpisodeStatisticsV0(gym.Wrapper[ObsType, ActType, ObsType, ActType])
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
env: Env[ObsType, ActType],
|
env: gym.Env[ObsType, ActType],
|
||||||
buffer_length: int | None = 100,
|
buffer_length: int | None = 100,
|
||||||
stats_key: str = "episode",
|
stats_key: str = "episode",
|
||||||
):
|
):
|
||||||
@@ -237,7 +251,8 @@ class RecordEpisodeStatisticsV0(gym.Wrapper[ObsType, ActType, ObsType, ActType])
|
|||||||
buffer_length: The size of the buffers :attr:`return_queue` and :attr:`length_queue`
|
buffer_length: The size of the buffers :attr:`return_queue` and :attr:`length_queue`
|
||||||
stats_key: The info key for the episode statistics
|
stats_key: The info key for the episode statistics
|
||||||
"""
|
"""
|
||||||
super().__init__(env)
|
gym.utils.RecordConstructorArgs.__init__(self)
|
||||||
|
gym.Wrapper.__init__(self, env)
|
||||||
|
|
||||||
self._stats_key = stats_key
|
self._stats_key = stats_key
|
||||||
|
|
||||||
|
@@ -8,7 +8,7 @@ from typing import Any, Iterable, Mapping, SupportsFloat
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from gymnasium import Env, Wrapper
|
import gymnasium as gym
|
||||||
from gymnasium.core import ActType, ObsType, RenderFrame, WrapperActType, WrapperObsType
|
from gymnasium.core import ActType, ObsType, RenderFrame, WrapperActType, WrapperObsType
|
||||||
from gymnasium.error import DependencyNotInstalled
|
from gymnasium.error import DependencyNotInstalled
|
||||||
|
|
||||||
@@ -92,7 +92,10 @@ if jnp is not None:
|
|||||||
return type(value)(jax_to_numpy(v) for v in value)
|
return type(value)(jax_to_numpy(v) for v in value)
|
||||||
|
|
||||||
|
|
||||||
class JaxToNumpyV0(Wrapper[WrapperObsType, WrapperActType, ObsType, ActType]):
|
class JaxToNumpyV0(
|
||||||
|
gym.Wrapper[WrapperObsType, WrapperActType, ObsType, ActType],
|
||||||
|
gym.utils.RecordConstructorArgs,
|
||||||
|
):
|
||||||
"""Wraps a jax environment so that it can be interacted with through numpy arrays.
|
"""Wraps a jax environment so that it can be interacted with through numpy arrays.
|
||||||
|
|
||||||
Actions must be provided as numpy arrays and observations will be returned as numpy arrays.
|
Actions must be provided as numpy arrays and observations will be returned as numpy arrays.
|
||||||
@@ -102,7 +105,7 @@ class JaxToNumpyV0(Wrapper[WrapperObsType, WrapperActType, ObsType, ActType]):
|
|||||||
The reason for this is jax does not support non-array values, therefore numpy ``int_32(5) -> DeviceArray([5], dtype=jnp.int23)``
|
The reason for this is jax does not support non-array values, therefore numpy ``int_32(5) -> DeviceArray([5], dtype=jnp.int23)``
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, env: Env[ObsType, ActType]):
|
def __init__(self, env: gym.Env[ObsType, ActType]):
|
||||||
"""Wraps an environment such that the input and outputs are numpy arrays.
|
"""Wraps an environment such that the input and outputs are numpy arrays.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -112,7 +115,8 @@ class JaxToNumpyV0(Wrapper[WrapperObsType, WrapperActType, ObsType, ActType]):
|
|||||||
raise DependencyNotInstalled(
|
raise DependencyNotInstalled(
|
||||||
"jax is not installed, run `pip install gymnasium[jax]`"
|
"jax is not installed, run `pip install gymnasium[jax]`"
|
||||||
)
|
)
|
||||||
super().__init__(env)
|
gym.utils.RecordConstructorArgs.__init__(self)
|
||||||
|
gym.Wrapper.__init__(self, env)
|
||||||
|
|
||||||
def step(
|
def step(
|
||||||
self, action: WrapperActType
|
self, action: WrapperActType
|
||||||
|
@@ -14,7 +14,7 @@ import numbers
|
|||||||
from collections import abc
|
from collections import abc
|
||||||
from typing import Any, Iterable, Mapping, SupportsFloat, Union
|
from typing import Any, Iterable, Mapping, SupportsFloat, Union
|
||||||
|
|
||||||
from gymnasium import Env, Wrapper
|
import gymnasium as gym
|
||||||
from gymnasium.core import RenderFrame, WrapperActType, WrapperObsType
|
from gymnasium.core import RenderFrame, WrapperActType, WrapperObsType
|
||||||
from gymnasium.error import DependencyNotInstalled
|
from gymnasium.error import DependencyNotInstalled
|
||||||
from gymnasium.experimental.wrappers.conversion.jax_to_numpy import jax_to_numpy
|
from gymnasium.experimental.wrappers.conversion.jax_to_numpy import jax_to_numpy
|
||||||
@@ -131,7 +131,7 @@ if torch is not None and jnp is not None:
|
|||||||
return type(value)(jax_to_torch(v, device) for v in value)
|
return type(value)(jax_to_torch(v, device) for v in value)
|
||||||
|
|
||||||
|
|
||||||
class JaxToTorchV0(Wrapper):
|
class JaxToTorchV0(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
||||||
"""Wraps a jax-based environment so that it can be interacted with through PyTorch Tensors.
|
"""Wraps a jax-based environment so that it can be interacted with through PyTorch Tensors.
|
||||||
|
|
||||||
Actions must be provided as PyTorch Tensors and observations will be returned as PyTorch Tensors.
|
Actions must be provided as PyTorch Tensors and observations will be returned as PyTorch Tensors.
|
||||||
@@ -140,7 +140,7 @@ class JaxToTorchV0(Wrapper):
|
|||||||
For ``rendered`` this is returned as a NumPy array not a pytorch Tensor.
|
For ``rendered`` this is returned as a NumPy array not a pytorch Tensor.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, env: Env, device: Device | None = None):
|
def __init__(self, env: gym.Env, device: Device | None = None):
|
||||||
"""Wrapper class to change inputs and outputs of environment to PyTorch tensors.
|
"""Wrapper class to change inputs and outputs of environment to PyTorch tensors.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -156,7 +156,9 @@ class JaxToTorchV0(Wrapper):
|
|||||||
"jax is not installed, run `pip install gymnasium[jax]`"
|
"jax is not installed, run `pip install gymnasium[jax]`"
|
||||||
)
|
)
|
||||||
|
|
||||||
super().__init__(env)
|
gym.utils.RecordConstructorArgs.__init__(self, device=device)
|
||||||
|
gym.Wrapper.__init__(self, env)
|
||||||
|
|
||||||
self.device: Device | None = device
|
self.device: Device | None = device
|
||||||
|
|
||||||
def step(
|
def step(
|
||||||
|
@@ -8,7 +8,7 @@ from typing import Any, Iterable, Mapping, SupportsFloat, Union
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from gymnasium import Env, Wrapper
|
import gymnasium as gym
|
||||||
from gymnasium.core import WrapperActType, WrapperObsType
|
from gymnasium.core import WrapperActType, WrapperObsType
|
||||||
from gymnasium.error import DependencyNotInstalled
|
from gymnasium.error import DependencyNotInstalled
|
||||||
|
|
||||||
@@ -94,7 +94,7 @@ if torch is not None:
|
|||||||
return type(value)(numpy_to_torch(v, device) for v in value)
|
return type(value)(numpy_to_torch(v, device) for v in value)
|
||||||
|
|
||||||
|
|
||||||
class NumpyToTorchV0(Wrapper):
|
class NumpyToTorchV0(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
||||||
"""Wraps a numpy-based environment so that it can be interacted with through PyTorch Tensors.
|
"""Wraps a numpy-based environment so that it can be interacted with through PyTorch Tensors.
|
||||||
|
|
||||||
Actions must be provided as PyTorch Tensors and observations will be returned as PyTorch Tensors.
|
Actions must be provided as PyTorch Tensors and observations will be returned as PyTorch Tensors.
|
||||||
@@ -103,7 +103,7 @@ class NumpyToTorchV0(Wrapper):
|
|||||||
For ``rendered`` this is returned as a NumPy array not a pytorch Tensor.
|
For ``rendered`` this is returned as a NumPy array not a pytorch Tensor.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, env: Env, device: Device | None = None):
|
def __init__(self, env: gym.Env, device: Device | None = None):
|
||||||
"""Wrapper class to change inputs and outputs of environment to PyTorch tensors.
|
"""Wrapper class to change inputs and outputs of environment to PyTorch tensors.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -115,7 +115,9 @@ class NumpyToTorchV0(Wrapper):
|
|||||||
"torch is not installed, run `pip install torch`"
|
"torch is not installed, run `pip install torch`"
|
||||||
)
|
)
|
||||||
|
|
||||||
super().__init__(env)
|
gym.utils.RecordConstructorArgs.__init__(self, device=device)
|
||||||
|
gym.Wrapper.__init__(self, env)
|
||||||
|
|
||||||
self.device: Device | None = device
|
self.device: Device | None = device
|
||||||
|
|
||||||
def step(
|
def step(
|
||||||
|
@@ -20,7 +20,9 @@ from gymnasium.core import ActType, ObsType, WrapperActType
|
|||||||
from gymnasium.spaces import Box, Space
|
from gymnasium.spaces import Box, Space
|
||||||
|
|
||||||
|
|
||||||
class LambdaActionV0(gym.ActionWrapper[ObsType, WrapperActType, ActType]):
|
class LambdaActionV0(
|
||||||
|
gym.ActionWrapper[ObsType, WrapperActType, ActType], gym.utils.RecordConstructorArgs
|
||||||
|
):
|
||||||
"""A wrapper that provides a function to modify the action passed to :meth:`step`."""
|
"""A wrapper that provides a function to modify the action passed to :meth:`step`."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -36,7 +38,11 @@ class LambdaActionV0(gym.ActionWrapper[ObsType, WrapperActType, ActType]):
|
|||||||
func: Function to apply to ``step`` ``action``
|
func: Function to apply to ``step`` ``action``
|
||||||
action_space: The updated action space of the wrapper given the function.
|
action_space: The updated action space of the wrapper given the function.
|
||||||
"""
|
"""
|
||||||
super().__init__(env)
|
gym.utils.RecordConstructorArgs.__init__(
|
||||||
|
self, func=func, action_space=action_space
|
||||||
|
)
|
||||||
|
gym.Wrapper.__init__(self, env)
|
||||||
|
|
||||||
if action_space is not None:
|
if action_space is not None:
|
||||||
self.action_space = action_space
|
self.action_space = action_space
|
||||||
|
|
||||||
@@ -47,7 +53,9 @@ class LambdaActionV0(gym.ActionWrapper[ObsType, WrapperActType, ActType]):
|
|||||||
return self.func(action)
|
return self.func(action)
|
||||||
|
|
||||||
|
|
||||||
class ClipActionV0(LambdaActionV0[ObsType, WrapperActType, ActType]):
|
class ClipActionV0(
|
||||||
|
LambdaActionV0[ObsType, WrapperActType, ActType], gym.utils.RecordConstructorArgs
|
||||||
|
):
|
||||||
"""Clip the continuous action within the valid :class:`Box` observation space bound.
|
"""Clip the continuous action within the valid :class:`Box` observation space bound.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
@@ -71,10 +79,14 @@ class ClipActionV0(LambdaActionV0[ObsType, WrapperActType, ActType]):
|
|||||||
"""
|
"""
|
||||||
assert isinstance(env.action_space, Box)
|
assert isinstance(env.action_space, Box)
|
||||||
|
|
||||||
super().__init__(
|
gym.utils.RecordConstructorArgs.__init__(self)
|
||||||
env,
|
LambdaActionV0.__init__(
|
||||||
lambda action: jp.clip(action, env.action_space.low, env.action_space.high),
|
self,
|
||||||
Box(
|
env=env,
|
||||||
|
func=lambda action: jp.clip(
|
||||||
|
action, env.action_space.low, env.action_space.high
|
||||||
|
),
|
||||||
|
action_space=Box(
|
||||||
-np.inf,
|
-np.inf,
|
||||||
np.inf,
|
np.inf,
|
||||||
shape=env.action_space.shape,
|
shape=env.action_space.shape,
|
||||||
@@ -83,7 +95,9 @@ class ClipActionV0(LambdaActionV0[ObsType, WrapperActType, ActType]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class RescaleActionV0(LambdaActionV0[ObsType, WrapperActType, ActType]):
|
class RescaleActionV0(
|
||||||
|
LambdaActionV0[ObsType, WrapperActType, ActType], gym.utils.RecordConstructorArgs
|
||||||
|
):
|
||||||
"""Affinely rescales the continuous action space of the environment to the range [min_action, max_action].
|
"""Affinely rescales the continuous action space of the environment to the range [min_action, max_action].
|
||||||
|
|
||||||
The base environment :attr:`env` must have an action space of type :class:`spaces.Box`. If :attr:`min_action`
|
The base environment :attr:`env` must have an action space of type :class:`spaces.Box`. If :attr:`min_action`
|
||||||
@@ -118,6 +132,10 @@ class RescaleActionV0(LambdaActionV0[ObsType, WrapperActType, ActType]):
|
|||||||
min_action (float, int or np.ndarray): The min values for each action. This may be a numpy array or a scalar.
|
min_action (float, int or np.ndarray): The min values for each action. This may be a numpy array or a scalar.
|
||||||
max_action (float, int or np.ndarray): The max values for each action. This may be a numpy array or a scalar.
|
max_action (float, int or np.ndarray): The max values for each action. This may be a numpy array or a scalar.
|
||||||
"""
|
"""
|
||||||
|
gym.utils.RecordConstructorArgs.__init__(
|
||||||
|
self, min_action=min_action, max_action=max_action
|
||||||
|
)
|
||||||
|
|
||||||
assert isinstance(env.action_space, Box)
|
assert isinstance(env.action_space, Box)
|
||||||
assert not np.any(env.action_space.low == np.inf) and not np.any(
|
assert not np.any(env.action_space.low == np.inf) and not np.any(
|
||||||
env.action_space.high == np.inf
|
env.action_space.high == np.inf
|
||||||
@@ -149,10 +167,11 @@ class RescaleActionV0(LambdaActionV0[ObsType, WrapperActType, ActType]):
|
|||||||
)
|
)
|
||||||
intercept = gradient * -min_action + env.action_space.low
|
intercept = gradient * -min_action + env.action_space.low
|
||||||
|
|
||||||
super().__init__(
|
LambdaActionV0.__init__(
|
||||||
env,
|
self,
|
||||||
lambda action: gradient * action + intercept,
|
env=env,
|
||||||
Box(
|
func=lambda action: gradient * action + intercept,
|
||||||
|
action_space=Box(
|
||||||
low=min_action,
|
low=min_action,
|
||||||
high=max_action,
|
high=max_action,
|
||||||
shape=env.action_space.shape,
|
shape=env.action_space.shape,
|
||||||
|
@@ -24,14 +24,16 @@ except ImportError as e:
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
from gymnasium import Env, spaces
|
from gymnasium import spaces
|
||||||
from gymnasium.core import ActType, ObservationWrapper, ObsType, WrapperObsType
|
from gymnasium.core import ActType, ObsType, WrapperObsType
|
||||||
from gymnasium.error import DependencyNotInstalled
|
from gymnasium.error import DependencyNotInstalled
|
||||||
from gymnasium.experimental.wrappers.utils import RunningMeanStd
|
from gymnasium.experimental.wrappers.utils import RunningMeanStd
|
||||||
from gymnasium.spaces import Box, Dict, utils
|
|
||||||
|
|
||||||
|
|
||||||
class LambdaObservationV0(gym.ObservationWrapper[WrapperObsType, ActType, ObsType]):
|
class LambdaObservationV0(
|
||||||
|
gym.ObservationWrapper[WrapperObsType, ActType, ObsType],
|
||||||
|
gym.utils.RecordConstructorArgs,
|
||||||
|
):
|
||||||
"""Transforms an observation via a function provided to the wrapper.
|
"""Transforms an observation via a function provided to the wrapper.
|
||||||
|
|
||||||
The function :attr:`func` will be applied to all observations.
|
The function :attr:`func` will be applied to all observations.
|
||||||
@@ -61,7 +63,11 @@ class LambdaObservationV0(gym.ObservationWrapper[WrapperObsType, ActType, ObsTyp
|
|||||||
func: A function that will transform an observation. If this transformed observation is outside the observation space of `env.observation_space` then provide an `observation_space`.
|
func: A function that will transform an observation. If this transformed observation is outside the observation space of `env.observation_space` then provide an `observation_space`.
|
||||||
observation_space: The observation spaces of the wrapper, if None, then it is assumed the same as `env.observation_space`.
|
observation_space: The observation spaces of the wrapper, if None, then it is assumed the same as `env.observation_space`.
|
||||||
"""
|
"""
|
||||||
super().__init__(env)
|
gym.utils.RecordConstructorArgs.__init__(
|
||||||
|
self, func=func, observation_space=observation_space
|
||||||
|
)
|
||||||
|
gym.ObservationWrapper.__init__(self, env)
|
||||||
|
|
||||||
if observation_space is not None:
|
if observation_space is not None:
|
||||||
self.observation_space = observation_space
|
self.observation_space = observation_space
|
||||||
|
|
||||||
@@ -72,7 +78,10 @@ class LambdaObservationV0(gym.ObservationWrapper[WrapperObsType, ActType, ObsTyp
|
|||||||
return self.func(observation)
|
return self.func(observation)
|
||||||
|
|
||||||
|
|
||||||
class FilterObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]):
|
class FilterObservationV0(
|
||||||
|
LambdaObservationV0[WrapperObsType, ActType, ObsType],
|
||||||
|
gym.utils.RecordConstructorArgs,
|
||||||
|
):
|
||||||
"""Filter Dict observation space by the keys.
|
"""Filter Dict observation space by the keys.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
@@ -96,6 +105,7 @@ class FilterObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType])
|
|||||||
):
|
):
|
||||||
"""Constructor for an environment with a dictionary observation space where all :attr:`filter_keys` are in the observation space keys."""
|
"""Constructor for an environment with a dictionary observation space where all :attr:`filter_keys` are in the observation space keys."""
|
||||||
assert isinstance(filter_keys, Sequence)
|
assert isinstance(filter_keys, Sequence)
|
||||||
|
gym.utils.RecordConstructorArgs.__init__(self, filter_keys=filter_keys)
|
||||||
|
|
||||||
# Filters for dictionary space
|
# Filters for dictionary space
|
||||||
if isinstance(env.observation_space, spaces.Dict):
|
if isinstance(env.observation_space, spaces.Dict):
|
||||||
@@ -124,10 +134,11 @@ class FilterObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType])
|
|||||||
"The observation space is empty due to filtering all keys."
|
"The observation space is empty due to filtering all keys."
|
||||||
)
|
)
|
||||||
|
|
||||||
super().__init__(
|
LambdaObservationV0.__init__(
|
||||||
env,
|
self,
|
||||||
lambda obs: {key: obs[key] for key in filter_keys},
|
env=env,
|
||||||
new_observation_space,
|
func=lambda obs: {key: obs[key] for key in filter_keys},
|
||||||
|
observation_space=new_observation_space,
|
||||||
)
|
)
|
||||||
# Filter for tuple observation
|
# Filter for tuple observation
|
||||||
elif isinstance(env.observation_space, spaces.Tuple):
|
elif isinstance(env.observation_space, spaces.Tuple):
|
||||||
@@ -158,10 +169,11 @@ class FilterObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType])
|
|||||||
"The observation space is empty due to filtering all keys."
|
"The observation space is empty due to filtering all keys."
|
||||||
)
|
)
|
||||||
|
|
||||||
super().__init__(
|
LambdaObservationV0.__init__(
|
||||||
env,
|
self,
|
||||||
lambda obs: tuple(obs[key] for key in filter_keys),
|
env=env,
|
||||||
new_observation_spaces,
|
func=lambda obs: tuple(obs[key] for key in filter_keys),
|
||||||
|
observation_space=new_observation_spaces,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -171,7 +183,10 @@ class FilterObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType])
|
|||||||
self.filter_keys: Final[Sequence[str | int]] = filter_keys
|
self.filter_keys: Final[Sequence[str | int]] = filter_keys
|
||||||
|
|
||||||
|
|
||||||
class FlattenObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]):
|
class FlattenObservationV0(
|
||||||
|
LambdaObservationV0[WrapperObsType, ActType, ObsType],
|
||||||
|
gym.utils.RecordConstructorArgs,
|
||||||
|
):
|
||||||
"""Observation wrapper that flattens the observation.
|
"""Observation wrapper that flattens the observation.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
@@ -190,14 +205,19 @@ class FlattenObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]
|
|||||||
|
|
||||||
def __init__(self, env: gym.Env[ObsType, ActType]):
|
def __init__(self, env: gym.Env[ObsType, ActType]):
|
||||||
"""Constructor for any environment's observation space that implements ``spaces.utils.flatten_space`` and ``spaces.utils.flatten``."""
|
"""Constructor for any environment's observation space that implements ``spaces.utils.flatten_space`` and ``spaces.utils.flatten``."""
|
||||||
super().__init__(
|
gym.utils.RecordConstructorArgs.__init__(self)
|
||||||
env,
|
LambdaObservationV0.__init__(
|
||||||
lambda obs: utils.flatten(env.observation_space, obs),
|
self,
|
||||||
utils.flatten_space(env.observation_space),
|
env=env,
|
||||||
|
func=lambda obs: spaces.utils.flatten(env.observation_space, obs),
|
||||||
|
observation_space=spaces.utils.flatten_space(env.observation_space),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class GrayscaleObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]):
|
class GrayscaleObservationV0(
|
||||||
|
LambdaObservationV0[WrapperObsType, ActType, ObsType],
|
||||||
|
gym.utils.RecordConstructorArgs,
|
||||||
|
):
|
||||||
"""Observation wrapper that converts an RGB image to grayscale.
|
"""Observation wrapper that converts an RGB image to grayscale.
|
||||||
|
|
||||||
The :attr:`keep_dim` will keep the channel dimension
|
The :attr:`keep_dim` will keep the channel dimension
|
||||||
@@ -228,6 +248,7 @@ class GrayscaleObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsTyp
|
|||||||
and np.all(env.observation_space.high == 255)
|
and np.all(env.observation_space.high == 255)
|
||||||
and env.observation_space.dtype == np.uint8
|
and env.observation_space.dtype == np.uint8
|
||||||
)
|
)
|
||||||
|
gym.utils.RecordConstructorArgs.__init__(self, keep_dim=keep_dim)
|
||||||
|
|
||||||
self.keep_dim: Final[bool] = keep_dim
|
self.keep_dim: Final[bool] = keep_dim
|
||||||
if keep_dim:
|
if keep_dim:
|
||||||
@@ -237,30 +258,35 @@ class GrayscaleObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsTyp
|
|||||||
shape=env.observation_space.shape[:2] + (1,),
|
shape=env.observation_space.shape[:2] + (1,),
|
||||||
dtype=np.uint8,
|
dtype=np.uint8,
|
||||||
)
|
)
|
||||||
super().__init__(
|
LambdaObservationV0.__init__(
|
||||||
env,
|
self,
|
||||||
lambda obs: jp.expand_dims(
|
env=env,
|
||||||
|
func=lambda obs: jp.expand_dims(
|
||||||
jp.sum(
|
jp.sum(
|
||||||
jp.multiply(obs, jp.array([0.2125, 0.7154, 0.0721])), axis=-1
|
jp.multiply(obs, jp.array([0.2125, 0.7154, 0.0721])), axis=-1
|
||||||
).astype(np.uint8),
|
).astype(np.uint8),
|
||||||
axis=-1,
|
axis=-1,
|
||||||
),
|
),
|
||||||
new_observation_space,
|
observation_space=new_observation_space,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
new_observation_space = spaces.Box(
|
new_observation_space = spaces.Box(
|
||||||
low=0, high=255, shape=env.observation_space.shape[:2], dtype=np.uint8
|
low=0, high=255, shape=env.observation_space.shape[:2], dtype=np.uint8
|
||||||
)
|
)
|
||||||
super().__init__(
|
LambdaObservationV0.__init__(
|
||||||
env,
|
self,
|
||||||
lambda obs: jp.sum(
|
env=env,
|
||||||
|
func=lambda obs: jp.sum(
|
||||||
jp.multiply(obs, jp.array([0.2125, 0.7154, 0.0721])), axis=-1
|
jp.multiply(obs, jp.array([0.2125, 0.7154, 0.0721])), axis=-1
|
||||||
).astype(np.uint8),
|
).astype(np.uint8),
|
||||||
new_observation_space,
|
observation_space=new_observation_space,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ResizeObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]):
|
class ResizeObservationV0(
|
||||||
|
LambdaObservationV0[WrapperObsType, ActType, ObsType],
|
||||||
|
gym.utils.RecordConstructorArgs,
|
||||||
|
):
|
||||||
"""Resizes image observations using OpenCV to shape.
|
"""Resizes image observations using OpenCV to shape.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
@@ -299,14 +325,20 @@ class ResizeObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType])
|
|||||||
new_observation_space = spaces.Box(
|
new_observation_space = spaces.Box(
|
||||||
low=0, high=255, shape=self.shape + env.observation_space.shape[2:]
|
low=0, high=255, shape=self.shape + env.observation_space.shape[2:]
|
||||||
)
|
)
|
||||||
super().__init__(
|
|
||||||
env,
|
gym.utils.RecordConstructorArgs.__init__(self, shape=shape)
|
||||||
lambda obs: cv2.resize(obs, self.shape, interpolation=cv2.INTER_AREA),
|
LambdaObservationV0.__init__(
|
||||||
new_observation_space,
|
self,
|
||||||
|
env=env,
|
||||||
|
func=lambda obs: cv2.resize(obs, self.shape, interpolation=cv2.INTER_AREA),
|
||||||
|
observation_space=new_observation_space,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ReshapeObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]):
|
class ReshapeObservationV0(
|
||||||
|
LambdaObservationV0[WrapperObsType, ActType, ObsType],
|
||||||
|
gym.utils.RecordConstructorArgs,
|
||||||
|
):
|
||||||
"""Reshapes array based observations to shapes.
|
"""Reshapes array based observations to shapes.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
@@ -336,10 +368,20 @@ class ReshapeObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]
|
|||||||
dtype=env.observation_space.dtype,
|
dtype=env.observation_space.dtype,
|
||||||
)
|
)
|
||||||
self.shape = shape
|
self.shape = shape
|
||||||
super().__init__(env, lambda obs: jp.reshape(obs, shape), new_observation_space)
|
|
||||||
|
gym.utils.RecordConstructorArgs.__init__(self, shape=shape)
|
||||||
|
LambdaObservationV0.__init__(
|
||||||
|
self,
|
||||||
|
env=env,
|
||||||
|
func=lambda obs: jp.reshape(obs, shape),
|
||||||
|
observation_space=new_observation_space,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class RescaleObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]):
|
class RescaleObservationV0(
|
||||||
|
LambdaObservationV0[WrapperObsType, ActType, ObsType],
|
||||||
|
gym.utils.RecordConstructorArgs,
|
||||||
|
):
|
||||||
"""Linearly rescales observation to between a minimum and maximum value.
|
"""Linearly rescales observation to between a minimum and maximum value.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
@@ -392,10 +434,12 @@ class RescaleObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]
|
|||||||
)
|
)
|
||||||
intercept = gradient * -env.observation_space.low + min_obs
|
intercept = gradient * -env.observation_space.low + min_obs
|
||||||
|
|
||||||
super().__init__(
|
gym.utils.RecordConstructorArgs.__init__(self, min_obs=min_obs, max_obs=max_obs)
|
||||||
env,
|
LambdaObservationV0.__init__(
|
||||||
lambda obs: gradient * obs + intercept,
|
self,
|
||||||
Box(
|
env=env,
|
||||||
|
func=lambda obs: gradient * obs + intercept,
|
||||||
|
observation_space=spaces.Box(
|
||||||
low=min_obs,
|
low=min_obs,
|
||||||
high=max_obs,
|
high=max_obs,
|
||||||
shape=env.observation_space.shape,
|
shape=env.observation_space.shape,
|
||||||
@@ -404,7 +448,10 @@ class RescaleObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class DtypeObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]):
|
class DtypeObservationV0(
|
||||||
|
LambdaObservationV0[WrapperObsType, ActType, ObsType],
|
||||||
|
gym.utils.RecordConstructorArgs,
|
||||||
|
):
|
||||||
"""Observation wrapper for transforming the dtype of an observation."""
|
"""Observation wrapper for transforming the dtype of an observation."""
|
||||||
|
|
||||||
def __init__(self, env: gym.Env[ObsType, ActType], dtype: Any):
|
def __init__(self, env: gym.Env[ObsType, ActType], dtype: Any):
|
||||||
@@ -445,10 +492,19 @@ class DtypeObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]):
|
|||||||
"DtypeObservation is only compatible with value / array-based observations."
|
"DtypeObservation is only compatible with value / array-based observations."
|
||||||
)
|
)
|
||||||
|
|
||||||
super().__init__(env, lambda obs: dtype(obs), new_observation_space)
|
gym.utils.RecordConstructorArgs.__init__(self, dtype=dtype)
|
||||||
|
LambdaObservationV0.__init__(
|
||||||
|
self,
|
||||||
|
env=env,
|
||||||
|
func=lambda obs: dtype(obs),
|
||||||
|
observation_space=new_observation_space,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class PixelObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]):
|
class PixelObservationV0(
|
||||||
|
LambdaObservationV0[WrapperObsType, ActType, ObsType],
|
||||||
|
gym.utils.RecordConstructorArgs,
|
||||||
|
):
|
||||||
"""Augment observations by pixel values.
|
"""Augment observations by pixel values.
|
||||||
|
|
||||||
Observations of this wrapper will be dictionaries of images.
|
Observations of this wrapper will be dictionaries of images.
|
||||||
@@ -461,7 +517,7 @@ class PixelObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
env: Env[ObsType, ActType],
|
env: gym.Env[ObsType, ActType],
|
||||||
pixels_only: bool = True,
|
pixels_only: bool = True,
|
||||||
pixels_key: str = "pixels",
|
pixels_key: str = "pixels",
|
||||||
obs_key: str = "state",
|
obs_key: str = "state",
|
||||||
@@ -478,30 +534,49 @@ class PixelObservationV0(LambdaObservationV0[WrapperObsType, ActType, ObsType]):
|
|||||||
pixels_key: Optional custom string specifying the pixel key. Defaults to "pixels"
|
pixels_key: Optional custom string specifying the pixel key. Defaults to "pixels"
|
||||||
obs_key: Optional custom string specifying the obs key. Defaults to "state"
|
obs_key: Optional custom string specifying the obs key. Defaults to "state"
|
||||||
"""
|
"""
|
||||||
|
gym.utils.RecordConstructorArgs.__init__(
|
||||||
|
self, pixels_only=pixels_only, pixels_key=pixels_key, obs_key=obs_key
|
||||||
|
)
|
||||||
|
|
||||||
assert env.render_mode is not None and env.render_mode != "human"
|
assert env.render_mode is not None and env.render_mode != "human"
|
||||||
env.reset()
|
env.reset()
|
||||||
pixels = env.render()
|
pixels = env.render()
|
||||||
assert pixels is not None and isinstance(pixels, np.ndarray)
|
assert pixels is not None and isinstance(pixels, np.ndarray)
|
||||||
pixel_space = Box(low=0, high=255, shape=pixels.shape, dtype=np.uint8)
|
pixel_space = spaces.Box(low=0, high=255, shape=pixels.shape, dtype=np.uint8)
|
||||||
|
|
||||||
if pixels_only:
|
if pixels_only:
|
||||||
obs_space = pixel_space
|
obs_space = pixel_space
|
||||||
super().__init__(env, lambda _: self.render(), obs_space)
|
LambdaObservationV0.__init__(
|
||||||
elif isinstance(env.observation_space, Dict):
|
self, env=env, func=lambda _: self.render(), observation_space=obs_space
|
||||||
|
)
|
||||||
|
elif isinstance(env.observation_space, spaces.Dict):
|
||||||
assert pixels_key not in env.observation_space.spaces.keys()
|
assert pixels_key not in env.observation_space.spaces.keys()
|
||||||
|
|
||||||
obs_space = Dict({pixels_key: pixel_space, **env.observation_space.spaces})
|
obs_space = spaces.Dict(
|
||||||
super().__init__(
|
{pixels_key: pixel_space, **env.observation_space.spaces}
|
||||||
env, lambda obs: {pixels_key: self.render(), **obs_space}, obs_space
|
)
|
||||||
|
LambdaObservationV0.__init__(
|
||||||
|
self,
|
||||||
|
env=env,
|
||||||
|
func=lambda obs: {pixels_key: self.render(), **obs_space},
|
||||||
|
observation_space=obs_space,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
obs_space = Dict({obs_key: env.observation_space, pixels_key: pixel_space})
|
obs_space = spaces.Dict(
|
||||||
super().__init__(
|
{obs_key: env.observation_space, pixels_key: pixel_space}
|
||||||
env, lambda obs: {obs_key: obs, pixels_key: self.render()}, obs_space
|
)
|
||||||
|
LambdaObservationV0.__init__(
|
||||||
|
self,
|
||||||
|
env=env,
|
||||||
|
func=lambda obs: {obs_key: obs, pixels_key: self.render()},
|
||||||
|
observation_space=obs_space,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class NormalizeObservationV0(ObservationWrapper[WrapperObsType, ActType, ObsType]):
|
class NormalizeObservationV0(
|
||||||
|
gym.ObservationWrapper[WrapperObsType, ActType, ObsType],
|
||||||
|
gym.utils.RecordConstructorArgs,
|
||||||
|
):
|
||||||
"""This wrapper will normalize observations s.t. each coordinate is centered with unit variance.
|
"""This wrapper will normalize observations s.t. each coordinate is centered with unit variance.
|
||||||
|
|
||||||
The property `_update_running_mean` allows to freeze/continue the running mean calculation of the observation
|
The property `_update_running_mean` allows to freeze/continue the running mean calculation of the observation
|
||||||
@@ -520,7 +595,9 @@ class NormalizeObservationV0(ObservationWrapper[WrapperObsType, ActType, ObsType
|
|||||||
env (Env): The environment to apply the wrapper
|
env (Env): The environment to apply the wrapper
|
||||||
epsilon: A stability parameter that is used when scaling the observations.
|
epsilon: A stability parameter that is used when scaling the observations.
|
||||||
"""
|
"""
|
||||||
super().__init__(env)
|
gym.utils.RecordConstructorArgs.__init__(self, epsilon=epsilon)
|
||||||
|
gym.ObservationWrapper.__init__(self, env)
|
||||||
|
|
||||||
self.obs_rms = RunningMeanStd(shape=self.observation_space.shape)
|
self.obs_rms = RunningMeanStd(shape=self.observation_space.shape)
|
||||||
self.epsilon = epsilon
|
self.epsilon = epsilon
|
||||||
self._update_running_mean = True
|
self._update_running_mean = True
|
||||||
|
@@ -15,7 +15,9 @@ from gymnasium.error import InvalidBound
|
|||||||
from gymnasium.experimental.wrappers.utils import RunningMeanStd
|
from gymnasium.experimental.wrappers.utils import RunningMeanStd
|
||||||
|
|
||||||
|
|
||||||
class LambdaRewardV0(gym.RewardWrapper[ObsType, ActType]):
|
class LambdaRewardV0(
|
||||||
|
gym.RewardWrapper[ObsType, ActType], gym.utils.RecordConstructorArgs
|
||||||
|
):
|
||||||
"""A reward wrapper that allows a custom function to modify the step reward.
|
"""A reward wrapper that allows a custom function to modify the step reward.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
@@ -40,7 +42,8 @@ class LambdaRewardV0(gym.RewardWrapper[ObsType, ActType]):
|
|||||||
env (Env): The environment to apply the wrapper
|
env (Env): The environment to apply the wrapper
|
||||||
func: (Callable): The function to apply to reward
|
func: (Callable): The function to apply to reward
|
||||||
"""
|
"""
|
||||||
super().__init__(env)
|
gym.utils.RecordConstructorArgs.__init__(self, func=func)
|
||||||
|
gym.RewardWrapper.__init__(self, env)
|
||||||
|
|
||||||
self.func = func
|
self.func = func
|
||||||
|
|
||||||
@@ -53,7 +56,7 @@ class LambdaRewardV0(gym.RewardWrapper[ObsType, ActType]):
|
|||||||
return self.func(reward)
|
return self.func(reward)
|
||||||
|
|
||||||
|
|
||||||
class ClipRewardV0(LambdaRewardV0[ObsType, ActType]):
|
class ClipRewardV0(LambdaRewardV0[ObsType, ActType], gym.utils.RecordConstructorArgs):
|
||||||
"""A wrapper that clips the rewards for an environment between an upper and lower bound.
|
"""A wrapper that clips the rewards for an environment between an upper and lower bound.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
@@ -89,10 +92,17 @@ class ClipRewardV0(LambdaRewardV0[ObsType, ActType]):
|
|||||||
f"Min reward ({min_reward}) must be smaller than max reward ({max_reward})"
|
f"Min reward ({min_reward}) must be smaller than max reward ({max_reward})"
|
||||||
)
|
)
|
||||||
|
|
||||||
super().__init__(env, lambda x: np.clip(x, a_min=min_reward, a_max=max_reward))
|
gym.utils.RecordConstructorArgs.__init__(
|
||||||
|
self, min_reward=min_reward, max_reward=max_reward
|
||||||
|
)
|
||||||
|
LambdaRewardV0.__init__(
|
||||||
|
self, env=env, func=lambda x: np.clip(x, a_min=min_reward, a_max=max_reward)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class NormalizeRewardV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
|
class NormalizeRewardV0(
|
||||||
|
gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
|
||||||
|
):
|
||||||
r"""This wrapper will normalize immediate rewards s.t. their exponential moving average has a fixed variance.
|
r"""This wrapper will normalize immediate rewards s.t. their exponential moving average has a fixed variance.
|
||||||
|
|
||||||
The exponential moving average will have variance :math:`(1 - \gamma)^2`.
|
The exponential moving average will have variance :math:`(1 - \gamma)^2`.
|
||||||
@@ -119,7 +129,9 @@ class NormalizeRewardV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
|
|||||||
epsilon (float): A stability parameter
|
epsilon (float): A stability parameter
|
||||||
gamma (float): The discount factor that is used in the exponential moving average.
|
gamma (float): The discount factor that is used in the exponential moving average.
|
||||||
"""
|
"""
|
||||||
super().__init__(env)
|
gym.utils.RecordConstructorArgs.__init__(self, gamma=gamma, epsilon=epsilon)
|
||||||
|
gym.Wrapper.__init__(self, env)
|
||||||
|
|
||||||
self.rewards_running_means = RunningMeanStd(shape=())
|
self.rewards_running_means = RunningMeanStd(shape=())
|
||||||
self.discounted_reward: np.array = np.array([0.0])
|
self.discounted_reward: np.array = np.array([0.0])
|
||||||
self.gamma = gamma
|
self.gamma = gamma
|
||||||
|
@@ -18,7 +18,9 @@ from gymnasium.core import ActType, ObsType, RenderFrame
|
|||||||
from gymnasium.error import DependencyNotInstalled
|
from gymnasium.error import DependencyNotInstalled
|
||||||
|
|
||||||
|
|
||||||
class RenderCollectionV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
|
class RenderCollectionV0(
|
||||||
|
gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
|
||||||
|
):
|
||||||
"""Collect rendered frames of an environment such ``render`` returns a ``list[RenderedFrame]``."""
|
"""Collect rendered frames of an environment such ``render`` returns a ``list[RenderedFrame]``."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -34,7 +36,11 @@ class RenderCollectionV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
|
|||||||
pop_frames (bool): If true, clear the collection frames after ``meth:render`` is called. Default value is ``True``.
|
pop_frames (bool): If true, clear the collection frames after ``meth:render`` is called. Default value is ``True``.
|
||||||
reset_clean (bool): If true, clear the collection frames when ``meth:reset`` is called. Default value is ``True``.
|
reset_clean (bool): If true, clear the collection frames when ``meth:reset`` is called. Default value is ``True``.
|
||||||
"""
|
"""
|
||||||
super().__init__(env)
|
gym.utils.RecordConstructorArgs.__init__(
|
||||||
|
self, pop_frames=pop_frames, reset_clean=reset_clean
|
||||||
|
)
|
||||||
|
gym.Wrapper.__init__(self, env)
|
||||||
|
|
||||||
assert env.render_mode is not None
|
assert env.render_mode is not None
|
||||||
assert not env.render_mode.endswith("_list")
|
assert not env.render_mode.endswith("_list")
|
||||||
|
|
||||||
@@ -80,7 +86,9 @@ class RenderCollectionV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
|
|||||||
return frames
|
return frames
|
||||||
|
|
||||||
|
|
||||||
class RecordVideoV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
|
class RecordVideoV0(
|
||||||
|
gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
|
||||||
|
):
|
||||||
"""This wrapper records videos of rollouts.
|
"""This wrapper records videos of rollouts.
|
||||||
|
|
||||||
Usually, you only want to record episodes intermittently, say every hundredth episode.
|
Usually, you only want to record episodes intermittently, say every hundredth episode.
|
||||||
@@ -117,9 +125,18 @@ class RecordVideoV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
|
|||||||
Otherwise, snippets of the specified length are captured
|
Otherwise, snippets of the specified length are captured
|
||||||
name_prefix (str): Will be prepended to the filename of the recordings
|
name_prefix (str): Will be prepended to the filename of the recordings
|
||||||
disable_logger (bool): Whether to disable moviepy logger or not
|
disable_logger (bool): Whether to disable moviepy logger or not
|
||||||
|
|
||||||
"""
|
"""
|
||||||
super().__init__(env)
|
gym.utils.RecordConstructorArgs.__init__(
|
||||||
|
self,
|
||||||
|
video_folder=video_folder,
|
||||||
|
episode_trigger=episode_trigger,
|
||||||
|
step_trigger=step_trigger,
|
||||||
|
video_length=video_length,
|
||||||
|
name_prefix=name_prefix,
|
||||||
|
disable_logger=disable_logger,
|
||||||
|
)
|
||||||
|
gym.Wrapper.__init__(self, env)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import moviepy # noqa: F401
|
import moviepy # noqa: F401
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
@@ -277,7 +294,9 @@ class RecordVideoV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
|
|||||||
logger.warn("Unable to save last video! Did you call close()?")
|
logger.warn("Unable to save last video! Did you call close()?")
|
||||||
|
|
||||||
|
|
||||||
class HumanRenderingV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
|
class HumanRenderingV0(
|
||||||
|
gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
|
||||||
|
):
|
||||||
"""Performs human rendering for an environment that only supports "rgb_array"rendering.
|
"""Performs human rendering for an environment that only supports "rgb_array"rendering.
|
||||||
|
|
||||||
This wrapper is particularly useful when you have implemented an environment that can produce
|
This wrapper is particularly useful when you have implemented an environment that can produce
|
||||||
@@ -317,7 +336,9 @@ class HumanRenderingV0(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
|
|||||||
Args:
|
Args:
|
||||||
env: The environment that is being wrapped
|
env: The environment that is being wrapped
|
||||||
"""
|
"""
|
||||||
super().__init__(env)
|
gym.utils.RecordConstructorArgs.__init__(self)
|
||||||
|
gym.Wrapper.__init__(self, env)
|
||||||
|
|
||||||
assert env.render_mode in [
|
assert env.render_mode in [
|
||||||
"rgb_array",
|
"rgb_array",
|
||||||
"rgb_array_list",
|
"rgb_array_list",
|
||||||
|
@@ -4,11 +4,13 @@ from __future__ import annotations
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
from gymnasium.core import ActionWrapper, ActType, ObsType
|
from gymnasium.core import ActType, ObsType
|
||||||
from gymnasium.error import InvalidProbability
|
from gymnasium.error import InvalidProbability
|
||||||
|
|
||||||
|
|
||||||
class StickyActionV0(ActionWrapper[ObsType, ActType, ActType]):
|
class StickyActionV0(
|
||||||
|
gym.ActionWrapper[ObsType, ActType, ActType], gym.utils.RecordConstructorArgs
|
||||||
|
):
|
||||||
"""Wrapper which adds a probability of repeating the previous action.
|
"""Wrapper which adds a probability of repeating the previous action.
|
||||||
|
|
||||||
This wrapper follows the implementation proposed by `Machado et al., 2018 <https://arxiv.org/pdf/1709.06009.pdf>`_
|
This wrapper follows the implementation proposed by `Machado et al., 2018 <https://arxiv.org/pdf/1709.06009.pdf>`_
|
||||||
@@ -29,7 +31,11 @@ class StickyActionV0(ActionWrapper[ObsType, ActType, ActType]):
|
|||||||
f"repeat_action_probability should be in the interval [0,1). Received {repeat_action_probability}"
|
f"repeat_action_probability should be in the interval [0,1). Received {repeat_action_probability}"
|
||||||
)
|
)
|
||||||
|
|
||||||
super().__init__(env)
|
gym.utils.RecordConstructorArgs.__init__(
|
||||||
|
self, repeat_action_probability=repeat_action_probability
|
||||||
|
)
|
||||||
|
gym.ActionWrapper.__init__(self, env)
|
||||||
|
|
||||||
self.repeat_action_probability = repeat_action_probability
|
self.repeat_action_probability = repeat_action_probability
|
||||||
self.last_action: ActType | None = None
|
self.last_action: ActType | None = None
|
||||||
|
|
||||||
|
@@ -13,8 +13,8 @@ from typing_extensions import Final
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
import gymnasium as gym
|
||||||
import gymnasium.spaces as spaces
|
import gymnasium.spaces as spaces
|
||||||
from gymnasium import Env, ObservationWrapper, Space, Wrapper
|
|
||||||
from gymnasium.core import ActType, ObsType, WrapperActType, WrapperObsType
|
from gymnasium.core import ActType, ObsType, WrapperActType, WrapperObsType
|
||||||
from gymnasium.experimental.vector.utils import (
|
from gymnasium.experimental.vector.utils import (
|
||||||
batch_space,
|
batch_space,
|
||||||
@@ -25,7 +25,9 @@ from gymnasium.experimental.wrappers.utils import create_zero_array
|
|||||||
from gymnasium.spaces import Box, Dict, Tuple
|
from gymnasium.spaces import Box, Dict, Tuple
|
||||||
|
|
||||||
|
|
||||||
class DelayObservationV0(ObservationWrapper[ObsType, ActType, ObsType]):
|
class DelayObservationV0(
|
||||||
|
gym.ObservationWrapper[ObsType, ActType, ObsType], gym.utils.RecordConstructorArgs
|
||||||
|
):
|
||||||
"""Wrapper which adds a delay to the returned observation.
|
"""Wrapper which adds a delay to the returned observation.
|
||||||
|
|
||||||
Before reaching the :attr:`delay` number of timesteps, returned observations is an array of zeros with
|
Before reaching the :attr:`delay` number of timesteps, returned observations is an array of zeros with
|
||||||
@@ -49,15 +51,13 @@ class DelayObservationV0(ObservationWrapper[ObsType, ActType, ObsType]):
|
|||||||
This does not support random delay values, if users are interested, please raise an issue or pull request to add this feature.
|
This does not support random delay values, if users are interested, please raise an issue or pull request to add this feature.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, env: Env[ObsType, ActType], delay: int):
|
def __init__(self, env: gym.Env[ObsType, ActType], delay: int):
|
||||||
"""Initialises the DelayObservation wrapper with an integer.
|
"""Initialises the DelayObservation wrapper with an integer.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
env: The environment to wrap
|
env: The environment to wrap
|
||||||
delay: The number of timesteps to delay observations
|
delay: The number of timesteps to delay observations
|
||||||
"""
|
"""
|
||||||
super().__init__(env)
|
|
||||||
|
|
||||||
if not np.issubdtype(type(delay), np.integer):
|
if not np.issubdtype(type(delay), np.integer):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
f"The delay is expected to be an integer, actual type: {type(delay)}"
|
f"The delay is expected to be an integer, actual type: {type(delay)}"
|
||||||
@@ -67,6 +67,9 @@ class DelayObservationV0(ObservationWrapper[ObsType, ActType, ObsType]):
|
|||||||
f"The delay needs to be greater than zero, actual value: {delay}"
|
f"The delay needs to be greater than zero, actual value: {delay}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
gym.utils.RecordConstructorArgs.__init__(self, delay=delay)
|
||||||
|
gym.ObservationWrapper.__init__(self, env)
|
||||||
|
|
||||||
self.delay: Final[int] = int(delay)
|
self.delay: Final[int] = int(delay)
|
||||||
self.observation_queue: Final[deque] = deque()
|
self.observation_queue: Final[deque] = deque()
|
||||||
|
|
||||||
@@ -88,7 +91,10 @@ class DelayObservationV0(ObservationWrapper[ObsType, ActType, ObsType]):
|
|||||||
return create_zero_array(self.observation_space)
|
return create_zero_array(self.observation_space)
|
||||||
|
|
||||||
|
|
||||||
class TimeAwareObservationV0(ObservationWrapper[WrapperObsType, ActType, ObsType]):
|
class TimeAwareObservationV0(
|
||||||
|
gym.ObservationWrapper[WrapperObsType, ActType, ObsType],
|
||||||
|
gym.utils.RecordConstructorArgs,
|
||||||
|
):
|
||||||
"""Augment the observation with time information of the episode.
|
"""Augment the observation with time information of the episode.
|
||||||
|
|
||||||
The :attr:`normalize_time` if ``True`` represents time as a normalized value between [0,1]
|
The :attr:`normalize_time` if ``True`` represents time as a normalized value between [0,1]
|
||||||
@@ -144,7 +150,7 @@ class TimeAwareObservationV0(ObservationWrapper[WrapperObsType, ActType, ObsType
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
env: Env[ObsType, ActType],
|
env: gym.Env[ObsType, ActType],
|
||||||
flatten: bool = False,
|
flatten: bool = False,
|
||||||
normalize_time: bool = True,
|
normalize_time: bool = True,
|
||||||
*,
|
*,
|
||||||
@@ -159,7 +165,13 @@ class TimeAwareObservationV0(ObservationWrapper[WrapperObsType, ActType, ObsType
|
|||||||
otherwise return time as remaining timesteps before truncation
|
otherwise return time as remaining timesteps before truncation
|
||||||
dict_time_key: For environment with a ``Dict`` observation space, the key for the time space. By default, `"time"`.
|
dict_time_key: For environment with a ``Dict`` observation space, the key for the time space. By default, `"time"`.
|
||||||
"""
|
"""
|
||||||
super().__init__(env)
|
gym.utils.RecordConstructorArgs.__init__(
|
||||||
|
self,
|
||||||
|
flatten=flatten,
|
||||||
|
normalize_time=normalize_time,
|
||||||
|
dict_time_key=dict_time_key,
|
||||||
|
)
|
||||||
|
gym.ObservationWrapper.__init__(self, env)
|
||||||
|
|
||||||
self.flatten: Final[bool] = flatten
|
self.flatten: Final[bool] = flatten
|
||||||
self.normalize_time: Final[bool] = normalize_time
|
self.normalize_time: Final[bool] = normalize_time
|
||||||
@@ -203,14 +215,14 @@ class TimeAwareObservationV0(ObservationWrapper[WrapperObsType, ActType, ObsType
|
|||||||
|
|
||||||
# If to flatten the observation space
|
# If to flatten the observation space
|
||||||
if self.flatten:
|
if self.flatten:
|
||||||
self.observation_space: Space[WrapperObsType] = spaces.flatten_space(
|
self.observation_space: gym.Space[WrapperObsType] = spaces.flatten_space(
|
||||||
observation_space
|
observation_space
|
||||||
)
|
)
|
||||||
self._obs_postprocess_func = lambda obs: spaces.flatten(
|
self._obs_postprocess_func = lambda obs: spaces.flatten(
|
||||||
observation_space, obs
|
observation_space, obs
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.observation_space: Space[WrapperObsType] = observation_space
|
self.observation_space: gym.Space[WrapperObsType] = observation_space
|
||||||
self._obs_postprocess_func = lambda obs: obs
|
self._obs_postprocess_func = lambda obs: obs
|
||||||
|
|
||||||
def observation(self, observation: ObsType) -> WrapperObsType:
|
def observation(self, observation: ObsType) -> WrapperObsType:
|
||||||
@@ -260,7 +272,10 @@ class TimeAwareObservationV0(ObservationWrapper[WrapperObsType, ActType, ObsType
|
|||||||
return super().reset(seed=seed, options=options)
|
return super().reset(seed=seed, options=options)
|
||||||
|
|
||||||
|
|
||||||
class FrameStackObservationV0(Wrapper[WrapperObsType, ActType, ObsType, ActType]):
|
class FrameStackObservationV0(
|
||||||
|
gym.Wrapper[WrapperObsType, ActType, ObsType, ActType],
|
||||||
|
gym.utils.RecordConstructorArgs,
|
||||||
|
):
|
||||||
"""Observation wrapper that stacks the observations in a rolling manner.
|
"""Observation wrapper that stacks the observations in a rolling manner.
|
||||||
|
|
||||||
For example, if the number of stacks is 4, then the returned observation contains
|
For example, if the number of stacks is 4, then the returned observation contains
|
||||||
@@ -286,7 +301,7 @@ class FrameStackObservationV0(Wrapper[WrapperObsType, ActType, ObsType, ActType]
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
env: Env[ObsType, ActType],
|
env: gym.Env[ObsType, ActType],
|
||||||
stack_size: int,
|
stack_size: int,
|
||||||
*,
|
*,
|
||||||
zeros_obs: ObsType | None = None,
|
zeros_obs: ObsType | None = None,
|
||||||
@@ -298,8 +313,6 @@ class FrameStackObservationV0(Wrapper[WrapperObsType, ActType, ObsType, ActType]
|
|||||||
stack_size: The number of frames to stack with zero_obs being used originally.
|
stack_size: The number of frames to stack with zero_obs being used originally.
|
||||||
zeros_obs: Keyword only parameter that allows a custom padding observation at :meth:`reset`
|
zeros_obs: Keyword only parameter that allows a custom padding observation at :meth:`reset`
|
||||||
"""
|
"""
|
||||||
super().__init__(env)
|
|
||||||
|
|
||||||
if not np.issubdtype(type(stack_size), np.integer):
|
if not np.issubdtype(type(stack_size), np.integer):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
f"The stack_size is expected to be an integer, actual type: {type(stack_size)}"
|
f"The stack_size is expected to be an integer, actual type: {type(stack_size)}"
|
||||||
@@ -309,6 +322,9 @@ class FrameStackObservationV0(Wrapper[WrapperObsType, ActType, ObsType, ActType]
|
|||||||
f"The stack_size needs to be greater than one, actual value: {stack_size}"
|
f"The stack_size needs to be greater than one, actual value: {stack_size}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
gym.utils.RecordConstructorArgs.__init__(self, stack_size=stack_size)
|
||||||
|
gym.Wrapper.__init__(self, env)
|
||||||
|
|
||||||
self.observation_space = batch_space(env.observation_space, n=stack_size)
|
self.observation_space = batch_space(env.observation_space, n=stack_size)
|
||||||
self.stack_size: Final[int] = stack_size
|
self.stack_size: Final[int] = stack_size
|
||||||
|
|
||||||
|
@@ -8,6 +8,7 @@ These are not intended as API functions, and will not remain stable over time.
|
|||||||
# that verify that our dependencies are actually present.
|
# that verify that our dependencies are actually present.
|
||||||
from gymnasium.utils.colorize import colorize
|
from gymnasium.utils.colorize import colorize
|
||||||
from gymnasium.utils.ezpickle import EzPickle
|
from gymnasium.utils.ezpickle import EzPickle
|
||||||
|
from gymnasium.utils.record_constructor import RecordConstructorArgs
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["colorize", "EzPickle"]
|
__all__ = ["colorize", "EzPickle", "RecordConstructorArgs"]
|
||||||
|
@@ -1,13 +1,15 @@
|
|||||||
"""Class for pickling and unpickling objects via their constructor arguments."""
|
"""Class for pickling and unpickling objects via their constructor arguments."""
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
class EzPickle:
|
class EzPickle:
|
||||||
"""Objects that are pickled and unpickled via their constructor arguments.
|
"""Objects that are pickled and unpickled via their constructor arguments.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
>>> class Dog(Animal, EzPickle): # doctest: +SKIP
|
>>> class Animal: pass
|
||||||
|
>>> class Dog(Animal, EzPickle):
|
||||||
... def __init__(self, furcolor, tailkind="bushy"):
|
... def __init__(self, furcolor, tailkind="bushy"):
|
||||||
... Animal.__init__()
|
... Animal.__init__(self)
|
||||||
... EzPickle.__init__(self, furcolor, tailkind)
|
... EzPickle.__init__(self, furcolor, tailkind)
|
||||||
|
|
||||||
When this object is unpickled, a new ``Dog`` will be constructed by passing the provided furcolor and tailkind into the constructor.
|
When this object is unpickled, a new ``Dog`` will be constructed by passing the provided furcolor and tailkind into the constructor.
|
||||||
@@ -16,7 +18,7 @@ class EzPickle:
|
|||||||
This is generally needed only for environments which wrap C/C++ code, such as MuJoCo and Atari.
|
This is generally needed only for environments which wrap C/C++ code, such as MuJoCo and Atari.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args: Any, **kwargs: Any):
|
||||||
"""Uses the ``args`` and ``kwargs`` from the object's constructor for pickling."""
|
"""Uses the ``args`` and ``kwargs`` from the object's constructor for pickling."""
|
||||||
self._ezpickle_args = args
|
self._ezpickle_args = args
|
||||||
self._ezpickle_kwargs = kwargs
|
self._ezpickle_kwargs = kwargs
|
||||||
|
33
gymnasium/utils/record_constructor.py
Normal file
33
gymnasium/utils/record_constructor.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
"""Allows attributes passed to `RecordConstructorArgs` to be saved. This is used by the `Wrapper.spec` to know the constructor arguments of implemented wrappers."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from copy import deepcopy
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
class RecordConstructorArgs:
|
||||||
|
"""Records all arguments passed to constructor to `_saved_kwargs`.
|
||||||
|
|
||||||
|
This can be used to save and reproduce class constructor arguments.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
If two class inherit from RecordConstructorArgs then the first class to call `RecordConstructorArgs.__init__(self, ...)` will have
|
||||||
|
their kwargs saved will all subsequent `RecordConstructorArgs.__init__` being ignored.
|
||||||
|
|
||||||
|
Therefore, always call `RecordConstructorArgs.__init__` before the `Class.__init__`
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *, _disable_deepcopy: bool = False, **kwargs: Any):
|
||||||
|
"""Records all arguments passed to constructor to `_saved_kwargs`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
_disable_deepcopy: If to not deepcopy the kwargs passed
|
||||||
|
**kwargs: Arguments to save
|
||||||
|
"""
|
||||||
|
# See class docstring for explanation
|
||||||
|
if not hasattr(self, "_saved_kwargs"):
|
||||||
|
if _disable_deepcopy is False:
|
||||||
|
kwargs = deepcopy(kwargs)
|
||||||
|
self._saved_kwargs: dict[str, Any] = kwargs
|
@@ -11,7 +11,7 @@ except ImportError:
|
|||||||
cv2 = None
|
cv2 = None
|
||||||
|
|
||||||
|
|
||||||
class AtariPreprocessing(gym.Wrapper):
|
class AtariPreprocessing(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
||||||
"""Atari 2600 preprocessing wrapper.
|
"""Atari 2600 preprocessing wrapper.
|
||||||
|
|
||||||
This class follows the guidelines in Machado et al. (2018),
|
This class follows the guidelines in Machado et al. (2018),
|
||||||
@@ -60,7 +60,18 @@ class AtariPreprocessing(gym.Wrapper):
|
|||||||
DependencyNotInstalled: opencv-python package not installed
|
DependencyNotInstalled: opencv-python package not installed
|
||||||
ValueError: Disable frame-skipping in the original env
|
ValueError: Disable frame-skipping in the original env
|
||||||
"""
|
"""
|
||||||
super().__init__(env)
|
gym.utils.RecordConstructorArgs.__init__(
|
||||||
|
self,
|
||||||
|
noop_max=noop_max,
|
||||||
|
frame_skip=frame_skip,
|
||||||
|
screen_size=screen_size,
|
||||||
|
terminal_on_life_loss=terminal_on_life_loss,
|
||||||
|
grayscale_obs=grayscale_obs,
|
||||||
|
grayscale_newaxis=grayscale_newaxis,
|
||||||
|
scale_obs=scale_obs,
|
||||||
|
)
|
||||||
|
gym.Wrapper.__init__(self, env)
|
||||||
|
|
||||||
if cv2 is None:
|
if cv2 is None:
|
||||||
raise gym.error.DependencyNotInstalled(
|
raise gym.error.DependencyNotInstalled(
|
||||||
"opencv-python package not installed, run `pip install gymnasium[other]` to get dependencies for atari"
|
"opencv-python package not installed, run `pip install gymnasium[other]` to get dependencies for atari"
|
||||||
|
@@ -2,7 +2,7 @@
|
|||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
|
|
||||||
|
|
||||||
class AutoResetWrapper(gym.Wrapper):
|
class AutoResetWrapper(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
||||||
"""A class for providing an automatic reset functionality for gymnasium environments when calling :meth:`self.step`.
|
"""A class for providing an automatic reset functionality for gymnasium environments when calling :meth:`self.step`.
|
||||||
|
|
||||||
When calling step causes :meth:`Env.step` to return `terminated=True` or `truncated=True`, :meth:`Env.reset` is called,
|
When calling step causes :meth:`Env.step` to return `terminated=True` or `truncated=True`, :meth:`Env.reset` is called,
|
||||||
@@ -31,7 +31,8 @@ class AutoResetWrapper(gym.Wrapper):
|
|||||||
Args:
|
Args:
|
||||||
env (gym.Env): The environment to apply the wrapper
|
env (gym.Env): The environment to apply the wrapper
|
||||||
"""
|
"""
|
||||||
super().__init__(env)
|
gym.utils.RecordConstructorArgs.__init__(self)
|
||||||
|
gym.Wrapper.__init__(self, env)
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
"""Steps through the environment with action and resets the environment if a terminated or truncated signal is encountered.
|
"""Steps through the environment with action and resets the environment if a terminated or truncated signal is encountered.
|
||||||
|
@@ -2,11 +2,10 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
from gymnasium import ActionWrapper
|
|
||||||
from gymnasium.spaces import Box
|
from gymnasium.spaces import Box
|
||||||
|
|
||||||
|
|
||||||
class ClipAction(ActionWrapper):
|
class ClipAction(gym.ActionWrapper, gym.utils.RecordConstructorArgs):
|
||||||
"""Clip the continuous action within the valid :class:`Box` observation space bound.
|
"""Clip the continuous action within the valid :class:`Box` observation space bound.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
@@ -28,7 +27,9 @@ class ClipAction(ActionWrapper):
|
|||||||
env: The environment to apply the wrapper
|
env: The environment to apply the wrapper
|
||||||
"""
|
"""
|
||||||
assert isinstance(env.action_space, Box)
|
assert isinstance(env.action_space, Box)
|
||||||
super().__init__(env)
|
|
||||||
|
gym.utils.RecordConstructorArgs.__init__(self)
|
||||||
|
gym.ActionWrapper.__init__(self, env)
|
||||||
|
|
||||||
def action(self, action):
|
def action(self, action):
|
||||||
"""Clips the action within the valid bounds.
|
"""Clips the action within the valid bounds.
|
||||||
|
@@ -68,11 +68,12 @@ class EnvCompatibility(gym.Env):
|
|||||||
"The `gymnasium.make(..., apply_api_compatibility=...)` parameter is deprecated and will be removed in v0.29. "
|
"The `gymnasium.make(..., apply_api_compatibility=...)` parameter is deprecated and will be removed in v0.29. "
|
||||||
"Instead use `gym.make('GymV21Environment-v0', env_name=...)` or `from shimmy import GymV21CompatibilityV0`"
|
"Instead use `gym.make('GymV21Environment-v0', env_name=...)` or `from shimmy import GymV21CompatibilityV0`"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.env = old_env
|
||||||
self.metadata = getattr(old_env, "metadata", {"render_modes": []})
|
self.metadata = getattr(old_env, "metadata", {"render_modes": []})
|
||||||
self.render_mode = render_mode
|
self.render_mode = render_mode
|
||||||
self.reward_range = getattr(old_env, "reward_range", None)
|
self.reward_range = getattr(old_env, "reward_range", None)
|
||||||
self.spec = getattr(old_env, "spec", None)
|
self.spec = getattr(old_env, "spec", None)
|
||||||
self.env = old_env
|
|
||||||
|
|
||||||
self.observation_space = old_env.observation_space
|
self.observation_space = old_env.observation_space
|
||||||
self.action_space = old_env.action_space
|
self.action_space = old_env.action_space
|
||||||
|
@@ -10,12 +10,13 @@ from gymnasium.utils.passive_env_checker import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class PassiveEnvChecker(gym.Wrapper):
|
class PassiveEnvChecker(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
||||||
"""A passive environment checker wrapper that surrounds the step, reset and render functions to check they follow the gymnasium API."""
|
"""A passive environment checker wrapper that surrounds the step, reset and render functions to check they follow the gymnasium API."""
|
||||||
|
|
||||||
def __init__(self, env):
|
def __init__(self, env):
|
||||||
"""Initialises the wrapper with the environments, run the observation and action space tests."""
|
"""Initialises the wrapper with the environments, run the observation and action space tests."""
|
||||||
super().__init__(env)
|
gym.utils.RecordConstructorArgs.__init__(self)
|
||||||
|
gym.Wrapper.__init__(self, env)
|
||||||
|
|
||||||
assert hasattr(
|
assert hasattr(
|
||||||
env, "action_space"
|
env, "action_space"
|
||||||
|
@@ -6,7 +6,7 @@ import gymnasium as gym
|
|||||||
from gymnasium import spaces
|
from gymnasium import spaces
|
||||||
|
|
||||||
|
|
||||||
class FilterObservation(gym.ObservationWrapper):
|
class FilterObservation(gym.ObservationWrapper, gym.utils.RecordConstructorArgs):
|
||||||
"""Filter Dict observation space by the keys.
|
"""Filter Dict observation space by the keys.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
@@ -35,7 +35,8 @@ class FilterObservation(gym.ObservationWrapper):
|
|||||||
ValueError: If the environment's observation space is not :class:`spaces.Dict`
|
ValueError: If the environment's observation space is not :class:`spaces.Dict`
|
||||||
ValueError: If any of the `filter_keys` are not included in the original `env`'s observation space
|
ValueError: If any of the `filter_keys` are not included in the original `env`'s observation space
|
||||||
"""
|
"""
|
||||||
super().__init__(env)
|
gym.utils.RecordConstructorArgs.__init__(self, filter_keys=filter_keys)
|
||||||
|
gym.ObservationWrapper.__init__(self, env)
|
||||||
|
|
||||||
wrapped_observation_space = env.observation_space
|
wrapped_observation_space = env.observation_space
|
||||||
if not isinstance(wrapped_observation_space, spaces.Dict):
|
if not isinstance(wrapped_observation_space, spaces.Dict):
|
||||||
|
@@ -3,7 +3,7 @@ import gymnasium as gym
|
|||||||
from gymnasium import spaces
|
from gymnasium import spaces
|
||||||
|
|
||||||
|
|
||||||
class FlattenObservation(gym.ObservationWrapper):
|
class FlattenObservation(gym.ObservationWrapper, gym.utils.RecordConstructorArgs):
|
||||||
"""Observation wrapper that flattens the observation.
|
"""Observation wrapper that flattens the observation.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
@@ -26,7 +26,9 @@ class FlattenObservation(gym.ObservationWrapper):
|
|||||||
Args:
|
Args:
|
||||||
env: The environment to apply the wrapper
|
env: The environment to apply the wrapper
|
||||||
"""
|
"""
|
||||||
super().__init__(env)
|
gym.utils.RecordConstructorArgs.__init__(self)
|
||||||
|
gym.ObservationWrapper.__init__(self, env)
|
||||||
|
|
||||||
self.observation_space = spaces.flatten_space(env.observation_space)
|
self.observation_space = spaces.flatten_space(env.observation_space)
|
||||||
|
|
||||||
def observation(self, observation):
|
def observation(self, observation):
|
||||||
|
@@ -97,7 +97,7 @@ class LazyFrames:
|
|||||||
return frame
|
return frame
|
||||||
|
|
||||||
|
|
||||||
class FrameStack(gym.ObservationWrapper):
|
class FrameStack(gym.ObservationWrapper, gym.utils.RecordConstructorArgs):
|
||||||
"""Observation wrapper that stacks the observations in a rolling manner.
|
"""Observation wrapper that stacks the observations in a rolling manner.
|
||||||
|
|
||||||
For example, if the number of stacks is 4, then the returned observation contains
|
For example, if the number of stacks is 4, then the returned observation contains
|
||||||
@@ -137,7 +137,11 @@ class FrameStack(gym.ObservationWrapper):
|
|||||||
num_stack (int): The number of frames to stack
|
num_stack (int): The number of frames to stack
|
||||||
lz4_compress (bool): Use lz4 to compress the frames internally
|
lz4_compress (bool): Use lz4 to compress the frames internally
|
||||||
"""
|
"""
|
||||||
super().__init__(env)
|
gym.utils.RecordConstructorArgs.__init__(
|
||||||
|
self, num_stack=num_stack, lz4_compress=lz4_compress
|
||||||
|
)
|
||||||
|
gym.ObservationWrapper.__init__(self, env)
|
||||||
|
|
||||||
self.num_stack = num_stack
|
self.num_stack = num_stack
|
||||||
self.lz4_compress = lz4_compress
|
self.lz4_compress = lz4_compress
|
||||||
|
|
||||||
|
@@ -5,7 +5,7 @@ import gymnasium as gym
|
|||||||
from gymnasium.spaces import Box
|
from gymnasium.spaces import Box
|
||||||
|
|
||||||
|
|
||||||
class GrayScaleObservation(gym.ObservationWrapper):
|
class GrayScaleObservation(gym.ObservationWrapper, gym.utils.RecordConstructorArgs):
|
||||||
"""Convert the image observation from RGB to gray scale.
|
"""Convert the image observation from RGB to gray scale.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
@@ -30,7 +30,9 @@ class GrayScaleObservation(gym.ObservationWrapper):
|
|||||||
keep_dim (bool): If `True`, a singleton dimension will be added, i.e. observations are of the shape AxBx1.
|
keep_dim (bool): If `True`, a singleton dimension will be added, i.e. observations are of the shape AxBx1.
|
||||||
Otherwise, they are of shape AxB.
|
Otherwise, they are of shape AxB.
|
||||||
"""
|
"""
|
||||||
super().__init__(env)
|
gym.utils.RecordConstructorArgs.__init__(self, keep_dim=keep_dim)
|
||||||
|
gym.ObservationWrapper.__init__(self, env)
|
||||||
|
|
||||||
self.keep_dim = keep_dim
|
self.keep_dim = keep_dim
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
|
@@ -7,7 +7,7 @@ import gymnasium as gym
|
|||||||
from gymnasium.error import DependencyNotInstalled
|
from gymnasium.error import DependencyNotInstalled
|
||||||
|
|
||||||
|
|
||||||
class HumanRendering(gym.Wrapper):
|
class HumanRendering(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
||||||
"""Performs human rendering for an environment that only supports "rgb_array"rendering.
|
"""Performs human rendering for an environment that only supports "rgb_array"rendering.
|
||||||
|
|
||||||
This wrapper is particularly useful when you have implemented an environment that can produce
|
This wrapper is particularly useful when you have implemented an environment that can produce
|
||||||
@@ -47,7 +47,9 @@ class HumanRendering(gym.Wrapper):
|
|||||||
Args:
|
Args:
|
||||||
env: The environment that is being wrapped
|
env: The environment that is being wrapped
|
||||||
"""
|
"""
|
||||||
super().__init__(env)
|
gym.utils.RecordConstructorArgs.__init__(self)
|
||||||
|
gym.Wrapper.__init__(self, env)
|
||||||
|
|
||||||
assert env.render_mode in [
|
assert env.render_mode in [
|
||||||
"rgb_array",
|
"rgb_array",
|
||||||
"rgb_array_list",
|
"rgb_array_list",
|
||||||
@@ -64,6 +66,8 @@ class HumanRendering(gym.Wrapper):
|
|||||||
if "human" not in self.metadata["render_modes"]:
|
if "human" not in self.metadata["render_modes"]:
|
||||||
self.metadata["render_modes"].append("human")
|
self.metadata["render_modes"].append("human")
|
||||||
|
|
||||||
|
gym.utils.RecordConstructorArgs.__init__(self)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def render_mode(self):
|
def render_mode(self):
|
||||||
"""Always returns ``'human'``."""
|
"""Always returns ``'human'``."""
|
||||||
|
@@ -45,7 +45,7 @@ def update_mean_var_count_from_moments(
|
|||||||
return new_mean, new_var, new_count
|
return new_mean, new_var, new_count
|
||||||
|
|
||||||
|
|
||||||
class NormalizeObservation(gym.Wrapper):
|
class NormalizeObservation(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
||||||
"""This wrapper will normalize observations s.t. each coordinate is centered with unit variance.
|
"""This wrapper will normalize observations s.t. each coordinate is centered with unit variance.
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
@@ -60,7 +60,9 @@ class NormalizeObservation(gym.Wrapper):
|
|||||||
env (Env): The environment to apply the wrapper
|
env (Env): The environment to apply the wrapper
|
||||||
epsilon: A stability parameter that is used when scaling the observations.
|
epsilon: A stability parameter that is used when scaling the observations.
|
||||||
"""
|
"""
|
||||||
super().__init__(env)
|
gym.utils.RecordConstructorArgs.__init__(self, epsilon=epsilon)
|
||||||
|
gym.Wrapper.__init__(self, env)
|
||||||
|
|
||||||
self.num_envs = getattr(env, "num_envs", 1)
|
self.num_envs = getattr(env, "num_envs", 1)
|
||||||
self.is_vector_env = getattr(env, "is_vector_env", False)
|
self.is_vector_env = getattr(env, "is_vector_env", False)
|
||||||
if self.is_vector_env:
|
if self.is_vector_env:
|
||||||
@@ -93,7 +95,7 @@ class NormalizeObservation(gym.Wrapper):
|
|||||||
return (obs - self.obs_rms.mean) / np.sqrt(self.obs_rms.var + self.epsilon)
|
return (obs - self.obs_rms.mean) / np.sqrt(self.obs_rms.var + self.epsilon)
|
||||||
|
|
||||||
|
|
||||||
class NormalizeReward(gym.core.Wrapper):
|
class NormalizeReward(gym.core.Wrapper, gym.utils.RecordConstructorArgs):
|
||||||
r"""This wrapper will normalize immediate rewards s.t. their exponential moving average has a fixed variance.
|
r"""This wrapper will normalize immediate rewards s.t. their exponential moving average has a fixed variance.
|
||||||
|
|
||||||
The exponential moving average will have variance :math:`(1 - \gamma)^2`.
|
The exponential moving average will have variance :math:`(1 - \gamma)^2`.
|
||||||
@@ -116,7 +118,9 @@ class NormalizeReward(gym.core.Wrapper):
|
|||||||
epsilon (float): A stability parameter
|
epsilon (float): A stability parameter
|
||||||
gamma (float): The discount factor that is used in the exponential moving average.
|
gamma (float): The discount factor that is used in the exponential moving average.
|
||||||
"""
|
"""
|
||||||
super().__init__(env)
|
gym.utils.RecordConstructorArgs.__init__(self, gamma=gamma, epsilon=epsilon)
|
||||||
|
gym.Wrapper.__init__(self, env)
|
||||||
|
|
||||||
self.num_envs = getattr(env, "num_envs", 1)
|
self.num_envs = getattr(env, "num_envs", 1)
|
||||||
self.is_vector_env = getattr(env, "is_vector_env", False)
|
self.is_vector_env = getattr(env, "is_vector_env", False)
|
||||||
self.return_rms = RunningMeanStd(shape=())
|
self.return_rms = RunningMeanStd(shape=())
|
||||||
|
@@ -3,7 +3,7 @@ import gymnasium as gym
|
|||||||
from gymnasium.error import ResetNeeded
|
from gymnasium.error import ResetNeeded
|
||||||
|
|
||||||
|
|
||||||
class OrderEnforcing(gym.Wrapper):
|
class OrderEnforcing(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
||||||
"""A wrapper that will produce an error if :meth:`step` is called before an initial :meth:`reset`.
|
"""A wrapper that will produce an error if :meth:`step` is called before an initial :meth:`reset`.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
@@ -32,7 +32,11 @@ class OrderEnforcing(gym.Wrapper):
|
|||||||
env: The environment to wrap
|
env: The environment to wrap
|
||||||
disable_render_order_enforcing: If to disable render order enforcing
|
disable_render_order_enforcing: If to disable render order enforcing
|
||||||
"""
|
"""
|
||||||
super().__init__(env)
|
gym.utils.RecordConstructorArgs.__init__(
|
||||||
|
self, disable_render_order_enforcing=disable_render_order_enforcing
|
||||||
|
)
|
||||||
|
gym.Wrapper.__init__(self, env)
|
||||||
|
|
||||||
self._has_reset: bool = False
|
self._has_reset: bool = False
|
||||||
self._disable_render_order_enforcing: bool = disable_render_order_enforcing
|
self._disable_render_order_enforcing: bool = disable_render_order_enforcing
|
||||||
|
|
||||||
|
@@ -13,7 +13,7 @@ from gymnasium import spaces
|
|||||||
STATE_KEY = "state"
|
STATE_KEY = "state"
|
||||||
|
|
||||||
|
|
||||||
class PixelObservationWrapper(gym.ObservationWrapper):
|
class PixelObservationWrapper(gym.ObservationWrapper, gym.utils.RecordConstructorArgs):
|
||||||
"""Augment observations by pixel values.
|
"""Augment observations by pixel values.
|
||||||
|
|
||||||
Observations of this wrapper will be dictionaries of images.
|
Observations of this wrapper will be dictionaries of images.
|
||||||
@@ -79,7 +79,13 @@ class PixelObservationWrapper(gym.ObservationWrapper):
|
|||||||
specified ``pixel_keys``.
|
specified ``pixel_keys``.
|
||||||
TypeError: When an unexpected pixel type is used
|
TypeError: When an unexpected pixel type is used
|
||||||
"""
|
"""
|
||||||
super().__init__(env)
|
gym.utils.RecordConstructorArgs.__init__(
|
||||||
|
self,
|
||||||
|
pixels_only=pixels_only,
|
||||||
|
render_kwargs=render_kwargs,
|
||||||
|
pixel_keys=pixel_keys,
|
||||||
|
)
|
||||||
|
gym.ObservationWrapper.__init__(self, env)
|
||||||
|
|
||||||
# Avoid side-effects that occur when render_kwargs is manipulated
|
# Avoid side-effects that occur when render_kwargs is manipulated
|
||||||
render_kwargs = copy.deepcopy(render_kwargs)
|
render_kwargs = copy.deepcopy(render_kwargs)
|
||||||
|
@@ -8,7 +8,7 @@ import numpy as np
|
|||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
|
|
||||||
|
|
||||||
class RecordEpisodeStatistics(gym.Wrapper):
|
class RecordEpisodeStatistics(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
||||||
"""This wrapper will keep track of cumulative rewards and episode lengths.
|
"""This wrapper will keep track of cumulative rewards and episode lengths.
|
||||||
|
|
||||||
At the end of an episode, the statistics of the episode will be added to ``info``
|
At the end of an episode, the statistics of the episode will be added to ``info``
|
||||||
@@ -56,7 +56,9 @@ class RecordEpisodeStatistics(gym.Wrapper):
|
|||||||
env (Env): The environment to apply the wrapper
|
env (Env): The environment to apply the wrapper
|
||||||
deque_size: The size of the buffers :attr:`return_queue` and :attr:`length_queue`
|
deque_size: The size of the buffers :attr:`return_queue` and :attr:`length_queue`
|
||||||
"""
|
"""
|
||||||
super().__init__(env)
|
gym.utils.RecordConstructorArgs.__init__(self, deque_size=deque_size)
|
||||||
|
gym.Wrapper.__init__(self, env)
|
||||||
|
|
||||||
self.num_envs = getattr(env, "num_envs", 1)
|
self.num_envs = getattr(env, "num_envs", 1)
|
||||||
self.episode_count = 0
|
self.episode_count = 0
|
||||||
self.episode_start_times: np.ndarray = None
|
self.episode_start_times: np.ndarray = None
|
||||||
|
@@ -24,7 +24,7 @@ def capped_cubic_video_schedule(episode_id: int) -> bool:
|
|||||||
return episode_id % 1000 == 0
|
return episode_id % 1000 == 0
|
||||||
|
|
||||||
|
|
||||||
class RecordVideo(gym.Wrapper):
|
class RecordVideo(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
||||||
"""This wrapper records videos of rollouts.
|
"""This wrapper records videos of rollouts.
|
||||||
|
|
||||||
Usually, you only want to record episodes intermittently, say every hundredth episode.
|
Usually, you only want to record episodes intermittently, say every hundredth episode.
|
||||||
@@ -58,9 +58,17 @@ class RecordVideo(gym.Wrapper):
|
|||||||
Otherwise, snippets of the specified length are captured
|
Otherwise, snippets of the specified length are captured
|
||||||
name_prefix (str): Will be prepended to the filename of the recordings
|
name_prefix (str): Will be prepended to the filename of the recordings
|
||||||
disable_logger (bool): Whether to disable moviepy logger or not.
|
disable_logger (bool): Whether to disable moviepy logger or not.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
super().__init__(env)
|
gym.utils.RecordConstructorArgs.__init__(
|
||||||
|
self,
|
||||||
|
video_folder=video_folder,
|
||||||
|
episode_trigger=episode_trigger,
|
||||||
|
step_trigger=step_trigger,
|
||||||
|
video_length=video_length,
|
||||||
|
name_prefix=name_prefix,
|
||||||
|
disable_logger=disable_logger,
|
||||||
|
)
|
||||||
|
gym.Wrapper.__init__(self, env)
|
||||||
|
|
||||||
if episode_trigger is None and step_trigger is None:
|
if episode_trigger is None and step_trigger is None:
|
||||||
episode_trigger = capped_cubic_video_schedule
|
episode_trigger = capped_cubic_video_schedule
|
||||||
|
@@ -4,7 +4,7 @@ import copy
|
|||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
|
|
||||||
|
|
||||||
class RenderCollection(gym.Wrapper):
|
class RenderCollection(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
||||||
"""Save collection of render frames."""
|
"""Save collection of render frames."""
|
||||||
|
|
||||||
def __init__(self, env: gym.Env, pop_frames: bool = True, reset_clean: bool = True):
|
def __init__(self, env: gym.Env, pop_frames: bool = True, reset_clean: bool = True):
|
||||||
@@ -17,7 +17,11 @@ class RenderCollection(gym.Wrapper):
|
|||||||
reset_clean (bool): If true, clear the collection frames when .reset() is called.
|
reset_clean (bool): If true, clear the collection frames when .reset() is called.
|
||||||
Default value is True.
|
Default value is True.
|
||||||
"""
|
"""
|
||||||
super().__init__(env)
|
gym.utils.RecordConstructorArgs.__init__(
|
||||||
|
self, pop_frames=pop_frames, reset_clean=reset_clean
|
||||||
|
)
|
||||||
|
gym.Wrapper.__init__(self, env)
|
||||||
|
|
||||||
assert env.render_mode is not None
|
assert env.render_mode is not None
|
||||||
assert not env.render_mode.endswith("_list")
|
assert not env.render_mode.endswith("_list")
|
||||||
self.frame_list = []
|
self.frame_list = []
|
||||||
|
@@ -7,7 +7,7 @@ import gymnasium as gym
|
|||||||
from gymnasium.spaces import Box
|
from gymnasium.spaces import Box
|
||||||
|
|
||||||
|
|
||||||
class RescaleAction(gym.ActionWrapper):
|
class RescaleAction(gym.ActionWrapper, gym.utils.RecordConstructorArgs):
|
||||||
"""Affinely rescales the continuous action space of the environment to the range [min_action, max_action].
|
"""Affinely rescales the continuous action space of the environment to the range [min_action, max_action].
|
||||||
|
|
||||||
The base environment :attr:`env` must have an action space of type :class:`spaces.Box`. If :attr:`min_action`
|
The base environment :attr:`env` must have an action space of type :class:`spaces.Box`. If :attr:`min_action`
|
||||||
@@ -47,7 +47,11 @@ class RescaleAction(gym.ActionWrapper):
|
|||||||
), f"expected Box action space, got {type(env.action_space)}"
|
), f"expected Box action space, got {type(env.action_space)}"
|
||||||
assert np.less_equal(min_action, max_action).all(), (min_action, max_action)
|
assert np.less_equal(min_action, max_action).all(), (min_action, max_action)
|
||||||
|
|
||||||
super().__init__(env)
|
gym.utils.RecordConstructorArgs.__init__(
|
||||||
|
self, min_action=min_action, max_action=max_action
|
||||||
|
)
|
||||||
|
gym.ActionWrapper.__init__(self, env)
|
||||||
|
|
||||||
self.min_action = (
|
self.min_action = (
|
||||||
np.zeros(env.action_space.shape, dtype=env.action_space.dtype) + min_action
|
np.zeros(env.action_space.shape, dtype=env.action_space.dtype) + min_action
|
||||||
)
|
)
|
||||||
|
@@ -8,7 +8,7 @@ from gymnasium.error import DependencyNotInstalled
|
|||||||
from gymnasium.spaces import Box
|
from gymnasium.spaces import Box
|
||||||
|
|
||||||
|
|
||||||
class ResizeObservation(gym.ObservationWrapper):
|
class ResizeObservation(gym.ObservationWrapper, gym.utils.RecordConstructorArgs):
|
||||||
"""Resize the image observation.
|
"""Resize the image observation.
|
||||||
|
|
||||||
This wrapper works on environments with image observations. More generally,
|
This wrapper works on environments with image observations. More generally,
|
||||||
@@ -36,7 +36,9 @@ class ResizeObservation(gym.ObservationWrapper):
|
|||||||
env: The environment to apply the wrapper
|
env: The environment to apply the wrapper
|
||||||
shape: The shape of the resized observations
|
shape: The shape of the resized observations
|
||||||
"""
|
"""
|
||||||
super().__init__(env)
|
gym.utils.RecordConstructorArgs.__init__(self, shape=shape)
|
||||||
|
gym.ObservationWrapper.__init__(self, env)
|
||||||
|
|
||||||
if isinstance(shape, int):
|
if isinstance(shape, int):
|
||||||
shape = (shape, shape)
|
shape = (shape, shape)
|
||||||
assert len(shape) == 2 and all(
|
assert len(shape) == 2 and all(
|
||||||
|
@@ -4,7 +4,7 @@ from gymnasium.logger import deprecation
|
|||||||
from gymnasium.utils.step_api_compatibility import step_api_compatibility
|
from gymnasium.utils.step_api_compatibility import step_api_compatibility
|
||||||
|
|
||||||
|
|
||||||
class StepAPICompatibility(gym.Wrapper):
|
class StepAPICompatibility(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
||||||
r"""A wrapper which can transform an environment from new step API to old and vice-versa.
|
r"""A wrapper which can transform an environment from new step API to old and vice-versa.
|
||||||
|
|
||||||
Old step API refers to step() method returning (observation, reward, done, info)
|
Old step API refers to step() method returning (observation, reward, done, info)
|
||||||
@@ -29,7 +29,11 @@ class StepAPICompatibility(gym.Wrapper):
|
|||||||
env (gym.Env): the env to wrap. Can be in old or new API
|
env (gym.Env): the env to wrap. Can be in old or new API
|
||||||
output_truncation_bool (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API)
|
output_truncation_bool (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API)
|
||||||
"""
|
"""
|
||||||
super().__init__(env)
|
gym.utils.RecordConstructorArgs.__init__(
|
||||||
|
self, output_truncation_bool=output_truncation_bool
|
||||||
|
)
|
||||||
|
gym.Wrapper.__init__(self, env)
|
||||||
|
|
||||||
self.is_vector_env = isinstance(env.unwrapped, gym.vector.VectorEnv)
|
self.is_vector_env = isinstance(env.unwrapped, gym.vector.VectorEnv)
|
||||||
self.output_truncation_bool = output_truncation_bool
|
self.output_truncation_bool = output_truncation_bool
|
||||||
if not self.output_truncation_bool:
|
if not self.output_truncation_bool:
|
||||||
|
@@ -5,7 +5,7 @@ import gymnasium as gym
|
|||||||
from gymnasium.spaces import Box
|
from gymnasium.spaces import Box
|
||||||
|
|
||||||
|
|
||||||
class TimeAwareObservation(gym.ObservationWrapper):
|
class TimeAwareObservation(gym.ObservationWrapper, gym.utils.RecordConstructorArgs):
|
||||||
"""Augment the observation with the current time step in the episode.
|
"""Augment the observation with the current time step in the episode.
|
||||||
|
|
||||||
The observation space of the wrapped environment is assumed to be a flat :class:`Box`.
|
The observation space of the wrapped environment is assumed to be a flat :class:`Box`.
|
||||||
@@ -29,7 +29,9 @@ class TimeAwareObservation(gym.ObservationWrapper):
|
|||||||
Args:
|
Args:
|
||||||
env: The environment to apply the wrapper
|
env: The environment to apply the wrapper
|
||||||
"""
|
"""
|
||||||
super().__init__(env)
|
gym.utils.RecordConstructorArgs.__init__(self)
|
||||||
|
gym.ObservationWrapper.__init__(self, env)
|
||||||
|
|
||||||
assert isinstance(env.observation_space, Box)
|
assert isinstance(env.observation_space, Box)
|
||||||
assert env.observation_space.dtype == np.float32
|
assert env.observation_space.dtype == np.float32
|
||||||
low = np.append(self.observation_space.low, 0.0)
|
low = np.append(self.observation_space.low, 0.0)
|
||||||
|
@@ -4,7 +4,7 @@ from typing import Optional
|
|||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
|
|
||||||
|
|
||||||
class TimeLimit(gym.Wrapper):
|
class TimeLimit(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
||||||
"""This wrapper will issue a `truncated` signal if a maximum number of timesteps is exceeded.
|
"""This wrapper will issue a `truncated` signal if a maximum number of timesteps is exceeded.
|
||||||
|
|
||||||
If a truncation is not defined inside the environment itself, this is the only place that the truncation signal is issued.
|
If a truncation is not defined inside the environment itself, this is the only place that the truncation signal is issued.
|
||||||
@@ -26,14 +26,16 @@ class TimeLimit(gym.Wrapper):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
env: The environment to apply the wrapper
|
env: The environment to apply the wrapper
|
||||||
max_episode_steps: An optional max episode steps (if ``Ǹone``, ``env.spec.max_episode_steps`` is used)
|
max_episode_steps: An optional max episode steps (if ``None``, ``env.spec.max_episode_steps`` is used)
|
||||||
"""
|
"""
|
||||||
super().__init__(env)
|
gym.utils.RecordConstructorArgs.__init__(
|
||||||
|
self, max_episode_steps=max_episode_steps
|
||||||
|
)
|
||||||
|
gym.Wrapper.__init__(self, env)
|
||||||
|
|
||||||
if max_episode_steps is None and self.env.spec is not None:
|
if max_episode_steps is None and self.env.spec is not None:
|
||||||
assert env.spec is not None
|
assert env.spec is not None
|
||||||
max_episode_steps = env.spec.max_episode_steps
|
max_episode_steps = env.spec.max_episode_steps
|
||||||
if self.env.spec is not None:
|
|
||||||
self.env.spec.max_episode_steps = max_episode_steps
|
|
||||||
self._max_episode_steps = max_episode_steps
|
self._max_episode_steps = max_episode_steps
|
||||||
self._elapsed_steps = None
|
self._elapsed_steps = None
|
||||||
|
|
||||||
|
@@ -4,7 +4,7 @@ from typing import Any, Callable
|
|||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
|
|
||||||
|
|
||||||
class TransformObservation(gym.ObservationWrapper):
|
class TransformObservation(gym.ObservationWrapper, gym.utils.RecordConstructorArgs):
|
||||||
"""Transform the observation via an arbitrary function :attr:`f`.
|
"""Transform the observation via an arbitrary function :attr:`f`.
|
||||||
|
|
||||||
The function :attr:`f` should be defined on the observation space of the base environment, ``env``, and should, ideally, return values in the same space.
|
The function :attr:`f` should be defined on the observation space of the base environment, ``env``, and should, ideally, return values in the same space.
|
||||||
@@ -29,7 +29,9 @@ class TransformObservation(gym.ObservationWrapper):
|
|||||||
env: The environment to apply the wrapper
|
env: The environment to apply the wrapper
|
||||||
f: A function that transforms the observation
|
f: A function that transforms the observation
|
||||||
"""
|
"""
|
||||||
super().__init__(env)
|
gym.utils.RecordConstructorArgs.__init__(self, f=f)
|
||||||
|
gym.ObservationWrapper.__init__(self, env)
|
||||||
|
|
||||||
assert callable(f)
|
assert callable(f)
|
||||||
self.f = f
|
self.f = f
|
||||||
|
|
||||||
|
@@ -2,10 +2,9 @@
|
|||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
from gymnasium import RewardWrapper
|
|
||||||
|
|
||||||
|
|
||||||
class TransformReward(RewardWrapper):
|
class TransformReward(gym.RewardWrapper, gym.utils.RecordConstructorArgs):
|
||||||
"""Transform the reward via an arbitrary function.
|
"""Transform the reward via an arbitrary function.
|
||||||
|
|
||||||
Warning:
|
Warning:
|
||||||
@@ -29,7 +28,9 @@ class TransformReward(RewardWrapper):
|
|||||||
env: The environment to apply the wrapper
|
env: The environment to apply the wrapper
|
||||||
f: A function that transforms the reward
|
f: A function that transforms the reward
|
||||||
"""
|
"""
|
||||||
super().__init__(env)
|
gym.utils.RecordConstructorArgs.__init__(self, f=f)
|
||||||
|
gym.RewardWrapper.__init__(self, env)
|
||||||
|
|
||||||
assert callable(f)
|
assert callable(f)
|
||||||
self.f = f
|
self.f = f
|
||||||
|
|
||||||
|
@@ -5,7 +5,7 @@ from typing import List
|
|||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
|
|
||||||
|
|
||||||
class VectorListInfo(gym.Wrapper):
|
class VectorListInfo(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
||||||
"""Converts infos of vectorized environments from dict to List[dict].
|
"""Converts infos of vectorized environments from dict to List[dict].
|
||||||
|
|
||||||
This wrapper converts the info format of a
|
This wrapper converts the info format of a
|
||||||
@@ -51,7 +51,9 @@ class VectorListInfo(gym.Wrapper):
|
|||||||
assert getattr(
|
assert getattr(
|
||||||
env, "is_vector_env", False
|
env, "is_vector_env", False
|
||||||
), "This wrapper can only be used in vectorized environments."
|
), "This wrapper can only be used in vectorized environments."
|
||||||
super().__init__(env)
|
|
||||||
|
gym.utils.RecordConstructorArgs.__init__(self)
|
||||||
|
gym.Wrapper.__init__(self, env)
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
"""Steps through the environment, convert dict info to list."""
|
"""Steps through the environment, convert dict info to list."""
|
||||||
|
237
tests/envs/registration/test_env_spec.py
Normal file
237
tests/envs/registration/test_env_spec.py
Normal file
@@ -0,0 +1,237 @@
|
|||||||
|
"""Example file showing usage of env.specstack."""
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import gymnasium as gym
|
||||||
|
from gymnasium.envs.classic_control import CartPoleEnv
|
||||||
|
from gymnasium.envs.registration import EnvSpec
|
||||||
|
from gymnasium.utils.env_checker import data_equivalence
|
||||||
|
|
||||||
|
|
||||||
|
def test_full_integration():
|
||||||
|
# Create an environment to test with
|
||||||
|
env = gym.make("CartPole-v1", render_mode="rgb_array").unwrapped
|
||||||
|
|
||||||
|
env = gym.wrappers.FlattenObservation(env)
|
||||||
|
env = gym.wrappers.TimeAwareObservation(env)
|
||||||
|
env = gym.wrappers.NormalizeReward(env, gamma=0.8)
|
||||||
|
|
||||||
|
# Generate the spec_stack
|
||||||
|
env_spec = env.spec
|
||||||
|
assert isinstance(env_spec, EnvSpec)
|
||||||
|
# env_spec.pprint()
|
||||||
|
|
||||||
|
# Serialize the spec_stack
|
||||||
|
env_spec_json = env_spec.to_json()
|
||||||
|
assert isinstance(env_spec_json, str)
|
||||||
|
|
||||||
|
# Deserialize the spec_stack
|
||||||
|
recreate_env_spec = EnvSpec.from_json(env_spec_json)
|
||||||
|
# recreate_env_spec.pprint()
|
||||||
|
|
||||||
|
for wrapper_spec, recreated_wrapper_spec in zip(
|
||||||
|
env_spec.applied_wrappers, recreate_env_spec.applied_wrappers
|
||||||
|
):
|
||||||
|
assert wrapper_spec == recreated_wrapper_spec
|
||||||
|
assert recreate_env_spec == env_spec
|
||||||
|
|
||||||
|
# Recreate the environment using the spec_stack
|
||||||
|
recreated_env = gym.make(recreate_env_spec)
|
||||||
|
assert recreated_env.render_mode == "rgb_array"
|
||||||
|
assert isinstance(recreated_env, gym.wrappers.NormalizeReward)
|
||||||
|
assert recreated_env.gamma == 0.8
|
||||||
|
assert isinstance(recreated_env.env, gym.wrappers.TimeAwareObservation)
|
||||||
|
assert isinstance(recreated_env.unwrapped, CartPoleEnv)
|
||||||
|
|
||||||
|
obs, info = env.reset(seed=42)
|
||||||
|
recreated_obs, recreated_info = recreated_env.reset(seed=42)
|
||||||
|
assert data_equivalence(obs, recreated_obs)
|
||||||
|
assert data_equivalence(info, recreated_info)
|
||||||
|
|
||||||
|
action = env.action_space.sample()
|
||||||
|
obs, reward, terminated, truncated, info = env.step(action)
|
||||||
|
(
|
||||||
|
recreated_obs,
|
||||||
|
recreated_reward,
|
||||||
|
recreated_terminated,
|
||||||
|
recreated_truncated,
|
||||||
|
recreated_info,
|
||||||
|
) = recreated_env.step(action)
|
||||||
|
assert data_equivalence(obs, recreated_obs)
|
||||||
|
assert data_equivalence(reward, recreated_reward)
|
||||||
|
assert data_equivalence(terminated, recreated_terminated)
|
||||||
|
assert data_equivalence(truncated, recreated_truncated)
|
||||||
|
assert data_equivalence(info, recreated_info)
|
||||||
|
|
||||||
|
# Test the pprint of the spec_stack
|
||||||
|
spec_stack_output = env_spec.pprint(disable_print=True)
|
||||||
|
json_spec_stack_output = env_spec.pprint(disable_print=True)
|
||||||
|
assert spec_stack_output == json_spec_stack_output
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"env_spec",
|
||||||
|
[
|
||||||
|
gym.spec("CartPole-v1"),
|
||||||
|
gym.make("CartPole-v1").unwrapped.spec,
|
||||||
|
gym.make("CartPole-v1").spec,
|
||||||
|
gym.wrappers.NormalizeReward(gym.make("CartPole-v1")).spec,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_env_spec_to_from_json(env_spec: EnvSpec):
|
||||||
|
json_spec = env_spec.to_json()
|
||||||
|
recreated_env_spec = EnvSpec.from_json(json_spec)
|
||||||
|
|
||||||
|
assert env_spec == recreated_env_spec
|
||||||
|
|
||||||
|
|
||||||
|
def test_wrapped_env_entry_point():
|
||||||
|
def _create_env():
|
||||||
|
_env = gym.make("CartPole-v1", render_mode="rgb_array")
|
||||||
|
_env = gym.wrappers.FlattenObservation(_env)
|
||||||
|
return _env
|
||||||
|
|
||||||
|
gym.register("TestingEnv-v0", entry_point=_create_env)
|
||||||
|
|
||||||
|
env = gym.make("TestingEnv-v0")
|
||||||
|
env = gym.wrappers.TimeAwareObservation(env)
|
||||||
|
env = gym.wrappers.NormalizeReward(env, gamma=0.8)
|
||||||
|
|
||||||
|
recreated_env = gym.make(env.spec)
|
||||||
|
|
||||||
|
obs, info = env.reset(seed=42)
|
||||||
|
recreated_obs, recreated_info = recreated_env.reset(seed=42)
|
||||||
|
assert data_equivalence(obs, recreated_obs)
|
||||||
|
assert data_equivalence(info, recreated_info)
|
||||||
|
|
||||||
|
action = env.action_space.sample()
|
||||||
|
obs, reward, terminated, truncated, info = env.step(action)
|
||||||
|
(
|
||||||
|
recreated_obs,
|
||||||
|
recreated_reward,
|
||||||
|
recreated_terminated,
|
||||||
|
recreated_truncated,
|
||||||
|
recreated_info,
|
||||||
|
) = recreated_env.step(action)
|
||||||
|
assert data_equivalence(obs, recreated_obs)
|
||||||
|
assert data_equivalence(reward, recreated_reward)
|
||||||
|
assert data_equivalence(terminated, recreated_terminated)
|
||||||
|
assert data_equivalence(truncated, recreated_truncated)
|
||||||
|
assert data_equivalence(info, recreated_info)
|
||||||
|
|
||||||
|
del gym.registry["TestingEnv-v0"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_pickling_env_stack():
|
||||||
|
env = gym.make("CartPole-v1", render_mode="rgb_array")
|
||||||
|
|
||||||
|
env = gym.wrappers.FlattenObservation(env)
|
||||||
|
env = gym.wrappers.TimeAwareObservation(env)
|
||||||
|
env = gym.wrappers.NormalizeReward(env, gamma=0.8)
|
||||||
|
|
||||||
|
pickled_env = pickle.loads(pickle.dumps(env))
|
||||||
|
|
||||||
|
obs, info = env.reset(seed=123)
|
||||||
|
pickled_obs, pickled_info = pickled_env.reset(seed=123)
|
||||||
|
|
||||||
|
assert data_equivalence(obs, pickled_obs)
|
||||||
|
assert data_equivalence(info, pickled_info)
|
||||||
|
|
||||||
|
action = env.action_space.sample()
|
||||||
|
obs, reward, terminated, truncated, info = env.step(action)
|
||||||
|
(
|
||||||
|
pickled_obs,
|
||||||
|
pickled_reward,
|
||||||
|
pickled_terminated,
|
||||||
|
pickled_truncated,
|
||||||
|
pickled_info,
|
||||||
|
) = pickled_env.step(action)
|
||||||
|
|
||||||
|
assert data_equivalence(obs, pickled_obs)
|
||||||
|
assert data_equivalence(reward, pickled_reward)
|
||||||
|
assert data_equivalence(terminated, pickled_terminated)
|
||||||
|
assert data_equivalence(truncated, pickled_truncated)
|
||||||
|
assert data_equivalence(info, pickled_info)
|
||||||
|
|
||||||
|
env.close()
|
||||||
|
pickled_env.close()
|
||||||
|
|
||||||
|
|
||||||
|
# flake8: noqa
|
||||||
|
|
||||||
|
|
||||||
|
def test_env_spec_pprint():
|
||||||
|
env = gym.make("CartPole-v1")
|
||||||
|
env_spec = env.spec
|
||||||
|
assert env_spec is not None
|
||||||
|
|
||||||
|
output = env_spec.pprint(disable_print=True)
|
||||||
|
assert (
|
||||||
|
output
|
||||||
|
== """id=CartPole-v1
|
||||||
|
reward_threshold=475.0
|
||||||
|
max_episode_steps=500
|
||||||
|
applied_wrappers=[
|
||||||
|
name=PassiveEnvChecker, kwargs={},
|
||||||
|
name=OrderEnforcing, kwargs={'disable_render_order_enforcing': False},
|
||||||
|
name=TimeLimit, kwargs={'max_episode_steps': 500}
|
||||||
|
]"""
|
||||||
|
)
|
||||||
|
|
||||||
|
output = env_spec.pprint(disable_print=True, include_entry_points=True)
|
||||||
|
assert (
|
||||||
|
output
|
||||||
|
== """id=CartPole-v1
|
||||||
|
entry_point=gymnasium.envs.classic_control.cartpole:CartPoleEnv
|
||||||
|
reward_threshold=475.0
|
||||||
|
max_episode_steps=500
|
||||||
|
applied_wrappers=[
|
||||||
|
name=PassiveEnvChecker, entry_point=gymnasium.wrappers.env_checker:PassiveEnvChecker, kwargs={},
|
||||||
|
name=OrderEnforcing, entry_point=gymnasium.wrappers.order_enforcing:OrderEnforcing, kwargs={'disable_render_order_enforcing': False},
|
||||||
|
name=TimeLimit, entry_point=gymnasium.wrappers.time_limit:TimeLimit, kwargs={'max_episode_steps': 500}
|
||||||
|
]"""
|
||||||
|
)
|
||||||
|
|
||||||
|
output = env_spec.pprint(disable_print=True, print_all=True)
|
||||||
|
assert (
|
||||||
|
output
|
||||||
|
== """id=CartPole-v1
|
||||||
|
entry_point=gymnasium.envs.classic_control.cartpole:CartPoleEnv
|
||||||
|
reward_threshold=475.0
|
||||||
|
nondeterministic=False
|
||||||
|
max_episode_steps=500
|
||||||
|
order_enforce=True
|
||||||
|
autoreset=False
|
||||||
|
disable_env_checker=False
|
||||||
|
applied_api_compatibility=False
|
||||||
|
applied_wrappers=[
|
||||||
|
name=PassiveEnvChecker, kwargs={},
|
||||||
|
name=OrderEnforcing, kwargs={'disable_render_order_enforcing': False},
|
||||||
|
name=TimeLimit, kwargs={'max_episode_steps': 500}
|
||||||
|
]"""
|
||||||
|
)
|
||||||
|
|
||||||
|
env_spec.applied_wrappers = ()
|
||||||
|
output = env_spec.pprint(disable_print=True)
|
||||||
|
assert (
|
||||||
|
output
|
||||||
|
== """id=CartPole-v1
|
||||||
|
reward_threshold=475.0
|
||||||
|
max_episode_steps=500"""
|
||||||
|
)
|
||||||
|
|
||||||
|
output = env_spec.pprint(disable_print=True, print_all=True)
|
||||||
|
assert (
|
||||||
|
output
|
||||||
|
== """id=CartPole-v1
|
||||||
|
entry_point=gymnasium.envs.classic_control.cartpole:CartPoleEnv
|
||||||
|
reward_threshold=475.0
|
||||||
|
nondeterministic=False
|
||||||
|
max_episode_steps=500
|
||||||
|
order_enforce=True
|
||||||
|
autoreset=False
|
||||||
|
disable_env_checker=False
|
||||||
|
applied_api_compatibility=False
|
||||||
|
applied_wrappers=[]"""
|
||||||
|
)
|
@@ -8,6 +8,8 @@ import numpy as np
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
|
from gymnasium import Env
|
||||||
|
from gymnasium.core import ActType, ObsType, WrapperObsType
|
||||||
from gymnasium.envs.classic_control import CartPoleEnv
|
from gymnasium.envs.classic_control import CartPoleEnv
|
||||||
from gymnasium.wrappers import (
|
from gymnasium.wrappers import (
|
||||||
AutoResetWrapper,
|
AutoResetWrapper,
|
||||||
@@ -355,3 +357,69 @@ def test_import_module_during_make():
|
|||||||
env.close()
|
env.close()
|
||||||
|
|
||||||
del gym.registry["RegisterDuringMake-v0"]
|
del gym.registry["RegisterDuringMake-v0"]
|
||||||
|
|
||||||
|
|
||||||
|
class NoRecordArgsWrapper(gym.ObservationWrapper):
|
||||||
|
def __init__(self, env: Env[ObsType, ActType]):
|
||||||
|
super().__init__(env)
|
||||||
|
|
||||||
|
def observation(self, observation: ObsType) -> WrapperObsType:
|
||||||
|
return self.observation_space.sample()
|
||||||
|
|
||||||
|
|
||||||
|
def test_make_env_spec():
|
||||||
|
# make
|
||||||
|
env_1 = gym.make(gym.spec("CartPole-v1"))
|
||||||
|
assert isinstance(env_1, CartPoleEnv)
|
||||||
|
assert env_1 is env_1.unwrapped
|
||||||
|
env_1.close()
|
||||||
|
|
||||||
|
# make with applied wrappers
|
||||||
|
env_2 = gym.wrappers.NormalizeReward(
|
||||||
|
gym.wrappers.TimeAwareObservation(
|
||||||
|
gym.wrappers.FlattenObservation(
|
||||||
|
gym.make("CartPole-v1", render_mode="rgb_array")
|
||||||
|
)
|
||||||
|
),
|
||||||
|
gamma=0.8,
|
||||||
|
)
|
||||||
|
env_2_recreated = gym.make(env_2.spec)
|
||||||
|
assert env_2.spec == env_2_recreated.spec
|
||||||
|
env_2.close()
|
||||||
|
env_2_recreated.close()
|
||||||
|
|
||||||
|
# make with callable entry point
|
||||||
|
gym.register("CartPole-v2", lambda: CartPoleEnv())
|
||||||
|
env_3 = gym.make("CartPole-v2")
|
||||||
|
assert isinstance(env_3.unwrapped, CartPoleEnv)
|
||||||
|
env_3.close()
|
||||||
|
|
||||||
|
# make with wrapper in env-creator
|
||||||
|
gym.register(
|
||||||
|
"CartPole-v3", lambda: gym.wrappers.TimeAwareObservation(CartPoleEnv())
|
||||||
|
)
|
||||||
|
env_4 = gym.make(gym.spec("CartPole-v3"))
|
||||||
|
assert isinstance(env_4, gym.wrappers.TimeAwareObservation)
|
||||||
|
assert isinstance(env_4.env, CartPoleEnv)
|
||||||
|
env_4.close()
|
||||||
|
|
||||||
|
# make with no ezpickle wrapper
|
||||||
|
env_5 = NoRecordArgsWrapper(gym.make("CartPole-v1").unwrapped)
|
||||||
|
env_5.close()
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match=re.escape(
|
||||||
|
"NoRecordArgsWrapper wrapper does not inherit from `gymnasium.utils.RecordConstructorArgs`, therefore, the wrapper cannot be recreated."
|
||||||
|
),
|
||||||
|
):
|
||||||
|
gym.make(env_5.spec)
|
||||||
|
|
||||||
|
# make with no ezpickle wrapper but in the entry point
|
||||||
|
gym.register("CartPole-v4", entry_point=lambda: NoRecordArgsWrapper(CartPoleEnv()))
|
||||||
|
env_6 = gym.make(gym.spec("CartPole-v4"))
|
||||||
|
assert isinstance(env_6, NoRecordArgsWrapper)
|
||||||
|
assert isinstance(env_6.unwrapped, CartPoleEnv)
|
||||||
|
|
||||||
|
del gym.registry["CartPole-v2"]
|
||||||
|
del gym.registry["CartPole-v3"]
|
||||||
|
del gym.registry["CartPole-v4"]
|
||||||
|
@@ -22,7 +22,7 @@ def test_mujoco_action_dimensions(env_spec: EnvSpec):
|
|||||||
* Too many dimensions
|
* Too many dimensions
|
||||||
* Incorrect shape
|
* Incorrect shape
|
||||||
"""
|
"""
|
||||||
env = env_spec.make(disable_env_checker=True)
|
env = env_spec.make()
|
||||||
env.reset()
|
env.reset()
|
||||||
|
|
||||||
# Too few actions
|
# Too few actions
|
||||||
|
@@ -42,7 +42,7 @@ CHECK_ENV_IGNORE_WARNINGS = [
|
|||||||
def test_all_env_api(spec):
|
def test_all_env_api(spec):
|
||||||
"""Check that all environments pass the environment checker with no warnings other than the expected."""
|
"""Check that all environments pass the environment checker with no warnings other than the expected."""
|
||||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||||
env = spec.make(disable_env_checker=True).unwrapped
|
env = spec.make().unwrapped
|
||||||
check_env(env, skip_render_check=True)
|
check_env(env, skip_render_check=True)
|
||||||
|
|
||||||
env.close()
|
env.close()
|
||||||
|
@@ -15,8 +15,8 @@ def verify_environments_match(
|
|||||||
):
|
):
|
||||||
"""Verifies with two environment ids (old and new) are identical in obs, reward and done
|
"""Verifies with two environment ids (old and new) are identical in obs, reward and done
|
||||||
(except info where all old info must be contained in new info)."""
|
(except info where all old info must be contained in new info)."""
|
||||||
old_env = envs.make(old_env_id, disable_env_checker=True)
|
old_env = envs.make(old_env_id)
|
||||||
new_env = envs.make(new_env_id, disable_env_checker=True)
|
new_env = envs.make(new_env_id)
|
||||||
|
|
||||||
old_reset_obs, old_info = old_env.reset(seed=seed)
|
old_reset_obs, old_info = old_env.reset(seed=seed)
|
||||||
new_reset_obs, new_info = new_env.reset(seed=seed)
|
new_reset_obs, new_info = new_env.reset(seed=seed)
|
||||||
|
@@ -106,7 +106,7 @@ class TestNestedDictWrapper:
|
|||||||
observation_space = env.observation_space
|
observation_space = env.observation_space
|
||||||
assert isinstance(observation_space, Dict)
|
assert isinstance(observation_space, Dict)
|
||||||
|
|
||||||
wrapped_env = FlattenObservation(FilterObservation(env, env.obs_keys))
|
wrapped_env = FlattenObservation(FilterObservation(env, list(env.obs_keys)))
|
||||||
assert wrapped_env.observation_space.shape == flat_shape
|
assert wrapped_env.observation_space.shape == flat_shape
|
||||||
|
|
||||||
assert wrapped_env.observation_space.dtype == np.float32
|
assert wrapped_env.observation_space.dtype == np.float32
|
||||||
@@ -114,7 +114,7 @@ class TestNestedDictWrapper:
|
|||||||
@pytest.mark.parametrize("observation_space, flat_shape", NESTED_DICT_TEST_CASES)
|
@pytest.mark.parametrize("observation_space, flat_shape", NESTED_DICT_TEST_CASES)
|
||||||
def test_nested_dicts_ravel(self, observation_space, flat_shape):
|
def test_nested_dicts_ravel(self, observation_space, flat_shape):
|
||||||
env = FakeEnvironment(observation_space=observation_space)
|
env = FakeEnvironment(observation_space=observation_space)
|
||||||
wrapped_env = FlattenObservation(FilterObservation(env, env.obs_keys))
|
wrapped_env = FlattenObservation(FilterObservation(env, list(env.obs_keys)))
|
||||||
obs, info = wrapped_env.reset()
|
obs, info = wrapped_env.reset()
|
||||||
assert obs.shape == wrapped_env.observation_space.shape
|
assert obs.shape == wrapped_env.observation_space.shape
|
||||||
assert isinstance(info, dict)
|
assert isinstance(info, dict)
|
||||||
|
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
|
|
||||||
|
|
||||||
def has_wrapper(wrapped_env: gym.Env, wrapper_type: type) -> bool:
|
def has_wrapper(wrapped_env: gym.Env, wrapper_type: type[gym.Wrapper]) -> bool:
|
||||||
while isinstance(wrapped_env, gym.Wrapper):
|
while isinstance(wrapped_env, gym.Wrapper):
|
||||||
if isinstance(wrapped_env, wrapper_type):
|
if isinstance(wrapped_env, wrapper_type):
|
||||||
return True
|
return True
|
||||||
|
Reference in New Issue
Block a user