Add Pydocstyle to CI (#2785)

* Added pydocstyle to pre-commit

* Added docstrings for tests and updated the tests for autoreset

* Add pydocstyle exclude folder to allow slowly adding new docstrings

* Add docstrings for setup.py and gym/__init__.py, core.py, error.py and logger.py

* Check that all unwrapped environment are of a particular wrapper type

* Reverted back to import gym.spaces.Space to gym.spaces

* Fixed the __init__.py docstring

* Fixed autoreset autoreset test

* Updated gym __init__.py top docstring

* Remove unnecessary import

* Removed "unused error" and make APIerror deprecated at gym 1.0

* Add pydocstyle description to CONTRIBUTING.md

* Added docstrings section to CONTRIBUTING.md

* Added :meth: and :attr: keywords to docstrings

* Added :meth: and :attr: keywords to docstrings

* Update the step docstring placing the return type in the as a note.

* Updated step return type to include each element

* Update maths notation to reward range

* Fixed infinity maths notation
This commit is contained in:
Mark Towers
2022-05-10 15:35:45 +01:00
committed by GitHub
parent 31e6f23e67
commit 1c62d3c6ad
11 changed files with 272 additions and 255 deletions

View File

@@ -26,6 +26,16 @@ repos:
hooks:
- id: isort
args: ["--profile", "black"]
- repo: https://github.com/pycqa/pydocstyle
rev: 6.1.1 # pick a git hash / tag to point to
hooks:
- id: pydocstyle
exclude: ^(gym/version.py)|(gym/(wrappers|envs|spaces|utils|vector)/)|(tests/)
args:
- --source
- --explain
- --convention=google
additional_dependencies: ["toml"]
- repo: https://github.com/asottile/pyupgrade
rev: v2.32.0
hooks:

View File

@@ -43,3 +43,13 @@ The Git hooks can also be run manually with `pre-commit run --all-files`, and if
Additionally, for pull requests, the project runs a number of tests for the whole project using [pytest](https://docs.pytest.org/en/latest/getting-started.html#install-pytest).
These tests can be run locally with `pytest` in the root folder.
## Docstrings
Pydocstyle has been added to the pre-commit process such that all new functions follow the (google docstring style)[https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html].
All new functions require either a short docstring, a single line explaining the purpose of a function
or a multiline docstring that documents each argument and the return type (if there is one) of the function.
In addition, new file and class require top docstrings that should outline the purpose of the file/class.
For classes, code block examples can be provided in the top docstring and not the constructor arguments.
To check your docstrings are correct, run `pre-commit run --al-files` or `pydocstyle --source --explain --convention=google`.
If all docstrings that fail, the source and reason for the failure is provided.

View File

@@ -1,3 +1,4 @@
"""Root __init__ of the gym module setting the __all__ of gym modules."""
# isort: skip_file
from gym import error

View File

@@ -1,7 +1,8 @@
"""Core API for Environment, Wrapper, ActionWrapper, RewardWrapper and ObservationWrapper."""
from __future__ import annotations
from abc import abstractmethod
from typing import Generic, Optional, SupportsFloat, Tuple, TypeVar, Union
from typing import Generic, Optional, SupportsFloat, TypeVar, Union
from gym import spaces
from gym.logger import deprecation
@@ -13,27 +14,30 @@ ActType = TypeVar("ActType")
class Env(Generic[ObsType, ActType]):
"""The main OpenAI Gym class. It encapsulates an environment with
arbitrary behind-the-scenes dynamics. An environment can be
partially or fully observed.
r"""The main OpenAI Gym class.
It encapsulates an environment with arbitrary behind-the-scenes dynamics.
An environment can be partially or fully observed.
The main API methods that users of this class need to know are:
step
reset
render
close
seed
- :meth:`step` - Takes a step in the environment using an action returning the next observation, reward,
if the environment terminated and more information.
- :meth:`reset` - Resets the environment to an initial state, returning the initial observation.
- :meth:`render` - Renders the environment observation with modes depending on the output
- :meth:`close` - Closes the environment, important for rendering where pygame is imported
- :meth:`seed` - Seeds the environment's random number generator, :deprecated: in favor of `Env.reset(seed=seed)`.
And set the following attributes:
action_space: The Space object corresponding to valid actions
observation_space: The Space object corresponding to valid observations
reward_range: A tuple corresponding to the min and max possible rewards
- :attr:`action_space` - The Space object corresponding to valid actions
- :attr:`observation_space` - The Space object corresponding to valid observations
- :attr:`reward_range` - A tuple corresponding to the minimum and maximum possible rewards
- :attr:`spec` - An environment spec that contains the information used to initialise the environment from `gym.make`
- :attr:`metadata` - The metadata of the environment, i.e. render modes
- :attr:`np_random` - The random number generator for the environment
Note: a default reward range set to [-inf,+inf] already exists. Set it if you want a narrower range.
The methods are accessed publicly as "step", "reset", etc...
Note: a default reward range set to :math:`(-\infty,+\infty)` already exists. Set it if you want a narrower range.
"""
# Set this in SOME subclasses
@@ -46,11 +50,11 @@ class Env(Generic[ObsType, ActType]):
observation_space: spaces.Space[ObsType]
# Created
_np_random: RandomNumberGenerator | None = None
_np_random: Optional[RandomNumberGenerator] = None
@property
def np_random(self) -> RandomNumberGenerator:
"""Initializes the np_random field if not done already."""
"""Returns the environment's internal :attr:`_np_random` that if not set will initialise with a random seed."""
if self._np_random is None:
self._np_random, seed = seeding.np_random()
return self._np_random
@@ -60,28 +64,27 @@ class Env(Generic[ObsType, ActType]):
self._np_random = value
@abstractmethod
def step(self, action: ActType) -> Tuple[ObsType, float, bool, dict]:
"""Run one timestep of the environment's dynamics. When end of
episode is reached, you are responsible for calling :meth:`reset`
to reset this environment's state.
def step(self, action: ActType) -> tuple[ObsType, float, bool, dict]:
"""Run one timestep of the environment's dynamics.
Accepts an action and returns a tuple (observation, reward, done, info).
When end of episode is reached, you are responsible for calling :meth:`reset` to reset this environment's state.
Accepts an action and returns a tuple `(observation, reward, done, info)`.
Args:
action (object): an action provided by the agent
This method returns a tuple ``(observation, reward, done, info)``
Returns:
observation (object): agent's observation of the current environment. This will be an element of the environment's :attr:`observation_space`. This may, for instance, be a numpy array containing the positions and velocities of certain objects.
reward (float) : amount of reward returned after previous action
done (bool): whether the episode has ended, in which case further :meth:`step` calls will return undefined results. A done signal may be emitted for different reasons: Maybe the task underlying the environment was solved successfully, a certain timelimit was exceeded, or the physics simulation has entered an invalid state. ``info`` may contain additional information regarding the reason for a ``done`` signal.
info (dict): contains auxiliary diagnostic information (helpful for debugging, learning, and logging). This might, for instance, contain:
- metrics that describe the agent's performance or
- state variables that are hidden from observations or
- information that distinguishes truncation and termination or
- individual reward terms that are combined to produce the total reward
observation (object): this will be an element of the environment's :attr:`observation_space`.
This may, for instance, be a numpy array containing the positions and velocities of certain objects.
reward (float): The amount of reward returned as a result of taking the action.
done (bool): A boolean value for if the episode has ended, in which case further :meth:`step` calls will return undefined results.
A done signal may be emitted for different reasons: Maybe the task underlying the environment was solved successfully,
a certain timelimit was exceeded, or the physics simulation has entered an invalid state.
info (dictionary): A dictionary that may contain additional information regarding the reason for a ``done`` signal.
`info` contains auxiliary diagnostic information (helpful for debugging, learning, and logging).
This might, for instance, contain: metrics that describe the agent's performance state, variables that are
hidden from observations, information that distinguishes truncation and termination or individual reward terms
that are combined to produce the total reward
"""
raise NotImplementedError
@@ -93,28 +96,34 @@ class Env(Generic[ObsType, ActType]):
return_info: bool = False,
options: Optional[dict] = None,
) -> Union[ObsType, tuple[ObsType, dict]]:
"""Resets the environment to an initial state and returns an initial
observation.
"""Resets the environment to an initial state and returns the initial observation.
This method should also reset the environment's random number
generator(s) if ``seed`` is an integer or if the environment has not
yet initialized a random number generator. If the environment already
has a random number generator and :meth:`reset` is called with ``seed=None``,
the RNG should not be reset.
Moreover, :meth:`reset` should (in the typical use case) be called with an
This method can reset the environment's random number generator(s) if ``seed`` is an integer or
if the environment has not yet initialized a random number generator.
If the environment already has a random number generator and :meth:`reset` is called with ``seed=None``,
the RNG should not be reset. Moreover, :meth:`reset` should (in the typical use case) be called with an
integer seed right after initialization and then never again.
Args:
seed (int or None):
The seed that is used to initialize the environment's PRNG. If the environment does not already have a PRNG and ``seed=None`` (the default option) is passed, a seed will be chosen from some source of entropy (e.g. timestamp or /dev/urandom).
However, if the environment already has a PRNG and ``seed=None`` is passed, the PRNG will *not* be reset. If you pass an integer, the PRNG will be reset even if it already exists. Usually, you want to pass an integer *right after the environment has been initialized and then never again*. Please refer to the minimal example above to see this paradigm in action.
return_info (bool): If true, return additional information along with initial observation. This info should be analogous to the info returned in :meth:`step`
options (dict or None): Additional information to specify how the environment is reset (optional, depending on the specific environment)
seed (optional int): The seed that is used to initialize the environment's PRNG.
If the environment does not already have a PRNG and ``seed=None`` (the default option) is passed,
a seed will be chosen from some source of entropy (e.g. timestamp or /dev/urandom).
However, if the environment already has a PRNG and ``seed=None`` is passed, the PRNG will *not* be reset.
If you pass an integer, the PRNG will be reset even if it already exists.
Usually, you want to pass an integer *right after the environment has been initialized and then never again*.
Please refer to the minimal example above to see this paradigm in action.
return_info (bool): If true, return additional information along with initial observation.
This info should be analogous to the info returned in :meth:`step`
options (optional dict): Additional information to specify how the environment is reset (optional,
depending on the specific environment)
Returns:
observation (object): Observation of the initial state. This will be an element of :attr:`observation_space` (usually a numpy array) and is analogous to the observation returned by :meth:`step`.
info (optional dictionary): This will *only* be returned if ``return_info=True`` is passed. It contains auxiliary information complementing ``observation``. This dictionary should be analogous to the ``info`` returned by :meth:`step`.
observation (object): Observation of the initial state. This will be an element of :attr:`observation_space`
(typically a numpy array) and is analogous to the observation returned by :meth:`step`.
info (optional dictionary): This will *only* be returned if ``return_info=True`` is passed.
It contains auxiliary information complementing ``observation``. This dictionary should be analogous to
the ``info`` returned by :meth:`step`.
"""
# Initialize the RNG if the seed is manually passed
if seed is not None:
@@ -124,13 +133,13 @@ class Env(Generic[ObsType, ActType]):
def render(self, mode="human"):
"""Renders the environment.
The set of supported modes varies per environment. (And some
A set of supported modes varies per environment. (And some
third-party environments may not support rendering at all.)
By convention, if mode is:
- human: render to the current display or terminal and
return nothing. Usually for human consumption.
- rgb_array: Return an numpy.ndarray with shape (x, y, 3),
- rgb_array: Return a numpy.ndarray with shape (x, y, 3),
representing RGB values for an x-by-y pixel image, suitable
for turning into a video.
- ansi: Return a string (str) or StringIO.StringIO containing a
@@ -142,57 +151,61 @@ class Env(Generic[ObsType, ActType]):
the list of supported modes. It's recommended to call super()
in implementations to use the functionality of this method.
Example:
>>> class MyEnv(Env):
... metadata = {'render_modes': ['human', 'rgb_array']}
...
... def render(self, mode='human'):
... if mode == 'rgb_array':
... return np.array(...) # return RGB frame suitable for video
... elif mode == 'human':
... ... # pop up a window and render
... else:
... super(MyEnv, self).render(mode=mode) # just raise an exception
Args:
mode (str): the mode to render with
Example::
class MyEnv(Env):
metadata = {'render_modes': ['human', 'rgb_array']}
def render(self, mode='human'):
if mode == 'rgb_array':
return np.array(...) # return RGB frame suitable for video
elif mode == 'human':
... # pop up a window and render
else:
super(MyEnv, self).render(mode=mode) # just raise an exception
mode: the mode to render with, valid modes are `env.metadata["render_modes"]`
"""
raise NotImplementedError
def close(self):
"""Override close in your subclass to perform any necessary cleanup.
Environments will automatically close() themselves when
Environments will automatically :meth:`close()` themselves when
garbage collected or when the program exits.
"""
pass
def seed(self, seed=None):
"""Sets the seed for this env's random number generator(s).
""":deprecated: function that sets the seed for the environment's random number generator(s).
Use `env.reset(seed=seed)` as the new API for setting the seed of the environment.
Note:
Some environments use multiple pseudorandom number generators.
We want to capture all such seeds used in order to ensure that
there aren't accidental correlations between multiple generators.
Args:
seed(Optional int): The seed value for the random number geneartor
Returns:
list<bigint>: Returns the list of seeds used in this env's random
seeds (List[int]): Returns the list of seeds used in this environment's random
number generators. The first value in the list should be the
"main" seed, or the value which a reproducer should pass to
'seed'. Often, the main seed equals the provided 'seed', but
this won't be true if seed=None, for example.
this won't be true `if seed=None`, for example.
"""
deprecation(
"Function `env.seed(seed)` is marked as deprecated and will be removed in the future. "
"Please use `env.reset(seed=seed) instead."
"Please use `env.reset(seed=seed)` instead."
)
self._np_random, seed = seeding.np_random(seed)
return [seed]
@property
def unwrapped(self) -> Env:
"""Completely unwrap this env.
"""Returns the base non-wrapped environment.
Returns:
gym.Env: The base non-wrapped gym.Env instance
@@ -200,6 +213,7 @@ class Env(Generic[ObsType, ActType]):
return self
def __str__(self):
"""Returns a string of the environment with the spec id if specified."""
if self.spec is None:
return f"<{type(self).__name__} instance>"
else:
@@ -217,71 +231,81 @@ class Env(Generic[ObsType, ActType]):
class Wrapper(Env[ObsType, ActType]):
"""Wraps the environment to allow a modular transformation.
"""Wraps an environment to allow a modular transformation of the :meth:`step` and :meth:`reset` methods.
This class is the base class for all wrappers. The subclass could override
some methods to change the behavior of the original environment without touching the
original code.
.. note::
Note:
Don't forget to call ``super().__init__(env)`` if the subclass overrides :meth:`__init__`.
"""
def __init__(self, env: Env):
"""Wraps an environment to allow a modular transformation of the :meth:`step` and :meth:`reset` methods.
Args:
env: The environment to wrap
"""
self.env = env
self._action_space: spaces.Space | None = None
self._observation_space: spaces.Space | None = None
self._reward_range: tuple[SupportsFloat, SupportsFloat] | None = None
self._metadata: dict | None = None
self._action_space: Optional[spaces.Space] = None
self._observation_space: Optional[spaces.Space] = None
self._reward_range: Optional[tuple[SupportsFloat, SupportsFloat]] = None
self._metadata: Optional[dict] = None
def __getattr__(self, name):
"""Returns an attribute with ``name``, unless ``name`` starts with an underscore."""
if name.startswith("_"):
raise AttributeError(f"accessing private attribute '{name}' is prohibited")
return getattr(self.env, name)
@property
def spec(self):
"""Returns the environment specification."""
return self.env.spec
@classmethod
def class_name(cls):
"""Returns the class name of the wrapper."""
return cls.__name__
@property
def action_space(self) -> spaces.Space[ActType]:
"""Returns the action space of the environment."""
if self._action_space is None:
return self.env.action_space
return self._action_space
@action_space.setter
def action_space(self, space):
def action_space(self, space: spaces.Space):
self._action_space = space
@property
def observation_space(self) -> spaces.Space:
"""Returns the observation space of the environment."""
if self._observation_space is None:
return self.env.observation_space
return self._observation_space
@observation_space.setter
def observation_space(self, space):
def observation_space(self, space: spaces.Space):
self._observation_space = space
@property
def reward_range(self) -> tuple[SupportsFloat, SupportsFloat]:
"""Return the reward range of the environment."""
if self._reward_range is None:
return self.env.reward_range
return self._reward_range
@reward_range.setter
def reward_range(self, value):
def reward_range(self, value: tuple[SupportsFloat, SupportsFloat]):
self._reward_range = value
@property
def metadata(self) -> dict:
"""Returns the environment metadata."""
if self._metadata is None:
return self.env.metadata
return self._metadata
@@ -290,34 +314,45 @@ class Wrapper(Env[ObsType, ActType]):
def metadata(self, value):
self._metadata = value
def step(self, action: ActType) -> Tuple[ObsType, float, bool, dict]:
def step(self, action: ActType) -> tuple[ObsType, float, bool, dict]:
"""Steps through the environment with action."""
return self.env.step(action)
def reset(self, **kwargs) -> Union[ObsType, tuple[ObsType, dict]]:
"""Resets the environment with kwargs."""
return self.env.reset(**kwargs)
def render(self, **kwargs):
"""Renders the environment with kwargs."""
return self.env.render(**kwargs)
def close(self):
"""Closes the environment."""
return self.env.close()
def seed(self, seed=None):
"""Seeds the environment."""
return self.env.seed(seed)
def __str__(self):
"""Returns the wrapper name and the unwrapped environment string."""
return f"<{type(self).__name__}{self.env}>"
def __repr__(self):
"""Returns the string representation of the wrapper."""
return str(self)
@property
def unwrapped(self) -> Env:
"""Returns the base environment of the wrapper."""
return self.env.unwrapped
class ObservationWrapper(Wrapper):
"""A wrapper that can modify observations using :meth:`observation` for :meth:`reset` and :meth:`step`."""
def reset(self, **kwargs):
"""Resets the environment, returning a modified observation using :meth:`self.observation`."""
if kwargs.get("return_info", False):
obs, info = self.env.reset(**kwargs)
return self.observation(obs), info
@@ -325,38 +360,43 @@ class ObservationWrapper(Wrapper):
return self.observation(self.env.reset(**kwargs))
def step(self, action):
"""Returns a modified observation using :meth:`self.observation` after calling :meth:`env.step`."""
observation, reward, done, info = self.env.step(action)
return self.observation(observation), reward, done, info
@abstractmethod
def observation(self, observation):
"""Returns a modified observation."""
raise NotImplementedError
class RewardWrapper(Wrapper):
def reset(self, **kwargs):
return self.env.reset(**kwargs)
"""A wrapper that can modify the returning reward from a step."""
def step(self, action):
"""Modifies the reward using :meth:`self.reward` after the environment :meth:`env.step`."""
observation, reward, done, info = self.env.step(action)
return observation, self.reward(reward), done, info
@abstractmethod
def reward(self, reward):
"""Returns a modified ``reward``."""
raise NotImplementedError
class ActionWrapper(Wrapper):
def reset(self, **kwargs):
return self.env.reset(**kwargs)
"""A wrapper that can modify the action before :meth:`env.step`."""
def step(self, action):
"""Runs the environment :meth:`env.step` using the modified ``action`` from :meth:`self.action`."""
return self.env.step(self.action(action))
@abstractmethod
def action(self, action):
"""Returns a modified action before :meth:`env.step` is called."""
raise NotImplementedError
@abstractmethod
def reverse_action(self, action):
"""Returns a reversed ``action``."""
raise NotImplementedError

View File

@@ -1,122 +1,76 @@
"""Set of Error classes for gym."""
import warnings
class Error(Exception):
pass
"""Error superclass."""
# Local errors
class Unregistered(Error):
"""Raised when the user requests an item from the registry that does
not actually exist.
"""
pass
"""Raised when the user requests an item from the registry that does not actually exist."""
class UnregisteredEnv(Unregistered):
"""Raised when the user requests an env from the registry that does
not actually exist.
"""
pass
"""Raised when the user requests an env from the registry that does not actually exist."""
class NamespaceNotFound(UnregisteredEnv):
"""Raised when the user requests an env from the registry where the
namespace doesn't exist.
"""
pass
"""Raised when the user requests an env from the registry where the namespace doesn't exist."""
class NameNotFound(UnregisteredEnv):
"""Raised when the user requests an env from the registry where the
name doesn't exist.
"""
pass
"""Raised when the user requests an env from the registry where the name doesn't exist."""
class VersionNotFound(UnregisteredEnv):
"""Raised when the user requests an env from the registry where the
version doesn't exist.
"""
pass
"""Raised when the user requests an env from the registry where the version doesn't exist."""
class UnregisteredBenchmark(Unregistered):
"""Raised when the user requests an env from the registry that does
not actually exist.
"""
pass
"""Raised when the user requests an env from the registry that does not actually exist."""
class DeprecatedEnv(Error):
"""Raised when the user requests an env from the registry with an
older version number than the latest env with the same name.
"""
pass
"""Raised when the user requests an env from the registry with an older version number than the latest env with the same name."""
class RegistrationError(Error):
"""Raised when the user attempts to register an invalid env.
For example, an unversioned env when a versioned env exists.
"""
pass
"""Raised when the user attempts to register an invalid env. For example, an unversioned env when a versioned env exists."""
class UnseedableEnv(Error):
"""Raised when the user tries to seed an env that does not support
seeding.
"""
pass
"""Raised when the user tries to seed an env that does not support seeding."""
class DependencyNotInstalled(Error):
pass
"""Raised when the user has not installed a dependency."""
class UnsupportedMode(Exception):
"""Raised when the user requests a rendering mode not supported by the
environment.
"""
pass
class UnsupportedMode(Error):
"""Raised when the user requests a rendering mode not supported by the environment."""
class ResetNeeded(Exception):
"""When the monitor is active, raised when the user tries to step an
environment that's already done.
"""
pass
class ResetNeeded(Error):
"""When the order enforcing is violated, i.e. step or render is called before reset."""
class ResetNotAllowed(Exception):
"""When the monitor is active, raised when the user tries to step an
environment that's not yet done.
"""
pass
class ResetNotAllowed(Error):
"""When the monitor is active, raised when the user tries to step an environment that's not yet done."""
class InvalidAction(Exception):
"""Raised when the user performs an action not contained within the
action space
"""
pass
class InvalidAction(Error):
"""Raised when the user performs an action not contained within the action space."""
# API errors
class APIError(Error):
"""Deprecated, to be removed at gym 1.0."""
def __init__(
self,
message=None,
@@ -125,8 +79,11 @@ class APIError(Error):
json_body=None,
headers=None,
):
"""Initialise API error."""
super().__init__(message)
warnings.warn("APIError is deprecated and will be removed at gym 1.0")
if http_body and hasattr(http_body, "decode"):
try:
http_body = http_body.decode("utf-8")
@@ -141,6 +98,7 @@ class APIError(Error):
self.request_id = self.headers.get("request-id", None)
def __unicode__(self):
"""Returns a string, if request_id is not None then make message other use the _message."""
if self.request_id is not None:
msg = self._message or "<empty message>"
return f"Request {self.request_id}: {msg}"
@@ -148,14 +106,17 @@ class APIError(Error):
return self._message
def __str__(self):
"""Returns the __unicode__."""
return self.__unicode__()
class APIConnectionError(APIError):
pass
"""Deprecated, to be removed at gym 1.0."""
class InvalidRequestError(APIError):
"""Deprecated, to be removed at gym 1.0."""
def __init__(
self,
message,
@@ -165,83 +126,69 @@ class InvalidRequestError(APIError):
json_body=None,
headers=None,
):
"""Initialises the invalid request error."""
super().__init__(message, http_body, http_status, json_body, headers)
self.param = param
class AuthenticationError(APIError):
pass
"""Deprecated, to be removed at gym 1.0."""
class RateLimitError(APIError):
pass
"""Deprecated, to be removed at gym 1.0."""
# Video errors
class VideoRecorderError(Error):
pass
"""Unused error."""
class InvalidFrame(Error):
pass
"""Error message when an invalid frame is captured."""
# Wrapper errors
class DoubleWrapperError(Error):
pass
"""Error message for when using double wrappers."""
class WrapAfterConfigureError(Error):
pass
"""Error message for using wrap after configure."""
class RetriesExceededError(Error):
pass
"""Error message for retries exceeding set number."""
# Vectorized environments errors
class AlreadyPendingCallError(Exception):
"""
Raised when `reset`, or `step` is called asynchronously (e.g. with
`reset_async`, or `step_async` respectively), and `reset_async`, or
`step_async` (respectively) is called again (without a complete call to
`reset_wait`, or `step_wait` respectively).
"""
"""Raised when `reset`, or `step` is called asynchronously (e.g. with `reset_async`, or `step_async` respectively), and `reset_async`, or `step_async` (respectively) is called again (without a complete call to `reset_wait`, or `step_wait` respectively)."""
def __init__(self, message, name):
def __init__(self, message: str, name: str):
"""Initialises the exception with name attributes."""
super().__init__(message)
self.name = name
class NoAsyncCallError(Exception):
"""
Raised when an asynchronous `reset`, or `step` is not running, but
`reset_wait`, or `step_wait` (respectively) is called.
"""
"""Raised when an asynchronous `reset`, or `step` is not running, but `reset_wait`, or `step_wait` (respectively) is called."""
def __init__(self, message, name):
def __init__(self, message: str, name: str):
"""Initialises the exception with name attributes."""
super().__init__(message)
self.name = name
class ClosedEnvironmentError(Exception):
"""
Trying to call `reset`, or `step`, while the environment is closed.
"""
pass
"""Trying to call `reset`, or `step`, while the environment is closed."""
class CustomSpaceError(Exception):
"""
The space is a custom gym.Space instance, and is not supported by
`AsyncVectorEnv` with `shared_memory=True`.
"""
pass
"""The space is a custom gym.Space instance, and is not supported by `AsyncVectorEnv` with `shared_memory=True`."""

View File

@@ -1,3 +1,4 @@
"""Set of functions for logging messages."""
import sys
import warnings
from typing import Optional, Type
@@ -16,20 +17,20 @@ min_level = 30
warnings.simplefilter("once", DeprecationWarning)
def set_level(level: int) -> None:
"""
Set logging threshold on current logger.
"""
def set_level(level: int):
"""Set logging threshold on current logger."""
global min_level
min_level = level
def debug(msg: str, *args: object):
"""Logs a debug message to the user."""
if min_level <= DEBUG:
print(f"DEBUG: {msg % args}", file=sys.stderr)
def info(msg: str, *args: object):
"""Logs an info message to the user."""
if min_level <= INFO:
print(f"INFO: {msg % args}", file=sys.stderr)
@@ -40,6 +41,14 @@ def warn(
category: Optional[Type[Warning]] = None,
stacklevel: int = 1,
):
"""Raises a warning to the user if the min_level <= WARN.
Args:
msg: The message to warn the user
*args: Additional information to warn the user
category: The category of warning
stacklevel: The stack level to raise to
"""
if min_level <= WARN:
warnings.warn(
colorize(f"WARN: {msg % args}", "yellow"),
@@ -49,10 +58,12 @@ def warn(
def deprecation(msg: str, *args: object):
"""Logs a deprecation warning to users."""
warn(msg, *args, category=DeprecationWarning, stacklevel=2)
def error(msg: str, *args: object):
"""Logs an error message if min_level <= ERROR in red on the sys.stderr."""
if min_level <= ERROR:
print(colorize(f"ERROR: {msg % args}", "red"), file=sys.stderr)

View File

@@ -1,5 +1,6 @@
"""A set of common utilities used within the environments. These are
not intended as API functions, and will not remain stable over time.
"""A set of common utilities used within the environments.
These are not intended as API functions, and will not remain stable over time.
"""
# These submodules should not have any import-time dependencies.

View File

@@ -33,7 +33,7 @@ class AutoResetWrapper(gym.Wrapper):
new observation from after calling self.env.reset() is returned
by self.step() alongside the terminal reward and done state from the
previous episode . If you need the terminal state from the previous
episode, you need to retrieve it via the the "terminal_observation" key
episode, you need to retrieve it via the "terminal_observation" key
in the info dict. Make sure you know what you're doing if you
use this wrapper!
"""

View File

@@ -1,3 +1,4 @@
"""Setups the project."""
import itertools
import os.path
import sys

View File

@@ -4,8 +4,7 @@ import numpy as np
import pytest
from numpy.testing import assert_array_equal
from gym import Space
from gym.spaces import Box, Dict, MultiDiscrete, Tuple
from gym.spaces import Box, Dict, MultiDiscrete, Space, Tuple
from gym.vector.utils.spaces import batch_space, iterate
from tests.vector.utils import CustomSpace, assert_rng_equal, custom_spaces, spaces

View File

@@ -1,4 +1,6 @@
from typing import Optional
"""Tests the gym.wrapper.AutoResetWrapper operates as expected."""
from typing import Generator, Optional
from unittest.mock import MagicMock
import numpy as np
@@ -10,33 +12,31 @@ from tests.envs.spec_list import spec_list
class DummyResetEnv(gym.Env):
"""
A dummy environment which returns ascending numbers starting
at 0 when self.step() is called. After the third call to self.step()
done is true. Info dicts are also returned containing the same number
returned as an observation, accessible via the key "count".
This environment is provided for the purpose of testing the
autoreset wrapper.
"""A dummy environment which returns ascending numbers starting at `0` when :meth:`self.step()` is called.
After the second call to :meth:`self.step()` done is true.
Info dicts are also returned containing the same number returned as an observation, accessible via the key "count".
This environment is provided for the purpose of testing the autoreset wrapper.
"""
metadata = {}
def __init__(self):
"""Initialise the DummyResetEnv."""
self.action_space = gym.spaces.Box(
low=np.array([-1.0]), high=np.array([1.0]), dtype=np.float64
)
self.observation_space = gym.spaces.Box(
low=np.array([-1.0]), high=np.array([1.0])
low=np.array([0]), high=np.array([2]), dtype=np.int64
)
self.observation_space = gym.spaces.Discrete(2)
self.count = 0
def step(self, action):
def step(self, action: int):
"""Steps the DummyEnv with the incremented step, reward and done `if self.count > 1` and updated info."""
self.count += 1
return (
np.array([self.count]),
1 if self.count > 2 else 0,
self.count > 2,
{"count": self.count},
np.array([self.count]), # Obs
self.count > 2, # Reward
self.count > 2, # Done
{"count": self.count}, # Info
)
def reset(
@@ -46,6 +46,7 @@ class DummyResetEnv(gym.Env):
return_info: Optional[bool] = False,
options: Optional[dict] = None
):
"""Resets the DummyEnv to return the count array and info with count."""
self.count = 0
if not return_info:
return np.array([self.count])
@@ -53,79 +54,78 @@ class DummyResetEnv(gym.Env):
return np.array([self.count]), {"count": self.count}
def test_autoreset_reset_info():
env = gym.make("CartPole-v1")
env = AutoResetWrapper(env)
ob_space = env.observation_space
obs = env.reset()
assert ob_space.contains(obs)
obs = env.reset(return_info=False)
assert ob_space.contains(obs)
obs, info = env.reset(return_info=True)
assert ob_space.contains(obs)
assert isinstance(info, dict)
env.close()
def unwrap_env(env) -> Generator[gym.Wrapper, None, None]:
"""Unwraps an environment yielding all wrappers around environment."""
while isinstance(env, gym.Wrapper):
yield type(env)
env = env.env
@pytest.mark.parametrize("spec", spec_list, ids=[spec.id for spec in spec_list])
def test_make_autoreset_true(spec):
"""
Note: This test assumes that the outermost wrapper is AutoResetWrapper
so if that is being changed in the future, this test will break and need
to be updated.
"""Tests gym.make with `autoreset=True`, and check that the reset actually happens.
Note: This test assumes that the outermost wrapper is AutoResetWrapper so if that
is being changed in the future, this test will break and need to be updated.
Note: This test assumes that all first-party environments will terminate in a finite
amount of time with random actions, which is true as of the time of adding this test.
"""
with pytest.warns(None):
env = spec.make(autoreset=True)
env = gym.make(spec.id, autoreset=True)
assert AutoResetWrapper in unwrap_env(env)
env.reset(seed=0)
env.action_space.seed(0)
env.unwrapped.reset = MagicMock(side_effect=env.unwrapped.reset)
done = False
while not done:
obs, reward, done, info = env.step(env.action_space.sample())
assert isinstance(env, AutoResetWrapper)
assert env.unwrapped.reset.called
env.close()
@pytest.mark.parametrize("spec", spec_list, ids=[spec.id for spec in spec_list])
def test_make_autoreset_false(spec):
def test_gym_make_autoreset(spec):
"""Tests that `gym.make` autoreset wrapper is applied only when `gym.make(..., autoreset=True)`."""
with pytest.warns(None):
env = spec.make(autoreset=False)
assert not isinstance(env, AutoResetWrapper)
env = gym.make(spec.id)
assert AutoResetWrapper not in unwrap_env(env)
env.close()
with pytest.warns(None):
env = gym.make(spec.id, autoreset=False)
assert AutoResetWrapper not in unwrap_env(env)
env.close()
with pytest.warns(None):
env = gym.make(spec.id, autoreset=True)
assert AutoResetWrapper in unwrap_env(env)
env.close()
@pytest.mark.parametrize("spec", spec_list, ids=[spec.id for spec in spec_list])
def test_make_autoreset_default_false(spec):
with pytest.warns(None):
env = spec.make()
assert not isinstance(env, AutoResetWrapper)
env.close()
def test_autoreset_autoreset():
def test_autoreset_wrapper_autoreset():
"""Tests the autoreset wrapper actually automatically resets correctly."""
env = DummyResetEnv()
env = AutoResetWrapper(env)
obs, info = env.reset(return_info=True)
assert obs == np.array([0])
assert info == {"count": 0}
action = 1
action = 0
obs, reward, done, info = env.step(action)
assert obs == np.array([1])
assert reward == 0
assert done is False
assert info == {"count": 1}
obs, reward, done, info = env.step(action)
assert obs == np.array([2])
assert done is False
assert reward == 0
assert info == {"count": 2}
obs, reward, done, info = env.step(action)
assert obs == np.array([0])
assert done is True
@@ -135,14 +135,11 @@ def test_autoreset_autoreset():
"terminal_observation": np.array([3]),
"terminal_info": {"count": 3},
}
obs, reward, done, info = env.step(action)
assert obs == np.array([1])
assert reward == 0
assert done is False
assert info == {"count": 1}
obs, reward, done, info = env.step(action)
assert obs == np.array([2])
assert reward == 0
assert done is False
assert info == {"count": 2}
env.close()