mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-19 13:32:03 +00:00
Add Pydocstyle to CI (#2785)
* Added pydocstyle to pre-commit * Added docstrings for tests and updated the tests for autoreset * Add pydocstyle exclude folder to allow slowly adding new docstrings * Add docstrings for setup.py and gym/__init__.py, core.py, error.py and logger.py * Check that all unwrapped environment are of a particular wrapper type * Reverted back to import gym.spaces.Space to gym.spaces * Fixed the __init__.py docstring * Fixed autoreset autoreset test * Updated gym __init__.py top docstring * Remove unnecessary import * Removed "unused error" and make APIerror deprecated at gym 1.0 * Add pydocstyle description to CONTRIBUTING.md * Added docstrings section to CONTRIBUTING.md * Added :meth: and :attr: keywords to docstrings * Added :meth: and :attr: keywords to docstrings * Update the step docstring placing the return type in the as a note. * Updated step return type to include each element * Update maths notation to reward range * Fixed infinity maths notation
This commit is contained in:
@@ -26,6 +26,16 @@ repos:
|
|||||||
hooks:
|
hooks:
|
||||||
- id: isort
|
- id: isort
|
||||||
args: ["--profile", "black"]
|
args: ["--profile", "black"]
|
||||||
|
- repo: https://github.com/pycqa/pydocstyle
|
||||||
|
rev: 6.1.1 # pick a git hash / tag to point to
|
||||||
|
hooks:
|
||||||
|
- id: pydocstyle
|
||||||
|
exclude: ^(gym/version.py)|(gym/(wrappers|envs|spaces|utils|vector)/)|(tests/)
|
||||||
|
args:
|
||||||
|
- --source
|
||||||
|
- --explain
|
||||||
|
- --convention=google
|
||||||
|
additional_dependencies: ["toml"]
|
||||||
- repo: https://github.com/asottile/pyupgrade
|
- repo: https://github.com/asottile/pyupgrade
|
||||||
rev: v2.32.0
|
rev: v2.32.0
|
||||||
hooks:
|
hooks:
|
||||||
|
@@ -43,3 +43,13 @@ The Git hooks can also be run manually with `pre-commit run --all-files`, and if
|
|||||||
|
|
||||||
Additionally, for pull requests, the project runs a number of tests for the whole project using [pytest](https://docs.pytest.org/en/latest/getting-started.html#install-pytest).
|
Additionally, for pull requests, the project runs a number of tests for the whole project using [pytest](https://docs.pytest.org/en/latest/getting-started.html#install-pytest).
|
||||||
These tests can be run locally with `pytest` in the root folder.
|
These tests can be run locally with `pytest` in the root folder.
|
||||||
|
|
||||||
|
## Docstrings
|
||||||
|
Pydocstyle has been added to the pre-commit process such that all new functions follow the (google docstring style)[https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html].
|
||||||
|
All new functions require either a short docstring, a single line explaining the purpose of a function
|
||||||
|
or a multiline docstring that documents each argument and the return type (if there is one) of the function.
|
||||||
|
In addition, new file and class require top docstrings that should outline the purpose of the file/class.
|
||||||
|
For classes, code block examples can be provided in the top docstring and not the constructor arguments.
|
||||||
|
|
||||||
|
To check your docstrings are correct, run `pre-commit run --al-files` or `pydocstyle --source --explain --convention=google`.
|
||||||
|
If all docstrings that fail, the source and reason for the failure is provided.
|
@@ -1,3 +1,4 @@
|
|||||||
|
"""Root __init__ of the gym module setting the __all__ of gym modules."""
|
||||||
# isort: skip_file
|
# isort: skip_file
|
||||||
|
|
||||||
from gym import error
|
from gym import error
|
||||||
|
216
gym/core.py
216
gym/core.py
@@ -1,7 +1,8 @@
|
|||||||
|
"""Core API for Environment, Wrapper, ActionWrapper, RewardWrapper and ObservationWrapper."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from typing import Generic, Optional, SupportsFloat, Tuple, TypeVar, Union
|
from typing import Generic, Optional, SupportsFloat, TypeVar, Union
|
||||||
|
|
||||||
from gym import spaces
|
from gym import spaces
|
||||||
from gym.logger import deprecation
|
from gym.logger import deprecation
|
||||||
@@ -13,27 +14,30 @@ ActType = TypeVar("ActType")
|
|||||||
|
|
||||||
|
|
||||||
class Env(Generic[ObsType, ActType]):
|
class Env(Generic[ObsType, ActType]):
|
||||||
"""The main OpenAI Gym class. It encapsulates an environment with
|
r"""The main OpenAI Gym class.
|
||||||
arbitrary behind-the-scenes dynamics. An environment can be
|
|
||||||
partially or fully observed.
|
It encapsulates an environment with arbitrary behind-the-scenes dynamics.
|
||||||
|
An environment can be partially or fully observed.
|
||||||
|
|
||||||
The main API methods that users of this class need to know are:
|
The main API methods that users of this class need to know are:
|
||||||
|
|
||||||
step
|
- :meth:`step` - Takes a step in the environment using an action returning the next observation, reward,
|
||||||
reset
|
if the environment terminated and more information.
|
||||||
render
|
- :meth:`reset` - Resets the environment to an initial state, returning the initial observation.
|
||||||
close
|
- :meth:`render` - Renders the environment observation with modes depending on the output
|
||||||
seed
|
- :meth:`close` - Closes the environment, important for rendering where pygame is imported
|
||||||
|
- :meth:`seed` - Seeds the environment's random number generator, :deprecated: in favor of `Env.reset(seed=seed)`.
|
||||||
|
|
||||||
And set the following attributes:
|
And set the following attributes:
|
||||||
|
|
||||||
action_space: The Space object corresponding to valid actions
|
- :attr:`action_space` - The Space object corresponding to valid actions
|
||||||
observation_space: The Space object corresponding to valid observations
|
- :attr:`observation_space` - The Space object corresponding to valid observations
|
||||||
reward_range: A tuple corresponding to the min and max possible rewards
|
- :attr:`reward_range` - A tuple corresponding to the minimum and maximum possible rewards
|
||||||
|
- :attr:`spec` - An environment spec that contains the information used to initialise the environment from `gym.make`
|
||||||
|
- :attr:`metadata` - The metadata of the environment, i.e. render modes
|
||||||
|
- :attr:`np_random` - The random number generator for the environment
|
||||||
|
|
||||||
Note: a default reward range set to [-inf,+inf] already exists. Set it if you want a narrower range.
|
Note: a default reward range set to :math:`(-\infty,+\infty)` already exists. Set it if you want a narrower range.
|
||||||
|
|
||||||
The methods are accessed publicly as "step", "reset", etc...
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Set this in SOME subclasses
|
# Set this in SOME subclasses
|
||||||
@@ -46,11 +50,11 @@ class Env(Generic[ObsType, ActType]):
|
|||||||
observation_space: spaces.Space[ObsType]
|
observation_space: spaces.Space[ObsType]
|
||||||
|
|
||||||
# Created
|
# Created
|
||||||
_np_random: RandomNumberGenerator | None = None
|
_np_random: Optional[RandomNumberGenerator] = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def np_random(self) -> RandomNumberGenerator:
|
def np_random(self) -> RandomNumberGenerator:
|
||||||
"""Initializes the np_random field if not done already."""
|
"""Returns the environment's internal :attr:`_np_random` that if not set will initialise with a random seed."""
|
||||||
if self._np_random is None:
|
if self._np_random is None:
|
||||||
self._np_random, seed = seeding.np_random()
|
self._np_random, seed = seeding.np_random()
|
||||||
return self._np_random
|
return self._np_random
|
||||||
@@ -60,28 +64,27 @@ class Env(Generic[ObsType, ActType]):
|
|||||||
self._np_random = value
|
self._np_random = value
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def step(self, action: ActType) -> Tuple[ObsType, float, bool, dict]:
|
def step(self, action: ActType) -> tuple[ObsType, float, bool, dict]:
|
||||||
"""Run one timestep of the environment's dynamics. When end of
|
"""Run one timestep of the environment's dynamics.
|
||||||
episode is reached, you are responsible for calling :meth:`reset`
|
|
||||||
to reset this environment's state.
|
|
||||||
|
|
||||||
Accepts an action and returns a tuple (observation, reward, done, info).
|
When end of episode is reached, you are responsible for calling :meth:`reset` to reset this environment's state.
|
||||||
|
Accepts an action and returns a tuple `(observation, reward, done, info)`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
action (object): an action provided by the agent
|
action (object): an action provided by the agent
|
||||||
|
|
||||||
This method returns a tuple ``(observation, reward, done, info)``
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
observation (object): agent's observation of the current environment. This will be an element of the environment's :attr:`observation_space`. This may, for instance, be a numpy array containing the positions and velocities of certain objects.
|
observation (object): this will be an element of the environment's :attr:`observation_space`.
|
||||||
reward (float) : amount of reward returned after previous action
|
This may, for instance, be a numpy array containing the positions and velocities of certain objects.
|
||||||
done (bool): whether the episode has ended, in which case further :meth:`step` calls will return undefined results. A done signal may be emitted for different reasons: Maybe the task underlying the environment was solved successfully, a certain timelimit was exceeded, or the physics simulation has entered an invalid state. ``info`` may contain additional information regarding the reason for a ``done`` signal.
|
reward (float): The amount of reward returned as a result of taking the action.
|
||||||
info (dict): contains auxiliary diagnostic information (helpful for debugging, learning, and logging). This might, for instance, contain:
|
done (bool): A boolean value for if the episode has ended, in which case further :meth:`step` calls will return undefined results.
|
||||||
|
A done signal may be emitted for different reasons: Maybe the task underlying the environment was solved successfully,
|
||||||
- metrics that describe the agent's performance or
|
a certain timelimit was exceeded, or the physics simulation has entered an invalid state.
|
||||||
- state variables that are hidden from observations or
|
info (dictionary): A dictionary that may contain additional information regarding the reason for a ``done`` signal.
|
||||||
- information that distinguishes truncation and termination or
|
`info` contains auxiliary diagnostic information (helpful for debugging, learning, and logging).
|
||||||
- individual reward terms that are combined to produce the total reward
|
This might, for instance, contain: metrics that describe the agent's performance state, variables that are
|
||||||
|
hidden from observations, information that distinguishes truncation and termination or individual reward terms
|
||||||
|
that are combined to produce the total reward
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@@ -93,28 +96,34 @@ class Env(Generic[ObsType, ActType]):
|
|||||||
return_info: bool = False,
|
return_info: bool = False,
|
||||||
options: Optional[dict] = None,
|
options: Optional[dict] = None,
|
||||||
) -> Union[ObsType, tuple[ObsType, dict]]:
|
) -> Union[ObsType, tuple[ObsType, dict]]:
|
||||||
"""Resets the environment to an initial state and returns an initial
|
"""Resets the environment to an initial state and returns the initial observation.
|
||||||
observation.
|
|
||||||
|
|
||||||
This method should also reset the environment's random number
|
This method can reset the environment's random number generator(s) if ``seed`` is an integer or
|
||||||
generator(s) if ``seed`` is an integer or if the environment has not
|
if the environment has not yet initialized a random number generator.
|
||||||
yet initialized a random number generator. If the environment already
|
If the environment already has a random number generator and :meth:`reset` is called with ``seed=None``,
|
||||||
has a random number generator and :meth:`reset` is called with ``seed=None``,
|
the RNG should not be reset. Moreover, :meth:`reset` should (in the typical use case) be called with an
|
||||||
the RNG should not be reset.
|
|
||||||
Moreover, :meth:`reset` should (in the typical use case) be called with an
|
|
||||||
integer seed right after initialization and then never again.
|
integer seed right after initialization and then never again.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
seed (int or None):
|
seed (optional int): The seed that is used to initialize the environment's PRNG.
|
||||||
The seed that is used to initialize the environment's PRNG. If the environment does not already have a PRNG and ``seed=None`` (the default option) is passed, a seed will be chosen from some source of entropy (e.g. timestamp or /dev/urandom).
|
If the environment does not already have a PRNG and ``seed=None`` (the default option) is passed,
|
||||||
However, if the environment already has a PRNG and ``seed=None`` is passed, the PRNG will *not* be reset. If you pass an integer, the PRNG will be reset even if it already exists. Usually, you want to pass an integer *right after the environment has been initialized and then never again*. Please refer to the minimal example above to see this paradigm in action.
|
a seed will be chosen from some source of entropy (e.g. timestamp or /dev/urandom).
|
||||||
return_info (bool): If true, return additional information along with initial observation. This info should be analogous to the info returned in :meth:`step`
|
However, if the environment already has a PRNG and ``seed=None`` is passed, the PRNG will *not* be reset.
|
||||||
options (dict or None): Additional information to specify how the environment is reset (optional, depending on the specific environment)
|
If you pass an integer, the PRNG will be reset even if it already exists.
|
||||||
|
Usually, you want to pass an integer *right after the environment has been initialized and then never again*.
|
||||||
|
Please refer to the minimal example above to see this paradigm in action.
|
||||||
|
return_info (bool): If true, return additional information along with initial observation.
|
||||||
|
This info should be analogous to the info returned in :meth:`step`
|
||||||
|
options (optional dict): Additional information to specify how the environment is reset (optional,
|
||||||
|
depending on the specific environment)
|
||||||
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
observation (object): Observation of the initial state. This will be an element of :attr:`observation_space` (usually a numpy array) and is analogous to the observation returned by :meth:`step`.
|
observation (object): Observation of the initial state. This will be an element of :attr:`observation_space`
|
||||||
info (optional dictionary): This will *only* be returned if ``return_info=True`` is passed. It contains auxiliary information complementing ``observation``. This dictionary should be analogous to the ``info`` returned by :meth:`step`.
|
(typically a numpy array) and is analogous to the observation returned by :meth:`step`.
|
||||||
|
info (optional dictionary): This will *only* be returned if ``return_info=True`` is passed.
|
||||||
|
It contains auxiliary information complementing ``observation``. This dictionary should be analogous to
|
||||||
|
the ``info`` returned by :meth:`step`.
|
||||||
"""
|
"""
|
||||||
# Initialize the RNG if the seed is manually passed
|
# Initialize the RNG if the seed is manually passed
|
||||||
if seed is not None:
|
if seed is not None:
|
||||||
@@ -124,13 +133,13 @@ class Env(Generic[ObsType, ActType]):
|
|||||||
def render(self, mode="human"):
|
def render(self, mode="human"):
|
||||||
"""Renders the environment.
|
"""Renders the environment.
|
||||||
|
|
||||||
The set of supported modes varies per environment. (And some
|
A set of supported modes varies per environment. (And some
|
||||||
third-party environments may not support rendering at all.)
|
third-party environments may not support rendering at all.)
|
||||||
By convention, if mode is:
|
By convention, if mode is:
|
||||||
|
|
||||||
- human: render to the current display or terminal and
|
- human: render to the current display or terminal and
|
||||||
return nothing. Usually for human consumption.
|
return nothing. Usually for human consumption.
|
||||||
- rgb_array: Return an numpy.ndarray with shape (x, y, 3),
|
- rgb_array: Return a numpy.ndarray with shape (x, y, 3),
|
||||||
representing RGB values for an x-by-y pixel image, suitable
|
representing RGB values for an x-by-y pixel image, suitable
|
||||||
for turning into a video.
|
for turning into a video.
|
||||||
- ansi: Return a string (str) or StringIO.StringIO containing a
|
- ansi: Return a string (str) or StringIO.StringIO containing a
|
||||||
@@ -139,60 +148,64 @@ class Env(Generic[ObsType, ActType]):
|
|||||||
|
|
||||||
Note:
|
Note:
|
||||||
Make sure that your class's metadata 'render_modes' key includes
|
Make sure that your class's metadata 'render_modes' key includes
|
||||||
the list of supported modes. It's recommended to call super()
|
the list of supported modes. It's recommended to call super()
|
||||||
in implementations to use the functionality of this method.
|
in implementations to use the functionality of this method.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> class MyEnv(Env):
|
||||||
|
... metadata = {'render_modes': ['human', 'rgb_array']}
|
||||||
|
...
|
||||||
|
... def render(self, mode='human'):
|
||||||
|
... if mode == 'rgb_array':
|
||||||
|
... return np.array(...) # return RGB frame suitable for video
|
||||||
|
... elif mode == 'human':
|
||||||
|
... ... # pop up a window and render
|
||||||
|
... else:
|
||||||
|
... super(MyEnv, self).render(mode=mode) # just raise an exception
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
mode (str): the mode to render with
|
mode: the mode to render with, valid modes are `env.metadata["render_modes"]`
|
||||||
|
|
||||||
Example::
|
|
||||||
|
|
||||||
class MyEnv(Env):
|
|
||||||
metadata = {'render_modes': ['human', 'rgb_array']}
|
|
||||||
|
|
||||||
def render(self, mode='human'):
|
|
||||||
if mode == 'rgb_array':
|
|
||||||
return np.array(...) # return RGB frame suitable for video
|
|
||||||
elif mode == 'human':
|
|
||||||
... # pop up a window and render
|
|
||||||
else:
|
|
||||||
super(MyEnv, self).render(mode=mode) # just raise an exception
|
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
"""Override close in your subclass to perform any necessary cleanup.
|
"""Override close in your subclass to perform any necessary cleanup.
|
||||||
|
|
||||||
Environments will automatically close() themselves when
|
Environments will automatically :meth:`close()` themselves when
|
||||||
garbage collected or when the program exits.
|
garbage collected or when the program exits.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def seed(self, seed=None):
|
def seed(self, seed=None):
|
||||||
"""Sets the seed for this env's random number generator(s).
|
""":deprecated: function that sets the seed for the environment's random number generator(s).
|
||||||
|
|
||||||
|
Use `env.reset(seed=seed)` as the new API for setting the seed of the environment.
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
Some environments use multiple pseudorandom number generators.
|
Some environments use multiple pseudorandom number generators.
|
||||||
We want to capture all such seeds used in order to ensure that
|
We want to capture all such seeds used in order to ensure that
|
||||||
there aren't accidental correlations between multiple generators.
|
there aren't accidental correlations between multiple generators.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seed(Optional int): The seed value for the random number geneartor
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list<bigint>: Returns the list of seeds used in this env's random
|
seeds (List[int]): Returns the list of seeds used in this environment's random
|
||||||
number generators. The first value in the list should be the
|
number generators. The first value in the list should be the
|
||||||
"main" seed, or the value which a reproducer should pass to
|
"main" seed, or the value which a reproducer should pass to
|
||||||
'seed'. Often, the main seed equals the provided 'seed', but
|
'seed'. Often, the main seed equals the provided 'seed', but
|
||||||
this won't be true if seed=None, for example.
|
this won't be true `if seed=None`, for example.
|
||||||
"""
|
"""
|
||||||
deprecation(
|
deprecation(
|
||||||
"Function `env.seed(seed)` is marked as deprecated and will be removed in the future. "
|
"Function `env.seed(seed)` is marked as deprecated and will be removed in the future. "
|
||||||
"Please use `env.reset(seed=seed) instead."
|
"Please use `env.reset(seed=seed)` instead."
|
||||||
)
|
)
|
||||||
self._np_random, seed = seeding.np_random(seed)
|
self._np_random, seed = seeding.np_random(seed)
|
||||||
return [seed]
|
return [seed]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def unwrapped(self) -> Env:
|
def unwrapped(self) -> Env:
|
||||||
"""Completely unwrap this env.
|
"""Returns the base non-wrapped environment.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
gym.Env: The base non-wrapped gym.Env instance
|
gym.Env: The base non-wrapped gym.Env instance
|
||||||
@@ -200,6 +213,7 @@ class Env(Generic[ObsType, ActType]):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
|
"""Returns a string of the environment with the spec id if specified."""
|
||||||
if self.spec is None:
|
if self.spec is None:
|
||||||
return f"<{type(self).__name__} instance>"
|
return f"<{type(self).__name__} instance>"
|
||||||
else:
|
else:
|
||||||
@@ -217,71 +231,81 @@ class Env(Generic[ObsType, ActType]):
|
|||||||
|
|
||||||
|
|
||||||
class Wrapper(Env[ObsType, ActType]):
|
class Wrapper(Env[ObsType, ActType]):
|
||||||
"""Wraps the environment to allow a modular transformation.
|
"""Wraps an environment to allow a modular transformation of the :meth:`step` and :meth:`reset` methods.
|
||||||
|
|
||||||
This class is the base class for all wrappers. The subclass could override
|
This class is the base class for all wrappers. The subclass could override
|
||||||
some methods to change the behavior of the original environment without touching the
|
some methods to change the behavior of the original environment without touching the
|
||||||
original code.
|
original code.
|
||||||
|
|
||||||
.. note::
|
Note:
|
||||||
|
|
||||||
Don't forget to call ``super().__init__(env)`` if the subclass overrides :meth:`__init__`.
|
Don't forget to call ``super().__init__(env)`` if the subclass overrides :meth:`__init__`.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, env: Env):
|
def __init__(self, env: Env):
|
||||||
|
"""Wraps an environment to allow a modular transformation of the :meth:`step` and :meth:`reset` methods.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env: The environment to wrap
|
||||||
|
"""
|
||||||
self.env = env
|
self.env = env
|
||||||
|
|
||||||
self._action_space: spaces.Space | None = None
|
self._action_space: Optional[spaces.Space] = None
|
||||||
self._observation_space: spaces.Space | None = None
|
self._observation_space: Optional[spaces.Space] = None
|
||||||
self._reward_range: tuple[SupportsFloat, SupportsFloat] | None = None
|
self._reward_range: Optional[tuple[SupportsFloat, SupportsFloat]] = None
|
||||||
self._metadata: dict | None = None
|
self._metadata: Optional[dict] = None
|
||||||
|
|
||||||
def __getattr__(self, name):
|
def __getattr__(self, name):
|
||||||
|
"""Returns an attribute with ``name``, unless ``name`` starts with an underscore."""
|
||||||
if name.startswith("_"):
|
if name.startswith("_"):
|
||||||
raise AttributeError(f"accessing private attribute '{name}' is prohibited")
|
raise AttributeError(f"accessing private attribute '{name}' is prohibited")
|
||||||
return getattr(self.env, name)
|
return getattr(self.env, name)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def spec(self):
|
def spec(self):
|
||||||
|
"""Returns the environment specification."""
|
||||||
return self.env.spec
|
return self.env.spec
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def class_name(cls):
|
def class_name(cls):
|
||||||
|
"""Returns the class name of the wrapper."""
|
||||||
return cls.__name__
|
return cls.__name__
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def action_space(self) -> spaces.Space[ActType]:
|
def action_space(self) -> spaces.Space[ActType]:
|
||||||
|
"""Returns the action space of the environment."""
|
||||||
if self._action_space is None:
|
if self._action_space is None:
|
||||||
return self.env.action_space
|
return self.env.action_space
|
||||||
return self._action_space
|
return self._action_space
|
||||||
|
|
||||||
@action_space.setter
|
@action_space.setter
|
||||||
def action_space(self, space):
|
def action_space(self, space: spaces.Space):
|
||||||
self._action_space = space
|
self._action_space = space
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def observation_space(self) -> spaces.Space:
|
def observation_space(self) -> spaces.Space:
|
||||||
|
"""Returns the observation space of the environment."""
|
||||||
if self._observation_space is None:
|
if self._observation_space is None:
|
||||||
return self.env.observation_space
|
return self.env.observation_space
|
||||||
return self._observation_space
|
return self._observation_space
|
||||||
|
|
||||||
@observation_space.setter
|
@observation_space.setter
|
||||||
def observation_space(self, space):
|
def observation_space(self, space: spaces.Space):
|
||||||
self._observation_space = space
|
self._observation_space = space
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def reward_range(self) -> tuple[SupportsFloat, SupportsFloat]:
|
def reward_range(self) -> tuple[SupportsFloat, SupportsFloat]:
|
||||||
|
"""Return the reward range of the environment."""
|
||||||
if self._reward_range is None:
|
if self._reward_range is None:
|
||||||
return self.env.reward_range
|
return self.env.reward_range
|
||||||
return self._reward_range
|
return self._reward_range
|
||||||
|
|
||||||
@reward_range.setter
|
@reward_range.setter
|
||||||
def reward_range(self, value):
|
def reward_range(self, value: tuple[SupportsFloat, SupportsFloat]):
|
||||||
self._reward_range = value
|
self._reward_range = value
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def metadata(self) -> dict:
|
def metadata(self) -> dict:
|
||||||
|
"""Returns the environment metadata."""
|
||||||
if self._metadata is None:
|
if self._metadata is None:
|
||||||
return self.env.metadata
|
return self.env.metadata
|
||||||
return self._metadata
|
return self._metadata
|
||||||
@@ -290,34 +314,45 @@ class Wrapper(Env[ObsType, ActType]):
|
|||||||
def metadata(self, value):
|
def metadata(self, value):
|
||||||
self._metadata = value
|
self._metadata = value
|
||||||
|
|
||||||
def step(self, action: ActType) -> Tuple[ObsType, float, bool, dict]:
|
def step(self, action: ActType) -> tuple[ObsType, float, bool, dict]:
|
||||||
|
"""Steps through the environment with action."""
|
||||||
return self.env.step(action)
|
return self.env.step(action)
|
||||||
|
|
||||||
def reset(self, **kwargs) -> Union[ObsType, tuple[ObsType, dict]]:
|
def reset(self, **kwargs) -> Union[ObsType, tuple[ObsType, dict]]:
|
||||||
|
"""Resets the environment with kwargs."""
|
||||||
return self.env.reset(**kwargs)
|
return self.env.reset(**kwargs)
|
||||||
|
|
||||||
def render(self, **kwargs):
|
def render(self, **kwargs):
|
||||||
|
"""Renders the environment with kwargs."""
|
||||||
return self.env.render(**kwargs)
|
return self.env.render(**kwargs)
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
|
"""Closes the environment."""
|
||||||
return self.env.close()
|
return self.env.close()
|
||||||
|
|
||||||
def seed(self, seed=None):
|
def seed(self, seed=None):
|
||||||
|
"""Seeds the environment."""
|
||||||
return self.env.seed(seed)
|
return self.env.seed(seed)
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
|
"""Returns the wrapper name and the unwrapped environment string."""
|
||||||
return f"<{type(self).__name__}{self.env}>"
|
return f"<{type(self).__name__}{self.env}>"
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
|
"""Returns the string representation of the wrapper."""
|
||||||
return str(self)
|
return str(self)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def unwrapped(self) -> Env:
|
def unwrapped(self) -> Env:
|
||||||
|
"""Returns the base environment of the wrapper."""
|
||||||
return self.env.unwrapped
|
return self.env.unwrapped
|
||||||
|
|
||||||
|
|
||||||
class ObservationWrapper(Wrapper):
|
class ObservationWrapper(Wrapper):
|
||||||
|
"""A wrapper that can modify observations using :meth:`observation` for :meth:`reset` and :meth:`step`."""
|
||||||
|
|
||||||
def reset(self, **kwargs):
|
def reset(self, **kwargs):
|
||||||
|
"""Resets the environment, returning a modified observation using :meth:`self.observation`."""
|
||||||
if kwargs.get("return_info", False):
|
if kwargs.get("return_info", False):
|
||||||
obs, info = self.env.reset(**kwargs)
|
obs, info = self.env.reset(**kwargs)
|
||||||
return self.observation(obs), info
|
return self.observation(obs), info
|
||||||
@@ -325,38 +360,43 @@ class ObservationWrapper(Wrapper):
|
|||||||
return self.observation(self.env.reset(**kwargs))
|
return self.observation(self.env.reset(**kwargs))
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
|
"""Returns a modified observation using :meth:`self.observation` after calling :meth:`env.step`."""
|
||||||
observation, reward, done, info = self.env.step(action)
|
observation, reward, done, info = self.env.step(action)
|
||||||
return self.observation(observation), reward, done, info
|
return self.observation(observation), reward, done, info
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def observation(self, observation):
|
def observation(self, observation):
|
||||||
|
"""Returns a modified observation."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class RewardWrapper(Wrapper):
|
class RewardWrapper(Wrapper):
|
||||||
def reset(self, **kwargs):
|
"""A wrapper that can modify the returning reward from a step."""
|
||||||
return self.env.reset(**kwargs)
|
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
|
"""Modifies the reward using :meth:`self.reward` after the environment :meth:`env.step`."""
|
||||||
observation, reward, done, info = self.env.step(action)
|
observation, reward, done, info = self.env.step(action)
|
||||||
return observation, self.reward(reward), done, info
|
return observation, self.reward(reward), done, info
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def reward(self, reward):
|
def reward(self, reward):
|
||||||
|
"""Returns a modified ``reward``."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class ActionWrapper(Wrapper):
|
class ActionWrapper(Wrapper):
|
||||||
def reset(self, **kwargs):
|
"""A wrapper that can modify the action before :meth:`env.step`."""
|
||||||
return self.env.reset(**kwargs)
|
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
|
"""Runs the environment :meth:`env.step` using the modified ``action`` from :meth:`self.action`."""
|
||||||
return self.env.step(self.action(action))
|
return self.env.step(self.action(action))
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def action(self, action):
|
def action(self, action):
|
||||||
|
"""Returns a modified action before :meth:`env.step` is called."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def reverse_action(self, action):
|
def reverse_action(self, action):
|
||||||
|
"""Returns a reversed ``action``."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
151
gym/error.py
151
gym/error.py
@@ -1,122 +1,76 @@
|
|||||||
|
"""Set of Error classes for gym."""
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
|
||||||
class Error(Exception):
|
class Error(Exception):
|
||||||
pass
|
"""Error superclass."""
|
||||||
|
|
||||||
|
|
||||||
# Local errors
|
# Local errors
|
||||||
|
|
||||||
|
|
||||||
class Unregistered(Error):
|
class Unregistered(Error):
|
||||||
"""Raised when the user requests an item from the registry that does
|
"""Raised when the user requests an item from the registry that does not actually exist."""
|
||||||
not actually exist.
|
|
||||||
"""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class UnregisteredEnv(Unregistered):
|
class UnregisteredEnv(Unregistered):
|
||||||
"""Raised when the user requests an env from the registry that does
|
"""Raised when the user requests an env from the registry that does not actually exist."""
|
||||||
not actually exist.
|
|
||||||
"""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class NamespaceNotFound(UnregisteredEnv):
|
class NamespaceNotFound(UnregisteredEnv):
|
||||||
"""Raised when the user requests an env from the registry where the
|
"""Raised when the user requests an env from the registry where the namespace doesn't exist."""
|
||||||
namespace doesn't exist.
|
|
||||||
"""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class NameNotFound(UnregisteredEnv):
|
class NameNotFound(UnregisteredEnv):
|
||||||
"""Raised when the user requests an env from the registry where the
|
"""Raised when the user requests an env from the registry where the name doesn't exist."""
|
||||||
name doesn't exist.
|
|
||||||
"""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class VersionNotFound(UnregisteredEnv):
|
class VersionNotFound(UnregisteredEnv):
|
||||||
"""Raised when the user requests an env from the registry where the
|
"""Raised when the user requests an env from the registry where the version doesn't exist."""
|
||||||
version doesn't exist.
|
|
||||||
"""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class UnregisteredBenchmark(Unregistered):
|
class UnregisteredBenchmark(Unregistered):
|
||||||
"""Raised when the user requests an env from the registry that does
|
"""Raised when the user requests an env from the registry that does not actually exist."""
|
||||||
not actually exist.
|
|
||||||
"""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class DeprecatedEnv(Error):
|
class DeprecatedEnv(Error):
|
||||||
"""Raised when the user requests an env from the registry with an
|
"""Raised when the user requests an env from the registry with an older version number than the latest env with the same name."""
|
||||||
older version number than the latest env with the same name.
|
|
||||||
"""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class RegistrationError(Error):
|
class RegistrationError(Error):
|
||||||
"""Raised when the user attempts to register an invalid env.
|
"""Raised when the user attempts to register an invalid env. For example, an unversioned env when a versioned env exists."""
|
||||||
For example, an unversioned env when a versioned env exists.
|
|
||||||
"""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class UnseedableEnv(Error):
|
class UnseedableEnv(Error):
|
||||||
"""Raised when the user tries to seed an env that does not support
|
"""Raised when the user tries to seed an env that does not support seeding."""
|
||||||
seeding.
|
|
||||||
"""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class DependencyNotInstalled(Error):
|
class DependencyNotInstalled(Error):
|
||||||
pass
|
"""Raised when the user has not installed a dependency."""
|
||||||
|
|
||||||
|
|
||||||
class UnsupportedMode(Exception):
|
class UnsupportedMode(Error):
|
||||||
"""Raised when the user requests a rendering mode not supported by the
|
"""Raised when the user requests a rendering mode not supported by the environment."""
|
||||||
environment.
|
|
||||||
"""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ResetNeeded(Exception):
|
class ResetNeeded(Error):
|
||||||
"""When the monitor is active, raised when the user tries to step an
|
"""When the order enforcing is violated, i.e. step or render is called before reset."""
|
||||||
environment that's already done.
|
|
||||||
"""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ResetNotAllowed(Exception):
|
class ResetNotAllowed(Error):
|
||||||
"""When the monitor is active, raised when the user tries to step an
|
"""When the monitor is active, raised when the user tries to step an environment that's not yet done."""
|
||||||
environment that's not yet done.
|
|
||||||
"""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class InvalidAction(Exception):
|
class InvalidAction(Error):
|
||||||
"""Raised when the user performs an action not contained within the
|
"""Raised when the user performs an action not contained within the action space."""
|
||||||
action space
|
|
||||||
"""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
# API errors
|
# API errors
|
||||||
|
|
||||||
|
|
||||||
class APIError(Error):
|
class APIError(Error):
|
||||||
|
"""Deprecated, to be removed at gym 1.0."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
message=None,
|
message=None,
|
||||||
@@ -125,8 +79,11 @@ class APIError(Error):
|
|||||||
json_body=None,
|
json_body=None,
|
||||||
headers=None,
|
headers=None,
|
||||||
):
|
):
|
||||||
|
"""Initialise API error."""
|
||||||
super().__init__(message)
|
super().__init__(message)
|
||||||
|
|
||||||
|
warnings.warn("APIError is deprecated and will be removed at gym 1.0")
|
||||||
|
|
||||||
if http_body and hasattr(http_body, "decode"):
|
if http_body and hasattr(http_body, "decode"):
|
||||||
try:
|
try:
|
||||||
http_body = http_body.decode("utf-8")
|
http_body = http_body.decode("utf-8")
|
||||||
@@ -141,6 +98,7 @@ class APIError(Error):
|
|||||||
self.request_id = self.headers.get("request-id", None)
|
self.request_id = self.headers.get("request-id", None)
|
||||||
|
|
||||||
def __unicode__(self):
|
def __unicode__(self):
|
||||||
|
"""Returns a string, if request_id is not None then make message other use the _message."""
|
||||||
if self.request_id is not None:
|
if self.request_id is not None:
|
||||||
msg = self._message or "<empty message>"
|
msg = self._message or "<empty message>"
|
||||||
return f"Request {self.request_id}: {msg}"
|
return f"Request {self.request_id}: {msg}"
|
||||||
@@ -148,14 +106,17 @@ class APIError(Error):
|
|||||||
return self._message
|
return self._message
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
|
"""Returns the __unicode__."""
|
||||||
return self.__unicode__()
|
return self.__unicode__()
|
||||||
|
|
||||||
|
|
||||||
class APIConnectionError(APIError):
|
class APIConnectionError(APIError):
|
||||||
pass
|
"""Deprecated, to be removed at gym 1.0."""
|
||||||
|
|
||||||
|
|
||||||
class InvalidRequestError(APIError):
|
class InvalidRequestError(APIError):
|
||||||
|
"""Deprecated, to be removed at gym 1.0."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
message,
|
message,
|
||||||
@@ -165,83 +126,69 @@ class InvalidRequestError(APIError):
|
|||||||
json_body=None,
|
json_body=None,
|
||||||
headers=None,
|
headers=None,
|
||||||
):
|
):
|
||||||
|
"""Initialises the invalid request error."""
|
||||||
super().__init__(message, http_body, http_status, json_body, headers)
|
super().__init__(message, http_body, http_status, json_body, headers)
|
||||||
self.param = param
|
self.param = param
|
||||||
|
|
||||||
|
|
||||||
class AuthenticationError(APIError):
|
class AuthenticationError(APIError):
|
||||||
pass
|
"""Deprecated, to be removed at gym 1.0."""
|
||||||
|
|
||||||
|
|
||||||
class RateLimitError(APIError):
|
class RateLimitError(APIError):
|
||||||
pass
|
"""Deprecated, to be removed at gym 1.0."""
|
||||||
|
|
||||||
|
|
||||||
# Video errors
|
# Video errors
|
||||||
|
|
||||||
|
|
||||||
class VideoRecorderError(Error):
|
class VideoRecorderError(Error):
|
||||||
pass
|
"""Unused error."""
|
||||||
|
|
||||||
|
|
||||||
class InvalidFrame(Error):
|
class InvalidFrame(Error):
|
||||||
pass
|
"""Error message when an invalid frame is captured."""
|
||||||
|
|
||||||
|
|
||||||
# Wrapper errors
|
# Wrapper errors
|
||||||
|
|
||||||
|
|
||||||
class DoubleWrapperError(Error):
|
class DoubleWrapperError(Error):
|
||||||
pass
|
"""Error message for when using double wrappers."""
|
||||||
|
|
||||||
|
|
||||||
class WrapAfterConfigureError(Error):
|
class WrapAfterConfigureError(Error):
|
||||||
pass
|
"""Error message for using wrap after configure."""
|
||||||
|
|
||||||
|
|
||||||
class RetriesExceededError(Error):
|
class RetriesExceededError(Error):
|
||||||
pass
|
"""Error message for retries exceeding set number."""
|
||||||
|
|
||||||
|
|
||||||
# Vectorized environments errors
|
# Vectorized environments errors
|
||||||
|
|
||||||
|
|
||||||
class AlreadyPendingCallError(Exception):
|
class AlreadyPendingCallError(Exception):
|
||||||
"""
|
"""Raised when `reset`, or `step` is called asynchronously (e.g. with `reset_async`, or `step_async` respectively), and `reset_async`, or `step_async` (respectively) is called again (without a complete call to `reset_wait`, or `step_wait` respectively)."""
|
||||||
Raised when `reset`, or `step` is called asynchronously (e.g. with
|
|
||||||
`reset_async`, or `step_async` respectively), and `reset_async`, or
|
|
||||||
`step_async` (respectively) is called again (without a complete call to
|
|
||||||
`reset_wait`, or `step_wait` respectively).
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, message, name):
|
def __init__(self, message: str, name: str):
|
||||||
|
"""Initialises the exception with name attributes."""
|
||||||
super().__init__(message)
|
super().__init__(message)
|
||||||
self.name = name
|
self.name = name
|
||||||
|
|
||||||
|
|
||||||
class NoAsyncCallError(Exception):
|
class NoAsyncCallError(Exception):
|
||||||
"""
|
"""Raised when an asynchronous `reset`, or `step` is not running, but `reset_wait`, or `step_wait` (respectively) is called."""
|
||||||
Raised when an asynchronous `reset`, or `step` is not running, but
|
|
||||||
`reset_wait`, or `step_wait` (respectively) is called.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, message, name):
|
def __init__(self, message: str, name: str):
|
||||||
|
"""Initialises the exception with name attributes."""
|
||||||
super().__init__(message)
|
super().__init__(message)
|
||||||
self.name = name
|
self.name = name
|
||||||
|
|
||||||
|
|
||||||
class ClosedEnvironmentError(Exception):
|
class ClosedEnvironmentError(Exception):
|
||||||
"""
|
"""Trying to call `reset`, or `step`, while the environment is closed."""
|
||||||
Trying to call `reset`, or `step`, while the environment is closed.
|
|
||||||
"""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class CustomSpaceError(Exception):
|
class CustomSpaceError(Exception):
|
||||||
"""
|
"""The space is a custom gym.Space instance, and is not supported by `AsyncVectorEnv` with `shared_memory=True`."""
|
||||||
The space is a custom gym.Space instance, and is not supported by
|
|
||||||
`AsyncVectorEnv` with `shared_memory=True`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
@@ -1,3 +1,4 @@
|
|||||||
|
"""Set of functions for logging messages."""
|
||||||
import sys
|
import sys
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Optional, Type
|
from typing import Optional, Type
|
||||||
@@ -16,20 +17,20 @@ min_level = 30
|
|||||||
warnings.simplefilter("once", DeprecationWarning)
|
warnings.simplefilter("once", DeprecationWarning)
|
||||||
|
|
||||||
|
|
||||||
def set_level(level: int) -> None:
|
def set_level(level: int):
|
||||||
"""
|
"""Set logging threshold on current logger."""
|
||||||
Set logging threshold on current logger.
|
|
||||||
"""
|
|
||||||
global min_level
|
global min_level
|
||||||
min_level = level
|
min_level = level
|
||||||
|
|
||||||
|
|
||||||
def debug(msg: str, *args: object):
|
def debug(msg: str, *args: object):
|
||||||
|
"""Logs a debug message to the user."""
|
||||||
if min_level <= DEBUG:
|
if min_level <= DEBUG:
|
||||||
print(f"DEBUG: {msg % args}", file=sys.stderr)
|
print(f"DEBUG: {msg % args}", file=sys.stderr)
|
||||||
|
|
||||||
|
|
||||||
def info(msg: str, *args: object):
|
def info(msg: str, *args: object):
|
||||||
|
"""Logs an info message to the user."""
|
||||||
if min_level <= INFO:
|
if min_level <= INFO:
|
||||||
print(f"INFO: {msg % args}", file=sys.stderr)
|
print(f"INFO: {msg % args}", file=sys.stderr)
|
||||||
|
|
||||||
@@ -40,6 +41,14 @@ def warn(
|
|||||||
category: Optional[Type[Warning]] = None,
|
category: Optional[Type[Warning]] = None,
|
||||||
stacklevel: int = 1,
|
stacklevel: int = 1,
|
||||||
):
|
):
|
||||||
|
"""Raises a warning to the user if the min_level <= WARN.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
msg: The message to warn the user
|
||||||
|
*args: Additional information to warn the user
|
||||||
|
category: The category of warning
|
||||||
|
stacklevel: The stack level to raise to
|
||||||
|
"""
|
||||||
if min_level <= WARN:
|
if min_level <= WARN:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
colorize(f"WARN: {msg % args}", "yellow"),
|
colorize(f"WARN: {msg % args}", "yellow"),
|
||||||
@@ -49,10 +58,12 @@ def warn(
|
|||||||
|
|
||||||
|
|
||||||
def deprecation(msg: str, *args: object):
|
def deprecation(msg: str, *args: object):
|
||||||
|
"""Logs a deprecation warning to users."""
|
||||||
warn(msg, *args, category=DeprecationWarning, stacklevel=2)
|
warn(msg, *args, category=DeprecationWarning, stacklevel=2)
|
||||||
|
|
||||||
|
|
||||||
def error(msg: str, *args: object):
|
def error(msg: str, *args: object):
|
||||||
|
"""Logs an error message if min_level <= ERROR in red on the sys.stderr."""
|
||||||
if min_level <= ERROR:
|
if min_level <= ERROR:
|
||||||
print(colorize(f"ERROR: {msg % args}", "red"), file=sys.stderr)
|
print(colorize(f"ERROR: {msg % args}", "red"), file=sys.stderr)
|
||||||
|
|
||||||
|
@@ -1,5 +1,6 @@
|
|||||||
"""A set of common utilities used within the environments. These are
|
"""A set of common utilities used within the environments.
|
||||||
not intended as API functions, and will not remain stable over time.
|
|
||||||
|
These are not intended as API functions, and will not remain stable over time.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# These submodules should not have any import-time dependencies.
|
# These submodules should not have any import-time dependencies.
|
||||||
|
@@ -33,7 +33,7 @@ class AutoResetWrapper(gym.Wrapper):
|
|||||||
new observation from after calling self.env.reset() is returned
|
new observation from after calling self.env.reset() is returned
|
||||||
by self.step() alongside the terminal reward and done state from the
|
by self.step() alongside the terminal reward and done state from the
|
||||||
previous episode . If you need the terminal state from the previous
|
previous episode . If you need the terminal state from the previous
|
||||||
episode, you need to retrieve it via the the "terminal_observation" key
|
episode, you need to retrieve it via the "terminal_observation" key
|
||||||
in the info dict. Make sure you know what you're doing if you
|
in the info dict. Make sure you know what you're doing if you
|
||||||
use this wrapper!
|
use this wrapper!
|
||||||
"""
|
"""
|
||||||
|
1
setup.py
1
setup.py
@@ -1,3 +1,4 @@
|
|||||||
|
"""Setups the project."""
|
||||||
import itertools
|
import itertools
|
||||||
import os.path
|
import os.path
|
||||||
import sys
|
import sys
|
||||||
|
@@ -4,8 +4,7 @@ import numpy as np
|
|||||||
import pytest
|
import pytest
|
||||||
from numpy.testing import assert_array_equal
|
from numpy.testing import assert_array_equal
|
||||||
|
|
||||||
from gym import Space
|
from gym.spaces import Box, Dict, MultiDiscrete, Space, Tuple
|
||||||
from gym.spaces import Box, Dict, MultiDiscrete, Tuple
|
|
||||||
from gym.vector.utils.spaces import batch_space, iterate
|
from gym.vector.utils.spaces import batch_space, iterate
|
||||||
from tests.vector.utils import CustomSpace, assert_rng_equal, custom_spaces, spaces
|
from tests.vector.utils import CustomSpace, assert_rng_equal, custom_spaces, spaces
|
||||||
|
|
||||||
|
@@ -1,4 +1,6 @@
|
|||||||
from typing import Optional
|
"""Tests the gym.wrapper.AutoResetWrapper operates as expected."""
|
||||||
|
|
||||||
|
from typing import Generator, Optional
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -10,33 +12,31 @@ from tests.envs.spec_list import spec_list
|
|||||||
|
|
||||||
|
|
||||||
class DummyResetEnv(gym.Env):
|
class DummyResetEnv(gym.Env):
|
||||||
"""
|
"""A dummy environment which returns ascending numbers starting at `0` when :meth:`self.step()` is called.
|
||||||
A dummy environment which returns ascending numbers starting
|
|
||||||
at 0 when self.step() is called. After the third call to self.step()
|
After the second call to :meth:`self.step()` done is true.
|
||||||
done is true. Info dicts are also returned containing the same number
|
Info dicts are also returned containing the same number returned as an observation, accessible via the key "count".
|
||||||
returned as an observation, accessible via the key "count".
|
This environment is provided for the purpose of testing the autoreset wrapper.
|
||||||
This environment is provided for the purpose of testing the
|
|
||||||
autoreset wrapper.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
metadata = {}
|
metadata = {}
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
"""Initialise the DummyResetEnv."""
|
||||||
self.action_space = gym.spaces.Box(
|
self.action_space = gym.spaces.Box(
|
||||||
low=np.array([-1.0]), high=np.array([1.0]), dtype=np.float64
|
low=np.array([0]), high=np.array([2]), dtype=np.int64
|
||||||
)
|
|
||||||
self.observation_space = gym.spaces.Box(
|
|
||||||
low=np.array([-1.0]), high=np.array([1.0])
|
|
||||||
)
|
)
|
||||||
|
self.observation_space = gym.spaces.Discrete(2)
|
||||||
self.count = 0
|
self.count = 0
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action: int):
|
||||||
|
"""Steps the DummyEnv with the incremented step, reward and done `if self.count > 1` and updated info."""
|
||||||
self.count += 1
|
self.count += 1
|
||||||
return (
|
return (
|
||||||
np.array([self.count]),
|
np.array([self.count]), # Obs
|
||||||
1 if self.count > 2 else 0,
|
self.count > 2, # Reward
|
||||||
self.count > 2,
|
self.count > 2, # Done
|
||||||
{"count": self.count},
|
{"count": self.count}, # Info
|
||||||
)
|
)
|
||||||
|
|
||||||
def reset(
|
def reset(
|
||||||
@@ -46,6 +46,7 @@ class DummyResetEnv(gym.Env):
|
|||||||
return_info: Optional[bool] = False,
|
return_info: Optional[bool] = False,
|
||||||
options: Optional[dict] = None
|
options: Optional[dict] = None
|
||||||
):
|
):
|
||||||
|
"""Resets the DummyEnv to return the count array and info with count."""
|
||||||
self.count = 0
|
self.count = 0
|
||||||
if not return_info:
|
if not return_info:
|
||||||
return np.array([self.count])
|
return np.array([self.count])
|
||||||
@@ -53,79 +54,78 @@ class DummyResetEnv(gym.Env):
|
|||||||
return np.array([self.count]), {"count": self.count}
|
return np.array([self.count]), {"count": self.count}
|
||||||
|
|
||||||
|
|
||||||
def test_autoreset_reset_info():
|
def unwrap_env(env) -> Generator[gym.Wrapper, None, None]:
|
||||||
env = gym.make("CartPole-v1")
|
"""Unwraps an environment yielding all wrappers around environment."""
|
||||||
env = AutoResetWrapper(env)
|
while isinstance(env, gym.Wrapper):
|
||||||
ob_space = env.observation_space
|
yield type(env)
|
||||||
obs = env.reset()
|
env = env.env
|
||||||
assert ob_space.contains(obs)
|
|
||||||
obs = env.reset(return_info=False)
|
|
||||||
assert ob_space.contains(obs)
|
|
||||||
obs, info = env.reset(return_info=True)
|
|
||||||
assert ob_space.contains(obs)
|
|
||||||
assert isinstance(info, dict)
|
|
||||||
env.close()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("spec", spec_list, ids=[spec.id for spec in spec_list])
|
@pytest.mark.parametrize("spec", spec_list, ids=[spec.id for spec in spec_list])
|
||||||
def test_make_autoreset_true(spec):
|
def test_make_autoreset_true(spec):
|
||||||
"""
|
"""Tests gym.make with `autoreset=True`, and check that the reset actually happens.
|
||||||
Note: This test assumes that the outermost wrapper is AutoResetWrapper
|
|
||||||
so if that is being changed in the future, this test will break and need
|
Note: This test assumes that the outermost wrapper is AutoResetWrapper so if that
|
||||||
to be updated.
|
is being changed in the future, this test will break and need to be updated.
|
||||||
Note: This test assumes that all first-party environments will terminate in a finite
|
Note: This test assumes that all first-party environments will terminate in a finite
|
||||||
amount of time with random actions, which is true as of the time of adding this test.
|
amount of time with random actions, which is true as of the time of adding this test.
|
||||||
"""
|
"""
|
||||||
with pytest.warns(None):
|
with pytest.warns(None):
|
||||||
env = spec.make(autoreset=True)
|
env = gym.make(spec.id, autoreset=True)
|
||||||
|
assert AutoResetWrapper in unwrap_env(env)
|
||||||
|
|
||||||
env.reset(seed=0)
|
env.reset(seed=0)
|
||||||
env.action_space.seed(0)
|
|
||||||
|
|
||||||
env.unwrapped.reset = MagicMock(side_effect=env.unwrapped.reset)
|
env.unwrapped.reset = MagicMock(side_effect=env.unwrapped.reset)
|
||||||
|
|
||||||
done = False
|
done = False
|
||||||
while not done:
|
while not done:
|
||||||
obs, reward, done, info = env.step(env.action_space.sample())
|
obs, reward, done, info = env.step(env.action_space.sample())
|
||||||
|
|
||||||
assert isinstance(env, AutoResetWrapper)
|
|
||||||
assert env.unwrapped.reset.called
|
assert env.unwrapped.reset.called
|
||||||
env.close()
|
env.close()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("spec", spec_list, ids=[spec.id for spec in spec_list])
|
@pytest.mark.parametrize("spec", spec_list, ids=[spec.id for spec in spec_list])
|
||||||
def test_make_autoreset_false(spec):
|
def test_gym_make_autoreset(spec):
|
||||||
|
"""Tests that `gym.make` autoreset wrapper is applied only when `gym.make(..., autoreset=True)`."""
|
||||||
with pytest.warns(None):
|
with pytest.warns(None):
|
||||||
env = spec.make(autoreset=False)
|
env = gym.make(spec.id)
|
||||||
assert not isinstance(env, AutoResetWrapper)
|
assert AutoResetWrapper not in unwrap_env(env)
|
||||||
|
env.close()
|
||||||
|
|
||||||
|
with pytest.warns(None):
|
||||||
|
env = gym.make(spec.id, autoreset=False)
|
||||||
|
assert AutoResetWrapper not in unwrap_env(env)
|
||||||
|
env.close()
|
||||||
|
|
||||||
|
with pytest.warns(None):
|
||||||
|
env = gym.make(spec.id, autoreset=True)
|
||||||
|
assert AutoResetWrapper in unwrap_env(env)
|
||||||
env.close()
|
env.close()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("spec", spec_list, ids=[spec.id for spec in spec_list])
|
def test_autoreset_wrapper_autoreset():
|
||||||
def test_make_autoreset_default_false(spec):
|
"""Tests the autoreset wrapper actually automatically resets correctly."""
|
||||||
with pytest.warns(None):
|
|
||||||
env = spec.make()
|
|
||||||
assert not isinstance(env, AutoResetWrapper)
|
|
||||||
env.close()
|
|
||||||
|
|
||||||
|
|
||||||
def test_autoreset_autoreset():
|
|
||||||
env = DummyResetEnv()
|
env = DummyResetEnv()
|
||||||
env = AutoResetWrapper(env)
|
env = AutoResetWrapper(env)
|
||||||
|
|
||||||
obs, info = env.reset(return_info=True)
|
obs, info = env.reset(return_info=True)
|
||||||
assert obs == np.array([0])
|
assert obs == np.array([0])
|
||||||
assert info == {"count": 0}
|
assert info == {"count": 0}
|
||||||
action = 1
|
|
||||||
|
action = 0
|
||||||
obs, reward, done, info = env.step(action)
|
obs, reward, done, info = env.step(action)
|
||||||
assert obs == np.array([1])
|
assert obs == np.array([1])
|
||||||
assert reward == 0
|
assert reward == 0
|
||||||
assert done is False
|
assert done is False
|
||||||
assert info == {"count": 1}
|
assert info == {"count": 1}
|
||||||
|
|
||||||
obs, reward, done, info = env.step(action)
|
obs, reward, done, info = env.step(action)
|
||||||
assert obs == np.array([2])
|
assert obs == np.array([2])
|
||||||
assert done is False
|
assert done is False
|
||||||
assert reward == 0
|
assert reward == 0
|
||||||
assert info == {"count": 2}
|
assert info == {"count": 2}
|
||||||
|
|
||||||
obs, reward, done, info = env.step(action)
|
obs, reward, done, info = env.step(action)
|
||||||
assert obs == np.array([0])
|
assert obs == np.array([0])
|
||||||
assert done is True
|
assert done is True
|
||||||
@@ -135,14 +135,11 @@ def test_autoreset_autoreset():
|
|||||||
"terminal_observation": np.array([3]),
|
"terminal_observation": np.array([3]),
|
||||||
"terminal_info": {"count": 3},
|
"terminal_info": {"count": 3},
|
||||||
}
|
}
|
||||||
|
|
||||||
obs, reward, done, info = env.step(action)
|
obs, reward, done, info = env.step(action)
|
||||||
assert obs == np.array([1])
|
assert obs == np.array([1])
|
||||||
assert reward == 0
|
assert reward == 0
|
||||||
assert done is False
|
assert done is False
|
||||||
assert info == {"count": 1}
|
assert info == {"count": 1}
|
||||||
obs, reward, done, info = env.step(action)
|
|
||||||
assert obs == np.array([2])
|
|
||||||
assert reward == 0
|
|
||||||
assert done is False
|
|
||||||
assert info == {"count": 2}
|
|
||||||
env.close()
|
env.close()
|
||||||
|
Reference in New Issue
Block a user