mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-19 13:32:03 +00:00
Add Pydocstyle to CI (#2785)
* Added pydocstyle to pre-commit * Added docstrings for tests and updated the tests for autoreset * Add pydocstyle exclude folder to allow slowly adding new docstrings * Add docstrings for setup.py and gym/__init__.py, core.py, error.py and logger.py * Check that all unwrapped environment are of a particular wrapper type * Reverted back to import gym.spaces.Space to gym.spaces * Fixed the __init__.py docstring * Fixed autoreset autoreset test * Updated gym __init__.py top docstring * Remove unnecessary import * Removed "unused error" and make APIerror deprecated at gym 1.0 * Add pydocstyle description to CONTRIBUTING.md * Added docstrings section to CONTRIBUTING.md * Added :meth: and :attr: keywords to docstrings * Added :meth: and :attr: keywords to docstrings * Update the step docstring placing the return type in the as a note. * Updated step return type to include each element * Update maths notation to reward range * Fixed infinity maths notation
This commit is contained in:
@@ -26,6 +26,16 @@ repos:
|
||||
hooks:
|
||||
- id: isort
|
||||
args: ["--profile", "black"]
|
||||
- repo: https://github.com/pycqa/pydocstyle
|
||||
rev: 6.1.1 # pick a git hash / tag to point to
|
||||
hooks:
|
||||
- id: pydocstyle
|
||||
exclude: ^(gym/version.py)|(gym/(wrappers|envs|spaces|utils|vector)/)|(tests/)
|
||||
args:
|
||||
- --source
|
||||
- --explain
|
||||
- --convention=google
|
||||
additional_dependencies: ["toml"]
|
||||
- repo: https://github.com/asottile/pyupgrade
|
||||
rev: v2.32.0
|
||||
hooks:
|
||||
|
@@ -43,3 +43,13 @@ The Git hooks can also be run manually with `pre-commit run --all-files`, and if
|
||||
|
||||
Additionally, for pull requests, the project runs a number of tests for the whole project using [pytest](https://docs.pytest.org/en/latest/getting-started.html#install-pytest).
|
||||
These tests can be run locally with `pytest` in the root folder.
|
||||
|
||||
## Docstrings
|
||||
Pydocstyle has been added to the pre-commit process such that all new functions follow the (google docstring style)[https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html].
|
||||
All new functions require either a short docstring, a single line explaining the purpose of a function
|
||||
or a multiline docstring that documents each argument and the return type (if there is one) of the function.
|
||||
In addition, new file and class require top docstrings that should outline the purpose of the file/class.
|
||||
For classes, code block examples can be provided in the top docstring and not the constructor arguments.
|
||||
|
||||
To check your docstrings are correct, run `pre-commit run --al-files` or `pydocstyle --source --explain --convention=google`.
|
||||
If all docstrings that fail, the source and reason for the failure is provided.
|
@@ -1,3 +1,4 @@
|
||||
"""Root __init__ of the gym module setting the __all__ of gym modules."""
|
||||
# isort: skip_file
|
||||
|
||||
from gym import error
|
||||
|
216
gym/core.py
216
gym/core.py
@@ -1,7 +1,8 @@
|
||||
"""Core API for Environment, Wrapper, ActionWrapper, RewardWrapper and ObservationWrapper."""
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
from typing import Generic, Optional, SupportsFloat, Tuple, TypeVar, Union
|
||||
from typing import Generic, Optional, SupportsFloat, TypeVar, Union
|
||||
|
||||
from gym import spaces
|
||||
from gym.logger import deprecation
|
||||
@@ -13,27 +14,30 @@ ActType = TypeVar("ActType")
|
||||
|
||||
|
||||
class Env(Generic[ObsType, ActType]):
|
||||
"""The main OpenAI Gym class. It encapsulates an environment with
|
||||
arbitrary behind-the-scenes dynamics. An environment can be
|
||||
partially or fully observed.
|
||||
r"""The main OpenAI Gym class.
|
||||
|
||||
It encapsulates an environment with arbitrary behind-the-scenes dynamics.
|
||||
An environment can be partially or fully observed.
|
||||
|
||||
The main API methods that users of this class need to know are:
|
||||
|
||||
step
|
||||
reset
|
||||
render
|
||||
close
|
||||
seed
|
||||
- :meth:`step` - Takes a step in the environment using an action returning the next observation, reward,
|
||||
if the environment terminated and more information.
|
||||
- :meth:`reset` - Resets the environment to an initial state, returning the initial observation.
|
||||
- :meth:`render` - Renders the environment observation with modes depending on the output
|
||||
- :meth:`close` - Closes the environment, important for rendering where pygame is imported
|
||||
- :meth:`seed` - Seeds the environment's random number generator, :deprecated: in favor of `Env.reset(seed=seed)`.
|
||||
|
||||
And set the following attributes:
|
||||
|
||||
action_space: The Space object corresponding to valid actions
|
||||
observation_space: The Space object corresponding to valid observations
|
||||
reward_range: A tuple corresponding to the min and max possible rewards
|
||||
- :attr:`action_space` - The Space object corresponding to valid actions
|
||||
- :attr:`observation_space` - The Space object corresponding to valid observations
|
||||
- :attr:`reward_range` - A tuple corresponding to the minimum and maximum possible rewards
|
||||
- :attr:`spec` - An environment spec that contains the information used to initialise the environment from `gym.make`
|
||||
- :attr:`metadata` - The metadata of the environment, i.e. render modes
|
||||
- :attr:`np_random` - The random number generator for the environment
|
||||
|
||||
Note: a default reward range set to [-inf,+inf] already exists. Set it if you want a narrower range.
|
||||
|
||||
The methods are accessed publicly as "step", "reset", etc...
|
||||
Note: a default reward range set to :math:`(-\infty,+\infty)` already exists. Set it if you want a narrower range.
|
||||
"""
|
||||
|
||||
# Set this in SOME subclasses
|
||||
@@ -46,11 +50,11 @@ class Env(Generic[ObsType, ActType]):
|
||||
observation_space: spaces.Space[ObsType]
|
||||
|
||||
# Created
|
||||
_np_random: RandomNumberGenerator | None = None
|
||||
_np_random: Optional[RandomNumberGenerator] = None
|
||||
|
||||
@property
|
||||
def np_random(self) -> RandomNumberGenerator:
|
||||
"""Initializes the np_random field if not done already."""
|
||||
"""Returns the environment's internal :attr:`_np_random` that if not set will initialise with a random seed."""
|
||||
if self._np_random is None:
|
||||
self._np_random, seed = seeding.np_random()
|
||||
return self._np_random
|
||||
@@ -60,28 +64,27 @@ class Env(Generic[ObsType, ActType]):
|
||||
self._np_random = value
|
||||
|
||||
@abstractmethod
|
||||
def step(self, action: ActType) -> Tuple[ObsType, float, bool, dict]:
|
||||
"""Run one timestep of the environment's dynamics. When end of
|
||||
episode is reached, you are responsible for calling :meth:`reset`
|
||||
to reset this environment's state.
|
||||
def step(self, action: ActType) -> tuple[ObsType, float, bool, dict]:
|
||||
"""Run one timestep of the environment's dynamics.
|
||||
|
||||
Accepts an action and returns a tuple (observation, reward, done, info).
|
||||
When end of episode is reached, you are responsible for calling :meth:`reset` to reset this environment's state.
|
||||
Accepts an action and returns a tuple `(observation, reward, done, info)`.
|
||||
|
||||
Args:
|
||||
action (object): an action provided by the agent
|
||||
|
||||
This method returns a tuple ``(observation, reward, done, info)``
|
||||
|
||||
Returns:
|
||||
observation (object): agent's observation of the current environment. This will be an element of the environment's :attr:`observation_space`. This may, for instance, be a numpy array containing the positions and velocities of certain objects.
|
||||
reward (float) : amount of reward returned after previous action
|
||||
done (bool): whether the episode has ended, in which case further :meth:`step` calls will return undefined results. A done signal may be emitted for different reasons: Maybe the task underlying the environment was solved successfully, a certain timelimit was exceeded, or the physics simulation has entered an invalid state. ``info`` may contain additional information regarding the reason for a ``done`` signal.
|
||||
info (dict): contains auxiliary diagnostic information (helpful for debugging, learning, and logging). This might, for instance, contain:
|
||||
|
||||
- metrics that describe the agent's performance or
|
||||
- state variables that are hidden from observations or
|
||||
- information that distinguishes truncation and termination or
|
||||
- individual reward terms that are combined to produce the total reward
|
||||
observation (object): this will be an element of the environment's :attr:`observation_space`.
|
||||
This may, for instance, be a numpy array containing the positions and velocities of certain objects.
|
||||
reward (float): The amount of reward returned as a result of taking the action.
|
||||
done (bool): A boolean value for if the episode has ended, in which case further :meth:`step` calls will return undefined results.
|
||||
A done signal may be emitted for different reasons: Maybe the task underlying the environment was solved successfully,
|
||||
a certain timelimit was exceeded, or the physics simulation has entered an invalid state.
|
||||
info (dictionary): A dictionary that may contain additional information regarding the reason for a ``done`` signal.
|
||||
`info` contains auxiliary diagnostic information (helpful for debugging, learning, and logging).
|
||||
This might, for instance, contain: metrics that describe the agent's performance state, variables that are
|
||||
hidden from observations, information that distinguishes truncation and termination or individual reward terms
|
||||
that are combined to produce the total reward
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -93,28 +96,34 @@ class Env(Generic[ObsType, ActType]):
|
||||
return_info: bool = False,
|
||||
options: Optional[dict] = None,
|
||||
) -> Union[ObsType, tuple[ObsType, dict]]:
|
||||
"""Resets the environment to an initial state and returns an initial
|
||||
observation.
|
||||
"""Resets the environment to an initial state and returns the initial observation.
|
||||
|
||||
This method should also reset the environment's random number
|
||||
generator(s) if ``seed`` is an integer or if the environment has not
|
||||
yet initialized a random number generator. If the environment already
|
||||
has a random number generator and :meth:`reset` is called with ``seed=None``,
|
||||
the RNG should not be reset.
|
||||
Moreover, :meth:`reset` should (in the typical use case) be called with an
|
||||
This method can reset the environment's random number generator(s) if ``seed`` is an integer or
|
||||
if the environment has not yet initialized a random number generator.
|
||||
If the environment already has a random number generator and :meth:`reset` is called with ``seed=None``,
|
||||
the RNG should not be reset. Moreover, :meth:`reset` should (in the typical use case) be called with an
|
||||
integer seed right after initialization and then never again.
|
||||
|
||||
Args:
|
||||
seed (int or None):
|
||||
The seed that is used to initialize the environment's PRNG. If the environment does not already have a PRNG and ``seed=None`` (the default option) is passed, a seed will be chosen from some source of entropy (e.g. timestamp or /dev/urandom).
|
||||
However, if the environment already has a PRNG and ``seed=None`` is passed, the PRNG will *not* be reset. If you pass an integer, the PRNG will be reset even if it already exists. Usually, you want to pass an integer *right after the environment has been initialized and then never again*. Please refer to the minimal example above to see this paradigm in action.
|
||||
return_info (bool): If true, return additional information along with initial observation. This info should be analogous to the info returned in :meth:`step`
|
||||
options (dict or None): Additional information to specify how the environment is reset (optional, depending on the specific environment)
|
||||
seed (optional int): The seed that is used to initialize the environment's PRNG.
|
||||
If the environment does not already have a PRNG and ``seed=None`` (the default option) is passed,
|
||||
a seed will be chosen from some source of entropy (e.g. timestamp or /dev/urandom).
|
||||
However, if the environment already has a PRNG and ``seed=None`` is passed, the PRNG will *not* be reset.
|
||||
If you pass an integer, the PRNG will be reset even if it already exists.
|
||||
Usually, you want to pass an integer *right after the environment has been initialized and then never again*.
|
||||
Please refer to the minimal example above to see this paradigm in action.
|
||||
return_info (bool): If true, return additional information along with initial observation.
|
||||
This info should be analogous to the info returned in :meth:`step`
|
||||
options (optional dict): Additional information to specify how the environment is reset (optional,
|
||||
depending on the specific environment)
|
||||
|
||||
|
||||
Returns:
|
||||
observation (object): Observation of the initial state. This will be an element of :attr:`observation_space` (usually a numpy array) and is analogous to the observation returned by :meth:`step`.
|
||||
info (optional dictionary): This will *only* be returned if ``return_info=True`` is passed. It contains auxiliary information complementing ``observation``. This dictionary should be analogous to the ``info`` returned by :meth:`step`.
|
||||
observation (object): Observation of the initial state. This will be an element of :attr:`observation_space`
|
||||
(typically a numpy array) and is analogous to the observation returned by :meth:`step`.
|
||||
info (optional dictionary): This will *only* be returned if ``return_info=True`` is passed.
|
||||
It contains auxiliary information complementing ``observation``. This dictionary should be analogous to
|
||||
the ``info`` returned by :meth:`step`.
|
||||
"""
|
||||
# Initialize the RNG if the seed is manually passed
|
||||
if seed is not None:
|
||||
@@ -124,13 +133,13 @@ class Env(Generic[ObsType, ActType]):
|
||||
def render(self, mode="human"):
|
||||
"""Renders the environment.
|
||||
|
||||
The set of supported modes varies per environment. (And some
|
||||
A set of supported modes varies per environment. (And some
|
||||
third-party environments may not support rendering at all.)
|
||||
By convention, if mode is:
|
||||
|
||||
- human: render to the current display or terminal and
|
||||
return nothing. Usually for human consumption.
|
||||
- rgb_array: Return an numpy.ndarray with shape (x, y, 3),
|
||||
- rgb_array: Return a numpy.ndarray with shape (x, y, 3),
|
||||
representing RGB values for an x-by-y pixel image, suitable
|
||||
for turning into a video.
|
||||
- ansi: Return a string (str) or StringIO.StringIO containing a
|
||||
@@ -139,60 +148,64 @@ class Env(Generic[ObsType, ActType]):
|
||||
|
||||
Note:
|
||||
Make sure that your class's metadata 'render_modes' key includes
|
||||
the list of supported modes. It's recommended to call super()
|
||||
in implementations to use the functionality of this method.
|
||||
the list of supported modes. It's recommended to call super()
|
||||
in implementations to use the functionality of this method.
|
||||
|
||||
Example:
|
||||
>>> class MyEnv(Env):
|
||||
... metadata = {'render_modes': ['human', 'rgb_array']}
|
||||
...
|
||||
... def render(self, mode='human'):
|
||||
... if mode == 'rgb_array':
|
||||
... return np.array(...) # return RGB frame suitable for video
|
||||
... elif mode == 'human':
|
||||
... ... # pop up a window and render
|
||||
... else:
|
||||
... super(MyEnv, self).render(mode=mode) # just raise an exception
|
||||
|
||||
Args:
|
||||
mode (str): the mode to render with
|
||||
|
||||
Example::
|
||||
|
||||
class MyEnv(Env):
|
||||
metadata = {'render_modes': ['human', 'rgb_array']}
|
||||
|
||||
def render(self, mode='human'):
|
||||
if mode == 'rgb_array':
|
||||
return np.array(...) # return RGB frame suitable for video
|
||||
elif mode == 'human':
|
||||
... # pop up a window and render
|
||||
else:
|
||||
super(MyEnv, self).render(mode=mode) # just raise an exception
|
||||
mode: the mode to render with, valid modes are `env.metadata["render_modes"]`
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def close(self):
|
||||
"""Override close in your subclass to perform any necessary cleanup.
|
||||
|
||||
Environments will automatically close() themselves when
|
||||
Environments will automatically :meth:`close()` themselves when
|
||||
garbage collected or when the program exits.
|
||||
"""
|
||||
pass
|
||||
|
||||
def seed(self, seed=None):
|
||||
"""Sets the seed for this env's random number generator(s).
|
||||
""":deprecated: function that sets the seed for the environment's random number generator(s).
|
||||
|
||||
Use `env.reset(seed=seed)` as the new API for setting the seed of the environment.
|
||||
|
||||
Note:
|
||||
Some environments use multiple pseudorandom number generators.
|
||||
We want to capture all such seeds used in order to ensure that
|
||||
there aren't accidental correlations between multiple generators.
|
||||
|
||||
Args:
|
||||
seed(Optional int): The seed value for the random number geneartor
|
||||
|
||||
Returns:
|
||||
list<bigint>: Returns the list of seeds used in this env's random
|
||||
seeds (List[int]): Returns the list of seeds used in this environment's random
|
||||
number generators. The first value in the list should be the
|
||||
"main" seed, or the value which a reproducer should pass to
|
||||
'seed'. Often, the main seed equals the provided 'seed', but
|
||||
this won't be true if seed=None, for example.
|
||||
this won't be true `if seed=None`, for example.
|
||||
"""
|
||||
deprecation(
|
||||
"Function `env.seed(seed)` is marked as deprecated and will be removed in the future. "
|
||||
"Please use `env.reset(seed=seed) instead."
|
||||
"Please use `env.reset(seed=seed)` instead."
|
||||
)
|
||||
self._np_random, seed = seeding.np_random(seed)
|
||||
return [seed]
|
||||
|
||||
@property
|
||||
def unwrapped(self) -> Env:
|
||||
"""Completely unwrap this env.
|
||||
"""Returns the base non-wrapped environment.
|
||||
|
||||
Returns:
|
||||
gym.Env: The base non-wrapped gym.Env instance
|
||||
@@ -200,6 +213,7 @@ class Env(Generic[ObsType, ActType]):
|
||||
return self
|
||||
|
||||
def __str__(self):
|
||||
"""Returns a string of the environment with the spec id if specified."""
|
||||
if self.spec is None:
|
||||
return f"<{type(self).__name__} instance>"
|
||||
else:
|
||||
@@ -217,71 +231,81 @@ class Env(Generic[ObsType, ActType]):
|
||||
|
||||
|
||||
class Wrapper(Env[ObsType, ActType]):
|
||||
"""Wraps the environment to allow a modular transformation.
|
||||
"""Wraps an environment to allow a modular transformation of the :meth:`step` and :meth:`reset` methods.
|
||||
|
||||
This class is the base class for all wrappers. The subclass could override
|
||||
some methods to change the behavior of the original environment without touching the
|
||||
original code.
|
||||
|
||||
.. note::
|
||||
|
||||
Note:
|
||||
Don't forget to call ``super().__init__(env)`` if the subclass overrides :meth:`__init__`.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, env: Env):
|
||||
"""Wraps an environment to allow a modular transformation of the :meth:`step` and :meth:`reset` methods.
|
||||
|
||||
Args:
|
||||
env: The environment to wrap
|
||||
"""
|
||||
self.env = env
|
||||
|
||||
self._action_space: spaces.Space | None = None
|
||||
self._observation_space: spaces.Space | None = None
|
||||
self._reward_range: tuple[SupportsFloat, SupportsFloat] | None = None
|
||||
self._metadata: dict | None = None
|
||||
self._action_space: Optional[spaces.Space] = None
|
||||
self._observation_space: Optional[spaces.Space] = None
|
||||
self._reward_range: Optional[tuple[SupportsFloat, SupportsFloat]] = None
|
||||
self._metadata: Optional[dict] = None
|
||||
|
||||
def __getattr__(self, name):
|
||||
"""Returns an attribute with ``name``, unless ``name`` starts with an underscore."""
|
||||
if name.startswith("_"):
|
||||
raise AttributeError(f"accessing private attribute '{name}' is prohibited")
|
||||
return getattr(self.env, name)
|
||||
|
||||
@property
|
||||
def spec(self):
|
||||
"""Returns the environment specification."""
|
||||
return self.env.spec
|
||||
|
||||
@classmethod
|
||||
def class_name(cls):
|
||||
"""Returns the class name of the wrapper."""
|
||||
return cls.__name__
|
||||
|
||||
@property
|
||||
def action_space(self) -> spaces.Space[ActType]:
|
||||
"""Returns the action space of the environment."""
|
||||
if self._action_space is None:
|
||||
return self.env.action_space
|
||||
return self._action_space
|
||||
|
||||
@action_space.setter
|
||||
def action_space(self, space):
|
||||
def action_space(self, space: spaces.Space):
|
||||
self._action_space = space
|
||||
|
||||
@property
|
||||
def observation_space(self) -> spaces.Space:
|
||||
"""Returns the observation space of the environment."""
|
||||
if self._observation_space is None:
|
||||
return self.env.observation_space
|
||||
return self._observation_space
|
||||
|
||||
@observation_space.setter
|
||||
def observation_space(self, space):
|
||||
def observation_space(self, space: spaces.Space):
|
||||
self._observation_space = space
|
||||
|
||||
@property
|
||||
def reward_range(self) -> tuple[SupportsFloat, SupportsFloat]:
|
||||
"""Return the reward range of the environment."""
|
||||
if self._reward_range is None:
|
||||
return self.env.reward_range
|
||||
return self._reward_range
|
||||
|
||||
@reward_range.setter
|
||||
def reward_range(self, value):
|
||||
def reward_range(self, value: tuple[SupportsFloat, SupportsFloat]):
|
||||
self._reward_range = value
|
||||
|
||||
@property
|
||||
def metadata(self) -> dict:
|
||||
"""Returns the environment metadata."""
|
||||
if self._metadata is None:
|
||||
return self.env.metadata
|
||||
return self._metadata
|
||||
@@ -290,34 +314,45 @@ class Wrapper(Env[ObsType, ActType]):
|
||||
def metadata(self, value):
|
||||
self._metadata = value
|
||||
|
||||
def step(self, action: ActType) -> Tuple[ObsType, float, bool, dict]:
|
||||
def step(self, action: ActType) -> tuple[ObsType, float, bool, dict]:
|
||||
"""Steps through the environment with action."""
|
||||
return self.env.step(action)
|
||||
|
||||
def reset(self, **kwargs) -> Union[ObsType, tuple[ObsType, dict]]:
|
||||
"""Resets the environment with kwargs."""
|
||||
return self.env.reset(**kwargs)
|
||||
|
||||
def render(self, **kwargs):
|
||||
"""Renders the environment with kwargs."""
|
||||
return self.env.render(**kwargs)
|
||||
|
||||
def close(self):
|
||||
"""Closes the environment."""
|
||||
return self.env.close()
|
||||
|
||||
def seed(self, seed=None):
|
||||
"""Seeds the environment."""
|
||||
return self.env.seed(seed)
|
||||
|
||||
def __str__(self):
|
||||
"""Returns the wrapper name and the unwrapped environment string."""
|
||||
return f"<{type(self).__name__}{self.env}>"
|
||||
|
||||
def __repr__(self):
|
||||
"""Returns the string representation of the wrapper."""
|
||||
return str(self)
|
||||
|
||||
@property
|
||||
def unwrapped(self) -> Env:
|
||||
"""Returns the base environment of the wrapper."""
|
||||
return self.env.unwrapped
|
||||
|
||||
|
||||
class ObservationWrapper(Wrapper):
|
||||
"""A wrapper that can modify observations using :meth:`observation` for :meth:`reset` and :meth:`step`."""
|
||||
|
||||
def reset(self, **kwargs):
|
||||
"""Resets the environment, returning a modified observation using :meth:`self.observation`."""
|
||||
if kwargs.get("return_info", False):
|
||||
obs, info = self.env.reset(**kwargs)
|
||||
return self.observation(obs), info
|
||||
@@ -325,38 +360,43 @@ class ObservationWrapper(Wrapper):
|
||||
return self.observation(self.env.reset(**kwargs))
|
||||
|
||||
def step(self, action):
|
||||
"""Returns a modified observation using :meth:`self.observation` after calling :meth:`env.step`."""
|
||||
observation, reward, done, info = self.env.step(action)
|
||||
return self.observation(observation), reward, done, info
|
||||
|
||||
@abstractmethod
|
||||
def observation(self, observation):
|
||||
"""Returns a modified observation."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class RewardWrapper(Wrapper):
|
||||
def reset(self, **kwargs):
|
||||
return self.env.reset(**kwargs)
|
||||
"""A wrapper that can modify the returning reward from a step."""
|
||||
|
||||
def step(self, action):
|
||||
"""Modifies the reward using :meth:`self.reward` after the environment :meth:`env.step`."""
|
||||
observation, reward, done, info = self.env.step(action)
|
||||
return observation, self.reward(reward), done, info
|
||||
|
||||
@abstractmethod
|
||||
def reward(self, reward):
|
||||
"""Returns a modified ``reward``."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ActionWrapper(Wrapper):
|
||||
def reset(self, **kwargs):
|
||||
return self.env.reset(**kwargs)
|
||||
"""A wrapper that can modify the action before :meth:`env.step`."""
|
||||
|
||||
def step(self, action):
|
||||
"""Runs the environment :meth:`env.step` using the modified ``action`` from :meth:`self.action`."""
|
||||
return self.env.step(self.action(action))
|
||||
|
||||
@abstractmethod
|
||||
def action(self, action):
|
||||
"""Returns a modified action before :meth:`env.step` is called."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def reverse_action(self, action):
|
||||
"""Returns a reversed ``action``."""
|
||||
raise NotImplementedError
|
||||
|
151
gym/error.py
151
gym/error.py
@@ -1,122 +1,76 @@
|
||||
"""Set of Error classes for gym."""
|
||||
import warnings
|
||||
|
||||
|
||||
class Error(Exception):
|
||||
pass
|
||||
"""Error superclass."""
|
||||
|
||||
|
||||
# Local errors
|
||||
|
||||
|
||||
class Unregistered(Error):
|
||||
"""Raised when the user requests an item from the registry that does
|
||||
not actually exist.
|
||||
"""
|
||||
|
||||
pass
|
||||
"""Raised when the user requests an item from the registry that does not actually exist."""
|
||||
|
||||
|
||||
class UnregisteredEnv(Unregistered):
|
||||
"""Raised when the user requests an env from the registry that does
|
||||
not actually exist.
|
||||
"""
|
||||
|
||||
pass
|
||||
"""Raised when the user requests an env from the registry that does not actually exist."""
|
||||
|
||||
|
||||
class NamespaceNotFound(UnregisteredEnv):
|
||||
"""Raised when the user requests an env from the registry where the
|
||||
namespace doesn't exist.
|
||||
"""
|
||||
|
||||
pass
|
||||
"""Raised when the user requests an env from the registry where the namespace doesn't exist."""
|
||||
|
||||
|
||||
class NameNotFound(UnregisteredEnv):
|
||||
"""Raised when the user requests an env from the registry where the
|
||||
name doesn't exist.
|
||||
"""
|
||||
|
||||
pass
|
||||
"""Raised when the user requests an env from the registry where the name doesn't exist."""
|
||||
|
||||
|
||||
class VersionNotFound(UnregisteredEnv):
|
||||
"""Raised when the user requests an env from the registry where the
|
||||
version doesn't exist.
|
||||
"""
|
||||
|
||||
pass
|
||||
"""Raised when the user requests an env from the registry where the version doesn't exist."""
|
||||
|
||||
|
||||
class UnregisteredBenchmark(Unregistered):
|
||||
"""Raised when the user requests an env from the registry that does
|
||||
not actually exist.
|
||||
"""
|
||||
|
||||
pass
|
||||
"""Raised when the user requests an env from the registry that does not actually exist."""
|
||||
|
||||
|
||||
class DeprecatedEnv(Error):
|
||||
"""Raised when the user requests an env from the registry with an
|
||||
older version number than the latest env with the same name.
|
||||
"""
|
||||
|
||||
pass
|
||||
"""Raised when the user requests an env from the registry with an older version number than the latest env with the same name."""
|
||||
|
||||
|
||||
class RegistrationError(Error):
|
||||
"""Raised when the user attempts to register an invalid env.
|
||||
For example, an unversioned env when a versioned env exists.
|
||||
"""
|
||||
|
||||
pass
|
||||
"""Raised when the user attempts to register an invalid env. For example, an unversioned env when a versioned env exists."""
|
||||
|
||||
|
||||
class UnseedableEnv(Error):
|
||||
"""Raised when the user tries to seed an env that does not support
|
||||
seeding.
|
||||
"""
|
||||
|
||||
pass
|
||||
"""Raised when the user tries to seed an env that does not support seeding."""
|
||||
|
||||
|
||||
class DependencyNotInstalled(Error):
|
||||
pass
|
||||
"""Raised when the user has not installed a dependency."""
|
||||
|
||||
|
||||
class UnsupportedMode(Exception):
|
||||
"""Raised when the user requests a rendering mode not supported by the
|
||||
environment.
|
||||
"""
|
||||
|
||||
pass
|
||||
class UnsupportedMode(Error):
|
||||
"""Raised when the user requests a rendering mode not supported by the environment."""
|
||||
|
||||
|
||||
class ResetNeeded(Exception):
|
||||
"""When the monitor is active, raised when the user tries to step an
|
||||
environment that's already done.
|
||||
"""
|
||||
|
||||
pass
|
||||
class ResetNeeded(Error):
|
||||
"""When the order enforcing is violated, i.e. step or render is called before reset."""
|
||||
|
||||
|
||||
class ResetNotAllowed(Exception):
|
||||
"""When the monitor is active, raised when the user tries to step an
|
||||
environment that's not yet done.
|
||||
"""
|
||||
|
||||
pass
|
||||
class ResetNotAllowed(Error):
|
||||
"""When the monitor is active, raised when the user tries to step an environment that's not yet done."""
|
||||
|
||||
|
||||
class InvalidAction(Exception):
|
||||
"""Raised when the user performs an action not contained within the
|
||||
action space
|
||||
"""
|
||||
|
||||
pass
|
||||
class InvalidAction(Error):
|
||||
"""Raised when the user performs an action not contained within the action space."""
|
||||
|
||||
|
||||
# API errors
|
||||
|
||||
|
||||
class APIError(Error):
|
||||
"""Deprecated, to be removed at gym 1.0."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message=None,
|
||||
@@ -125,8 +79,11 @@ class APIError(Error):
|
||||
json_body=None,
|
||||
headers=None,
|
||||
):
|
||||
"""Initialise API error."""
|
||||
super().__init__(message)
|
||||
|
||||
warnings.warn("APIError is deprecated and will be removed at gym 1.0")
|
||||
|
||||
if http_body and hasattr(http_body, "decode"):
|
||||
try:
|
||||
http_body = http_body.decode("utf-8")
|
||||
@@ -141,6 +98,7 @@ class APIError(Error):
|
||||
self.request_id = self.headers.get("request-id", None)
|
||||
|
||||
def __unicode__(self):
|
||||
"""Returns a string, if request_id is not None then make message other use the _message."""
|
||||
if self.request_id is not None:
|
||||
msg = self._message or "<empty message>"
|
||||
return f"Request {self.request_id}: {msg}"
|
||||
@@ -148,14 +106,17 @@ class APIError(Error):
|
||||
return self._message
|
||||
|
||||
def __str__(self):
|
||||
"""Returns the __unicode__."""
|
||||
return self.__unicode__()
|
||||
|
||||
|
||||
class APIConnectionError(APIError):
|
||||
pass
|
||||
"""Deprecated, to be removed at gym 1.0."""
|
||||
|
||||
|
||||
class InvalidRequestError(APIError):
|
||||
"""Deprecated, to be removed at gym 1.0."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message,
|
||||
@@ -165,83 +126,69 @@ class InvalidRequestError(APIError):
|
||||
json_body=None,
|
||||
headers=None,
|
||||
):
|
||||
"""Initialises the invalid request error."""
|
||||
super().__init__(message, http_body, http_status, json_body, headers)
|
||||
self.param = param
|
||||
|
||||
|
||||
class AuthenticationError(APIError):
|
||||
pass
|
||||
"""Deprecated, to be removed at gym 1.0."""
|
||||
|
||||
|
||||
class RateLimitError(APIError):
|
||||
pass
|
||||
"""Deprecated, to be removed at gym 1.0."""
|
||||
|
||||
|
||||
# Video errors
|
||||
|
||||
|
||||
class VideoRecorderError(Error):
|
||||
pass
|
||||
"""Unused error."""
|
||||
|
||||
|
||||
class InvalidFrame(Error):
|
||||
pass
|
||||
"""Error message when an invalid frame is captured."""
|
||||
|
||||
|
||||
# Wrapper errors
|
||||
|
||||
|
||||
class DoubleWrapperError(Error):
|
||||
pass
|
||||
"""Error message for when using double wrappers."""
|
||||
|
||||
|
||||
class WrapAfterConfigureError(Error):
|
||||
pass
|
||||
"""Error message for using wrap after configure."""
|
||||
|
||||
|
||||
class RetriesExceededError(Error):
|
||||
pass
|
||||
"""Error message for retries exceeding set number."""
|
||||
|
||||
|
||||
# Vectorized environments errors
|
||||
|
||||
|
||||
class AlreadyPendingCallError(Exception):
|
||||
"""
|
||||
Raised when `reset`, or `step` is called asynchronously (e.g. with
|
||||
`reset_async`, or `step_async` respectively), and `reset_async`, or
|
||||
`step_async` (respectively) is called again (without a complete call to
|
||||
`reset_wait`, or `step_wait` respectively).
|
||||
"""
|
||||
"""Raised when `reset`, or `step` is called asynchronously (e.g. with `reset_async`, or `step_async` respectively), and `reset_async`, or `step_async` (respectively) is called again (without a complete call to `reset_wait`, or `step_wait` respectively)."""
|
||||
|
||||
def __init__(self, message, name):
|
||||
def __init__(self, message: str, name: str):
|
||||
"""Initialises the exception with name attributes."""
|
||||
super().__init__(message)
|
||||
self.name = name
|
||||
|
||||
|
||||
class NoAsyncCallError(Exception):
|
||||
"""
|
||||
Raised when an asynchronous `reset`, or `step` is not running, but
|
||||
`reset_wait`, or `step_wait` (respectively) is called.
|
||||
"""
|
||||
"""Raised when an asynchronous `reset`, or `step` is not running, but `reset_wait`, or `step_wait` (respectively) is called."""
|
||||
|
||||
def __init__(self, message, name):
|
||||
def __init__(self, message: str, name: str):
|
||||
"""Initialises the exception with name attributes."""
|
||||
super().__init__(message)
|
||||
self.name = name
|
||||
|
||||
|
||||
class ClosedEnvironmentError(Exception):
|
||||
"""
|
||||
Trying to call `reset`, or `step`, while the environment is closed.
|
||||
"""
|
||||
|
||||
pass
|
||||
"""Trying to call `reset`, or `step`, while the environment is closed."""
|
||||
|
||||
|
||||
class CustomSpaceError(Exception):
|
||||
"""
|
||||
The space is a custom gym.Space instance, and is not supported by
|
||||
`AsyncVectorEnv` with `shared_memory=True`.
|
||||
"""
|
||||
|
||||
pass
|
||||
"""The space is a custom gym.Space instance, and is not supported by `AsyncVectorEnv` with `shared_memory=True`."""
|
||||
|
@@ -1,3 +1,4 @@
|
||||
"""Set of functions for logging messages."""
|
||||
import sys
|
||||
import warnings
|
||||
from typing import Optional, Type
|
||||
@@ -16,20 +17,20 @@ min_level = 30
|
||||
warnings.simplefilter("once", DeprecationWarning)
|
||||
|
||||
|
||||
def set_level(level: int) -> None:
|
||||
"""
|
||||
Set logging threshold on current logger.
|
||||
"""
|
||||
def set_level(level: int):
|
||||
"""Set logging threshold on current logger."""
|
||||
global min_level
|
||||
min_level = level
|
||||
|
||||
|
||||
def debug(msg: str, *args: object):
|
||||
"""Logs a debug message to the user."""
|
||||
if min_level <= DEBUG:
|
||||
print(f"DEBUG: {msg % args}", file=sys.stderr)
|
||||
|
||||
|
||||
def info(msg: str, *args: object):
|
||||
"""Logs an info message to the user."""
|
||||
if min_level <= INFO:
|
||||
print(f"INFO: {msg % args}", file=sys.stderr)
|
||||
|
||||
@@ -40,6 +41,14 @@ def warn(
|
||||
category: Optional[Type[Warning]] = None,
|
||||
stacklevel: int = 1,
|
||||
):
|
||||
"""Raises a warning to the user if the min_level <= WARN.
|
||||
|
||||
Args:
|
||||
msg: The message to warn the user
|
||||
*args: Additional information to warn the user
|
||||
category: The category of warning
|
||||
stacklevel: The stack level to raise to
|
||||
"""
|
||||
if min_level <= WARN:
|
||||
warnings.warn(
|
||||
colorize(f"WARN: {msg % args}", "yellow"),
|
||||
@@ -49,10 +58,12 @@ def warn(
|
||||
|
||||
|
||||
def deprecation(msg: str, *args: object):
|
||||
"""Logs a deprecation warning to users."""
|
||||
warn(msg, *args, category=DeprecationWarning, stacklevel=2)
|
||||
|
||||
|
||||
def error(msg: str, *args: object):
|
||||
"""Logs an error message if min_level <= ERROR in red on the sys.stderr."""
|
||||
if min_level <= ERROR:
|
||||
print(colorize(f"ERROR: {msg % args}", "red"), file=sys.stderr)
|
||||
|
||||
|
@@ -1,5 +1,6 @@
|
||||
"""A set of common utilities used within the environments. These are
|
||||
not intended as API functions, and will not remain stable over time.
|
||||
"""A set of common utilities used within the environments.
|
||||
|
||||
These are not intended as API functions, and will not remain stable over time.
|
||||
"""
|
||||
|
||||
# These submodules should not have any import-time dependencies.
|
||||
|
@@ -33,7 +33,7 @@ class AutoResetWrapper(gym.Wrapper):
|
||||
new observation from after calling self.env.reset() is returned
|
||||
by self.step() alongside the terminal reward and done state from the
|
||||
previous episode . If you need the terminal state from the previous
|
||||
episode, you need to retrieve it via the the "terminal_observation" key
|
||||
episode, you need to retrieve it via the "terminal_observation" key
|
||||
in the info dict. Make sure you know what you're doing if you
|
||||
use this wrapper!
|
||||
"""
|
||||
|
1
setup.py
1
setup.py
@@ -1,3 +1,4 @@
|
||||
"""Setups the project."""
|
||||
import itertools
|
||||
import os.path
|
||||
import sys
|
||||
|
@@ -4,8 +4,7 @@ import numpy as np
|
||||
import pytest
|
||||
from numpy.testing import assert_array_equal
|
||||
|
||||
from gym import Space
|
||||
from gym.spaces import Box, Dict, MultiDiscrete, Tuple
|
||||
from gym.spaces import Box, Dict, MultiDiscrete, Space, Tuple
|
||||
from gym.vector.utils.spaces import batch_space, iterate
|
||||
from tests.vector.utils import CustomSpace, assert_rng_equal, custom_spaces, spaces
|
||||
|
||||
|
@@ -1,4 +1,6 @@
|
||||
from typing import Optional
|
||||
"""Tests the gym.wrapper.AutoResetWrapper operates as expected."""
|
||||
|
||||
from typing import Generator, Optional
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import numpy as np
|
||||
@@ -10,33 +12,31 @@ from tests.envs.spec_list import spec_list
|
||||
|
||||
|
||||
class DummyResetEnv(gym.Env):
|
||||
"""
|
||||
A dummy environment which returns ascending numbers starting
|
||||
at 0 when self.step() is called. After the third call to self.step()
|
||||
done is true. Info dicts are also returned containing the same number
|
||||
returned as an observation, accessible via the key "count".
|
||||
This environment is provided for the purpose of testing the
|
||||
autoreset wrapper.
|
||||
"""A dummy environment which returns ascending numbers starting at `0` when :meth:`self.step()` is called.
|
||||
|
||||
After the second call to :meth:`self.step()` done is true.
|
||||
Info dicts are also returned containing the same number returned as an observation, accessible via the key "count".
|
||||
This environment is provided for the purpose of testing the autoreset wrapper.
|
||||
"""
|
||||
|
||||
metadata = {}
|
||||
|
||||
def __init__(self):
|
||||
"""Initialise the DummyResetEnv."""
|
||||
self.action_space = gym.spaces.Box(
|
||||
low=np.array([-1.0]), high=np.array([1.0]), dtype=np.float64
|
||||
)
|
||||
self.observation_space = gym.spaces.Box(
|
||||
low=np.array([-1.0]), high=np.array([1.0])
|
||||
low=np.array([0]), high=np.array([2]), dtype=np.int64
|
||||
)
|
||||
self.observation_space = gym.spaces.Discrete(2)
|
||||
self.count = 0
|
||||
|
||||
def step(self, action):
|
||||
def step(self, action: int):
|
||||
"""Steps the DummyEnv with the incremented step, reward and done `if self.count > 1` and updated info."""
|
||||
self.count += 1
|
||||
return (
|
||||
np.array([self.count]),
|
||||
1 if self.count > 2 else 0,
|
||||
self.count > 2,
|
||||
{"count": self.count},
|
||||
np.array([self.count]), # Obs
|
||||
self.count > 2, # Reward
|
||||
self.count > 2, # Done
|
||||
{"count": self.count}, # Info
|
||||
)
|
||||
|
||||
def reset(
|
||||
@@ -46,6 +46,7 @@ class DummyResetEnv(gym.Env):
|
||||
return_info: Optional[bool] = False,
|
||||
options: Optional[dict] = None
|
||||
):
|
||||
"""Resets the DummyEnv to return the count array and info with count."""
|
||||
self.count = 0
|
||||
if not return_info:
|
||||
return np.array([self.count])
|
||||
@@ -53,79 +54,78 @@ class DummyResetEnv(gym.Env):
|
||||
return np.array([self.count]), {"count": self.count}
|
||||
|
||||
|
||||
def test_autoreset_reset_info():
|
||||
env = gym.make("CartPole-v1")
|
||||
env = AutoResetWrapper(env)
|
||||
ob_space = env.observation_space
|
||||
obs = env.reset()
|
||||
assert ob_space.contains(obs)
|
||||
obs = env.reset(return_info=False)
|
||||
assert ob_space.contains(obs)
|
||||
obs, info = env.reset(return_info=True)
|
||||
assert ob_space.contains(obs)
|
||||
assert isinstance(info, dict)
|
||||
env.close()
|
||||
def unwrap_env(env) -> Generator[gym.Wrapper, None, None]:
|
||||
"""Unwraps an environment yielding all wrappers around environment."""
|
||||
while isinstance(env, gym.Wrapper):
|
||||
yield type(env)
|
||||
env = env.env
|
||||
|
||||
|
||||
@pytest.mark.parametrize("spec", spec_list, ids=[spec.id for spec in spec_list])
|
||||
def test_make_autoreset_true(spec):
|
||||
"""
|
||||
Note: This test assumes that the outermost wrapper is AutoResetWrapper
|
||||
so if that is being changed in the future, this test will break and need
|
||||
to be updated.
|
||||
"""Tests gym.make with `autoreset=True`, and check that the reset actually happens.
|
||||
|
||||
Note: This test assumes that the outermost wrapper is AutoResetWrapper so if that
|
||||
is being changed in the future, this test will break and need to be updated.
|
||||
Note: This test assumes that all first-party environments will terminate in a finite
|
||||
amount of time with random actions, which is true as of the time of adding this test.
|
||||
amount of time with random actions, which is true as of the time of adding this test.
|
||||
"""
|
||||
with pytest.warns(None):
|
||||
env = spec.make(autoreset=True)
|
||||
env = gym.make(spec.id, autoreset=True)
|
||||
assert AutoResetWrapper in unwrap_env(env)
|
||||
|
||||
env.reset(seed=0)
|
||||
env.action_space.seed(0)
|
||||
|
||||
env.unwrapped.reset = MagicMock(side_effect=env.unwrapped.reset)
|
||||
|
||||
done = False
|
||||
while not done:
|
||||
obs, reward, done, info = env.step(env.action_space.sample())
|
||||
|
||||
assert isinstance(env, AutoResetWrapper)
|
||||
assert env.unwrapped.reset.called
|
||||
env.close()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("spec", spec_list, ids=[spec.id for spec in spec_list])
|
||||
def test_make_autoreset_false(spec):
|
||||
def test_gym_make_autoreset(spec):
|
||||
"""Tests that `gym.make` autoreset wrapper is applied only when `gym.make(..., autoreset=True)`."""
|
||||
with pytest.warns(None):
|
||||
env = spec.make(autoreset=False)
|
||||
assert not isinstance(env, AutoResetWrapper)
|
||||
env = gym.make(spec.id)
|
||||
assert AutoResetWrapper not in unwrap_env(env)
|
||||
env.close()
|
||||
|
||||
with pytest.warns(None):
|
||||
env = gym.make(spec.id, autoreset=False)
|
||||
assert AutoResetWrapper not in unwrap_env(env)
|
||||
env.close()
|
||||
|
||||
with pytest.warns(None):
|
||||
env = gym.make(spec.id, autoreset=True)
|
||||
assert AutoResetWrapper in unwrap_env(env)
|
||||
env.close()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("spec", spec_list, ids=[spec.id for spec in spec_list])
|
||||
def test_make_autoreset_default_false(spec):
|
||||
with pytest.warns(None):
|
||||
env = spec.make()
|
||||
assert not isinstance(env, AutoResetWrapper)
|
||||
env.close()
|
||||
|
||||
|
||||
def test_autoreset_autoreset():
|
||||
def test_autoreset_wrapper_autoreset():
|
||||
"""Tests the autoreset wrapper actually automatically resets correctly."""
|
||||
env = DummyResetEnv()
|
||||
env = AutoResetWrapper(env)
|
||||
|
||||
obs, info = env.reset(return_info=True)
|
||||
assert obs == np.array([0])
|
||||
assert info == {"count": 0}
|
||||
action = 1
|
||||
|
||||
action = 0
|
||||
obs, reward, done, info = env.step(action)
|
||||
assert obs == np.array([1])
|
||||
assert reward == 0
|
||||
assert done is False
|
||||
assert info == {"count": 1}
|
||||
|
||||
obs, reward, done, info = env.step(action)
|
||||
assert obs == np.array([2])
|
||||
assert done is False
|
||||
assert reward == 0
|
||||
assert info == {"count": 2}
|
||||
|
||||
obs, reward, done, info = env.step(action)
|
||||
assert obs == np.array([0])
|
||||
assert done is True
|
||||
@@ -135,14 +135,11 @@ def test_autoreset_autoreset():
|
||||
"terminal_observation": np.array([3]),
|
||||
"terminal_info": {"count": 3},
|
||||
}
|
||||
|
||||
obs, reward, done, info = env.step(action)
|
||||
assert obs == np.array([1])
|
||||
assert reward == 0
|
||||
assert done is False
|
||||
assert info == {"count": 1}
|
||||
obs, reward, done, info = env.step(action)
|
||||
assert obs == np.array([2])
|
||||
assert reward == 0
|
||||
assert done is False
|
||||
assert info == {"count": 2}
|
||||
|
||||
env.close()
|
||||
|
Reference in New Issue
Block a user